Skip to content

Commit c337261

Browse files
committed
Handle environments without home dir
1 parent 247b1a5 commit c337261

File tree

3 files changed

+67
-20
lines changed

3 files changed

+67
-20
lines changed

‎asyncpg/compat.py‎

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
importasyncio
99
importpathlib
1010
importplatform
11+
importtyping
1112

1213

1314
SYSTEM=platform.uname().system
@@ -18,7 +19,7 @@
1819

1920
CSIDL_APPDATA=0x001a
2021

21-
defget_pg_home_directory() ->pathlib.Path:
22+
defget_pg_home_directory() ->typing.Optional[pathlib.Path]:
2223
# We cannot simply use expanduser() as that returns the user's
2324
# home directory, whereas Postgres stores its config in
2425
# %AppData% on Windows.
@@ -30,8 +31,11 @@ def get_pg_home_directory() -> pathlib.Path:
3031
returnpathlib.Path(buf.value) /'postgresql'
3132

3233
else:
33-
defget_pg_home_directory() ->pathlib.Path:
34-
returnpathlib.Path.home()
34+
defget_pg_home_directory() ->typing.Optional[pathlib.Path]:
35+
try:
36+
returnpathlib.Path.home()
37+
except (RuntimeError, KeyError):
38+
returnNone
3539

3640

3741
asyncdefwait_closed(stream):

‎asyncpg/connect_utils.py‎

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -249,8 +249,12 @@ def _parse_tls_version(tls_version):
249249
)
250250

251251

252-
def_dot_postgresql_path(filename) ->pathlib.Path:
253-
return (pathlib.Path.home() /'.postgresql'/filename).resolve()
252+
def_dot_postgresql_path(filename) ->typing.Optional[pathlib.Path]:
253+
homedir=compat.get_pg_home_directory()
254+
ifhomedirisNone:
255+
returnNone
256+
257+
return (homedir/'.postgresql'/filename).resolve()
254258

255259

256260
def_parse_connect_dsn_and_args(*, dsn, host, port, user,
@@ -501,11 +505,16 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
501505
ssl.load_verify_locations(cafile=sslrootcert)
502506
ssl.verify_mode=ssl_module.CERT_REQUIRED
503507
else:
504-
sslrootcert=_dot_postgresql_path('root.crt')
505508
try:
509+
sslrootcert=_dot_postgresql_path('root.crt')
510+
assertsslrootcertisnotNone
506511
ssl.load_verify_locations(cafile=sslrootcert)
507-
exceptFileNotFoundError:
512+
except(AssertionError, FileNotFoundError):
508513
ifsslmode>SSLMode.require:
514+
ifsslrootcertisNone:
515+
raiseRuntimeError(
516+
'Cannot determine home directory'
517+
)
509518
raiseValueError(
510519
f'root certificate file "{sslrootcert}" does '
511520
f'not exist\nEither provide the file or '
@@ -526,18 +535,20 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
526535
ssl.verify_flags|=ssl_module.VERIFY_CRL_CHECK_CHAIN
527536
else:
528537
sslcrl=_dot_postgresql_path('root.crl')
529-
try:
530-
ssl.load_verify_locations(cafile=sslcrl)
531-
exceptFileNotFoundError:
532-
pass
533-
else:
534-
ssl.verify_flags|=ssl_module.VERIFY_CRL_CHECK_CHAIN
538+
ifsslcrlisnotNone:
539+
try:
540+
ssl.load_verify_locations(cafile=sslcrl)
541+
exceptFileNotFoundError:
542+
pass
543+
else:
544+
ssl.verify_flags|= \
545+
ssl_module.VERIFY_CRL_CHECK_CHAIN
535546

536547
ifsslkeyisNone:
537548
sslkey=os.getenv('PGSSLKEY')
538549
ifnotsslkey:
539550
sslkey=_dot_postgresql_path('postgresql.key')
540-
ifnotsslkey.exists():
551+
ifsslkeyisnotNoneandnotsslkey.exists():
541552
sslkey=None
542553
ifnotsslpassword:
543554
sslpassword=''
@@ -549,12 +560,15 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
549560
)
550561
else:
551562
sslcert=_dot_postgresql_path('postgresql.crt')
552-
try:
553-
ssl.load_cert_chain(
554-
sslcert, keyfile=sslkey, password=lambda: sslpassword
555-
)
556-
exceptFileNotFoundError:
557-
pass
563+
ifsslcertisnotNone:
564+
try:
565+
ssl.load_cert_chain(
566+
sslcert,
567+
keyfile=sslkey,
568+
password=lambda: sslpassword
569+
)
570+
exceptFileNotFoundError:
571+
pass
558572

559573
# OpenSSL 1.1.1 keylog file, copied from create_default_context()
560574
ifhasattr(ssl, 'keylog_filename'):

‎tests/test_connect.py‎

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,14 @@ def mock_dot_postgresql(*, ca=True, crl=False, client=False, protected=False):
7171
yield
7272

7373

74+
@contextlib.contextmanager
75+
defmock_no_home_dir():
76+
withunittest.mock.patch(
77+
'pathlib.Path.home', unittest.mock.Mock(side_effect=RuntimeError)
78+
):
79+
yield
80+
81+
7482
classTestSettings(tb.ConnectedTestCase):
7583

7684
asyncdeftest_get_settings_01(self):
@@ -1257,6 +1265,27 @@ async def test_connection_implicit_host(self):
12571265
user=conn_spec.get('user'))
12581266
awaitcon.close()
12591267

1268+
@unittest.skipIf(os.environ.get('PGHOST'), 'unmanaged cluster')
1269+
asyncdeftest_connection_no_home_dir(self):
1270+
withmock_no_home_dir():
1271+
con=awaitself.connect(
1272+
dsn='postgresql://foo/',
1273+
user='postgres',
1274+
database='postgres',
1275+
host='localhost')
1276+
awaitcon.fetchval('SELECT 42')
1277+
awaitcon.close()
1278+
1279+
withself.assertRaisesRegex(
1280+
RuntimeError,
1281+
'Cannot determine home directory'
1282+
):
1283+
withmock_no_home_dir():
1284+
awaitself.connect(
1285+
host='localhost',
1286+
user='ssl_user',
1287+
ssl='verify-full')
1288+
12601289

12611290
classBaseTestSSLConnection(tb.ConnectedTestCase):
12621291
@classmethod

0 commit comments

Comments
(0)