295 lines
11 KiB
Python
295 lines
11 KiB
Python
from datetime import datetime, timezone
|
||
|
||
import pytz
|
||
from telegram import ReplyKeyboardRemove, Update
|
||
from telegram.ext import (
|
||
CommandHandler,
|
||
ContextTypes,
|
||
ConversationHandler,
|
||
MessageHandler,
|
||
filters,
|
||
)
|
||
|
||
from bot.models.database import Session
|
||
from bot.models.reminder import Reminder
|
||
from bot.models.user import User
|
||
from bot.scheduler.job_manager import add_reminder_job
|
||
from bot.states import (
|
||
CHOOSE_TYPE,
|
||
CONFIRM,
|
||
INPUT_DESC,
|
||
INPUT_HOLIDAY,
|
||
INPUT_INTERVAL_MINUTES,
|
||
INPUT_INTERVAL_WINDOW,
|
||
INPUT_TIME,
|
||
INPUT_TITLE,
|
||
INPUT_WEEKLY_DAYS,
|
||
)
|
||
from bot.utils.keyboards import (
|
||
confirm_keyboard,
|
||
main_keyboard,
|
||
reminder_type_keyboard,
|
||
yes_no_keyboard,
|
||
)
|
||
|
||
|
||
async def new_reminder(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int:
|
||
context.user_data.clear()
|
||
await update.message.reply_text(
|
||
"让我们创建一个新提醒!\n\n请选择提醒类型:",
|
||
reply_markup=reminder_type_keyboard(),
|
||
)
|
||
return CHOOSE_TYPE
|
||
|
||
|
||
async def choose_type(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int:
|
||
text = update.message.text.strip()
|
||
if text == "取消":
|
||
return await cancel(update, context)
|
||
|
||
type_map = {"一次性": "once", "每日": "daily", "每周": "weekly", "间隔": "interval"}
|
||
reminder_type = type_map.get(text)
|
||
|
||
if reminder_type is None:
|
||
await update.message.reply_text("请选择有效的提醒类型。")
|
||
return CHOOSE_TYPE
|
||
|
||
context.user_data["reminder_type"] = reminder_type
|
||
await update.message.reply_text(
|
||
"请输入提醒标题(简短描述):", reply_markup=ReplyKeyboardRemove()
|
||
)
|
||
return INPUT_TITLE
|
||
|
||
|
||
async def input_title(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int:
|
||
title = update.message.text.strip()
|
||
if not title:
|
||
await update.message.reply_text("标题不能为空,请重新输入:")
|
||
return INPUT_TITLE
|
||
|
||
context.user_data["title"] = title
|
||
await update.message.reply_text("请输入提醒描述(可选,直接发送"跳过"):")
|
||
return INPUT_DESC
|
||
|
||
|
||
async def input_desc(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int:
|
||
text = update.message.text.strip()
|
||
if text and text != "跳过":
|
||
context.user_data["description"] = text
|
||
|
||
reminder_type = context.user_data["reminder_type"]
|
||
|
||
if reminder_type == "once":
|
||
await update.message.reply_text(
|
||
"请输入提醒时间(格式:YYYY-MM-DD HH:MM)\n例如:2026-03-10 14:30"
|
||
)
|
||
elif reminder_type == "daily":
|
||
await update.message.reply_text("请输入每日提醒时间(格式:HH:MM)\n例如:09:00")
|
||
elif reminder_type == "weekly":
|
||
await update.message.reply_text(
|
||
"请输入星期几提醒(用逗号分隔,周一=1,周日=7)\n例如:1,3,5 表示周一、周三、周五"
|
||
)
|
||
elif reminder_type == "interval":
|
||
await update.message.reply_text("请输入间隔分钟数(例如:30 表示每30分钟)")
|
||
|
||
return INPUT_TIME
|
||
|
||
|
||
async def input_time(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int:
|
||
text = update.message.text.strip()
|
||
reminder_type = context.user_data["reminder_type"]
|
||
tz = pytz.timezone("Asia/Shanghai")
|
||
|
||
try:
|
||
if reminder_type == "once":
|
||
dt = datetime.strptime(text, "%Y-%m-%d %H:%M")
|
||
dt_aware = tz.localize(dt)
|
||
if dt_aware <= datetime.now(timezone.utc).astimezone(tz):
|
||
await update.message.reply_text("时间必须是未来时间,请重新输入:")
|
||
return INPUT_TIME
|
||
context.user_data["once_time"] = dt_aware.astimezone(timezone.utc)
|
||
return await ask_holiday(update, context)
|
||
|
||
elif reminder_type == "daily":
|
||
h, m = map(int, text.split(":"))
|
||
if not (0 <= h < 24 and 0 <= m < 60):
|
||
raise ValueError
|
||
context.user_data["daily_time"] = f"{h:02d}:{m:02d}"
|
||
return await ask_holiday(update, context)
|
||
|
||
elif reminder_type == "weekly":
|
||
days = [int(d.strip()) for d in text.split(",")]
|
||
if not all(1 <= d <= 7 for d in days):
|
||
raise ValueError
|
||
# Convert to APScheduler format (Mon=0, Sun=6)
|
||
context.user_data["weekly_days"] = ",".join(str((d - 1) % 7) for d in days)
|
||
await update.message.reply_text("请输入提醒时间(格式:HH:MM)\n例如:09:00")
|
||
return INPUT_TIME # stay in INPUT_TIME for time input
|
||
|
||
elif reminder_type == "interval":
|
||
minutes = int(text)
|
||
if minutes <= 0:
|
||
raise ValueError
|
||
context.user_data["interval_minutes"] = minutes
|
||
await update.message.reply_text(
|
||
"请输入时间窗口(格式:HH:MM-HH:MM)\n例如:09:00-22:00"
|
||
)
|
||
return INPUT_INTERVAL_WINDOW
|
||
|
||
except Exception:
|
||
await update.message.reply_text("格式错误,请重新输入:")
|
||
return INPUT_TIME
|
||
|
||
# For weekly after time input
|
||
if reminder_type == "weekly" and "daily_time" in context.user_data:
|
||
return await ask_holiday(update, context)
|
||
|
||
return INPUT_TIME
|
||
|
||
|
||
async def input_interval_window(
|
||
update: Update, context: ContextTypes.DEFAULT_TYPE
|
||
) -> int:
|
||
text = update.message.text.strip()
|
||
try:
|
||
start, end = text.split("-")
|
||
sh, sm = map(int, start.strip().split(":"))
|
||
eh, em = map(int, end.strip().split(":"))
|
||
if not (0 <= sh < 24 and 0 <= sm < 60 and 0 <= eh < 24 and 0 <= em < 60):
|
||
raise ValueError
|
||
context.user_data["interval_start_time"] = f"{sh:02d}:{sm:02d}"
|
||
context.user_data["interval_end_time"] = f"{eh:02d}:{em:02d}"
|
||
return await ask_holiday(update, context)
|
||
except Exception:
|
||
await update.message.reply_text("格式错误,请重新输入(例如:09:00-22:00):")
|
||
return INPUT_INTERVAL_WINDOW
|
||
|
||
|
||
async def ask_holiday(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int:
|
||
await update.message.reply_text(
|
||
"是否在中国节假日跳过提醒?", reply_markup=yes_no_keyboard()
|
||
)
|
||
return INPUT_HOLIDAY
|
||
|
||
|
||
async def input_holiday(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int:
|
||
text = update.message.text.strip()
|
||
if text == "取消":
|
||
return await cancel(update, context)
|
||
|
||
context.user_data["skip_holidays"] = text == "是"
|
||
|
||
# Show summary
|
||
summary = _build_summary(context.user_data)
|
||
await update.message.reply_text(
|
||
f"请确认提醒信息:\n\n{summary}", reply_markup=confirm_keyboard()
|
||
)
|
||
return CONFIRM
|
||
|
||
|
||
async def confirm(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int:
|
||
text = update.message.text.strip()
|
||
if text == "❌ 取消":
|
||
return await cancel(update, context)
|
||
|
||
if text != "✅ 确认创建":
|
||
await update.message.reply_text("请点击按钮确认或取消。")
|
||
return CONFIRM
|
||
|
||
# Save to database
|
||
user = update.effective_user
|
||
session = Session()
|
||
try:
|
||
user_obj = User.get_or_create(session, user.id, user.username)
|
||
|
||
reminder = Reminder(
|
||
user_id=user_obj.id,
|
||
title=context.user_data["title"],
|
||
description=context.user_data.get("description"),
|
||
reminder_type=context.user_data["reminder_type"],
|
||
once_time=context.user_data.get("once_time"),
|
||
daily_time=context.user_data.get("daily_time"),
|
||
weekly_days=context.user_data.get("weekly_days"),
|
||
interval_minutes=context.user_data.get("interval_minutes"),
|
||
interval_start_time=context.user_data.get("interval_start_time"),
|
||
interval_end_time=context.user_data.get("interval_end_time"),
|
||
skip_holidays=context.user_data.get("skip_holidays", False),
|
||
is_active=True,
|
||
)
|
||
session.add(reminder)
|
||
session.commit()
|
||
reminder_id = reminder.id
|
||
|
||
# Schedule the job
|
||
add_reminder_job(reminder_id)
|
||
|
||
await update.message.reply_text(
|
||
"✅ 提醒创建成功!", reply_markup=main_keyboard()
|
||
)
|
||
except Exception:
|
||
session.rollback()
|
||
await update.message.reply_text(
|
||
"❌ 创建失败,请稍后重试。", reply_markup=main_keyboard()
|
||
)
|
||
finally:
|
||
Session.remove()
|
||
|
||
context.user_data.clear()
|
||
return ConversationHandler.END
|
||
|
||
|
||
async def cancel(update: Update, context: ContextTypes.DEFAULT_TYPE) -> int:
|
||
context.user_data.clear()
|
||
await update.message.reply_text("已取消创建提醒。", reply_markup=main_keyboard())
|
||
return ConversationHandler.END
|
||
|
||
|
||
def _build_summary(data: dict) -> str:
|
||
lines = [f"标题:{data['title']}"]
|
||
if data.get("description"):
|
||
lines.append(f"描述:{data['description']}")
|
||
|
||
rtype = data["reminder_type"]
|
||
type_names = {"once": "一次性", "daily": "每日", "weekly": "每周", "interval": "间隔"}
|
||
lines.append(f"类型:{type_names[rtype]}")
|
||
|
||
if rtype == "once":
|
||
dt = data["once_time"]
|
||
tz = pytz.timezone("Asia/Shanghai")
|
||
lines.append(f"时间:{dt.astimezone(tz).strftime('%Y-%m-%d %H:%M')}")
|
||
elif rtype == "daily":
|
||
lines.append(f"时间:每天 {data['daily_time']}")
|
||
elif rtype == "weekly":
|
||
day_names = ["周一", "周二", "周三", "周四", "周五", "周六", "周日"]
|
||
days = ", ".join(day_names[int(d)] for d in data["weekly_days"].split(","))
|
||
lines.append(f"时间:{days} {data['daily_time']}")
|
||
elif rtype == "interval":
|
||
lines.append(
|
||
f"间隔:每 {data['interval_minutes']} 分钟\n"
|
||
f"时间窗口:{data['interval_start_time']} - {data['interval_end_time']}"
|
||
)
|
||
|
||
lines.append(f"跳过节假日:{'是' if data.get('skip_holidays') else '否'}")
|
||
return "\n".join(lines)
|
||
|
||
|
||
# Build the ConversationHandler
|
||
reminder_conv_handler = ConversationHandler(
|
||
entry_points=[
|
||
CommandHandler("new", new_reminder),
|
||
MessageHandler(filters.Regex("^➕ 新建提醒$"), new_reminder),
|
||
],
|
||
states={
|
||
CHOOSE_TYPE: [MessageHandler(filters.TEXT & ~filters.COMMAND, choose_type)],
|
||
INPUT_TITLE: [MessageHandler(filters.TEXT & ~filters.COMMAND, input_title)],
|
||
INPUT_DESC: [MessageHandler(filters.TEXT & ~filters.COMMAND, input_desc)],
|
||
INPUT_TIME: [MessageHandler(filters.TEXT & ~filters.COMMAND, input_time)],
|
||
INPUT_INTERVAL_WINDOW: [
|
||
MessageHandler(filters.TEXT & ~filters.COMMAND, input_interval_window)
|
||
],
|
||
INPUT_HOLIDAY: [MessageHandler(filters.TEXT & ~filters.COMMAND, input_holiday)],
|
||
CONFIRM: [MessageHandler(filters.TEXT & ~filters.COMMAND, confirm)],
|
||
},
|
||
fallbacks=[CommandHandler("cancel", cancel)],
|
||
)
|