Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 35 additions & 15 deletions src/guardrails/checks/text/moderation.py
Original file line numberDiff line numberDiff line change
Expand Up@@ -27,6 +27,7 @@

from __future__ import annotations

import asyncio
import logging
from enum import Enum
from functools import cache
Expand DownExpand Up@@ -130,11 +131,11 @@ def _get_moderation_client() -> AsyncOpenAI:
return AsyncOpenAI()


async def _call_moderation_api(client: AsyncOpenAI, data: str) -> Any:
"""Call the OpenAI moderation API.
async def _call_moderation_api_async(client: Any, data: str) -> Any:
"""Call the OpenAI moderation API asynchronously.

Args:
client: The OpenAI client to use.
client: The async OpenAI or Azure OpenAI client to use.
data: The text to analyze.

Returns:
Expand All@@ -146,6 +147,22 @@ async def _call_moderation_api(client: AsyncOpenAI, data: str) -> Any:
)


def _call_moderation_api_sync(client: Any, data: str) -> Any:
"""Call the OpenAI moderation API synchronously.

Args:
client: The sync OpenAI or Azure OpenAI client to use.
data: The text to analyze.

Returns:
The moderation API response.
"""
return client.moderations.create(
model="omni-moderation-latest",
input=data,
)


async def moderation(
ctx: Any,
data: str,
Expand All@@ -165,29 +182,32 @@ async def moderation(
Returns:
GuardrailResult: Indicates if tripwire was triggered, and details of flagged categories.
"""
client = None
if ctx is not None:
candidate = getattr(ctx, "guardrail_llm", None)
if isinstance(candidate, AsyncOpenAI):
client = candidate
# Try context client first (if provided), fall back on 404
client = getattr(ctx, "guardrail_llm", None) if ctx is not None else None

# Try the context client first, fall back if moderation endpoint doesn't exist
if client is not None:
# Determine if client is async or sync
is_async = isinstance(client, AsyncOpenAI)
Copy link

CopilotAINov 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The client type detection assumes any non-AsyncOpenAI client is synchronous, which may not be accurate for Azure clients or other variants. Consider using a more explicit check like hasattr(client.moderations.create, '__call__') and not inspect.iscoroutinefunction(client.moderations.create) similar to the pattern in llm_base.py (line 231), or import and check against specific sync client types (OpenAI, AzureOpenAI).

Copilot uses AI. Check for mistakes.
Copy link
CollaboratorAuthor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WAI. Only OpenAI clients are valid for moderation so we will catch the other types and fallback to generating an OAI client.


try:
resp = await _call_moderation_api(client, data)
if is_async:
resp = await _call_moderation_api_async(client, data)
else:
# Sync client - run in thread pool to avoid blocking event loop
resp = await asyncio.to_thread(_call_moderation_api_sync, client, data)
except NotFoundError as e:
# Moderation endpoint doesn't exist on this provider (e.g., third-party)
# Fall back to the OpenAI client
# Moderation endpoint doesn't exist (e.g., Azure, third-party)
# Fall back to OpenAI client with OPENAI_API_KEY env var
logger.debug(
"Moderation endpoint not available on context client, falling back to OpenAI: %s",
e,
)
client = _get_moderation_client()
resp = await _call_moderation_api(client, data)
resp = await _call_moderation_api_async(client, data)
else:
# No context client, use fallback
# No context client - use fallback OpenAI client
client = _get_moderation_client()
resp = await _call_moderation_api(client, data)
resp = await _call_moderation_api_async(client, data)
results = resp.results or []
if not results:
return GuardrailResult(
Expand Down
86 changes: 86 additions & 0 deletions tests/unit/checks/test_moderation.py
Original file line numberDiff line numberDiff line change
Expand Up@@ -134,3 +134,89 @@ async def raise_not_found(**_: Any) -> Any:
# Verify the fallback client was used (not the third-party one)
assert fallback_used is True # noqa: S101
assert result.tripwire_triggered is False # noqa: S101


@pytest.mark.asyncio
async def test_moderation_uses_sync_context_client() -> None:
"""Moderation should support synchronous OpenAI clients from context."""
from openai import OpenAI

# Track whether sync context client was used
sync_client_used = False

def track_sync_create(**_: Any) -> Any:
nonlocal sync_client_used
sync_client_used = True

class _Result:
def model_dump(self) -> dict[str, Any]:
return{"categories":{"hate": False, "violence": False}}

return SimpleNamespace(results=[_Result()])

# Create a sync context client
sync_client = OpenAI(api_key="test-sync-key", base_url="https://api.openai.com/v1")
sync_client.moderations = SimpleNamespace(create=track_sync_create) # type: ignore[assignment]

ctx = SimpleNamespace(guardrail_llm=sync_client)

cfg = ModerationCfg(categories=[Category.HATE, Category.VIOLENCE])
result = await moderation(ctx, "test text", cfg)

# Verify the sync context client was used (via asyncio.to_thread)
assert sync_client_used is True # noqa: S101
assert result.tripwire_triggered is False # noqa: S101


@pytest.mark.asyncio
async def test_moderation_falls_back_for_azure_clients(monkeypatch: pytest.MonkeyPatch) -> None:
"""Moderation should fall back to OpenAI client for Azure clients (no moderation endpoint)."""
try:
from openai import AsyncAzureOpenAI, NotFoundError
except ImportError:
pytest.skip("Azure OpenAI not available")

# Track whether fallback was used
fallback_used = False

async def track_fallback_create(**_: Any) -> Any:
nonlocal fallback_used
fallback_used = True

class _Result:
def model_dump(self) -> dict[str, Any]:
return{"categories":{"hate": False, "violence": False}}

return SimpleNamespace(results=[_Result()])

# Mock the fallback client
fallback_client = SimpleNamespace(moderations=SimpleNamespace(create=track_fallback_create))
monkeypatch.setattr("guardrails.checks.text.moderation._get_moderation_client", lambda: fallback_client)

# Create a mock httpx.Response for NotFoundError
mock_response = SimpleNamespace(
status_code=404,
headers={},
text="404 page not found",
json=lambda:{"error":{"message": "Not found", "type": "invalid_request_error"}},
)

# Create an Azure context client that raises NotFoundError for moderation
async def raise_not_found(**_: Any) -> Any:
raise NotFoundError("404 page not found", response=mock_response, body=None) # type: ignore[arg-type]

azure_client = AsyncAzureOpenAI(
api_key="test-azure-key",
api_version="2024-02-01",
azure_endpoint="https://test.openai.azure.com",
)
azure_client.moderations = SimpleNamespace(create=raise_not_found) # type: ignore[assignment]

ctx = SimpleNamespace(guardrail_llm=azure_client)

cfg = ModerationCfg(categories=[Category.HATE, Category.VIOLENCE])
result = await moderation(ctx, "test text", cfg)

# Verify the fallback client was used (not the Azure one)
assert fallback_used is True # noqa: S101
assert result.tripwire_triggered is False # noqa: S101
Loading