Skip to content

Commit 77d4742

Browse files
committed
Add sslmode=allow support and fix =prefer retry
We didn't really retry the connection without SSL if the first SSL connection fails, that led to an issue when the server has SSL support but explicitly denies SSL connection through pg_hba.conf. This commit adds a retry in a new connection, which makes it easy to implement the sslmode=allow retry. Fixes#716
1 parent 53bea98 commit 77d4742

File tree

5 files changed

+242
-49
lines changed

5 files changed

+242
-49
lines changed

‎asyncpg/connect_utils.py‎

Lines changed: 100 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
'password',
3636
'database',
3737
'ssl',
38-
'ssl_is_advisory',
38+
'alt_retry_ssl_first',
3939
'connect_timeout',
4040
'server_settings',
4141
])
@@ -402,8 +402,13 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
402402
ifsslisNoneandhave_tcp_addrs:
403403
ssl='prefer'
404404

405-
# ssl_is_advisory is only allowed to come from the sslmode parameter.
406-
ssl_is_advisory=None
405+
# alt_retry_ssl_first is particularly for "allow" and "prefer"
406+
# to alternatively try SSL/non-SSL connections (once each if supported):
407+
# False - allow (try non-SSL first)
408+
# True - prefer (try SSL first)
409+
# None - other (don't retry, stick with the "ssl" parameter)
410+
alt_retry_ssl_first=None
411+
407412
ifisinstance(ssl, str):
408413
SSLMODES={
409414
'disable': 0,
@@ -420,26 +425,21 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
420425
raiseexceptions.InterfaceError(
421426
'`sslmode` parameter must be one of:{}'.format(modes))
422427

423-
# sslmode 'allow' is currently handled as 'prefer' because we're
424-
# missing the "retry with SSL" behavior for 'allow', but do have the
425-
# "retry without SSL" behavior for 'prefer'.
426-
# Not changing 'allow' to 'prefer' here would be effectively the same
427-
# as changing 'allow' to 'disable'.
428428
ifsslmode==SSLMODES['allow']:
429-
sslmode=SSLMODES['prefer']
429+
alt_retry_ssl_first=False
430+
elifsslmode==SSLMODES['prefer']:
431+
alt_retry_ssl_first=True
430432

431433
# docs at https://www.postgresql.org/docs/10/static/libpq-connect.html
432434
# Not implemented: sslcert & sslkey & sslrootcert & sslcrl params.
433-
ifsslmode<=SSLMODES['allow']:
435+
ifsslmode<SSLMODES['allow']:
434436
ssl=False
435-
ssl_is_advisory=sslmode>=SSLMODES['allow']
436437
else:
437438
ssl=ssl_module.create_default_context()
438439
ssl.check_hostname=sslmode>=SSLMODES['verify-full']
439440
ssl.verify_mode=ssl_module.CERT_REQUIRED
440441
ifsslmode<=SSLMODES['require']:
441442
ssl.verify_mode=ssl_module.CERT_NONE
442-
ssl_is_advisory=sslmode<=SSLMODES['prefer']
443443
elifsslisTrue:
444444
ssl=ssl_module.create_default_context()
445445

@@ -453,7 +453,8 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
453453

454454
params=_ConnectionParameters(
455455
user=user, password=password, database=database, ssl=ssl,
456-
ssl_is_advisory=ssl_is_advisory, connect_timeout=connect_timeout,
456+
alt_retry_ssl_first=alt_retry_ssl_first,
457+
connect_timeout=connect_timeout,
457458
server_settings=server_settings)
458459

459460
returnaddrs, params
@@ -520,9 +521,8 @@ def data_received(self, data):
520521
data==b'N'):
521522
# ssl_is_advisory will imply that ssl.verify_mode == CERT_NONE,
522523
# since the only way to get ssl_is_advisory is from
523-
# sslmode=prefer (or sslmode=allow). But be extra sure to
524-
# disallow insecure connections when the ssl context asks for
525-
# real security.
524+
# sslmode=prefer. But be extra sure to disallow insecure
525+
# connections when the ssl context asks for real security.
526526
self.on_data.set_result(False)
527527
else:
528528
self.on_data.set_exception(
@@ -566,6 +566,7 @@ async def _create_ssl_connection(protocol_factory, host, port, *,
566566
new_tr=tr
567567

568568
pg_proto=protocol_factory()
569+
pg_proto.is_ssl=do_ssl_upgrade
569570
pg_proto.connection_made(new_tr)
570571
new_tr.set_protocol(pg_proto)
571572

@@ -584,7 +585,9 @@ async def _create_ssl_connection(protocol_factory, host, port, *,
584585
tr.close()
585586

586587
try:
587-
returnawaitconn_factory(sock=sock)
588+
new_tr, pg_proto=awaitconn_factory(sock=sock)
589+
pg_proto.is_ssl=do_ssl_upgrade
590+
returnnew_tr, pg_proto
588591
except (Exception, asyncio.CancelledError):
589592
sock.close()
590593
raise
@@ -605,8 +608,6 @@ async def _connect_addr(
605608
iftimeout<=0:
606609
raiseasyncio.TimeoutError
607610

608-
connected=_create_future(loop)
609-
610611
params_input=params
611612
ifcallable(params.password):
612613
ifinspect.iscoroutinefunction(params.password):
@@ -615,6 +616,44 @@ async def _connect_addr(
615616
password=params.password()
616617

617618
params=params._replace(password=password)
619+
args= (addr, loop, config, connection_class, record_class, params_input)
620+
621+
# skip retry if alt_retry is not enabled
622+
ifparams.alt_retry_ssl_firstisNone:
623+
returnawait__connect_addr(params, timeout, *args)
624+
625+
# prepare the params (which attempt has ssl) for the 2 attempts
626+
params_retry=params._replace(ssl=None)
627+
ifnotparams.alt_retry_ssl_first:
628+
params, params_retry=params_retry, params
629+
630+
# first attempt
631+
before=time.monotonic()
632+
try:
633+
returnawait__connect_addr(params, timeout, *args)
634+
exceptConnectionError:
635+
pass
636+
637+
# the second attempt with alt_retry_ssl_first=None
638+
timeout-=time.monotonic() -before
639+
iftimeout<=0:
640+
raiseasyncio.TimeoutError
641+
else:
642+
params_retry=params_retry._replace(alt_retry_ssl_first=None)
643+
returnawait__connect_addr(params_retry, timeout, *args)
644+
645+
646+
asyncdef__connect_addr(
647+
params,
648+
timeout,
649+
addr,
650+
loop,
651+
config,
652+
connection_class,
653+
record_class,
654+
params_input,
655+
):
656+
connected=_create_future(loop)
618657

619658
proto_factory=lambda: protocol.Protocol(
620659
addr, connected, params, record_class, loop)
@@ -625,7 +664,7 @@ async def _connect_addr(
625664
elifparams.ssl:
626665
connector=_create_ssl_connection(
627666
proto_factory, *addr, loop=loop, ssl_context=params.ssl,
628-
ssl_is_advisory=params.ssl_is_advisory)
667+
ssl_is_advisory=params.alt_retry_ssl_first)
629668
else:
630669
connector=loop.create_connection(proto_factory, *addr)
631670

@@ -638,6 +677,23 @@ async def _connect_addr(
638677
iftimeout<=0:
639678
raiseasyncio.TimeoutError
640679
awaitcompat.wait_for(connected, timeout=timeout)
680+
exceptexceptions.InvalidAuthorizationSpecificationError:
681+
tr.close()
682+
683+
# pr.is_ssl is a bool, so this equal test implies
684+
# alt_retry_ssl_first is not None (should do alt_retry)
685+
ifparams.alt_retry_ssl_first==pr.is_ssl:
686+
# Elevate the error to ConnectionError to trigger retry
687+
raiseConnectionError("Connection rejected trying{} SSL".format(
688+
'with'ifpr.is_sslelse'without'))
689+
690+
else:
691+
# Don't retry if alt_retry_ssl_first is None, or we don't need to
692+
# (alt_retry_ssl_first=True and pr.is_ssl=False means the server
693+
# doesn't support SSL, and we've already tried to Startup without
694+
# SSL but failed; The opposite case doesn't exist).
695+
raise
696+
641697
except (Exception, asyncio.CancelledError):
642698
tr.close()
643699
raise
@@ -684,6 +740,7 @@ class CancelProto(asyncio.Protocol):
684740

685741
def__init__(self):
686742
self.on_disconnect=_create_future(loop)
743+
self.is_ssl=False
687744

688745
defconnection_lost(self, exc):
689746
ifnotself.on_disconnect.done():
@@ -692,17 +749,30 @@ def connection_lost(self, exc):
692749
ifisinstance(addr, str):
693750
tr, pr=awaitloop.create_unix_connection(CancelProto, addr)
694751
else:
695-
ifparams.ssl:
696-
tr, pr=await_create_ssl_connection(
697-
CancelProto,
698-
*addr,
699-
loop=loop,
700-
ssl_context=params.ssl,
701-
ssl_is_advisory=params.ssl_is_advisory)
752+
asyncdef_connect(params_in, ssl_is_advisory):
753+
ifparams_in.ssl:
754+
returnawait_create_ssl_connection(
755+
CancelProto,
756+
*addr,
757+
loop=loop,
758+
ssl_context=params_in.ssl,
759+
ssl_is_advisory=ssl_is_advisory)
760+
else:
761+
returnawaitloop.create_connection(
762+
CancelProto, *addr)
763+
_set_nodelay(_get_socket(tr))
764+
765+
ifparams.alt_retry_ssl_firstisNone:
766+
tr, pr=await_connect(params, False)
702767
else:
703-
tr, pr=awaitloop.create_connection(
704-
CancelProto, *addr)
705-
_set_nodelay(_get_socket(tr))
768+
params_retry=params._replace(ssl=None)
769+
ifnotparams.alt_retry_ssl_first:
770+
params, params_retry=params_retry, params
771+
try:
772+
tr, pr=await_connect(params, True)
773+
exceptConnectionError:
774+
tr, pr=await_connect(
775+
params._replace(alt_retry_ssl_first=None), False)
706776

707777
# Pack a CancelRequest message
708778
msg=struct.pack('!llll', 16, 80877102, backend_pid, backend_secret)

‎asyncpg/connection.py‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1879,7 +1879,8 @@ async def connect(dsn=None, *,
18791879
- ``'disable'`` - SSL is disabled (equivalent to ``False``)
18801880
- ``'prefer'`` - try SSL first, fallback to non-SSL connection
18811881
if SSL connection fails
1882-
- ``'allow'`` - currently equivalent to ``'prefer'``
1882+
- ``'allow'`` - try without SSL first, then retry with SSL if the first
1883+
attempt fails.
18831884
- ``'require'`` - only try an SSL connection. Certificate
18841885
verification errors are ignored
18851886
- ``'verify-ca'`` - only try an SSL connection, and verify

‎asyncpg/protocol/protocol.pxd‎

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ cdef class BaseProtocol(CoreProtocol):
5252

5353
readonly uint64_t queries_count
5454

55+
bint _is_ssl
56+
5557
PreparedStatementState statement
5658

5759
cdef get_connection(self)

‎asyncpg/protocol/protocol.pyx‎

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ cdef class BaseProtocol(CoreProtocol):
103103

104104
self.queries_count =0
105105

106+
self._is_ssl =False
107+
106108
try:
107109
self.create_future = loop.create_future
108110
exceptAttributeError:
@@ -943,6 +945,14 @@ cdef class BaseProtocol(CoreProtocol):
943945
defresume_writing(self):
944946
self.writing_allowed.set()
945947

948+
@property
949+
defis_ssl(self):
950+
returnself._is_ssl
951+
952+
@is_ssl.setter
953+
defis_ssl(self, value):
954+
self._is_ssl = value
955+
946956

947957
classTimer:
948958
def__init__(self, budget):

0 commit comments

Comments
(0)