Skip to content

Commit 9f44bed

Browse files
committed
Switch to Python 3.12-style wait_for
`wait_for` has been a mess with respect to cancellations consistently in `asyncio`. Hopefully the approach taken in Python 3.12 solves the issues, so adopt that instead of trying to "fix" `wait_for` with wrappers on older Pythons. Use `async_timeout` as a polyfill on pre-3.11 Python. Closes: #1056Closes: #1052Fixes: #955
1 parent 4ddb039 commit 9f44bed

File tree

4 files changed

+104
-18
lines changed

4 files changed

+104
-18
lines changed

‎asyncpg/_asyncio_compat.py‎

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Backports from Python/Lib/asyncio for older Pythons
2+
#
3+
# Copyright (c) 2001-2023 Python Software Foundation; All Rights Reserved
4+
#
5+
# SPDX-License-Identifier: PSF-2.0
6+
7+
8+
importasyncio
9+
importfunctools
10+
importsys
11+
12+
from . importevents
13+
from . importexceptions
14+
15+
16+
ifsys.version_info< (3, 11):
17+
fromasync_timeoutimporttimeoutastimeout_ctx
18+
else:
19+
fromasyncioimporttimeoutastimeout_ctx
20+
21+
fromasync_timeoutimporttimeoutastimeout_ctx_2
22+
23+
24+
asyncdefwait_for(fut, timeout):
25+
"""Wait for the single Future or coroutine to complete, with timeout.
26+
27+
Coroutine will be wrapped in Task.
28+
29+
Returns result of the Future or coroutine. When a timeout occurs,
30+
it cancels the task and raises TimeoutError. To avoid the task
31+
cancellation, wrap it in shield().
32+
33+
If the wait is cancelled, the task is also cancelled.
34+
35+
If the task supresses the cancellation and returns a value instead,
36+
that value is returned.
37+
38+
This function is a coroutine.
39+
"""
40+
# The special case for timeout <= 0 is for the following case:
41+
#
42+
# async def test_waitfor():
43+
# func_started = False
44+
#
45+
# async def func():
46+
# nonlocal func_started
47+
# func_started = True
48+
#
49+
# try:
50+
# await asyncio.wait_for(func(), 0)
51+
# except asyncio.TimeoutError:
52+
# assert not func_started
53+
# else:
54+
# assert False
55+
#
56+
# asyncio.run(test_waitfor())
57+
58+
iftimeoutisnotNoneandtimeout<=0:
59+
fut=asyncio.ensure_future(fut)
60+
61+
iffut.done():
62+
returnfut.result()
63+
64+
await_cancel_and_wait(fut)
65+
try:
66+
returnfut.result()
67+
exceptexceptions.CancelledErrorasexc:
68+
raiseTimeoutErrorfromexc
69+
70+
asyncwithtimeout_ctx(timeout):
71+
returnawaitfut
72+
73+
74+
asyncdef_cancel_and_wait(fut):
75+
"""Cancel the *fut* future or task and wait until it completes."""
76+
77+
loop=events.get_running_loop()
78+
waiter=loop.create_future()
79+
cb=functools.partial(_release_waiter, waiter)
80+
fut.add_done_callback(cb)
81+
82+
try:
83+
fut.cancel()
84+
# We cannot wait on *fut* directly to make
85+
# sure _cancel_and_wait itself is reliably cancellable.
86+
awaitwaiter
87+
finally:
88+
fut.remove_done_callback(cb)
89+
90+
91+
def_release_waiter(waiter, *args):
92+
ifnotwaiter.done():
93+
waiter.set_result(None)

‎asyncpg/compat.py‎

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
66

77

8-
importasyncio
98
importpathlib
109
importplatform
1110
importtyping
11+
importsys
1212

1313

1414
SYSTEM=platform.uname().system
@@ -49,17 +49,7 @@ async def wait_closed(stream):
4949
pass
5050

5151

52-
# Workaround for https://bugs.python.org/issue37658
53-
asyncdefwait_for(fut, timeout):
54-
iftimeoutisNone:
55-
returnawaitfut
56-
57-
fut=asyncio.ensure_future(fut)
58-
59-
try:
60-
returnawaitasyncio.wait_for(fut, timeout)
61-
exceptasyncio.CancelledError:
62-
iffut.done():
63-
returnfut.result()
64-
else:
65-
raise
52+
ifsys.version_info< (3, 12):
53+
from ._asyncio_compatimportwait_foraswait_for# noqa: F401
54+
else:
55+
fromasyncioimportwait_foraswait_for# noqa: F401

‎asyncpg/protocol/protocol.pyx‎

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ cdef class BaseProtocol(CoreProtocol):
249249

250250
while more:
251251
with timer:
252-
await asyncio.wait_for(
252+
await compat.wait_for(
253253
self.writing_allowed.wait(),
254254
timeout=timer.get_remaining_budget())
255255
# On Windows the above event somehow won't allow context
@@ -383,7 +383,7 @@ cdef class BaseProtocol(CoreProtocol):
383383
ifbuffer:
384384
try:
385385
with timer:
386-
await asyncio.wait_for(
386+
await compat.wait_for(
387387
sink(buffer),
388388
timeout=timer.get_remaining_budget())
389389
except (Exception, asyncio.CancelledError) as ex:
@@ -511,7 +511,7 @@ cdef class BaseProtocol(CoreProtocol):
511511
with timer:
512512
await self.writing_allowed.wait()
513513
with timer:
514-
chunk = await asyncio.wait_for(
514+
chunk = await compat.wait_for(
515515
iterator.__anext__(),
516516
timeout=timer.get_remaining_budget())
517517
self._write_copy_data_msg(chunk)

‎pyproject.toml‎

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ classifiers = [
2727
"Programming Language :: Python :: Implementation :: CPython",
2828
"Topic :: Database :: Front-Ends",
2929
]
30+
dependencies = [
31+
'async_timeout>=4.0.3; python_version < "3.12.0"'
32+
]
3033

3134
[project.urls]
3235
github = "https://github.com/MagicStack/asyncpg"

0 commit comments

Comments
(0)