426 lines
No EOL
15 KiB
Python
426 lines
No EOL
15 KiB
Python
"""
|
|
OAuth2 Controller - Implements OAuth2 and OpenID Connect endpoints as per RFC specifications.
|
|
"""
|
|
import logging
|
|
from typing import Optional, List, Dict, Any, Tuple
|
|
from urllib.parse import urlencode
|
|
from fastapi import (
|
|
APIRouter,
|
|
Depends,
|
|
HTTPException,
|
|
status,
|
|
Request,
|
|
Query,
|
|
Form,
|
|
Header,
|
|
Response,
|
|
)
|
|
from fastapi.responses import RedirectResponse, JSONResponse
|
|
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from database import get_db
|
|
from config import settings
|
|
from .services import (
|
|
OAuthService,
|
|
TokenService,
|
|
ClientService,
|
|
ScopeService,
|
|
)
|
|
from .dependencies import get_current_token_payload
|
|
from .auth_code_store import authorization_code_store
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix="/oauth", tags=["oauth2"])
|
|
security = HTTPBasic(auto_error=False)
|
|
|
|
|
|
def parse_scopes(scope_str: Optional[str]) -> List[str]:
|
|
"""Parse space-separated scope string into list of scopes."""
|
|
if not scope_str:
|
|
return []
|
|
return [s.strip() for s in scope_str.split(" ") if s.strip()]
|
|
|
|
|
|
def oauth_error_response(
|
|
error: str,
|
|
error_description: Optional[str] = None,
|
|
status_code: int = status.HTTP_400_BAD_REQUEST,
|
|
) -> None:
|
|
"""Raise HTTPException with OAuth2 error format."""
|
|
detail = {"error": error}
|
|
if error_description:
|
|
detail["error_description"] = error_description
|
|
raise HTTPException(status_code=status_code, detail=detail)
|
|
|
|
|
|
async def get_client_credentials(
|
|
request: Request,
|
|
credentials: Optional[HTTPBasicCredentials] = Depends(security),
|
|
client_id: Optional[str] = Form(None),
|
|
client_secret: Optional[str] = Form(None),
|
|
) -> Tuple[str, str]:
|
|
"""
|
|
Extract client credentials from either HTTP Basic auth header or request body.
|
|
Returns (client_id, client_secret).
|
|
Raises HTTPException if credentials are missing or invalid.
|
|
"""
|
|
# Priority: HTTP Basic auth over body parameters (RFC 6749 §2.3.1)
|
|
if credentials:
|
|
return credentials.username, credentials.password
|
|
|
|
# Fallback to body parameters
|
|
if client_id and client_secret:
|
|
return client_id, client_secret
|
|
|
|
# No credentials provided
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail={"error": "invalid_client", "error_description": "Client authentication failed"},
|
|
)
|
|
|
|
|
|
# ---------- Authorization Endpoint (Authorization Code Grant) ----------
|
|
@router.get("/authorize", response_class=RedirectResponse)
|
|
async def authorize(
|
|
request: Request,
|
|
response_type: str = Query(..., alias="response_type"),
|
|
client_id: str = Query(..., alias="client_id"),
|
|
redirect_uri: str = Query(..., alias="redirect_uri"),
|
|
scope: Optional[str] = Query(None, alias="scope"),
|
|
state: Optional[str] = Query(None, alias="state"),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""
|
|
OAuth2 Authorization Endpoint (RFC 6749 §4.1).
|
|
Validates the authorization request and returns a redirect to the client's redirect_uri
|
|
with an authorization code (and state if provided).
|
|
"""
|
|
# Only support authorization code grant for now
|
|
if response_type != "code":
|
|
oauth_error_response(
|
|
"unsupported_response_type",
|
|
"Only 'code' response_type is supported",
|
|
)
|
|
|
|
# Parse scopes
|
|
scope_list = parse_scopes(scope)
|
|
|
|
# For now, assume user is authenticated with user_id = 1 (placeholder)
|
|
# TODO: integrate with authentication system
|
|
user_id = 1
|
|
|
|
oauth_service = OAuthService(db)
|
|
try:
|
|
result = await oauth_service.authorize_code_flow(
|
|
client_id=client_id,
|
|
redirect_uri=redirect_uri,
|
|
scope=scope_list,
|
|
state=state,
|
|
user_id=user_id,
|
|
)
|
|
except HTTPException as e:
|
|
# Convert HTTPException to OAuth2 error response
|
|
error_detail = e.detail
|
|
if isinstance(error_detail, dict) and "error" in error_detail:
|
|
raise e
|
|
# Wrap generic errors
|
|
raise HTTPException(
|
|
status_code=e.status_code,
|
|
detail={"error": "invalid_request", "error_description": str(error_detail)}
|
|
)
|
|
|
|
code = result["code"]
|
|
state = result.get("state")
|
|
|
|
# Build redirect URL with code and state
|
|
params = {"code": code}
|
|
if state:
|
|
params["state"] = state
|
|
|
|
redirect_url = f"{redirect_uri}?{urlencode(params)}"
|
|
logger.debug(f"Redirecting to {redirect_url}")
|
|
return RedirectResponse(url=redirect_url, status_code=status.HTTP_302_FOUND)
|
|
|
|
|
|
# Optional: POST /authorize for consent submission (placeholder)
|
|
@router.post("/authorize", response_class=RedirectResponse)
|
|
async def authorize_post(
|
|
request: Request,
|
|
response_type: str = Form(...),
|
|
client_id: str = Form(...),
|
|
redirect_uri: str = Form(...),
|
|
scope: Optional[str] = Form(None),
|
|
state: Optional[str] = Form(None),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""Handle user consent submission (placeholder)."""
|
|
# For now, delegate to GET endpoint (same logic)
|
|
return await authorize(
|
|
request=request,
|
|
response_type=response_type,
|
|
client_id=client_id,
|
|
redirect_uri=redirect_uri,
|
|
scope=scope,
|
|
state=state,
|
|
db=db,
|
|
)
|
|
|
|
|
|
# ---------- Token Endpoint ----------
|
|
@router.post("/token", response_class=JSONResponse)
|
|
async def token(
|
|
request: Request,
|
|
grant_type: str = Form(...),
|
|
code: Optional[str] = Form(None),
|
|
redirect_uri: Optional[str] = Form(None),
|
|
refresh_token: Optional[str] = Form(None),
|
|
scope: Optional[str] = Form(None),
|
|
db: AsyncSession = Depends(get_db),
|
|
# Client credentials via dependency
|
|
client_credentials: Tuple[str, str] = Depends(get_client_credentials),
|
|
):
|
|
"""
|
|
OAuth2 Token Endpoint (RFC 6749 §4.1.3).
|
|
Supports authorization_code, client_credentials, and refresh_token grants.
|
|
"""
|
|
client_id, client_secret = client_credentials
|
|
scope_list = parse_scopes(scope)
|
|
oauth_service = OAuthService(db)
|
|
|
|
token_response: Optional[Dict[str, Any]] = None
|
|
|
|
if grant_type == "authorization_code":
|
|
if not code:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail={"error": "invalid_request", "error_description": "Missing 'code' parameter"}
|
|
)
|
|
if not redirect_uri:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail={"error": "invalid_request", "error_description": "Missing 'redirect_uri' parameter"}
|
|
)
|
|
# At this point, code and redirect_uri are not None (type narrowing)
|
|
assert code is not None
|
|
assert redirect_uri is not None
|
|
|
|
try:
|
|
token_response = await oauth_service.exchange_code_for_tokens(
|
|
code=code,
|
|
client_id=client_id,
|
|
redirect_uri=redirect_uri,
|
|
)
|
|
except HTTPException as e:
|
|
error_detail = e.detail
|
|
if isinstance(error_detail, dict) and "error" in error_detail:
|
|
raise e
|
|
raise HTTPException(
|
|
status_code=e.status_code,
|
|
detail={"error": "invalid_grant", "error_description": str(error_detail)}
|
|
)
|
|
|
|
elif grant_type == "client_credentials":
|
|
try:
|
|
token_response = await oauth_service.client_credentials_flow(
|
|
client_id=client_id,
|
|
client_secret=client_secret,
|
|
scope=scope_list,
|
|
)
|
|
except HTTPException as e:
|
|
error_detail = e.detail
|
|
if isinstance(error_detail, dict) and "error" in error_detail:
|
|
raise e
|
|
raise HTTPException(
|
|
status_code=e.status_code,
|
|
detail={"error": "invalid_client", "error_description": str(error_detail)}
|
|
)
|
|
|
|
elif grant_type == "refresh_token":
|
|
if not refresh_token:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail={"error": "invalid_request", "error_description": "Missing 'refresh_token' parameter"}
|
|
)
|
|
assert refresh_token is not None
|
|
|
|
try:
|
|
token_response = await oauth_service.refresh_token_flow(
|
|
refresh_token=refresh_token,
|
|
client_id=client_id,
|
|
client_secret=client_secret,
|
|
scope=scope_list,
|
|
)
|
|
except HTTPException as e:
|
|
error_detail = e.detail
|
|
if isinstance(error_detail, dict) and "error" in error_detail:
|
|
raise e
|
|
raise HTTPException(
|
|
status_code=e.status_code,
|
|
detail={"error": "invalid_grant", "error_description": str(error_detail)}
|
|
)
|
|
|
|
else:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail={"error": "unsupported_grant_type"}
|
|
)
|
|
|
|
# token_response must be set at this point
|
|
assert token_response is not None
|
|
# Ensure token_type is Bearer (default)
|
|
token_response.setdefault("token_type", "Bearer")
|
|
return JSONResponse(content=token_response, status_code=status.HTTP_200_OK)
|
|
|
|
|
|
# ---------- UserInfo Endpoint (OpenID Connect) ----------
|
|
@router.get("/userinfo")
|
|
async def userinfo(
|
|
payload: Dict[str, Any] = Depends(get_current_token_payload),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
"""
|
|
OpenID Connect UserInfo Endpoint (OIDC Core §5.3).
|
|
Returns claims about the authenticated user (subject).
|
|
"""
|
|
# Extract subject from token payload
|
|
subject = payload.get("sub")
|
|
client_id = payload.get("client_id")
|
|
scopes = payload.get("scopes", [])
|
|
|
|
# TODO: Fetch user details from OAuthUser table when available
|
|
# For now, return minimal claims
|
|
user_info = {
|
|
"sub": subject,
|
|
"client_id": client_id,
|
|
"scope": " ".join(scopes) if scopes else "",
|
|
}
|
|
|
|
# Optionally add additional claims based on scopes
|
|
if "email" in scopes:
|
|
user_info["email"] = f"{subject}@example.com" # placeholder
|
|
if "profile" in scopes:
|
|
user_info["name"] = f"User {subject}" # placeholder
|
|
|
|
return JSONResponse(content=user_info)
|
|
|
|
|
|
# ---------- Token Introspection Endpoint (RFC 7662) ----------
|
|
@router.post("/introspect", response_class=JSONResponse)
|
|
async def introspect(
|
|
request: Request,
|
|
token: str = Form(...),
|
|
token_type_hint: Optional[str] = Form(None),
|
|
db: AsyncSession = Depends(get_db),
|
|
client_credentials: Tuple[str, str] = Depends(get_client_credentials),
|
|
):
|
|
"""
|
|
OAuth2 Token Introspection Endpoint (RFC 7662).
|
|
Requires client authentication (any valid client for now).
|
|
Returns metadata about the token, including active status.
|
|
"""
|
|
client_id, client_secret = client_credentials
|
|
|
|
# Validate client credentials
|
|
client_service = ClientService(db)
|
|
if not await client_service.validate_client(client_id, client_secret):
|
|
oauth_error_response(
|
|
"invalid_client",
|
|
"Client authentication failed",
|
|
status.HTTP_401_UNAUTHORIZED,
|
|
)
|
|
|
|
token_service = TokenService(db)
|
|
|
|
# Try to verify token (validates signature, expiration, and revocation)
|
|
try:
|
|
payload = await token_service.verify_token(token)
|
|
active = True
|
|
except HTTPException:
|
|
# Token is invalid, expired, or revoked
|
|
active = False
|
|
payload = None
|
|
|
|
# Build introspection response according to RFC 7662 §2.2
|
|
response: Dict[str, Any] = {"active": active}
|
|
if active and payload:
|
|
# Include token metadata
|
|
response.update({
|
|
"client_id": payload.get("client_id"),
|
|
"sub": payload.get("sub"),
|
|
"scope": " ".join(payload.get("scopes", [])),
|
|
"token_type": payload.get("token_type", "Bearer"),
|
|
"exp": payload.get("exp"),
|
|
"iat": payload.get("iat"),
|
|
"jti": payload.get("jti"),
|
|
})
|
|
|
|
return JSONResponse(content=response)
|
|
|
|
|
|
# ---------- Token Revocation Endpoint (RFC 7009) ----------
|
|
@router.post("/revoke", response_class=Response)
|
|
async def revoke(
|
|
request: Request,
|
|
token: str = Form(...),
|
|
token_type_hint: Optional[str] = Form(None),
|
|
db: AsyncSession = Depends(get_db),
|
|
client_credentials: Tuple[str, str] = Depends(get_client_credentials),
|
|
):
|
|
"""
|
|
OAuth2 Token Revocation Endpoint (RFC 7009).
|
|
Requires client authentication (client must own the token or be privileged).
|
|
Revokes the given token (access or refresh).
|
|
"""
|
|
client_id, client_secret = client_credentials
|
|
|
|
# Validate client credentials
|
|
client_service = ClientService(db)
|
|
if not await client_service.validate_client(client_id, client_secret):
|
|
oauth_error_response(
|
|
"invalid_client",
|
|
"Client authentication failed",
|
|
status.HTTP_401_UNAUTHORIZED,
|
|
)
|
|
|
|
token_service = TokenService(db)
|
|
|
|
# TODO: Verify that the client owns the token (optional for now)
|
|
# For simplicity, any authenticated client can revoke any token.
|
|
# In production, you should check token ownership.
|
|
|
|
success = await token_service.revoke_token(token)
|
|
if not success:
|
|
# Token might already be revoked or not found
|
|
logger.warning(f"Token revocation failed for token (client: {client_id})")
|
|
|
|
# RFC 7009 §2.2: successful revocation returns HTTP 200 with empty body
|
|
return Response(status_code=status.HTTP_200_OK)
|
|
|
|
|
|
# ---------- OpenID Connect Discovery Endpoint ----------
|
|
@router.get("/.well-known/openid-configuration")
|
|
async def openid_configuration(request: Request):
|
|
"""
|
|
OpenID Connect Discovery Endpoint (OIDC Discovery §4).
|
|
Returns provider configuration metadata.
|
|
"""
|
|
base_url = str(request.base_url).rstrip("/")
|
|
config = {
|
|
"issuer": settings.oauth2_issuer,
|
|
"authorization_endpoint": f"{base_url}/oauth/authorize",
|
|
"token_endpoint": f"{base_url}/oauth/token",
|
|
"userinfo_endpoint": f"{base_url}/oauth/userinfo",
|
|
"introspection_endpoint": f"{base_url}/oauth/introspect",
|
|
"revocation_endpoint": f"{base_url}/oauth/revoke",
|
|
"jwks_uri": None, # Not implemented yet
|
|
"scopes_supported": settings.oauth2_supported_scopes,
|
|
"response_types_supported": ["code"],
|
|
"grant_types_supported": settings.oauth2_supported_grant_types,
|
|
"token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"],
|
|
"id_token_signing_alg_values_supported": [], # Not using ID tokens yet
|
|
"subject_types_supported": ["public"],
|
|
"claims_supported": ["sub", "client_id", "scope"],
|
|
}
|
|
return JSONResponse(content=config) |