128 lines
3.6 KiB
Python
128 lines
3.6 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
from datetime import datetime, timezone
|
|
from typing import Optional
|
|
|
|
import pytz
|
|
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
|
from apscheduler.triggers.cron import CronTrigger
|
|
from apscheduler.triggers.date import DateTrigger
|
|
from apscheduler.triggers.interval import IntervalTrigger
|
|
from telegram import Bot
|
|
|
|
from bot.models.database import Session
|
|
from bot.models.reminder import Reminder
|
|
from bot.scheduler.executor import execute_reminder
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_scheduler: Optional[AsyncIOScheduler] = None
|
|
_bot: Optional[Bot] = None
|
|
|
|
SHANGHAI_TZ = pytz.timezone("Asia/Shanghai")
|
|
|
|
|
|
def init_scheduler(bot: Bot) -> AsyncIOScheduler:
|
|
global _scheduler, _bot
|
|
_bot = bot
|
|
_scheduler = AsyncIOScheduler(timezone=SHANGHAI_TZ)
|
|
_load_all_reminders()
|
|
_scheduler.start()
|
|
logger.info("Scheduler started")
|
|
return _scheduler
|
|
|
|
|
|
def shutdown_scheduler() -> None:
|
|
if _scheduler and _scheduler.running:
|
|
_scheduler.shutdown(wait=False)
|
|
logger.info("Scheduler stopped")
|
|
|
|
|
|
def _load_all_reminders() -> None:
|
|
session = Session()
|
|
try:
|
|
reminders = Reminder.get_active(session)
|
|
for reminder in reminders:
|
|
_add_job(reminder)
|
|
logger.info("Loaded %d active reminders", len(reminders))
|
|
finally:
|
|
Session.remove()
|
|
|
|
|
|
def add_reminder_job(reminder_id: int) -> None:
|
|
session = Session()
|
|
try:
|
|
reminder = session.get(Reminder, reminder_id)
|
|
if reminder:
|
|
_add_job(reminder)
|
|
finally:
|
|
Session.remove()
|
|
|
|
|
|
def remove_reminder_job(reminder_id: int) -> None:
|
|
if _scheduler is None:
|
|
return
|
|
job_id = f"reminder_{reminder_id}"
|
|
if _scheduler.get_job(job_id):
|
|
_scheduler.remove_job(job_id)
|
|
logger.info("Removed job %s", job_id)
|
|
|
|
|
|
def _add_job(reminder: Reminder) -> None:
|
|
if _scheduler is None or _bot is None:
|
|
return
|
|
|
|
job_id = f"reminder_{reminder.id}"
|
|
# Remove existing job if any
|
|
if _scheduler.get_job(job_id):
|
|
_scheduler.remove_job(job_id)
|
|
|
|
trigger = _build_trigger(reminder)
|
|
if trigger is None:
|
|
return
|
|
|
|
_scheduler.add_job(
|
|
execute_reminder,
|
|
trigger=trigger,
|
|
id=job_id,
|
|
kwargs={"reminder_id": reminder.id, "bot": _bot},
|
|
replace_existing=True,
|
|
misfire_grace_time=60,
|
|
)
|
|
logger.info("Scheduled job %s (type=%s)", job_id, reminder.reminder_type)
|
|
|
|
|
|
def _build_trigger(reminder: Reminder):
|
|
rtype = reminder.reminder_type
|
|
|
|
if rtype == "once":
|
|
if reminder.once_time is None:
|
|
return None
|
|
run_time = reminder.once_time
|
|
if run_time.tzinfo is None:
|
|
run_time = pytz.utc.localize(run_time)
|
|
if run_time <= datetime.now(timezone.utc).astimezone(SHANGHAI_TZ):
|
|
return None # already passed
|
|
return DateTrigger(run_date=run_time, timezone=SHANGHAI_TZ)
|
|
|
|
if rtype == "daily":
|
|
if not reminder.daily_time:
|
|
return None
|
|
h, m = map(int, reminder.daily_time.split(":"))
|
|
return CronTrigger(hour=h, minute=m, timezone=SHANGHAI_TZ)
|
|
|
|
if rtype == "weekly":
|
|
if not reminder.weekly_days or not reminder.daily_time:
|
|
return None
|
|
h, m = map(int, reminder.daily_time.split(":"))
|
|
dow = reminder.weekly_days # e.g. "0,2,4"
|
|
return CronTrigger(day_of_week=dow, hour=h, minute=m, timezone=SHANGHAI_TZ)
|
|
|
|
if rtype == "interval":
|
|
if not reminder.interval_minutes:
|
|
return None
|
|
return IntervalTrigger(minutes=reminder.interval_minutes, timezone=SHANGHAI_TZ)
|
|
|
|
return None
|