337 lines
No EOL
12 KiB
Python
337 lines
No EOL
12 KiB
Python
"""
|
|
Unit tests for OAuth2 controller endpoints.
|
|
"""
|
|
import pytest
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
from fastapi.testclient import TestClient
|
|
from fastapi import FastAPI, status
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from oauth2.controller import router as oauth_router
|
|
|
|
|
|
def create_test_app(override_dependency=None) -> FastAPI:
|
|
"""Create a FastAPI app with OAuth router and optional dependency overrides."""
|
|
app = FastAPI()
|
|
app.include_router(oauth_router)
|
|
if override_dependency:
|
|
app.dependency_overrides.update(override_dependency)
|
|
return app
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_db_session():
|
|
"""Mock database session."""
|
|
session = AsyncMock(spec=AsyncSession)
|
|
return session
|
|
|
|
|
|
@pytest.fixture
|
|
def client(mock_db_session):
|
|
"""Test client with mocked database session."""
|
|
from database import get_db
|
|
def override_get_db():
|
|
yield mock_db_session
|
|
app = create_test_app({get_db: override_get_db})
|
|
return TestClient(app)
|
|
|
|
|
|
# ---------- Authorization Endpoint Tests ----------
|
|
def test_authorize_missing_parameters(client):
|
|
"""GET /oauth/authorize without required parameters should return error."""
|
|
response = client.get("/oauth/authorize")
|
|
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
|
# FastAPI returns validation error details
|
|
|
|
|
|
def test_authorize_unsupported_response_type(client):
|
|
"""Only 'code' response_type is supported."""
|
|
response = client.get(
|
|
"/oauth/authorize",
|
|
params={
|
|
"response_type": "token", # unsupported
|
|
"client_id": "test_client",
|
|
"redirect_uri": "https://example.com/callback",
|
|
}
|
|
)
|
|
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
|
data = response.json()
|
|
assert "detail" in data
|
|
assert data["detail"]["error"] == "unsupported_response_type"
|
|
|
|
|
|
def test_authorize_success(client, mock_db_session):
|
|
"""Successful authorization returns redirect with code."""
|
|
# Mock OAuthService.authorize_code_flow
|
|
with patch('oauth2.controller.OAuthService') as MockOAuthService:
|
|
mock_service = AsyncMock()
|
|
mock_service.authorize_code_flow.return_value = {
|
|
"code": "auth_code_123",
|
|
"state": "xyz",
|
|
}
|
|
MockOAuthService.return_value = mock_service
|
|
response = client.get(
|
|
"/oauth/authorize",
|
|
params={
|
|
"response_type": "code",
|
|
"client_id": "test_client",
|
|
"redirect_uri": "https://example.com/callback",
|
|
"scope": "read write",
|
|
"state": "xyz",
|
|
},
|
|
follow_redirects=False
|
|
)
|
|
assert response.status_code == status.HTTP_302_FOUND
|
|
assert "location" in response.headers
|
|
location = response.headers["location"]
|
|
assert location.startswith("https://example.com/callback?")
|
|
assert "code=auth_code_123" in location
|
|
assert "state=xyz" in location
|
|
# Verify service was called with correct parameters
|
|
mock_service.authorize_code_flow.assert_called_once_with(
|
|
client_id="test_client",
|
|
redirect_uri="https://example.com/callback",
|
|
scope=["read", "write"],
|
|
state="xyz",
|
|
user_id=1, # placeholder
|
|
)
|
|
|
|
|
|
# ---------- Token Endpoint Tests ----------
|
|
def test_token_missing_grant_type(client):
|
|
"""POST /oauth/token without grant_type should error (client auth required)."""
|
|
response = client.post("/oauth/token", data={})
|
|
# Client authentication missing -> 401 unauthorized
|
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
|
|
|
|
|
def test_token_unsupported_grant_type(client):
|
|
"""Unsupported grant_type returns error."""
|
|
response = client.post(
|
|
"/oauth/token",
|
|
data={"grant_type": "password"}, # not supported
|
|
auth=("test_client", "secret")
|
|
)
|
|
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
|
data = response.json()
|
|
assert "detail" in data
|
|
assert data["detail"]["error"] == "unsupported_grant_type"
|
|
|
|
|
|
def test_token_authorization_code_missing_code(client):
|
|
"""authorization_code grant requires code."""
|
|
response = client.post(
|
|
"/oauth/token",
|
|
data={
|
|
"grant_type": "authorization_code",
|
|
"client_id": "test_client",
|
|
"client_secret": "secret",
|
|
"redirect_uri": "https://example.com/callback",
|
|
}
|
|
)
|
|
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
|
data = response.json()
|
|
assert "detail" in data
|
|
assert data["detail"]["error"] == "invalid_request"
|
|
|
|
|
|
def test_token_authorization_code_success(client, mock_db_session):
|
|
"""Successful authorization_code exchange returns tokens."""
|
|
with patch('oauth2.controller.OAuthService') as MockOAuthService:
|
|
mock_service = AsyncMock()
|
|
mock_service.exchange_code_for_tokens.return_value = {
|
|
"access_token": "access_token_123",
|
|
"token_type": "Bearer",
|
|
"expires_in": 1800,
|
|
"refresh_token": "refresh_token_456",
|
|
"scope": "read write",
|
|
}
|
|
MockOAuthService.return_value = mock_service
|
|
|
|
response = client.post(
|
|
"/oauth/token",
|
|
data={
|
|
"grant_type": "authorization_code",
|
|
"code": "auth_code_xyz",
|
|
"redirect_uri": "https://example.com/callback",
|
|
"client_id": "test_client",
|
|
"client_secret": "secret",
|
|
}
|
|
)
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert data["access_token"] == "access_token_123"
|
|
assert data["token_type"] == "Bearer"
|
|
assert data["refresh_token"] == "refresh_token_456"
|
|
mock_service.exchange_code_for_tokens.assert_called_once_with(
|
|
code="auth_code_xyz",
|
|
client_id="test_client",
|
|
redirect_uri="https://example.com/callback",
|
|
)
|
|
|
|
|
|
def test_token_client_credentials_success(client, mock_db_session):
|
|
"""Client credentials grant returns access token."""
|
|
with patch('oauth2.controller.OAuthService') as MockOAuthService:
|
|
mock_service = AsyncMock()
|
|
mock_service.client_credentials_flow.return_value = {
|
|
"access_token": "client_token",
|
|
"token_type": "Bearer",
|
|
"expires_in": 1800,
|
|
"scope": "read",
|
|
}
|
|
MockOAuthService.return_value = mock_service
|
|
|
|
response = client.post(
|
|
"/oauth/token",
|
|
data={
|
|
"grant_type": "client_credentials",
|
|
"client_id": "test_client",
|
|
"client_secret": "secret",
|
|
"scope": "read",
|
|
}
|
|
)
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert data["access_token"] == "client_token"
|
|
mock_service.client_credentials_flow.assert_called_once_with(
|
|
client_id="test_client",
|
|
client_secret="secret",
|
|
scope=["read"],
|
|
)
|
|
|
|
|
|
def test_token_refresh_token_success(client, mock_db_session):
|
|
"""Refresh token grant returns new tokens."""
|
|
with patch('oauth2.controller.OAuthService') as MockOAuthService:
|
|
mock_service = AsyncMock()
|
|
mock_service.refresh_token_flow.return_value = {
|
|
"access_token": "new_access_token",
|
|
"token_type": "Bearer",
|
|
"expires_in": 1800,
|
|
"refresh_token": "new_refresh_token",
|
|
"scope": "read",
|
|
}
|
|
MockOAuthService.return_value = mock_service
|
|
|
|
response = client.post(
|
|
"/oauth/token",
|
|
data={
|
|
"grant_type": "refresh_token",
|
|
"refresh_token": "old_refresh_token",
|
|
"client_id": "test_client",
|
|
"client_secret": "secret",
|
|
}
|
|
)
|
|
assert response.status_code == status.HTTP_200_OK
|
|
mock_service.refresh_token_flow.assert_called_once_with(
|
|
refresh_token="old_refresh_token",
|
|
client_id="test_client",
|
|
client_secret="secret",
|
|
scope=[],
|
|
)
|
|
|
|
|
|
# ---------- UserInfo Endpoint Tests ----------
|
|
def test_userinfo_missing_token(client):
|
|
"""UserInfo requires Bearer token."""
|
|
response = client.get("/oauth/userinfo")
|
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
|
|
|
|
|
def test_userinfo_with_valid_token(client, mock_db_session):
|
|
"""UserInfo returns claims from token payload."""
|
|
# Mock get_current_token_payload dependency
|
|
from oauth2.dependencies import get_current_token_payload
|
|
async def mock_payload():
|
|
return {"sub": "user1", "client_id": "client1", "scopes": ["profile"]}
|
|
|
|
app = create_test_app({get_current_token_payload: mock_payload})
|
|
client_with_auth = TestClient(app)
|
|
|
|
response = client_with_auth.get("/oauth/userinfo")
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert data["sub"] == "user1"
|
|
assert data["client_id"] == "client1"
|
|
assert "scope" in data
|
|
|
|
|
|
# ---------- Introspection Endpoint Tests ----------
|
|
def test_introspect_missing_authentication(client):
|
|
"""Introspection requires client credentials."""
|
|
response = client.post("/oauth/introspect", data={"token": "some_token"})
|
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
|
|
|
|
|
def test_introspect_success(client, mock_db_session):
|
|
"""Introspection returns active token metadata."""
|
|
# Mock ClientService.validate_client and TokenService.verify_token
|
|
with patch('oauth2.controller.ClientService') as MockClientService, \
|
|
patch('oauth2.controller.TokenService') as MockTokenService:
|
|
mock_client_service = AsyncMock()
|
|
mock_client_service.validate_client.return_value = True
|
|
MockClientService.return_value = mock_client_service
|
|
|
|
mock_token_service = AsyncMock()
|
|
mock_token_service.verify_token.return_value = {
|
|
"sub": "user1",
|
|
"client_id": "client1",
|
|
"scopes": ["read"],
|
|
"token_type": "Bearer",
|
|
"exp": 1234567890,
|
|
"iat": 1234567800,
|
|
"jti": "jti_123",
|
|
}
|
|
MockTokenService.return_value = mock_token_service
|
|
|
|
response = client.post(
|
|
"/oauth/introspect",
|
|
data={"token": "valid_token"},
|
|
auth=("test_client", "secret"),
|
|
)
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert data["active"] is True
|
|
assert data["sub"] == "user1"
|
|
assert data["client_id"] == "client1"
|
|
|
|
|
|
# ---------- Revocation Endpoint Tests ----------
|
|
def test_revoke_missing_authentication(client):
|
|
"""Revocation requires client credentials."""
|
|
response = client.post("/oauth/revoke", data={"token": "some_token"})
|
|
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
|
|
|
|
|
def test_revoke_success(client, mock_db_session):
|
|
"""Successful revocation returns 200."""
|
|
with patch('oauth2.controller.ClientService') as MockClientService, \
|
|
patch('oauth2.controller.TokenService') as MockTokenService:
|
|
mock_client_service = AsyncMock()
|
|
mock_client_service.validate_client.return_value = True
|
|
MockClientService.return_value = mock_client_service
|
|
|
|
mock_token_service = AsyncMock()
|
|
mock_token_service.revoke_token.return_value = True
|
|
MockTokenService.return_value = mock_token_service
|
|
|
|
response = client.post(
|
|
"/oauth/revoke",
|
|
data={"token": "token_to_revoke"},
|
|
auth=("test_client", "secret"),
|
|
)
|
|
assert response.status_code == status.HTTP_200_OK
|
|
assert response.content == b""
|
|
|
|
|
|
# ---------- OpenID Configuration Endpoint ----------
|
|
def test_openid_configuration(client):
|
|
"""Discovery endpoint returns provider metadata."""
|
|
response = client.get("/oauth/.well-known/openid-configuration")
|
|
assert response.status_code == status.HTTP_200_OK
|
|
data = response.json()
|
|
assert "issuer" in data
|
|
assert "authorization_endpoint" in data
|
|
assert "token_endpoint" in data
|
|
assert "userinfo_endpoint" in data |