from typing import List, Optional from datetime import datetime from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, update, delete, and_ from sqlalchemy.exc import SQLAlchemyError import logging # Import database first to resolve circular import import database from models.oauth_models import OAuthClient, OAuthToken, OAuthUser logger = logging.getLogger(__name__) class OAuthClientRepository: """Repository for performing CRUD operations on OAuthClient model.""" def __init__(self, session: AsyncSession): self.session = session async def create(self, client_data: dict) -> Optional[OAuthClient]: """ Create a new OAuth client. Args: client_data: Dictionary with client fields. Returns: OAuthClient instance if successful, None otherwise. """ try: client = OAuthClient(**client_data) self.session.add(client) await self.session.commit() await self.session.refresh(client) return client except SQLAlchemyError as e: logger.error(f"Failed to create OAuth client: {e}") await self.session.rollback() return None async def get_by_id(self, client_id: int) -> Optional[OAuthClient]: """ Retrieve a client by its ID. Args: client_id: The client ID. Returns: OAuthClient if found, None otherwise. """ try: stmt = select(OAuthClient).where(OAuthClient.id == client_id) result = await self.session.execute(stmt) return result.scalar_one_or_none() except SQLAlchemyError as e: logger.error(f"Failed to fetch client by id {client_id}: {e}") return None async def get_by_client_id(self, client_id_str: str) -> Optional[OAuthClient]: """ Retrieve a client by its client_id (unique string identifier). Args: client_id_str: The client identifier string. Returns: OAuthClient if found, None otherwise. """ try: stmt = select(OAuthClient).where(OAuthClient.client_id == client_id_str) result = await self.session.execute(stmt) return result.scalar_one_or_none() except SQLAlchemyError as e: logger.error(f"Failed to fetch client by client_id {client_id_str}: {e}") return None async def get_all(self, skip: int = 0, limit: int = 100) -> List[OAuthClient]: """ Retrieve all clients with pagination. Args: skip: Number of records to skip. limit: Maximum number of records to return. Returns: List of OAuthClient objects. """ try: stmt = select(OAuthClient).offset(skip).limit(limit) result = await self.session.execute(stmt) return list(result.scalars().all()) except SQLAlchemyError as e: logger.error(f"Failed to fetch all clients: {e}") return [] async def update(self, client_id: int, client_data: dict) -> Optional[OAuthClient]: """ Update an existing client. Args: client_id: The client ID. client_data: Dictionary of fields to update. Returns: Updated OAuthClient if successful, None otherwise. """ try: stmt = ( update(OAuthClient) .where(OAuthClient.id == client_id) .values(**client_data) .returning(OAuthClient) ) result = await self.session.execute(stmt) await self.session.commit() client = result.scalar_one_or_none() if client: await self.session.refresh(client) return client except SQLAlchemyError as e: logger.error(f"Failed to update client {client_id}: {e}") await self.session.rollback() return None async def delete(self, client_id: int) -> bool: """ Delete a client by ID. Args: client_id: The client ID. Returns: True if deletion succeeded, False otherwise. """ try: stmt = delete(OAuthClient).where(OAuthClient.id == client_id) result = await self.session.execute(stmt) await self.session.commit() return result.rowcount > 0 except SQLAlchemyError as e: logger.error(f"Failed to delete client {client_id}: {e}") await self.session.rollback() return False class OAuthTokenRepository: """Repository for performing CRUD operations on OAuthToken model.""" def __init__(self, session: AsyncSession): self.session = session async def create(self, token_data: dict) -> Optional[OAuthToken]: """ Create a new OAuth token. Args: token_data: Dictionary with token fields. Returns: OAuthToken instance if successful, None otherwise. """ try: token = OAuthToken(**token_data) self.session.add(token) await self.session.commit() await self.session.refresh(token) return token except SQLAlchemyError as e: logger.error(f"Failed to create OAuth token: {e}") await self.session.rollback() return None async def get_by_id(self, token_id: int) -> Optional[OAuthToken]: """ Retrieve a token by its ID. Args: token_id: The token ID. Returns: OAuthToken if found, None otherwise. """ try: stmt = select(OAuthToken).where(OAuthToken.id == token_id) result = await self.session.execute(stmt) return result.scalar_one_or_none() except SQLAlchemyError as e: logger.error(f"Failed to fetch token by id {token_id}: {e}") return None async def get_by_access_token(self, access_token: str) -> Optional[OAuthToken]: """ Retrieve a token by its access token. Args: access_token: The access token string. Returns: OAuthToken if found, None otherwise. """ try: stmt = select(OAuthToken).where(OAuthToken.access_token == access_token) result = await self.session.execute(stmt) return result.scalar_one_or_none() except SQLAlchemyError as e: logger.error(f"Failed to fetch token by access token: {e}") return None async def get_by_refresh_token(self, refresh_token: str) -> Optional[OAuthToken]: """ Retrieve a token by its refresh token. Args: refresh_token: The refresh token string. Returns: OAuthToken if found, None otherwise. """ try: stmt = select(OAuthToken).where(OAuthToken.refresh_token == refresh_token) result = await self.session.execute(stmt) return result.scalar_one_or_none() except SQLAlchemyError as e: logger.error(f"Failed to fetch token by refresh token: {e}") return None async def get_expired_tokens(self) -> List[OAuthToken]: """ Retrieve all expired tokens. Returns: List of expired OAuthToken objects. """ try: stmt = select(OAuthToken).where(OAuthToken.expires_at < datetime.utcnow()) result = await self.session.execute(stmt) return list(result.scalars().all()) except SQLAlchemyError as e: logger.error(f"Failed to fetch expired tokens: {e}") return [] async def revoke_token(self, token_id: int) -> bool: """ Revoke (delete) a token by ID. Args: token_id: The token ID. Returns: True if deletion succeeded, False otherwise. """ return await self.delete(token_id) async def revoke_by_access_token(self, access_token: str) -> bool: """ Revoke (delete) a token by access token. Args: access_token: The access token string. Returns: True if deletion succeeded, False otherwise. """ try: stmt = delete(OAuthToken).where(OAuthToken.access_token == access_token) result = await self.session.execute(stmt) await self.session.commit() return result.rowcount > 0 except SQLAlchemyError as e: logger.error(f"Failed to revoke token by access token: {e}") await self.session.rollback() return False async def get_all(self, skip: int = 0, limit: int = 100) -> List[OAuthToken]: """ Retrieve all tokens with pagination. Args: skip: Number of records to skip. limit: Maximum number of records to return. Returns: List of OAuthToken objects. """ try: stmt = select(OAuthToken).offset(skip).limit(limit) result = await self.session.execute(stmt) return list(result.scalars().all()) except SQLAlchemyError as e: logger.error(f"Failed to fetch all tokens: {e}") return [] async def update(self, token_id: int, token_data: dict) -> Optional[OAuthToken]: """ Update an existing token. Args: token_id: The token ID. token_data: Dictionary of fields to update. Returns: Updated OAuthToken if successful, None otherwise. """ try: stmt = ( update(OAuthToken) .where(OAuthToken.id == token_id) .values(**token_data) .returning(OAuthToken) ) result = await self.session.execute(stmt) await self.session.commit() token = result.scalar_one_or_none() if token: await self.session.refresh(token) return token except SQLAlchemyError as e: logger.error(f"Failed to update token {token_id}: {e}") await self.session.rollback() return None async def delete(self, token_id: int) -> bool: """ Delete a token by ID. Args: token_id: The token ID. Returns: True if deletion succeeded, False otherwise. """ try: stmt = delete(OAuthToken).where(OAuthToken.id == token_id) result = await self.session.execute(stmt) await self.session.commit() return result.rowcount > 0 except SQLAlchemyError as e: logger.error(f"Failed to delete token {token_id}: {e}") await self.session.rollback() return False class OAuthUserRepository: """Repository for performing CRUD operations on OAuthUser model.""" def __init__(self, session: AsyncSession): self.session = session async def create(self, user_data: dict) -> Optional[OAuthUser]: """ Create a new OAuth user. Args: user_data: Dictionary with user fields. Returns: OAuthUser instance if successful, None otherwise. """ try: user = OAuthUser(**user_data) self.session.add(user) await self.session.commit() await self.session.refresh(user) return user except SQLAlchemyError as e: logger.error(f"Failed to create OAuth user: {e}") await self.session.rollback() return None async def get_by_id(self, user_id: int) -> Optional[OAuthUser]: """ Retrieve a user by its ID. Args: user_id: The user ID. Returns: OAuthUser if found, None otherwise. """ try: stmt = select(OAuthUser).where(OAuthUser.id == user_id) result = await self.session.execute(stmt) return result.scalar_one_or_none() except SQLAlchemyError as e: logger.error(f"Failed to fetch user by id {user_id}: {e}") return None async def get_by_username(self, username: str) -> Optional[OAuthUser]: """ Retrieve a user by username. Args: username: The username string. Returns: OAuthUser if found, None otherwise. """ try: stmt = select(OAuthUser).where(OAuthUser.username == username) result = await self.session.execute(stmt) return result.scalar_one_or_none() except SQLAlchemyError as e: logger.error(f"Failed to fetch user by username {username}: {e}") return None async def get_by_email(self, email: str) -> Optional[OAuthUser]: """ Retrieve a user by email. Args: email: The email address. Returns: OAuthUser if found, None otherwise. """ try: stmt = select(OAuthUser).where(OAuthUser.email == email) result = await self.session.execute(stmt) return result.scalar_one_or_none() except SQLAlchemyError as e: logger.error(f"Failed to fetch user by email {email}: {e}") return None async def get_all(self, skip: int = 0, limit: int = 100) -> List[OAuthUser]: """ Retrieve all users with pagination. Args: skip: Number of records to skip. limit: Maximum number of records to return. Returns: List of OAuthUser objects. """ try: stmt = select(OAuthUser).offset(skip).limit(limit) result = await self.session.execute(stmt) return list(result.scalars().all()) except SQLAlchemyError as e: logger.error(f"Failed to fetch all users: {e}") return [] async def update(self, user_id: int, user_data: dict) -> Optional[OAuthUser]: """ Update an existing user. Args: user_id: The user ID. user_data: Dictionary of fields to update. Returns: Updated OAuthUser if successful, None otherwise. """ try: stmt = ( update(OAuthUser) .where(OAuthUser.id == user_id) .values(**user_data) .returning(OAuthUser) ) result = await self.session.execute(stmt) await self.session.commit() user = result.scalar_one_or_none() if user: await self.session.refresh(user) return user except SQLAlchemyError as e: logger.error(f"Failed to update user {user_id}: {e}") await self.session.rollback() return None async def delete(self, user_id: int) -> bool: """ Delete a user by ID. Args: user_id: The user ID. Returns: True if deletion succeeded, False otherwise. """ try: stmt = delete(OAuthUser).where(OAuthUser.id == user_id) result = await self.session.execute(stmt) await self.session.commit() return result.rowcount > 0 except SQLAlchemyError as e: logger.error(f"Failed to delete user {user_id}: {e}") await self.session.rollback() return False