419 lines
18 KiB
Python
419 lines
18 KiB
Python
import asyncio
|
||
import os
|
||
import re
|
||
import io
|
||
import csv
|
||
import json
|
||
from datetime import datetime, timedelta, timezone
|
||
|
||
from aiogram import Bot, Dispatcher, types, F
|
||
from aiogram.filters import Command, StateFilter
|
||
from aiogram.fsm.context import FSMContext
|
||
from aiogram.fsm.state import State, StatesGroup, default_state
|
||
from aiogram.fsm.storage.memory import MemoryStorage
|
||
from aiogram.client.session.aiohttp import AiohttpSession
|
||
from aiogram.utils.keyboard import InlineKeyboardBuilder
|
||
from aiogram.types import BotCommand, BufferedInputFile
|
||
from sqlalchemy import select, delete, update
|
||
|
||
# 导入提供的函数和数据库逻辑
|
||
from vlm import call_qwen_vlm
|
||
from database import init_db, AsyncSessionLocal, Record
|
||
|
||
# --- 配置 ---
|
||
DEFAULT_CATEGORIES = ["餐饮", "交通", "购物", "娱乐", "住宿", "办公", "学习", "医疗", "居家", "人情", "运动", "其它"]
|
||
|
||
class RecordState(StatesGroup):
|
||
waiting_confirm = State() # 识别后的确认状态
|
||
editing_new_amt = State() # 存入前修改金额
|
||
# --- 新增/细化修改状态 ---
|
||
editing_old_amt = State() # 修改数据库已有记录的金额
|
||
editing_old_cat = State() # 修改数据库已有记录的类别
|
||
|
||
proxy = os.getenv("BOT_PROXY")
|
||
session = AiohttpSession(proxy=proxy) if proxy else None
|
||
bot = Bot(token=os.getenv("TG_TOKEN"), session=session)
|
||
dp = Dispatcher(storage=MemoryStorage())
|
||
|
||
# --- 辅助函数 ---
|
||
|
||
def render_confirm_text(data):
|
||
return (
|
||
f"📝 **AI 识别结果**\n"
|
||
f"━━━━━━━━━━━━━━━\n"
|
||
f"💰 **金额**:`{float(data.get('amount', 0)):.2f}` 元\n"
|
||
f"🏷 **类别**:`{data.get('category', '其它')}`\n"
|
||
f"📅 **时间**:`{data.get('transaction_time', '未知')}`\n"
|
||
f"━━━━━━━━━━━━━━━\n"
|
||
f"💡 请检查,如有误请点击下方按钮修改。"
|
||
)
|
||
|
||
def get_confirm_kb():
|
||
builder = InlineKeyboardBuilder()
|
||
builder.button(text="✅ 确认保存", callback_data="save_new")
|
||
builder.button(text="💰 改金额", callback_data="edit_new_amt")
|
||
builder.button(text="📂 改类别", callback_data="edit_new_cat")
|
||
builder.button(text="❌ 取消", callback_data="cancel_action")
|
||
builder.adjust(1, 2, 1)
|
||
return builder.as_markup()
|
||
|
||
# ================= 1. 指令处理器 =================
|
||
|
||
@dp.message(Command("start"), StateFilter("*"))
|
||
async def cmd_start(message: types.Message, state: FSMContext):
|
||
await state.clear()
|
||
welcome_text = (
|
||
"👋 **AI 记账助手**\n\n"
|
||
"● **发图**:识别小票或截图\n"
|
||
"● **发文**:如“昨天晚上火锅花了238元”\n\n"
|
||
"📅 **管理指令**:\n"
|
||
"/edit - 管理最近账单\n"
|
||
"/export - 导出 CSV 报表\n"
|
||
"━━━━━━━━━━━━━━━\n"
|
||
"请发送内容开始记账:"
|
||
)
|
||
await message.answer(welcome_text, parse_mode="Markdown")
|
||
|
||
@dp.message(Command("edit"), StateFilter("*"))
|
||
async def cmd_edit(message: types.Message, state: FSMContext):
|
||
await state.clear()
|
||
await show_record_list(message, message.from_user.id, page=0)
|
||
|
||
@dp.message(Command("export"), StateFilter("*"))
|
||
async def cmd_export(message: types.Message, state: FSMContext):
|
||
await state.clear()
|
||
user_id = message.from_user.id
|
||
|
||
processing_msg = await message.answer("📊 正在生成报表...")
|
||
|
||
async with AsyncSessionLocal() as session:
|
||
# 查询该用户的所有记录
|
||
stmt = select(Record).where(Record.user_id == user_id).order_by(Record.transaction_time.desc())
|
||
res = await session.execute(stmt)
|
||
records = res.scalars().all()
|
||
|
||
if not records:
|
||
await processing_msg.edit_text("📭 暂无数据可导出。")
|
||
return
|
||
|
||
# 生成 CSV
|
||
csv_buffer = io.StringIO()
|
||
writer = csv.writer(csv_buffer)
|
||
writer.writerow(["ID", "时间", "金额", "类别"])
|
||
|
||
total_amount = 0.0
|
||
for r in records:
|
||
writer.writerow([
|
||
r.id,
|
||
r.transaction_time.strftime("%Y-%m-%d %H:%M:%S"),
|
||
f"{r.amount:.2f}",
|
||
r.category
|
||
])
|
||
total_amount += r.amount
|
||
|
||
# 添加统计信息
|
||
writer.writerow([])
|
||
writer.writerow(["统计", "", "", ""])
|
||
writer.writerow(["总记录数", len(records), "", ""])
|
||
writer.writerow(["总金额", f"{total_amount:.2f}", "", ""])
|
||
|
||
# 按类别统计
|
||
category_stats = {}
|
||
for r in records:
|
||
category_stats[r.category] = category_stats.get(r.category, 0) + r.amount
|
||
|
||
writer.writerow([])
|
||
writer.writerow(["类别统计", "金额", "占比", ""])
|
||
for cat, amt in sorted(category_stats.items(), key=lambda x: x[1], reverse=True):
|
||
percentage = (amt / total_amount * 100) if total_amount > 0 else 0
|
||
writer.writerow([cat, f"{amt:.2f}", f"{percentage:.1f}%", ""])
|
||
|
||
# 转换为字节并发送
|
||
csv_bytes = csv_buffer.getvalue().encode('utf-8-sig') # 使用 UTF-8 BOM 以便 Excel 正确识别中文
|
||
filename = f"账单_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
|
||
|
||
file = BufferedInputFile(csv_bytes, filename=filename)
|
||
await processing_msg.delete()
|
||
await message.answer_document(
|
||
file,
|
||
caption=f"📊 **导出完成**\n\n📝 总记录:`{len(records)}` 条\n💰 总金额:`{total_amount:.2f}` 元",
|
||
parse_mode="Markdown"
|
||
)
|
||
|
||
# ================= 2. 核心:记账输入处理 (VLM) =================
|
||
|
||
@dp.message(F.photo | F.text, StateFilter(default_state))
|
||
async def handle_input(message: types.Message, state: FSMContext):
|
||
if message.text and message.text.startswith("/"): return
|
||
|
||
processing_msg = await message.answer("🧠 AI 正在分析账单...")
|
||
try:
|
||
if message.photo:
|
||
file_id = message.photo[-1].file_id
|
||
file_info = await bot.get_file(file_id)
|
||
photo_bytes = io.BytesIO()
|
||
await bot.download_file(file_info.file_path, photo_bytes)
|
||
vlm_res_str = await call_qwen_vlm(photo_bytes.getvalue(), is_image=True)
|
||
else:
|
||
vlm_res_str = await call_qwen_vlm(message.text, is_image=False)
|
||
|
||
data = json.loads(vlm_res_str)
|
||
await state.update_data(
|
||
amount=data.get("amount", 0.0),
|
||
category=data.get("category", "其它"),
|
||
transaction_time=data.get("transaction_time", "未知")
|
||
)
|
||
await processing_msg.delete()
|
||
await message.answer(render_confirm_text(data), reply_markup=get_confirm_kb(), parse_mode="Markdown")
|
||
await state.set_state(RecordState.waiting_confirm)
|
||
except Exception as e:
|
||
await processing_msg.edit_text(f"❌ 解析失败: {str(e)}")
|
||
|
||
# ================= 3. 新增记录的确认与修改 =================
|
||
|
||
@dp.callback_query(F.data == "save_new", RecordState.waiting_confirm)
|
||
async def cb_save_new(callback: types.CallbackQuery, state: FSMContext):
|
||
data = await state.get_data()
|
||
async with AsyncSessionLocal() as session:
|
||
new_rec = Record(
|
||
user_id=callback.from_user.id,
|
||
amount=abs(float(data['amount'])),
|
||
category=data['category'],
|
||
transaction_time=datetime.strptime(data['transaction_time'], "%Y-%m-%d %H:%M:%S")
|
||
)
|
||
session.add(new_rec)
|
||
await session.commit()
|
||
await callback.message.edit_text(f"✅ 已存入:{abs(float(data['amount']))}元 ({data['category']})")
|
||
await state.clear()
|
||
|
||
@dp.callback_query(F.data == "edit_new_amt", RecordState.waiting_confirm)
|
||
async def cb_edit_amt_before_save(callback: types.CallbackQuery, state: FSMContext):
|
||
# 构建带返回按钮的键盘
|
||
builder = InlineKeyboardBuilder()
|
||
builder.button(text="🔙 返回", callback_data="return_to_confirm")
|
||
|
||
await callback.message.answer("✍️ 请输入新金额(数字):", reply_markup=builder.as_markup())
|
||
await state.set_state(RecordState.editing_new_amt)
|
||
|
||
@dp.message(RecordState.editing_new_amt)
|
||
async def process_new_amt_input(message: types.Message, state: FSMContext):
|
||
if not re.match(r"^\d+(\.\d+)?$", message.text.strip()):
|
||
return await message.answer("⚠️ 请输入数字。")
|
||
await state.update_data(amount=float(message.text.strip()))
|
||
data = await state.get_data()
|
||
await state.set_state(RecordState.waiting_confirm)
|
||
await message.answer(render_confirm_text(data), reply_markup=get_confirm_kb(), parse_mode="Markdown")
|
||
|
||
@dp.callback_query(F.data == "edit_new_cat", RecordState.waiting_confirm)
|
||
async def cb_choose_cat_before_save(callback: types.CallbackQuery):
|
||
builder = InlineKeyboardBuilder()
|
||
for cat in DEFAULT_CATEGORIES:
|
||
builder.button(text=cat, callback_data=f"set_new_cat_{cat}")
|
||
builder.adjust(3)
|
||
# 添加返回按钮
|
||
builder.row(types.InlineKeyboardButton(text="🔙 返回", callback_data="return_to_confirm"))
|
||
await callback.message.edit_text("📂 请选择新类别:", reply_markup=builder.as_markup())
|
||
|
||
@dp.callback_query(F.data.startswith("set_new_cat_"), RecordState.waiting_confirm)
|
||
async def cb_set_cat_before_save(callback: types.CallbackQuery, state: FSMContext):
|
||
new_cat = callback.data.replace("set_new_cat_", "")
|
||
await state.update_data(category=new_cat)
|
||
data = await state.get_data()
|
||
await callback.message.edit_text(render_confirm_text(data), reply_markup=get_confirm_kb(), parse_mode="Markdown")
|
||
|
||
# 处理从修改类别界面返回确认界面的回调
|
||
@dp.callback_query(F.data == "return_to_confirm", RecordState.waiting_confirm)
|
||
async def cb_return_to_confirm_from_cat(callback: types.CallbackQuery, state: FSMContext):
|
||
"""从修改类别界面返回到确认界面"""
|
||
data = await state.get_data()
|
||
await callback.message.edit_text(render_confirm_text(data), reply_markup=get_confirm_kb(), parse_mode="Markdown")
|
||
|
||
# 处理从修改金额输入状态返回确认界面的回调
|
||
@dp.callback_query(F.data == "return_to_confirm", RecordState.editing_new_amt)
|
||
async def cb_return_to_confirm_from_amt(callback: types.CallbackQuery, state: FSMContext):
|
||
"""从修改金额输入状态返回到确认界面"""
|
||
data = await state.get_data()
|
||
await state.set_state(RecordState.waiting_confirm)
|
||
await callback.message.answer(render_confirm_text(data), reply_markup=get_confirm_kb(), parse_mode="Markdown")
|
||
|
||
# ================= 4. 账单列表展示 =================
|
||
|
||
async def show_record_list(message_or_call, user_id, page=0):
|
||
limit = 5
|
||
offset = page * limit
|
||
async with AsyncSessionLocal() as session:
|
||
stmt = select(Record).where(Record.user_id == user_id).order_by(Record.id.desc()).limit(limit).offset(offset)
|
||
res = await session.execute(stmt)
|
||
records = res.scalars().all()
|
||
|
||
if not records and page == 0:
|
||
text = "📭 暂无账单。"
|
||
return await (message_or_call.answer(text) if isinstance(message_or_call, types.Message) else message_or_call.message.edit_text(text))
|
||
|
||
list_text = [f"📊 **近期账单 (第 {page+1} 页)**", "━━━━━━━━━━━━━━━"]
|
||
builder = InlineKeyboardBuilder()
|
||
for i, r in enumerate(records):
|
||
list_text.append(f"{i+1}️⃣ `{r.transaction_time.strftime('%m-%d')}` | **{r.category}** | `{r.amount:.2f}`")
|
||
builder.button(text=f"{i+1}", callback_data=f"manage_{r.id}")
|
||
|
||
builder.adjust(5)
|
||
nav_btns = []
|
||
if page > 0: nav_btns.append(types.InlineKeyboardButton(text="⬅️ 上一页", callback_data=f"page_{page-1}"))
|
||
if len(records) == limit: nav_btns.append(types.InlineKeyboardButton(text="下一页 ➡️", callback_data=f"page_{page+1}"))
|
||
if nav_btns: builder.row(*nav_btns)
|
||
builder.row(types.InlineKeyboardButton(text="❌ 关闭并返回", callback_data="close_menu"))
|
||
|
||
final_text = "\n".join(list_text)
|
||
if isinstance(message_or_call, types.Message):
|
||
await message_or_call.answer(final_text, reply_markup=builder.as_markup(), parse_mode="Markdown")
|
||
else:
|
||
await message_or_call.message.edit_text(final_text, reply_markup=builder.as_markup(), parse_mode="Markdown")
|
||
|
||
# ================= 5. 【修复部分】修改数据库已有记录的逻辑 =================
|
||
|
||
# 1. 点击“改金额”按钮
|
||
@dp.callback_query(F.data.startswith("fld_amt_"))
|
||
async def cb_edit_old_amt(callback: types.CallbackQuery, state: FSMContext):
|
||
rid = int(callback.data.split("_")[2])
|
||
await state.update_data(edit_rid=rid)
|
||
await state.set_state(RecordState.editing_old_amt)
|
||
|
||
# 创建一个取消按钮
|
||
builder = InlineKeyboardBuilder()
|
||
builder.button(text="❌ 取消修改", callback_data=f"manage_{rid}")
|
||
|
||
await callback.message.answer(
|
||
f"✍️ 请输入 ID:{rid} 的新金额:",
|
||
reply_markup=builder.as_markup()
|
||
)
|
||
await callback.answer()
|
||
|
||
# 2. 接收用户输入的金额并更新数据库
|
||
@dp.message(RecordState.editing_old_amt)
|
||
async def process_old_amt_input(message: types.Message, state: FSMContext):
|
||
if not re.match(r"^\d+(\.\d+)?$", message.text.strip()):
|
||
return await message.answer("⚠️ 请输入有效的数字。")
|
||
|
||
data = await state.get_data()
|
||
rid = data['edit_rid']
|
||
new_amt = float(message.text.strip())
|
||
|
||
async with AsyncSessionLocal() as session:
|
||
await session.execute(update(Record).where(Record.id == rid).values(amount=new_amt))
|
||
await session.commit()
|
||
|
||
await message.answer(f"✅ ID:{rid} 金额已修改为 {new_amt}")
|
||
await state.clear()
|
||
# 修改完后回到该记录的详情页
|
||
await refresh_record_detail(message, rid)
|
||
|
||
@dp.callback_query(F.data.startswith("fld_cat_"))
|
||
async def cb_edit_old_cat(callback: types.CallbackQuery, state: FSMContext):
|
||
rid = int(callback.data.split("_")[2])
|
||
await state.update_data(edit_rid=rid)
|
||
await state.set_state(RecordState.editing_old_cat)
|
||
|
||
builder = InlineKeyboardBuilder()
|
||
# 添加分类按钮
|
||
for cat in DEFAULT_CATEGORIES:
|
||
builder.button(text=cat, callback_data=f"set_old_cat_{cat}")
|
||
builder.adjust(3) # 每行3个分类按钮
|
||
|
||
# --- 新增返回按钮逻辑 ---
|
||
# 使用 .row() 确保返回按钮独占一行
|
||
builder.row(types.InlineKeyboardButton(text="🔙 取消并返回详情", callback_data=f"manage_{rid}"))
|
||
|
||
await callback.message.edit_text(
|
||
f"📂 请选择 ID:{rid} 的新类别:",
|
||
reply_markup=builder.as_markup()
|
||
)
|
||
|
||
# 4. 点击类别按钮进行更新
|
||
@dp.callback_query(F.data.startswith("set_old_cat_"), RecordState.editing_old_cat)
|
||
async def cb_set_old_cat(callback: types.CallbackQuery, state: FSMContext):
|
||
new_cat = callback.data.replace("set_old_cat_", "")
|
||
data = await state.get_data()
|
||
rid = data['edit_rid']
|
||
|
||
async with AsyncSessionLocal() as session:
|
||
await session.execute(update(Record).where(Record.id == rid).values(category=new_cat))
|
||
await session.commit()
|
||
|
||
await callback.answer(f"✅ 类别已更新为 {new_cat}")
|
||
await state.clear()
|
||
await refresh_record_detail(callback.message, rid)
|
||
|
||
# 辅助函数:修改后刷新显示该记录详情
|
||
async def refresh_record_detail(message: types.Message, rid: int):
|
||
async with AsyncSessionLocal() as session:
|
||
res = await session.execute(select(Record).where(Record.id == rid))
|
||
r = res.scalar_one_or_none()
|
||
if not r: return
|
||
|
||
builder = InlineKeyboardBuilder()
|
||
builder.button(text="💰 改金额", callback_data=f"fld_amt_{rid}")
|
||
builder.button(text="📂 改类别", callback_data=f"fld_cat_{rid}")
|
||
builder.button(text="🗑 删除", callback_data=f"del_{rid}")
|
||
builder.button(text="🔙 返回列表", callback_data="page_0")
|
||
builder.adjust(2, 2, 1)
|
||
|
||
text = (f"🛠 **账单详情 (ID: {rid})**\n━━━━━━━━━━━━━━━\n"
|
||
f"时间:`{r.transaction_time}`\n金额:`{r.amount:.2f}` 元\n类别:`{r.category}`")
|
||
await message.answer(text, reply_markup=builder.as_markup(), parse_mode="Markdown")
|
||
|
||
# ================= 原有逻辑保持不变 =================
|
||
|
||
@dp.callback_query(F.data.startswith("manage_"))
|
||
async def cb_manage_record(callback: types.CallbackQuery):
|
||
rid = int(callback.data.split("_")[1])
|
||
async with AsyncSessionLocal() as session:
|
||
res = await session.execute(select(Record).where(Record.id == rid))
|
||
r = res.scalar_one_or_none()
|
||
if not r: return await callback.answer("记录不存在")
|
||
|
||
builder = InlineKeyboardBuilder()
|
||
builder.button(text="💰 改金额", callback_data=f"fld_amt_{rid}")
|
||
builder.button(text="📂 改类别", callback_data=f"fld_cat_{rid}")
|
||
builder.button(text="🗑 删除", callback_data=f"del_{rid}")
|
||
builder.button(text="🔙 返回列表", callback_data="page_0")
|
||
builder.adjust(2, 2, 1)
|
||
|
||
text = (f"🛠 **账单详情 (ID: {rid})**\n━━━━━━━━━━━━━━━\n"
|
||
f"时间:`{r.transaction_time}`\n金额:`{r.amount:.2f}` 元\n类别:`{r.category}`")
|
||
await callback.message.edit_text(text, reply_markup=builder.as_markup(), parse_mode="Markdown")
|
||
|
||
@dp.callback_query(F.data.startswith("del_"))
|
||
async def cb_del(callback: types.CallbackQuery):
|
||
rid = int(callback.data.split("_")[1])
|
||
async with AsyncSessionLocal() as session:
|
||
await session.execute(delete(Record).where(Record.id == rid))
|
||
await session.commit()
|
||
await callback.answer("🗑 记录已删除")
|
||
await show_record_list(callback, callback.from_user.id, page=0)
|
||
|
||
# 分页和其他回调...
|
||
@dp.callback_query(F.data.startswith("page_"))
|
||
async def cb_pagination(callback: types.CallbackQuery):
|
||
await show_record_list(callback, callback.from_user.id, page=int(callback.data.split("_")[1]))
|
||
|
||
@dp.callback_query(F.data == "close_menu")
|
||
async def cb_close_menu(callback: types.CallbackQuery, state: FSMContext):
|
||
await state.clear()
|
||
await callback.message.edit_text("✅ 已退出列表。")
|
||
|
||
@dp.callback_query(F.data == "cancel_action", StateFilter("*"))
|
||
async def cb_cancel(callback: types.CallbackQuery, state: FSMContext):
|
||
await state.clear()
|
||
await callback.message.edit_text("❌ 操作已取消。")
|
||
|
||
async def main():
|
||
await init_db()
|
||
await bot.set_my_commands([
|
||
BotCommand(command="start", description="开始"),
|
||
BotCommand(command="edit", description="列表管理"),
|
||
BotCommand(command="export", description="导出数据")
|
||
])
|
||
await dp.start_polling(bot)
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main()) |