""" Authorization Code Store for OAuth2 authorization code flow. Provides temporary, thread‑safe storage of authorization codes with automatic expiration. """ import asyncio import logging from datetime import datetime, timedelta from typing import Dict, Optional from config import settings logger = logging.getLogger(__name__) class AuthorizationCodeStore: """ In‑memory store for OAuth2 authorization codes. This class provides a thread‑safe dictionary‑based store with automatic expiration of codes. Each stored code is associated with a dictionary of metadata (client_id, redirect_uri, scopes, user_id, expires_at) and is automatically removed when retrieved or when its expiration time passes. The store is designed as a singleton; use the global instance `authorization_code_store`. """ # Default expiration time for authorization codes (RFC 6749 §4.1.2 recommends ≤10 minutes) DEFAULT_EXPIRATION = timedelta(minutes=settings.oauth2_authorization_code_expire_minutes) def __init__(self, default_expiration: Optional[timedelta] = None): """ Initialize a new authorization code store. Args: default_expiration: Default lifetime for stored codes. If None, DEFAULT_EXPIRATION is used. """ self._store: Dict[str, dict] = {} self._lock = asyncio.Lock() self.default_expiration = default_expiration or self.DEFAULT_EXPIRATION logger.info(f"AuthorizationCodeStore initialized with default expiration {self.default_expiration}") async def store_code(self, code: str, data: dict) -> None: """ Store an authorization code with its associated data. Args: code: The authorization code string (generated securely). data: Dictionary containing at least: - client_id (str) - redirect_uri (str) - scopes (list of str) - user_id (optional int) - expires_at (datetime) – if not provided, defaults to now + default_expiration. Raises: ValueError: If required fields are missing. """ required = {"client_id", "redirect_uri", "scopes"} if not all(key in data for key in required): missing = required - set(data.keys()) raise ValueError(f"Missing required fields in data: {missing}") # Ensure expires_at is set expires_at = data.get("expires_at") if expires_at is None: expires_at = datetime.utcnow() + self.default_expiration data = {**data, "expires_at": expires_at} elif isinstance(expires_at, (int, float)): # If a timestamp is passed, convert to datetime expires_at = datetime.utcfromtimestamp(expires_at) data = {**data, "expires_at": expires_at} async with self._lock: self._store[code] = data logger.debug(f"Stored authorization code {code[:8]}... for client {data['client_id']}") logger.debug(f"Total codes stored: {len(self._store)}") async def get_code(self, code: str) -> Optional[dict]: """ Retrieve the data associated with an authorization code. This method performs automatic cleanup: if the code has expired, it is deleted and None is returned. If the code is valid, it is returned but NOT deleted (deletion is the responsibility of the caller, typically via delete_code after successful exchange). Args: code: The authorization code string. Returns: The stored data dict if the code exists and is not expired, otherwise None. """ async with self._lock: if code not in self._store: logger.debug(f"Authorization code {code[:8]}... not found") return None data = self._store[code] expires_at = data["expires_at"] if expires_at < datetime.utcnow(): del self._store[code] logger.debug(f"Authorization code {code[:8]}... expired and removed") return None logger.debug(f"Retrieved authorization code {code[:8]}... for client {data['client_id']}") return data async def delete_code(self, code: str) -> None: """ Delete an authorization code from the store. This method is idempotent; deleting a non‑existent code does nothing. Args: code: The authorization code string. """ async with self._lock: if code in self._store: del self._store[code] logger.debug(f"Deleted authorization code {code[:8]}...") else: logger.debug(f"Authorization code {code[:8]}... not found (nothing to delete)") async def prune_expired(self) -> int: """ Remove all expired codes from the store. Returns: Number of codes removed. """ now = datetime.utcnow() removed = 0 async with self._lock: expired_keys = [k for k, v in self._store.items() if v["expires_at"] < now] for key in expired_keys: del self._store[key] removed += 1 if removed: logger.debug(f"Pruned {removed} expired authorization codes") return removed def get_store_size(self) -> int: """ Return the current number of codes stored (including expired ones). Note: This method is not thread‑safe unless called from within a lock. """ return len(self._store) # Global singleton instance authorization_code_store = AuthorizationCodeStore() if __name__ == "__main__": """Simple demonstration of the AuthorizationCodeStore.""" import asyncio import sys async def demo() -> None: store = AuthorizationCodeStore(default_expiration=timedelta(seconds=2)) print("=== AuthorizationCodeStore Demo ===") # 1. Store a code code = "demo_auth_code_xyz" data = { "client_id": "demo_client", "redirect_uri": "https://demo.example.com/callback", "scopes": ["read", "write"], "user_id": 1001, } await store.store_code(code, data) print(f"1. Stored code: {code[:8]}...") # 2. Retrieve it (should succeed) retrieved = await store.get_code(code) if retrieved: print(f"2. Retrieved code for client: {retrieved['client_id']}") else: print("2. ERROR: Code not found") sys.exit(1) # 3. Wait for expiration print("3. Waiting 3 seconds for code to expire...") await asyncio.sleep(3) # 4. Retrieve again (should be None and automatically removed) retrieved = await store.get_code(code) if retrieved is None: print("4. Code correctly expired and removed") else: print("4. ERROR: Code still present after expiration") sys.exit(1) # 5. Prune expired (should be empty) removed = await store.prune_expired() print(f"5. Pruned {removed} expired codes") # 6. Thread‑safe concurrent operations codes = [f"concurrent_{i}" for i in range(5)] tasks = [store.store_code(c, data) for c in codes] await asyncio.gather(*tasks) print(f"6. Stored {len(codes)} codes concurrently") # 7. Delete all for c in codes: await store.delete_code(c) print("7. Deleted all concurrent codes") print("=== Demo completed successfully ===") asyncio.run(demo())