mockapi/repositories/endpoint_repository.py
2026-03-16 05:47:01 +00:00

161 lines
No EOL
5.1 KiB
Python

from typing import List, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update, delete
from sqlalchemy.exc import SQLAlchemyError
import logging
from models.endpoint_model import Endpoint
logger = logging.getLogger(__name__)
class EndpointRepository:
"""Repository for performing CRUD operations on Endpoint model."""
def __init__(self, session: AsyncSession):
self.session = session
async def create(self, endpoint_data: dict) -> Optional[Endpoint]:
"""
Create a new endpoint.
Args:
endpoint_data: Dictionary with endpoint fields.
Returns:
Endpoint instance if successful, None otherwise.
"""
try:
endpoint = Endpoint(**endpoint_data)
self.session.add(endpoint)
await self.session.commit()
await self.session.refresh(endpoint)
return endpoint
except SQLAlchemyError as e:
logger.error(f"Failed to create endpoint: {e}")
await self.session.rollback()
return None
async def get_by_id(self, endpoint_id: int) -> Optional[Endpoint]:
"""
Retrieve an endpoint by its ID.
Args:
endpoint_id: The endpoint ID.
Returns:
Endpoint if found, None otherwise.
"""
try:
stmt = select(Endpoint).where(Endpoint.id == endpoint_id)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
except SQLAlchemyError as e:
logger.error(f"Failed to fetch endpoint by id {endpoint_id}: {e}")
return None
async def get_all(self, skip: int = 0, limit: int = 100) -> List[Endpoint]:
"""
Retrieve all endpoints with pagination.
Args:
skip: Number of records to skip.
limit: Maximum number of records to return.
Returns:
List of Endpoint objects.
"""
try:
stmt = select(Endpoint).offset(skip).limit(limit)
result = await self.session.execute(stmt)
return list(result.scalars().all())
except SQLAlchemyError as e:
logger.error(f"Failed to fetch all endpoints: {e}")
return []
async def get_active(self) -> List[Endpoint]:
"""
Retrieve all active endpoints.
Returns:
List of active Endpoint objects.
"""
try:
stmt = select(Endpoint).where(Endpoint.is_active == True)
result = await self.session.execute(stmt)
return list(result.scalars().all())
except SQLAlchemyError as e:
logger.error(f"Failed to fetch active endpoints: {e}")
return []
async def update(self, endpoint_id: int, endpoint_data: dict) -> Optional[Endpoint]:
"""
Update an existing endpoint.
Args:
endpoint_id: The endpoint ID.
endpoint_data: Dictionary of fields to update.
Returns:
Updated Endpoint if successful, None otherwise.
"""
try:
stmt = (
update(Endpoint)
.where(Endpoint.id == endpoint_id)
.values(**endpoint_data)
.returning(Endpoint)
)
result = await self.session.execute(stmt)
await self.session.commit()
endpoint = result.scalar_one_or_none()
if endpoint:
await self.session.refresh(endpoint)
return endpoint
except SQLAlchemyError as e:
logger.error(f"Failed to update endpoint {endpoint_id}: {e}")
await self.session.rollback()
return None
async def delete(self, endpoint_id: int) -> bool:
"""
Delete an endpoint by ID.
Args:
endpoint_id: The endpoint ID.
Returns:
True if deletion succeeded, False otherwise.
"""
try:
stmt = delete(Endpoint).where(Endpoint.id == endpoint_id)
result = await self.session.execute(stmt)
await self.session.commit()
return result.rowcount > 0
except SQLAlchemyError as e:
logger.error(f"Failed to delete endpoint {endpoint_id}: {e}")
await self.session.rollback()
return False
async def get_by_route_and_method(self, route: str, method: str) -> Optional[Endpoint]:
"""
Retrieve an endpoint by route and HTTP method.
Args:
route: The endpoint route (path).
method: HTTP method (GET, POST, etc.).
Returns:
Endpoint if found, None otherwise.
"""
try:
stmt = select(Endpoint).where(
Endpoint.route == route,
Endpoint.method == method.upper()
)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
except SQLAlchemyError as e:
logger.error(f"Failed to fetch endpoint {method} {route}: {e}")
return None