mockapi/app/core/database.py
2026-03-16 09:00:26 +00:00

106 lines
4.3 KiB
Python

from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker, declarative_base
from sqlalchemy import text, event
from app.core.config import settings
import logging
logger = logging.getLogger(__name__)
# Create async engine
engine = create_async_engine(
settings.database_url,
echo=settings.debug,
future=True
)
# Enable SQLite foreign key constraints
@event.listens_for(engine.sync_engine, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
# Create sessionmaker
AsyncSessionLocal = sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False
)
Base = declarative_base()
# Import models to ensure they are registered with Base.metadata
from app.modules.endpoints.models.endpoint_model import Endpoint
from app.modules.oauth2.models import OAuthClient, OAuthToken, OAuthUser
async def get_db() -> AsyncSession:
"""Dependency for getting async database session."""
async with AsyncSessionLocal() as session:
try:
yield session
finally:
await session.close()
async def init_db():
"""Initialize database, create tables."""
async with engine.begin() as conn:
# Migrate existing tables (add missing columns)
await conn.run_sync(_migrate_endpoints_table)
await conn.run_sync(_migrate_oauth_tokens_table)
await conn.run_sync(Base.metadata.create_all)
def _migrate_endpoints_table(connection):
"""Add OAuth columns to endpoints table if they don't exist."""
# Check if endpoints table exists
cursor = connection.execute(text("SELECT name FROM sqlite_master WHERE type='table' AND name='endpoints'"))
if not cursor.fetchone():
logger.info("endpoints table does not exist yet; skipping migration")
return
# Check if requires_oauth column exists
cursor = connection.execute(text("PRAGMA table_info(endpoints)"))
columns = [row[1] for row in cursor.fetchall()]
if "requires_oauth" not in columns:
connection.execute(text("ALTER TABLE endpoints ADD COLUMN requires_oauth BOOLEAN DEFAULT 0"))
logger.info("Added column 'requires_oauth' to endpoints table")
if "oauth_scopes" not in columns:
connection.execute(text("ALTER TABLE endpoints ADD COLUMN oauth_scopes TEXT DEFAULT '[]'"))
logger.info("Added column 'oauth_scopes' to endpoints table")
def _migrate_oauth_tokens_table(connection):
"""Add updated_at column and indexes to oauth_tokens table if missing."""
# Check if oauth_tokens table exists
cursor = connection.execute(text("SELECT name FROM sqlite_master WHERE type='table' AND name='oauth_tokens'"))
if not cursor.fetchone():
logger.info("oauth_tokens table does not exist yet; skipping migration")
return
# Check if updated_at column exists
cursor = connection.execute(text("PRAGMA table_info(oauth_tokens)"))
columns = [row[1] for row in cursor.fetchall()]
if "updated_at" not in columns:
connection.execute(text("ALTER TABLE oauth_tokens ADD COLUMN updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP"))
logger.info("Added column 'updated_at' to oauth_tokens table")
# Create single-column indexes if they don't exist
for column in ['client_id', 'user_id', 'expires_at']:
cursor = connection.execute(text(f"SELECT name FROM sqlite_master WHERE type='index' AND name='ix_oauth_tokens_{column}'"))
if not cursor.fetchone():
connection.execute(text(f"CREATE INDEX ix_oauth_tokens_{column} ON oauth_tokens ({column})"))
logger.info(f"Created index 'ix_oauth_tokens_{column}'")
# Create composite indexes if they don't exist
composite_indexes = [
('ix_oauth_tokens_client_expires', 'client_id', 'expires_at'),
('ix_oauth_tokens_user_expires', 'user_id', 'expires_at'),
]
for name, col1, col2 in composite_indexes:
cursor = connection.execute(text(f"SELECT name FROM sqlite_master WHERE type='index' AND name='{name}'"))
if not cursor.fetchone():
connection.execute(text(f"CREATE INDEX {name} ON oauth_tokens ({col1}, {col2})"))
logger.info(f"Created index '{name}'")