import asyncio import json import logging import time from typing import Dict, Any, Optional, Tuple, Callable from uuid import uuid4 import jinja2 from fastapi import FastAPI, Request, Response, status, HTTPException from fastapi.routing import APIRoute from sqlalchemy.ext.asyncio import AsyncSession from app.core.config import settings from app.modules.endpoints.models.endpoint_model import Endpoint from app.modules.endpoints.repositories.endpoint_repository import EndpointRepository from app.modules.endpoints.services.template_service import TemplateService from app.modules.oauth2.services import TokenService, ScopeService logger = logging.getLogger(__name__) class RouteManager: """ Manages dynamic route registration and removal for the FastAPI application. """ __slots__ = ('app', 'async_session_factory', 'template_service', 'registered_routes', '_routes_lock') MAX_BODY_SIZE = 1024 * 1024 # 1 MB def __init__(self, app: FastAPI, async_session_factory: Callable[[], AsyncSession]): self.app = app self.async_session_factory = async_session_factory self.template_service = TemplateService() self.registered_routes: Dict[Tuple[str, str], str] = {} self._routes_lock = asyncio.Lock() # Maps (route, method) to route_id (used by FastAPI for removal) async def register_endpoint(self, endpoint: Endpoint) -> bool: """ Register a single endpoint as a route in the FastAPI app. Args: endpoint: The Endpoint model instance. Returns: True if registration succeeded, False otherwise. """ async with self._routes_lock: try: # Create a unique route identifier for FastAPI method = endpoint.method.upper() route_id = f"{method}_{endpoint.route}_{uuid4().hex[:8]}" # Create handler closure with endpoint data async def endpoint_handler(request: Request) -> Response: return await self._handle_request(request, endpoint) # Add route to FastAPI self.app.add_api_route( endpoint.route, endpoint_handler, methods=[method], name=route_id, response_model=None, # We'll return raw Response ) self.registered_routes[(endpoint.route, method)] = route_id logger.info(f"Registered endpoint {method} {endpoint.route}") return True except (ValueError, RuntimeError, TypeError, AttributeError) as e: logger.error(f"Failed to register endpoint {endpoint}: {e}", exc_info=settings.debug) return False async def unregister_endpoint(self, route: str, method: str) -> bool: """ Remove a previously registered route. Args: route: The endpoint route. method: HTTP method. Returns: True if removal succeeded, False otherwise. """ async with self._routes_lock: method = method.upper() key = (route, method) if key not in self.registered_routes: logger.warning(f"Route {method} {route} not registered") return False route_id = self.registered_routes[key] found = False # Find the route in the app's router and remove it for r in list(self.app.routes): if isinstance(r, APIRoute) and r.name == route_id: self.app.routes.remove(r) found = True break if found: logger.info(f"Unregistered endpoint {method} {route}") else: logger.warning(f"Route with ID {route_id} not found in FastAPI routes") # Always remove from registered_routes (cleanup) del self.registered_routes[key] return found async def refresh_routes(self) -> int: """ Reload all active endpoints from repository and register them. Removes any previously registered routes that are no longer active. Returns: Number of active routes after refresh. """ # Fetch active endpoints using a fresh session async with self.async_session_factory() as session: repository = EndpointRepository(session) active_endpoints = await repository.get_active() active_keys = {(e.route, e.method.upper()) for e in active_endpoints} async with self._routes_lock: # Unregister routes that are no longer active # Create a copy of items to avoid modification during iteration to_unregister = [] for (route, method), route_id in list(self.registered_routes.items()): if (route, method) not in active_keys: to_unregister.append((route, method)) # Register new active endpoints to_register = [] for endpoint in active_endpoints: key = (endpoint.route, endpoint.method.upper()) if key not in self.registered_routes: to_register.append(endpoint) # Now perform unregistration and registration without holding the lock # (each submethod will acquire its own lock) for route, method in to_unregister: await self.unregister_endpoint(route, method) registered_count = 0 for endpoint in to_register: success = await self.register_endpoint(endpoint) if success: registered_count += 1 logger.info(f"Routes refreshed. Total active routes: {len(self.registered_routes)}") return len(self.registered_routes) async def _handle_request(self, request: Request, endpoint: Endpoint) -> Response: """ Generic request handler for a registered endpoint. Args: request: FastAPI Request object. endpoint: Endpoint configuration. Returns: FastAPI Response object. """ # OAuth2 token validation if endpoint requires it if endpoint.requires_oauth: await self._validate_oauth_token(request, endpoint) # Apply artificial delay if configured if endpoint.delay_ms > 0: await asyncio.sleep(endpoint.delay_ms / 1000.0) # Gather variable sources context = await self._build_template_context(request, endpoint) try: # Render response body using Jinja2 template rendered_body = self.template_service.render( endpoint.response_body, context ) except jinja2.TemplateError as e: logger.error(f"Template rendering failed for endpoint {endpoint.id}: {e}", exc_info=settings.debug) return Response( content=json.dumps({"error": "Template rendering failed"}), status_code=500, media_type="application/json" ) # Build response with custom headers headers = dict(endpoint.headers or {}) response = Response( content=rendered_body, status_code=endpoint.response_code, headers=headers, media_type=endpoint.content_type ) return response async def _build_template_context(self, request: Request, endpoint: Endpoint) -> Dict[str, Any]: """ Build the template context from all variable sources. Sources: - Path parameters (from request.path_params) - Query parameters (from request.query_params) - Request headers - Request body (JSON or raw text) - System variables (timestamp, request_id, etc.) - Endpoint default variables Args: request: FastAPI Request object. endpoint: Endpoint configuration. Returns: Dictionary of template variables. """ context = {} # Path parameters context.update({f"path_{k}": v for k, v in request.path_params.items()}) context.update(request.path_params) # Query parameters query_params = dict(request.query_params) context.update({f"query_{k}": v for k, v in query_params.items()}) context.update(query_params) # Request headers headers = dict(request.headers) context.update({f"header_{k.lower()}": v for k, v in headers.items()}) context.update({k.lower(): v for k, v in headers.items()}) # Request body body = await self._extract_request_body(request) if body is not None: if isinstance(body, dict): context.update({f"body_{k}": v for k, v in body.items()}) context.update(body) else: context["body"] = body # System variables context.update(self._get_system_variables(request)) # Endpoint default variables if endpoint.variables: context.update(endpoint.variables) return context async def _extract_request_body(self, request: Request) -> Optional[Any]: """ Extract request body as JSON if possible, otherwise as text. Returns: Parsed JSON (dict/list) or raw string, or None if no body. """ # Check content-length header content_length = request.headers.get("content-length") if content_length: try: if int(content_length) > self.MAX_BODY_SIZE: raise HTTPException(status_code=413, detail="Request body too large") except ValueError: pass # Ignore malformed content-length content_type = request.headers.get("content-type", "") # Read body bytes once body_bytes = await request.body() if not body_bytes: return None # Check actual body size if len(body_bytes) > self.MAX_BODY_SIZE: raise HTTPException(status_code=413, detail="Request body too large") if "application/json" in content_type: try: return json.loads(body_bytes.decode("utf-8")) except json.JSONDecodeError: # Fallback to raw text pass # Return raw body as string return body_bytes.decode("utf-8", errors="ignore") async def _validate_oauth_token(self, request: Request, endpoint: Endpoint) -> Dict[str, Any]: """ Validate OAuth2 Bearer token for endpoints that require authentication. Args: request: FastAPI Request object. endpoint: Endpoint configuration. Returns: Validated token payload. Raises: HTTPException with status 401/403 for missing, invalid, or insufficient scope tokens. """ # Extract Bearer token from Authorization header auth_header = request.headers.get("Authorization") if not auth_header: logger.warning(f"OAuth2 token missing for endpoint {endpoint.method} {endpoint.route}") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Authorization header", headers={"WWW-Authenticate": "Bearer"}, ) # Check Bearer scheme parts = auth_header.split() if len(parts) != 2 or parts[0].lower() != "bearer": raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid Authorization header format. Expected: Bearer ", headers={"WWW-Authenticate": "Bearer error=\"invalid_token\""}, ) token = parts[1] # Create a database session and validate token async with self.async_session_factory() as session: token_service = TokenService(session) try: payload = await token_service.verify_token(token) except HTTPException: raise # Re-raise token validation errors except Exception as e: logger.error(f"Unexpected error during token validation: {e}", exc_info=settings.debug) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error", ) # Check scopes if endpoint specifies required scopes if endpoint.oauth_scopes: scope_service = ScopeService(session) token_scopes = payload.get("scopes", []) if not scope_service.check_scope_access(token_scopes, endpoint.oauth_scopes): logger.warning( f"Insufficient scopes for endpoint {endpoint.method} {endpoint.route}. " f"Token scopes: {token_scopes}, required: {endpoint.oauth_scopes}" ) raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Insufficient scope", headers={"WWW-Authenticate": f"Bearer error=\"insufficient_scope\", scope=\"{' '.join(endpoint.oauth_scopes)}\""}, ) logger.debug(f"OAuth2 token validated for endpoint {endpoint.method} {endpoint.route}, client_id: {payload.get('client_id')}, scopes: {payload.get('scopes')}") return payload def _get_system_variables(self, request: Request) -> Dict[str, Any]: """ Generate system variables (e.g., timestamp, request ID). Returns: Dictionary of system variables. """ return { "timestamp": time.time(), "datetime": time.strftime("%Y-%m-%d %H:%M:%S"), "request_id": str(uuid4()), "method": request.method, "url": str(request.url), "client_host": request.client.host if request.client else None, }