""" Unit tests for AuthorizationCodeStore. """ import asyncio import pytest from datetime import datetime, timedelta from oauth2.auth_code_store import AuthorizationCodeStore @pytest.fixture def store(): """Return a fresh AuthorizationCodeStore instance for each test.""" return AuthorizationCodeStore(default_expiration=timedelta(seconds=1)) @pytest.mark.asyncio async def test_store_and_retrieve_code(store): """Store a code and retrieve it before expiration.""" code = "test_code_123" data = { "client_id": "test_client", "redirect_uri": "https://example.com/callback", "scopes": ["read", "write"], "user_id": 42, } await store.store_code(code, data) retrieved = await store.get_code(code) assert retrieved is not None assert retrieved["client_id"] == data["client_id"] assert retrieved["redirect_uri"] == data["redirect_uri"] assert retrieved["scopes"] == data["scopes"] assert retrieved["user_id"] == data["user_id"] assert "expires_at" in retrieved assert isinstance(retrieved["expires_at"], datetime) @pytest.mark.asyncio async def test_store_without_expires_at_gets_default(store): """When expires_at is omitted, the store adds a default expiration.""" code = "test_code_no_exp" data = { "client_id": "client1", "redirect_uri": "https://example.com/cb", "scopes": [], } await store.store_code(code, data) retrieved = await store.get_code(code) assert retrieved is not None assert "expires_at" in retrieved # Should be roughly now + default expiration (1 second in test fixture) # Allow small tolerance expected_min = datetime.utcnow() + timedelta(seconds=0.9) expected_max = datetime.utcnow() + timedelta(seconds=1.1) assert expected_min <= retrieved["expires_at"] <= expected_max @pytest.mark.asyncio async def test_get_expired_code_returns_none_and_deletes(store): """Expired codes are automatically removed on get_code.""" code = "expired_code" data = { "client_id": "client", "redirect_uri": "https://example.com/cb", "scopes": [], "expires_at": datetime.utcnow() - timedelta(minutes=5), # already expired } await store.store_code(code, data) # Wait a tiny bit to ensure expiration await asyncio.sleep(0.01) retrieved = await store.get_code(code) assert retrieved is None # Ensure code is removed from store assert store.get_store_size() == 0 @pytest.mark.asyncio async def test_delete_code(store): """Explicit deletion removes the code.""" code = "to_delete" data = { "client_id": "client", "redirect_uri": "https://example.com/cb", "scopes": [], } await store.store_code(code, data) assert store.get_store_size() == 1 await store.delete_code(code) assert store.get_store_size() == 0 assert await store.get_code(code) is None @pytest.mark.asyncio async def test_delete_nonexistent_code_is_idempotent(store): """Deleting a non‑existent code does not raise an error.""" await store.delete_code("does_not_exist") # No exception raised @pytest.mark.asyncio async def test_prune_expired(store): """prune_expired removes all expired codes.""" # Store one expired and one valid code expired_data = { "client_id": "client1", "redirect_uri": "https://example.com/cb", "scopes": [], "expires_at": datetime.utcnow() - timedelta(seconds=30), } valid_data = { "client_id": "client2", "redirect_uri": "https://example.com/cb", "scopes": [], "expires_at": datetime.utcnow() + timedelta(seconds=30), } await store.store_code("expired", expired_data) await store.store_code("valid", valid_data) assert store.get_store_size() == 2 removed = await store.prune_expired() assert removed == 1 assert store.get_store_size() == 1 assert await store.get_code("valid") is not None assert await store.get_code("expired") is None @pytest.mark.asyncio async def test_missing_required_fields_raises_error(store): """store_code raises ValueError if required fields are missing.""" code = "bad_code" incomplete_data = { "client_id": "client", # missing redirect_uri and scopes } with pytest.raises(ValueError) as exc: await store.store_code(code, incomplete_data) assert "Missing required fields" in str(exc.value) @pytest.mark.asyncio async def test_thread_safety_simulation(store): """Concurrent access should not raise exceptions (basic safety check).""" codes = [f"code_{i}" for i in range(10)] data = { "client_id": "client", "redirect_uri": "https://example.com/cb", "scopes": [], } # Store concurrently tasks = [store.store_code(code, data) for code in codes] await asyncio.gather(*tasks) assert store.get_store_size() == 10 # Retrieve and delete concurrently tasks = [store.get_code(code) for code in codes] results = await asyncio.gather(*tasks) assert all(r is not None for r in results) tasks = [store.delete_code(code) for code in codes] await asyncio.gather(*tasks) assert store.get_store_size() == 0 @pytest.mark.asyncio async def test_singleton_global_instance(): """The global instance authorization_code_store is a singleton.""" from oauth2.auth_code_store import authorization_code_store # Import again to ensure it's the same object from oauth2.auth_code_store import authorization_code_store as same_instance assert authorization_code_store is same_instance if __name__ == "__main__": # Simple standalone test (can be run with python -m pytest) pytest.main([__file__, "-v"])