mockapi/tests/test_oauth2_controller.py
2026-03-16 10:49:01 +00:00

337 lines
No EOL
12 KiB
Python

"""
Unit tests for OAuth2 controller endpoints.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from fastapi.testclient import TestClient
from fastapi import FastAPI, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.modules.oauth2.controller import router as oauth_router
def create_test_app(override_dependency=None) -> FastAPI:
"""Create a FastAPI app with OAuth router and optional dependency overrides."""
app = FastAPI()
app.include_router(oauth_router)
if override_dependency:
app.dependency_overrides.update(override_dependency)
return app
@pytest.fixture
def mock_db_session():
"""Mock database session."""
session = AsyncMock(spec=AsyncSession)
return session
@pytest.fixture
def client(mock_db_session):
"""Test client with mocked database session."""
from app.core.database import get_db
def override_get_db():
yield mock_db_session
app = create_test_app({get_db: override_get_db})
return TestClient(app)
# ---------- Authorization Endpoint Tests ----------
def test_authorize_missing_parameters(client):
"""GET /oauth/authorize without required parameters should return error."""
response = client.get("/oauth/authorize")
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
# FastAPI returns validation error details
def test_authorize_unsupported_response_type(client):
"""Only 'code' response_type is supported."""
response = client.get(
"/oauth/authorize",
params={
"response_type": "token", # unsupported
"client_id": "test_client",
"redirect_uri": "https://example.com/callback",
}
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
data = response.json()
assert "detail" in data
assert data["detail"]["error"] == "unsupported_response_type"
def test_authorize_success(client, mock_db_session):
"""Successful authorization returns redirect with code."""
# Mock OAuthService.authorize_code_flow
with patch('app.modules.oauth2.controller.OAuthService') as MockOAuthService:
mock_service = AsyncMock()
mock_service.authorize_code_flow.return_value = {
"code": "auth_code_123",
"state": "xyz",
}
MockOAuthService.return_value = mock_service
response = client.get(
"/oauth/authorize",
params={
"response_type": "code",
"client_id": "test_client",
"redirect_uri": "https://example.com/callback",
"scope": "read write",
"state": "xyz",
},
follow_redirects=False
)
assert response.status_code == status.HTTP_302_FOUND
assert "location" in response.headers
location = response.headers["location"]
assert location.startswith("https://example.com/callback?")
assert "code=auth_code_123" in location
assert "state=xyz" in location
# Verify service was called with correct parameters
mock_service.authorize_code_flow.assert_called_once_with(
client_id="test_client",
redirect_uri="https://example.com/callback",
scope=["read", "write"],
state="xyz",
user_id=1, # placeholder
)
# ---------- Token Endpoint Tests ----------
def test_token_missing_grant_type(client):
"""POST /oauth/token without grant_type should error (client auth required)."""
response = client.post("/oauth/token", data={})
# Client authentication missing -> 401 unauthorized
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_token_unsupported_grant_type(client):
"""Unsupported grant_type returns error."""
response = client.post(
"/oauth/token",
data={"grant_type": "password"}, # not supported
auth=("test_client", "secret")
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
data = response.json()
assert "detail" in data
assert data["detail"]["error"] == "unsupported_grant_type"
def test_token_authorization_code_missing_code(client):
"""authorization_code grant requires code."""
response = client.post(
"/oauth/token",
data={
"grant_type": "authorization_code",
"client_id": "test_client",
"client_secret": "secret",
"redirect_uri": "https://example.com/callback",
}
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
data = response.json()
assert "detail" in data
assert data["detail"]["error"] == "invalid_request"
def test_token_authorization_code_success(client, mock_db_session):
"""Successful authorization_code exchange returns tokens."""
with patch('app.modules.oauth2.controller.OAuthService') as MockOAuthService:
mock_service = AsyncMock()
mock_service.exchange_code_for_tokens.return_value = {
"access_token": "access_token_123",
"token_type": "Bearer",
"expires_in": 1800,
"refresh_token": "refresh_token_456",
"scope": "read write",
}
MockOAuthService.return_value = mock_service
response = client.post(
"/oauth/token",
data={
"grant_type": "authorization_code",
"code": "auth_code_xyz",
"redirect_uri": "https://example.com/callback",
"client_id": "test_client",
"client_secret": "secret",
}
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["access_token"] == "access_token_123"
assert data["token_type"] == "Bearer"
assert data["refresh_token"] == "refresh_token_456"
mock_service.exchange_code_for_tokens.assert_called_once_with(
code="auth_code_xyz",
client_id="test_client",
redirect_uri="https://example.com/callback",
)
def test_token_client_credentials_success(client, mock_db_session):
"""Client credentials grant returns access token."""
with patch('app.modules.oauth2.controller.OAuthService') as MockOAuthService:
mock_service = AsyncMock()
mock_service.client_credentials_flow.return_value = {
"access_token": "client_token",
"token_type": "Bearer",
"expires_in": 1800,
"scope": "read",
}
MockOAuthService.return_value = mock_service
response = client.post(
"/oauth/token",
data={
"grant_type": "client_credentials",
"client_id": "test_client",
"client_secret": "secret",
"scope": "read",
}
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["access_token"] == "client_token"
mock_service.client_credentials_flow.assert_called_once_with(
client_id="test_client",
client_secret="secret",
scope=["read"],
)
def test_token_refresh_token_success(client, mock_db_session):
"""Refresh token grant returns new tokens."""
with patch('app.modules.oauth2.controller.OAuthService') as MockOAuthService:
mock_service = AsyncMock()
mock_service.refresh_token_flow.return_value = {
"access_token": "new_access_token",
"token_type": "Bearer",
"expires_in": 1800,
"refresh_token": "new_refresh_token",
"scope": "read",
}
MockOAuthService.return_value = mock_service
response = client.post(
"/oauth/token",
data={
"grant_type": "refresh_token",
"refresh_token": "old_refresh_token",
"client_id": "test_client",
"client_secret": "secret",
}
)
assert response.status_code == status.HTTP_200_OK
mock_service.refresh_token_flow.assert_called_once_with(
refresh_token="old_refresh_token",
client_id="test_client",
client_secret="secret",
scope=[],
)
# ---------- UserInfo Endpoint Tests ----------
def test_userinfo_missing_token(client):
"""UserInfo requires Bearer token."""
response = client.get("/oauth/userinfo")
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_userinfo_with_valid_token(client, mock_db_session):
"""UserInfo returns claims from token payload."""
# Mock get_current_token_payload dependency
from app.modules.oauth2.dependencies import get_current_token_payload
async def mock_payload():
return {"sub": "user1", "client_id": "client1", "scopes": ["profile"]}
app = create_test_app({get_current_token_payload: mock_payload})
client_with_auth = TestClient(app)
response = client_with_auth.get("/oauth/userinfo")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["sub"] == "user1"
assert data["client_id"] == "client1"
assert "scope" in data
# ---------- Introspection Endpoint Tests ----------
def test_introspect_missing_authentication(client):
"""Introspection requires client credentials."""
response = client.post("/oauth/introspect", data={"token": "some_token"})
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_introspect_success(client, mock_db_session):
"""Introspection returns active token metadata."""
# Mock ClientService.validate_client and TokenService.verify_token
with patch('app.modules.oauth2.controller.ClientService') as MockClientService, \
patch('app.modules.oauth2.controller.TokenService') as MockTokenService:
mock_client_service = AsyncMock()
mock_client_service.validate_client.return_value = True
MockClientService.return_value = mock_client_service
mock_token_service = AsyncMock()
mock_token_service.verify_token.return_value = {
"sub": "user1",
"client_id": "client1",
"scopes": ["read"],
"token_type": "Bearer",
"exp": 1234567890,
"iat": 1234567800,
"jti": "jti_123",
}
MockTokenService.return_value = mock_token_service
response = client.post(
"/oauth/introspect",
data={"token": "valid_token"},
auth=("test_client", "secret"),
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["active"] is True
assert data["sub"] == "user1"
assert data["client_id"] == "client1"
# ---------- Revocation Endpoint Tests ----------
def test_revoke_missing_authentication(client):
"""Revocation requires client credentials."""
response = client.post("/oauth/revoke", data={"token": "some_token"})
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_revoke_success(client, mock_db_session):
"""Successful revocation returns 200."""
with patch('app.modules.oauth2.controller.ClientService') as MockClientService, \
patch('app.modules.oauth2.controller.TokenService') as MockTokenService:
mock_client_service = AsyncMock()
mock_client_service.validate_client.return_value = True
MockClientService.return_value = mock_client_service
mock_token_service = AsyncMock()
mock_token_service.revoke_token.return_value = True
MockTokenService.return_value = mock_token_service
response = client.post(
"/oauth/revoke",
data={"token": "token_to_revoke"},
auth=("test_client", "secret"),
)
assert response.status_code == status.HTTP_200_OK
assert response.content == b""
# ---------- OpenID Configuration Endpoint ----------
def test_openid_configuration(client):
"""Discovery endpoint returns provider metadata."""
response = client.get("/oauth/.well-known/openid-configuration")
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert "issuer" in data
assert "authorization_endpoint" in data
assert "token_endpoint" in data
assert "userinfo_endpoint" in data