124 lines
No EOL
5 KiB
Python
124 lines
No EOL
5 KiB
Python
import json
|
|
from typing import Optional, Dict, Any
|
|
from datetime import datetime
|
|
from pydantic import BaseModel, Field, field_validator, ConfigDict, Json
|
|
|
|
|
|
HTTP_METHODS = {"GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS", "TRACE"}
|
|
|
|
|
|
class EndpointBase(BaseModel):
|
|
"""Base schema with common fields."""
|
|
route: str = Field(..., description="Endpoint route (must start with '/')", max_length=500)
|
|
method: str = Field(..., description="HTTP method", max_length=10)
|
|
response_body: str = Field(..., description="Response body (supports Jinja2 templating)")
|
|
response_code: int = Field(200, description="HTTP status code", ge=100, le=599)
|
|
content_type: str = Field("application/json", description="Content-Type header", max_length=100)
|
|
is_active: bool = Field(True, description="Whether endpoint is active")
|
|
variables: Dict[str, Any] = Field(default_factory=dict, description="Default template variables")
|
|
headers: Dict[str, str] = Field(default_factory=dict, description="Custom response headers")
|
|
delay_ms: int = Field(0, description="Artificial delay in milliseconds", ge=0, le=30000)
|
|
|
|
@field_validator("route")
|
|
def route_must_start_with_slash(cls, v):
|
|
if not v.startswith("/"):
|
|
raise ValueError("Route must start with '/'")
|
|
# Prevent path traversal
|
|
if ".." in v:
|
|
raise ValueError("Route must not contain '..'")
|
|
# Prevent consecutive slashes (simplifies routing)
|
|
if "//" in v:
|
|
raise ValueError("Route must not contain consecutive slashes '//'")
|
|
# Prevent backslashes
|
|
if "\\" in v:
|
|
raise ValueError("Route must not contain backslashes")
|
|
# Ensure path is not empty after slash
|
|
if v == "/":
|
|
return v
|
|
# Ensure no trailing slash? We'll allow.
|
|
return v
|
|
|
|
@field_validator("method")
|
|
def method_must_be_valid(cls, v):
|
|
method = v.upper()
|
|
if method not in HTTP_METHODS:
|
|
raise ValueError(f"Method must be one of {HTTP_METHODS}")
|
|
return method
|
|
|
|
|
|
@field_validator('variables', 'headers')
|
|
def validate_json_serializable(cls, v):
|
|
# Ensure the value is JSON serializable
|
|
try:
|
|
json.dumps(v)
|
|
except (TypeError, ValueError) as e:
|
|
raise ValueError(f"Value must be JSON serializable: {e}")
|
|
return v
|
|
|
|
|
|
class EndpointCreate(EndpointBase):
|
|
"""Schema for creating a new endpoint."""
|
|
pass
|
|
|
|
|
|
class EndpointUpdate(BaseModel):
|
|
"""Schema for updating an existing endpoint (all fields optional)."""
|
|
route: Optional[str] = Field(None, description="Endpoint route (must start with '/')", max_length=500)
|
|
method: Optional[str] = Field(None, description="HTTP method", max_length=10)
|
|
response_body: Optional[str] = Field(None, description="Response body (supports Jinja2 templating)")
|
|
response_code: Optional[int] = Field(None, description="HTTP status code", ge=100, le=599)
|
|
content_type: Optional[str] = Field(None, description="Content-Type header", max_length=100)
|
|
is_active: Optional[bool] = Field(None, description="Whether endpoint is active")
|
|
variables: Optional[Dict[str, Any]] = Field(None, description="Default template variables")
|
|
headers: Optional[Dict[str, str]] = Field(None, description="Custom response headers")
|
|
delay_ms: Optional[int] = Field(None, description="Artificial delay in milliseconds", ge=0, le=30000)
|
|
|
|
@field_validator("route")
|
|
def route_must_start_with_slash(cls, v):
|
|
if v is None:
|
|
return v
|
|
if not v.startswith("/"):
|
|
raise ValueError("Route must start with '/'")
|
|
# Prevent path traversal
|
|
if ".." in v:
|
|
raise ValueError("Route must not contain '..'")
|
|
# Prevent consecutive slashes (simplifies routing)
|
|
if "//" in v:
|
|
raise ValueError("Route must not contain consecutive slashes '//'")
|
|
# Prevent backslashes
|
|
if "\\" in v:
|
|
raise ValueError("Route must not contain backslashes")
|
|
# Ensure path is not empty after slash
|
|
if v == "/":
|
|
return v
|
|
# Ensure no trailing slash? We'll allow.
|
|
return v
|
|
|
|
@field_validator("method")
|
|
def method_must_be_valid(cls, v):
|
|
if v is None:
|
|
return v
|
|
method = v.upper()
|
|
if method not in HTTP_METHODS:
|
|
raise ValueError(f"Method must be one of {HTTP_METHODS}")
|
|
return method
|
|
|
|
@field_validator('variables', 'headers')
|
|
def validate_json_serializable(cls, v):
|
|
if v is None:
|
|
return v
|
|
# Ensure the value is JSON serializable
|
|
try:
|
|
json.dumps(v)
|
|
except (TypeError, ValueError) as e:
|
|
raise ValueError(f"Value must be JSON serializable: {e}")
|
|
return v
|
|
|
|
|
|
class EndpointResponse(EndpointBase):
|
|
"""Schema for returning an endpoint (includes ID and timestamps)."""
|
|
id: int
|
|
created_at: datetime
|
|
updated_at: datetime
|
|
|
|
model_config = ConfigDict(from_attributes=True) # Enables ORM mode (formerly `orm_mode`) |