173 lines
No EOL
5.7 KiB
Python
173 lines
No EOL
5.7 KiB
Python
"""
|
||
Unit tests for AuthorizationCodeStore.
|
||
"""
|
||
import asyncio
|
||
import pytest
|
||
from datetime import datetime, timedelta
|
||
from app.modules.oauth2.auth_code_store import AuthorizationCodeStore
|
||
|
||
|
||
@pytest.fixture
|
||
def store():
|
||
"""Return a fresh AuthorizationCodeStore instance for each test."""
|
||
return AuthorizationCodeStore(default_expiration=timedelta(seconds=1))
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_store_and_retrieve_code(store):
|
||
"""Store a code and retrieve it before expiration."""
|
||
code = "test_code_123"
|
||
data = {
|
||
"client_id": "test_client",
|
||
"redirect_uri": "https://example.com/callback",
|
||
"scopes": ["read", "write"],
|
||
"user_id": 42,
|
||
}
|
||
await store.store_code(code, data)
|
||
|
||
retrieved = await store.get_code(code)
|
||
assert retrieved is not None
|
||
assert retrieved["client_id"] == data["client_id"]
|
||
assert retrieved["redirect_uri"] == data["redirect_uri"]
|
||
assert retrieved["scopes"] == data["scopes"]
|
||
assert retrieved["user_id"] == data["user_id"]
|
||
assert "expires_at" in retrieved
|
||
assert isinstance(retrieved["expires_at"], datetime)
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_store_without_expires_at_gets_default(store):
|
||
"""When expires_at is omitted, the store adds a default expiration."""
|
||
code = "test_code_no_exp"
|
||
data = {
|
||
"client_id": "client1",
|
||
"redirect_uri": "https://example.com/cb",
|
||
"scopes": [],
|
||
}
|
||
await store.store_code(code, data)
|
||
retrieved = await store.get_code(code)
|
||
assert retrieved is not None
|
||
assert "expires_at" in retrieved
|
||
# Should be roughly now + default expiration (1 second in test fixture)
|
||
# Allow small tolerance
|
||
expected_min = datetime.utcnow() + timedelta(seconds=0.9)
|
||
expected_max = datetime.utcnow() + timedelta(seconds=1.1)
|
||
assert expected_min <= retrieved["expires_at"] <= expected_max
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_expired_code_returns_none_and_deletes(store):
|
||
"""Expired codes are automatically removed on get_code."""
|
||
code = "expired_code"
|
||
data = {
|
||
"client_id": "client",
|
||
"redirect_uri": "https://example.com/cb",
|
||
"scopes": [],
|
||
"expires_at": datetime.utcnow() - timedelta(minutes=5), # already expired
|
||
}
|
||
await store.store_code(code, data)
|
||
# Wait a tiny bit to ensure expiration
|
||
await asyncio.sleep(0.01)
|
||
retrieved = await store.get_code(code)
|
||
assert retrieved is None
|
||
# Ensure code is removed from store
|
||
assert store.get_store_size() == 0
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_delete_code(store):
|
||
"""Explicit deletion removes the code."""
|
||
code = "to_delete"
|
||
data = {
|
||
"client_id": "client",
|
||
"redirect_uri": "https://example.com/cb",
|
||
"scopes": [],
|
||
}
|
||
await store.store_code(code, data)
|
||
assert store.get_store_size() == 1
|
||
await store.delete_code(code)
|
||
assert store.get_store_size() == 0
|
||
assert await store.get_code(code) is None
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_delete_nonexistent_code_is_idempotent(store):
|
||
"""Deleting a non‑existent code does not raise an error."""
|
||
await store.delete_code("does_not_exist")
|
||
# No exception raised
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_prune_expired(store):
|
||
"""prune_expired removes all expired codes."""
|
||
# Store one expired and one valid code
|
||
expired_data = {
|
||
"client_id": "client1",
|
||
"redirect_uri": "https://example.com/cb",
|
||
"scopes": [],
|
||
"expires_at": datetime.utcnow() - timedelta(seconds=30),
|
||
}
|
||
valid_data = {
|
||
"client_id": "client2",
|
||
"redirect_uri": "https://example.com/cb",
|
||
"scopes": [],
|
||
"expires_at": datetime.utcnow() + timedelta(seconds=30),
|
||
}
|
||
await store.store_code("expired", expired_data)
|
||
await store.store_code("valid", valid_data)
|
||
assert store.get_store_size() == 2
|
||
|
||
removed = await store.prune_expired()
|
||
assert removed == 1
|
||
assert store.get_store_size() == 1
|
||
assert await store.get_code("valid") is not None
|
||
assert await store.get_code("expired") is None
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_missing_required_fields_raises_error(store):
|
||
"""store_code raises ValueError if required fields are missing."""
|
||
code = "bad_code"
|
||
incomplete_data = {
|
||
"client_id": "client",
|
||
# missing redirect_uri and scopes
|
||
}
|
||
with pytest.raises(ValueError) as exc:
|
||
await store.store_code(code, incomplete_data)
|
||
assert "Missing required fields" in str(exc.value)
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_thread_safety_simulation(store):
|
||
"""Concurrent access should not raise exceptions (basic safety check)."""
|
||
codes = [f"code_{i}" for i in range(10)]
|
||
data = {
|
||
"client_id": "client",
|
||
"redirect_uri": "https://example.com/cb",
|
||
"scopes": [],
|
||
}
|
||
# Store concurrently
|
||
tasks = [store.store_code(code, data) for code in codes]
|
||
await asyncio.gather(*tasks)
|
||
assert store.get_store_size() == 10
|
||
# Retrieve and delete concurrently
|
||
tasks = [store.get_code(code) for code in codes]
|
||
results = await asyncio.gather(*tasks)
|
||
assert all(r is not None for r in results)
|
||
tasks = [store.delete_code(code) for code in codes]
|
||
await asyncio.gather(*tasks)
|
||
assert store.get_store_size() == 0
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_singleton_global_instance():
|
||
"""The global instance authorization_code_store is a singleton."""
|
||
from app.modules.oauth2.auth_code_store import authorization_code_store
|
||
# Import again to ensure it's the same object
|
||
from app.modules.oauth2.auth_code_store import authorization_code_store as same_instance
|
||
assert authorization_code_store is same_instance
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# Simple standalone test (can be run with python -m pytest)
|
||
pytest.main([__file__, "-v"]) |