From bdb0de6c806c76721f2338c7e22cf63dc3cbfdf2 Mon Sep 17 00:00:00 2001 From: steven10a <158192461+steven10a@users.noreply.github.com> Date: Mon, 20 Oct 2025 14:00:34 -0400 Subject: [PATCH 1/3] Update conversation history handling (#24) * Update conversation history handling * support using previous_response_id --- .gitignore | 3 + examples/basic/agents_sdk.py | 5 + src/guardrails/_base_client.py | 13 + src/guardrails/_streaming.py | 47 ++- src/guardrails/agents.py | 265 +++++++------- .../checks/text/prompt_injection_detection.py | 33 +- src/guardrails/client.py | 214 ++++++++---- src/guardrails/resources/chat/chat.py | 16 +- .../resources/responses/responses.py | 64 +++- src/guardrails/types.py | 1 + src/guardrails/utils/__init__.py | 2 + src/guardrails/utils/conversation.py | 328 ++++++++++++++++++ tests/unit/test_agents.py | 205 +++++------ tests/unit/test_client_async.py | 5 +- tests/unit/test_client_sync.py | 7 +- tests/unit/test_resources_chat.py | 11 + tests/unit/test_resources_responses.py | 111 +++++- 17 files changed, 946 insertions(+), 384 deletions(-) create mode 100644 src/guardrails/utils/conversation.py diff --git a/.gitignore b/.gitignore index 1dfea27..5b2c301 100644 --- a/.gitignore +++ b/.gitignore @@ -147,3 +147,6 @@ env/ # Python package management uv.lock + +# Internal files +internal_examples/ \ No newline at end of file diff --git a/examples/basic/agents_sdk.py b/examples/basic/agents_sdk.py index 8c77d02..fe23222 100644 --- a/examples/basic/agents_sdk.py +++ b/examples/basic/agents_sdk.py @@ -7,6 +7,7 @@ InputGuardrailTripwireTriggered, OutputGuardrailTripwireTriggered, Runner, + SQLiteSession, ) from agents.run import RunConfig @@ -50,6 +51,9 @@ async def main() -> None: """Main input loop for the customer support agent with input/output guardrails.""" + # Create a session for the agent to store the conversation history + session = SQLiteSession("guardrails-session") + # Create agent with guardrails automatically configured from pipeline configuration AGENT = GuardrailAgent( config=PIPELINE_CONFIG, @@ -65,6 +69,7 @@ async def main() -> None: AGENT, user_input, run_config=RunConfig(tracing_disabled=True), + session=session, ) print(f"Assistant: {result.final_output}") except EOFError: diff --git a/src/guardrails/_base_client.py b/src/guardrails/_base_client.py index 05925ef..a599a3d 100644 --- a/src/guardrails/_base_client.py +++ b/src/guardrails/_base_client.py @@ -19,6 +19,7 @@ from .runtime import load_pipeline_bundles from .types import GuardrailLLMContextProto, GuardrailResult from .utils.context import validate_guardrail_context +from .utils.conversation import append_assistant_response, normalize_conversation logger = logging.getLogger(__name__) @@ -257,6 +258,18 @@ def _instantiate_all_guardrails(self) -> dict[str, list]: guardrails[stage_name] = instantiate_guardrails(stage, default_spec_registry) if stage else [] return guardrails + def _normalize_conversation(self, payload: Any) -> list[dict[str, Any]]: + """Normalize arbitrary conversation payloads.""" + return normalize_conversation(payload) + + def _conversation_with_response( + self, + conversation: list[dict[str, Any]], + response: Any, + ) -> list[dict[str, Any]]: + """Append the assistant response to a normalized conversation.""" + return append_assistant_response(conversation, response) + def _validate_context(self, context: Any) -> None: """Validate context against all guardrails.""" for stage_guardrails in self.guardrails.values(): diff --git a/src/guardrails/_streaming.py b/src/guardrails/_streaming.py index 898bc0b..4e621c2 100644 --- a/src/guardrails/_streaming.py +++ b/src/guardrails/_streaming.py @@ -13,6 +13,7 @@ from ._base_client import GuardrailsResponse from .exceptions import GuardrailTripwireTriggered from .types import GuardrailResult +from .utils.conversation import merge_conversation_with_items logger = logging.getLogger(__name__) @@ -25,6 +26,7 @@ async def _stream_with_guardrails( llm_stream: Any, # coroutine or async iterator of OpenAI chunks preflight_results: list[GuardrailResult], input_results: list[GuardrailResult], + conversation_history: list[dict[str, Any]] | None = None, check_interval: int = 100, suppress_tripwire: bool = False, ) -> AsyncIterator[GuardrailsResponse]: @@ -46,7 +48,16 @@ async def _stream_with_guardrails( # Run output guardrails periodically if chunk_count % check_interval == 0: try: - await self._run_stage_guardrails("output", accumulated_text, suppress_tripwire=suppress_tripwire) + history = merge_conversation_with_items( + conversation_history or [], + [{"role": "assistant", "content": accumulated_text}], + ) + await self._run_stage_guardrails( + "output", + accumulated_text, + conversation_history=history, + suppress_tripwire=suppress_tripwire, + ) except GuardrailTripwireTriggered: # Clear accumulated output and re-raise accumulated_text = "" @@ -57,7 +68,16 @@ async def _stream_with_guardrails( # Final output check if accumulated_text: - await self._run_stage_guardrails("output", accumulated_text, suppress_tripwire=suppress_tripwire) + history = merge_conversation_with_items( + conversation_history or [], + [{"role": "assistant", "content": accumulated_text}], + ) + await self._run_stage_guardrails( + "output", + accumulated_text, + conversation_history=history, + suppress_tripwire=suppress_tripwire, + ) # Note: This final result won't be yielded since stream is complete # but the results are available in the last chunk @@ -66,6 +86,7 @@ def _stream_with_guardrails_sync( llm_stream: Any, # iterator of OpenAI chunks preflight_results: list[GuardrailResult], input_results: list[GuardrailResult], + conversation_history: list[dict[str, Any]] | None = None, check_interval: int = 100, suppress_tripwire: bool = False, ): @@ -83,7 +104,16 @@ def _stream_with_guardrails_sync( # Run output guardrails periodically if chunk_count % check_interval == 0: try: - self._run_stage_guardrails("output", accumulated_text, suppress_tripwire=suppress_tripwire) + history = merge_conversation_with_items( + conversation_history or [], + [{"role": "assistant", "content": accumulated_text}], + ) + self._run_stage_guardrails( + "output", + accumulated_text, + conversation_history=history, + suppress_tripwire=suppress_tripwire, + ) except GuardrailTripwireTriggered: # Clear accumulated output and re-raise accumulated_text = "" @@ -94,6 +124,15 @@ def _stream_with_guardrails_sync( # Final output check if accumulated_text: - self._run_stage_guardrails("output", accumulated_text, suppress_tripwire=suppress_tripwire) + history = merge_conversation_with_items( + conversation_history or [], + [{"role": "assistant", "content": accumulated_text}], + ) + self._run_stage_guardrails( + "output", + accumulated_text, + conversation_history=history, + suppress_tripwire=suppress_tripwire, + ) # Note: This final result won't be yielded since stream is complete # but the results are available in the last chunk diff --git a/src/guardrails/agents.py b/src/guardrails/agents.py index 5c849fd..0645081 100644 --- a/src/guardrails/agents.py +++ b/src/guardrails/agents.py @@ -19,102 +19,128 @@ from typing import Any from ._openai_utils import prepare_openai_kwargs +from .utils.conversation import merge_conversation_with_items, normalize_conversation logger = logging.getLogger(__name__) __all__ = ["GuardrailAgent"] -# Guardrails that require conversation history context -_NEEDS_CONVERSATION_HISTORY = ["Prompt Injection Detection"] - # Guardrails that should run at tool level (before/after each tool call) # instead of at agent level (before/after entire agent interaction) _TOOL_LEVEL_GUARDRAILS = ["Prompt Injection Detection"] -# Context variable for tracking user messages across conversation turns -# Only stores user messages - NOT full conversation history -# This persists across turns to maintain multi-turn context -# Only used when a guardrail in _NEEDS_CONVERSATION_HISTORY is configured -_user_messages: ContextVar[list[str]] = ContextVar("user_messages", default=[]) # noqa: B039 +# Context variables used to expose conversation information during guardrail checks. +_agent_session: ContextVar[Any | None] = ContextVar("guardrails_agent_session", default=None) +_agent_conversation: ContextVar[tuple[dict[str, Any], ...] | None] = ContextVar( + "guardrails_agent_conversation", + default=None, +) +_AGENT_RUNNER_PATCHED = False -def _get_user_messages() -> list[str]: - """Get user messages from context variable with proper error handling. +def _ensure_agent_runner_patch() -> None: + """Patch AgentRunner.run once so sessions are exposed via ContextVars.""" + global _AGENT_RUNNER_PATCHED + if _AGENT_RUNNER_PATCHED: + return - Returns: - List of user messages, or empty list if not yet initialized - """ try: - return _user_messages.get() - except LookupError: - user_msgs: list[str] = [] - _user_messages.set(user_msgs) - return user_msgs - + from agents.run import AgentRunner # type: ignore + except ImportError: + return -def _separate_tool_level_from_agent_level(guardrails: list[Any]) -> tuple[list[Any], list[Any]]: - """Separate tool-level guardrails from agent-level guardrails. + original_run = AgentRunner.run - Args: - guardrails: List of configured guardrails + async def _patched_run(self, starting_agent, input, **kwargs): # type: ignore[override] + session = kwargs.get("session") + fallback_history: list[dict[str, Any]] | None = None + if session is None: + fallback_history = normalize_conversation(input) - Returns: - Tuple of (tool_level_guardrails, agent_level_guardrails) - """ - tool_level = [] - agent_level = [] + session_token = _agent_session.set(session) + conversation_token = _agent_conversation.set(tuple(dict(item) for item in fallback_history) if fallback_history else None) - for guardrail in guardrails: - if guardrail.definition.name in _TOOL_LEVEL_GUARDRAILS: - tool_level.append(guardrail) - else: - agent_level.append(guardrail) + try: + return await original_run(self, starting_agent, input, **kwargs) + finally: + _agent_session.reset(session_token) + _agent_conversation.reset(conversation_token) - return tool_level, agent_level + AgentRunner.run = _patched_run # type: ignore[assignment] + _AGENT_RUNNER_PATCHED = True -def _needs_conversation_history(guardrail: Any) -> bool: - """Check if a guardrail needs conversation history context. +def _cache_conversation(conversation: list[dict[str, Any]]) -> None: + """Cache the normalized conversation for the current run.""" + _agent_conversation.set(tuple(dict(item) for item in conversation)) - Args: - guardrail: Configured guardrail to check - Returns: - True if guardrail needs conversation history, False otherwise - """ - return guardrail.definition.name in _NEEDS_CONVERSATION_HISTORY +async def _load_agent_conversation() -> list[dict[str, Any]]: + """Load the latest conversation snapshot from session or fallback storage.""" + cached = _agent_conversation.get() + if cached is not None: + return [dict(item) for item in cached] + session = _agent_session.get() + if session is not None: + items = await session.get_items() + conversation = normalize_conversation(items) + _cache_conversation(conversation) + return conversation -def _build_conversation_with_tool_call(data: Any) -> list: - """Build conversation history with user messages + tool call. + return [] - Args: - data: ToolInputGuardrailData containing tool call information - Returns: - List of conversation messages including user context and tool call - """ - user_msgs = _get_user_messages() - conversation = [{"role": "user", "content": msg} for msg in user_msgs] - conversation.append({"type": "function_call", "tool_name": data.context.tool_name, "arguments": data.context.tool_arguments}) +async def _conversation_with_items(items: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Return conversation history including additional items.""" + base_history = await _load_agent_conversation() + conversation = merge_conversation_with_items(base_history, items) + _cache_conversation(conversation) return conversation -def _build_conversation_with_tool_output(data: Any) -> list: - """Build conversation history with user messages + tool output. +async def _conversation_with_tool_call(data: Any) -> list[dict[str, Any]]: + """Build conversation history including the current tool call.""" + event = { + "type": "function_call", + "tool_name": data.context.tool_name, + "arguments": data.context.tool_arguments, + "call_id": getattr(data.context, "tool_call_id", None), + } + return await _conversation_with_items([event]) + + +async def _conversation_with_tool_output(data: Any) -> list[dict[str, Any]]: + """Build conversation history including the current tool output.""" + event = { + "type": "function_call_output", + "tool_name": data.context.tool_name, + "arguments": data.context.tool_arguments, + "output": str(data.output), + "call_id": getattr(data.context, "tool_call_id", None), + } + return await _conversation_with_items([event]) + + +def _separate_tool_level_from_agent_level(guardrails: list[Any]) -> tuple[list[Any], list[Any]]: + """Separate tool-level guardrails from agent-level guardrails. Args: - data: ToolOutputGuardrailData containing tool output information + guardrails: List of configured guardrails Returns: - List of conversation messages including user context and tool output + Tuple of (tool_level_guardrails, agent_level_guardrails) """ - user_msgs = _get_user_messages() - conversation = [{"role": "user", "content": msg} for msg in user_msgs] - conversation.append( - {"type": "function_call_output", "tool_name": data.context.tool_name, "arguments": data.context.tool_arguments, "output": str(data.output)} - ) - return conversation + tool_level = [] + agent_level = [] + + for guardrail in guardrails: + if guardrail.definition.name in _TOOL_LEVEL_GUARDRAILS: + tool_level.append(guardrail) + else: + agent_level.append(guardrail) + + return tool_level, agent_level def _attach_guardrail_to_tools(tools: list[Any], guardrail: Callable, guardrail_type: str) -> None: @@ -173,14 +199,17 @@ def get_conversation_history(self) -> list: def _create_tool_guardrail( - guardrail: Any, guardrail_type: str, needs_conv_history: bool, context: Any, raise_guardrail_errors: bool, block_on_violations: bool + guardrail: Any, + guardrail_type: str, + context: Any, + raise_guardrail_errors: bool, + block_on_violations: bool, ) -> Callable: """Create a generic tool-level guardrail wrapper. Args: guardrail: The configured guardrail guardrail_type: "input" (before tool execution) or "output" (after tool execution) - needs_conv_history: Whether this guardrail needs conversation history context context: Guardrail context for LLM client raise_guardrail_errors: Whether to raise on errors block_on_violations: If True, use raise_exception (halt). If False, use reject_content (continue). @@ -209,26 +238,18 @@ async def tool_input_gr(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOu guardrail_name = guardrail.definition.name try: - # Build context based on whether conversation history is needed - if needs_conv_history: - # Get user messages and check if available - user_msgs = _get_user_messages() - - if not user_msgs: - return ToolGuardrailFunctionOutput(output_info=f"Skipped: no user intent available for {guardrail_name}") - - # Build conversation history with user messages + tool call - conversation_history = _build_conversation_with_tool_call(data) - ctx = _create_conversation_context( - conversation_history=conversation_history, - base_context=context, - ) - check_data = "" # Unused for conversation-history-aware guardrails - else: - # Use simple context without conversation history - ctx = context - # Format tool call data for non-conversation-aware guardrails - check_data = json.dumps({"tool_name": data.context.tool_name, "arguments": data.context.tool_arguments}) + conversation_history = await _conversation_with_tool_call(data) + ctx = _create_conversation_context( + conversation_history=conversation_history, + base_context=context, + ) + check_data = json.dumps( + { + "tool_name": data.context.tool_name, + "arguments": data.context.tool_arguments, + "call_id": getattr(data.context, "tool_call_id", None), + } + ) # Run the guardrail results = await run_guardrails( @@ -271,28 +292,19 @@ async def tool_output_gr(data: ToolOutputGuardrailData) -> ToolGuardrailFunction guardrail_name = guardrail.definition.name try: - # Build context based on whether conversation history is needed - if needs_conv_history: - # Get user messages and check if available - user_msgs = _get_user_messages() - - if not user_msgs: - return ToolGuardrailFunctionOutput(output_info=f"Skipped: no user intent available for {guardrail_name}") - - # Build conversation history with user messages + tool output - conversation_history = _build_conversation_with_tool_output(data) - ctx = _create_conversation_context( - conversation_history=conversation_history, - base_context=context, - ) - check_data = "" # Unused for conversation-history-aware guardrails - else: - # Use simple context without conversation history - ctx = context - # Format tool output data for non-conversation-aware guardrails - check_data = json.dumps( - {"tool_name": data.context.tool_name, "arguments": data.context.tool_arguments, "output": str(data.output)} - ) + conversation_history = await _conversation_with_tool_output(data) + ctx = _create_conversation_context( + conversation_history=conversation_history, + base_context=context, + ) + check_data = json.dumps( + { + "tool_name": data.context.tool_name, + "arguments": data.context.tool_arguments, + "output": str(data.output), + "call_id": getattr(data.context, "tool_call_id", None), + } + ) # Run the guardrail results = await run_guardrails( @@ -387,15 +399,6 @@ def _create_stage_guardrail(stage_name: str): async def stage_guardrail(ctx: RunContextWrapper[None], agent: Agent, input_data: str) -> GuardrailFunctionOutput: """Guardrail function for a specific pipeline stage.""" try: - # If this is an input guardrail, capture user messages for tool-level alignment - if guardrail_type == "input": - # Parse input_data to extract user message - # input_data is typically a string containing the user's message - if input_data and input_data.strip(): - user_msgs = _get_user_messages() - if input_data not in user_msgs: - user_msgs.append(input_data) - # Get guardrails for this stage (already filtered to exclude prompt injection) guardrails = stage_guardrails.get(stage_name, []) if not guardrails: @@ -457,6 +460,11 @@ class GuardrailAgent: Prompt Injection Detection guardrails are applied at the tool level (before and after each tool call), while other guardrails run at the agent level. + When you supply an Agents Session via ``Runner.run(..., session=...)`` the + guardrails automatically read the persisted conversation history. Without a + session, guardrails operate on the conversation passed to ``Runner.run`` for + the current turn. + Example: ```python from guardrails import GuardrailAgent @@ -527,6 +535,8 @@ def __new__( from .registry import default_spec_registry from .runtime import instantiate_guardrails, load_pipeline_bundles + _ensure_agent_runner_patch() + # Load and instantiate guardrails from config pipeline = load_pipeline_bundles(config) @@ -538,10 +548,6 @@ def __new__( else: stage_guardrails[stage_name] = [] - # Check if ANY guardrail in the entire pipeline needs conversation history - all_guardrails = stage_guardrails.get("pre_flight", []) + stage_guardrails.get("input", []) + stage_guardrails.get("output", []) - needs_user_tracking = any(gr.definition.name in _NEEDS_CONVERSATION_HISTORY for gr in all_guardrails) - # Separate tool-level from agent-level guardrails in each stage preflight_tool, preflight_agent = _separate_tool_level_from_agent_level(stage_guardrails.get("pre_flight", [])) input_tool, input_agent = _separate_tool_level_from_agent_level(stage_guardrails.get("input", [])) @@ -550,25 +556,6 @@ def __new__( # Create agent-level INPUT guardrails input_guardrails = [] - # ONLY create user message capture guardrail if needed - if needs_user_tracking: - try: - from agents import Agent as AgentType, GuardrailFunctionOutput, RunContextWrapper, input_guardrail - except ImportError as e: - raise ImportError("The 'agents' package is required. Please install it with: pip install openai-agents") from e - - @input_guardrail - async def capture_user_message(ctx: RunContextWrapper[None], agent: AgentType, input_data: str) -> GuardrailFunctionOutput: - """Capture user messages for conversation-history-aware guardrails.""" - if input_data and input_data.strip(): - user_msgs = _get_user_messages() - if input_data not in user_msgs: - user_msgs.append(input_data) - - return GuardrailFunctionOutput(output_info=None, tripwire_triggered=False) - - input_guardrails.append(capture_user_message) - # Add agent-level guardrails from pre_flight and input stages agent_input_stages = [] if preflight_agent: @@ -610,7 +597,6 @@ async def capture_user_message(ctx: RunContextWrapper[None], agent: AgentType, i tool_input_gr = _create_tool_guardrail( guardrail=guardrail, guardrail_type="input", - needs_conv_history=_needs_conversation_history(guardrail), context=context, raise_guardrail_errors=raise_guardrail_errors, block_on_violations=block_on_tool_violations, @@ -622,7 +608,6 @@ async def capture_user_message(ctx: RunContextWrapper[None], agent: AgentType, i tool_output_gr = _create_tool_guardrail( guardrail=guardrail, guardrail_type="output", - needs_conv_history=_needs_conversation_history(guardrail), context=context, raise_guardrail_errors=raise_guardrail_errors, block_on_violations=block_on_tool_violations, diff --git a/src/guardrails/checks/text/prompt_injection_detection.py b/src/guardrails/checks/text/prompt_injection_detection.py index 631d243..a0e685b 100644 --- a/src/guardrails/checks/text/prompt_injection_detection.py +++ b/src/guardrails/checks/text/prompt_injection_detection.py @@ -35,6 +35,7 @@ from guardrails.registry import default_spec_registry from guardrails.spec import GuardrailSpecMetadata from guardrails.types import GuardrailLLMContextProto, GuardrailResult +from guardrails.utils.conversation import normalize_conversation from .llm_base import LLMConfig, LLMOutput, _invoke_openai_callable @@ -171,7 +172,7 @@ async def prompt_injection_detection( """ try: # Get conversation history for evaluating the latest exchange - conversation_history = ctx.get_conversation_history() + conversation_history = normalize_conversation(ctx.get_conversation_history()) if not conversation_history: return _create_skip_result( "No conversation history available", @@ -271,14 +272,7 @@ def _find_latest_user_index(conversation_history: list[Any]) -> int | None: def _is_user_message(message: Any) -> bool: """Check whether a message originates from the user role.""" - if isinstance(message, dict) and message.get("role") == "user": - return True - if hasattr(message, "role") and message.role == "user": - return True - embedded_message = message.message if hasattr(message, "message") else None - if embedded_message is not None: - return _is_user_message(embedded_message) - return False + return isinstance(message, dict) and message.get("role") == "user" def _coerce_content_to_text(content: Any) -> str: @@ -327,26 +321,17 @@ def _extract_user_intent_from_messages(messages: list) -> dict[str, str | list[s - "most_recent_message": The latest user message as a string - "previous_context": List of previous user messages for context """ - user_messages = [] + normalized_messages = normalize_conversation(messages) + user_texts = [entry["content"] for entry in normalized_messages if entry.get("role") == "user" and isinstance(entry.get("content"), str)] - # Extract all user messages in chronological order and track indices - for _i, msg in enumerate(messages): - if isinstance(msg, dict): - if msg.get("role") == "user": - user_messages.append(_extract_user_message_text(msg)) - elif hasattr(msg, "role") and msg.role == "user": - user_messages.append(_extract_user_message_text(msg)) - - if not user_messages: + if not user_texts: return {"most_recent_message": "", "previous_context": []} - user_intent_dict = { - "most_recent_message": user_messages[-1], - "previous_context": user_messages[:-1], + return { + "most_recent_message": user_texts[-1], + "previous_context": user_texts[:-1], } - return user_intent_dict - def _create_skip_result( observation: str, diff --git a/src/guardrails/client.py b/src/guardrails/client.py index 9f8f2bd..01bcb9d 100644 --- a/src/guardrails/client.py +++ b/src/guardrails/client.py @@ -50,6 +50,92 @@ OUTPUT_STAGE = "output" +def _collect_conversation_items_sync(resource_client: Any, previous_response_id: str) -> list[Any]: + """Return all conversation items for a previous response using sync client APIs.""" + try: + response = resource_client.responses.retrieve(previous_response_id) + except Exception: # pragma: no cover - upstream client/network errors + return [] + + conversation = getattr(response, "conversation", None) + conversation_id = getattr(conversation, "id", None) if conversation else None + + items: list[Any] = [] + + if conversation_id and hasattr(resource_client, "conversations"): + try: + page = resource_client.conversations.items.list( + conversation_id=conversation_id, + order="asc", + limit=100, + ) + for item in page: + items.append(item) + except Exception: # pragma: no cover - upstream client/network errors + items = [] + + if not items: + try: + page = resource_client.responses.input_items.list( + previous_response_id, + order="asc", + limit=100, + ) + for item in page: + items.append(item) + except Exception: # pragma: no cover - upstream client/network errors + items = [] + + output_items = getattr(response, "output", None) + if output_items: + items.extend(output_items) + + return items + + +async def _collect_conversation_items_async(resource_client: Any, previous_response_id: str) -> list[Any]: + """Return all conversation items for a previous response using async client APIs.""" + try: + response = await resource_client.responses.retrieve(previous_response_id) + except Exception: # pragma: no cover - upstream client/network errors + return [] + + conversation = getattr(response, "conversation", None) + conversation_id = getattr(conversation, "id", None) if conversation else None + + items: list[Any] = [] + + if conversation_id and hasattr(resource_client, "conversations"): + try: + page = await resource_client.conversations.items.list( + conversation_id=conversation_id, + order="asc", + limit=100, + ) + async for item in page: # type: ignore[attr-defined] + items.append(item) + except Exception: # pragma: no cover - upstream client/network errors + items = [] + + if not items: + try: + page = await resource_client.responses.input_items.list( + previous_response_id, + order="asc", + limit=100, + ) + async for item in page: # type: ignore[attr-defined] + items.append(item) + except Exception: # pragma: no cover - upstream client/network errors + items = [] + + output_items = getattr(response, "output", None) + if output_items: + items.extend(output_items) + + return items + + class GuardrailsAsyncOpenAI(AsyncOpenAI, GuardrailsBaseClient, StreamingMixin): """AsyncOpenAI subclass with automatic guardrail integration. @@ -142,24 +228,18 @@ def get_conversation_history(self) -> list: def _append_llm_response_to_conversation(self, conversation_history: list | str, llm_response: Any) -> list: """Append LLM response to conversation history as-is.""" - if conversation_history is None: - conversation_history = [] - - # Handle case where conversation_history is a string (from single input) - if isinstance(conversation_history, str): - conversation_history = [{"role": "user", "content": conversation_history}] + normalized_history = self._normalize_conversation(conversation_history) + return self._conversation_with_response(normalized_history, llm_response) - # Make a copy to avoid modifying the original - updated_history = conversation_history.copy() - - # For responses API: append the output directly - if hasattr(llm_response, "output") and llm_response.output: - updated_history.extend(llm_response.output) - # For chat completions: append the choice message directly (prompt injection detection check will parse) - elif hasattr(llm_response, "choices") and llm_response.choices: - updated_history.append(llm_response.choices[0]) + async def _load_conversation_history_from_previous_response(self, previous_response_id: str | None) -> list[dict[str, Any]]: + """Load full conversation history for a stored previous response.""" + if not previous_response_id: + return [] - return updated_history + items = await _collect_conversation_items_async(self._resource_client, previous_response_id) + if not items: + return [] + return self._normalize_conversation(items) def _override_resources(self): """Override chat and responses with our guardrail-enhanced versions.""" @@ -174,7 +254,7 @@ async def _run_stage_guardrails( self, stage_name: str, text: str, - conversation_history: list = None, + conversation_history: list | None = None, suppress_tripwire: bool = False, ) -> list[GuardrailResult]: """Run guardrails for a specific pipeline stage.""" @@ -182,15 +262,9 @@ async def _run_stage_guardrails( return [] try: - # Check if prompt injection detection guardrail is present and we have conversation history - has_injection_detection = any( - guardrail.definition.name.lower() == "prompt injection detection" for guardrail in self.guardrails[stage_name] - ) - - if has_injection_detection and conversation_history: + ctx = self.context + if conversation_history: ctx = self._create_context_with_conversation(conversation_history) - else: - ctx = self.context results = await run_guardrails( ctx=ctx, @@ -225,7 +299,8 @@ async def _handle_llm_response( ) -> GuardrailsResponse: """Handle non-streaming LLM response with output guardrails.""" # Create complete conversation history including the LLM response - complete_conversation = self._append_llm_response_to_conversation(conversation_history, llm_response) + normalized_history = conversation_history or [] + complete_conversation = self._conversation_with_response(normalized_history, llm_response) response_text = self._extract_response_text(llm_response) output_results = await self._run_stage_guardrails( @@ -321,24 +396,18 @@ def get_conversation_history(self) -> list: def _append_llm_response_to_conversation(self, conversation_history: list | str, llm_response: Any) -> list: """Append LLM response to conversation history as-is.""" - if conversation_history is None: - conversation_history = [] - - # Handle case where conversation_history is a string (from single input) - if isinstance(conversation_history, str): - conversation_history = [{"role": "user", "content": conversation_history}] + normalized_history = self._normalize_conversation(conversation_history) + return self._conversation_with_response(normalized_history, llm_response) - # Make a copy to avoid modifying the original - updated_history = conversation_history.copy() - - # For responses API: append the output directly - if hasattr(llm_response, "output") and llm_response.output: - updated_history.extend(llm_response.output) - # For chat completions: append the choice message directly (prompt injection detection check will parse) - elif hasattr(llm_response, "choices") and llm_response.choices: - updated_history.append(llm_response.choices[0]) + def _load_conversation_history_from_previous_response(self, previous_response_id: str | None) -> list[dict[str, Any]]: + """Load full conversation history for a stored previous response.""" + if not previous_response_id: + return [] - return updated_history + items = _collect_conversation_items_sync(self._resource_client, previous_response_id) + if not items: + return [] + return self._normalize_conversation(items) def _override_resources(self): """Override chat and responses with our guardrail-enhanced versions.""" @@ -371,14 +440,9 @@ def _run_stage_guardrails( async def _run_async(): # Check if prompt injection detection guardrail is present and we have conversation history - has_injection_detection = any( - guardrail.definition.name.lower() == "prompt injection detection" for guardrail in self.guardrails[stage_name] - ) - - if has_injection_detection and conversation_history: + ctx = self.context + if conversation_history: ctx = self._create_context_with_conversation(conversation_history) - else: - ctx = self.context results = await run_guardrails( ctx=ctx, @@ -415,7 +479,8 @@ def _handle_llm_response( ) -> GuardrailsResponse: """Handle LLM response with output guardrails.""" # Create complete conversation history including the LLM response - complete_conversation = self._append_llm_response_to_conversation(conversation_history, llm_response) + normalized_history = conversation_history or [] + complete_conversation = self._conversation_with_response(normalized_history, llm_response) response_text = self._extract_response_text(llm_response) output_results = self._run_stage_guardrails( @@ -502,24 +567,18 @@ def get_conversation_history(self) -> list: def _append_llm_response_to_conversation(self, conversation_history: list | str, llm_response: Any) -> list: """Append LLM response to conversation history as-is.""" - if conversation_history is None: - conversation_history = [] - - # Handle case where conversation_history is a string (from single input) - if isinstance(conversation_history, str): - conversation_history = [{"role": "user", "content": conversation_history}] + normalized_history = self._normalize_conversation(conversation_history) + return self._conversation_with_response(normalized_history, llm_response) - # Make a copy to avoid modifying the original - updated_history = conversation_history.copy() - - # For responses API: append the output directly - if hasattr(llm_response, "output") and llm_response.output: - updated_history.extend(llm_response.output) - # For chat completions: append the choice message directly (prompt injection detection check will parse) - elif hasattr(llm_response, "choices") and llm_response.choices: - updated_history.append(llm_response.choices[0]) + async def _load_conversation_history_from_previous_response(self, previous_response_id: str | None) -> list[dict[str, Any]]: + """Load full conversation history for a stored previous response.""" + if not previous_response_id: + return [] - return updated_history + items = await _collect_conversation_items_async(self._resource_client, previous_response_id) + if not items: + return [] + return self._normalize_conversation(items) def _override_resources(self): from .resources.chat import AsyncChat @@ -540,15 +599,9 @@ async def _run_stage_guardrails( return [] try: - # Check if prompt injection detection guardrail is present and we have conversation history - has_injection_detection = any( - guardrail.definition.name.lower() == "prompt injection detection" for guardrail in self.guardrails[stage_name] - ) - - if has_injection_detection and conversation_history: + ctx = self.context + if conversation_history: ctx = self._create_context_with_conversation(conversation_history) - else: - ctx = self.context results = await run_guardrails( ctx=ctx, @@ -583,7 +636,8 @@ async def _handle_llm_response( ) -> GuardrailsResponse: """Handle non-streaming LLM response with output guardrails (async).""" # Create complete conversation history including the LLM response - complete_conversation = self._append_llm_response_to_conversation(conversation_history, llm_response) + normalized_history = conversation_history or [] + complete_conversation = self._conversation_with_response(normalized_history, llm_response) response_text = self._extract_response_text(llm_response) output_results = await self._run_stage_guardrails( @@ -682,6 +736,16 @@ def _append_llm_response_to_conversation(self, conversation_history: list | str, return updated_history + def _load_conversation_history_from_previous_response(self, previous_response_id: str | None) -> list[dict[str, Any]]: + """Load full conversation history for a stored previous response.""" + if not previous_response_id: + return [] + + items = _collect_conversation_items_sync(self._resource_client, previous_response_id) + if not items: + return [] + return self._normalize_conversation(items) + def _override_resources(self): from .resources.chat import Chat from .resources.responses import Responses diff --git a/src/guardrails/resources/chat/chat.py b/src/guardrails/resources/chat/chat.py index e2adb54..aa0382e 100644 --- a/src/guardrails/resources/chat/chat.py +++ b/src/guardrails/resources/chat/chat.py @@ -66,13 +66,14 @@ def create(self, messages: list[dict[str, str]], model: str, stream: bool = Fals Runs preflight first, then executes input guardrails concurrently with the LLM call. """ + normalized_conversation = self._client._normalize_conversation(messages) latest_message, _ = self._client._extract_latest_user_message(messages) # Preflight first (synchronous wrapper) preflight_results = self._client._run_stage_guardrails( "pre_flight", latest_message, - conversation_history=messages, # Pass full conversation for prompt injection detection + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) @@ -91,7 +92,7 @@ def create(self, messages: list[dict[str, str]], model: str, stream: bool = Fals input_results = self._client._run_stage_guardrails( "input", latest_message, - conversation_history=messages, # Pass full conversation for prompt injection detection + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) llm_response = llm_future.result() @@ -102,6 +103,7 @@ def create(self, messages: list[dict[str, str]], model: str, stream: bool = Fals llm_response, preflight_results, input_results, + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) else: @@ -109,7 +111,7 @@ def create(self, messages: list[dict[str, str]], model: str, stream: bool = Fals llm_response, preflight_results, input_results, - conversation_history=messages, + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) @@ -129,13 +131,14 @@ async def create( self, messages: list[dict[str, str]], model: str, stream: bool = False, suppress_tripwire: bool = False, **kwargs ) -> Any | AsyncIterator[Any]: """Create chat completion with guardrails.""" + normalized_conversation = self._client._normalize_conversation(messages) latest_message, _ = self._client._extract_latest_user_message(messages) # Run pre-flight guardrails preflight_results = await self._client._run_stage_guardrails( "pre_flight", latest_message, - conversation_history=messages, # Pass full conversation for prompt injection detection + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) @@ -146,7 +149,7 @@ async def create( input_check = self._client._run_stage_guardrails( "input", latest_message, - conversation_history=messages, # Pass full conversation for prompt injection detection + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) llm_call = self._client._resource_client.chat.completions.create( @@ -163,6 +166,7 @@ async def create( llm_response, preflight_results, input_results, + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) else: @@ -170,6 +174,6 @@ async def create( llm_response, preflight_results, input_results, - conversation_history=messages, + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) diff --git a/src/guardrails/resources/responses/responses.py b/src/guardrails/resources/responses/responses.py index 0d02b8a..89a84f7 100644 --- a/src/guardrails/resources/responses/responses.py +++ b/src/guardrails/resources/responses/responses.py @@ -34,6 +34,16 @@ def create( Runs preflight first, then executes input guardrails concurrently with the LLM call. """ + previous_response_id = kwargs.get("previous_response_id") + prior_history = self._client._load_conversation_history_from_previous_response(previous_response_id) + + current_turn = self._client._normalize_conversation(input) + if prior_history: + normalized_conversation = [entry.copy() for entry in prior_history] + normalized_conversation.extend(current_turn) + else: + normalized_conversation = current_turn + # Determine latest user message text when a list of messages is provided if isinstance(input, list): latest_message, _ = self._client._extract_latest_user_message(input) @@ -44,7 +54,7 @@ def create( preflight_results = self._client._run_stage_guardrails( "pre_flight", latest_message, - conversation_history=input, # Pass full conversation for prompt injection detection + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) @@ -64,7 +74,7 @@ def create( input_results = self._client._run_stage_guardrails( "input", latest_message, - conversation_history=input, # Pass full conversation for prompt injection detection + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) llm_response = llm_future.result() @@ -75,6 +85,7 @@ def create( llm_response, preflight_results, input_results, + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) else: @@ -82,19 +93,28 @@ def create( llm_response, preflight_results, input_results, - conversation_history=input, + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) def parse(self, input: list[dict[str, str]], model: str, text_format: type[BaseModel], suppress_tripwire: bool = False, **kwargs): """Parse response with structured output and guardrails (synchronous).""" + previous_response_id = kwargs.get("previous_response_id") + prior_history = self._client._load_conversation_history_from_previous_response(previous_response_id) + + current_turn = self._client._normalize_conversation(input) + if prior_history: + normalized_conversation = [entry.copy() for entry in prior_history] + normalized_conversation.extend(current_turn) + else: + normalized_conversation = current_turn latest_message, _ = self._client._extract_latest_user_message(input) # Preflight first preflight_results = self._client._run_stage_guardrails( "pre_flight", latest_message, - conversation_history=input, # Pass full conversation for prompt injection detection + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) @@ -113,7 +133,7 @@ def parse(self, input: list[dict[str, str]], model: str, text_format: type[BaseM input_results = self._client._run_stage_guardrails( "input", latest_message, - conversation_history=input, # Pass full conversation for prompt injection detection + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) llm_response = llm_future.result() @@ -122,7 +142,7 @@ def parse(self, input: list[dict[str, str]], model: str, text_format: type[BaseM llm_response, preflight_results, input_results, - conversation_data=input, + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) @@ -165,6 +185,15 @@ async def create( **kwargs, ) -> Any | AsyncIterator[Any]: """Create response with guardrails.""" + previous_response_id = kwargs.get("previous_response_id") + prior_history = await self._client._load_conversation_history_from_previous_response(previous_response_id) + + current_turn = self._client._normalize_conversation(input) + if prior_history: + normalized_conversation = [entry.copy() for entry in prior_history] + normalized_conversation.extend(current_turn) + else: + normalized_conversation = current_turn # Determine latest user message text when a list of messages is provided if isinstance(input, list): latest_message, _ = self._client._extract_latest_user_message(input) @@ -175,7 +204,7 @@ async def create( preflight_results = await self._client._run_stage_guardrails( "pre_flight", latest_message, - conversation_history=input, # Pass full conversation for prompt injection detection + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) @@ -186,7 +215,7 @@ async def create( input_check = self._client._run_stage_guardrails( "input", latest_message, - conversation_history=input, # Pass full conversation for prompt injection detection + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) llm_call = self._client._resource_client.responses.create( @@ -204,6 +233,7 @@ async def create( llm_response, preflight_results, input_results, + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) else: @@ -211,7 +241,7 @@ async def create( llm_response, preflight_results, input_results, - conversation_history=input, + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) @@ -219,13 +249,22 @@ async def parse( self, input: list[dict[str, str]], model: str, text_format: type[BaseModel], stream: bool = False, suppress_tripwire: bool = False, **kwargs ) -> Any | AsyncIterator[Any]: """Parse response with structured output and guardrails.""" + previous_response_id = kwargs.get("previous_response_id") + prior_history = await self._client._load_conversation_history_from_previous_response(previous_response_id) + + current_turn = self._client._normalize_conversation(input) + if prior_history: + normalized_conversation = [entry.copy() for entry in prior_history] + normalized_conversation.extend(current_turn) + else: + normalized_conversation = current_turn latest_message, _ = self._client._extract_latest_user_message(input) # Run pre-flight guardrails preflight_results = await self._client._run_stage_guardrails( "pre_flight", latest_message, - conversation_history=input, # Pass full conversation for prompt injection detection + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) @@ -236,7 +275,7 @@ async def parse( input_check = self._client._run_stage_guardrails( "input", latest_message, - conversation_history=input, # Pass full conversation for prompt injection detection + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) llm_call = self._client._resource_client.responses.parse( @@ -254,6 +293,7 @@ async def parse( llm_response, preflight_results, input_results, + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) else: @@ -261,7 +301,7 @@ async def parse( llm_response, preflight_results, input_results, - conversation_history=input, + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) diff --git a/src/guardrails/types.py b/src/guardrails/types.py index 82f5e76..1f287e5 100644 --- a/src/guardrails/types.py +++ b/src/guardrails/types.py @@ -47,6 +47,7 @@ def get_conversation_history(self) -> list | None: """Get conversation history if available, None otherwise.""" return getattr(self, "conversation_history", None) + @dataclass(frozen=True, slots=True) class GuardrailResult: """Result returned from a guardrail check. diff --git a/src/guardrails/utils/__init__.py b/src/guardrails/utils/__init__.py index 622cb33..d961790 100644 --- a/src/guardrails/utils/__init__.py +++ b/src/guardrails/utils/__init__.py @@ -5,9 +5,11 @@ - response parsing - strict schema enforcement - context validation +- conversation history normalization Modules: schema: Utilities for enforcing strict JSON schema standards. parsing: Tools for parsing and formatting response items. context: Functions for validating guardrail contexts. + conversation: Helpers for normalizing conversation payloads across APIs. """ diff --git a/src/guardrails/utils/conversation.py b/src/guardrails/utils/conversation.py new file mode 100644 index 0000000..f3fa237 --- /dev/null +++ b/src/guardrails/utils/conversation.py @@ -0,0 +1,328 @@ +"""Utilities for normalizing conversation history across providers. + +The helpers in this module transform arbitrary chat/response payloads into a +consistent list of dictionaries that guardrails can consume. The structure is +intended to capture the semantic roles of user/assistant turns as well as tool +calls and outputs regardless of the originating API. +""" + +from __future__ import annotations + +import json +from collections.abc import Iterable, Mapping, Sequence +from dataclasses import dataclass +from typing import Any + + +@dataclass(frozen=True, slots=True) +class ConversationEntry: + """Normalized representation of a conversation item. + + Attributes: + role: Logical speaker role (user, assistant, system, tool, etc.). + type: Optional type discriminator for non-message items such as + ``function_call`` or ``function_call_output``. + content: Primary text payload for message-like items. + tool_name: Name of the tool/function associated with the entry. + arguments: Serialized tool/function arguments when available. + output: Serialized tool result payload when available. + call_id: Identifier that links tool calls and outputs. + """ + + role: str | None = None + type: str | None = None + content: str | None = None + tool_name: str | None = None + arguments: str | None = None + output: str | None = None + call_id: str | None = None + + def to_payload(self) -> dict[str, Any]: + """Convert entry to a plain dict, omitting null fields.""" + payload: dict[str, Any] = {} + if self.role is not None: + payload["role"] = self.role + if self.type is not None: + payload["type"] = self.type + if self.content is not None: + payload["content"] = self.content + if self.tool_name is not None: + payload["tool_name"] = self.tool_name + if self.arguments is not None: + payload["arguments"] = self.arguments + if self.output is not None: + payload["output"] = self.output + if self.call_id is not None: + payload["call_id"] = self.call_id + return payload + + +def normalize_conversation( + conversation: str | Mapping[str, Any] | Sequence[Any] | None, +) -> list[dict[str, Any]]: + """Normalize arbitrary conversation payloads to guardrail-friendly dicts. + + Args: + conversation: Conversation history expressed as a raw string (single + user turn), a mapping/object representing a message, or a sequence + of messages/items. + + Returns: + List of dictionaries describing the conversation in chronological order. + """ + if conversation is None: + return [] + + if isinstance(conversation, str): + entry = ConversationEntry(role="user", content=conversation) + return [entry.to_payload()] + + if isinstance(conversation, Mapping): + entries = _normalize_item(conversation) + return [entry.to_payload() for entry in entries] + + if isinstance(conversation, Sequence): + normalized: list[ConversationEntry] = [] + for item in conversation: + normalized.extend(_normalize_item(item)) + return [entry.to_payload() for entry in normalized] + + # Fallback: treat the value as a message-like object. + entries = _normalize_item(conversation) + return [entry.to_payload() for entry in entries] + + +def append_assistant_response( + conversation: Iterable[dict[str, Any]], + llm_response: Any, +) -> list[dict[str, Any]]: + """Append the assistant response to a normalized conversation copy. + + Args: + conversation: Existing normalized conversation. + llm_response: Response object returned from the model call. + + Returns: + New conversation list containing the assistant's response entries. + """ + base = [entry.copy() for entry in conversation] + response_entries = _normalize_model_response(llm_response) + base.extend(entry.to_payload() for entry in response_entries) + return base + + +def merge_conversation_with_items( + conversation: Iterable[dict[str, Any]], + items: Sequence[Any], +) -> list[dict[str, Any]]: + """Return a new conversation list with additional items appended. + + Args: + conversation: Existing normalized conversation. + items: Additional items (tool calls, tool outputs, messages) to append. + + Returns: + List representing the combined conversation. + """ + base = [entry.copy() for entry in conversation] + for entry in _normalize_sequence(items): + base.append(entry.to_payload()) + return base + + +def _normalize_sequence(items: Sequence[Any]) -> list[ConversationEntry]: + entries: list[ConversationEntry] = [] + for item in items: + entries.extend(_normalize_item(item)) + return entries + + +def _normalize_item(item: Any) -> list[ConversationEntry]: + """Normalize a single message or tool record.""" + if item is None: + return [] + + if isinstance(item, Mapping): + return _normalize_mapping(item) + + if hasattr(item, "model_dump"): + return _normalize_mapping(item.model_dump(exclude_unset=True)) + + if hasattr(item, "__dict__"): + return _normalize_mapping(vars(item)) + + if isinstance(item, str): + return [ConversationEntry(role="user", content=item)] + + return [ConversationEntry(content=_stringify(item))] + + +def _normalize_mapping(item: Mapping[str, Any]) -> list[ConversationEntry]: + entries: list[ConversationEntry] = [] + item_type = item.get("type") + + if item_type in {"function_call", "tool_call"}: + entries.append( + ConversationEntry( + type="function_call", + tool_name=_extract_tool_name(item), + arguments=_stringify(item.get("arguments") or item.get("function", {}).get("arguments")), + call_id=_stringify(item.get("call_id") or item.get("id")), + ) + ) + return entries + + if item_type == "function_call_output": + entries.append( + ConversationEntry( + type="function_call_output", + tool_name=_extract_tool_name(item), + arguments=_stringify(item.get("arguments")), + output=_extract_text(item.get("output")), + call_id=_stringify(item.get("call_id")), + ) + ) + return entries + + role = item.get("role") + if role is not None: + entries.extend(_normalize_role_message(role, item)) + return entries + + # Fallback path for message-like objects without explicit role/type. + entries.append( + ConversationEntry( + content=_extract_text(item.get("content") if "content" in item else item), + type=item_type if isinstance(item_type, str) else None, + ) + ) + return entries + + +def _normalize_role_message(role: str, item: Mapping[str, Any]) -> list[ConversationEntry]: + entries: list[ConversationEntry] = [] + text = _extract_text(item.get("content")) + if role != "tool": + entries.append(ConversationEntry(role=role, content=text)) + + # Normalize inline tool calls/functions. + tool_calls = item.get("tool_calls") + if isinstance(tool_calls, Sequence): + entries.extend(_normalize_tool_calls(tool_calls)) + + function_call = item.get("function_call") + if isinstance(function_call, Mapping): + entries.append( + ConversationEntry( + type="function_call", + tool_name=_stringify(function_call.get("name")), + arguments=_stringify(function_call.get("arguments")), + call_id=_stringify(function_call.get("call_id")), + ) + ) + + if role == "tool": + tool_output = ConversationEntry( + type="function_call_output", + tool_name=_extract_tool_name(item), + output=text, + arguments=_stringify(item.get("arguments")), + call_id=_stringify(item.get("tool_call_id") or item.get("call_id")), + ) + return [entry for entry in [tool_output] if any(entry.to_payload().values())] + + return [entry for entry in entries if any(entry.to_payload().values())] + + +def _normalize_tool_calls(tool_calls: Sequence[Any]) -> list[ConversationEntry]: + entries: list[ConversationEntry] = [] + for call in tool_calls: + if hasattr(call, "model_dump"): + call_mapping = call.model_dump(exclude_unset=True) + elif isinstance(call, Mapping): + call_mapping = call + else: + call_mapping = {} + + entries.append( + ConversationEntry( + type="function_call", + tool_name=_extract_tool_name(call_mapping), + arguments=_stringify(call_mapping.get("arguments") or call_mapping.get("function", {}).get("arguments")), + call_id=_stringify(call_mapping.get("id") or call_mapping.get("call_id")), + ) + ) + return entries + + +def _extract_tool_name(item: Mapping[str, Any]) -> str | None: + if "tool_name" in item and isinstance(item["tool_name"], str): + return item["tool_name"] + if "name" in item and isinstance(item["name"], str): + return item["name"] + function = item.get("function") + if isinstance(function, Mapping): + name = function.get("name") + if isinstance(name, str): + return name + return None + + +def _extract_text(content: Any) -> str | None: + if content is None: + return None + + if isinstance(content, str): + return content + + if isinstance(content, Mapping): + text = content.get("text") + if isinstance(text, str): + return text + return _extract_text(content.get("content")) + + if isinstance(content, Sequence) and not isinstance(content, bytes | bytearray): + parts: list[str] = [] + for item in content: + extracted = _extract_text(item) + if extracted: + parts.append(extracted) + joined = " ".join(part for part in parts if part) + return joined or None + + return _stringify(content) + + +def _normalize_model_response(response: Any) -> list[ConversationEntry]: + if response is None: + return [] + + if hasattr(response, "output"): + output = response.output + if isinstance(output, Sequence): + return _normalize_sequence(output) + + if hasattr(response, "choices"): + choices = response.choices + if isinstance(choices, Sequence) and choices: + choice = choices[0] + message = getattr(choice, "message", choice) + return _normalize_item(message) + + # Streaming deltas often expose ``delta`` with message fragments. + delta = getattr(response, "delta", None) + if delta: + return _normalize_item(delta) + + return [] + + +def _stringify(value: Any) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + try: + return json.dumps(value, ensure_ascii=False) + except (TypeError, ValueError): + return str(value) diff --git a/tests/unit/test_agents.py b/tests/unit/test_agents.py index 2bdcd4b..0aa1c4d 100644 --- a/tests/unit/test_agents.py +++ b/tests/unit/test_agents.py @@ -23,10 +23,11 @@ @dataclass class ToolContext: - """Stub tool context carrying name and arguments.""" + """Stub tool context carrying name, arguments, and optional call id.""" tool_name: str - tool_arguments: dict[str, Any] + tool_arguments: dict[str, Any] | str + tool_call_id: str | None = None @dataclass @@ -99,6 +100,14 @@ class Agent: tools: list[Any] | None = None +class AgentRunner: + """Minimal AgentRunner stub so guardrails patching succeeds.""" + + async def run(self, *args: Any, **kwargs: Any) -> Any: + """Return a sentinel result.""" + return SimpleNamespace() + + agents_module.ToolGuardrailFunctionOutput = ToolGuardrailFunctionOutput agents_module.ToolInputGuardrailData = ToolInputGuardrailData agents_module.ToolOutputGuardrailData = ToolOutputGuardrailData @@ -109,9 +118,15 @@ class Agent: agents_module.GuardrailFunctionOutput = GuardrailFunctionOutput agents_module.input_guardrail = _decorator_passthrough agents_module.output_guardrail = _decorator_passthrough +agents_module.AgentRunner = AgentRunner sys.modules.setdefault("agents", agents_module) +agents_run_module = types.ModuleType("agents.run") +agents_run_module.AgentRunner = AgentRunner +sys.modules.setdefault("agents.run", agents_run_module) +agents_module.run = agents_run_module + import guardrails.agents as agents # noqa: E402 (import after stubbing) import guardrails.runtime as runtime_module # noqa: E402 @@ -135,44 +150,75 @@ def model_validate(value: Any, **_: Any) -> Any: @pytest.fixture(autouse=True) -def reset_user_messages() -> None: - """Ensure user message context is reset for each test.""" - agents._user_messages.set([]) +def reset_agent_context() -> None: + """Ensure agent conversation context vars are reset for each test.""" + agents._agent_session.set(None) + agents._agent_conversation.set(None) + + +@pytest.mark.asyncio +async def test_conversation_with_tool_call_updates_fallback_history() -> None: + """Fallback conversation should include previous history and new tool call.""" + agents._agent_session.set(None) + agents._agent_conversation.set(({"role": "user", "content": "Hi there"},)) + data = SimpleNamespace(context=ToolContext(tool_name="math", tool_arguments={"x": 1}, tool_call_id="call-1")) + + conversation = await agents._conversation_with_tool_call(data) + + assert conversation[0]["content"] == "Hi there" # noqa: S101 + assert conversation[-1]["type"] == "function_call" # noqa: S101 + assert conversation[-1]["tool_name"] == "math" # noqa: S101 + stored = agents._agent_conversation.get() + assert stored is not None and stored[-1]["call_id"] == "call-1" # type: ignore[index] # noqa: S101 + + +@pytest.mark.asyncio +async def test_conversation_with_tool_call_uses_session_history() -> None: + """When session is available, its items form the conversation baseline.""" + + class StubSession: + def __init__(self) -> None: + self.items = [{"role": "user", "content": "Remember me"}] + + async def get_items(self, limit: int | None = None) -> list[dict[str, Any]]: + return self.items + async def add_items(self, items: list[Any]) -> None: + self.items.extend(items) -def test_get_user_messages_initializes_list() -> None: - """_get_user_messages should return the same list instance across calls.""" - msgs1 = agents._get_user_messages() - msgs1.append("hello") - msgs2 = agents._get_user_messages() + async def pop_item(self) -> Any | None: + return None - assert msgs2 == ["hello"] # noqa: S101 - assert msgs1 is msgs2 # noqa: S101 + async def clear_session(self) -> None: + self.items.clear() + session = StubSession() + agents._agent_session.set(session) + agents._agent_conversation.set(None) -def test_build_conversation_with_tool_call_includes_user_messages() -> None: - """Conversation builder should include stored user messages and tool call details.""" - agents._user_messages.set(["Hi there"]) - data = SimpleNamespace(context=ToolContext(tool_name="math", tool_arguments={"x": 1})) + data = SimpleNamespace(context=ToolContext(tool_name="lookup", tool_arguments={"zip": 12345}, tool_call_id="call-2")) - conversation = agents._build_conversation_with_tool_call(data) + conversation = await agents._conversation_with_tool_call(data) - assert conversation[0] == {"role": "user", "content": "Hi there"} # noqa: S101 - assert conversation[1]["tool_name"] == "math" # noqa: S101 - assert conversation[1]["arguments"] == {"x": 1} # noqa: S101 + assert conversation[0]["content"] == "Remember me" # noqa: S101 + assert conversation[-1]["call_id"] == "call-2" # noqa: S101 + cached = agents._agent_conversation.get() + assert cached is not None and cached[-1]["call_id"] == "call-2" # type: ignore[index] # noqa: S101 -def test_build_conversation_with_tool_output_includes_output() -> None: - """Tool output conversation should include function output payload.""" - agents._user_messages.set(["User request"]) +@pytest.mark.asyncio +async def test_conversation_with_tool_output_includes_output() -> None: + """Tool output conversation should include serialized output payload.""" + agents._agent_session.set(None) + agents._agent_conversation.set(({"role": "user", "content": "Compute"},)) data = SimpleNamespace( - context=ToolContext(tool_name="calc", tool_arguments={"y": 2}), + context=ToolContext(tool_name="calc", tool_arguments={"y": 2}, tool_call_id="call-3"), output={"result": 4}, ) - conversation = agents._build_conversation_with_tool_output(data) + conversation = await agents._conversation_with_tool_output(data) - assert conversation[1]["output"] == "{'result': 4}" # noqa: S101 + assert conversation[-1]["output"] == "{'result': 4}" # noqa: S101 def test_create_conversation_context_exposes_history() -> None: @@ -216,12 +262,6 @@ def test_attach_guardrail_to_tools_initializes_lists() -> None: assert tool.tool_output_guardrails == [fn] # type: ignore[attr-defined] # noqa: S101 -def test_needs_conversation_history() -> None: - """Guardrails requiring conversation history should be detected.""" - assert agents._needs_conversation_history(_make_guardrail("Prompt Injection Detection")) is True # noqa: S101 - assert agents._needs_conversation_history(_make_guardrail("Other Guard")) is False # noqa: S101 - - def test_separate_tool_level_from_agent_level() -> None: """Prompt injection guardrails should be classified as tool-level.""" tool, agent_level = agents._separate_tool_level_from_agent_level([_make_guardrail("Prompt Injection Detection"), _make_guardrail("Other Guard")]) @@ -235,9 +275,13 @@ async def test_create_tool_guardrail_rejects_on_tripwire(monkeypatch: pytest.Mon """Tool guardrail should reject content when run_guardrails flags a violation.""" guardrail = _make_guardrail("Test Guardrail") expected_info = {"observation": "violation"} + agents._agent_session.set(None) + agents._agent_conversation.set(({"role": "user", "content": "Original request"},)) async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: assert kwargs["stage_name"] == "tool_input_test_guardrail" # noqa: S101 + history = kwargs["ctx"].get_conversation_history() + assert history[-1]["tool_name"] == "weather" # noqa: S101 return [GuardrailResult(tripwire_triggered=True, info=expected_info)] monkeypatch.setattr(runtime_module, "run_guardrails", fake_run_guardrails) @@ -245,8 +289,7 @@ async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: tool_fn = agents._create_tool_guardrail( guardrail=guardrail, guardrail_type="input", - needs_conv_history=False, - context=SimpleNamespace(), + context=SimpleNamespace(guardrail_llm="client"), raise_guardrail_errors=False, block_on_violations=False, ) @@ -268,12 +311,13 @@ async def fake_run_guardrails(**_: Any) -> list[GuardrailResult]: return [GuardrailResult(tripwire_triggered=True, info={})] monkeypatch.setattr(runtime_module, "run_guardrails", fake_run_guardrails) + agents._agent_session.set(None) + agents._agent_conversation.set(({"role": "user", "content": "Hi"},)) tool_fn = agents._create_tool_guardrail( guardrail=guardrail, guardrail_type="input", - needs_conv_history=False, - context=SimpleNamespace(), + context=SimpleNamespace(guardrail_llm="client"), raise_guardrail_errors=False, block_on_violations=True, ) @@ -293,12 +337,13 @@ async def failing_run_guardrails(**_: Any) -> list[GuardrailResult]: raise RuntimeError("guardrail failure") monkeypatch.setattr(runtime_module, "run_guardrails", failing_run_guardrails) + agents._agent_session.set(None) + agents._agent_conversation.set(({"role": "user", "content": "Hi"},)) tool_fn = agents._create_tool_guardrail( guardrail=guardrail, guardrail_type="input", - needs_conv_history=False, - context=SimpleNamespace(), + context=SimpleNamespace(guardrail_llm="client"), raise_guardrail_errors=True, block_on_violations=False, ) @@ -310,21 +355,23 @@ async def failing_run_guardrails(**_: Any) -> list[GuardrailResult]: @pytest.mark.asyncio -async def test_create_tool_guardrail_skips_without_user_messages(monkeypatch: pytest.MonkeyPatch) -> None: - """Conversation-aware tool guardrails should skip when no user intent is recorded.""" +async def test_create_tool_guardrail_handles_empty_conversation(monkeypatch: pytest.MonkeyPatch) -> None: + """Guardrail executes even when no prior conversation is present.""" guardrail = _make_guardrail("Prompt Injection Detection") - agents._user_messages.set([]) # Reset stored messages async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: - raise AssertionError("run_guardrails should not be called when skipping") + history = kwargs["ctx"].get_conversation_history() + assert history[-1]["output"] == "ok" # noqa: S101 + return [GuardrailResult(tripwire_triggered=False, info={})] - monkeypatch.setattr(agents, "run_guardrails", fake_run_guardrails, raising=False) + monkeypatch.setattr(runtime_module, "run_guardrails", fake_run_guardrails) + agents._agent_session.set(None) + agents._agent_conversation.set(None) tool_fn = agents._create_tool_guardrail( guardrail=guardrail, guardrail_type="output", - needs_conv_history=True, - context=SimpleNamespace(), + context=SimpleNamespace(guardrail_llm="client"), raise_guardrail_errors=False, block_on_violations=False, ) @@ -335,13 +382,12 @@ async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: ) result = await tool_fn(data) - assert "Skipped" in result.output_info # noqa: S101 assert result.tripwire_triggered is False # noqa: S101 @pytest.mark.asyncio async def test_create_agents_guardrails_from_config_success(monkeypatch: pytest.MonkeyPatch) -> None: - """Agent-level guardrail functions should execute run_guardrails and capture user messages.""" + """Agent-level guardrail functions should execute run_guardrails.""" pipeline = SimpleNamespace(pre_flight=None, input=SimpleNamespace(), output=None) monkeypatch.setattr(runtime_module, "load_pipeline_bundles", lambda config: pipeline) monkeypatch.setattr( @@ -371,7 +417,7 @@ async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: assert output.tripwire_triggered is False # noqa: S101 assert captured["stage_name"] == "input" # noqa: S101 - assert agents._get_user_messages()[-1] == "hello" # noqa: S101 + assert captured["data"] == "hello" # noqa: S101 @pytest.mark.asyncio @@ -489,7 +535,6 @@ async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: result = await guardrails[0](agents_module.RunContextWrapper(None), Agent("n", "i"), "response") assert result.tripwire_triggered is False # noqa: S101 - assert agents._get_user_messages() == [] # noqa: S101 def test_guardrail_agent_attaches_tool_guardrails(monkeypatch: pytest.MonkeyPatch) -> None: @@ -542,68 +587,6 @@ def fake_instantiate_guardrails(stage: Any, registry: Any | None = None) -> list assert len(agent_instance.input_guardrails or []) >= 1 # noqa: S101 -@pytest.mark.asyncio -async def test_guardrail_agent_captures_user_messages(monkeypatch: pytest.MonkeyPatch) -> None: - """GuardrailAgent should capture user messages and invoke tool guardrails.""" - prompt_guard = _make_guardrail("Prompt Injection Detection") - input_guard = _make_guardrail("Agent Guard") - - class FakePipeline: - def __init__(self) -> None: - self.pre_flight = SimpleNamespace() - self.input = SimpleNamespace() - self.output = None - - def stages(self) -> list[Any]: - return [self.pre_flight, self.input] - - pipeline = FakePipeline() - - def fake_load_pipeline_bundles(config: Any) -> FakePipeline: - return pipeline - - def fake_instantiate_guardrails(stage: Any, registry: Any | None = None) -> list[Any]: - if stage is pipeline.pre_flight: - return [prompt_guard] - if stage is pipeline.input: - return [input_guard] - return [] - - calls: list[str] = [] - - async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: - calls.append(kwargs["stage_name"]) - return [GuardrailResult(tripwire_triggered=False, info={})] - - monkeypatch.setattr(runtime_module, "load_pipeline_bundles", fake_load_pipeline_bundles, raising=False) - monkeypatch.setattr(runtime_module, "instantiate_guardrails", fake_instantiate_guardrails, raising=False) - monkeypatch.setattr(runtime_module, "run_guardrails", fake_run_guardrails) - - tool = SimpleNamespace() - agent_instance = agents.GuardrailAgent( - config={"version": 1}, - name="Test", - instructions="Help", - tools=[tool], - ) - - # Call the first input guardrail (capture function) - capture_fn = agent_instance.input_guardrails[0] - await capture_fn(agents_module.RunContextWrapper(None), agent_instance, "user question") - - assert agents._get_user_messages() == ["user question"] # noqa: S101 - - # Run actual agent guardrail - guard_fn = agent_instance.input_guardrails[1] - await guard_fn(agents_module.RunContextWrapper(None), agent_instance, "user question") - - # Tool guardrail should be attached and callable - data = agents_module.ToolInputGuardrailData(context=ToolContext("tool", {})) - await tool.tool_input_guardrails[0](data) # type: ignore[attr-defined] - - assert any(name.startswith("tool_input") for name in calls) # noqa: S101 - - def test_guardrail_agent_without_tools(monkeypatch: pytest.MonkeyPatch) -> None: """Agent with no tools should not attach tool guardrails.""" pipeline = SimpleNamespace(pre_flight=None, input=None, output=None) diff --git a/tests/unit/test_client_async.py b/tests/unit/test_client_async.py index 50d0c4c..624e56c 100644 --- a/tests/unit/test_client_async.py +++ b/tests/unit/test_client_async.py @@ -61,7 +61,8 @@ def test_append_llm_response_handles_string_history() -> None: updated_history = client._append_llm_response_to_conversation("hi there", response) assert updated_history[0]["content"] == "hi there" # noqa: S101 - assert updated_history[1].message.content == "assistant reply" # type: ignore[union-attr] # noqa: S101 + assert updated_history[0]["role"] == "user" # noqa: S101 + assert updated_history[1]["content"] == "assistant reply" # noqa: S101 def test_append_llm_response_handles_response_output() -> None: @@ -182,7 +183,7 @@ async def fake_run_stage( ) assert captured_text == ["LLM response"] # noqa: S101 - assert captured_history[-1][-1].message.content == "LLM response" # type: ignore[index] # noqa: S101 + assert captured_history[-1][-1]["content"] == "LLM response" # noqa: S101 assert response.guardrail_results.output == [output_result] # noqa: S101 diff --git a/tests/unit/test_client_sync.py b/tests/unit/test_client_sync.py index 2eb6a7c..b04c724 100644 --- a/tests/unit/test_client_sync.py +++ b/tests/unit/test_client_sync.py @@ -86,7 +86,8 @@ def test_append_llm_response_handles_string_history() -> None: updated_history = client._append_llm_response_to_conversation("hi there", response) assert updated_history[0]["content"] == "hi there" # noqa: S101 - assert updated_history[1].message.content == "assistant reply" # type: ignore[union-attr] # noqa: S101 + assert updated_history[0]["role"] == "user" # noqa: S101 + assert updated_history[1]["content"] == "assistant reply" # noqa: S101 def test_append_llm_response_handles_response_output() -> None: @@ -112,7 +113,7 @@ def test_append_llm_response_handles_none_history() -> None: history = client._append_llm_response_to_conversation(None, response) - assert history[-1].message.content == "assistant reply" # type: ignore[union-attr] # noqa: S101 + assert history[-1]["content"] == "assistant reply" # noqa: S101 def test_run_stage_guardrails_raises_on_tripwire(monkeypatch: pytest.MonkeyPatch) -> None: @@ -264,7 +265,7 @@ def fake_run_stage( ) assert captured_text == ["LLM response"] # noqa: S101 - assert captured_history[-1][-1].message.content == "LLM response" # type: ignore[index] # noqa: S101 + assert captured_history[-1][-1]["content"] == "LLM response" # noqa: S101 assert response.guardrail_results.output == [output_result] # noqa: S101 diff --git a/tests/unit/test_resources_chat.py b/tests/unit/test_resources_chat.py index 2a73ca3..fcff527 100644 --- a/tests/unit/test_resources_chat.py +++ b/tests/unit/test_resources_chat.py @@ -8,6 +8,7 @@ import pytest from guardrails.resources.chat.chat import AsyncChatCompletions, ChatCompletions +from guardrails.utils.conversation import normalize_conversation class _InlineExecutor: @@ -48,6 +49,7 @@ def __init__(self) -> None: completions=SimpleNamespace(create=self._llm_call), ) ) + self._normalize_conversation = normalize_conversation self._llm_response = SimpleNamespace(type="llm") self._stream_result = "stream" self._handle_result = "handled" @@ -106,6 +108,8 @@ def _stream_with_guardrails_sync( llm_stream: Any, preflight_results: list[Any], input_results: list[Any], + conversation_history: list[dict[str, Any]] | None = None, + check_interval: int = 100, suppress_tripwire: bool = False, ) -> Any: self.stream_calls.append( @@ -113,6 +117,8 @@ def _stream_with_guardrails_sync( "stream": llm_stream, "preflight": preflight_results, "input": input_results, + "history": conversation_history, + "interval": check_interval, "suppress": suppress_tripwire, } ) @@ -134,6 +140,7 @@ def __init__(self) -> None: completions=SimpleNamespace(create=self._llm_call), ) ) + self._normalize_conversation = normalize_conversation self._llm_response = SimpleNamespace(type="llm") self._stream_result = "async-stream" self._handle_result = "async-handled" @@ -192,6 +199,8 @@ def _stream_with_guardrails( llm_stream: Any, preflight_results: list[Any], input_results: list[Any], + conversation_history: list[dict[str, Any]] | None = None, + check_interval: int = 100, suppress_tripwire: bool = False, ) -> Any: self.stream_calls.append( @@ -199,6 +208,8 @@ def _stream_with_guardrails( "stream": llm_stream, "preflight": preflight_results, "input": input_results, + "history": conversation_history, + "interval": check_interval, "suppress": suppress_tripwire, } ) diff --git a/tests/unit/test_resources_responses.py b/tests/unit/test_resources_responses.py index 88adbe3..3726be2 100644 --- a/tests/unit/test_resources_responses.py +++ b/tests/unit/test_resources_responses.py @@ -9,6 +9,7 @@ from pydantic import BaseModel from guardrails.resources.responses.responses import AsyncResponses, Responses +from guardrails.utils.conversation import normalize_conversation class _SyncResponsesClient: @@ -24,6 +25,8 @@ def __init__(self) -> None: self.create_calls: list[dict[str, Any]] = [] self.parse_calls: list[dict[str, Any]] = [] self.retrieve_calls: list[dict[str, Any]] = [] + self.history_requests: list[str | None] = [] + self.history_lookup: dict[str, list[dict[str, Any]]] = {} self._llm_response = SimpleNamespace(output_text="result", type="llm") self._stream_result = "stream" self._handle_result = "handled" @@ -34,6 +37,7 @@ def __init__(self) -> None: retrieve=self._llm_retrieve, ) ) + self._normalize_conversation = normalize_conversation def _llm_create(self, **kwargs: Any) -> Any: self.create_calls.append(kwargs) @@ -78,6 +82,14 @@ def _apply_preflight_modifications(self, data: Any, results: list[Any]) -> Any: return [{"role": "user", "content": "modified"}] return "modified" + def _load_conversation_history_from_previous_response(self, previous_response_id: str | None) -> list[dict[str, Any]]: + self.history_requests.append(previous_response_id) + if not previous_response_id: + return [] + + history = self.history_lookup.get(previous_response_id, []) + return [entry.copy() for entry in history] + def _handle_llm_response( self, llm_response: Any, @@ -103,6 +115,8 @@ def _stream_with_guardrails_sync( llm_stream: Any, preflight_results: list[Any], input_results: list[Any], + conversation_history: list[dict[str, Any]] | None = None, + check_interval: int = 100, suppress_tripwire: bool = False, ) -> Any: self.stream_calls.append( @@ -110,6 +124,8 @@ def _stream_with_guardrails_sync( "stream": llm_stream, "preflight": preflight_results, "input": input_results, + "history": conversation_history, + "interval": check_interval, "suppress": suppress_tripwire, } ) @@ -130,7 +146,6 @@ def _create_guardrails_response( "output": output_results, } - class _AsyncResponsesClient: """Fake asynchronous guardrails client for AsyncResponses tests.""" @@ -142,14 +157,13 @@ def __init__(self) -> None: self.handle_calls: list[dict[str, Any]] = [] self.stream_calls: list[dict[str, Any]] = [] self.create_calls: list[dict[str, Any]] = [] + self.history_requests: list[str | None] = [] + self.history_lookup: dict[str, list[dict[str, Any]]] = {} self._llm_response = SimpleNamespace(output_text="async", type="llm") self._stream_result = "async-stream" self._handle_result = "async-handled" - self._resource_client = SimpleNamespace( - responses=SimpleNamespace( - create=self._llm_create, - ) - ) + self._resource_client = SimpleNamespace(responses=SimpleNamespace(create=self._llm_create)) + self._normalize_conversation = normalize_conversation async def _llm_create(self, **kwargs: Any) -> Any: self.create_calls.append(kwargs) @@ -186,6 +200,14 @@ def _apply_preflight_modifications(self, data: Any, results: list[Any]) -> Any: return [{"role": "user", "content": "modified"}] return "modified" + async def _load_conversation_history_from_previous_response(self, previous_response_id: str | None) -> list[dict[str, Any]]: + self.history_requests.append(previous_response_id) + if not previous_response_id: + return [] + + history = self.history_lookup.get(previous_response_id, []) + return [entry.copy() for entry in history] + async def _handle_llm_response( self, llm_response: Any, @@ -209,6 +231,8 @@ def _stream_with_guardrails( llm_stream: Any, preflight_results: list[Any], input_results: list[Any], + conversation_history: list[dict[str, Any]] | None = None, + check_interval: int = 100, suppress_tripwire: bool = False, ) -> Any: self.stream_calls.append( @@ -216,6 +240,8 @@ def _stream_with_guardrails( "stream": llm_stream, "preflight": preflight_results, "input": input_results, + "history": conversation_history, + "interval": check_interval, "suppress": suppress_tripwire, } ) @@ -279,6 +305,28 @@ def test_responses_create_stream_returns_stream(monkeypatch: pytest.MonkeyPatch) stream_call = client.stream_calls[0] assert stream_call["suppress"] is True # noqa: S101 assert stream_call["preflight"] == ["preflight"] # noqa: S101 + assert stream_call["history"] == normalize_conversation(_messages()) # noqa: S101 + + +def test_responses_create_merges_previous_history(monkeypatch: pytest.MonkeyPatch) -> None: + """Responses.create should merge stored conversation history when provided.""" + client = _SyncResponsesClient() + responses = Responses(client) + _inline_executor(monkeypatch) + + previous_turn = [ + {"role": "user", "content": "old question"}, + {"role": "assistant", "content": "old answer"}, + ] + client.history_lookup["resp-prev"] = normalize_conversation(previous_turn) + + messages = _messages() + responses.create(input=messages, model="gpt-test", previous_response_id="resp-prev") + + expected_history = client.history_lookup["resp-prev"] + normalize_conversation(messages) + assert client.preflight_calls[0]["history"] == expected_history # noqa: S101 + assert client.history_requests == ["resp-prev"] # noqa: S101 + assert client.create_calls[0]["previous_response_id"] == "resp-prev" # noqa: S101 def test_responses_parse_runs_guardrails(monkeypatch: pytest.MonkeyPatch) -> None: @@ -295,7 +343,35 @@ class _Schema(BaseModel): assert result == "handled" # noqa: S101 assert client.parse_calls[0]["input"][0]["content"] == "modified" # noqa: S101 - assert client.handle_calls[0]["extra"]["conversation_data"] == messages # noqa: S101 + assert client.handle_calls[0]["history"] == normalize_conversation(messages) # noqa: S101 + + +def test_responses_parse_merges_previous_history(monkeypatch: pytest.MonkeyPatch) -> None: + """Responses.parse should include stored conversation history.""" + client = _SyncResponsesClient() + responses = Responses(client) + _inline_executor(monkeypatch) + + previous_turn = [ + {"role": "user", "content": "first step"}, + {"role": "assistant", "content": "ack"}, + ] + client.history_lookup["resp-prev"] = normalize_conversation(previous_turn) + + class _Schema(BaseModel): + text: str + + messages = _messages() + responses.parse( + input=messages, + model="gpt-test", + text_format=_Schema, + previous_response_id="resp-prev", + ) + + expected_history = client.history_lookup["resp-prev"] + normalize_conversation(messages) + assert client.preflight_calls[0]["history"] == expected_history # noqa: S101 + assert client.parse_calls[0]["previous_response_id"] == "resp-prev" # noqa: S101 def test_responses_retrieve_wraps_output() -> None: @@ -336,3 +412,24 @@ async def test_async_responses_stream_returns_wrapper() -> None: stream_call = client.stream_calls[0] assert stream_call["preflight"] == ["preflight"] # noqa: S101 assert stream_call["input"] == ["input"] # noqa: S101 + assert stream_call["history"] == normalize_conversation(_messages()) # noqa: S101 + + +@pytest.mark.asyncio +async def test_async_responses_create_merges_previous_history() -> None: + """AsyncResponses.create should merge stored conversation history.""" + client = _AsyncResponsesClient() + responses = AsyncResponses(client) + + previous_turn = [ + {"role": "user", "content": "old question"}, + {"role": "assistant", "content": "old answer"}, + ] + client.history_lookup["resp-prev"] = normalize_conversation(previous_turn) + + await responses.create(input=_messages(), model="gpt-test", previous_response_id="resp-prev") + + expected_history = client.history_lookup["resp-prev"] + normalize_conversation(_messages()) + assert client.preflight_calls[0]["history"] == expected_history # noqa: S101 + assert client.history_requests == ["resp-prev"] # noqa: S101 + assert client.create_calls[0]["previous_response_id"] == "resp-prev" # noqa: S101 From 77456ced49555c53f19c20e7742596af80698432 Mon Sep 17 00:00:00 2001 From: steven10a <158192461+steven10a@users.noreply.github.com> Date: Mon, 20 Oct 2025 14:04:33 -0400 Subject: [PATCH 2/3] Updating examples with conversation history (#25) * Updating examples with conversation history * fixing lint errors --- docs/quickstart.md | 32 +++++++++ examples/basic/azure_implementation.py | 28 ++++++-- .../basic/multiturn_chat_with_alignment.py | 70 +++++++++++-------- examples/basic/pii_mask_example.py | 27 ++++--- examples/basic/structured_outputs_example.py | 42 ++++++++--- .../run_hallucination_detection.py | 68 ++++++++++-------- .../blocking/blocking_completions.py | 25 +++++-- .../streaming/streaming_completions.py | 31 +++++--- .../custom_context.py | 17 ++++- src/guardrails/runtime.py | 8 +-- tests/unit/test_runtime.py | 17 +++++ 11 files changed, 261 insertions(+), 104 deletions(-) rename examples/{basic => internal_examples}/custom_context.py (73%) diff --git a/docs/quickstart.md b/docs/quickstart.md index e7551bf..fe91f01 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -81,6 +81,38 @@ asyncio.run(main()) **That's it!** Your existing OpenAI code now includes automatic guardrail validation based on your pipeline configuration. Just use `response.llm_response` instead of `response`. +## Multi-Turn Conversations + +When maintaining conversation history across multiple turns, **only append messages after guardrails pass**. This prevents blocked input messages from polluting your conversation context. + +```python +messages: list[dict] = [] + +while True: + user_input = input("You: ") + + try: + # ✅ Pass user input inline (don't mutate messages first) + response = await client.chat.completions.create( + messages=messages + [{"role": "user", "content": user_input}], + model="gpt-4o" + ) + + response_content = response.llm_response.choices[0].message.content + print(f"Assistant: {response_content}") + + # ✅ Only append AFTER guardrails pass + messages.append({"role": "user", "content": user_input}) + messages.append({"role": "assistant", "content": response_content}) + + except GuardrailTripwireTriggered: + # ❌ Guardrail blocked - message NOT added to history + print("Message blocked by guardrails") + continue +``` + +**Why this matters**: If you append the user message before the guardrail check, blocked messages remain in your conversation history and get sent on every subsequent turn, even though they violated your safety policies. + ## Guardrail Execution Error Handling Guardrails supports two error handling modes for guardrail execution failures: diff --git a/examples/basic/azure_implementation.py b/examples/basic/azure_implementation.py index 08c5a54..c475103 100644 --- a/examples/basic/azure_implementation.py +++ b/examples/basic/azure_implementation.py @@ -54,13 +54,24 @@ } -async def process_input(guardrails_client: GuardrailsAsyncAzureOpenAI, user_input: str) -> None: - """Process user input with complete response validation using GuardrailsClient.""" +async def process_input( + guardrails_client: GuardrailsAsyncAzureOpenAI, + user_input: str, + messages: list[dict], +) -> None: + """Process user input with complete response validation using GuardrailsClient. + + Args: + guardrails_client: GuardrailsAsyncAzureOpenAI instance. + user_input: User's input text. + messages: Conversation history (modified in place after guardrails pass). + """ try: - # Use GuardrailsClient to handle all guardrail checks and LLM calls + # Pass user input inline WITHOUT mutating messages first + # Only add to messages AFTER guardrails pass and LLM call succeeds response = await guardrails_client.chat.completions.create( model=AZURE_DEPLOYMENT, - messages=[{"role": "user", "content": user_input}], + messages=messages + [{"role": "user", "content": user_input}], ) # Extract the response content from the GuardrailsResponse @@ -69,11 +80,16 @@ async def process_input(guardrails_client: GuardrailsAsyncAzureOpenAI, user_inpu # Only show output if all guardrails pass print(f"\nAssistant: {response_text}") + # Guardrails passed - now safe to add to conversation history + messages.append({"role": "user", "content": user_input}) + messages.append({"role": "assistant", "content": response_text}) + except GuardrailTripwireTriggered as e: # Extract information from the triggered guardrail triggered_result = e.guardrail_result print(" Input blocked. Please try a different message.") print(f" Full result: {triggered_result}") + # Guardrail blocked - user message NOT added to history raise except BadRequestError as e: # Handle Azure's built-in content filter errors @@ -97,6 +113,8 @@ async def main(): api_version="2025-01-01-preview", ) + messages: list[dict] = [] + while True: try: prompt = input("\nEnter a message: ") @@ -105,7 +123,7 @@ async def main(): print("Goodbye!") break - await process_input(guardrails_client, prompt) + await process_input(guardrails_client, prompt, messages) except (EOFError, KeyboardInterrupt): break except (GuardrailTripwireTriggered, BadRequestError): diff --git a/examples/basic/multiturn_chat_with_alignment.py b/examples/basic/multiturn_chat_with_alignment.py index 48dc53e..84b6c03 100644 --- a/examples/basic/multiturn_chat_with_alignment.py +++ b/examples/basic/multiturn_chat_with_alignment.py @@ -226,14 +226,21 @@ async def main(malicious: bool = False) -> None: if not user_input: continue - messages.append({"role": "user", "content": user_input}) - + # Pass user input inline WITHOUT mutating messages first + # Only add to messages AFTER guardrails pass and LLM call succeeds try: - resp = await client.chat.completions.create(model="gpt-4.1-nano", messages=messages, tools=tools) + resp = await client.chat.completions.create( + model="gpt-4.1-nano", + messages=messages + [{"role": "user", "content": user_input}], + tools=tools, + ) print_guardrail_results("initial", resp) choice = resp.llm_response.choices[0] message = choice.message tool_calls = getattr(message, "tool_calls", []) or [] + + # Guardrails passed - now safe to add user message to conversation history + messages.append({"role": "user", "content": user_input}) except GuardrailTripwireTriggered as e: info = getattr(e, "guardrail_result", None) info = info.info if info else {} @@ -252,29 +259,29 @@ async def main(malicious: bool = False) -> None: border_style="red", ) ) + # Guardrail blocked - user message NOT added to history continue if tool_calls: - # Add assistant message with tool calls to conversation - messages.append( - { - "role": "assistant", - "content": message.content, - "tool_calls": [ - { - "id": call.id, - "type": "function", - "function": { - "name": call.function.name, - "arguments": call.function.arguments or "{}", - }, - } - for call in tool_calls - ], - } - ) - - # Execute tool calls + # Prepare assistant message with tool calls (don't append yet) + assistant_message = { + "role": "assistant", + "content": message.content, + "tool_calls": [ + { + "id": call.id, + "type": "function", + "function": { + "name": call.function.name, + "arguments": call.function.arguments or "{}", + }, + } + for call in tool_calls + ], + } + + # Execute tool calls and collect results (don't append yet) + tool_messages = [] for call in tool_calls: fname = call.function.name fargs = json.loads(call.function.arguments or "{}") @@ -293,7 +300,7 @@ async def main(malicious: bool = False) -> None: "ssn": "123-45-6789", "credit_card": "4111-1111-1111-1111", } - messages.append( + tool_messages.append( { "role": "tool", "tool_call_id": call.id, @@ -302,7 +309,7 @@ async def main(malicious: bool = False) -> None: } ) else: - messages.append( + tool_messages.append( { "role": "tool", "tool_call_id": call.id, @@ -311,9 +318,13 @@ async def main(malicious: bool = False) -> None: } ) - # Final call + # Final call with tool results (pass inline without mutating messages) try: - resp = await client.chat.completions.create(model="gpt-4.1-nano", messages=messages, tools=tools) + resp = await client.chat.completions.create( + model="gpt-4.1-nano", + messages=messages + [assistant_message] + tool_messages, + tools=tools, + ) print_guardrail_results("final", resp) final_message = resp.llm_response.choices[0].message @@ -325,7 +336,9 @@ async def main(malicious: bool = False) -> None: ) ) - # Add final assistant response to conversation + # Guardrails passed - now safe to add all messages to conversation history + messages.append(assistant_message) + messages.extend(tool_messages) messages.append({"role": "assistant", "content": final_message.content}) except GuardrailTripwireTriggered as e: info = getattr(e, "guardrail_result", None) @@ -345,6 +358,7 @@ async def main(malicious: bool = False) -> None: border_style="red", ) ) + # Guardrail blocked - tool results NOT added to history continue else: # No tool calls; just print assistant content and add to conversation diff --git a/examples/basic/pii_mask_example.py b/examples/basic/pii_mask_example.py index e3cd73a..58ca48d 100644 --- a/examples/basic/pii_mask_example.py +++ b/examples/basic/pii_mask_example.py @@ -69,23 +69,20 @@ async def process_input( guardrails_client: GuardrailsAsyncOpenAI, user_input: str, + messages: list[dict], ) -> None: """Process user input using GuardrailsClient with automatic PII masking. Args: guardrails_client: GuardrailsClient instance with PII masking configuration. user_input: User's input text. + messages: Conversation history (modified in place after guardrails pass). """ try: - # Use GuardrailsClient - it handles all PII masking automatically + # Pass user input inline WITHOUT mutating messages first + # Only add to messages AFTER guardrails pass and LLM call succeeds response = await guardrails_client.chat.completions.create( - messages=[ - { - "role": "system", - "content": "You are a helpful assistant. Comply with the user's request.", - }, - {"role": "user", "content": user_input}, - ], + messages=messages + [{"role": "user", "content": user_input}], model="gpt-4", ) @@ -125,11 +122,16 @@ async def process_input( ) ) + # Guardrails passed - now safe to add to conversation history + messages.append({"role": "user", "content": user_input}) + messages.append({"role": "assistant", "content": content}) + except GuardrailTripwireTriggered as exc: stage_name = exc.guardrail_result.info.get("stage_name", "unknown") guardrail_name = exc.guardrail_result.info.get("guardrail_name", "unknown") console.print(f"[bold red]Guardrail '{guardrail_name}' triggered in stage '{stage_name}'![/bold red]") console.print(Panel(str(exc.guardrail_result), title="Guardrail Result", border_style="red")) + # Guardrail blocked - user message NOT added to history raise @@ -138,6 +140,13 @@ async def main() -> None: # Initialize GuardrailsAsyncOpenAI with PII masking configuration guardrails_client = GuardrailsAsyncOpenAI(config=PIPELINE_CONFIG) + messages: list[dict] = [ + { + "role": "system", + "content": "You are a helpful assistant. Comply with the user's request.", + } + ] + with suppress(KeyboardInterrupt, asyncio.CancelledError): while True: try: @@ -145,7 +154,7 @@ async def main() -> None: if user_input.lower() == "exit": break - await process_input(guardrails_client, user_input) + await process_input(guardrails_client, user_input, messages) except EOFError: break diff --git a/examples/basic/structured_outputs_example.py b/examples/basic/structured_outputs_example.py index be88f6c..05e011f 100644 --- a/examples/basic/structured_outputs_example.py +++ b/examples/basic/structured_outputs_example.py @@ -23,39 +23,64 @@ class UserInfo(BaseModel): "version": 1, "guardrails": [ {"name": "Moderation", "config": {"categories": ["hate", "violence"]}}, + { + "name": "Custom Prompt Check", + "config": { + "model": "gpt-4.1-nano", + "confidence_threshold": 0.7, + "system_prompt_details": "Check if the text contains any math problems.", + }, + }, ], }, } -async def extract_user_info(guardrails_client: GuardrailsAsyncOpenAI, text: str) -> UserInfo: - """Extract user information using responses_parse with structured output.""" +async def extract_user_info( + guardrails_client: GuardrailsAsyncOpenAI, + text: str, + previous_response_id: str | None = None, +) -> tuple[UserInfo, str]: + """Extract user information using responses.parse with structured output.""" try: + # Use responses.parse() for structured outputs with guardrails + # Note: responses.parse() requires input as a list of message dicts response = await guardrails_client.responses.parse( - input=[{"role": "system", "content": "Extract user information from the provided text."}, {"role": "user", "content": text}], + input=[ + {"role": "system", "content": "Extract user information from the provided text."}, + {"role": "user", "content": text}, + ], model="gpt-4.1-nano", text_format=UserInfo, + previous_response_id=previous_response_id, ) # Access the parsed structured output user_info = response.llm_response.output_parsed print(f"✅ Successfully extracted: {user_info.name}, {user_info.age}, {user_info.email}") - return user_info + # Return user info and response ID (only returned if guardrails pass) + return user_info, response.llm_response.id - except GuardrailTripwireTriggered as exc: - print(f"❌ Guardrail triggered: {exc}") + except GuardrailTripwireTriggered: + # Guardrail blocked - no response ID returned, conversation history unchanged raise async def main() -> None: - """Interactive loop demonstrating structured outputs.""" + """Interactive loop demonstrating structured outputs with conversation history.""" # Initialize GuardrailsAsyncOpenAI guardrails_client = GuardrailsAsyncOpenAI(config=PIPELINE_CONFIG) + + # Use previous_response_id to maintain conversation history with responses API + response_id: str | None = None + while True: try: text = input("Enter text to extract user info. Include name, age, and email: ") - user_info = await extract_user_info(guardrails_client, text) + + # Extract user info - only updates response_id if guardrails pass + user_info, response_id = await extract_user_info(guardrails_client, text, response_id) # Demonstrate structured output clearly print("\n✅ Parsed structured output:") @@ -66,6 +91,7 @@ async def main() -> None: print("\nExiting.") break except GuardrailTripwireTriggered as exc: + # Guardrail blocked - response_id unchanged, so blocked message not in history print(f"🛑 Guardrail triggered: {exc}") continue except Exception as e: diff --git a/examples/hallucination_detection/run_hallucination_detection.py b/examples/hallucination_detection/run_hallucination_detection.py index 99765ec..f65ecb2 100644 --- a/examples/hallucination_detection/run_hallucination_detection.py +++ b/examples/hallucination_detection/run_hallucination_detection.py @@ -34,37 +34,47 @@ async def main(): # Initialize the guardrails client client = GuardrailsAsyncOpenAI(config=pipeline_config) - # Example hallucination - candidate = "Microsoft's annual revenue was $500 billion in 2023." - - # Example non-hallucination - # candidate = "Microsoft's annual revenue was $56.5 billion in 2023." - - try: - # Use the client to check the text with guardrails - response = await client.chat.completions.create( - messages=[{"role": "user", "content": candidate}], - model="gpt-4.1-mini", - ) - - console.print( - Panel( - f"[bold green]Tripwire not triggered[/bold green]\n\nResponse: {response.llm_response.choices[0].message.content}", - title="✅ Guardrail Check Passed", - border_style="green", + messages: list[dict[str, str]] = [] + + # Example inputs to test + test_cases = [ + "Microsoft's annual revenue was $500 billion in 2023.", # hallucination + "Microsoft's annual revenue was $56.5 billion in 2023.", # non-hallucination + ] + + for candidate in test_cases: + console.print(f"\n[bold cyan]Testing:[/bold cyan] {candidate}\n") + + try: + # Pass user input inline WITHOUT mutating messages first + response = await client.chat.completions.create( + messages=messages + [{"role": "user", "content": candidate}], + model="gpt-4.1-mini", ) - ) - - except GuardrailTripwireTriggered as exc: - # Make the guardrail triggered message stand out with Rich - console.print( - Panel( - f"[bold red]Guardrail triggered: {exc.guardrail_result.info.get('guardrail_name', 'unnamed')}[/bold red]", - title="⚠️ Guardrail Alert", - border_style="red", + + response_content = response.llm_response.choices[0].message.content + console.print( + Panel( + f"[bold green]Tripwire not triggered[/bold green]\n\nResponse: {response_content}", + title="✅ Guardrail Check Passed", + border_style="green", + ) + ) + + # Guardrails passed - now safe to add to conversation history + messages.append({"role": "user", "content": candidate}) + messages.append({"role": "assistant", "content": response_content}) + + except GuardrailTripwireTriggered as exc: + # Guardrail blocked - user message NOT added to history + console.print( + Panel( + f"[bold red]Guardrail triggered: {exc.guardrail_result.info.get('guardrail_name', 'unnamed')}[/bold red]", + title="⚠️ Guardrail Alert", + border_style="red", + ) ) - ) - print(f"Result details: {exc.guardrail_result.info}") + print(f"Result details: {exc.guardrail_result.info}") if __name__ == "__main__": diff --git a/examples/implementation_code/blocking/blocking_completions.py b/examples/implementation_code/blocking/blocking_completions.py index 82ea931..f05cf62 100644 --- a/examples/implementation_code/blocking/blocking_completions.py +++ b/examples/implementation_code/blocking/blocking_completions.py @@ -11,20 +11,29 @@ from guardrails import GuardrailsAsyncOpenAI, GuardrailTripwireTriggered -async def process_input(guardrails_client: GuardrailsAsyncOpenAI, user_input: str) -> None: +async def process_input( + guardrails_client: GuardrailsAsyncOpenAI, + user_input: str, + messages: list[dict[str, str]], +) -> None: """Process user input with complete response validation using the new GuardrailsClient.""" try: - # Use the GuardrailsClient - it handles all guardrail validation automatically - # including pre-flight, input, and output stages, plus the LLM call + # Pass user input inline WITHOUT mutating messages first + # Only add to messages AFTER guardrails pass and LLM call succeeds response = await guardrails_client.chat.completions.create( - messages=[{"role": "user", "content": user_input}], + messages=messages + [{"role": "user", "content": user_input}], model="gpt-4.1-nano", ) - print(f"\nAssistant: {response.llm_response.choices[0].message.content}") + response_content = response.llm_response.choices[0].message.content + print(f"\nAssistant: {response_content}") + + # Guardrails passed - now safe to add to conversation history + messages.append({"role": "user", "content": user_input}) + messages.append({"role": "assistant", "content": response_content}) except GuardrailTripwireTriggered: - # GuardrailsClient automatically handles tripwire exceptions + # Guardrail blocked - user message NOT added to history raise @@ -32,10 +41,12 @@ async def main(): # Initialize GuardrailsAsyncOpenAI with the config file guardrails_client = GuardrailsAsyncOpenAI(config=Path("guardrails_config.json")) + messages: list[dict[str, str]] = [] + while True: try: prompt = input("\nEnter a message: ") - await process_input(guardrails_client, prompt) + await process_input(guardrails_client, prompt, messages) except (EOFError, KeyboardInterrupt): break except GuardrailTripwireTriggered as e: diff --git a/examples/implementation_code/streaming/streaming_completions.py b/examples/implementation_code/streaming/streaming_completions.py index 5365cec..6aca50c 100644 --- a/examples/implementation_code/streaming/streaming_completions.py +++ b/examples/implementation_code/streaming/streaming_completions.py @@ -12,24 +12,37 @@ from guardrails import GuardrailsAsyncOpenAI, GuardrailTripwireTriggered -async def process_input(guardrails_client: GuardrailsAsyncOpenAI, user_input: str) -> str: +async def process_input( + guardrails_client: GuardrailsAsyncOpenAI, + user_input: str, + messages: list[dict[str, str]], +) -> str: """Process user input with streaming output and guardrails using the GuardrailsClient.""" try: - # Use the GuardrailsClient - it handles all guardrail validation automatically - # including pre-flight, input, and output stages, plus the LLM call + # Pass user input inline WITHOUT mutating messages first + # Only add to messages AFTER guardrails pass and streaming completes stream = await guardrails_client.chat.completions.create( - messages=[{"role": "user", "content": user_input}], + messages=messages + [{"role": "user", "content": user_input}], model="gpt-4.1-nano", stream=True, ) - # Stream with output guardrail checks + # Stream with output guardrail checks and accumulate response + response_content = "" async for chunk in stream: if chunk.llm_response.choices[0].delta.content: - print(chunk.llm_response.choices[0].delta.content, end="", flush=True) - return "Stream completed successfully" + delta = chunk.llm_response.choices[0].delta.content + print(delta, end="", flush=True) + response_content += delta + + print() # New line after streaming + + # Guardrails passed - now safe to add to conversation history + messages.append({"role": "user", "content": user_input}) + messages.append({"role": "assistant", "content": response_content}) except GuardrailTripwireTriggered: + # Guardrail blocked - user message NOT added to history raise @@ -37,10 +50,12 @@ async def main(): # Initialize GuardrailsAsyncOpenAI with the config file guardrails_client = GuardrailsAsyncOpenAI(config=Path("guardrails_config.json")) + messages: list[dict[str, str]] = [] + while True: try: prompt = input("\nEnter a message: ") - await process_input(guardrails_client, prompt) + await process_input(guardrails_client, prompt, messages) except (EOFError, KeyboardInterrupt): break except GuardrailTripwireTriggered as exc: diff --git a/examples/basic/custom_context.py b/examples/internal_examples/custom_context.py similarity index 73% rename from examples/basic/custom_context.py rename to examples/internal_examples/custom_context.py index 331189a..511d327 100644 --- a/examples/basic/custom_context.py +++ b/examples/internal_examples/custom_context.py @@ -47,16 +47,27 @@ async def main() -> None: # the default OpenAI for main LLM calls client = GuardrailsAsyncOpenAI(config=PIPELINE_CONFIG) + messages: list[dict[str, str]] = [] + with suppress(KeyboardInterrupt, asyncio.CancelledError): while True: try: user_input = input("Enter a message: ") - response = await client.chat.completions.create(model="gpt-4.1-nano", messages=[{"role": "user", "content": user_input}]) - print("Assistant:", response.llm_response.choices[0].message.content) + # Pass user input inline WITHOUT mutating messages first + response = await client.chat.completions.create( + model="gpt-4.1-nano", + messages=messages + [{"role": "user", "content": user_input}], + ) + response_content = response.llm_response.choices[0].message.content + print("Assistant:", response_content) + + # Guardrails passed - now safe to add to conversation history + messages.append({"role": "user", "content": user_input}) + messages.append({"role": "assistant", "content": response_content}) except EOFError: break except GuardrailTripwireTriggered as exc: - # Minimal handling; guardrail details available on exc.guardrail_result + # Guardrail blocked - user message NOT added to history print("🛑 Guardrail triggered.", str(exc)) continue diff --git a/src/guardrails/runtime.py b/src/guardrails/runtime.py index 8de1fda..68948a5 100644 --- a/src/guardrails/runtime.py +++ b/src/guardrails/runtime.py @@ -113,11 +113,6 @@ class ConfigBundle(BaseModel): Attributes: guardrails (list[GuardrailConfig]): The configured guardrails. version (int): Format version for forward/backward compatibility. - stage_name (str): User-defined name for the pipeline stage this bundle is for. - This can be any string that helps identify which part of your pipeline - triggered the guardrail (e.g., "user_input_validation", "content_generation", - "pre_processing", etc.). It will be included in GuardrailResult info for - easy identification. config (dict[str, Any]): Execution configuration for this bundle. Optional fields include: - concurrency (int): Maximum number of guardrails to run in parallel (default: 10) @@ -126,7 +121,6 @@ class ConfigBundle(BaseModel): guardrails: list[GuardrailConfig] version: int = 1 - stage_name: str = "unnamed" config: dict[str, Any] = {} model_config = ConfigDict(frozen=True, extra="forbid") @@ -563,4 +557,4 @@ async def check_plain_text( ctx = _get_default_ctx() bundle = load_config_bundle(bundle_path) guardrails: list[ConfiguredGuardrail[Any, str, Any]] = instantiate_guardrails(bundle, registry=registry) - return await run_guardrails(ctx, text, "text/plain", guardrails, stage_name=bundle.stage_name, **kwargs) + return await run_guardrails(ctx, text, "text/plain", guardrails, **kwargs) diff --git a/tests/unit/test_runtime.py b/tests/unit/test_runtime.py index f4c0241..cd91d6b 100644 --- a/tests/unit/test_runtime.py +++ b/tests/unit/test_runtime.py @@ -177,6 +177,23 @@ def test_load_pipeline_bundles_errors_on_invalid_dict() -> None: load_pipeline_bundles({"version": 1, "invalid": "field"}) +def test_config_bundle_rejects_stage_name_override() -> None: + """ConfigBundle forbids overriding stage names.""" + with pytest.raises(ValidationError): + ConfigBundle(guardrails=[], version=1, stage_name="custom") # type: ignore[call-arg] + + +def test_pipeline_bundles_reject_stage_name_override() -> None: + """Pipeline bundle stages disallow custom stage_name field.""" + with pytest.raises(ValidationError): + load_pipeline_bundles( + { + "version": 1, + "pre_flight": {"version": 1, "guardrails": [], "stage_name": "custom"}, + } + ) + + @given(st.text()) def test_load_pipeline_bundles_plain_string_invalid(text: str) -> None: """Plain strings are rejected.""" From c1d868b615a1bf4e72b542d87a5f515f19c32b62 Mon Sep 17 00:00:00 2001 From: steven10a <158192461+steven10a@users.noreply.github.com> Date: Tue, 21 Oct 2025 18:26:50 -0400 Subject: [PATCH 3/3] version bump to v0.1.2 (#26) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 4a0154b..d73010b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "openai-guardrails" -version = "0.1.1" +version = "0.1.2" description = "OpenAI Guardrails: A framework for building safe and reliable AI systems." readme = "README.md" requires-python = ">=3.11"