215 lines
No EOL
7.7 KiB
Python
215 lines
No EOL
7.7 KiB
Python
"""
|
||
Authorization Code Store for OAuth2 authorization code flow.
|
||
|
||
Provides temporary, thread‑safe storage of authorization codes with automatic expiration.
|
||
"""
|
||
import asyncio
|
||
import logging
|
||
from datetime import datetime, timedelta
|
||
from typing import Dict, Optional
|
||
from config import settings
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class AuthorizationCodeStore:
|
||
"""
|
||
In‑memory store for OAuth2 authorization codes.
|
||
|
||
This class provides a thread‑safe dictionary‑based store with automatic
|
||
expiration of codes. Each stored code is associated with a dictionary of
|
||
metadata (client_id, redirect_uri, scopes, user_id, expires_at) and is
|
||
automatically removed when retrieved or when its expiration time passes.
|
||
|
||
The store is designed as a singleton; use the global instance
|
||
`authorization_code_store`.
|
||
"""
|
||
|
||
# Default expiration time for authorization codes (RFC 6749 §4.1.2 recommends ≤10 minutes)
|
||
DEFAULT_EXPIRATION = timedelta(minutes=settings.oauth2_authorization_code_expire_minutes)
|
||
|
||
def __init__(self, default_expiration: Optional[timedelta] = None):
|
||
"""
|
||
Initialize a new authorization code store.
|
||
|
||
Args:
|
||
default_expiration: Default lifetime for stored codes.
|
||
If None, DEFAULT_EXPIRATION is used.
|
||
"""
|
||
self._store: Dict[str, dict] = {}
|
||
self._lock = asyncio.Lock()
|
||
self.default_expiration = default_expiration or self.DEFAULT_EXPIRATION
|
||
logger.info(f"AuthorizationCodeStore initialized with default expiration {self.default_expiration}")
|
||
|
||
async def store_code(self, code: str, data: dict) -> None:
|
||
"""
|
||
Store an authorization code with its associated data.
|
||
|
||
Args:
|
||
code: The authorization code string (generated securely).
|
||
data: Dictionary containing at least:
|
||
- client_id (str)
|
||
- redirect_uri (str)
|
||
- scopes (list of str)
|
||
- user_id (optional int)
|
||
- expires_at (datetime) – if not provided, defaults to now + default_expiration.
|
||
|
||
Raises:
|
||
ValueError: If required fields are missing.
|
||
"""
|
||
required = {"client_id", "redirect_uri", "scopes"}
|
||
if not all(key in data for key in required):
|
||
missing = required - set(data.keys())
|
||
raise ValueError(f"Missing required fields in data: {missing}")
|
||
|
||
# Ensure expires_at is set
|
||
expires_at = data.get("expires_at")
|
||
if expires_at is None:
|
||
expires_at = datetime.utcnow() + self.default_expiration
|
||
data = {**data, "expires_at": expires_at}
|
||
elif isinstance(expires_at, (int, float)):
|
||
# If a timestamp is passed, convert to datetime
|
||
expires_at = datetime.utcfromtimestamp(expires_at)
|
||
data = {**data, "expires_at": expires_at}
|
||
|
||
async with self._lock:
|
||
self._store[code] = data
|
||
logger.debug(f"Stored authorization code {code[:8]}... for client {data['client_id']}")
|
||
logger.debug(f"Total codes stored: {len(self._store)}")
|
||
|
||
async def get_code(self, code: str) -> Optional[dict]:
|
||
"""
|
||
Retrieve the data associated with an authorization code.
|
||
|
||
This method performs automatic cleanup: if the code has expired,
|
||
it is deleted and None is returned. If the code is valid, it is
|
||
returned but NOT deleted (deletion is the responsibility of the caller,
|
||
typically via delete_code after successful exchange).
|
||
|
||
Args:
|
||
code: The authorization code string.
|
||
|
||
Returns:
|
||
The stored data dict if the code exists and is not expired,
|
||
otherwise None.
|
||
"""
|
||
async with self._lock:
|
||
if code not in self._store:
|
||
logger.debug(f"Authorization code {code[:8]}... not found")
|
||
return None
|
||
|
||
data = self._store[code]
|
||
expires_at = data["expires_at"]
|
||
if expires_at < datetime.utcnow():
|
||
del self._store[code]
|
||
logger.debug(f"Authorization code {code[:8]}... expired and removed")
|
||
return None
|
||
|
||
logger.debug(f"Retrieved authorization code {code[:8]}... for client {data['client_id']}")
|
||
return data
|
||
|
||
async def delete_code(self, code: str) -> None:
|
||
"""
|
||
Delete an authorization code from the store.
|
||
|
||
This method is idempotent; deleting a non‑existent code does nothing.
|
||
|
||
Args:
|
||
code: The authorization code string.
|
||
"""
|
||
async with self._lock:
|
||
if code in self._store:
|
||
del self._store[code]
|
||
logger.debug(f"Deleted authorization code {code[:8]}...")
|
||
else:
|
||
logger.debug(f"Authorization code {code[:8]}... not found (nothing to delete)")
|
||
|
||
async def prune_expired(self) -> int:
|
||
"""
|
||
Remove all expired codes from the store.
|
||
|
||
Returns:
|
||
Number of codes removed.
|
||
"""
|
||
now = datetime.utcnow()
|
||
removed = 0
|
||
async with self._lock:
|
||
expired_keys = [k for k, v in self._store.items() if v["expires_at"] < now]
|
||
for key in expired_keys:
|
||
del self._store[key]
|
||
removed += 1
|
||
if removed:
|
||
logger.debug(f"Pruned {removed} expired authorization codes")
|
||
return removed
|
||
|
||
def get_store_size(self) -> int:
|
||
"""
|
||
Return the current number of codes stored (including expired ones).
|
||
|
||
Note: This method is not thread‑safe unless called from within a lock.
|
||
"""
|
||
return len(self._store)
|
||
|
||
|
||
# Global singleton instance
|
||
authorization_code_store = AuthorizationCodeStore()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
"""Simple demonstration of the AuthorizationCodeStore."""
|
||
import asyncio
|
||
import sys
|
||
|
||
async def demo() -> None:
|
||
store = AuthorizationCodeStore(default_expiration=timedelta(seconds=2))
|
||
print("=== AuthorizationCodeStore Demo ===")
|
||
|
||
# 1. Store a code
|
||
code = "demo_auth_code_xyz"
|
||
data = {
|
||
"client_id": "demo_client",
|
||
"redirect_uri": "https://demo.example.com/callback",
|
||
"scopes": ["read", "write"],
|
||
"user_id": 1001,
|
||
}
|
||
await store.store_code(code, data)
|
||
print(f"1. Stored code: {code[:8]}...")
|
||
|
||
# 2. Retrieve it (should succeed)
|
||
retrieved = await store.get_code(code)
|
||
if retrieved:
|
||
print(f"2. Retrieved code for client: {retrieved['client_id']}")
|
||
else:
|
||
print("2. ERROR: Code not found")
|
||
sys.exit(1)
|
||
|
||
# 3. Wait for expiration
|
||
print("3. Waiting 3 seconds for code to expire...")
|
||
await asyncio.sleep(3)
|
||
|
||
# 4. Retrieve again (should be None and automatically removed)
|
||
retrieved = await store.get_code(code)
|
||
if retrieved is None:
|
||
print("4. Code correctly expired and removed")
|
||
else:
|
||
print("4. ERROR: Code still present after expiration")
|
||
sys.exit(1)
|
||
|
||
# 5. Prune expired (should be empty)
|
||
removed = await store.prune_expired()
|
||
print(f"5. Pruned {removed} expired codes")
|
||
|
||
# 6. Thread‑safe concurrent operations
|
||
codes = [f"concurrent_{i}" for i in range(5)]
|
||
tasks = [store.store_code(c, data) for c in codes]
|
||
await asyncio.gather(*tasks)
|
||
print(f"6. Stored {len(codes)} codes concurrently")
|
||
|
||
# 7. Delete all
|
||
for c in codes:
|
||
await store.delete_code(c)
|
||
print("7. Deleted all concurrent codes")
|
||
|
||
print("=== Demo completed successfully ===")
|
||
|
||
asyncio.run(demo()) |