mockapi/app/modules/oauth2/auth_code_store.py
2026-03-16 09:00:26 +00:00

215 lines
No EOL
7.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Authorization Code Store for OAuth2 authorization code flow.
Provides temporary, threadsafe storage of authorization codes with automatic expiration.
"""
import asyncio
import logging
from datetime import datetime, timedelta
from typing import Dict, Optional
from app.core.config import settings
logger = logging.getLogger(__name__)
class AuthorizationCodeStore:
"""
Inmemory store for OAuth2 authorization codes.
This class provides a threadsafe dictionarybased store with automatic
expiration of codes. Each stored code is associated with a dictionary of
metadata (client_id, redirect_uri, scopes, user_id, expires_at) and is
automatically removed when retrieved or when its expiration time passes.
The store is designed as a singleton; use the global instance
`authorization_code_store`.
"""
# Default expiration time for authorization codes (RFC 6749 §4.1.2 recommends ≤10 minutes)
DEFAULT_EXPIRATION = timedelta(minutes=settings.oauth2_authorization_code_expire_minutes)
def __init__(self, default_expiration: Optional[timedelta] = None):
"""
Initialize a new authorization code store.
Args:
default_expiration: Default lifetime for stored codes.
If None, DEFAULT_EXPIRATION is used.
"""
self._store: Dict[str, dict] = {}
self._lock = asyncio.Lock()
self.default_expiration = default_expiration or self.DEFAULT_EXPIRATION
logger.info(f"AuthorizationCodeStore initialized with default expiration {self.default_expiration}")
async def store_code(self, code: str, data: dict) -> None:
"""
Store an authorization code with its associated data.
Args:
code: The authorization code string (generated securely).
data: Dictionary containing at least:
- client_id (str)
- redirect_uri (str)
- scopes (list of str)
- user_id (optional int)
- expires_at (datetime) if not provided, defaults to now + default_expiration.
Raises:
ValueError: If required fields are missing.
"""
required = {"client_id", "redirect_uri", "scopes"}
if not all(key in data for key in required):
missing = required - set(data.keys())
raise ValueError(f"Missing required fields in data: {missing}")
# Ensure expires_at is set
expires_at = data.get("expires_at")
if expires_at is None:
expires_at = datetime.utcnow() + self.default_expiration
data = {**data, "expires_at": expires_at}
elif isinstance(expires_at, (int, float)):
# If a timestamp is passed, convert to datetime
expires_at = datetime.utcfromtimestamp(expires_at)
data = {**data, "expires_at": expires_at}
async with self._lock:
self._store[code] = data
logger.debug(f"Stored authorization code {code[:8]}... for client {data['client_id']}")
logger.debug(f"Total codes stored: {len(self._store)}")
async def get_code(self, code: str) -> Optional[dict]:
"""
Retrieve the data associated with an authorization code.
This method performs automatic cleanup: if the code has expired,
it is deleted and None is returned. If the code is valid, it is
returned but NOT deleted (deletion is the responsibility of the caller,
typically via delete_code after successful exchange).
Args:
code: The authorization code string.
Returns:
The stored data dict if the code exists and is not expired,
otherwise None.
"""
async with self._lock:
if code not in self._store:
logger.debug(f"Authorization code {code[:8]}... not found")
return None
data = self._store[code]
expires_at = data["expires_at"]
if expires_at < datetime.utcnow():
del self._store[code]
logger.debug(f"Authorization code {code[:8]}... expired and removed")
return None
logger.debug(f"Retrieved authorization code {code[:8]}... for client {data['client_id']}")
return data
async def delete_code(self, code: str) -> None:
"""
Delete an authorization code from the store.
This method is idempotent; deleting a nonexistent code does nothing.
Args:
code: The authorization code string.
"""
async with self._lock:
if code in self._store:
del self._store[code]
logger.debug(f"Deleted authorization code {code[:8]}...")
else:
logger.debug(f"Authorization code {code[:8]}... not found (nothing to delete)")
async def prune_expired(self) -> int:
"""
Remove all expired codes from the store.
Returns:
Number of codes removed.
"""
now = datetime.utcnow()
removed = 0
async with self._lock:
expired_keys = [k for k, v in self._store.items() if v["expires_at"] < now]
for key in expired_keys:
del self._store[key]
removed += 1
if removed:
logger.debug(f"Pruned {removed} expired authorization codes")
return removed
def get_store_size(self) -> int:
"""
Return the current number of codes stored (including expired ones).
Note: This method is not threadsafe unless called from within a lock.
"""
return len(self._store)
# Global singleton instance
authorization_code_store = AuthorizationCodeStore()
if __name__ == "__main__":
"""Simple demonstration of the AuthorizationCodeStore."""
import asyncio
import sys
async def demo() -> None:
store = AuthorizationCodeStore(default_expiration=timedelta(seconds=2))
print("=== AuthorizationCodeStore Demo ===")
# 1. Store a code
code = "demo_auth_code_xyz"
data = {
"client_id": "demo_client",
"redirect_uri": "https://demo.example.com/callback",
"scopes": ["read", "write"],
"user_id": 1001,
}
await store.store_code(code, data)
print(f"1. Stored code: {code[:8]}...")
# 2. Retrieve it (should succeed)
retrieved = await store.get_code(code)
if retrieved:
print(f"2. Retrieved code for client: {retrieved['client_id']}")
else:
print("2. ERROR: Code not found")
sys.exit(1)
# 3. Wait for expiration
print("3. Waiting 3 seconds for code to expire...")
await asyncio.sleep(3)
# 4. Retrieve again (should be None and automatically removed)
retrieved = await store.get_code(code)
if retrieved is None:
print("4. Code correctly expired and removed")
else:
print("4. ERROR: Code still present after expiration")
sys.exit(1)
# 5. Prune expired (should be empty)
removed = await store.prune_expired()
print(f"5. Pruned {removed} expired codes")
# 6. Threadsafe concurrent operations
codes = [f"concurrent_{i}" for i in range(5)]
tasks = [store.store_code(c, data) for c in codes]
await asyncio.gather(*tasks)
print(f"6. Stored {len(codes)} codes concurrently")
# 7. Delete all
for c in codes:
await store.delete_code(c)
print("7. Deleted all concurrent codes")
print("=== Demo completed successfully ===")
asyncio.run(demo())