Skip to content

Commit bdba7ce

Browse files
committed
Use loop.start_tls() to upgrade connections to SSL
The old way of TLS upgrade (openining a connection, asking postgres to do TLS and then duping the underlying socket) seems not to work anymore on Windows with Python 3.8.
1 parent d655a39 commit bdba7ce

File tree

2 files changed

+117
-80
lines changed

2 files changed

+117
-80
lines changed

‎asyncpg/connect_utils.py‎

Lines changed: 112 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,95 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
504504
returnaddrs, params, config
505505

506506

507+
classTLSUpgradeProto(asyncio.Protocol):
508+
def__init__(self, loop, host, port, ssl_context, ssl_is_advisory):
509+
self.on_data=_create_future(loop)
510+
self.host=host
511+
self.port=port
512+
self.ssl_context=ssl_context
513+
self.ssl_is_advisory=ssl_is_advisory
514+
515+
defdata_received(self, data):
516+
ifdata==b'S':
517+
self.on_data.set_result(True)
518+
elif (self.ssl_is_advisoryand
519+
self.ssl_context.verify_mode==ssl_module.CERT_NONEand
520+
data==b'N'):
521+
# ssl_is_advisory will imply that ssl.verify_mode == CERT_NONE,
522+
# since the only way to get ssl_is_advisory is from
523+
# sslmode=prefer (or sslmode=allow). But be extra sure to
524+
# disallow insecure connections when the ssl context asks for
525+
# real security.
526+
self.on_data.set_result(False)
527+
else:
528+
self.on_data.set_exception(
529+
ConnectionError(
530+
'PostgreSQL server at "{host}:{port}" '
531+
'rejected SSL upgrade'.format(
532+
host=self.host, port=self.port)))
533+
534+
defconnection_lost(self, exc):
535+
ifnotself.on_data.done():
536+
ifexcisNone:
537+
exc=ConnectionError('unexpected connection_lost() call')
538+
self.on_data.set_exception(exc)
539+
540+
541+
asyncdef_create_ssl_connection(protocol_factory, host, port, *,
542+
loop, ssl_context, ssl_is_advisory=False):
543+
544+
ifssl_contextisTrue:
545+
ssl_context=ssl_module.create_default_context()
546+
547+
tr, pr=awaitloop.create_connection(
548+
lambda: TLSUpgradeProto(loop, host, port,
549+
ssl_context, ssl_is_advisory),
550+
host, port)
551+
552+
tr.write(struct.pack('!ll', 8, 80877103)) # SSLRequest message.
553+
554+
try:
555+
do_ssl_upgrade=awaitpr.on_data
556+
except (Exception, asyncio.CancelledError):
557+
tr.close()
558+
raise
559+
560+
ifhasattr(loop, 'start_tls'):
561+
ifdo_ssl_upgrade:
562+
try:
563+
new_tr=awaitloop.start_tls(
564+
tr, pr, ssl_context, server_hostname=host)
565+
except (Exception, asyncio.CancelledError):
566+
tr.close()
567+
raise
568+
else:
569+
new_tr=tr
570+
571+
pg_proto=protocol_factory()
572+
pg_proto.connection_made(new_tr)
573+
new_tr.set_protocol(pg_proto)
574+
575+
returnnew_tr, pg_proto
576+
else:
577+
conn_factory=functools.partial(
578+
loop.create_connection, protocol_factory)
579+
580+
ifdo_ssl_upgrade:
581+
conn_factory=functools.partial(
582+
conn_factory, ssl=ssl_context, server_hostname=host)
583+
584+
sock=_get_socket(tr)
585+
sock=sock.dup()
586+
_set_nodelay(sock)
587+
tr.close()
588+
589+
try:
590+
returnawaitconn_factory(sock=sock)
591+
except (Exception, asyncio.CancelledError):
592+
sock.close()
593+
raise
594+
595+
507596
asyncdef_connect_addr(*, addr, loop, timeout, params, config,
508597
connection_class):
509598
assertloopisnotNone
@@ -526,8 +615,6 @@ async def _connect_addr(*, addr, loop, timeout, params, config,
526615
else:
527616
connector=loop.create_connection(proto_factory, *addr)
528617

529-
connector=asyncio.ensure_future(connector)
530-
531618
before=time.monotonic()
532619
try:
533620
tr, pr=awaitasyncio.wait_for(
@@ -575,79 +662,41 @@ async def _connect(*, loop, timeout, connection_class, **kwargs):
575662
raiselast_error
576663

577664

578-
asyncdef_negotiate_ssl_connection(host, port, conn_factory, *, loop, ssl,
579-
server_hostname, ssl_is_advisory=False):
580-
# Note: ssl_is_advisory only affects behavior when the server does not
581-
# accept SSLRequests. If the SSLRequest is accepted but either the SSL
582-
# negotiation fails or the PostgreSQL user isn't permitted to use SSL,
583-
# there's nothing that would attempt to reconnect with a non-SSL socket.
584-
reader, writer=awaitasyncio.open_connection(host, port)
585-
586-
tr=writer.transport
587-
try:
588-
sock=_get_socket(tr)
589-
_set_nodelay(sock)
590-
591-
writer.write(struct.pack('!ll', 8, 80877103)) # SSLRequest message.
592-
awaitwriter.drain()
593-
resp=awaitreader.readexactly(1)
594-
595-
ifresp==b'S':
596-
conn_factory=functools.partial(
597-
conn_factory, ssl=ssl, server_hostname=server_hostname)
598-
elif (ssl_is_advisoryand
599-
ssl.verify_mode==ssl_module.CERT_NONEand
600-
resp==b'N'):
601-
# ssl_is_advisory will imply that ssl.verify_mode == CERT_NONE,
602-
# since the only way to get ssl_is_advisory is from sslmode=prefer
603-
# (or sslmode=allow). But be extra sure to disallow insecure
604-
# connections when the ssl context asks for real security.
605-
pass
606-
else:
607-
raiseConnectionError(
608-
'PostgreSQL server at "{}:{}" rejected SSL upgrade'.format(
609-
host, port))
610-
611-
sock=sock.dup() # Must come before tr.close()
612-
finally:
613-
writer.close()
614-
awaitcompat.wait_closed(writer)
615-
616-
try:
617-
returnawaitconn_factory(sock=sock) # Must come after tr.close()
618-
except (Exception, asyncio.CancelledError):
619-
sock.close()
620-
raise
665+
asyncdef_cancel(*, loop, addr, params: _ConnectionParameters,
666+
backend_pid, backend_secret):
621667

668+
classCancelProto(asyncio.Protocol):
622669

623-
asyncdef_create_ssl_connection(protocol_factory, host, port, *,
624-
loop, ssl_context, ssl_is_advisory=False):
625-
returnawait_negotiate_ssl_connection(
626-
host, port,
627-
functools.partial(loop.create_connection, protocol_factory),
628-
loop=loop,
629-
ssl=ssl_context,
630-
server_hostname=host,
631-
ssl_is_advisory=ssl_is_advisory)
670+
def__init__(self):
671+
self.on_disconnect=_create_future(loop)
632672

673+
defconnection_lost(self, exc):
674+
ifnotself.on_disconnect.done():
675+
self.on_disconnect.set_result(True)
633676

634-
asyncdef_open_connection(*, loop, addr, params: _ConnectionParameters):
635677
ifisinstance(addr, str):
636-
r, w=awaitasyncio.open_unix_connection(addr)
678+
tr, pr=awaitloop.create_unix_connection(CancelProto, addr)
637679
else:
638680
ifparams.ssl:
639-
r, w=await_negotiate_ssl_connection(
681+
tr, pr=await_create_ssl_connection(
682+
CancelProto,
640683
*addr,
641-
asyncio.open_connection,
642684
loop=loop,
643-
ssl=params.ssl,
644-
server_hostname=addr[0],
685+
ssl_context=params.ssl,
645686
ssl_is_advisory=params.ssl_is_advisory)
646687
else:
647-
r, w=awaitasyncio.open_connection(*addr)
648-
_set_nodelay(_get_socket(w.transport))
688+
tr, pr=awaitloop.create_connection(
689+
CancelProto, *addr)
690+
_set_nodelay(_get_socket(tr))
691+
692+
# Pack a CancelRequest message
693+
msg=struct.pack('!llll', 16, 80877102, backend_pid, backend_secret)
649694

650-
returnr, w
695+
try:
696+
tr.write(msg)
697+
awaitpr.on_disconnect
698+
finally:
699+
tr.close()
651700

652701

653702
def_get_socket(transport):

‎asyncpg/connection.py‎

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
importcollections
1111
importcollections.abc
1212
importitertools
13-
importstruct
1413
importsys
1514
importtime
1615
importtraceback
@@ -1186,24 +1185,16 @@ async def _cleanup_stmts(self):
11861185
awaitself._protocol.close_statement(stmt, protocol.NO_TIMEOUT)
11871186

11881187
asyncdef_cancel(self, waiter):
1189-
r=w=None
1190-
11911188
try:
11921189
# Open new connection to the server
1193-
r, w=awaitconnect_utils._open_connection(
1194-
loop=self._loop, addr=self._addr, params=self._params)
1195-
1196-
# Pack CancelRequest message
1197-
msg=struct.pack('!llll', 16, 80877102,
1198-
self._protocol.backend_pid,
1199-
self._protocol.backend_secret)
1200-
1201-
w.write(msg)
1202-
awaitr.read() # Wait until EOF
1190+
awaitconnect_utils._cancel(
1191+
loop=self._loop, addr=self._addr, params=self._params,
1192+
backend_pid=self._protocol.backend_pid,
1193+
backend_secret=self._protocol.backend_secret)
12031194
exceptConnectionResetErrorasex:
12041195
# On some systems Postgres will reset the connection
12051196
# after processing the cancellation command.
1206-
ifrisNoneandnotwaiter.done():
1197+
ifnotwaiter.done():
12071198
waiter.set_exception(ex)
12081199
exceptasyncio.CancelledError:
12091200
# There are two scenarios in which the cancellation
@@ -1221,9 +1212,6 @@ async def _cancel(self, waiter):
12211212
compat.current_asyncio_task(self._loop))
12221213
ifnotwaiter.done():
12231214
waiter.set_result(None)
1224-
ifwisnotNone:
1225-
w.close()
1226-
awaitcompat.wait_closed(w)
12271215

12281216
def_cancel_current_command(self, waiter):
12291217
self._cancellations.add(self._loop.create_task(self._cancel(waiter)))

0 commit comments

Comments
(0)