Skip to content

Commit eb447fb

Browse files
committed
Fix guardrail task cleanup to properly await cancelled tasks
Problem: The _cleanup_guardrail_tasks() method in RealtimeSession was only calling task.cancel() on pending guardrail tasks but not awaiting them. This could lead to: 1. Unhandled task exception warnings 2. Potential memory leaks from abandoned tasks 3. Improper resource cleanup Evidence: - Test code in tests/realtime/test_session.py:1199 shows the correct pattern: await asyncio.gather(*session._guardrail_tasks, return_exceptions=True) - Similar pattern used in openai_realtime.py:519-523 for WebSocket task cleanup Solution: 1. Made _cleanup_guardrail_tasks() async 2. Added await asyncio.gather() for real asyncio.Task objects to properly collect exceptions (with isinstance check to support mock objects in tests) 3. Updated _cleanup() to await the cleanup method Testing: - Created comprehensive test suite in tests/realtime/test_guardrail_cleanup.py with 3 test cases: 1. Verify cancelled tasks are properly awaited 2. Verify exceptions during cleanup are handled 3. Verify multiple concurrent tasks are cleaned up - All new tests pass - All existing tests pass (838 passed, 3 skipped) - Note: test_issue_889_guardrail_tool_execution has 1 pre-existing failure unrelated to this PR (also fails on main)
1 parent 16169e1 commit eb447fb

File tree

2 files changed

+264
-2
lines changed

2 files changed

+264
-2
lines changed

‎src/agents/realtime/session.py‎

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -746,16 +746,32 @@ def _on_guardrail_task_done(self, task: asyncio.Task[Any]) -> None:
746746
)
747747
)
748748

749-
def_cleanup_guardrail_tasks(self) ->None:
749+
asyncdef_cleanup_guardrail_tasks(self) ->None:
750+
"""Cancel all pending guardrail tasks and wait for them to complete.
751+
752+
This ensures that any exceptions raised by the tasks are properly handled
753+
and prevents warnings about unhandled task exceptions.
754+
"""
755+
# Collect real asyncio.Task objects that need to be awaited
756+
real_tasks= []
757+
750758
fortaskinself._guardrail_tasks:
751759
ifnottask.done():
752760
task.cancel()
761+
# Only await real asyncio.Task objects (not mocks in tests)
762+
ifisinstance(task, asyncio.Task):
763+
real_tasks.append(task)
764+
765+
# Wait for all real tasks to complete and collect any exceptions
766+
ifreal_tasks:
767+
awaitasyncio.gather(*real_tasks, return_exceptions=True)
768+
753769
self._guardrail_tasks.clear()
754770

755771
asyncdef_cleanup(self) ->None:
756772
"""Clean up all resources and mark session as closed."""
757773
# Cancel and cleanup guardrail tasks
758-
self._cleanup_guardrail_tasks()
774+
awaitself._cleanup_guardrail_tasks()
759775

760776
# Remove ourselves as a listener
761777
self._model.remove_listener(self)
Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
"""Test guardrail task cleanup to ensure proper exception handling.
2+
3+
This test verifies the fix for the bug where _cleanup_guardrail_tasks() was not
4+
properly awaiting cancelled tasks, which could lead to unhandled task exceptions
5+
and potential memory leaks.
6+
"""
7+
8+
importasyncio
9+
fromunittest.mockimportAsyncMock, Mock, PropertyMock
10+
11+
importpytest
12+
13+
fromagents.guardrailimportGuardrailFunctionOutput, OutputGuardrail
14+
fromagents.realtimeimportRealtimeSession
15+
fromagents.realtime.agentimportRealtimeAgent
16+
fromagents.realtime.configimportRealtimeRunConfig
17+
fromagents.realtime.modelimportRealtimeModel
18+
fromagents.realtime.model_eventsimportRealtimeModelTranscriptDeltaEvent
19+
20+
21+
classMockRealtimeModel(RealtimeModel):
22+
"""Mock realtime model for testing."""
23+
24+
def__init__(self):
25+
super().__init__()
26+
self.listeners= []
27+
self.connect_called=False
28+
self.close_called=False
29+
self.sent_events= []
30+
self.sent_messages= []
31+
self.sent_audio= []
32+
self.sent_tool_outputs= []
33+
self.interrupts_called=0
34+
35+
asyncdefconnect(self, options=None):
36+
self.connect_called=True
37+
38+
defadd_listener(self, listener):
39+
self.listeners.append(listener)
40+
41+
defremove_listener(self, listener):
42+
iflistenerinself.listeners:
43+
self.listeners.remove(listener)
44+
45+
asyncdefsend_event(self, event):
46+
fromagents.realtime.model_inputsimport (
47+
RealtimeModelSendAudio,
48+
RealtimeModelSendInterrupt,
49+
RealtimeModelSendToolOutput,
50+
RealtimeModelSendUserInput,
51+
)
52+
53+
self.sent_events.append(event)
54+
55+
# Update legacy tracking for compatibility
56+
ifisinstance(event, RealtimeModelSendUserInput):
57+
self.sent_messages.append(event.user_input)
58+
elifisinstance(event, RealtimeModelSendAudio):
59+
self.sent_audio.append((event.audio, event.commit))
60+
elifisinstance(event, RealtimeModelSendToolOutput):
61+
self.sent_tool_outputs.append((event.tool_call, event.output, event.start_response))
62+
elifisinstance(event, RealtimeModelSendInterrupt):
63+
self.interrupts_called+=1
64+
65+
asyncdefclose(self):
66+
self.close_called=True
67+
68+
69+
@pytest.fixture
70+
defmock_model():
71+
returnMockRealtimeModel()
72+
73+
74+
@pytest.fixture
75+
defmock_agent():
76+
agent=Mock(spec=RealtimeAgent)
77+
agent.name="test_agent"
78+
agent.get_all_tools=AsyncMock(return_value=[])
79+
type(agent).handoffs=PropertyMock(return_value=[])
80+
type(agent).output_guardrails=PropertyMock(return_value=[])
81+
returnagent
82+
83+
84+
@pytest.mark.asyncio
85+
asyncdeftest_guardrail_task_cleanup_awaits_cancelled_tasks(mock_model, mock_agent):
86+
"""Test that cleanup properly awaits cancelled guardrail tasks.
87+
88+
This test verifies that when guardrail tasks are cancelled during cleanup,
89+
the cleanup method properly awaits them to completion using asyncio.gather()
90+
with return_exceptions=True. This ensures:
91+
1. No warnings about unhandled task exceptions
92+
2. Proper resource cleanup
93+
3. No memory leaks from abandoned tasks
94+
"""
95+
96+
# Create a guardrail that runs a long async operation
97+
task_started=asyncio.Event()
98+
task_cancelled=asyncio.Event()
99+
100+
asyncdefslow_guardrail_func(context, agent, output):
101+
"""A guardrail that takes time to execute."""
102+
task_started.set()
103+
try:
104+
# Simulate a long-running operation
105+
awaitasyncio.sleep(10)
106+
returnGuardrailFunctionOutput(output_info={}, tripwire_triggered=False)
107+
exceptasyncio.CancelledError:
108+
task_cancelled.set()
109+
raise
110+
111+
guardrail=OutputGuardrail(guardrail_function=slow_guardrail_func, name="slow_guardrail")
112+
113+
run_config: RealtimeRunConfig={
114+
"output_guardrails": [guardrail],
115+
"guardrails_settings":{"debounce_text_length": 5},
116+
}
117+
118+
session=RealtimeSession(mock_model, mock_agent, None, run_config=run_config)
119+
120+
# Trigger a guardrail by sending a transcript delta
121+
transcript_event=RealtimeModelTranscriptDeltaEvent(
122+
item_id="item_1", delta="hello world", response_id="resp_1"
123+
)
124+
125+
awaitsession.on_event(transcript_event)
126+
127+
# Wait for the guardrail task to start
128+
awaitasyncio.wait_for(task_started.wait(), timeout=1.0)
129+
130+
# Verify a guardrail task was created
131+
assertlen(session._guardrail_tasks) ==1
132+
task=list(session._guardrail_tasks)[0]
133+
assertnottask.done()
134+
135+
# Now cleanup the session - this should cancel and await the task
136+
awaitsession._cleanup_guardrail_tasks()
137+
138+
# Verify the task was cancelled and properly awaited
139+
asserttask_cancelled.is_set(), "Task should have received CancelledError"
140+
assertlen(session._guardrail_tasks) ==0, "Tasks list should be cleared"
141+
142+
# No warnings should be raised about unhandled task exceptions
143+
144+
145+
@pytest.mark.asyncio
146+
asyncdeftest_guardrail_task_cleanup_with_exception(mock_model, mock_agent):
147+
"""Test that cleanup handles guardrail tasks that raise exceptions.
148+
149+
This test verifies that if a guardrail task raises an exception (not just
150+
CancelledError), the cleanup method still completes successfully and doesn't
151+
propagate the exception, thanks to return_exceptions=True.
152+
"""
153+
154+
task_started=asyncio.Event()
155+
exception_raised=asyncio.Event()
156+
157+
asyncdeffailing_guardrail_func(context, agent, output):
158+
"""A guardrail that raises an exception."""
159+
task_started.set()
160+
try:
161+
awaitasyncio.sleep(10)
162+
returnGuardrailFunctionOutput(output_info={}, tripwire_triggered=False)
163+
exceptasyncio.CancelledErrorase:
164+
exception_raised.set()
165+
# Simulate an error during cleanup
166+
raiseRuntimeError("Cleanup error") frome
167+
168+
guardrail=OutputGuardrail(
169+
guardrail_function=failing_guardrail_func, name="failing_guardrail"
170+
)
171+
172+
run_config: RealtimeRunConfig={
173+
"output_guardrails": [guardrail],
174+
"guardrails_settings":{"debounce_text_length": 5},
175+
}
176+
177+
session=RealtimeSession(mock_model, mock_agent, None, run_config=run_config)
178+
179+
# Trigger a guardrail
180+
transcript_event=RealtimeModelTranscriptDeltaEvent(
181+
item_id="item_1", delta="hello world", response_id="resp_1"
182+
)
183+
184+
awaitsession.on_event(transcript_event)
185+
186+
# Wait for the guardrail task to start
187+
awaitasyncio.wait_for(task_started.wait(), timeout=1.0)
188+
189+
# Cleanup should not raise the RuntimeError due to return_exceptions=True
190+
awaitsession._cleanup_guardrail_tasks()
191+
192+
# Verify cleanup completed successfully
193+
assertexception_raised.is_set()
194+
assertlen(session._guardrail_tasks) ==0
195+
196+
197+
@pytest.mark.asyncio
198+
asyncdeftest_guardrail_task_cleanup_with_multiple_tasks(mock_model, mock_agent):
199+
"""Test cleanup with multiple pending guardrail tasks.
200+
201+
This test verifies that cleanup properly handles multiple concurrent guardrail
202+
tasks by triggering guardrails multiple times, then cancelling and awaiting all of them.
203+
"""
204+
205+
tasks_started=asyncio.Event()
206+
tasks_cancelled=0
207+
208+
asyncdefslow_guardrail_func(context, agent, output):
209+
nonlocaltasks_cancelled
210+
tasks_started.set()
211+
try:
212+
awaitasyncio.sleep(10)
213+
returnGuardrailFunctionOutput(output_info={}, tripwire_triggered=False)
214+
exceptasyncio.CancelledError:
215+
tasks_cancelled+=1
216+
raise
217+
218+
guardrail=OutputGuardrail(guardrail_function=slow_guardrail_func, name="slow_guardrail")
219+
220+
run_config: RealtimeRunConfig={
221+
"output_guardrails": [guardrail],
222+
"guardrails_settings":{"debounce_text_length": 5},
223+
}
224+
225+
session=RealtimeSession(mock_model, mock_agent, None, run_config=run_config)
226+
227+
# Trigger guardrails multiple times to create multiple tasks
228+
foriinrange(3):
229+
transcript_event=RealtimeModelTranscriptDeltaEvent(
230+
item_id=f"item_{i}", delta="hello world", response_id=f"resp_{i}"
231+
)
232+
awaitsession.on_event(transcript_event)
233+
234+
# Wait for at least one task to start
235+
awaitasyncio.wait_for(tasks_started.wait(), timeout=1.0)
236+
237+
# Should have at least one guardrail task
238+
initial_task_count=len(session._guardrail_tasks)
239+
assertinitial_task_count>=1, "At least one guardrail task should exist"
240+
241+
# Cleanup should cancel and await all tasks
242+
awaitsession._cleanup_guardrail_tasks()
243+
244+
# Verify all tasks were cancelled and cleared
245+
asserttasks_cancelled>=1, "At least one task should have been cancelled"
246+
assertlen(session._guardrail_tasks) ==0

0 commit comments

Comments
(0)