mockapi/services/route_service.py
2026-03-16 05:47:01 +00:00

370 lines
No EOL
14 KiB
Python

import asyncio
import json
import logging
import time
from typing import Dict, Any, Optional, Tuple, Callable
from uuid import uuid4
import jinja2
from fastapi import FastAPI, Request, Response, status, HTTPException
from fastapi.routing import APIRoute
from sqlalchemy.ext.asyncio import AsyncSession
from config import settings
from models.endpoint_model import Endpoint
from repositories.endpoint_repository import EndpointRepository
from services.template_service import TemplateService
from oauth2.services import TokenService, ScopeService
logger = logging.getLogger(__name__)
class RouteManager:
"""
Manages dynamic route registration and removal for the FastAPI application.
"""
__slots__ = ('app', 'async_session_factory', 'template_service', 'registered_routes', '_routes_lock')
MAX_BODY_SIZE = 1024 * 1024 # 1 MB
def __init__(self, app: FastAPI, async_session_factory: Callable[[], AsyncSession]):
self.app = app
self.async_session_factory = async_session_factory
self.template_service = TemplateService()
self.registered_routes: Dict[Tuple[str, str], str] = {}
self._routes_lock = asyncio.Lock()
# Maps (route, method) to route_id (used by FastAPI for removal)
async def register_endpoint(self, endpoint: Endpoint) -> bool:
"""
Register a single endpoint as a route in the FastAPI app.
Args:
endpoint: The Endpoint model instance.
Returns:
True if registration succeeded, False otherwise.
"""
async with self._routes_lock:
try:
# Create a unique route identifier for FastAPI
method = endpoint.method.upper()
route_id = f"{method}_{endpoint.route}_{uuid4().hex[:8]}"
# Create handler closure with endpoint data
async def endpoint_handler(request: Request) -> Response:
return await self._handle_request(request, endpoint)
# Add route to FastAPI
self.app.add_api_route(
endpoint.route,
endpoint_handler,
methods=[method],
name=route_id,
response_model=None, # We'll return raw Response
)
self.registered_routes[(endpoint.route, method)] = route_id
logger.info(f"Registered endpoint {method} {endpoint.route}")
return True
except (ValueError, RuntimeError, TypeError, AttributeError) as e:
logger.error(f"Failed to register endpoint {endpoint}: {e}", exc_info=settings.debug)
return False
async def unregister_endpoint(self, route: str, method: str) -> bool:
"""
Remove a previously registered route.
Args:
route: The endpoint route.
method: HTTP method.
Returns:
True if removal succeeded, False otherwise.
"""
async with self._routes_lock:
method = method.upper()
key = (route, method)
if key not in self.registered_routes:
logger.warning(f"Route {method} {route} not registered")
return False
route_id = self.registered_routes[key]
found = False
# Find the route in the app's router and remove it
for r in list(self.app.routes):
if isinstance(r, APIRoute) and r.name == route_id:
self.app.routes.remove(r)
found = True
break
if found:
logger.info(f"Unregistered endpoint {method} {route}")
else:
logger.warning(f"Route with ID {route_id} not found in FastAPI routes")
# Always remove from registered_routes (cleanup)
del self.registered_routes[key]
return found
async def refresh_routes(self) -> int:
"""
Reload all active endpoints from repository and register them.
Removes any previously registered routes that are no longer active.
Returns:
Number of active routes after refresh.
"""
# Fetch active endpoints using a fresh session
async with self.async_session_factory() as session:
repository = EndpointRepository(session)
active_endpoints = await repository.get_active()
active_keys = {(e.route, e.method.upper()) for e in active_endpoints}
async with self._routes_lock:
# Unregister routes that are no longer active
# Create a copy of items to avoid modification during iteration
to_unregister = []
for (route, method), route_id in list(self.registered_routes.items()):
if (route, method) not in active_keys:
to_unregister.append((route, method))
# Register new active endpoints
to_register = []
for endpoint in active_endpoints:
key = (endpoint.route, endpoint.method.upper())
if key not in self.registered_routes:
to_register.append(endpoint)
# Now perform unregistration and registration without holding the lock
# (each submethod will acquire its own lock)
for route, method in to_unregister:
await self.unregister_endpoint(route, method)
registered_count = 0
for endpoint in to_register:
success = await self.register_endpoint(endpoint)
if success:
registered_count += 1
logger.info(f"Routes refreshed. Total active routes: {len(self.registered_routes)}")
return len(self.registered_routes)
async def _handle_request(self, request: Request, endpoint: Endpoint) -> Response:
"""
Generic request handler for a registered endpoint.
Args:
request: FastAPI Request object.
endpoint: Endpoint configuration.
Returns:
FastAPI Response object.
"""
# OAuth2 token validation if endpoint requires it
if endpoint.requires_oauth:
await self._validate_oauth_token(request, endpoint)
# Apply artificial delay if configured
if endpoint.delay_ms > 0:
await asyncio.sleep(endpoint.delay_ms / 1000.0)
# Gather variable sources
context = await self._build_template_context(request, endpoint)
try:
# Render response body using Jinja2 template
rendered_body = self.template_service.render(
endpoint.response_body,
context
)
except jinja2.TemplateError as e:
logger.error(f"Template rendering failed for endpoint {endpoint.id}: {e}", exc_info=settings.debug)
return Response(
content=json.dumps({"error": "Template rendering failed"}),
status_code=500,
media_type="application/json"
)
# Build response with custom headers
headers = dict(endpoint.headers or {})
response = Response(
content=rendered_body,
status_code=endpoint.response_code,
headers=headers,
media_type=endpoint.content_type
)
return response
async def _build_template_context(self, request: Request, endpoint: Endpoint) -> Dict[str, Any]:
"""
Build the template context from all variable sources.
Sources:
- Path parameters (from request.path_params)
- Query parameters (from request.query_params)
- Request headers
- Request body (JSON or raw text)
- System variables (timestamp, request_id, etc.)
- Endpoint default variables
Args:
request: FastAPI Request object.
endpoint: Endpoint configuration.
Returns:
Dictionary of template variables.
"""
context = {}
# Path parameters
context.update({f"path_{k}": v for k, v in request.path_params.items()})
context.update(request.path_params)
# Query parameters
query_params = dict(request.query_params)
context.update({f"query_{k}": v for k, v in query_params.items()})
context.update(query_params)
# Request headers
headers = dict(request.headers)
context.update({f"header_{k.lower()}": v for k, v in headers.items()})
context.update({k.lower(): v for k, v in headers.items()})
# Request body
body = await self._extract_request_body(request)
if body is not None:
if isinstance(body, dict):
context.update({f"body_{k}": v for k, v in body.items()})
context.update(body)
else:
context["body"] = body
# System variables
context.update(self._get_system_variables(request))
# Endpoint default variables
if endpoint.variables:
context.update(endpoint.variables)
return context
async def _extract_request_body(self, request: Request) -> Optional[Any]:
"""
Extract request body as JSON if possible, otherwise as text.
Returns:
Parsed JSON (dict/list) or raw string, or None if no body.
"""
# Check content-length header
content_length = request.headers.get("content-length")
if content_length:
try:
if int(content_length) > self.MAX_BODY_SIZE:
raise HTTPException(status_code=413, detail="Request body too large")
except ValueError:
pass # Ignore malformed content-length
content_type = request.headers.get("content-type", "")
# Read body bytes once
body_bytes = await request.body()
if not body_bytes:
return None
# Check actual body size
if len(body_bytes) > self.MAX_BODY_SIZE:
raise HTTPException(status_code=413, detail="Request body too large")
if "application/json" in content_type:
try:
return json.loads(body_bytes.decode("utf-8"))
except json.JSONDecodeError:
# Fallback to raw text
pass
# Return raw body as string
return body_bytes.decode("utf-8", errors="ignore")
async def _validate_oauth_token(self, request: Request, endpoint: Endpoint) -> Dict[str, Any]:
"""
Validate OAuth2 Bearer token for endpoints that require authentication.
Args:
request: FastAPI Request object.
endpoint: Endpoint configuration.
Returns:
Validated token payload.
Raises:
HTTPException with status 401/403 for missing, invalid, or insufficient scope tokens.
"""
# Extract Bearer token from Authorization header
auth_header = request.headers.get("Authorization")
if not auth_header:
logger.warning(f"OAuth2 token missing for endpoint {endpoint.method} {endpoint.route}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing Authorization header",
headers={"WWW-Authenticate": "Bearer"},
)
# Check Bearer scheme
parts = auth_header.split()
if len(parts) != 2 or parts[0].lower() != "bearer":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid Authorization header format. Expected: Bearer <token>",
headers={"WWW-Authenticate": "Bearer error=\"invalid_token\""},
)
token = parts[1]
# Create a database session and validate token
async with self.async_session_factory() as session:
token_service = TokenService(session)
try:
payload = await token_service.verify_token(token)
except HTTPException:
raise # Re-raise token validation errors
except Exception as e:
logger.error(f"Unexpected error during token validation: {e}", exc_info=settings.debug)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Internal server error",
)
# Check scopes if endpoint specifies required scopes
if endpoint.oauth_scopes:
scope_service = ScopeService(session)
token_scopes = payload.get("scopes", [])
if not scope_service.check_scope_access(token_scopes, endpoint.oauth_scopes):
logger.warning(
f"Insufficient scopes for endpoint {endpoint.method} {endpoint.route}. "
f"Token scopes: {token_scopes}, required: {endpoint.oauth_scopes}"
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Insufficient scope",
headers={"WWW-Authenticate": f"Bearer error=\"insufficient_scope\", scope=\"{' '.join(endpoint.oauth_scopes)}\""},
)
logger.debug(f"OAuth2 token validated for endpoint {endpoint.method} {endpoint.route}, client_id: {payload.get('client_id')}, scopes: {payload.get('scopes')}")
return payload
def _get_system_variables(self, request: Request) -> Dict[str, Any]:
"""
Generate system variables (e.g., timestamp, request ID).
Returns:
Dictionary of system variables.
"""
return {
"timestamp": time.time(),
"datetime": time.strftime("%Y-%m-%d %H:%M:%S"),
"request_id": str(uuid4()),
"method": request.method,
"url": str(request.url),
"client_host": request.client.host if request.client else None,
}