From 514c197e2fcea99dac757033d3d227b098976077 Mon Sep 17 00:00:00 2001 From: Steven C Date: Thu, 30 Oct 2025 16:26:12 -0400 Subject: [PATCH 1/3] fixing safety identifier --- src/guardrails/_openai_utils.py | 25 ----- src/guardrails/agents.py | 5 +- src/guardrails/checks/text/llm_base.py | 49 +++++++-- src/guardrails/checks/text/moderation.py | 4 +- src/guardrails/client.py | 9 +- src/guardrails/evals/guardrail_evals.py | 5 +- src/guardrails/resources/chat/chat.py | 62 +++++++++-- .../resources/responses/responses.py | 100 ++++++++++++++---- src/guardrails/runtime.py | 3 +- src/guardrails/utils/create_vector_store.py | 4 +- tests/unit/test_agents.py | 9 +- tests/unit/test_openai_utils.py | 33 ------ tests/unit/test_safety_identifier.py | 73 +++++++++++++ 13 files changed, 259 insertions(+), 122 deletions(-) delete mode 100644 src/guardrails/_openai_utils.py delete mode 100644 tests/unit/test_openai_utils.py create mode 100644 tests/unit/test_safety_identifier.py diff --git a/src/guardrails/_openai_utils.py b/src/guardrails/_openai_utils.py deleted file mode 100644 index 6121d03..0000000 --- a/src/guardrails/_openai_utils.py +++ /dev/null @@ -1,25 +0,0 @@ -"""Utilities for configuring OpenAI clients used by guardrails.""" - -from __future__ import annotations - -from collections.abc import Mapping -from typing import Any - -SAFETY_IDENTIFIER_HEADER = "OpenAI-Safety-Identifier" -SAFETY_IDENTIFIER_VALUE = "oai_guardrails" - - -def ensure_safety_identifier_header(default_headers: Mapping[str, str] | None) -> dict[str, str]: - """Return headers with the Guardrails safety identifier applied.""" - headers = dict(default_headers or {}) - headers[SAFETY_IDENTIFIER_HEADER] = SAFETY_IDENTIFIER_VALUE - return headers - - -def prepare_openai_kwargs(openai_kwargs: dict[str, Any]) -> dict[str, Any]: - """Return OpenAI constructor kwargs that include the safety identifier header.""" - prepared_kwargs = dict(openai_kwargs) - default_headers = prepared_kwargs.get("default_headers") - headers = ensure_safety_identifier_header(default_headers if isinstance(default_headers, Mapping) else None) - prepared_kwargs["default_headers"] = headers - return prepared_kwargs diff --git a/src/guardrails/agents.py b/src/guardrails/agents.py index 0dbe077..4f3202c 100644 --- a/src/guardrails/agents.py +++ b/src/guardrails/agents.py @@ -18,7 +18,6 @@ from pathlib import Path from typing import Any -from ._openai_utils import prepare_openai_kwargs from .utils.conversation import merge_conversation_with_items, normalize_conversation logger = logging.getLogger(__name__) @@ -167,7 +166,7 @@ def _create_default_tool_context() -> Any: class DefaultContext: guardrail_llm: AsyncOpenAI - return DefaultContext(guardrail_llm=AsyncOpenAI(**prepare_openai_kwargs({}))) + return DefaultContext(guardrail_llm=AsyncOpenAI()) def _create_conversation_context( @@ -393,7 +392,7 @@ def _create_agents_guardrails_from_config( class DefaultContext: guardrail_llm: AsyncOpenAI - context = DefaultContext(guardrail_llm=AsyncOpenAI(**prepare_openai_kwargs({}))) + context = DefaultContext(guardrail_llm=AsyncOpenAI()) def _create_stage_guardrail(stage_name: str): async def stage_guardrail(ctx: RunContextWrapper[None], agent: Agent, input_data: str) -> GuardrailFunctionOutput: diff --git a/src/guardrails/checks/text/llm_base.py b/src/guardrails/checks/text/llm_base.py index 5a843b2..bfa4ed2 100644 --- a/src/guardrails/checks/text/llm_base.py +++ b/src/guardrails/checks/text/llm_base.py @@ -48,6 +48,37 @@ class MyLLMOutput(LLMOutput): from guardrails.types import CheckFn, GuardrailLLMContextProto, GuardrailResult from guardrails.utils.output import OutputSchema +# OpenAI safety identifier for tracking guardrails library usage +# Only supported by official OpenAI API (not Azure or local/alternative providers) +_SAFETY_IDENTIFIER = "oai_guardrails" + + +def _supports_safety_identifier(client: AsyncOpenAI | OpenAI | AsyncAzureOpenAI | AzureOpenAI) -> bool: + """Check if the client supports the safety_identifier parameter. + + Only the official OpenAI API supports this parameter. + Azure OpenAI and local/alternative providers do not. + + Args: + client: The OpenAI client instance. + + Returns: + True if safety_identifier should be included, False otherwise. + """ + # Azure clients don't support it + if isinstance(client, AsyncAzureOpenAI | AzureOpenAI): + return False + + # Check if using a custom base_url (local or alternative provider) + base_url = getattr(client, "base_url", None) + if base_url is not None: + base_url_str = str(base_url) + # Only official OpenAI API endpoints support safety_identifier + return "api.openai.com" in base_url_str + + # Default OpenAI client (no custom base_url) supports it + return True + if TYPE_CHECKING: from openai import AsyncAzureOpenAI, AzureOpenAI # type: ignore[unused-import] else: @@ -247,12 +278,18 @@ async def _request_chat_completion( response_format: dict[str, Any], ) -> Any: """Invoke chat.completions.create on sync or async OpenAI clients.""" - return await _invoke_openai_callable( - client.chat.completions.create, - messages=messages, - model=model, - response_format=response_format, - ) + # Only include safety_identifier for official OpenAI API + kwargs: dict[str, Any] = { + "messages": messages, + "model": model, + "response_format": response_format, + } + + # Only official OpenAI API supports safety_identifier (not Azure or local models) + if _supports_safety_identifier(client): + kwargs["safety_identifier"] = _SAFETY_IDENTIFIER + + return await _invoke_openai_callable(client.chat.completions.create, **kwargs) async def run_llm( diff --git a/src/guardrails/checks/text/moderation.py b/src/guardrails/checks/text/moderation.py index 66f8bb1..cd0a353 100644 --- a/src/guardrails/checks/text/moderation.py +++ b/src/guardrails/checks/text/moderation.py @@ -39,8 +39,6 @@ from guardrails.spec import GuardrailSpecMetadata from guardrails.types import GuardrailResult -from ..._openai_utils import prepare_openai_kwargs - logger = logging.getLogger(__name__) __all__ = ["moderation", "Category", "ModerationCfg"] @@ -129,7 +127,7 @@ def _get_moderation_client() -> AsyncOpenAI: Returns: AsyncOpenAI: Cached OpenAI API client for moderation checks. """ - return AsyncOpenAI(**prepare_openai_kwargs({})) + return AsyncOpenAI() async def moderation( diff --git a/src/guardrails/client.py b/src/guardrails/client.py index 01bcb9d..208b2d4 100644 --- a/src/guardrails/client.py +++ b/src/guardrails/client.py @@ -26,7 +26,6 @@ GuardrailsResponse, OpenAIResponseType, ) -from ._openai_utils import prepare_openai_kwargs from ._streaming import StreamingMixin from .exceptions import GuardrailTripwireTriggered from .runtime import run_guardrails @@ -167,7 +166,6 @@ def __init__( by this parameter. **openai_kwargs: Additional arguments passed to AsyncOpenAI constructor. """ - openai_kwargs = prepare_openai_kwargs(openai_kwargs) # Initialize OpenAI client first super().__init__(**openai_kwargs) @@ -205,7 +203,7 @@ class DefaultContext: default_headers = getattr(self, "default_headers", None) if default_headers is not None: guardrail_kwargs["default_headers"] = default_headers - guardrail_client = AsyncOpenAI(**prepare_openai_kwargs(guardrail_kwargs)) + guardrail_client = AsyncOpenAI(**guardrail_kwargs) return DefaultContext(guardrail_llm=guardrail_client) @@ -335,7 +333,6 @@ def __init__( by this parameter. **openai_kwargs: Additional arguments passed to OpenAI constructor. """ - openai_kwargs = prepare_openai_kwargs(openai_kwargs) # Initialize OpenAI client first super().__init__(**openai_kwargs) @@ -373,7 +370,7 @@ class DefaultContext: default_headers = getattr(self, "default_headers", None) if default_headers is not None: guardrail_kwargs["default_headers"] = default_headers - guardrail_client = OpenAI(**prepare_openai_kwargs(guardrail_kwargs)) + guardrail_client = OpenAI(**guardrail_kwargs) return DefaultContext(guardrail_llm=guardrail_client) @@ -516,7 +513,6 @@ def __init__( by this parameter. **azure_kwargs: Additional arguments passed to AsyncAzureOpenAI constructor. """ - azure_kwargs = prepare_openai_kwargs(azure_kwargs) # Initialize Azure client first super().__init__(**azure_kwargs) @@ -671,7 +667,6 @@ def __init__( by this parameter. **azure_kwargs: Additional arguments passed to AzureOpenAI constructor. """ - azure_kwargs = prepare_openai_kwargs(azure_kwargs) super().__init__(**azure_kwargs) # Store the error handling preference diff --git a/src/guardrails/evals/guardrail_evals.py b/src/guardrails/evals/guardrail_evals.py index e86aee4..a2c2ae6 100644 --- a/src/guardrails/evals/guardrail_evals.py +++ b/src/guardrails/evals/guardrail_evals.py @@ -23,7 +23,6 @@ from guardrails import instantiate_guardrails, load_pipeline_bundles -from guardrails._openai_utils import prepare_openai_kwargs from guardrails.evals.core import ( AsyncRunEngine, BenchmarkMetricsCalculator, @@ -281,7 +280,7 @@ def _create_context(self) -> Context: if self.api_key: azure_kwargs["api_key"] = self.api_key - guardrail_llm = AsyncAzureOpenAI(**prepare_openai_kwargs(azure_kwargs)) + guardrail_llm = AsyncAzureOpenAI(**azure_kwargs) logger.info("Created Azure OpenAI client for endpoint: %s", self.azure_endpoint) # OpenAI or OpenAI-compatible API else: @@ -292,7 +291,7 @@ def _create_context(self) -> Context: openai_kwargs["base_url"] = self.base_url logger.info("Created OpenAI-compatible client for base_url: %s", self.base_url) - guardrail_llm = AsyncOpenAI(**prepare_openai_kwargs(openai_kwargs)) + guardrail_llm = AsyncOpenAI(**openai_kwargs) return Context(guardrail_llm=guardrail_llm) diff --git a/src/guardrails/resources/chat/chat.py b/src/guardrails/resources/chat/chat.py index aa0382e..c311287 100644 --- a/src/guardrails/resources/chat/chat.py +++ b/src/guardrails/resources/chat/chat.py @@ -7,6 +7,38 @@ from ..._base_client import GuardrailsBaseClient +# OpenAI safety identifier for tracking guardrails library usage +# Only supported by official OpenAI API (not Azure or local/alternative providers) +_SAFETY_IDENTIFIER = "oai_guardrails" + + +def _supports_safety_identifier(client: Any) -> bool: + """Check if the client supports the safety_identifier parameter. + + Only the official OpenAI API supports this parameter. + Azure OpenAI and local/alternative providers do not. + + Args: + client: The OpenAI client instance. + + Returns: + True if safety_identifier should be included, False otherwise. + """ + # Azure clients don't support it + client_type = type(client).__name__ + if "Azure" in client_type: + return False + + # Check if using a custom base_url (local or alternative provider) + base_url = getattr(client, "base_url", None) + if base_url is not None: + base_url_str = str(base_url) + # Only official OpenAI API endpoints support safety_identifier + return "api.openai.com" in base_url_str + + # Default OpenAI client (no custom base_url) supports it + return True + class Chat: """Chat completions with guardrails (sync).""" @@ -82,12 +114,19 @@ def create(self, messages: list[dict[str, str]], model: str, stream: bool = Fals # Run input guardrails and LLM call concurrently using a thread for the LLM with ThreadPoolExecutor(max_workers=1) as executor: + # Only include safety_identifier for OpenAI clients (not Azure) + llm_kwargs = { + "messages": modified_messages, + "model": model, + "stream": stream, + **kwargs, + } + if _supports_safety_identifier(self._client._resource_client): + llm_kwargs["safety_identifier"] = _SAFETY_IDENTIFIER + llm_future = executor.submit( self._client._resource_client.chat.completions.create, - messages=modified_messages, # Use messages with any preflight modifications - model=model, - stream=stream, - **kwargs, + **llm_kwargs, ) input_results = self._client._run_stage_guardrails( "input", @@ -152,12 +191,17 @@ async def create( conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) - llm_call = self._client._resource_client.chat.completions.create( - messages=modified_messages, # Use messages with any preflight modifications - model=model, - stream=stream, + # Only include safety_identifier for OpenAI clients (not Azure) + llm_kwargs = { + "messages": modified_messages, + "model": model, + "stream": stream, **kwargs, - ) + } + if _supports_safety_identifier(self._client._resource_client): + llm_kwargs["safety_identifier"] = _SAFETY_IDENTIFIER + + llm_call = self._client._resource_client.chat.completions.create(**llm_kwargs) input_results, llm_response = await asyncio.gather(input_check, llm_call) diff --git a/src/guardrails/resources/responses/responses.py b/src/guardrails/resources/responses/responses.py index 89a84f7..f56920b 100644 --- a/src/guardrails/resources/responses/responses.py +++ b/src/guardrails/resources/responses/responses.py @@ -9,6 +9,38 @@ from ..._base_client import GuardrailsBaseClient +# OpenAI safety identifier for tracking guardrails library usage +# Only supported by official OpenAI API (not Azure or local/alternative providers) +_SAFETY_IDENTIFIER = "oai_guardrails" + + +def _supports_safety_identifier(client: Any) -> bool: + """Check if the client supports the safety_identifier parameter. + + Only the official OpenAI API supports this parameter. + Azure OpenAI and local/alternative providers do not. + + Args: + client: The OpenAI client instance. + + Returns: + True if safety_identifier should be included, False otherwise. + """ + # Azure clients don't support it + client_type = type(client).__name__ + if "Azure" in client_type: + return False + + # Check if using a custom base_url (local or alternative provider) + base_url = getattr(client, "base_url", None) + if base_url is not None: + base_url_str = str(base_url) + # Only official OpenAI API endpoints support safety_identifier + return "api.openai.com" in base_url_str + + # Default OpenAI client (no custom base_url) supports it + return True + class Responses: """Responses API with guardrails (sync).""" @@ -63,13 +95,20 @@ def create( # Input guardrails and LLM call concurrently with ThreadPoolExecutor(max_workers=1) as executor: + # Only include safety_identifier for OpenAI clients (not Azure or local models) + llm_kwargs = { + "input": modified_input, + "model": model, + "stream": stream, + "tools": tools, + **kwargs, + } + if _supports_safety_identifier(self._client._resource_client): + llm_kwargs["safety_identifier"] = _SAFETY_IDENTIFIER + llm_future = executor.submit( self._client._resource_client.responses.create, - input=modified_input, # Use preflight-modified input - model=model, - stream=stream, - tools=tools, - **kwargs, + **llm_kwargs, ) input_results = self._client._run_stage_guardrails( "input", @@ -123,12 +162,19 @@ def parse(self, input: list[dict[str, str]], model: str, text_format: type[BaseM # Input guardrails and LLM call concurrently with ThreadPoolExecutor(max_workers=1) as executor: + # Only include safety_identifier for OpenAI clients (not Azure or local models) + llm_kwargs = { + "input": modified_input, + "model": model, + "text_format": text_format, + **kwargs, + } + if _supports_safety_identifier(self._client._resource_client): + llm_kwargs["safety_identifier"] = _SAFETY_IDENTIFIER + llm_future = executor.submit( self._client._resource_client.responses.parse, - input=modified_input, # Use modified input with preflight changes - model=model, - text_format=text_format, - **kwargs, + **llm_kwargs, ) input_results = self._client._run_stage_guardrails( "input", @@ -218,13 +264,19 @@ async def create( conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) - llm_call = self._client._resource_client.responses.create( - input=modified_input, # Use preflight-modified input - model=model, - stream=stream, - tools=tools, + + # Only include safety_identifier for OpenAI clients (not Azure or local models) + llm_kwargs = { + "input": modified_input, + "model": model, + "stream": stream, + "tools": tools, **kwargs, - ) + } + if _supports_safety_identifier(self._client._resource_client): + llm_kwargs["safety_identifier"] = _SAFETY_IDENTIFIER + + llm_call = self._client._resource_client.responses.create(**llm_kwargs) input_results, llm_response = await asyncio.gather(input_check, llm_call) @@ -278,13 +330,19 @@ async def parse( conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) - llm_call = self._client._resource_client.responses.parse( - input=modified_input, # Use modified input with preflight changes - model=model, - text_format=text_format, - stream=stream, + + # Only include safety_identifier for OpenAI clients (not Azure or local models) + llm_kwargs = { + "input": modified_input, + "model": model, + "text_format": text_format, + "stream": stream, **kwargs, - ) + } + if _supports_safety_identifier(self._client._resource_client): + llm_kwargs["safety_identifier"] = _SAFETY_IDENTIFIER + + llm_call = self._client._resource_client.responses.parse(**llm_kwargs) input_results, llm_response = await asyncio.gather(input_check, llm_call) diff --git a/src/guardrails/runtime.py b/src/guardrails/runtime.py index 68948a5..dea412c 100644 --- a/src/guardrails/runtime.py +++ b/src/guardrails/runtime.py @@ -21,7 +21,6 @@ from openai import AsyncOpenAI from pydantic import BaseModel, ConfigDict -from ._openai_utils import prepare_openai_kwargs from .exceptions import ConfigError, GuardrailTripwireTriggered from .registry import GuardrailRegistry, default_spec_registry from .spec import GuardrailSpec @@ -495,7 +494,7 @@ def _get_default_ctx(): class DefaultCtx: guardrail_llm: AsyncOpenAI - return DefaultCtx(guardrail_llm=AsyncOpenAI(**prepare_openai_kwargs({}))) + return DefaultCtx(guardrail_llm=AsyncOpenAI()) async def check_plain_text( diff --git a/src/guardrails/utils/create_vector_store.py b/src/guardrails/utils/create_vector_store.py index a27d0b8..3add976 100644 --- a/src/guardrails/utils/create_vector_store.py +++ b/src/guardrails/utils/create_vector_store.py @@ -15,8 +15,6 @@ from openai import AsyncOpenAI -from .._openai_utils import prepare_openai_kwargs - # Configure logging # logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') logger = logging.getLogger(__name__) @@ -136,7 +134,7 @@ async def main(): path = sys.argv[1] try: - client = AsyncOpenAI(**prepare_openai_kwargs({})) + client = AsyncOpenAI() vector_store_id = await create_vector_store_from_path(path, client) print("\n✅ Vector store created successfully!") diff --git a/tests/unit/test_agents.py b/tests/unit/test_agents.py index d9202fe..ea96c33 100644 --- a/tests/unit/test_agents.py +++ b/tests/unit/test_agents.py @@ -11,7 +11,6 @@ import pytest -from guardrails._openai_utils import SAFETY_IDENTIFIER_HEADER, SAFETY_IDENTIFIER_VALUE from guardrails.types import GuardrailResult # --------------------------------------------------------------------------- @@ -231,14 +230,12 @@ def test_create_conversation_context_exposes_history() -> None: def test_create_default_tool_context_provides_async_client(monkeypatch: pytest.MonkeyPatch) -> None: - """Default tool context should return AsyncOpenAI with safety identifier header.""" - captured_kwargs: dict[str, Any] = {} - + """Default tool context should return AsyncOpenAI client.""" openai_mod = types.ModuleType("openai") class StubAsyncOpenAI: def __init__(self, **kwargs: Any) -> None: - captured_kwargs.update(kwargs) + pass openai_mod.AsyncOpenAI = StubAsyncOpenAI monkeypatch.setitem(sys.modules, "openai", openai_mod) @@ -246,8 +243,6 @@ def __init__(self, **kwargs: Any) -> None: context = agents._create_default_tool_context() assert isinstance(context.guardrail_llm, StubAsyncOpenAI) # noqa: S101 - headers = captured_kwargs.get("default_headers", {}) - assert headers.get(SAFETY_IDENTIFIER_HEADER) == SAFETY_IDENTIFIER_VALUE # noqa: S101 def test_attach_guardrail_to_tools_initializes_lists() -> None: diff --git a/tests/unit/test_openai_utils.py b/tests/unit/test_openai_utils.py deleted file mode 100644 index a3c402c..0000000 --- a/tests/unit/test_openai_utils.py +++ /dev/null @@ -1,33 +0,0 @@ -"""Tests for OpenAI client helper utilities.""" - -from guardrails._openai_utils import SAFETY_IDENTIFIER_HEADER, SAFETY_IDENTIFIER_VALUE, ensure_safety_identifier_header, prepare_openai_kwargs - - -def test_prepare_openai_kwargs_adds_safety_identifier() -> None: - """Default kwargs gain the Guardrails safety identifier.""" - result = prepare_openai_kwargs({}) - headers = result["default_headers"] - assert headers[SAFETY_IDENTIFIER_HEADER] == SAFETY_IDENTIFIER_VALUE # noqa: S101 - - -def test_prepare_openai_kwargs_overrides_existing_identifier() -> None: - """Existing identifier value is overwritten with Guardrails tag.""" - kwargs = {"default_headers": {SAFETY_IDENTIFIER_HEADER: "custom", "X-Test": "value"}} - result = prepare_openai_kwargs(kwargs) - headers = result["default_headers"] - assert headers["X-Test"] == "value" # noqa: S101 - assert headers[SAFETY_IDENTIFIER_HEADER] == SAFETY_IDENTIFIER_VALUE # noqa: S101 - - -def test_prepare_openai_kwargs_handles_non_mapping_as_none() -> None: - """Non-mapping default headers fall back to an empty mapping.""" - result = prepare_openai_kwargs({"default_headers": object()}) - headers = result["default_headers"] - assert headers == {SAFETY_IDENTIFIER_HEADER: SAFETY_IDENTIFIER_VALUE} # noqa: S101 - - -def test_ensure_safety_identifier_header_adds_identifier() -> None: - """ensure_safety_identifier_header augments mappings.""" - headers = ensure_safety_identifier_header({"X-Test": "value"}) - assert headers["X-Test"] == "value" # noqa: S101 - assert headers[SAFETY_IDENTIFIER_HEADER] == SAFETY_IDENTIFIER_VALUE # noqa: S101 diff --git a/tests/unit/test_safety_identifier.py b/tests/unit/test_safety_identifier.py new file mode 100644 index 0000000..ef94a0b --- /dev/null +++ b/tests/unit/test_safety_identifier.py @@ -0,0 +1,73 @@ +"""Tests for safety_identifier parameter handling across different client types.""" + +from unittest.mock import Mock + +import pytest + + +def test_supports_safety_identifier_for_openai_client() -> None: + """Official OpenAI client with default base_url should support safety_identifier.""" + from guardrails.checks.text.llm_base import _supports_safety_identifier + + mock_client = Mock() + mock_client.base_url = None + mock_client.__class__.__name__ = "AsyncOpenAI" + + assert _supports_safety_identifier(mock_client) is True # noqa: S101 + + +def test_supports_safety_identifier_for_openai_with_official_url() -> None: + """OpenAI client with explicit api.openai.com base_url should support safety_identifier.""" + from guardrails.checks.text.llm_base import _supports_safety_identifier + + mock_client = Mock() + mock_client.base_url = "https://api.openai.com/v1" + mock_client.__class__.__name__ = "AsyncOpenAI" + + assert _supports_safety_identifier(mock_client) is True # noqa: S101 + + +def test_does_not_support_safety_identifier_for_azure() -> None: + """Azure OpenAI client should not support safety_identifier.""" + from guardrails.checks.text.llm_base import _supports_safety_identifier + + mock_client = Mock() + mock_client.base_url = "https://example.openai.azure.com/v1" + mock_client.__class__.__name__ = "AsyncAzureOpenAI" + + # Azure detection happens via isinstance check, but we can test with class name + from openai import AsyncAzureOpenAI + + try: + azure_client = AsyncAzureOpenAI( + api_key="test", + azure_endpoint="https://example.openai.azure.com", + api_version="2024-02-01", + ) + assert _supports_safety_identifier(azure_client) is False # noqa: S101 + except Exception: + # If we can't create a real Azure client in tests, that's okay + pytest.skip("Could not create Azure client for testing") + + +def test_does_not_support_safety_identifier_for_local_model() -> None: + """Local model with custom base_url should not support safety_identifier.""" + from guardrails.checks.text.llm_base import _supports_safety_identifier + + mock_client = Mock() + mock_client.base_url = "http://localhost:11434/v1" # Ollama + mock_client.__class__.__name__ = "AsyncOpenAI" + + assert _supports_safety_identifier(mock_client) is False # noqa: S101 + + +def test_does_not_support_safety_identifier_for_alternative_provider() -> None: + """Alternative OpenAI-compatible provider should not support safety_identifier.""" + from guardrails.checks.text.llm_base import _supports_safety_identifier + + mock_client = Mock() + mock_client.base_url = "https://api.together.xyz/v1" + mock_client.__class__.__name__ = "AsyncOpenAI" + + assert _supports_safety_identifier(mock_client) is False # noqa: S101 + From 62b91cfb8baa2adf5a12afac1a6a9142067b5787 Mon Sep 17 00:00:00 2001 From: Steven C Date: Thu, 30 Oct 2025 16:37:11 -0400 Subject: [PATCH 2/3] extract common logic --- src/guardrails/checks/text/llm_base.py | 39 ++--------- src/guardrails/resources/chat/chat.py | 41 ++--------- .../resources/responses/responses.py | 49 +++---------- src/guardrails/utils/safety_identifier.py | 68 +++++++++++++++++++ tests/unit/test_safety_identifier.py | 20 +++--- 5 files changed, 97 insertions(+), 120 deletions(-) create mode 100644 src/guardrails/utils/safety_identifier.py diff --git a/src/guardrails/checks/text/llm_base.py b/src/guardrails/checks/text/llm_base.py index bfa4ed2..ed6a71f 100644 --- a/src/guardrails/checks/text/llm_base.py +++ b/src/guardrails/checks/text/llm_base.py @@ -48,36 +48,7 @@ class MyLLMOutput(LLMOutput): from guardrails.types import CheckFn, GuardrailLLMContextProto, GuardrailResult from guardrails.utils.output import OutputSchema -# OpenAI safety identifier for tracking guardrails library usage -# Only supported by official OpenAI API (not Azure or local/alternative providers) -_SAFETY_IDENTIFIER = "oai_guardrails" - - -def _supports_safety_identifier(client: AsyncOpenAI | OpenAI | AsyncAzureOpenAI | AzureOpenAI) -> bool: - """Check if the client supports the safety_identifier parameter. - - Only the official OpenAI API supports this parameter. - Azure OpenAI and local/alternative providers do not. - - Args: - client: The OpenAI client instance. - - Returns: - True if safety_identifier should be included, False otherwise. - """ - # Azure clients don't support it - if isinstance(client, AsyncAzureOpenAI | AzureOpenAI): - return False - - # Check if using a custom base_url (local or alternative provider) - base_url = getattr(client, "base_url", None) - if base_url is not None: - base_url_str = str(base_url) - # Only official OpenAI API endpoints support safety_identifier - return "api.openai.com" in base_url_str - - # Default OpenAI client (no custom base_url) supports it - return True +from ...utils.safety_identifier import SAFETY_IDENTIFIER, supports_safety_identifier if TYPE_CHECKING: from openai import AsyncAzureOpenAI, AzureOpenAI # type: ignore[unused-import] @@ -93,10 +64,10 @@ def _supports_safety_identifier(client: AsyncOpenAI | OpenAI | AsyncAzureOpenAI __all__ = [ "LLMConfig", - "LLMOutput", "LLMErrorOutput", - "create_llm_check_fn", + "LLMOutput", "create_error_result", + "create_llm_check_fn", ] @@ -286,8 +257,8 @@ async def _request_chat_completion( } # Only official OpenAI API supports safety_identifier (not Azure or local models) - if _supports_safety_identifier(client): - kwargs["safety_identifier"] = _SAFETY_IDENTIFIER + if supports_safety_identifier(client): + kwargs["safety_identifier"] = SAFETY_IDENTIFIER return await _invoke_openai_callable(client.chat.completions.create, **kwargs) diff --git a/src/guardrails/resources/chat/chat.py b/src/guardrails/resources/chat/chat.py index c311287..a76d9b7 100644 --- a/src/guardrails/resources/chat/chat.py +++ b/src/guardrails/resources/chat/chat.py @@ -6,38 +6,7 @@ from typing import Any from ..._base_client import GuardrailsBaseClient - -# OpenAI safety identifier for tracking guardrails library usage -# Only supported by official OpenAI API (not Azure or local/alternative providers) -_SAFETY_IDENTIFIER = "oai_guardrails" - - -def _supports_safety_identifier(client: Any) -> bool: - """Check if the client supports the safety_identifier parameter. - - Only the official OpenAI API supports this parameter. - Azure OpenAI and local/alternative providers do not. - - Args: - client: The OpenAI client instance. - - Returns: - True if safety_identifier should be included, False otherwise. - """ - # Azure clients don't support it - client_type = type(client).__name__ - if "Azure" in client_type: - return False - - # Check if using a custom base_url (local or alternative provider) - base_url = getattr(client, "base_url", None) - if base_url is not None: - base_url_str = str(base_url) - # Only official OpenAI API endpoints support safety_identifier - return "api.openai.com" in base_url_str - - # Default OpenAI client (no custom base_url) supports it - return True +from ...utils.safety_identifier import SAFETY_IDENTIFIER, supports_safety_identifier class Chat: @@ -121,8 +90,8 @@ def create(self, messages: list[dict[str, str]], model: str, stream: bool = Fals "stream": stream, **kwargs, } - if _supports_safety_identifier(self._client._resource_client): - llm_kwargs["safety_identifier"] = _SAFETY_IDENTIFIER + if supports_safety_identifier(self._client._resource_client): + llm_kwargs["safety_identifier"] = SAFETY_IDENTIFIER llm_future = executor.submit( self._client._resource_client.chat.completions.create, @@ -198,8 +167,8 @@ async def create( "stream": stream, **kwargs, } - if _supports_safety_identifier(self._client._resource_client): - llm_kwargs["safety_identifier"] = _SAFETY_IDENTIFIER + if supports_safety_identifier(self._client._resource_client): + llm_kwargs["safety_identifier"] = SAFETY_IDENTIFIER llm_call = self._client._resource_client.chat.completions.create(**llm_kwargs) diff --git a/src/guardrails/resources/responses/responses.py b/src/guardrails/resources/responses/responses.py index f56920b..4df5f46 100644 --- a/src/guardrails/resources/responses/responses.py +++ b/src/guardrails/resources/responses/responses.py @@ -8,38 +8,7 @@ from pydantic import BaseModel from ..._base_client import GuardrailsBaseClient - -# OpenAI safety identifier for tracking guardrails library usage -# Only supported by official OpenAI API (not Azure or local/alternative providers) -_SAFETY_IDENTIFIER = "oai_guardrails" - - -def _supports_safety_identifier(client: Any) -> bool: - """Check if the client supports the safety_identifier parameter. - - Only the official OpenAI API supports this parameter. - Azure OpenAI and local/alternative providers do not. - - Args: - client: The OpenAI client instance. - - Returns: - True if safety_identifier should be included, False otherwise. - """ - # Azure clients don't support it - client_type = type(client).__name__ - if "Azure" in client_type: - return False - - # Check if using a custom base_url (local or alternative provider) - base_url = getattr(client, "base_url", None) - if base_url is not None: - base_url_str = str(base_url) - # Only official OpenAI API endpoints support safety_identifier - return "api.openai.com" in base_url_str - - # Default OpenAI client (no custom base_url) supports it - return True +from ...utils.safety_identifier import SAFETY_IDENTIFIER, supports_safety_identifier class Responses: @@ -103,8 +72,8 @@ def create( "tools": tools, **kwargs, } - if _supports_safety_identifier(self._client._resource_client): - llm_kwargs["safety_identifier"] = _SAFETY_IDENTIFIER + if supports_safety_identifier(self._client._resource_client): + llm_kwargs["safety_identifier"] = SAFETY_IDENTIFIER llm_future = executor.submit( self._client._resource_client.responses.create, @@ -169,8 +138,8 @@ def parse(self, input: list[dict[str, str]], model: str, text_format: type[BaseM "text_format": text_format, **kwargs, } - if _supports_safety_identifier(self._client._resource_client): - llm_kwargs["safety_identifier"] = _SAFETY_IDENTIFIER + if supports_safety_identifier(self._client._resource_client): + llm_kwargs["safety_identifier"] = SAFETY_IDENTIFIER llm_future = executor.submit( self._client._resource_client.responses.parse, @@ -273,8 +242,8 @@ async def create( "tools": tools, **kwargs, } - if _supports_safety_identifier(self._client._resource_client): - llm_kwargs["safety_identifier"] = _SAFETY_IDENTIFIER + if supports_safety_identifier(self._client._resource_client): + llm_kwargs["safety_identifier"] = SAFETY_IDENTIFIER llm_call = self._client._resource_client.responses.create(**llm_kwargs) @@ -339,8 +308,8 @@ async def parse( "stream": stream, **kwargs, } - if _supports_safety_identifier(self._client._resource_client): - llm_kwargs["safety_identifier"] = _SAFETY_IDENTIFIER + if supports_safety_identifier(self._client._resource_client): + llm_kwargs["safety_identifier"] = SAFETY_IDENTIFIER llm_call = self._client._resource_client.responses.parse(**llm_kwargs) diff --git a/src/guardrails/utils/safety_identifier.py b/src/guardrails/utils/safety_identifier.py new file mode 100644 index 0000000..07ecbd4 --- /dev/null +++ b/src/guardrails/utils/safety_identifier.py @@ -0,0 +1,68 @@ +"""OpenAI safety identifier utilities. + +This module provides utilities for handling the OpenAI safety_identifier parameter, +which is used to track guardrails library usage for monitoring and abuse detection. + +The safety identifier is only supported by the official OpenAI API and should not +be sent to Azure OpenAI or other OpenAI-compatible providers. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI +else: + try: + from openai import AsyncAzureOpenAI, AzureOpenAI + except ImportError: + AsyncAzureOpenAI = None # type: ignore[assignment, misc] + AzureOpenAI = None # type: ignore[assignment, misc] + +__all__ = ["SAFETY_IDENTIFIER", "supports_safety_identifier"] + +# OpenAI safety identifier for tracking guardrails library usage +SAFETY_IDENTIFIER = "oai_guardrails" + + +def supports_safety_identifier( + client: AsyncOpenAI | OpenAI | AsyncAzureOpenAI | AzureOpenAI | Any, +) -> bool: + """Check if the client supports the safety_identifier parameter. + + Only the official OpenAI API supports this parameter. + Azure OpenAI and local/alternative providers (Ollama, vLLM, etc.) do not. + + Args: + client: The OpenAI client instance to check. + + Returns: + True if safety_identifier should be included in API calls, False otherwise. + + Examples: + >>> from openai import AsyncOpenAI + >>> client = AsyncOpenAI() + >>> supports_safety_identifier(client) + True + + >>> from openai import AsyncOpenAI + >>> local_client = AsyncOpenAI(base_url="http://localhost:11434") + >>> supports_safety_identifier(local_client) + False + """ + # Azure clients don't support it + if AsyncAzureOpenAI is not None and AzureOpenAI is not None: + if isinstance(client, AsyncAzureOpenAI | AzureOpenAI): + return False + + # Check if using a custom base_url (local or alternative provider) + base_url = getattr(client, "base_url", None) + if base_url is not None: + base_url_str = str(base_url) + # Only official OpenAI API endpoints support safety_identifier + return "api.openai.com" in base_url_str + + # Default OpenAI client (no custom base_url) supports it + return True + diff --git a/tests/unit/test_safety_identifier.py b/tests/unit/test_safety_identifier.py index ef94a0b..b427098 100644 --- a/tests/unit/test_safety_identifier.py +++ b/tests/unit/test_safety_identifier.py @@ -7,29 +7,29 @@ def test_supports_safety_identifier_for_openai_client() -> None: """Official OpenAI client with default base_url should support safety_identifier.""" - from guardrails.checks.text.llm_base import _supports_safety_identifier + from guardrails.utils.safety_identifier import supports_safety_identifier mock_client = Mock() mock_client.base_url = None mock_client.__class__.__name__ = "AsyncOpenAI" - assert _supports_safety_identifier(mock_client) is True # noqa: S101 + assert supports_safety_identifier(mock_client) is True # noqa: S101 def test_supports_safety_identifier_for_openai_with_official_url() -> None: """OpenAI client with explicit api.openai.com base_url should support safety_identifier.""" - from guardrails.checks.text.llm_base import _supports_safety_identifier + from guardrails.utils.safety_identifier import supports_safety_identifier mock_client = Mock() mock_client.base_url = "https://api.openai.com/v1" mock_client.__class__.__name__ = "AsyncOpenAI" - assert _supports_safety_identifier(mock_client) is True # noqa: S101 + assert supports_safety_identifier(mock_client) is True # noqa: S101 def test_does_not_support_safety_identifier_for_azure() -> None: """Azure OpenAI client should not support safety_identifier.""" - from guardrails.checks.text.llm_base import _supports_safety_identifier + from guardrails.utils.safety_identifier import supports_safety_identifier mock_client = Mock() mock_client.base_url = "https://example.openai.azure.com/v1" @@ -44,7 +44,7 @@ def test_does_not_support_safety_identifier_for_azure() -> None: azure_endpoint="https://example.openai.azure.com", api_version="2024-02-01", ) - assert _supports_safety_identifier(azure_client) is False # noqa: S101 + assert supports_safety_identifier(azure_client) is False # noqa: S101 except Exception: # If we can't create a real Azure client in tests, that's okay pytest.skip("Could not create Azure client for testing") @@ -52,22 +52,22 @@ def test_does_not_support_safety_identifier_for_azure() -> None: def test_does_not_support_safety_identifier_for_local_model() -> None: """Local model with custom base_url should not support safety_identifier.""" - from guardrails.checks.text.llm_base import _supports_safety_identifier + from guardrails.utils.safety_identifier import supports_safety_identifier mock_client = Mock() mock_client.base_url = "http://localhost:11434/v1" # Ollama mock_client.__class__.__name__ = "AsyncOpenAI" - assert _supports_safety_identifier(mock_client) is False # noqa: S101 + assert supports_safety_identifier(mock_client) is False # noqa: S101 def test_does_not_support_safety_identifier_for_alternative_provider() -> None: """Alternative OpenAI-compatible provider should not support safety_identifier.""" - from guardrails.checks.text.llm_base import _supports_safety_identifier + from guardrails.utils.safety_identifier import supports_safety_identifier mock_client = Mock() mock_client.base_url = "https://api.together.xyz/v1" mock_client.__class__.__name__ = "AsyncOpenAI" - assert _supports_safety_identifier(mock_client) is False # noqa: S101 + assert supports_safety_identifier(mock_client) is False # noqa: S101 From 7523c8eb22e95a1028f165980818a9bbca0779a9 Mon Sep 17 00:00:00 2001 From: Steven C Date: Thu, 30 Oct 2025 16:44:14 -0400 Subject: [PATCH 3/3] change id value --- src/guardrails/utils/safety_identifier.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/guardrails/utils/safety_identifier.py b/src/guardrails/utils/safety_identifier.py index 07ecbd4..50a87ff 100644 --- a/src/guardrails/utils/safety_identifier.py +++ b/src/guardrails/utils/safety_identifier.py @@ -23,7 +23,7 @@ __all__ = ["SAFETY_IDENTIFIER", "supports_safety_identifier"] # OpenAI safety identifier for tracking guardrails library usage -SAFETY_IDENTIFIER = "oai_guardrails" +SAFETY_IDENTIFIER = "openai-guardrails-python" def supports_safety_identifier(