from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker, declarative_base from sqlalchemy import text, event from app.core.config import settings import logging logger = logging.getLogger(__name__) # Create async engine engine = create_async_engine( settings.database_url, echo=settings.debug, future=True ) # Enable SQLite foreign key constraints @event.listens_for(engine.sync_engine, "connect") def set_sqlite_pragma(dbapi_connection, connection_record): cursor = dbapi_connection.cursor() cursor.execute("PRAGMA foreign_keys=ON") cursor.close() # Create sessionmaker AsyncSessionLocal = sessionmaker( engine, class_=AsyncSession, expire_on_commit=False ) Base = declarative_base() # Import models to ensure they are registered with Base.metadata # Models are imported elsewhere (route_service, oauth2 module) to avoid circular imports # from app.modules.endpoints.models.endpoint_model import Endpoint # from app.modules.oauth2.models import OAuthClient, OAuthToken, OAuthUser async def get_db() -> AsyncSession: """Dependency for getting async database session.""" async with AsyncSessionLocal() as session: try: yield session finally: await session.close() async def init_db(): """Initialize database, create tables.""" async with engine.begin() as conn: # Migrate existing tables (add missing columns) await conn.run_sync(_migrate_endpoints_table) await conn.run_sync(_migrate_oauth_tokens_table) await conn.run_sync(Base.metadata.create_all) def _migrate_endpoints_table(connection): """Add OAuth columns to endpoints table if they don't exist.""" # Check if endpoints table exists cursor = connection.execute(text("SELECT name FROM sqlite_master WHERE type='table' AND name='endpoints'")) if not cursor.fetchone(): logger.info("endpoints table does not exist yet; skipping migration") return # Check if requires_oauth column exists cursor = connection.execute(text("PRAGMA table_info(endpoints)")) columns = [row[1] for row in cursor.fetchall()] if "requires_oauth" not in columns: connection.execute(text("ALTER TABLE endpoints ADD COLUMN requires_oauth BOOLEAN DEFAULT 0")) logger.info("Added column 'requires_oauth' to endpoints table") if "oauth_scopes" not in columns: connection.execute(text("ALTER TABLE endpoints ADD COLUMN oauth_scopes TEXT DEFAULT '[]'")) logger.info("Added column 'oauth_scopes' to endpoints table") def _migrate_oauth_tokens_table(connection): """Add updated_at column and indexes to oauth_tokens table if missing.""" # Check if oauth_tokens table exists cursor = connection.execute(text("SELECT name FROM sqlite_master WHERE type='table' AND name='oauth_tokens'")) if not cursor.fetchone(): logger.info("oauth_tokens table does not exist yet; skipping migration") return # Check if updated_at column exists cursor = connection.execute(text("PRAGMA table_info(oauth_tokens)")) columns = [row[1] for row in cursor.fetchall()] if "updated_at" not in columns: connection.execute(text("ALTER TABLE oauth_tokens ADD COLUMN updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP")) logger.info("Added column 'updated_at' to oauth_tokens table") # Create single-column indexes if they don't exist for column in ['client_id', 'user_id', 'expires_at']: cursor = connection.execute(text(f"SELECT name FROM sqlite_master WHERE type='index' AND name='ix_oauth_tokens_{column}'")) if not cursor.fetchone(): connection.execute(text(f"CREATE INDEX ix_oauth_tokens_{column} ON oauth_tokens ({column})")) logger.info(f"Created index 'ix_oauth_tokens_{column}'") # Create composite indexes if they don't exist composite_indexes = [ ('ix_oauth_tokens_client_expires', 'client_id', 'expires_at'), ('ix_oauth_tokens_user_expires', 'user_id', 'expires_at'), ] for name, col1, col2 in composite_indexes: cursor = connection.execute(text(f"SELECT name FROM sqlite_master WHERE type='index' AND name='{name}'")) if not cursor.fetchone(): connection.execute(text(f"CREATE INDEX {name} ON oauth_tokens ({col1}, {col2})")) logger.info(f"Created index '{name}'")