mockapi/tests/test_auth_code_store.py
2026-03-16 05:47:01 +00:00

173 lines
No EOL
5.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Unit tests for AuthorizationCodeStore.
"""
import asyncio
import pytest
from datetime import datetime, timedelta
from oauth2.auth_code_store import AuthorizationCodeStore
@pytest.fixture
def store():
"""Return a fresh AuthorizationCodeStore instance for each test."""
return AuthorizationCodeStore(default_expiration=timedelta(seconds=1))
@pytest.mark.asyncio
async def test_store_and_retrieve_code(store):
"""Store a code and retrieve it before expiration."""
code = "test_code_123"
data = {
"client_id": "test_client",
"redirect_uri": "https://example.com/callback",
"scopes": ["read", "write"],
"user_id": 42,
}
await store.store_code(code, data)
retrieved = await store.get_code(code)
assert retrieved is not None
assert retrieved["client_id"] == data["client_id"]
assert retrieved["redirect_uri"] == data["redirect_uri"]
assert retrieved["scopes"] == data["scopes"]
assert retrieved["user_id"] == data["user_id"]
assert "expires_at" in retrieved
assert isinstance(retrieved["expires_at"], datetime)
@pytest.mark.asyncio
async def test_store_without_expires_at_gets_default(store):
"""When expires_at is omitted, the store adds a default expiration."""
code = "test_code_no_exp"
data = {
"client_id": "client1",
"redirect_uri": "https://example.com/cb",
"scopes": [],
}
await store.store_code(code, data)
retrieved = await store.get_code(code)
assert retrieved is not None
assert "expires_at" in retrieved
# Should be roughly now + default expiration (1 second in test fixture)
# Allow small tolerance
expected_min = datetime.utcnow() + timedelta(seconds=0.9)
expected_max = datetime.utcnow() + timedelta(seconds=1.1)
assert expected_min <= retrieved["expires_at"] <= expected_max
@pytest.mark.asyncio
async def test_get_expired_code_returns_none_and_deletes(store):
"""Expired codes are automatically removed on get_code."""
code = "expired_code"
data = {
"client_id": "client",
"redirect_uri": "https://example.com/cb",
"scopes": [],
"expires_at": datetime.utcnow() - timedelta(minutes=5), # already expired
}
await store.store_code(code, data)
# Wait a tiny bit to ensure expiration
await asyncio.sleep(0.01)
retrieved = await store.get_code(code)
assert retrieved is None
# Ensure code is removed from store
assert store.get_store_size() == 0
@pytest.mark.asyncio
async def test_delete_code(store):
"""Explicit deletion removes the code."""
code = "to_delete"
data = {
"client_id": "client",
"redirect_uri": "https://example.com/cb",
"scopes": [],
}
await store.store_code(code, data)
assert store.get_store_size() == 1
await store.delete_code(code)
assert store.get_store_size() == 0
assert await store.get_code(code) is None
@pytest.mark.asyncio
async def test_delete_nonexistent_code_is_idempotent(store):
"""Deleting a nonexistent code does not raise an error."""
await store.delete_code("does_not_exist")
# No exception raised
@pytest.mark.asyncio
async def test_prune_expired(store):
"""prune_expired removes all expired codes."""
# Store one expired and one valid code
expired_data = {
"client_id": "client1",
"redirect_uri": "https://example.com/cb",
"scopes": [],
"expires_at": datetime.utcnow() - timedelta(seconds=30),
}
valid_data = {
"client_id": "client2",
"redirect_uri": "https://example.com/cb",
"scopes": [],
"expires_at": datetime.utcnow() + timedelta(seconds=30),
}
await store.store_code("expired", expired_data)
await store.store_code("valid", valid_data)
assert store.get_store_size() == 2
removed = await store.prune_expired()
assert removed == 1
assert store.get_store_size() == 1
assert await store.get_code("valid") is not None
assert await store.get_code("expired") is None
@pytest.mark.asyncio
async def test_missing_required_fields_raises_error(store):
"""store_code raises ValueError if required fields are missing."""
code = "bad_code"
incomplete_data = {
"client_id": "client",
# missing redirect_uri and scopes
}
with pytest.raises(ValueError) as exc:
await store.store_code(code, incomplete_data)
assert "Missing required fields" in str(exc.value)
@pytest.mark.asyncio
async def test_thread_safety_simulation(store):
"""Concurrent access should not raise exceptions (basic safety check)."""
codes = [f"code_{i}" for i in range(10)]
data = {
"client_id": "client",
"redirect_uri": "https://example.com/cb",
"scopes": [],
}
# Store concurrently
tasks = [store.store_code(code, data) for code in codes]
await asyncio.gather(*tasks)
assert store.get_store_size() == 10
# Retrieve and delete concurrently
tasks = [store.get_code(code) for code in codes]
results = await asyncio.gather(*tasks)
assert all(r is not None for r in results)
tasks = [store.delete_code(code) for code in codes]
await asyncio.gather(*tasks)
assert store.get_store_size() == 0
@pytest.mark.asyncio
async def test_singleton_global_instance():
"""The global instance authorization_code_store is a singleton."""
from oauth2.auth_code_store import authorization_code_store
# Import again to ensure it's the same object
from oauth2.auth_code_store import authorization_code_store as same_instance
assert authorization_code_store is same_instance
if __name__ == "__main__":
# Simple standalone test (can be run with python -m pytest)
pytest.main([__file__, "-v"])