161 lines
No EOL
5.1 KiB
Python
161 lines
No EOL
5.1 KiB
Python
from typing import List, Optional
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select, update, delete
|
|
from sqlalchemy.exc import SQLAlchemyError
|
|
import logging
|
|
|
|
from models.endpoint_model import Endpoint
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class EndpointRepository:
|
|
"""Repository for performing CRUD operations on Endpoint model."""
|
|
|
|
def __init__(self, session: AsyncSession):
|
|
self.session = session
|
|
|
|
async def create(self, endpoint_data: dict) -> Optional[Endpoint]:
|
|
"""
|
|
Create a new endpoint.
|
|
|
|
Args:
|
|
endpoint_data: Dictionary with endpoint fields.
|
|
|
|
Returns:
|
|
Endpoint instance if successful, None otherwise.
|
|
"""
|
|
try:
|
|
endpoint = Endpoint(**endpoint_data)
|
|
self.session.add(endpoint)
|
|
await self.session.commit()
|
|
await self.session.refresh(endpoint)
|
|
return endpoint
|
|
except SQLAlchemyError as e:
|
|
logger.error(f"Failed to create endpoint: {e}")
|
|
await self.session.rollback()
|
|
return None
|
|
|
|
async def get_by_id(self, endpoint_id: int) -> Optional[Endpoint]:
|
|
"""
|
|
Retrieve an endpoint by its ID.
|
|
|
|
Args:
|
|
endpoint_id: The endpoint ID.
|
|
|
|
Returns:
|
|
Endpoint if found, None otherwise.
|
|
"""
|
|
try:
|
|
stmt = select(Endpoint).where(Endpoint.id == endpoint_id)
|
|
result = await self.session.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
except SQLAlchemyError as e:
|
|
logger.error(f"Failed to fetch endpoint by id {endpoint_id}: {e}")
|
|
return None
|
|
|
|
async def get_all(self, skip: int = 0, limit: int = 100) -> List[Endpoint]:
|
|
"""
|
|
Retrieve all endpoints with pagination.
|
|
|
|
Args:
|
|
skip: Number of records to skip.
|
|
limit: Maximum number of records to return.
|
|
|
|
Returns:
|
|
List of Endpoint objects.
|
|
"""
|
|
try:
|
|
stmt = select(Endpoint).offset(skip).limit(limit)
|
|
result = await self.session.execute(stmt)
|
|
return list(result.scalars().all())
|
|
except SQLAlchemyError as e:
|
|
logger.error(f"Failed to fetch all endpoints: {e}")
|
|
return []
|
|
|
|
async def get_active(self) -> List[Endpoint]:
|
|
"""
|
|
Retrieve all active endpoints.
|
|
|
|
Returns:
|
|
List of active Endpoint objects.
|
|
"""
|
|
try:
|
|
stmt = select(Endpoint).where(Endpoint.is_active == True)
|
|
result = await self.session.execute(stmt)
|
|
return list(result.scalars().all())
|
|
except SQLAlchemyError as e:
|
|
logger.error(f"Failed to fetch active endpoints: {e}")
|
|
return []
|
|
|
|
async def update(self, endpoint_id: int, endpoint_data: dict) -> Optional[Endpoint]:
|
|
"""
|
|
Update an existing endpoint.
|
|
|
|
Args:
|
|
endpoint_id: The endpoint ID.
|
|
endpoint_data: Dictionary of fields to update.
|
|
|
|
Returns:
|
|
Updated Endpoint if successful, None otherwise.
|
|
"""
|
|
try:
|
|
stmt = (
|
|
update(Endpoint)
|
|
.where(Endpoint.id == endpoint_id)
|
|
.values(**endpoint_data)
|
|
.returning(Endpoint)
|
|
)
|
|
result = await self.session.execute(stmt)
|
|
await self.session.commit()
|
|
endpoint = result.scalar_one_or_none()
|
|
if endpoint:
|
|
await self.session.refresh(endpoint)
|
|
return endpoint
|
|
except SQLAlchemyError as e:
|
|
logger.error(f"Failed to update endpoint {endpoint_id}: {e}")
|
|
await self.session.rollback()
|
|
return None
|
|
|
|
async def delete(self, endpoint_id: int) -> bool:
|
|
"""
|
|
Delete an endpoint by ID.
|
|
|
|
Args:
|
|
endpoint_id: The endpoint ID.
|
|
|
|
Returns:
|
|
True if deletion succeeded, False otherwise.
|
|
"""
|
|
try:
|
|
stmt = delete(Endpoint).where(Endpoint.id == endpoint_id)
|
|
result = await self.session.execute(stmt)
|
|
await self.session.commit()
|
|
return result.rowcount > 0
|
|
except SQLAlchemyError as e:
|
|
logger.error(f"Failed to delete endpoint {endpoint_id}: {e}")
|
|
await self.session.rollback()
|
|
return False
|
|
|
|
async def get_by_route_and_method(self, route: str, method: str) -> Optional[Endpoint]:
|
|
"""
|
|
Retrieve an endpoint by route and HTTP method.
|
|
|
|
Args:
|
|
route: The endpoint route (path).
|
|
method: HTTP method (GET, POST, etc.).
|
|
|
|
Returns:
|
|
Endpoint if found, None otherwise.
|
|
"""
|
|
try:
|
|
stmt = select(Endpoint).where(
|
|
Endpoint.route == route,
|
|
Endpoint.method == method.upper()
|
|
)
|
|
result = await self.session.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
except SQLAlchemyError as e:
|
|
logger.error(f"Failed to fetch endpoint {method} {route}: {e}")
|
|
return None |