272 lines
No EOL
10 KiB
Python
272 lines
No EOL
10 KiB
Python
import json
|
|
import re
|
|
from typing import Optional, List
|
|
from datetime import datetime
|
|
from pydantic import BaseModel, Field, field_validator, ConfigDict
|
|
from urllib.parse import urlparse
|
|
|
|
|
|
ALLOWED_GRANT_TYPES = {"authorization_code", "client_credentials", "password", "refresh_token"}
|
|
|
|
|
|
class OAuthClientBase(BaseModel):
|
|
"""Base schema for OAuthClient with common fields."""
|
|
client_id: str = Field(..., description="Unique client identifier", max_length=100)
|
|
client_secret: str = Field(..., description="Client secret (plaintext for input, will be hashed)", max_length=255)
|
|
name: str = Field(..., description="Human-readable client name", max_length=200)
|
|
redirect_uris: List[str] = Field(default_factory=list, description="Allowed redirect URIs")
|
|
grant_types: List[str] = Field(default_factory=list, description="Allowed grant types")
|
|
scopes: List[str] = Field(default_factory=list, description="Available scopes")
|
|
is_active: bool = Field(True, description="Whether client is active")
|
|
|
|
@field_validator("redirect_uris")
|
|
@classmethod
|
|
def validate_redirect_uris(cls, v):
|
|
for uri in v:
|
|
try:
|
|
parsed = urlparse(uri)
|
|
if not parsed.scheme or not parsed.netloc:
|
|
raise ValueError(f"Invalid URI: {uri}. Must have scheme and network location.")
|
|
if parsed.scheme not in ("http", "https"):
|
|
raise ValueError(f"Invalid scheme: {parsed.scheme}. Only http/https allowed.")
|
|
except Exception as e:
|
|
raise ValueError(f"Invalid URI: {uri}. {e}")
|
|
return v
|
|
|
|
@field_validator("grant_types")
|
|
@classmethod
|
|
def validate_grant_types(cls, v):
|
|
for grant in v:
|
|
if grant not in ALLOWED_GRANT_TYPES:
|
|
raise ValueError(f"Invalid grant type: {grant}. Must be one of {ALLOWED_GRANT_TYPES}")
|
|
return v
|
|
|
|
@field_validator("scopes")
|
|
@classmethod
|
|
def validate_scopes(cls, v):
|
|
for scope in v:
|
|
if not scope or not isinstance(scope, str):
|
|
raise ValueError("Scope must be a non-empty string")
|
|
return v
|
|
|
|
@field_validator("client_secret")
|
|
@classmethod
|
|
def validate_client_secret(cls, v):
|
|
if len(v) < 8:
|
|
raise ValueError("Client secret must be at least 8 characters long")
|
|
return v
|
|
|
|
|
|
class OAuthClientCreate(OAuthClientBase):
|
|
"""Schema for creating a new OAuth client."""
|
|
pass
|
|
|
|
|
|
class OAuthClientUpdate(BaseModel):
|
|
"""Schema for updating an existing OAuth client (all fields optional)."""
|
|
client_id: Optional[str] = Field(None, description="Unique client identifier", max_length=100)
|
|
client_secret: Optional[str] = Field(None, description="Client secret (plaintext for input)", max_length=255)
|
|
name: Optional[str] = Field(None, description="Human-readable client name", max_length=200)
|
|
redirect_uris: Optional[List[str]] = Field(None, description="Allowed redirect URIs")
|
|
grant_types: Optional[List[str]] = Field(None, description="Allowed grant types")
|
|
scopes: Optional[List[str]] = Field(None, description="Available scopes")
|
|
is_active: Optional[bool] = Field(None, description="Whether client is active")
|
|
|
|
@field_validator("redirect_uris")
|
|
@classmethod
|
|
def validate_redirect_uris(cls, v):
|
|
if v is None:
|
|
return v
|
|
for uri in v:
|
|
try:
|
|
parsed = urlparse(uri)
|
|
if not parsed.scheme or not parsed.netloc:
|
|
raise ValueError(f"Invalid URI: {uri}. Must have scheme and network location.")
|
|
if parsed.scheme not in ("http", "https"):
|
|
raise ValueError(f"Invalid scheme: {parsed.scheme}. Only http/https allowed.")
|
|
except Exception as e:
|
|
raise ValueError(f"Invalid URI: {uri}. {e}")
|
|
return v
|
|
|
|
@field_validator("grant_types")
|
|
@classmethod
|
|
def validate_grant_types(cls, v):
|
|
if v is None:
|
|
return v
|
|
for grant in v:
|
|
if grant not in ALLOWED_GRANT_TYPES:
|
|
raise ValueError(f"Invalid grant type: {grant}. Must be one of {ALLOWED_GRANT_TYPES}")
|
|
return v
|
|
|
|
@field_validator("scopes")
|
|
@classmethod
|
|
def validate_scopes(cls, v):
|
|
if v is None:
|
|
return v
|
|
for scope in v:
|
|
if not scope or not isinstance(scope, str):
|
|
raise ValueError("Scope must be a non-empty string")
|
|
return v
|
|
|
|
@field_validator("client_secret")
|
|
@classmethod
|
|
def validate_client_secret(cls, v):
|
|
if v is None:
|
|
return v
|
|
if len(v) < 8:
|
|
raise ValueError("Client secret must be at least 8 characters long")
|
|
return v
|
|
|
|
|
|
class OAuthClientResponse(OAuthClientBase):
|
|
"""Schema for returning an OAuth client (includes ID and timestamps)."""
|
|
id: int
|
|
created_at: datetime
|
|
updated_at: datetime
|
|
|
|
model_config = ConfigDict(from_attributes=True)
|
|
|
|
|
|
class OAuthTokenBase(BaseModel):
|
|
"""Base schema for OAuthToken with common fields."""
|
|
access_token: str = Field(..., description="Access token value", max_length=1000)
|
|
refresh_token: Optional[str] = Field(None, description="Refresh token value", max_length=1000)
|
|
token_type: str = Field("Bearer", description="Token type", max_length=50)
|
|
expires_at: datetime = Field(..., description="Token expiration timestamp")
|
|
scopes: List[str] = Field(default_factory=list, description="Granted scopes")
|
|
client_id: str = Field(..., description="Client identifier", max_length=100)
|
|
user_id: Optional[int] = Field(None, description="User identifier")
|
|
|
|
@field_validator("token_type")
|
|
@classmethod
|
|
def validate_token_type(cls, v):
|
|
if v.lower() not in ("bearer", "mac", "jwt"):
|
|
raise ValueError("Token type must be 'bearer', 'mac', or 'jwt'")
|
|
return v.title() # Capitalize first letter
|
|
|
|
@field_validator("scopes")
|
|
@classmethod
|
|
def validate_scopes(cls, v):
|
|
# Ensure scopes are non-empty strings
|
|
for scope in v:
|
|
if not scope or not isinstance(scope, str):
|
|
raise ValueError("Scope must be a non-empty string")
|
|
return v
|
|
|
|
|
|
class OAuthTokenCreate(OAuthTokenBase):
|
|
"""Schema for creating a new OAuth token."""
|
|
pass
|
|
|
|
|
|
class OAuthTokenUpdate(BaseModel):
|
|
"""Schema for updating an existing OAuth token (all fields optional)."""
|
|
access_token: Optional[str] = Field(None, description="Access token value", max_length=1000)
|
|
refresh_token: Optional[str] = Field(None, description="Refresh token value", max_length=1000)
|
|
token_type: Optional[str] = Field(None, description="Token type", max_length=50)
|
|
expires_at: Optional[datetime] = Field(None, description="Token expiration timestamp")
|
|
scopes: Optional[List[str]] = Field(None, description="Granted scopes")
|
|
client_id: Optional[str] = Field(None, description="Client identifier", max_length=100)
|
|
user_id: Optional[int] = Field(None, description="User identifier")
|
|
|
|
@field_validator("token_type")
|
|
@classmethod
|
|
def validate_token_type(cls, v):
|
|
if v is None:
|
|
return v
|
|
if v.lower() not in ("bearer", "mac", "jwt"):
|
|
raise ValueError("Token type must be 'bearer', 'mac', or 'jwt'")
|
|
return v.title()
|
|
|
|
@field_validator("scopes")
|
|
@classmethod
|
|
def validate_scopes(cls, v):
|
|
if v is None:
|
|
return v
|
|
for scope in v:
|
|
if not scope or not isinstance(scope, str):
|
|
raise ValueError("Scope must be a non-empty string")
|
|
return v
|
|
|
|
|
|
class OAuthTokenResponse(OAuthTokenBase):
|
|
"""Schema for returning an OAuth token (includes ID, timestamps, and computed fields)."""
|
|
id: int
|
|
created_at: datetime
|
|
updated_at: datetime
|
|
|
|
@property
|
|
def is_expired(self) -> bool:
|
|
"""Check if token is expired."""
|
|
return self.expires_at < datetime.utcnow()
|
|
|
|
model_config = ConfigDict(from_attributes=True)
|
|
|
|
|
|
class OAuthUserBase(BaseModel):
|
|
"""Base schema for OAuthUser with common fields."""
|
|
username: str = Field(..., description="Unique username", max_length=100)
|
|
password_hash: str = Field(..., description="Password hash (plaintext for input, will be hashed)", max_length=255)
|
|
email: Optional[str] = Field(None, description="User email address", max_length=255)
|
|
is_active: bool = Field(True, description="Whether user account is active")
|
|
|
|
@field_validator("password_hash")
|
|
@classmethod
|
|
def validate_password_hash(cls, v):
|
|
# In reality, we'd check if it's a hash or plaintext; for simplicity, require min length 8
|
|
if len(v) < 8:
|
|
raise ValueError("Password must be at least 8 characters long")
|
|
return v
|
|
|
|
@field_validator("email")
|
|
@classmethod
|
|
def validate_email(cls, v):
|
|
if v is None:
|
|
return v
|
|
# Simple email regex (not exhaustive)
|
|
email_regex = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
|
|
if not re.match(email_regex, v):
|
|
raise ValueError("Invalid email address format")
|
|
return v
|
|
|
|
|
|
class OAuthUserCreate(OAuthUserBase):
|
|
"""Schema for creating a new OAuth user."""
|
|
pass
|
|
|
|
|
|
class OAuthUserUpdate(BaseModel):
|
|
"""Schema for updating an existing OAuth user (all fields optional)."""
|
|
username: Optional[str] = Field(None, description="Unique username", max_length=100)
|
|
password_hash: Optional[str] = Field(None, description="Password hash (plaintext for input)", max_length=255)
|
|
email: Optional[str] = Field(None, description="User email address", max_length=255)
|
|
is_active: Optional[bool] = Field(None, description="Whether user account is active")
|
|
|
|
@field_validator("password_hash")
|
|
@classmethod
|
|
def validate_password_hash(cls, v):
|
|
if v is None:
|
|
return v
|
|
if len(v) < 8:
|
|
raise ValueError("Password must be at least 8 characters long")
|
|
return v
|
|
|
|
@field_validator("email")
|
|
@classmethod
|
|
def validate_email(cls, v):
|
|
if v is None:
|
|
return v
|
|
# Simple email regex (not exhaustive)
|
|
email_regex = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
|
|
if not re.match(email_regex, v):
|
|
raise ValueError("Invalid email address format")
|
|
return v
|
|
|
|
|
|
class OAuthUserResponse(OAuthUserBase):
|
|
"""Schema for returning an OAuth user (includes ID and timestamps)."""
|
|
id: int
|
|
created_at: datetime
|
|
updated_at: datetime
|
|
|
|
model_config = ConfigDict(from_attributes=True) |