34 lines
1.1 KiB
Python
34 lines
1.1 KiB
Python
import os
|
|
from datetime import datetime
|
|
from sqlalchemy import Column, Integer, Float, BigInteger, String, DateTime
|
|
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
|
from sqlalchemy.orm import sessionmaker, declarative_base
|
|
|
|
# 从环境变量获取数据库连接 URL
|
|
DB_URL = os.getenv("DB_URL")
|
|
|
|
# 初始化 SQLAlchemy 基础类
|
|
Base = declarative_base()
|
|
|
|
class Record(Base):
|
|
__tablename__ = 'records'
|
|
|
|
id = Column(Integer, primary_key=True)
|
|
user_id = Column(BigInteger, index=True) # 区分不同用户
|
|
amount = Column(Float, nullable=False) # 花了多少钱
|
|
category = Column(String(100)) # 属于什么类别
|
|
transaction_time = Column(DateTime, default=datetime.now) # 在那个时间
|
|
|
|
# 创建异步引擎
|
|
engine = create_async_engine(DB_URL, echo=False)
|
|
|
|
# 创建异步会话工厂
|
|
AsyncSessionLocal = sessionmaker(
|
|
engine, class_=AsyncSession, expire_on_commit=False
|
|
)
|
|
|
|
# 初始化数据库表
|
|
async def init_db():
|
|
async with engine.begin() as conn:
|
|
# 如果表不存在则创建
|
|
await conn.run_sync(Base.metadata.create_all) |