105 lines
4.2 KiB
Python
105 lines
4.2 KiB
Python
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
|
from sqlalchemy.orm import sessionmaker, declarative_base
|
|
from sqlalchemy import text, event
|
|
from config import settings
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Create async engine
|
|
engine = create_async_engine(
|
|
settings.database_url,
|
|
echo=settings.debug,
|
|
future=True
|
|
)
|
|
|
|
# Enable SQLite foreign key constraints
|
|
@event.listens_for(engine.sync_engine, "connect")
|
|
def set_sqlite_pragma(dbapi_connection, connection_record):
|
|
cursor = dbapi_connection.cursor()
|
|
cursor.execute("PRAGMA foreign_keys=ON")
|
|
cursor.close()
|
|
|
|
# Create sessionmaker
|
|
AsyncSessionLocal = sessionmaker(
|
|
engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False
|
|
)
|
|
|
|
Base = declarative_base()
|
|
|
|
# Import models to ensure they are registered with Base.metadata
|
|
from models import Endpoint, OAuthClient, OAuthToken, OAuthUser
|
|
|
|
|
|
async def get_db() -> AsyncSession:
|
|
"""Dependency for getting async database session."""
|
|
async with AsyncSessionLocal() as session:
|
|
try:
|
|
yield session
|
|
finally:
|
|
await session.close()
|
|
|
|
|
|
async def init_db():
|
|
"""Initialize database, create tables."""
|
|
async with engine.begin() as conn:
|
|
# Migrate existing tables (add missing columns)
|
|
await conn.run_sync(_migrate_endpoints_table)
|
|
await conn.run_sync(_migrate_oauth_tokens_table)
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
|
|
def _migrate_endpoints_table(connection):
|
|
"""Add OAuth columns to endpoints table if they don't exist."""
|
|
# Check if endpoints table exists
|
|
cursor = connection.execute(text("SELECT name FROM sqlite_master WHERE type='table' AND name='endpoints'"))
|
|
if not cursor.fetchone():
|
|
logger.info("endpoints table does not exist yet; skipping migration")
|
|
return
|
|
|
|
# Check if requires_oauth column exists
|
|
cursor = connection.execute(text("PRAGMA table_info(endpoints)"))
|
|
columns = [row[1] for row in cursor.fetchall()]
|
|
|
|
if "requires_oauth" not in columns:
|
|
connection.execute(text("ALTER TABLE endpoints ADD COLUMN requires_oauth BOOLEAN DEFAULT 0"))
|
|
logger.info("Added column 'requires_oauth' to endpoints table")
|
|
|
|
if "oauth_scopes" not in columns:
|
|
connection.execute(text("ALTER TABLE endpoints ADD COLUMN oauth_scopes TEXT DEFAULT '[]'"))
|
|
logger.info("Added column 'oauth_scopes' to endpoints table")
|
|
|
|
def _migrate_oauth_tokens_table(connection):
|
|
"""Add updated_at column and indexes to oauth_tokens table if missing."""
|
|
# Check if oauth_tokens table exists
|
|
cursor = connection.execute(text("SELECT name FROM sqlite_master WHERE type='table' AND name='oauth_tokens'"))
|
|
if not cursor.fetchone():
|
|
logger.info("oauth_tokens table does not exist yet; skipping migration")
|
|
return
|
|
|
|
# Check if updated_at column exists
|
|
cursor = connection.execute(text("PRAGMA table_info(oauth_tokens)"))
|
|
columns = [row[1] for row in cursor.fetchall()]
|
|
|
|
if "updated_at" not in columns:
|
|
connection.execute(text("ALTER TABLE oauth_tokens ADD COLUMN updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP"))
|
|
logger.info("Added column 'updated_at' to oauth_tokens table")
|
|
|
|
# Create single-column indexes if they don't exist
|
|
for column in ['client_id', 'user_id', 'expires_at']:
|
|
cursor = connection.execute(text(f"SELECT name FROM sqlite_master WHERE type='index' AND name='ix_oauth_tokens_{column}'"))
|
|
if not cursor.fetchone():
|
|
connection.execute(text(f"CREATE INDEX ix_oauth_tokens_{column} ON oauth_tokens ({column})"))
|
|
logger.info(f"Created index 'ix_oauth_tokens_{column}'")
|
|
|
|
# Create composite indexes if they don't exist
|
|
composite_indexes = [
|
|
('ix_oauth_tokens_client_expires', 'client_id', 'expires_at'),
|
|
('ix_oauth_tokens_user_expires', 'user_id', 'expires_at'),
|
|
]
|
|
for name, col1, col2 in composite_indexes:
|
|
cursor = connection.execute(text(f"SELECT name FROM sqlite_master WHERE type='index' AND name='{name}'"))
|
|
if not cursor.fetchone():
|
|
connection.execute(text(f"CREATE INDEX {name} ON oauth_tokens ({col1}, {col2})"))
|
|
logger.info(f"Created index '{name}'")
|