649 lines
No EOL
22 KiB
Python
649 lines
No EOL
22 KiB
Python
"""
|
|
OAuth2 Services for token management, client validation, and grant flow handling.
|
|
"""
|
|
import logging
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Optional, Union, Any
|
|
from jose import jwt, JWTError
|
|
from fastapi import HTTPException, status
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from config import settings
|
|
from middleware.auth_middleware import verify_password
|
|
from .repositories import OAuthClientRepository, OAuthTokenRepository, OAuthUserRepository
|
|
from .schemas import OAuthTokenCreate, OAuthClientResponse
|
|
from .auth_code_store import authorization_code_store
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TokenService:
|
|
"""Service for JWT token generation, validation, and revocation checking."""
|
|
|
|
ALGORITHM = "HS256"
|
|
|
|
def __init__(self, session: AsyncSession):
|
|
self.session = session
|
|
self.token_repo = OAuthTokenRepository(session)
|
|
|
|
def create_access_token(
|
|
self,
|
|
subject: str,
|
|
client_id: str,
|
|
scopes: List[str],
|
|
token_type: str = "Bearer",
|
|
expires_delta: Optional[timedelta] = None,
|
|
) -> str:
|
|
"""
|
|
Create a JWT access token.
|
|
|
|
Args:
|
|
subject: The token subject (user ID or client ID).
|
|
client_id: OAuth client identifier.
|
|
scopes: List of granted scopes.
|
|
token_type: Token type (default "Bearer").
|
|
expires_delta: Optional custom expiration delta.
|
|
|
|
Returns:
|
|
JWT token string.
|
|
"""
|
|
if expires_delta:
|
|
expire = datetime.utcnow() + expires_delta
|
|
else:
|
|
expire = datetime.utcnow() + timedelta(minutes=settings.oauth2_access_token_expire_minutes)
|
|
|
|
payload = {
|
|
"sub": subject,
|
|
"client_id": client_id,
|
|
"scopes": scopes,
|
|
"exp": expire,
|
|
"token_type": token_type,
|
|
"iat": datetime.utcnow(),
|
|
"jti": self._generate_jti(),
|
|
}
|
|
return jwt.encode(payload, settings.secret_key, algorithm=self.ALGORITHM)
|
|
|
|
def create_refresh_token(
|
|
self,
|
|
subject: str,
|
|
client_id: str,
|
|
scopes: List[str],
|
|
token_type: str = "Refresh",
|
|
expires_delta: Optional[timedelta] = None,
|
|
) -> str:
|
|
"""
|
|
Create a JWT refresh token.
|
|
|
|
Args:
|
|
subject: The token subject (user ID or client ID).
|
|
client_id: OAuth client identifier.
|
|
scopes: List of granted scopes.
|
|
token_type: Token type (default "Refresh").
|
|
expires_delta: Optional custom expiration delta.
|
|
|
|
Returns:
|
|
JWT token string.
|
|
"""
|
|
if expires_delta:
|
|
expire = datetime.utcnow() + expires_delta
|
|
else:
|
|
expire = datetime.utcnow() + timedelta(days=settings.oauth2_refresh_token_expire_days)
|
|
|
|
payload = {
|
|
"sub": subject,
|
|
"client_id": client_id,
|
|
"scopes": scopes,
|
|
"exp": expire,
|
|
"token_type": token_type,
|
|
"iat": datetime.utcnow(),
|
|
"jti": self._generate_jti(),
|
|
}
|
|
return jwt.encode(payload, settings.secret_key, algorithm=self.ALGORITHM)
|
|
|
|
async def verify_token(self, token: str) -> Dict[str, Any]:
|
|
"""
|
|
Verify a JWT token and return its payload.
|
|
|
|
This method validates the token signature, expiration, and checks if the token
|
|
has been revoked (deleted from database).
|
|
|
|
Args:
|
|
token: JWT token string.
|
|
|
|
Returns:
|
|
Token payload dict if valid.
|
|
|
|
Raises:
|
|
HTTPException with status 401 if token is invalid, expired, or revoked.
|
|
"""
|
|
try:
|
|
payload = jwt.decode(token, settings.secret_key, algorithms=[self.ALGORITHM])
|
|
except JWTError as e:
|
|
logger.warning(f"JWT validation failed: {e}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid token",
|
|
headers={"WWW-Authenticate": "Bearer error=\"invalid_token\""},
|
|
)
|
|
|
|
# Check token expiration (JWT decode already validates exp, but we double-check)
|
|
exp_timestamp = payload.get("exp")
|
|
if exp_timestamp is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Token missing expiration",
|
|
headers={"WWW-Authenticate": "Bearer error=\"invalid_token\""},
|
|
)
|
|
exp_datetime = datetime.utcfromtimestamp(exp_timestamp)
|
|
if exp_datetime < datetime.utcnow():
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Token expired",
|
|
headers={"WWW-Authenticate": "Bearer error=\"invalid_token\""},
|
|
)
|
|
|
|
# Check if token has been revoked (exists in database)
|
|
token_record = await self.token_repo.get_by_access_token(token)
|
|
if token_record is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Token revoked",
|
|
headers={"WWW-Authenticate": "Bearer error=\"invalid_token\""},
|
|
)
|
|
|
|
# Ensure token is not expired according to database (should match)
|
|
if token_record.expires_at < datetime.utcnow():
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Token expired",
|
|
headers={"WWW-Authenticate": "Bearer error=\"invalid_token\""},
|
|
)
|
|
|
|
return payload
|
|
|
|
def decode_token(self, token: str) -> Dict[str, Any]:
|
|
"""
|
|
Decode a JWT token without verification (for introspection only).
|
|
|
|
Warning: This does NOT validate signature or expiration. Use only when
|
|
the token has already been verified via verify_token().
|
|
|
|
Args:
|
|
token: JWT token string.
|
|
|
|
Returns:
|
|
Token payload dict.
|
|
"""
|
|
return jwt.get_unverified_claims(token)
|
|
|
|
async def store_token(self, token_data: OAuthTokenCreate) -> bool:
|
|
"""
|
|
Store a token record in the database.
|
|
|
|
Args:
|
|
token_data: OAuthTokenCreate schema with token details.
|
|
|
|
Returns:
|
|
True if storage succeeded, False otherwise.
|
|
"""
|
|
token_record = await self.token_repo.create(token_data.dict())
|
|
return token_record is not None
|
|
|
|
async def revoke_token(self, token: str) -> bool:
|
|
"""
|
|
Revoke a token by deleting it from the database.
|
|
|
|
Args:
|
|
token: Access token string.
|
|
|
|
Returns:
|
|
True if revocation succeeded, False otherwise.
|
|
"""
|
|
return await self.token_repo.revoke_by_access_token(token)
|
|
|
|
def _generate_jti(self) -> str:
|
|
"""Generate a unique JWT ID (jti)."""
|
|
import secrets
|
|
return secrets.token_urlsafe(32)
|
|
|
|
|
|
class OAuthService:
|
|
"""Service implementing OAuth2 grant flows."""
|
|
|
|
def __init__(self, session: AsyncSession):
|
|
self.session = session
|
|
self.client_repo = OAuthClientRepository(session)
|
|
self.token_repo = OAuthTokenRepository(session)
|
|
self.user_repo = OAuthUserRepository(session)
|
|
self.token_service = TokenService(session)
|
|
|
|
async def authorize_code_flow(
|
|
self,
|
|
client_id: str,
|
|
redirect_uri: str,
|
|
scope: Optional[List[str]] = None,
|
|
state: Optional[str] = None,
|
|
user_id: Optional[int] = None,
|
|
) -> Dict[str, str]:
|
|
"""
|
|
Handle authorization code grant flow (RFC 6749 §4.1).
|
|
|
|
Args:
|
|
client_id: Client identifier.
|
|
redirect_uri: Redirect URI must match one of the client's registered URIs.
|
|
scope: Requested scopes.
|
|
state: Opaque value for CSRF protection.
|
|
user_id: Resource owner ID (if authenticated).
|
|
|
|
Returns:
|
|
Dictionary with authorization code and state (if provided).
|
|
|
|
Raises:
|
|
HTTPException with status 400 for invalid requests.
|
|
"""
|
|
# Validate client
|
|
client = await self.client_repo.get_by_client_id(client_id)
|
|
if not client or not client.is_active:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Invalid client",
|
|
)
|
|
|
|
# Validate redirect URI
|
|
if redirect_uri not in client.redirect_uris:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Invalid redirect URI",
|
|
)
|
|
|
|
# Validate requested scopes (if any)
|
|
if scope:
|
|
scope_service = ScopeService(self.session)
|
|
if not scope_service.validate_scopes(scope, client.scopes):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Invalid scope",
|
|
)
|
|
|
|
# Generate authorization code (short-lived)
|
|
import secrets
|
|
code = secrets.token_urlsafe(32)
|
|
|
|
# Determine granted scopes (if no scope requested, use client's default scopes)
|
|
granted_scopes = scope or client.scopes
|
|
|
|
# Store authorization code with metadata
|
|
expires_at = datetime.utcnow() + timedelta(
|
|
minutes=settings.oauth2_authorization_code_expire_minutes
|
|
)
|
|
data = {
|
|
"client_id": client_id,
|
|
"redirect_uri": redirect_uri,
|
|
"scopes": granted_scopes,
|
|
"user_id": user_id,
|
|
"expires_at": expires_at,
|
|
}
|
|
await authorization_code_store.store_code(code, data)
|
|
|
|
logger.debug(f"Generated authorization code {code[:8]}... for client {client_id}")
|
|
|
|
result = {"code": code}
|
|
if state:
|
|
result["state"] = state
|
|
return result
|
|
|
|
async def exchange_code_for_tokens(
|
|
self,
|
|
code: str,
|
|
client_id: str,
|
|
redirect_uri: str,
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Exchange an authorization code for access and refresh tokens (RFC 6749 §4.1.3).
|
|
|
|
Args:
|
|
code: Authorization code received from the client.
|
|
client_id: Client identifier (must match the code's client_id).
|
|
redirect_uri: Redirect URI used in the authorization request (must match).
|
|
|
|
Returns:
|
|
Dictionary with access token, refresh token, token type, expiration, and scope.
|
|
|
|
Raises:
|
|
HTTPException with status 400 for invalid code, mismatched client/redirect_uri,
|
|
or expired code.
|
|
"""
|
|
# Retrieve code data from store
|
|
data = await authorization_code_store.get_code(code)
|
|
if data is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Invalid or expired authorization code",
|
|
)
|
|
|
|
# Validate client_id and redirect_uri match
|
|
if data["client_id"] != client_id:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Client mismatch",
|
|
)
|
|
if data["redirect_uri"] != redirect_uri:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Redirect URI mismatch",
|
|
)
|
|
|
|
# Delete the code (one-time use)
|
|
await authorization_code_store.delete_code(code)
|
|
|
|
# Prepare token generation parameters
|
|
scopes = data["scopes"]
|
|
user_id = data.get("user_id")
|
|
subject = str(user_id) if user_id is not None else client_id
|
|
|
|
# Generate access token
|
|
access_token = self.token_service.create_access_token(
|
|
subject=subject,
|
|
client_id=client_id,
|
|
scopes=scopes,
|
|
)
|
|
|
|
# Generate refresh token (authorization code grant includes refresh token)
|
|
refresh_token = self.token_service.create_refresh_token(
|
|
subject=subject,
|
|
client_id=client_id,
|
|
scopes=scopes,
|
|
)
|
|
|
|
# Store token in database
|
|
expires_at = datetime.utcnow() + timedelta(minutes=settings.oauth2_access_token_expire_minutes)
|
|
token_data = OAuthTokenCreate(
|
|
access_token=access_token,
|
|
refresh_token=refresh_token,
|
|
token_type="Bearer",
|
|
expires_at=expires_at,
|
|
scopes=scopes,
|
|
client_id=client_id,
|
|
user_id=user_id,
|
|
)
|
|
await self.token_service.store_token(token_data)
|
|
|
|
# Return token response according to RFC 6749 §5.1
|
|
return {
|
|
"access_token": access_token,
|
|
"token_type": "Bearer",
|
|
"expires_in": settings.oauth2_access_token_expire_minutes * 60,
|
|
"refresh_token": refresh_token,
|
|
"scope": " ".join(scopes) if scopes else "",
|
|
}
|
|
|
|
async def client_credentials_flow(
|
|
self,
|
|
client_id: str,
|
|
client_secret: str,
|
|
scope: Optional[List[str]] = None,
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Handle client credentials grant flow (RFC 6749 §4.4).
|
|
|
|
Args:
|
|
client_id: Client identifier.
|
|
client_secret: Client secret.
|
|
scope: Requested scopes.
|
|
|
|
Returns:
|
|
Dictionary with access token and metadata.
|
|
|
|
Raises:
|
|
HTTPException with status 400 for invalid credentials.
|
|
"""
|
|
client_service = ClientService(self.session)
|
|
if not await client_service.validate_client(client_id, client_secret):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid client credentials",
|
|
headers={"WWW-Authenticate": "Basic"},
|
|
)
|
|
|
|
client = await self.client_repo.get_by_client_id(client_id)
|
|
if not client or not client.is_active:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Invalid client",
|
|
)
|
|
|
|
# Validate requested scopes
|
|
if scope:
|
|
scope_service = ScopeService(self.session)
|
|
if not scope_service.validate_scopes(scope, client.scopes):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Invalid scope",
|
|
)
|
|
else:
|
|
scope = client.scopes
|
|
|
|
# Generate access token
|
|
access_token = self.token_service.create_access_token(
|
|
subject=client_id,
|
|
client_id=client_id,
|
|
scopes=scope,
|
|
)
|
|
|
|
# Store token in database
|
|
expires_at = datetime.utcnow() + timedelta(minutes=settings.oauth2_access_token_expire_minutes)
|
|
token_data = OAuthTokenCreate(
|
|
access_token=access_token,
|
|
refresh_token=None,
|
|
token_type="Bearer",
|
|
expires_at=expires_at,
|
|
scopes=scope,
|
|
client_id=client_id,
|
|
user_id=None,
|
|
)
|
|
await self.token_service.store_token(token_data)
|
|
|
|
return {
|
|
"access_token": access_token,
|
|
"token_type": "Bearer",
|
|
"expires_in": settings.oauth2_access_token_expire_minutes * 60,
|
|
"scope": " ".join(scope) if scope else "",
|
|
}
|
|
|
|
async def refresh_token_flow(
|
|
self,
|
|
refresh_token: str,
|
|
client_id: str,
|
|
client_secret: str,
|
|
scope: Optional[List[str]] = None,
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Handle refresh token grant flow (RFC 6749 §6).
|
|
|
|
Args:
|
|
refresh_token: Valid refresh token.
|
|
client_id: Client identifier.
|
|
client_secret: Client secret.
|
|
scope: Optional requested scopes (must be subset of original).
|
|
|
|
Returns:
|
|
Dictionary with new access token and optionally new refresh token.
|
|
|
|
Raises:
|
|
HTTPException with status 400 for invalid request.
|
|
"""
|
|
# Validate client credentials
|
|
client_service = ClientService(self.session)
|
|
if not await client_service.validate_client(client_id, client_secret):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid client credentials",
|
|
)
|
|
|
|
# Look up refresh token in database
|
|
token_record = await self.token_repo.get_by_refresh_token(refresh_token)
|
|
if not token_record:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Invalid refresh token",
|
|
)
|
|
|
|
# Verify token belongs to client
|
|
if token_record.client_id != client_id:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Refresh token does not belong to client",
|
|
)
|
|
|
|
# Check if token is expired
|
|
if token_record.expires_at < datetime.utcnow():
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Refresh token expired",
|
|
)
|
|
|
|
# Validate requested scopes (if any) are subset of original scopes
|
|
if scope:
|
|
scope_service = ScopeService(self.session)
|
|
if not scope_service.validate_scopes(scope, token_record.scopes):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Invalid scope",
|
|
)
|
|
else:
|
|
scope = token_record.scopes
|
|
|
|
# Generate new access token
|
|
subject = str(token_record.user_id) if token_record.user_id else token_record.client_id
|
|
access_token = self.token_service.create_access_token(
|
|
subject=subject,
|
|
client_id=client_id,
|
|
scopes=scope,
|
|
)
|
|
|
|
# Optionally generate new refresh token (rotation)
|
|
new_refresh_token = self.token_service.create_refresh_token(
|
|
subject=subject,
|
|
client_id=client_id,
|
|
scopes=scope,
|
|
)
|
|
|
|
# Store new tokens and revoke old refresh token
|
|
expires_at = datetime.utcnow() + timedelta(minutes=settings.oauth2_access_token_expire_minutes)
|
|
new_token_data = OAuthTokenCreate(
|
|
access_token=access_token,
|
|
refresh_token=new_refresh_token,
|
|
token_type="Bearer",
|
|
expires_at=expires_at,
|
|
scopes=scope,
|
|
client_id=client_id,
|
|
user_id=token_record.user_id,
|
|
)
|
|
await self.token_service.store_token(new_token_data)
|
|
await self.token_repo.revoke_token(token_record.id)
|
|
|
|
response = {
|
|
"access_token": access_token,
|
|
"token_type": "Bearer",
|
|
"expires_in": settings.oauth2_access_token_expire_minutes * 60,
|
|
"scope": " ".join(scope) if scope else "",
|
|
}
|
|
if new_refresh_token:
|
|
response["refresh_token"] = new_refresh_token
|
|
|
|
return response
|
|
|
|
|
|
class ClientService:
|
|
"""Service for OAuth client validation and secret verification."""
|
|
|
|
def __init__(self, session: AsyncSession):
|
|
self.session = session
|
|
self.client_repo = OAuthClientRepository(session)
|
|
|
|
async def validate_client(self, client_id: str, client_secret: str) -> bool:
|
|
"""
|
|
Validate client credentials.
|
|
|
|
Args:
|
|
client_id: Client identifier.
|
|
client_secret: Client secret (plaintext).
|
|
|
|
Returns:
|
|
True if credentials are valid, False otherwise.
|
|
"""
|
|
client = await self.client_repo.get_by_client_id(client_id)
|
|
if not client or not client.is_active:
|
|
return False
|
|
return await self.verify_client_secret(client_secret, client.client_secret)
|
|
|
|
async def verify_client_secret(self, plain_secret: str, hashed_secret: str) -> bool:
|
|
"""
|
|
Verify a client secret against its hash.
|
|
|
|
Args:
|
|
plain_secret: Plaintext secret.
|
|
hashed_secret: Hashed secret (bcrypt).
|
|
|
|
Returns:
|
|
True if secret matches, False otherwise.
|
|
"""
|
|
return verify_password(plain_secret, hashed_secret)
|
|
|
|
async def get_client_scopes(self, client_id: str) -> List[str]:
|
|
"""
|
|
Retrieve allowed scopes for a client.
|
|
|
|
Args:
|
|
client_id: Client identifier.
|
|
|
|
Returns:
|
|
List of scopes allowed for the client.
|
|
|
|
Raises:
|
|
HTTPException if client not found.
|
|
"""
|
|
client = await self.client_repo.get_by_client_id(client_id)
|
|
if not client:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Invalid client",
|
|
)
|
|
return client.scopes
|
|
|
|
|
|
class ScopeService:
|
|
"""Service for scope validation and management."""
|
|
|
|
def __init__(self, session: AsyncSession):
|
|
self.session = session
|
|
|
|
def validate_scopes(self, requested_scopes: List[str], allowed_scopes: List[str]) -> bool:
|
|
"""
|
|
Validate that requested scopes are subset of allowed scopes.
|
|
|
|
Args:
|
|
requested_scopes: List of scopes being requested.
|
|
allowed_scopes: List of scopes allowed for the client.
|
|
|
|
Returns:
|
|
True if all requested scopes are allowed, False otherwise.
|
|
"""
|
|
if not requested_scopes:
|
|
return True
|
|
return all(scope in allowed_scopes for scope in requested_scopes)
|
|
|
|
def check_scope_access(self, token_scopes: List[str], required_scopes: List[str]) -> bool:
|
|
"""
|
|
Check if token scopes satisfy required scopes.
|
|
|
|
Args:
|
|
token_scopes: Scopes granted to the token.
|
|
required_scopes: Scopes required for the endpoint.
|
|
|
|
Returns:
|
|
True if token has all required scopes, False otherwise.
|
|
"""
|
|
if not required_scopes:
|
|
return True
|
|
return all(scope in token_scopes for scope in required_scopes) |