70 lines
No EOL
2.8 KiB
Python
70 lines
No EOL
2.8 KiB
Python
import logging
|
||
import bcrypt
|
||
from typing import Callable, Awaitable
|
||
from starlette.middleware.base import BaseHTTPMiddleware
|
||
from fastapi import status
|
||
from starlette.requests import Request
|
||
from fastapi.responses import Response, RedirectResponse
|
||
from app.core.config import settings
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||
"""Verify a plain password against a bcrypt hash."""
|
||
if isinstance(hashed_password, str):
|
||
hashed_password = hashed_password.encode('utf-8')
|
||
if isinstance(plain_password, str):
|
||
plain_password = plain_password.encode('utf-8')
|
||
return bcrypt.checkpw(plain_password, hashed_password)
|
||
|
||
|
||
def get_password_hash(password: str) -> str:
|
||
"""Generate bcrypt hash for a password."""
|
||
if isinstance(password, str):
|
||
password = password.encode('utf-8')
|
||
# Generate salt and hash
|
||
salt = bcrypt.gensalt()
|
||
hashed = bcrypt.hashpw(password, salt)
|
||
return hashed.decode('utf-8')
|
||
|
||
|
||
class AuthMiddleware(BaseHTTPMiddleware):
|
||
"""
|
||
Middleware to protect admin routes.
|
||
Requires session authentication for all routes under /admin except /admin/login.
|
||
"""
|
||
|
||
def __init__(self, app, admin_path_prefix: str = "/admin"):
|
||
super().__init__(app)
|
||
self.admin_path_prefix = admin_path_prefix
|
||
# Pre‑compute admin credentials (hash the configured password)
|
||
self.admin_username = settings.admin_username
|
||
self.admin_password_hash = get_password_hash(settings.admin_password)
|
||
logger.info("AuthMiddleware initialized for admin prefix: %s", admin_path_prefix)
|
||
|
||
async def dispatch(
|
||
self, request: Request, call_next: Callable[[Request], Awaitable[Response]]
|
||
) -> Response:
|
||
path = request.url.path
|
||
|
||
# Skip authentication for login endpoint and static files
|
||
if path == f"{self.admin_path_prefix}/login" and request.method == "POST":
|
||
# Login endpoint will handle authentication
|
||
return await call_next(request)
|
||
if path == f"{self.admin_path_prefix}/login" and request.method == "GET":
|
||
return await call_next(request)
|
||
if path.startswith(self.admin_path_prefix):
|
||
# Check if user is authenticated
|
||
session = request.session
|
||
username = session.get("username")
|
||
if username == self.admin_username:
|
||
# User authenticated, proceed
|
||
return await call_next(request)
|
||
else:
|
||
# Redirect to login page
|
||
logger.warning("Unauthorized access attempt to %s", path)
|
||
return RedirectResponse(url=f"{self.admin_path_prefix}/login", status_code=status.HTTP_302_FOUND)
|
||
|
||
# Non‑admin route, proceed
|
||
return await call_next(request) |