370 lines
No EOL
14 KiB
Python
370 lines
No EOL
14 KiB
Python
import asyncio
|
|
import json
|
|
import logging
|
|
import time
|
|
from typing import Dict, Any, Optional, Tuple, Callable
|
|
from uuid import uuid4
|
|
import jinja2
|
|
|
|
from fastapi import FastAPI, Request, Response, status, HTTPException
|
|
from fastapi.routing import APIRoute
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from app.core.config import settings
|
|
|
|
from app.modules.endpoints.models.endpoint_model import Endpoint
|
|
from app.modules.endpoints.repositories.endpoint_repository import EndpointRepository
|
|
from app.modules.endpoints.services.template_service import TemplateService
|
|
from app.modules.oauth2.services import TokenService, ScopeService
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RouteManager:
|
|
"""
|
|
Manages dynamic route registration and removal for the FastAPI application.
|
|
"""
|
|
__slots__ = ('app', 'async_session_factory', 'template_service', 'registered_routes', '_routes_lock')
|
|
MAX_BODY_SIZE = 1024 * 1024 # 1 MB
|
|
|
|
def __init__(self, app: FastAPI, async_session_factory: Callable[[], AsyncSession]):
|
|
self.app = app
|
|
self.async_session_factory = async_session_factory
|
|
self.template_service = TemplateService()
|
|
self.registered_routes: Dict[Tuple[str, str], str] = {}
|
|
self._routes_lock = asyncio.Lock()
|
|
# Maps (route, method) to route_id (used by FastAPI for removal)
|
|
|
|
async def register_endpoint(self, endpoint: Endpoint) -> bool:
|
|
"""
|
|
Register a single endpoint as a route in the FastAPI app.
|
|
|
|
Args:
|
|
endpoint: The Endpoint model instance.
|
|
|
|
Returns:
|
|
True if registration succeeded, False otherwise.
|
|
"""
|
|
async with self._routes_lock:
|
|
try:
|
|
# Create a unique route identifier for FastAPI
|
|
method = endpoint.method.upper()
|
|
route_id = f"{method}_{endpoint.route}_{uuid4().hex[:8]}"
|
|
|
|
# Create handler closure with endpoint data
|
|
async def endpoint_handler(request: Request) -> Response:
|
|
return await self._handle_request(request, endpoint)
|
|
|
|
# Add route to FastAPI
|
|
self.app.add_api_route(
|
|
endpoint.route,
|
|
endpoint_handler,
|
|
methods=[method],
|
|
name=route_id,
|
|
response_model=None, # We'll return raw Response
|
|
)
|
|
|
|
self.registered_routes[(endpoint.route, method)] = route_id
|
|
logger.info(f"Registered endpoint {method} {endpoint.route}")
|
|
return True
|
|
|
|
except (ValueError, RuntimeError, TypeError, AttributeError) as e:
|
|
logger.error(f"Failed to register endpoint {endpoint}: {e}", exc_info=settings.debug)
|
|
return False
|
|
|
|
async def unregister_endpoint(self, route: str, method: str) -> bool:
|
|
"""
|
|
Remove a previously registered route.
|
|
|
|
Args:
|
|
route: The endpoint route.
|
|
method: HTTP method.
|
|
|
|
Returns:
|
|
True if removal succeeded, False otherwise.
|
|
"""
|
|
async with self._routes_lock:
|
|
method = method.upper()
|
|
key = (route, method)
|
|
if key not in self.registered_routes:
|
|
logger.warning(f"Route {method} {route} not registered")
|
|
return False
|
|
|
|
route_id = self.registered_routes[key]
|
|
found = False
|
|
# Find the route in the app's router and remove it
|
|
for r in list(self.app.routes):
|
|
if isinstance(r, APIRoute) and r.name == route_id:
|
|
self.app.routes.remove(r)
|
|
found = True
|
|
break
|
|
|
|
if found:
|
|
logger.info(f"Unregistered endpoint {method} {route}")
|
|
else:
|
|
logger.warning(f"Route with ID {route_id} not found in FastAPI routes")
|
|
# Always remove from registered_routes (cleanup)
|
|
del self.registered_routes[key]
|
|
return found
|
|
|
|
async def refresh_routes(self) -> int:
|
|
"""
|
|
Reload all active endpoints from repository and register them.
|
|
Removes any previously registered routes that are no longer active.
|
|
|
|
Returns:
|
|
Number of active routes after refresh.
|
|
"""
|
|
# Fetch active endpoints using a fresh session
|
|
async with self.async_session_factory() as session:
|
|
repository = EndpointRepository(session)
|
|
active_endpoints = await repository.get_active()
|
|
active_keys = {(e.route, e.method.upper()) for e in active_endpoints}
|
|
|
|
async with self._routes_lock:
|
|
# Unregister routes that are no longer active
|
|
# Create a copy of items to avoid modification during iteration
|
|
to_unregister = []
|
|
for (route, method), route_id in list(self.registered_routes.items()):
|
|
if (route, method) not in active_keys:
|
|
to_unregister.append((route, method))
|
|
|
|
# Register new active endpoints
|
|
to_register = []
|
|
for endpoint in active_endpoints:
|
|
key = (endpoint.route, endpoint.method.upper())
|
|
if key not in self.registered_routes:
|
|
to_register.append(endpoint)
|
|
|
|
# Now perform unregistration and registration without holding the lock
|
|
# (each submethod will acquire its own lock)
|
|
for route, method in to_unregister:
|
|
await self.unregister_endpoint(route, method)
|
|
|
|
registered_count = 0
|
|
for endpoint in to_register:
|
|
success = await self.register_endpoint(endpoint)
|
|
if success:
|
|
registered_count += 1
|
|
|
|
logger.info(f"Routes refreshed. Total active routes: {len(self.registered_routes)}")
|
|
return len(self.registered_routes)
|
|
|
|
async def _handle_request(self, request: Request, endpoint: Endpoint) -> Response:
|
|
"""
|
|
Generic request handler for a registered endpoint.
|
|
|
|
Args:
|
|
request: FastAPI Request object.
|
|
endpoint: Endpoint configuration.
|
|
|
|
Returns:
|
|
FastAPI Response object.
|
|
"""
|
|
# OAuth2 token validation if endpoint requires it
|
|
if endpoint.requires_oauth:
|
|
await self._validate_oauth_token(request, endpoint)
|
|
|
|
# Apply artificial delay if configured
|
|
if endpoint.delay_ms > 0:
|
|
await asyncio.sleep(endpoint.delay_ms / 1000.0)
|
|
|
|
# Gather variable sources
|
|
context = await self._build_template_context(request, endpoint)
|
|
|
|
try:
|
|
# Render response body using Jinja2 template
|
|
rendered_body = self.template_service.render(
|
|
endpoint.response_body,
|
|
context
|
|
)
|
|
except jinja2.TemplateError as e:
|
|
logger.error(f"Template rendering failed for endpoint {endpoint.id}: {e}", exc_info=settings.debug)
|
|
return Response(
|
|
content=json.dumps({"error": "Template rendering failed"}),
|
|
status_code=500,
|
|
media_type="application/json"
|
|
)
|
|
|
|
# Build response with custom headers
|
|
headers = dict(endpoint.headers or {})
|
|
response = Response(
|
|
content=rendered_body,
|
|
status_code=endpoint.response_code,
|
|
headers=headers,
|
|
media_type=endpoint.content_type
|
|
)
|
|
return response
|
|
|
|
async def _build_template_context(self, request: Request, endpoint: Endpoint) -> Dict[str, Any]:
|
|
"""
|
|
Build the template context from all variable sources.
|
|
|
|
Sources:
|
|
- Path parameters (from request.path_params)
|
|
- Query parameters (from request.query_params)
|
|
- Request headers
|
|
- Request body (JSON or raw text)
|
|
- System variables (timestamp, request_id, etc.)
|
|
- Endpoint default variables
|
|
|
|
Args:
|
|
request: FastAPI Request object.
|
|
endpoint: Endpoint configuration.
|
|
|
|
Returns:
|
|
Dictionary of template variables.
|
|
"""
|
|
context = {}
|
|
|
|
# Path parameters
|
|
context.update({f"path_{k}": v for k, v in request.path_params.items()})
|
|
context.update(request.path_params)
|
|
|
|
# Query parameters
|
|
query_params = dict(request.query_params)
|
|
context.update({f"query_{k}": v for k, v in query_params.items()})
|
|
context.update(query_params)
|
|
|
|
# Request headers
|
|
headers = dict(request.headers)
|
|
context.update({f"header_{k.lower()}": v for k, v in headers.items()})
|
|
context.update({k.lower(): v for k, v in headers.items()})
|
|
|
|
# Request body
|
|
body = await self._extract_request_body(request)
|
|
if body is not None:
|
|
if isinstance(body, dict):
|
|
context.update({f"body_{k}": v for k, v in body.items()})
|
|
context.update(body)
|
|
else:
|
|
context["body"] = body
|
|
|
|
# System variables
|
|
context.update(self._get_system_variables(request))
|
|
|
|
# Endpoint default variables
|
|
if endpoint.variables:
|
|
context.update(endpoint.variables)
|
|
|
|
return context
|
|
|
|
async def _extract_request_body(self, request: Request) -> Optional[Any]:
|
|
"""
|
|
Extract request body as JSON if possible, otherwise as text.
|
|
|
|
Returns:
|
|
Parsed JSON (dict/list) or raw string, or None if no body.
|
|
"""
|
|
# Check content-length header
|
|
content_length = request.headers.get("content-length")
|
|
if content_length:
|
|
try:
|
|
if int(content_length) > self.MAX_BODY_SIZE:
|
|
raise HTTPException(status_code=413, detail="Request body too large")
|
|
except ValueError:
|
|
pass # Ignore malformed content-length
|
|
|
|
content_type = request.headers.get("content-type", "")
|
|
|
|
# Read body bytes once
|
|
body_bytes = await request.body()
|
|
if not body_bytes:
|
|
return None
|
|
|
|
# Check actual body size
|
|
if len(body_bytes) > self.MAX_BODY_SIZE:
|
|
raise HTTPException(status_code=413, detail="Request body too large")
|
|
|
|
if "application/json" in content_type:
|
|
try:
|
|
return json.loads(body_bytes.decode("utf-8"))
|
|
except json.JSONDecodeError:
|
|
# Fallback to raw text
|
|
pass
|
|
|
|
# Return raw body as string
|
|
return body_bytes.decode("utf-8", errors="ignore")
|
|
|
|
async def _validate_oauth_token(self, request: Request, endpoint: Endpoint) -> Dict[str, Any]:
|
|
"""
|
|
Validate OAuth2 Bearer token for endpoints that require authentication.
|
|
|
|
Args:
|
|
request: FastAPI Request object.
|
|
endpoint: Endpoint configuration.
|
|
|
|
Returns:
|
|
Validated token payload.
|
|
|
|
Raises:
|
|
HTTPException with status 401/403 for missing, invalid, or insufficient scope tokens.
|
|
"""
|
|
# Extract Bearer token from Authorization header
|
|
auth_header = request.headers.get("Authorization")
|
|
if not auth_header:
|
|
logger.warning(f"OAuth2 token missing for endpoint {endpoint.method} {endpoint.route}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Missing Authorization header",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
# Check Bearer scheme
|
|
parts = auth_header.split()
|
|
if len(parts) != 2 or parts[0].lower() != "bearer":
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid Authorization header format. Expected: Bearer <token>",
|
|
headers={"WWW-Authenticate": "Bearer error=\"invalid_token\""},
|
|
)
|
|
|
|
token = parts[1]
|
|
|
|
# Create a database session and validate token
|
|
async with self.async_session_factory() as session:
|
|
token_service = TokenService(session)
|
|
try:
|
|
payload = await token_service.verify_token(token)
|
|
except HTTPException:
|
|
raise # Re-raise token validation errors
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error during token validation: {e}", exc_info=settings.debug)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Internal server error",
|
|
)
|
|
|
|
# Check scopes if endpoint specifies required scopes
|
|
if endpoint.oauth_scopes:
|
|
scope_service = ScopeService(session)
|
|
token_scopes = payload.get("scopes", [])
|
|
if not scope_service.check_scope_access(token_scopes, endpoint.oauth_scopes):
|
|
logger.warning(
|
|
f"Insufficient scopes for endpoint {endpoint.method} {endpoint.route}. "
|
|
f"Token scopes: {token_scopes}, required: {endpoint.oauth_scopes}"
|
|
)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Insufficient scope",
|
|
headers={"WWW-Authenticate": f"Bearer error=\"insufficient_scope\", scope=\"{' '.join(endpoint.oauth_scopes)}\""},
|
|
)
|
|
|
|
logger.debug(f"OAuth2 token validated for endpoint {endpoint.method} {endpoint.route}, client_id: {payload.get('client_id')}, scopes: {payload.get('scopes')}")
|
|
return payload
|
|
|
|
def _get_system_variables(self, request: Request) -> Dict[str, Any]:
|
|
"""
|
|
Generate system variables (e.g., timestamp, request ID).
|
|
|
|
Returns:
|
|
Dictionary of system variables.
|
|
"""
|
|
return {
|
|
"timestamp": time.time(),
|
|
"datetime": time.strftime("%Y-%m-%d %H:%M:%S"),
|
|
"request_id": str(uuid4()),
|
|
"method": request.method,
|
|
"url": str(request.url),
|
|
"client_host": request.client.host if request.client else None,
|
|
} |