mockapi/oauth2/repositories.py
2026-03-16 05:47:01 +00:00

492 lines
No EOL
15 KiB
Python

from typing import List, Optional
from datetime import datetime
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update, delete, and_
from sqlalchemy.exc import SQLAlchemyError
import logging
# Import database first to resolve circular import
import database
from models.oauth_models import OAuthClient, OAuthToken, OAuthUser
logger = logging.getLogger(__name__)
class OAuthClientRepository:
"""Repository for performing CRUD operations on OAuthClient model."""
def __init__(self, session: AsyncSession):
self.session = session
async def create(self, client_data: dict) -> Optional[OAuthClient]:
"""
Create a new OAuth client.
Args:
client_data: Dictionary with client fields.
Returns:
OAuthClient instance if successful, None otherwise.
"""
try:
client = OAuthClient(**client_data)
self.session.add(client)
await self.session.commit()
await self.session.refresh(client)
return client
except SQLAlchemyError as e:
logger.error(f"Failed to create OAuth client: {e}")
await self.session.rollback()
return None
async def get_by_id(self, client_id: int) -> Optional[OAuthClient]:
"""
Retrieve a client by its ID.
Args:
client_id: The client ID.
Returns:
OAuthClient if found, None otherwise.
"""
try:
stmt = select(OAuthClient).where(OAuthClient.id == client_id)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
except SQLAlchemyError as e:
logger.error(f"Failed to fetch client by id {client_id}: {e}")
return None
async def get_by_client_id(self, client_id_str: str) -> Optional[OAuthClient]:
"""
Retrieve a client by its client_id (unique string identifier).
Args:
client_id_str: The client identifier string.
Returns:
OAuthClient if found, None otherwise.
"""
try:
stmt = select(OAuthClient).where(OAuthClient.client_id == client_id_str)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
except SQLAlchemyError as e:
logger.error(f"Failed to fetch client by client_id {client_id_str}: {e}")
return None
async def get_all(self, skip: int = 0, limit: int = 100) -> List[OAuthClient]:
"""
Retrieve all clients with pagination.
Args:
skip: Number of records to skip.
limit: Maximum number of records to return.
Returns:
List of OAuthClient objects.
"""
try:
stmt = select(OAuthClient).offset(skip).limit(limit)
result = await self.session.execute(stmt)
return list(result.scalars().all())
except SQLAlchemyError as e:
logger.error(f"Failed to fetch all clients: {e}")
return []
async def update(self, client_id: int, client_data: dict) -> Optional[OAuthClient]:
"""
Update an existing client.
Args:
client_id: The client ID.
client_data: Dictionary of fields to update.
Returns:
Updated OAuthClient if successful, None otherwise.
"""
try:
stmt = (
update(OAuthClient)
.where(OAuthClient.id == client_id)
.values(**client_data)
.returning(OAuthClient)
)
result = await self.session.execute(stmt)
await self.session.commit()
client = result.scalar_one_or_none()
if client:
await self.session.refresh(client)
return client
except SQLAlchemyError as e:
logger.error(f"Failed to update client {client_id}: {e}")
await self.session.rollback()
return None
async def delete(self, client_id: int) -> bool:
"""
Delete a client by ID.
Args:
client_id: The client ID.
Returns:
True if deletion succeeded, False otherwise.
"""
try:
stmt = delete(OAuthClient).where(OAuthClient.id == client_id)
result = await self.session.execute(stmt)
await self.session.commit()
return result.rowcount > 0
except SQLAlchemyError as e:
logger.error(f"Failed to delete client {client_id}: {e}")
await self.session.rollback()
return False
class OAuthTokenRepository:
"""Repository for performing CRUD operations on OAuthToken model."""
def __init__(self, session: AsyncSession):
self.session = session
async def create(self, token_data: dict) -> Optional[OAuthToken]:
"""
Create a new OAuth token.
Args:
token_data: Dictionary with token fields.
Returns:
OAuthToken instance if successful, None otherwise.
"""
try:
token = OAuthToken(**token_data)
self.session.add(token)
await self.session.commit()
await self.session.refresh(token)
return token
except SQLAlchemyError as e:
logger.error(f"Failed to create OAuth token: {e}")
await self.session.rollback()
return None
async def get_by_id(self, token_id: int) -> Optional[OAuthToken]:
"""
Retrieve a token by its ID.
Args:
token_id: The token ID.
Returns:
OAuthToken if found, None otherwise.
"""
try:
stmt = select(OAuthToken).where(OAuthToken.id == token_id)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
except SQLAlchemyError as e:
logger.error(f"Failed to fetch token by id {token_id}: {e}")
return None
async def get_by_access_token(self, access_token: str) -> Optional[OAuthToken]:
"""
Retrieve a token by its access token.
Args:
access_token: The access token string.
Returns:
OAuthToken if found, None otherwise.
"""
try:
stmt = select(OAuthToken).where(OAuthToken.access_token == access_token)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
except SQLAlchemyError as e:
logger.error(f"Failed to fetch token by access token: {e}")
return None
async def get_by_refresh_token(self, refresh_token: str) -> Optional[OAuthToken]:
"""
Retrieve a token by its refresh token.
Args:
refresh_token: The refresh token string.
Returns:
OAuthToken if found, None otherwise.
"""
try:
stmt = select(OAuthToken).where(OAuthToken.refresh_token == refresh_token)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
except SQLAlchemyError as e:
logger.error(f"Failed to fetch token by refresh token: {e}")
return None
async def get_expired_tokens(self) -> List[OAuthToken]:
"""
Retrieve all expired tokens.
Returns:
List of expired OAuthToken objects.
"""
try:
stmt = select(OAuthToken).where(OAuthToken.expires_at < datetime.utcnow())
result = await self.session.execute(stmt)
return list(result.scalars().all())
except SQLAlchemyError as e:
logger.error(f"Failed to fetch expired tokens: {e}")
return []
async def revoke_token(self, token_id: int) -> bool:
"""
Revoke (delete) a token by ID.
Args:
token_id: The token ID.
Returns:
True if deletion succeeded, False otherwise.
"""
return await self.delete(token_id)
async def revoke_by_access_token(self, access_token: str) -> bool:
"""
Revoke (delete) a token by access token.
Args:
access_token: The access token string.
Returns:
True if deletion succeeded, False otherwise.
"""
try:
stmt = delete(OAuthToken).where(OAuthToken.access_token == access_token)
result = await self.session.execute(stmt)
await self.session.commit()
return result.rowcount > 0
except SQLAlchemyError as e:
logger.error(f"Failed to revoke token by access token: {e}")
await self.session.rollback()
return False
async def get_all(self, skip: int = 0, limit: int = 100) -> List[OAuthToken]:
"""
Retrieve all tokens with pagination.
Args:
skip: Number of records to skip.
limit: Maximum number of records to return.
Returns:
List of OAuthToken objects.
"""
try:
stmt = select(OAuthToken).offset(skip).limit(limit)
result = await self.session.execute(stmt)
return list(result.scalars().all())
except SQLAlchemyError as e:
logger.error(f"Failed to fetch all tokens: {e}")
return []
async def update(self, token_id: int, token_data: dict) -> Optional[OAuthToken]:
"""
Update an existing token.
Args:
token_id: The token ID.
token_data: Dictionary of fields to update.
Returns:
Updated OAuthToken if successful, None otherwise.
"""
try:
stmt = (
update(OAuthToken)
.where(OAuthToken.id == token_id)
.values(**token_data)
.returning(OAuthToken)
)
result = await self.session.execute(stmt)
await self.session.commit()
token = result.scalar_one_or_none()
if token:
await self.session.refresh(token)
return token
except SQLAlchemyError as e:
logger.error(f"Failed to update token {token_id}: {e}")
await self.session.rollback()
return None
async def delete(self, token_id: int) -> bool:
"""
Delete a token by ID.
Args:
token_id: The token ID.
Returns:
True if deletion succeeded, False otherwise.
"""
try:
stmt = delete(OAuthToken).where(OAuthToken.id == token_id)
result = await self.session.execute(stmt)
await self.session.commit()
return result.rowcount > 0
except SQLAlchemyError as e:
logger.error(f"Failed to delete token {token_id}: {e}")
await self.session.rollback()
return False
class OAuthUserRepository:
"""Repository for performing CRUD operations on OAuthUser model."""
def __init__(self, session: AsyncSession):
self.session = session
async def create(self, user_data: dict) -> Optional[OAuthUser]:
"""
Create a new OAuth user.
Args:
user_data: Dictionary with user fields.
Returns:
OAuthUser instance if successful, None otherwise.
"""
try:
user = OAuthUser(**user_data)
self.session.add(user)
await self.session.commit()
await self.session.refresh(user)
return user
except SQLAlchemyError as e:
logger.error(f"Failed to create OAuth user: {e}")
await self.session.rollback()
return None
async def get_by_id(self, user_id: int) -> Optional[OAuthUser]:
"""
Retrieve a user by its ID.
Args:
user_id: The user ID.
Returns:
OAuthUser if found, None otherwise.
"""
try:
stmt = select(OAuthUser).where(OAuthUser.id == user_id)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
except SQLAlchemyError as e:
logger.error(f"Failed to fetch user by id {user_id}: {e}")
return None
async def get_by_username(self, username: str) -> Optional[OAuthUser]:
"""
Retrieve a user by username.
Args:
username: The username string.
Returns:
OAuthUser if found, None otherwise.
"""
try:
stmt = select(OAuthUser).where(OAuthUser.username == username)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
except SQLAlchemyError as e:
logger.error(f"Failed to fetch user by username {username}: {e}")
return None
async def get_by_email(self, email: str) -> Optional[OAuthUser]:
"""
Retrieve a user by email.
Args:
email: The email address.
Returns:
OAuthUser if found, None otherwise.
"""
try:
stmt = select(OAuthUser).where(OAuthUser.email == email)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
except SQLAlchemyError as e:
logger.error(f"Failed to fetch user by email {email}: {e}")
return None
async def get_all(self, skip: int = 0, limit: int = 100) -> List[OAuthUser]:
"""
Retrieve all users with pagination.
Args:
skip: Number of records to skip.
limit: Maximum number of records to return.
Returns:
List of OAuthUser objects.
"""
try:
stmt = select(OAuthUser).offset(skip).limit(limit)
result = await self.session.execute(stmt)
return list(result.scalars().all())
except SQLAlchemyError as e:
logger.error(f"Failed to fetch all users: {e}")
return []
async def update(self, user_id: int, user_data: dict) -> Optional[OAuthUser]:
"""
Update an existing user.
Args:
user_id: The user ID.
user_data: Dictionary of fields to update.
Returns:
Updated OAuthUser if successful, None otherwise.
"""
try:
stmt = (
update(OAuthUser)
.where(OAuthUser.id == user_id)
.values(**user_data)
.returning(OAuthUser)
)
result = await self.session.execute(stmt)
await self.session.commit()
user = result.scalar_one_or_none()
if user:
await self.session.refresh(user)
return user
except SQLAlchemyError as e:
logger.error(f"Failed to update user {user_id}: {e}")
await self.session.rollback()
return None
async def delete(self, user_id: int) -> bool:
"""
Delete a user by ID.
Args:
user_id: The user ID.
Returns:
True if deletion succeeded, False otherwise.
"""
try:
stmt = delete(OAuthUser).where(OAuthUser.id == user_id)
result = await self.session.execute(stmt)
await self.session.commit()
return result.rowcount > 0
except SQLAlchemyError as e:
logger.error(f"Failed to delete user {user_id}: {e}")
await self.session.rollback()
return False