Skip to content

Commit 515bd41

Browse files
authored
Handle sync guardrail calls to avoid awaitable error (#21)
1 parent a251298 commit 515bd41

File tree

5 files changed

+110
-15
lines changed

5 files changed

+110
-15
lines changed

‎src/guardrails/checks/text/hallucination_detection.py‎

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,7 @@
5252
fromguardrails.specimportGuardrailSpecMetadata
5353
fromguardrails.typesimportGuardrailLLMContextProto, GuardrailResult
5454

55-
from .llm_baseimport (
56-
LLMConfig,
57-
LLMOutput,
58-
)
55+
from .llm_baseimportLLMConfig, LLMOutput, _invoke_openai_callable
5956

6057
logger=logging.getLogger(__name__)
6158

@@ -210,9 +207,10 @@ async def hallucination_detection(
210207
validation_query=f"{VALIDATION_PROMPT}\n\nText to validate:\n{candidate}"
211208

212209
# Use the Responses API with file search and structured output
213-
response=awaitctx.guardrail_llm.responses.parse(
214-
model=config.model,
210+
response=await_invoke_openai_callable(
211+
ctx.guardrail_llm.responses.parse,
215212
input=validation_query,
213+
model=config.model,
216214
text_format=HallucinationDetectionOutput,
217215
tools=[{"type": "file_search", "vector_store_ids": [config.knowledge_source]}],
218216
)

‎src/guardrails/checks/text/llm_base.py‎

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,16 @@ class MyLLMOutput(LLMOutput):
3131

3232
from __future__ importannotations
3333

34+
importasyncio
35+
importfunctools
36+
importinspect
3437
importjson
3538
importlogging
3639
importtextwrap
37-
fromtypingimportTYPE_CHECKING, TypeVar
40+
fromcollections.abcimportCallable
41+
fromtypingimportTYPE_CHECKING, Any, TypeVar
3842

39-
fromopenaiimportAsyncOpenAI
43+
fromopenaiimportAsyncOpenAI, OpenAI
4044
frompydanticimportBaseModel, ConfigDict, Field
4145

4246
fromguardrails.registryimportdefault_spec_registry
@@ -45,7 +49,13 @@ class MyLLMOutput(LLMOutput):
4549
fromguardrails.utils.outputimportOutputSchema
4650

4751
ifTYPE_CHECKING:
48-
fromopenaiimportAsyncOpenAI
52+
fromopenaiimportAsyncAzureOpenAI, AzureOpenAI# type: ignore[unused-import]
53+
else:
54+
try:
55+
fromopenaiimportAsyncAzureOpenAI, AzureOpenAI# type: ignore
56+
exceptException: # pragma: no cover - optional dependency
57+
AsyncAzureOpenAI=object# type: ignore[assignment]
58+
AzureOpenAI=object# type: ignore[assignment]
4959

5060
logger=logging.getLogger(__name__)
5161

@@ -165,10 +175,46 @@ def _strip_json_code_fence(text: str) -> str:
165175
returncandidate
166176

167177

178+
asyncdef_invoke_openai_callable(
179+
method: Callable[..., Any],
180+
/,
181+
*args: Any,
182+
**kwargs: Any,
183+
) ->Any:
184+
"""Invoke OpenAI SDK methods that may be sync or async."""
185+
ifinspect.iscoroutinefunction(method):
186+
returnawaitmethod(*args, **kwargs)
187+
188+
loop=asyncio.get_running_loop()
189+
result=awaitloop.run_in_executor(
190+
None,
191+
functools.partial(method, *args, **kwargs),
192+
)
193+
ifinspect.isawaitable(result):
194+
returnawaitresult
195+
returnresult
196+
197+
198+
asyncdef_request_chat_completion(
199+
client: AsyncOpenAI|OpenAI|AsyncAzureOpenAI|AzureOpenAI,
200+
*,
201+
messages: list[dict[str, str]],
202+
model: str,
203+
response_format: dict[str, Any],
204+
) ->Any:
205+
"""Invoke chat.completions.create on sync or async OpenAI clients."""
206+
returnawait_invoke_openai_callable(
207+
client.chat.completions.create,
208+
messages=messages,
209+
model=model,
210+
response_format=response_format,
211+
)
212+
213+
168214
asyncdefrun_llm(
169215
text: str,
170216
system_prompt: str,
171-
client: AsyncOpenAI,
217+
client: AsyncOpenAI|OpenAI|AsyncAzureOpenAI|AzureOpenAI,
172218
model: str,
173219
output_model: type[LLMOutput],
174220
) ->LLMOutput:
@@ -180,7 +226,7 @@ async def run_llm(
180226
Args:
181227
text (str): Text to analyze.
182228
system_prompt (str): Prompt instructions for the LLM.
183-
client (AsyncOpenAI): OpenAI client for LLM inference.
229+
client (AsyncOpenAI | OpenAI | AsyncAzureOpenAI | AzureOpenAI): OpenAI client used for guardrails.
184230
model (str): Identifier for which LLM model to use.
185231
output_model (type[LLMOutput]): Model for parsing and validating the LLM's response.
186232
@@ -190,7 +236,8 @@ async def run_llm(
190236
full_prompt=_build_full_prompt(system_prompt)
191237

192238
try:
193-
response=awaitclient.chat.completions.create(
239+
response=await_request_chat_completion(
240+
client=client,
194241
messages=[
195242
{"role": "system", "content": full_prompt},
196243
{"role": "user", "content": f"# Text\n\n{text}"},

‎src/guardrails/checks/text/prompt_injection_detection.py‎

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
fromguardrails.specimportGuardrailSpecMetadata
3737
fromguardrails.typesimportGuardrailLLMContextProto, GuardrailResult
3838

39-
from .llm_baseimportLLMConfig, LLMOutput
39+
from .llm_baseimportLLMConfig, LLMOutput, _invoke_openai_callable
4040

4141
__all__= ["prompt_injection_detection", "PromptInjectionDetectionOutput"]
4242

@@ -373,9 +373,10 @@ def _create_skip_result(
373373

374374
asyncdef_call_prompt_injection_detection_llm(ctx: GuardrailLLMContextProto, prompt: str, config: LLMConfig) ->PromptInjectionDetectionOutput:
375375
"""Call LLM for prompt injection detection analysis."""
376-
parsed_response=awaitctx.guardrail_llm.responses.parse(
377-
model=config.model,
376+
parsed_response=await_invoke_openai_callable(
377+
ctx.guardrail_llm.responses.parse,
378378
input=prompt,
379+
model=config.model,
379380
text_format=PromptInjectionDetectionOutput,
380381
)
381382
returnparsed_response.output_parsed

‎tests/unit/checks/test_llm_base.py‎

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,20 @@ def __init__(self, content: str | None) -> None:
3434
self.chat=SimpleNamespace(completions=_FakeCompletions(content))
3535

3636

37+
class_FakeSyncCompletions:
38+
def__init__(self, content: str|None) ->None:
39+
self._content=content
40+
41+
defcreate(self, **kwargs: Any) ->Any:
42+
_=kwargs
43+
returnSimpleNamespace(choices=[SimpleNamespace(message=SimpleNamespace(content=self._content))])
44+
45+
46+
class_FakeSyncClient:
47+
def__init__(self, content: str|None) ->None:
48+
self.chat=SimpleNamespace(completions=_FakeSyncCompletions(content))
49+
50+
3751
deftest_strip_json_code_fence_removes_wrapping() ->None:
3852
"""Valid JSON code fences should be removed."""
3953
fenced="""```json
@@ -64,6 +78,23 @@ async def test_run_llm_returns_valid_output() -> None:
6478
assertresult.flaggedisTrueandresult.confidence==0.9# noqa: S101
6579

6680

81+
@pytest.mark.asyncio
82+
asyncdeftest_run_llm_supports_sync_clients() ->None:
83+
"""run_llm should invoke synchronous clients without awaiting them."""
84+
client=_FakeSyncClient('{"flagged": false, "confidence": 0.25}')
85+
86+
result=awaitrun_llm(
87+
text="General text",
88+
system_prompt="Assess text.",
89+
client=client, # type: ignore[arg-type]
90+
model="gpt-test",
91+
output_model=LLMOutput,
92+
)
93+
94+
assertisinstance(result, LLMOutput) # noqa: S101
95+
assertresult.flaggedisFalseandresult.confidence==0.25# noqa: S101
96+
97+
6798
@pytest.mark.asyncio
6899
asyncdeftest_run_llm_handles_content_filter_error(monkeypatch: pytest.MonkeyPatch) ->None:
69100
"""Content filter errors should return LLMErrorOutput with flagged=True."""

‎tests/unit/checks/test_prompt_injection_detection.py‎

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,21 @@ async def failing_llm(*_args: Any, **_kwargs: Any) -> PromptInjectionDetectionOu
147147

148148
assertresult.tripwire_triggeredisFalse# noqa: S101
149149
assert"Error during prompt injection detection check"inresult.info["observation"] # noqa: S101
150+
151+
152+
@pytest.mark.asyncio
153+
asyncdeftest_prompt_injection_detection_llm_supports_sync_responses() ->None:
154+
"""Underlying responses.parse may be synchronous for some clients."""
155+
analysis=PromptInjectionDetectionOutput(flagged=True, confidence=0.4, observation="Action summary")
156+
157+
class_SyncResponses:
158+
defparse(self, **kwargs: Any) ->Any:
159+
_=kwargs
160+
returnSimpleNamespace(output_parsed=analysis)
161+
162+
context=SimpleNamespace(guardrail_llm=SimpleNamespace(responses=_SyncResponses()))
163+
config=LLMConfig(model="gpt-test", confidence_threshold=0.5)
164+
165+
parsed=awaitpid_module._call_prompt_injection_detection_llm(context, "prompt", config)
166+
167+
assertparsedisanalysis# noqa: S101

0 commit comments

Comments
(0)