""" OAuth2 Services for token management, client validation, and grant flow handling. """ import logging from datetime import datetime, timedelta from typing import Dict, List, Optional, Union, Any from jose import jwt, JWTError from fastapi import HTTPException, status from sqlalchemy.ext.asyncio import AsyncSession from config import settings from middleware.auth_middleware import verify_password from .repositories import OAuthClientRepository, OAuthTokenRepository, OAuthUserRepository from .schemas import OAuthTokenCreate, OAuthClientResponse from .auth_code_store import authorization_code_store logger = logging.getLogger(__name__) class TokenService: """Service for JWT token generation, validation, and revocation checking.""" ALGORITHM = "HS256" def __init__(self, session: AsyncSession): self.session = session self.token_repo = OAuthTokenRepository(session) def create_access_token( self, subject: str, client_id: str, scopes: List[str], token_type: str = "Bearer", expires_delta: Optional[timedelta] = None, ) -> str: """ Create a JWT access token. Args: subject: The token subject (user ID or client ID). client_id: OAuth client identifier. scopes: List of granted scopes. token_type: Token type (default "Bearer"). expires_delta: Optional custom expiration delta. Returns: JWT token string. """ if expires_delta: expire = datetime.utcnow() + expires_delta else: expire = datetime.utcnow() + timedelta(minutes=settings.oauth2_access_token_expire_minutes) payload = { "sub": subject, "client_id": client_id, "scopes": scopes, "exp": expire, "token_type": token_type, "iat": datetime.utcnow(), "jti": self._generate_jti(), } return jwt.encode(payload, settings.secret_key, algorithm=self.ALGORITHM) def create_refresh_token( self, subject: str, client_id: str, scopes: List[str], token_type: str = "Refresh", expires_delta: Optional[timedelta] = None, ) -> str: """ Create a JWT refresh token. Args: subject: The token subject (user ID or client ID). client_id: OAuth client identifier. scopes: List of granted scopes. token_type: Token type (default "Refresh"). expires_delta: Optional custom expiration delta. Returns: JWT token string. """ if expires_delta: expire = datetime.utcnow() + expires_delta else: expire = datetime.utcnow() + timedelta(days=settings.oauth2_refresh_token_expire_days) payload = { "sub": subject, "client_id": client_id, "scopes": scopes, "exp": expire, "token_type": token_type, "iat": datetime.utcnow(), "jti": self._generate_jti(), } return jwt.encode(payload, settings.secret_key, algorithm=self.ALGORITHM) async def verify_token(self, token: str) -> Dict[str, Any]: """ Verify a JWT token and return its payload. This method validates the token signature, expiration, and checks if the token has been revoked (deleted from database). Args: token: JWT token string. Returns: Token payload dict if valid. Raises: HTTPException with status 401 if token is invalid, expired, or revoked. """ try: payload = jwt.decode(token, settings.secret_key, algorithms=[self.ALGORITHM]) except JWTError as e: logger.warning(f"JWT validation failed: {e}") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token", headers={"WWW-Authenticate": "Bearer error=\"invalid_token\""}, ) # Check token expiration (JWT decode already validates exp, but we double-check) exp_timestamp = payload.get("exp") if exp_timestamp is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Token missing expiration", headers={"WWW-Authenticate": "Bearer error=\"invalid_token\""}, ) exp_datetime = datetime.utcfromtimestamp(exp_timestamp) if exp_datetime < datetime.utcnow(): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Token expired", headers={"WWW-Authenticate": "Bearer error=\"invalid_token\""}, ) # Check if token has been revoked (exists in database) token_record = await self.token_repo.get_by_access_token(token) if token_record is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Token revoked", headers={"WWW-Authenticate": "Bearer error=\"invalid_token\""}, ) # Ensure token is not expired according to database (should match) if token_record.expires_at < datetime.utcnow(): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Token expired", headers={"WWW-Authenticate": "Bearer error=\"invalid_token\""}, ) return payload def decode_token(self, token: str) -> Dict[str, Any]: """ Decode a JWT token without verification (for introspection only). Warning: This does NOT validate signature or expiration. Use only when the token has already been verified via verify_token(). Args: token: JWT token string. Returns: Token payload dict. """ return jwt.get_unverified_claims(token) async def store_token(self, token_data: OAuthTokenCreate) -> bool: """ Store a token record in the database. Args: token_data: OAuthTokenCreate schema with token details. Returns: True if storage succeeded, False otherwise. """ token_record = await self.token_repo.create(token_data.dict()) return token_record is not None async def revoke_token(self, token: str) -> bool: """ Revoke a token by deleting it from the database. Args: token: Access token string. Returns: True if revocation succeeded, False otherwise. """ return await self.token_repo.revoke_by_access_token(token) def _generate_jti(self) -> str: """Generate a unique JWT ID (jti).""" import secrets return secrets.token_urlsafe(32) class OAuthService: """Service implementing OAuth2 grant flows.""" def __init__(self, session: AsyncSession): self.session = session self.client_repo = OAuthClientRepository(session) self.token_repo = OAuthTokenRepository(session) self.user_repo = OAuthUserRepository(session) self.token_service = TokenService(session) async def authorize_code_flow( self, client_id: str, redirect_uri: str, scope: Optional[List[str]] = None, state: Optional[str] = None, user_id: Optional[int] = None, ) -> Dict[str, str]: """ Handle authorization code grant flow (RFC 6749 §4.1). Args: client_id: Client identifier. redirect_uri: Redirect URI must match one of the client's registered URIs. scope: Requested scopes. state: Opaque value for CSRF protection. user_id: Resource owner ID (if authenticated). Returns: Dictionary with authorization code and state (if provided). Raises: HTTPException with status 400 for invalid requests. """ # Validate client client = await self.client_repo.get_by_client_id(client_id) if not client or not client.is_active: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid client", ) # Validate redirect URI if redirect_uri not in client.redirect_uris: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid redirect URI", ) # Validate requested scopes (if any) if scope: scope_service = ScopeService(self.session) if not scope_service.validate_scopes(scope, client.scopes): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid scope", ) # Generate authorization code (short-lived) import secrets code = secrets.token_urlsafe(32) # Determine granted scopes (if no scope requested, use client's default scopes) granted_scopes = scope or client.scopes # Store authorization code with metadata expires_at = datetime.utcnow() + timedelta( minutes=settings.oauth2_authorization_code_expire_minutes ) data = { "client_id": client_id, "redirect_uri": redirect_uri, "scopes": granted_scopes, "user_id": user_id, "expires_at": expires_at, } await authorization_code_store.store_code(code, data) logger.debug(f"Generated authorization code {code[:8]}... for client {client_id}") result = {"code": code} if state: result["state"] = state return result async def exchange_code_for_tokens( self, code: str, client_id: str, redirect_uri: str, ) -> Dict[str, Any]: """ Exchange an authorization code for access and refresh tokens (RFC 6749 §4.1.3). Args: code: Authorization code received from the client. client_id: Client identifier (must match the code's client_id). redirect_uri: Redirect URI used in the authorization request (must match). Returns: Dictionary with access token, refresh token, token type, expiration, and scope. Raises: HTTPException with status 400 for invalid code, mismatched client/redirect_uri, or expired code. """ # Retrieve code data from store data = await authorization_code_store.get_code(code) if data is None: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid or expired authorization code", ) # Validate client_id and redirect_uri match if data["client_id"] != client_id: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Client mismatch", ) if data["redirect_uri"] != redirect_uri: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Redirect URI mismatch", ) # Delete the code (one-time use) await authorization_code_store.delete_code(code) # Prepare token generation parameters scopes = data["scopes"] user_id = data.get("user_id") subject = str(user_id) if user_id is not None else client_id # Generate access token access_token = self.token_service.create_access_token( subject=subject, client_id=client_id, scopes=scopes, ) # Generate refresh token (authorization code grant includes refresh token) refresh_token = self.token_service.create_refresh_token( subject=subject, client_id=client_id, scopes=scopes, ) # Store token in database expires_at = datetime.utcnow() + timedelta(minutes=settings.oauth2_access_token_expire_minutes) token_data = OAuthTokenCreate( access_token=access_token, refresh_token=refresh_token, token_type="Bearer", expires_at=expires_at, scopes=scopes, client_id=client_id, user_id=user_id, ) await self.token_service.store_token(token_data) # Return token response according to RFC 6749 §5.1 return { "access_token": access_token, "token_type": "Bearer", "expires_in": settings.oauth2_access_token_expire_minutes * 60, "refresh_token": refresh_token, "scope": " ".join(scopes) if scopes else "", } async def client_credentials_flow( self, client_id: str, client_secret: str, scope: Optional[List[str]] = None, ) -> Dict[str, Any]: """ Handle client credentials grant flow (RFC 6749 §4.4). Args: client_id: Client identifier. client_secret: Client secret. scope: Requested scopes. Returns: Dictionary with access token and metadata. Raises: HTTPException with status 400 for invalid credentials. """ client_service = ClientService(self.session) if not await client_service.validate_client(client_id, client_secret): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid client credentials", headers={"WWW-Authenticate": "Basic"}, ) client = await self.client_repo.get_by_client_id(client_id) if not client or not client.is_active: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid client", ) # Validate requested scopes if scope: scope_service = ScopeService(self.session) if not scope_service.validate_scopes(scope, client.scopes): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid scope", ) else: scope = client.scopes # Generate access token access_token = self.token_service.create_access_token( subject=client_id, client_id=client_id, scopes=scope, ) # Store token in database expires_at = datetime.utcnow() + timedelta(minutes=settings.oauth2_access_token_expire_minutes) token_data = OAuthTokenCreate( access_token=access_token, refresh_token=None, token_type="Bearer", expires_at=expires_at, scopes=scope, client_id=client_id, user_id=None, ) await self.token_service.store_token(token_data) return { "access_token": access_token, "token_type": "Bearer", "expires_in": settings.oauth2_access_token_expire_minutes * 60, "scope": " ".join(scope) if scope else "", } async def refresh_token_flow( self, refresh_token: str, client_id: str, client_secret: str, scope: Optional[List[str]] = None, ) -> Dict[str, Any]: """ Handle refresh token grant flow (RFC 6749 §6). Args: refresh_token: Valid refresh token. client_id: Client identifier. client_secret: Client secret. scope: Optional requested scopes (must be subset of original). Returns: Dictionary with new access token and optionally new refresh token. Raises: HTTPException with status 400 for invalid request. """ # Validate client credentials client_service = ClientService(self.session) if not await client_service.validate_client(client_id, client_secret): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid client credentials", ) # Look up refresh token in database token_record = await self.token_repo.get_by_refresh_token(refresh_token) if not token_record: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid refresh token", ) # Verify token belongs to client if token_record.client_id != client_id: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Refresh token does not belong to client", ) # Check if token is expired if token_record.expires_at < datetime.utcnow(): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Refresh token expired", ) # Validate requested scopes (if any) are subset of original scopes if scope: scope_service = ScopeService(self.session) if not scope_service.validate_scopes(scope, token_record.scopes): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid scope", ) else: scope = token_record.scopes # Generate new access token subject = str(token_record.user_id) if token_record.user_id else token_record.client_id access_token = self.token_service.create_access_token( subject=subject, client_id=client_id, scopes=scope, ) # Optionally generate new refresh token (rotation) new_refresh_token = self.token_service.create_refresh_token( subject=subject, client_id=client_id, scopes=scope, ) # Store new tokens and revoke old refresh token expires_at = datetime.utcnow() + timedelta(minutes=settings.oauth2_access_token_expire_minutes) new_token_data = OAuthTokenCreate( access_token=access_token, refresh_token=new_refresh_token, token_type="Bearer", expires_at=expires_at, scopes=scope, client_id=client_id, user_id=token_record.user_id, ) await self.token_service.store_token(new_token_data) await self.token_repo.revoke_token(token_record.id) response = { "access_token": access_token, "token_type": "Bearer", "expires_in": settings.oauth2_access_token_expire_minutes * 60, "scope": " ".join(scope) if scope else "", } if new_refresh_token: response["refresh_token"] = new_refresh_token return response class ClientService: """Service for OAuth client validation and secret verification.""" def __init__(self, session: AsyncSession): self.session = session self.client_repo = OAuthClientRepository(session) async def validate_client(self, client_id: str, client_secret: str) -> bool: """ Validate client credentials. Args: client_id: Client identifier. client_secret: Client secret (plaintext). Returns: True if credentials are valid, False otherwise. """ client = await self.client_repo.get_by_client_id(client_id) if not client or not client.is_active: return False return await self.verify_client_secret(client_secret, client.client_secret) async def verify_client_secret(self, plain_secret: str, hashed_secret: str) -> bool: """ Verify a client secret against its hash. Args: plain_secret: Plaintext secret. hashed_secret: Hashed secret (bcrypt). Returns: True if secret matches, False otherwise. """ return verify_password(plain_secret, hashed_secret) async def get_client_scopes(self, client_id: str) -> List[str]: """ Retrieve allowed scopes for a client. Args: client_id: Client identifier. Returns: List of scopes allowed for the client. Raises: HTTPException if client not found. """ client = await self.client_repo.get_by_client_id(client_id) if not client: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid client", ) return client.scopes class ScopeService: """Service for scope validation and management.""" def __init__(self, session: AsyncSession): self.session = session def validate_scopes(self, requested_scopes: List[str], allowed_scopes: List[str]) -> bool: """ Validate that requested scopes are subset of allowed scopes. Args: requested_scopes: List of scopes being requested. allowed_scopes: List of scopes allowed for the client. Returns: True if all requested scopes are allowed, False otherwise. """ if not requested_scopes: return True return all(scope in allowed_scopes for scope in requested_scopes) def check_scope_access(self, token_scopes: List[str], required_scopes: List[str]) -> bool: """ Check if token scopes satisfy required scopes. Args: token_scopes: Scopes granted to the token. required_scopes: Scopes required for the endpoint. Returns: True if token has all required scopes, False otherwise. """ if not required_scopes: return True return all(scope in token_scopes for scope in required_scopes)