Skip to content

Commit 313b2b2

Browse files
authored
Use the timeout context manager in the connection path (#1087)
Drop timeout management gymnastics from the `connect()` path and use the `timeout` context manager instead.
1 parent 8b45beb commit 313b2b2

File tree

5 files changed

+58
-55
lines changed

5 files changed

+58
-55
lines changed

‎asyncpg/compat.py‎

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,9 @@ async def wait_closed(stream):
5353
from ._asyncio_compatimportwait_foraswait_for# noqa: F401
5454
else:
5555
fromasyncioimportwait_foraswait_for# noqa: F401
56+
57+
58+
ifsys.version_info< (3, 11):
59+
from ._asyncio_compatimporttimeout_ctxastimeout# noqa: F401
60+
else:
61+
fromasyncioimporttimeoutastimeout# noqa: F401

‎asyncpg/connect_utils.py‎

Lines changed: 12 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
importstat
2121
importstruct
2222
importsys
23-
importtime
2423
importtyping
2524
importurllib.parse
2625
importwarnings
@@ -55,7 +54,6 @@ def parse(cls, sslmode):
5554
'ssl',
5655
'sslmode',
5756
'direct_tls',
58-
'connect_timeout',
5957
'server_settings',
6058
'target_session_attrs',
6159
])
@@ -262,7 +260,7 @@ def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]:
262260

263261
def_parse_connect_dsn_and_args(*, dsn, host, port, user,
264262
password, passfile, database, ssl,
265-
direct_tls, connect_timeout, server_settings,
263+
direct_tls, server_settings,
266264
target_session_attrs):
267265
# `auth_hosts` is the version of host information for the purposes
268266
# of reading the pgpass file.
@@ -655,14 +653,14 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
655653
params=_ConnectionParameters(
656654
user=user, password=password, database=database, ssl=ssl,
657655
sslmode=sslmode, direct_tls=direct_tls,
658-
connect_timeout=connect_timeout, server_settings=server_settings,
656+
server_settings=server_settings,
659657
target_session_attrs=target_session_attrs)
660658

661659
returnaddrs, params
662660

663661

664662
def_parse_connect_arguments(*, dsn, host, port, user, password, passfile,
665-
database, timeout, command_timeout,
663+
database, command_timeout,
666664
statement_cache_size,
667665
max_cached_statement_lifetime,
668666
max_cacheable_statement_size,
@@ -695,7 +693,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
695693
dsn=dsn, host=host, port=port, user=user,
696694
password=password, passfile=passfile, ssl=ssl,
697695
direct_tls=direct_tls, database=database,
698-
connect_timeout=timeout, server_settings=server_settings,
696+
server_settings=server_settings,
699697
target_session_attrs=target_session_attrs)
700698

701699
config=_ClientConfiguration(
@@ -799,17 +797,13 @@ async def _connect_addr(
799797
*,
800798
addr,
801799
loop,
802-
timeout,
803800
params,
804801
config,
805802
connection_class,
806803
record_class
807804
):
808805
assertloopisnotNone
809806

810-
iftimeout<=0:
811-
raiseasyncio.TimeoutError
812-
813807
params_input=params
814808
ifcallable(params.password):
815809
password=params.password()
@@ -827,21 +821,16 @@ async def _connect_addr(
827821
params_retry=params._replace(ssl=None)
828822
else:
829823
# skip retry if we don't have to
830-
returnawait__connect_addr(params, timeout, False, *args)
824+
returnawait__connect_addr(params, False, *args)
831825

832826
# first attempt
833-
before=time.monotonic()
834827
try:
835-
returnawait__connect_addr(params, timeout, True, *args)
828+
returnawait__connect_addr(params, True, *args)
836829
except_RetryConnectSignal:
837830
pass
838831

839832
# second attempt
840-
timeout-=time.monotonic() -before
841-
iftimeout<=0:
842-
raiseasyncio.TimeoutError
843-
else:
844-
returnawait__connect_addr(params_retry, timeout, False, *args)
833+
returnawait__connect_addr(params_retry, False, *args)
845834

846835

847836
class_RetryConnectSignal(Exception):
@@ -850,7 +839,6 @@ class _RetryConnectSignal(Exception):
850839

851840
asyncdef__connect_addr(
852841
params,
853-
timeout,
854842
retry,
855843
addr,
856844
loop,
@@ -882,15 +870,10 @@ async def __connect_addr(
882870
else:
883871
connector=loop.create_connection(proto_factory, *addr)
884872

885-
connector=asyncio.ensure_future(connector)
886-
before=time.monotonic()
887-
tr, pr=awaitcompat.wait_for(connector, timeout=timeout)
888-
timeout-=time.monotonic() -before
873+
tr, pr=awaitconnector
889874

890875
try:
891-
iftimeout<=0:
892-
raiseasyncio.TimeoutError
893-
awaitcompat.wait_for(connected, timeout=timeout)
876+
awaitconnected
894877
except (
895878
exceptions.InvalidAuthorizationSpecificationError,
896879
exceptions.ConnectionDoesNotExistError, # seen on Windows
@@ -993,23 +976,21 @@ async def _can_use_connection(connection, attr: SessionAttribute):
993976
returnawaitcan_use(connection)
994977

995978

996-
asyncdef_connect(*, loop, timeout, connection_class, record_class, **kwargs):
979+
asyncdef_connect(*, loop, connection_class, record_class, **kwargs):
997980
ifloopisNone:
998981
loop=asyncio.get_event_loop()
999982

1000-
addrs, params, config=_parse_connect_arguments(timeout=timeout, **kwargs)
983+
addrs, params, config=_parse_connect_arguments(**kwargs)
1001984
target_attr=params.target_session_attrs
1002985

1003986
candidates= []
1004987
chosen_connection=None
1005988
last_error=None
1006989
foraddrinaddrs:
1007-
before=time.monotonic()
1008990
try:
1009991
conn=await_connect_addr(
1010992
addr=addr,
1011993
loop=loop,
1012-
timeout=timeout,
1013994
params=params,
1014995
config=config,
1015996
connection_class=connection_class,
@@ -1019,10 +1000,8 @@ async def _connect(*, loop, timeout, connection_class, record_class, **kwargs):
10191000
ifawait_can_use_connection(conn, target_attr):
10201001
chosen_connection=conn
10211002
break
1022-
except(OSError, asyncio.TimeoutError, ConnectionError)asex:
1003+
exceptOSErrorasex:
10231004
last_error=ex
1024-
finally:
1025-
timeout-=time.monotonic() -before
10261005
else:
10271006
iftarget_attr==SessionAttribute.prefer_standbyandcandidates:
10281007
chosen_connection=random.choice(candidates)

‎asyncpg/connection.py‎

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
importwarnings
2121
importweakref
2222

23+
from . importcompat
2324
from . importconnect_utils
2425
from . importcursor
2526
from . importexceptions
@@ -2184,27 +2185,27 @@ async def connect(dsn=None, *,
21842185
ifloopisNone:
21852186
loop=asyncio.get_event_loop()
21862187

2187-
returnawaitconnect_utils._connect(
2188-
loop=loop,
2189-
timeout=timeout,
2190-
connection_class=connection_class,
2191-
record_class=record_class,
2192-
dsn=dsn,
2193-
host=host,
2194-
port=port,
2195-
user=user,
2196-
password=password,
2197-
passfile=passfile,
2198-
ssl=ssl,
2199-
direct_tls=direct_tls,
2200-
database=database,
2201-
server_settings=server_settings,
2202-
command_timeout=command_timeout,
2203-
statement_cache_size=statement_cache_size,
2204-
max_cached_statement_lifetime=max_cached_statement_lifetime,
2205-
max_cacheable_statement_size=max_cacheable_statement_size,
2206-
target_session_attrs=target_session_attrs
2207-
)
2188+
asyncwithcompat.timeout(timeout):
2189+
returnawaitconnect_utils._connect(
2190+
loop=loop,
2191+
connection_class=connection_class,
2192+
record_class=record_class,
2193+
dsn=dsn,
2194+
host=host,
2195+
port=port,
2196+
user=user,
2197+
password=password,
2198+
passfile=passfile,
2199+
ssl=ssl,
2200+
direct_tls=direct_tls,
2201+
database=database,
2202+
server_settings=server_settings,
2203+
command_timeout=command_timeout,
2204+
statement_cache_size=statement_cache_size,
2205+
max_cached_statement_lifetime=max_cached_statement_lifetime,
2206+
max_cacheable_statement_size=max_cacheable_statement_size,
2207+
target_session_attrs=target_session_attrs
2208+
)
22082209

22092210

22102211
class_StatementCacheEntry:

‎tests/test_adversity.py‎

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,23 @@ async def test_connection_close_timeout(self):
2626
withself.assertRaises(asyncio.TimeoutError):
2727
awaitcon.close(timeout=0.5)
2828

29+
@tb.with_timeout(30.0)
30+
asyncdeftest_pool_acquire_timeout(self):
31+
pool=awaitself.create_pool(
32+
database='postgres', min_size=2, max_size=2)
33+
try:
34+
self.proxy.trigger_connectivity_loss()
35+
for_inrange(2):
36+
withself.assertRaises(asyncio.TimeoutError):
37+
asyncwithpool.acquire(timeout=0.5):
38+
pass
39+
self.proxy.restore_connectivity()
40+
asyncwithpool.acquire(timeout=0.5):
41+
pass
42+
finally:
43+
self.proxy.restore_connectivity()
44+
pool.terminate()
45+
2946
@tb.with_timeout(30.0)
3047
asyncdeftest_pool_release_timeout(self):
3148
pool=awaitself.create_pool(

‎tests/test_connect.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -891,7 +891,7 @@ def run_testcase(self, testcase):
891891
addrs, params=connect_utils._parse_connect_dsn_and_args(
892892
dsn=dsn, host=host, port=port, user=user, password=password,
893893
passfile=passfile, database=database, ssl=sslmode,
894-
direct_tls=False,connect_timeout=None,
894+
direct_tls=False,
895895
server_settings=server_settings,
896896
target_session_attrs=target_session_attrs)
897897

0 commit comments

Comments
(0)