diff --git a/src/guardrails/checks/text/moderation.py b/src/guardrails/checks/text/moderation.py index c536669..322f709 100644 --- a/src/guardrails/checks/text/moderation.py +++ b/src/guardrails/checks/text/moderation.py @@ -27,6 +27,7 @@ from __future__ import annotations +import asyncio import logging from enum import Enum from functools import cache @@ -130,11 +131,11 @@ def _get_moderation_client() -> AsyncOpenAI: return AsyncOpenAI() -async def _call_moderation_api(client: AsyncOpenAI, data: str) -> Any: - """Call the OpenAI moderation API. +async def _call_moderation_api_async(client: Any, data: str) -> Any: + """Call the OpenAI moderation API asynchronously. Args: - client: The OpenAI client to use. + client: The async OpenAI or Azure OpenAI client to use. data: The text to analyze. Returns: @@ -146,6 +147,22 @@ async def _call_moderation_api(client: AsyncOpenAI, data: str) -> Any: ) +def _call_moderation_api_sync(client: Any, data: str) -> Any: + """Call the OpenAI moderation API synchronously. + + Args: + client: The sync OpenAI or Azure OpenAI client to use. + data: The text to analyze. + + Returns: + The moderation API response. + """ + return client.moderations.create( + model="omni-moderation-latest", + input=data, + ) + + async def moderation( ctx: Any, data: str, @@ -165,29 +182,32 @@ async def moderation( Returns: GuardrailResult: Indicates if tripwire was triggered, and details of flagged categories. """ - client = None - if ctx is not None: - candidate = getattr(ctx, "guardrail_llm", None) - if isinstance(candidate, AsyncOpenAI): - client = candidate + # Try context client first (if provided), fall back on 404 + client = getattr(ctx, "guardrail_llm", None) if ctx is not None else None - # Try the context client first, fall back if moderation endpoint doesn't exist if client is not None: + # Determine if client is async or sync + is_async = isinstance(client, AsyncOpenAI) + try: - resp = await _call_moderation_api(client, data) + if is_async: + resp = await _call_moderation_api_async(client, data) + else: + # Sync client - run in thread pool to avoid blocking event loop + resp = await asyncio.to_thread(_call_moderation_api_sync, client, data) except NotFoundError as e: - # Moderation endpoint doesn't exist on this provider (e.g., third-party) - # Fall back to the OpenAI client + # Moderation endpoint doesn't exist (e.g., Azure, third-party) + # Fall back to OpenAI client with OPENAI_API_KEY env var logger.debug( "Moderation endpoint not available on context client, falling back to OpenAI: %s", e, ) client = _get_moderation_client() - resp = await _call_moderation_api(client, data) + resp = await _call_moderation_api_async(client, data) else: - # No context client, use fallback + # No context client - use fallback OpenAI client client = _get_moderation_client() - resp = await _call_moderation_api(client, data) + resp = await _call_moderation_api_async(client, data) results = resp.results or [] if not results: return GuardrailResult( diff --git a/tests/unit/checks/test_moderation.py b/tests/unit/checks/test_moderation.py index 11351ea..f3879fd 100644 --- a/tests/unit/checks/test_moderation.py +++ b/tests/unit/checks/test_moderation.py @@ -134,3 +134,89 @@ async def raise_not_found(**_: Any) -> Any: # Verify the fallback client was used (not the third-party one) assert fallback_used is True # noqa: S101 assert result.tripwire_triggered is False # noqa: S101 + + +@pytest.mark.asyncio +async def test_moderation_uses_sync_context_client() -> None: + """Moderation should support synchronous OpenAI clients from context.""" + from openai import OpenAI + + # Track whether sync context client was used + sync_client_used = False + + def track_sync_create(**_: Any) -> Any: + nonlocal sync_client_used + sync_client_used = True + + class _Result: + def model_dump(self) -> dict[str, Any]: + return {"categories": {"hate": False, "violence": False}} + + return SimpleNamespace(results=[_Result()]) + + # Create a sync context client + sync_client = OpenAI(api_key="test-sync-key", base_url="https://api.openai.com/v1") + sync_client.moderations = SimpleNamespace(create=track_sync_create) # type: ignore[assignment] + + ctx = SimpleNamespace(guardrail_llm=sync_client) + + cfg = ModerationCfg(categories=[Category.HATE, Category.VIOLENCE]) + result = await moderation(ctx, "test text", cfg) + + # Verify the sync context client was used (via asyncio.to_thread) + assert sync_client_used is True # noqa: S101 + assert result.tripwire_triggered is False # noqa: S101 + + +@pytest.mark.asyncio +async def test_moderation_falls_back_for_azure_clients(monkeypatch: pytest.MonkeyPatch) -> None: + """Moderation should fall back to OpenAI client for Azure clients (no moderation endpoint).""" + try: + from openai import AsyncAzureOpenAI, NotFoundError + except ImportError: + pytest.skip("Azure OpenAI not available") + + # Track whether fallback was used + fallback_used = False + + async def track_fallback_create(**_: Any) -> Any: + nonlocal fallback_used + fallback_used = True + + class _Result: + def model_dump(self) -> dict[str, Any]: + return {"categories": {"hate": False, "violence": False}} + + return SimpleNamespace(results=[_Result()]) + + # Mock the fallback client + fallback_client = SimpleNamespace(moderations=SimpleNamespace(create=track_fallback_create)) + monkeypatch.setattr("guardrails.checks.text.moderation._get_moderation_client", lambda: fallback_client) + + # Create a mock httpx.Response for NotFoundError + mock_response = SimpleNamespace( + status_code=404, + headers={}, + text="404 page not found", + json=lambda: {"error": {"message": "Not found", "type": "invalid_request_error"}}, + ) + + # Create an Azure context client that raises NotFoundError for moderation + async def raise_not_found(**_: Any) -> Any: + raise NotFoundError("404 page not found", response=mock_response, body=None) # type: ignore[arg-type] + + azure_client = AsyncAzureOpenAI( + api_key="test-azure-key", + api_version="2024-02-01", + azure_endpoint="https://test.openai.azure.com", + ) + azure_client.moderations = SimpleNamespace(create=raise_not_found) # type: ignore[assignment] + + ctx = SimpleNamespace(guardrail_llm=azure_client) + + cfg = ModerationCfg(categories=[Category.HATE, Category.VIOLENCE]) + result = await moderation(ctx, "test text", cfg) + + # Verify the fallback client was used (not the Azure one) + assert fallback_used is True # noqa: S101 + assert result.tripwire_triggered is False # noqa: S101