from typing import List, Optional from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, update, delete from sqlalchemy.exc import SQLAlchemyError import logging from app.modules.endpoints.models.endpoint_model import Endpoint logger = logging.getLogger(__name__) class EndpointRepository: """Repository for performing CRUD operations on Endpoint model.""" def __init__(self, session: AsyncSession): self.session = session async def create(self, endpoint_data: dict) -> Optional[Endpoint]: """ Create a new endpoint. Args: endpoint_data: Dictionary with endpoint fields. Returns: Endpoint instance if successful, None otherwise. """ try: endpoint = Endpoint(**endpoint_data) self.session.add(endpoint) await self.session.commit() await self.session.refresh(endpoint) return endpoint except SQLAlchemyError as e: logger.error(f"Failed to create endpoint: {e}") await self.session.rollback() return None async def get_by_id(self, endpoint_id: int) -> Optional[Endpoint]: """ Retrieve an endpoint by its ID. Args: endpoint_id: The endpoint ID. Returns: Endpoint if found, None otherwise. """ try: stmt = select(Endpoint).where(Endpoint.id == endpoint_id) result = await self.session.execute(stmt) return result.scalar_one_or_none() except SQLAlchemyError as e: logger.error(f"Failed to fetch endpoint by id {endpoint_id}: {e}") return None async def get_all(self, skip: int = 0, limit: int = 100) -> List[Endpoint]: """ Retrieve all endpoints with pagination. Args: skip: Number of records to skip. limit: Maximum number of records to return. Returns: List of Endpoint objects. """ try: stmt = select(Endpoint).offset(skip).limit(limit) result = await self.session.execute(stmt) return list(result.scalars().all()) except SQLAlchemyError as e: logger.error(f"Failed to fetch all endpoints: {e}") return [] async def get_active(self) -> List[Endpoint]: """ Retrieve all active endpoints. Returns: List of active Endpoint objects. """ try: stmt = select(Endpoint).where(Endpoint.is_active == True) result = await self.session.execute(stmt) return list(result.scalars().all()) except SQLAlchemyError as e: logger.error(f"Failed to fetch active endpoints: {e}") return [] async def update(self, endpoint_id: int, endpoint_data: dict) -> Optional[Endpoint]: """ Update an existing endpoint. Args: endpoint_id: The endpoint ID. endpoint_data: Dictionary of fields to update. Returns: Updated Endpoint if successful, None otherwise. """ try: stmt = ( update(Endpoint) .where(Endpoint.id == endpoint_id) .values(**endpoint_data) .returning(Endpoint) ) result = await self.session.execute(stmt) await self.session.commit() endpoint = result.scalar_one_or_none() if endpoint: await self.session.refresh(endpoint) return endpoint except SQLAlchemyError as e: logger.error(f"Failed to update endpoint {endpoint_id}: {e}") await self.session.rollback() return None async def delete(self, endpoint_id: int) -> bool: """ Delete an endpoint by ID. Args: endpoint_id: The endpoint ID. Returns: True if deletion succeeded, False otherwise. """ try: stmt = delete(Endpoint).where(Endpoint.id == endpoint_id) result = await self.session.execute(stmt) await self.session.commit() return result.rowcount > 0 except SQLAlchemyError as e: logger.error(f"Failed to delete endpoint {endpoint_id}: {e}") await self.session.rollback() return False async def get_by_route_and_method(self, route: str, method: str) -> Optional[Endpoint]: """ Retrieve an endpoint by route and HTTP method. Args: route: The endpoint route (path). method: HTTP method (GET, POST, etc.). Returns: Endpoint if found, None otherwise. """ try: stmt = select(Endpoint).where( Endpoint.route == route, Endpoint.method == method.upper() ) result = await self.session.execute(stmt) return result.scalar_one_or_none() except SQLAlchemyError as e: logger.error(f"Failed to fetch endpoint {method} {route}: {e}") return None