Skip to content

Commit 50bb192

Browse files
committed
TLSUpgradeProto: don't set multiple results for an event
In the case of a misbehaving server, the client may receive more than one byte in separate data_received() invocations from the server. While we can't do much sane with this, we should handle it gracefully and not crash with asyncio.InvalidStateError when trying to set another result on the event. Fixes#729
1 parent c2c8d20 commit 50bb192

File tree

3 files changed

+100
-0
lines changed

3 files changed

+100
-0
lines changed

‎asyncpg/_testbase/__init__.py‎

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
importlogging
1414
importos
1515
importre
16+
importsocket
1617
importtextwrap
1718
importtime
1819
importtraceback
@@ -525,3 +526,42 @@ def connect_standby(cls, **kwargs):
525526
kwargs
526527
)
527528
returnpg_connection.connect(**conn_spec, loop=cls.loop)
529+
530+
531+
classInstrumentedServer:
532+
"""
533+
A socket server for testing.
534+
It will write each item from `data`, and wait for the corresponding event
535+
in `received_events` to notify that it was received before writing the next
536+
item from `data`.
537+
"""
538+
def__init__(self, data, received_events):
539+
assertlen(data) ==len(received_events)
540+
self._data=data
541+
self._server=None
542+
self._received_events=received_events
543+
544+
asyncdef_handle_client(self, _reader, writer):
545+
fordatum, received_eventinzip(self._data, self._received_events):
546+
writer.write(datum)
547+
awaitwriter.drain()
548+
awaitreceived_event.wait()
549+
550+
writer.close()
551+
awaitwriter.wait_closed()
552+
553+
asyncdefstart(self):
554+
"""Start the server."""
555+
self._server=awaitasyncio.start_server(self._handle_client, 'localhost', 0)
556+
assertself._server.sockets
557+
sock=self._server.sockets[0]
558+
# Account for IPv4 and IPv6
559+
addr, port=sock.getsockname()[:2]
560+
return{
561+
'host': addr,
562+
'port': port,
563+
}
564+
565+
defstop(self):
566+
"""Stop the server."""
567+
self._server.close()

‎asyncpg/connect_utils.py‎

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -714,6 +714,11 @@ def __init__(self, loop, host, port, ssl_context, ssl_is_advisory):
714714
self.ssl_is_advisory=ssl_is_advisory
715715

716716
defdata_received(self, data):
717+
ifself.on_data.done():
718+
# Only expect to receive one byte here; ignore unsolicited further
719+
# data.
720+
return
721+
717722
ifdata==b'S':
718723
self.on_data.set_result(True)
719724
elif (self.ssl_is_advisoryand

‎tests/test_connect.py‎

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
importasyncio
99
importcontextlib
10+
importcopy
1011
importgc
1112
importipaddress
1213
importos
@@ -17,11 +18,13 @@
1718
importstat
1819
importtempfile
1920
importtextwrap
21+
importtime
2022
importunittest
2123
importunittest.mock
2224
importurllib.parse
2325
importwarnings
2426
importweakref
27+
fromunittestimportmock
2528

2629
importasyncpg
2730
fromasyncpgimport_testbaseastb
@@ -1989,6 +1992,58 @@ async def test_prefer_standby_picks_master_when_standby_is_down(self):
19891992
awaitcon.close()
19901993

19911994

1995+
classTestMisbehavingServer(tb.TestCase):
1996+
"""Tests for client connection behaviour given a misbehaving server."""
1997+
1998+
asyncdeftest_tls_upgrade_extra_data_received(self):
1999+
data= [
2000+
# First, the server writes b"S" to signal it is willing to perform
2001+
# SSL
2002+
b"S",
2003+
# Then, the server writes an unsolicted arbitrary byte afterwards
2004+
b"N",
2005+
]
2006+
data_received_events= [asyncio.Event() for_indata]
2007+
2008+
# Patch out the loop's create_connection so we can instrument the proto
2009+
# we return.
2010+
old_create_conn=self.loop.create_connection
2011+
2012+
asyncdef_mock_create_conn(*args, **kwargs):
2013+
transport, proto=awaitold_create_conn(*args, **kwargs)
2014+
old_data_received=proto.data_received
2015+
2016+
num_received=0
2017+
2018+
def_data_received(*args, **kwargs):
2019+
nonlocalnum_received
2020+
# Call the original data_received method
2021+
ret=old_data_received(*args, **kwargs)
2022+
# Fire the event to signal we've received this datum now.
2023+
data_received_events[num_received].set()
2024+
num_received+=1
2025+
returnret
2026+
2027+
proto.data_received=_data_received
2028+
2029+
# To deterministically provoke the race we're interested in for
2030+
# this regression test, wait for all data to be received before
2031+
# returning from create_connection().
2032+
awaitdata_received_events[-1].wait()
2033+
returntransport, proto
2034+
2035+
server=tb.InstrumentedServer(data, data_received_events)
2036+
conn_spec=awaitserver.start()
2037+
2038+
# The call to connect() should raise a ConnectionResetError as the
2039+
# server will close the connection after writing all the data.
2040+
with (mock.patch.object(self.loop, "create_connection", side_effect=_mock_create_conn),
2041+
self.assertRaises(ConnectionResetError)):
2042+
awaitpg_connection.connect(**conn_spec, ssl=True, loop=self.loop)
2043+
2044+
server.stop()
2045+
2046+
19922047
def_get_connected_host(con):
19932048
peername=con._transport.get_extra_info('peername')
19942049
ifisinstance(peername, tuple):

0 commit comments

Comments
(0)