492 lines
No EOL
15 KiB
Python
492 lines
No EOL
15 KiB
Python
from typing import List, Optional
|
|
from datetime import datetime
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select, update, delete, and_
|
|
from sqlalchemy.exc import SQLAlchemyError
|
|
import logging
|
|
|
|
# Import database first to resolve circular import
|
|
from app.core import database
|
|
from app.modules.oauth2.models import OAuthClient, OAuthToken, OAuthUser
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class OAuthClientRepository:
|
|
"""Repository for performing CRUD operations on OAuthClient model."""
|
|
|
|
def __init__(self, session: AsyncSession):
|
|
self.session = session
|
|
|
|
async def create(self, client_data: dict) -> Optional[OAuthClient]:
|
|
"""
|
|
Create a new OAuth client.
|
|
|
|
Args:
|
|
client_data: Dictionary with client fields.
|
|
|
|
Returns:
|
|
OAuthClient instance if successful, None otherwise.
|
|
"""
|
|
try:
|
|
client = OAuthClient(**client_data)
|
|
self.session.add(client)
|
|
await self.session.commit()
|
|
await self.session.refresh(client)
|
|
return client
|
|
except SQLAlchemyError as e:
|
|
logger.error(f"Failed to create OAuth client: {e}")
|
|
await self.session.rollback()
|
|
return None
|
|
|
|
async def get_by_id(self, client_id: int) -> Optional[OAuthClient]:
|
|
"""
|
|
Retrieve a client by its ID.
|
|
|
|
Args:
|
|
client_id: The client ID.
|
|
|
|
Returns:
|
|
OAuthClient if found, None otherwise.
|
|
"""
|
|
try:
|
|
stmt = select(OAuthClient).where(OAuthClient.id == client_id)
|
|
result = await self.session.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
except SQLAlchemyError as e:
|
|
logger.error(f"Failed to fetch client by id {client_id}: {e}")
|
|
return None
|
|
|
|
async def get_by_client_id(self, client_id_str: str) -> Optional[OAuthClient]:
|
|
"""
|
|
Retrieve a client by its client_id (unique string identifier).
|
|
|
|
Args:
|
|
client_id_str: The client identifier string.
|
|
|
|
Returns:
|
|
OAuthClient if found, None otherwise.
|
|
"""
|
|
try:
|
|
stmt = select(OAuthClient).where(OAuthClient.client_id == client_id_str)
|
|
result = await self.session.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
except SQLAlchemyError as e:
|
|
logger.error(f"Failed to fetch client by client_id {client_id_str}: {e}")
|
|
return None
|
|
|
|
async def get_all(self, skip: int = 0, limit: int = 100) -> List[OAuthClient]:
|
|
"""
|
|
Retrieve all clients with pagination.
|
|
|
|
Args:
|
|
skip: Number of records to skip.
|
|
limit: Maximum number of records to return.
|
|
|
|
Returns:
|
|
List of OAuthClient objects.
|
|
"""
|
|
try:
|
|
stmt = select(OAuthClient).offset(skip).limit(limit)
|
|
result = await self.session.execute(stmt)
|
|
return list(result.scalars().all())
|
|
except SQLAlchemyError as e:
|
|
logger.error(f"Failed to fetch all clients: {e}")
|
|
return []
|
|
|
|
async def update(self, client_id: int, client_data: dict) -> Optional[OAuthClient]:
|
|
"""
|
|
Update an existing client.
|
|
|
|
Args:
|
|
client_id: The client ID.
|
|
client_data: Dictionary of fields to update.
|
|
|
|
Returns:
|
|
Updated OAuthClient if successful, None otherwise.
|
|
"""
|
|
try:
|
|
stmt = (
|
|
update(OAuthClient)
|
|
.where(OAuthClient.id == client_id)
|
|
.values(**client_data)
|
|
.returning(OAuthClient)
|
|
)
|
|
result = await self.session.execute(stmt)
|
|
await self.session.commit()
|
|
client = result.scalar_one_or_none()
|
|
if client:
|
|
await self.session.refresh(client)
|
|
return client
|
|
except SQLAlchemyError as e:
|
|
logger.error(f"Failed to update client {client_id}: {e}")
|
|
await self.session.rollback()
|
|
return None
|
|
|
|
async def delete(self, client_id: int) -> bool:
|
|
"""
|
|
Delete a client by ID.
|
|
|
|
Args:
|
|
client_id: The client ID.
|
|
|
|
Returns:
|
|
True if deletion succeeded, False otherwise.
|
|
"""
|
|
try:
|
|
stmt = delete(OAuthClient).where(OAuthClient.id == client_id)
|
|
result = await self.session.execute(stmt)
|
|
await self.session.commit()
|
|
return result.rowcount > 0
|
|
except SQLAlchemyError as e:
|
|
logger.error(f"Failed to delete client {client_id}: {e}")
|
|
await self.session.rollback()
|
|
return False
|
|
|
|
|
|
class OAuthTokenRepository:
|
|
"""Repository for performing CRUD operations on OAuthToken model."""
|
|
|
|
def __init__(self, session: AsyncSession):
|
|
self.session = session
|
|
|
|
async def create(self, token_data: dict) -> Optional[OAuthToken]:
|
|
"""
|
|
Create a new OAuth token.
|
|
|
|
Args:
|
|
token_data: Dictionary with token fields.
|
|
|
|
Returns:
|
|
OAuthToken instance if successful, None otherwise.
|
|
"""
|
|
try:
|
|
token = OAuthToken(**token_data)
|
|
self.session.add(token)
|
|
await self.session.commit()
|
|
await self.session.refresh(token)
|
|
return token
|
|
except SQLAlchemyError as e:
|
|
logger.error(f"Failed to create OAuth token: {e}")
|
|
await self.session.rollback()
|
|
return None
|
|
|
|
async def get_by_id(self, token_id: int) -> Optional[OAuthToken]:
|
|
"""
|
|
Retrieve a token by its ID.
|
|
|
|
Args:
|
|
token_id: The token ID.
|
|
|
|
Returns:
|
|
OAuthToken if found, None otherwise.
|
|
"""
|
|
try:
|
|
stmt = select(OAuthToken).where(OAuthToken.id == token_id)
|
|
result = await self.session.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
except SQLAlchemyError as e:
|
|
logger.error(f"Failed to fetch token by id {token_id}: {e}")
|
|
return None
|
|
|
|
async def get_by_access_token(self, access_token: str) -> Optional[OAuthToken]:
|
|
"""
|
|
Retrieve a token by its access token.
|
|
|
|
Args:
|
|
access_token: The access token string.
|
|
|
|
Returns:
|
|
OAuthToken if found, None otherwise.
|
|
"""
|
|
try:
|
|
stmt = select(OAuthToken).where(OAuthToken.access_token == access_token)
|
|
result = await self.session.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
except SQLAlchemyError as e:
|
|
logger.error(f"Failed to fetch token by access token: {e}")
|
|
return None
|
|
|
|
async def get_by_refresh_token(self, refresh_token: str) -> Optional[OAuthToken]:
|
|
"""
|
|
Retrieve a token by its refresh token.
|
|
|
|
Args:
|
|
refresh_token: The refresh token string.
|
|
|
|
Returns:
|
|
OAuthToken if found, None otherwise.
|
|
"""
|
|
try:
|
|
stmt = select(OAuthToken).where(OAuthToken.refresh_token == refresh_token)
|
|
result = await self.session.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
except SQLAlchemyError as e:
|
|
logger.error(f"Failed to fetch token by refresh token: {e}")
|
|
return None
|
|
|
|
async def get_expired_tokens(self) -> List[OAuthToken]:
|
|
"""
|
|
Retrieve all expired tokens.
|
|
|
|
Returns:
|
|
List of expired OAuthToken objects.
|
|
"""
|
|
try:
|
|
stmt = select(OAuthToken).where(OAuthToken.expires_at < datetime.utcnow())
|
|
result = await self.session.execute(stmt)
|
|
return list(result.scalars().all())
|
|
except SQLAlchemyError as e:
|
|
logger.error(f"Failed to fetch expired tokens: {e}")
|
|
return []
|
|
|
|
async def revoke_token(self, token_id: int) -> bool:
|
|
"""
|
|
Revoke (delete) a token by ID.
|
|
|
|
Args:
|
|
token_id: The token ID.
|
|
|
|
Returns:
|
|
True if deletion succeeded, False otherwise.
|
|
"""
|
|
return await self.delete(token_id)
|
|
|
|
async def revoke_by_access_token(self, access_token: str) -> bool:
|
|
"""
|
|
Revoke (delete) a token by access token.
|
|
|
|
Args:
|
|
access_token: The access token string.
|
|
|
|
Returns:
|
|
True if deletion succeeded, False otherwise.
|
|
"""
|
|
try:
|
|
stmt = delete(OAuthToken).where(OAuthToken.access_token == access_token)
|
|
result = await self.session.execute(stmt)
|
|
await self.session.commit()
|
|
return result.rowcount > 0
|
|
except SQLAlchemyError as e:
|
|
logger.error(f"Failed to revoke token by access token: {e}")
|
|
await self.session.rollback()
|
|
return False
|
|
|
|
async def get_all(self, skip: int = 0, limit: int = 100) -> List[OAuthToken]:
|
|
"""
|
|
Retrieve all tokens with pagination.
|
|
|
|
Args:
|
|
skip: Number of records to skip.
|
|
limit: Maximum number of records to return.
|
|
|
|
Returns:
|
|
List of OAuthToken objects.
|
|
"""
|
|
try:
|
|
stmt = select(OAuthToken).offset(skip).limit(limit)
|
|
result = await self.session.execute(stmt)
|
|
return list(result.scalars().all())
|
|
except SQLAlchemyError as e:
|
|
logger.error(f"Failed to fetch all tokens: {e}")
|
|
return []
|
|
|
|
async def update(self, token_id: int, token_data: dict) -> Optional[OAuthToken]:
|
|
"""
|
|
Update an existing token.
|
|
|
|
Args:
|
|
token_id: The token ID.
|
|
token_data: Dictionary of fields to update.
|
|
|
|
Returns:
|
|
Updated OAuthToken if successful, None otherwise.
|
|
"""
|
|
try:
|
|
stmt = (
|
|
update(OAuthToken)
|
|
.where(OAuthToken.id == token_id)
|
|
.values(**token_data)
|
|
.returning(OAuthToken)
|
|
)
|
|
result = await self.session.execute(stmt)
|
|
await self.session.commit()
|
|
token = result.scalar_one_or_none()
|
|
if token:
|
|
await self.session.refresh(token)
|
|
return token
|
|
except SQLAlchemyError as e:
|
|
logger.error(f"Failed to update token {token_id}: {e}")
|
|
await self.session.rollback()
|
|
return None
|
|
|
|
async def delete(self, token_id: int) -> bool:
|
|
"""
|
|
Delete a token by ID.
|
|
|
|
Args:
|
|
token_id: The token ID.
|
|
|
|
Returns:
|
|
True if deletion succeeded, False otherwise.
|
|
"""
|
|
try:
|
|
stmt = delete(OAuthToken).where(OAuthToken.id == token_id)
|
|
result = await self.session.execute(stmt)
|
|
await self.session.commit()
|
|
return result.rowcount > 0
|
|
except SQLAlchemyError as e:
|
|
logger.error(f"Failed to delete token {token_id}: {e}")
|
|
await self.session.rollback()
|
|
return False
|
|
|
|
|
|
class OAuthUserRepository:
|
|
"""Repository for performing CRUD operations on OAuthUser model."""
|
|
|
|
def __init__(self, session: AsyncSession):
|
|
self.session = session
|
|
|
|
async def create(self, user_data: dict) -> Optional[OAuthUser]:
|
|
"""
|
|
Create a new OAuth user.
|
|
|
|
Args:
|
|
user_data: Dictionary with user fields.
|
|
|
|
Returns:
|
|
OAuthUser instance if successful, None otherwise.
|
|
"""
|
|
try:
|
|
user = OAuthUser(**user_data)
|
|
self.session.add(user)
|
|
await self.session.commit()
|
|
await self.session.refresh(user)
|
|
return user
|
|
except SQLAlchemyError as e:
|
|
logger.error(f"Failed to create OAuth user: {e}")
|
|
await self.session.rollback()
|
|
return None
|
|
|
|
async def get_by_id(self, user_id: int) -> Optional[OAuthUser]:
|
|
"""
|
|
Retrieve a user by its ID.
|
|
|
|
Args:
|
|
user_id: The user ID.
|
|
|
|
Returns:
|
|
OAuthUser if found, None otherwise.
|
|
"""
|
|
try:
|
|
stmt = select(OAuthUser).where(OAuthUser.id == user_id)
|
|
result = await self.session.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
except SQLAlchemyError as e:
|
|
logger.error(f"Failed to fetch user by id {user_id}: {e}")
|
|
return None
|
|
|
|
async def get_by_username(self, username: str) -> Optional[OAuthUser]:
|
|
"""
|
|
Retrieve a user by username.
|
|
|
|
Args:
|
|
username: The username string.
|
|
|
|
Returns:
|
|
OAuthUser if found, None otherwise.
|
|
"""
|
|
try:
|
|
stmt = select(OAuthUser).where(OAuthUser.username == username)
|
|
result = await self.session.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
except SQLAlchemyError as e:
|
|
logger.error(f"Failed to fetch user by username {username}: {e}")
|
|
return None
|
|
|
|
async def get_by_email(self, email: str) -> Optional[OAuthUser]:
|
|
"""
|
|
Retrieve a user by email.
|
|
|
|
Args:
|
|
email: The email address.
|
|
|
|
Returns:
|
|
OAuthUser if found, None otherwise.
|
|
"""
|
|
try:
|
|
stmt = select(OAuthUser).where(OAuthUser.email == email)
|
|
result = await self.session.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
except SQLAlchemyError as e:
|
|
logger.error(f"Failed to fetch user by email {email}: {e}")
|
|
return None
|
|
|
|
async def get_all(self, skip: int = 0, limit: int = 100) -> List[OAuthUser]:
|
|
"""
|
|
Retrieve all users with pagination.
|
|
|
|
Args:
|
|
skip: Number of records to skip.
|
|
limit: Maximum number of records to return.
|
|
|
|
Returns:
|
|
List of OAuthUser objects.
|
|
"""
|
|
try:
|
|
stmt = select(OAuthUser).offset(skip).limit(limit)
|
|
result = await self.session.execute(stmt)
|
|
return list(result.scalars().all())
|
|
except SQLAlchemyError as e:
|
|
logger.error(f"Failed to fetch all users: {e}")
|
|
return []
|
|
|
|
async def update(self, user_id: int, user_data: dict) -> Optional[OAuthUser]:
|
|
"""
|
|
Update an existing user.
|
|
|
|
Args:
|
|
user_id: The user ID.
|
|
user_data: Dictionary of fields to update.
|
|
|
|
Returns:
|
|
Updated OAuthUser if successful, None otherwise.
|
|
"""
|
|
try:
|
|
stmt = (
|
|
update(OAuthUser)
|
|
.where(OAuthUser.id == user_id)
|
|
.values(**user_data)
|
|
.returning(OAuthUser)
|
|
)
|
|
result = await self.session.execute(stmt)
|
|
await self.session.commit()
|
|
user = result.scalar_one_or_none()
|
|
if user:
|
|
await self.session.refresh(user)
|
|
return user
|
|
except SQLAlchemyError as e:
|
|
logger.error(f"Failed to update user {user_id}: {e}")
|
|
await self.session.rollback()
|
|
return None
|
|
|
|
async def delete(self, user_id: int) -> bool:
|
|
"""
|
|
Delete a user by ID.
|
|
|
|
Args:
|
|
user_id: The user ID.
|
|
|
|
Returns:
|
|
True if deletion succeeded, False otherwise.
|
|
"""
|
|
try:
|
|
stmt = delete(OAuthUser).where(OAuthUser.id == user_id)
|
|
result = await self.session.execute(stmt)
|
|
await self.session.commit()
|
|
return result.rowcount > 0
|
|
except SQLAlchemyError as e:
|
|
logger.error(f"Failed to delete user {user_id}: {e}")
|
|
await self.session.rollback()
|
|
return False |