@@ -331,12 +331,9 @@ def create_predefined_instance_roles(self) -> None:
331
331
connection = None
332
332
try :
333
333
for database in ["postgres" , "template1" ]:
334
- with (
335
- self ._connect_to_database (
336
- database = database ,
337
- ) as connection ,
338
- connection .cursor () as cursor ,
339
- ):
334
+ with self ._connect_to_database (
335
+ database = database ,
336
+ ) as connection , connection .cursor () as cursor :
340
337
cursor .execute (SQL ("CREATE EXTENSION IF NOT EXISTS set_user;" ))
341
338
finally :
342
339
if connection is not None :
@@ -421,10 +418,9 @@ def delete_user(self, user: str) -> None:
421
418
# Existing objects need to be reassigned in each database
422
419
# before the user can be deleted.
423
420
for database in databases :
424
- with (
425
- self ._connect_to_database (database ) as connection ,
426
- connection .cursor () as cursor ,
427
- ):
421
+ with self ._connect_to_database (
422
+ database
423
+ ) as connection , connection .cursor () as cursor :
428
424
cursor .execute (
429
425
SQL ("REASSIGN OWNED BY {} TO {};" ).format (
430
426
Identifier (user ), Identifier (self .user )
@@ -515,10 +511,9 @@ def enable_disable_extensions(
515
511
516
512
# Enable/disabled the extension in each database.
517
513
for database in databases :
518
- with (
519
- self ._connect_to_database (database = database ) as connection ,
520
- connection .cursor () as cursor ,
521
- ):
514
+ with self ._connect_to_database (
515
+ database = database
516
+ ) as connection , connection .cursor () as cursor :
522
517
for extension , enable in ordered_extensions .items ():
523
518
cursor .execute (
524
519
f"CREATE EXTENSION IF NOT EXISTS { extension } ;"
@@ -562,10 +557,9 @@ def get_postgresql_text_search_configs(self) -> Set[str]:
562
557
Returns:
563
558
Set of PostgreSQL text search configs.
564
559
"""
565
- with (
566
- self ._connect_to_database (database_host = self .current_host ) as connection ,
567
- connection .cursor () as cursor ,
568
- ):
560
+ with self ._connect_to_database (
561
+ database_host = self .current_host
562
+ ) as connection , connection .cursor () as cursor :
569
563
cursor .execute ("SELECT CONCAT('pg_catalog.', cfgname) FROM pg_ts_config;" )
570
564
text_search_configs = cursor .fetchall ()
571
565
return {text_search_config [0 ] for text_search_config in text_search_configs }
@@ -576,10 +570,9 @@ def get_postgresql_timezones(self) -> Set[str]:
576
570
Returns:
577
571
Set of PostgreSQL timezones.
578
572
"""
579
- with (
580
- self ._connect_to_database (database_host = self .current_host ) as connection ,
581
- connection .cursor () as cursor ,
582
- ):
573
+ with self ._connect_to_database (
574
+ database_host = self .current_host
575
+ ) as connection , connection .cursor () as cursor :
583
576
cursor .execute ("SELECT name FROM pg_timezone_names;" )
584
577
timezones = cursor .fetchall ()
585
578
return {timezone [0 ] for timezone in timezones }
@@ -590,10 +583,9 @@ def get_postgresql_default_table_access_methods(self) -> Set[str]:
590
583
Returns:
591
584
Set of PostgreSQL table access methods.
592
585
"""
593
- with (
594
- self ._connect_to_database (database_host = self .current_host ) as connection ,
595
- connection .cursor () as cursor ,
596
- ):
586
+ with self ._connect_to_database (
587
+ database_host = self .current_host
588
+ ) as connection , connection .cursor () as cursor :
597
589
cursor .execute ("SELECT amname FROM pg_am WHERE amtype = 't';" )
598
590
access_methods = cursor .fetchall ()
599
591
return {access_method [0 ] for access_method in access_methods }
@@ -606,10 +598,9 @@ def get_postgresql_version(self, current_host=True) -> str:
606
598
"""
607
599
host = self .current_host if current_host else None
608
600
try :
609
- with (
610
- self ._connect_to_database (database_host = host ) as connection ,
611
- connection .cursor () as cursor ,
612
- ):
601
+ with self ._connect_to_database (
602
+ database_host = host
603
+ ) as connection , connection .cursor () as cursor :
613
604
cursor .execute ("SELECT version();" )
614
605
# Split to get only the version number.
615
606
return cursor .fetchone ()[0 ].split (" " )[1 ]
@@ -628,12 +619,9 @@ def is_tls_enabled(self, check_current_host: bool = False) -> bool:
628
619
whether TLS is enabled.
629
620
"""
630
621
try :
631
- with (
632
- self ._connect_to_database (
633
- database_host = self .current_host if check_current_host else None
634
- ) as connection ,
635
- connection .cursor () as cursor ,
636
- ):
622
+ with self ._connect_to_database (
623
+ database_host = self .current_host if check_current_host else None
624
+ ) as connection , connection .cursor () as cursor :
637
625
cursor .execute ("SHOW ssl;" )
638
626
return "on" in cursor .fetchone ()[0 ]
639
627
except psycopg2 .Error :
@@ -653,10 +641,9 @@ def list_access_groups(self, current_host=False) -> Set[str]:
653
641
connection = None
654
642
host = self .current_host if current_host else None
655
643
try :
656
- with (
657
- self ._connect_to_database (database_host = host ) as connection ,
658
- connection .cursor () as cursor ,
659
- ):
644
+ with self ._connect_to_database (
645
+ database_host = host
646
+ ) as connection , connection .cursor () as cursor :
660
647
cursor .execute (
661
648
"SELECT groname FROM pg_catalog.pg_group WHERE groname LIKE '%_access';"
662
649
)
@@ -684,10 +671,9 @@ def list_accessible_databases_for_user(self, user: str, current_host=False) -> S
684
671
connection = None
685
672
host = self .current_host if current_host else None
686
673
try :
687
- with (
688
- self ._connect_to_database (database_host = host ) as connection ,
689
- connection .cursor () as cursor ,
690
- ):
674
+ with self ._connect_to_database (
675
+ database_host = host
676
+ ) as connection , connection .cursor () as cursor :
691
677
cursor .execute (
692
678
SQL (
693
679
"SELECT TRUE FROM pg_catalog.pg_user WHERE usename = {} AND usesuper;"
@@ -723,10 +709,9 @@ def list_users(self, group: Optional[str] = None, current_host=False) -> Set[str
723
709
connection = None
724
710
host = self .current_host if current_host else None
725
711
try :
726
- with (
727
- self ._connect_to_database (database_host = host ) as connection ,
728
- connection .cursor () as cursor ,
729
- ):
712
+ with self ._connect_to_database (
713
+ database_host = host
714
+ ) as connection , connection .cursor () as cursor :
730
715
if group :
731
716
query = SQL (
732
717
"SELECT usename FROM (SELECT UNNEST(grolist) AS user_id FROM pg_catalog.pg_group WHERE groname = {}) AS g JOIN pg_catalog.pg_user AS u ON g.user_id = u.usesysid;"
@@ -756,10 +741,9 @@ def list_users_from_relation(self, current_host=False) -> Set[str]:
756
741
connection = None
757
742
host = self .current_host if current_host else None
758
743
try :
759
- with (
760
- self ._connect_to_database (database_host = host ) as connection ,
761
- connection .cursor () as cursor ,
762
- ):
744
+ with self ._connect_to_database (
745
+ database_host = host
746
+ ) as connection , connection .cursor () as cursor :
763
747
cursor .execute (
764
748
"SELECT usename "
765
749
"FROM pg_catalog.pg_user "
@@ -795,18 +779,23 @@ def set_up_database(self, temp_location: Optional[str] = None) -> None:
795
779
connection = None
796
780
cursor = None
797
781
try :
798
- with (
799
- self ._connect_to_database (database = "template1" ) as connection ,
800
- connection .cursor () as cursor ,
801
- ):
802
- if temp_location is not None :
803
- cursor .execute ("SELECT TRUE FROM pg_tablespace WHERE spcname='temp';" )
804
- if cursor .fetchone () is None :
805
- cursor .execute (f"CREATE TABLESPACE temp LOCATION '{ temp_location } ';" )
806
- cursor .execute ("GRANT CREATE ON TABLESPACE temp TO public;" )
807
- cursor .execute (
808
- "SELECT TRUE FROM pg_roles WHERE rolname='charmed_databases_owner';"
809
- )
782
+ connection = self ._connect_to_database ()
783
+ cursor = connection .cursor ()
784
+
785
+ if temp_location is not None :
786
+ cursor .execute ("SELECT TRUE FROM pg_tablespace WHERE spcname='temp';" )
787
+ if cursor .fetchone () is None :
788
+ cursor .execute (f"CREATE TABLESPACE temp LOCATION '{ temp_location } ';" )
789
+ cursor .execute ("GRANT CREATE ON TABLESPACE temp TO public;" )
790
+
791
+ cursor .close ()
792
+ cursor = None
793
+ connection .close ()
794
+ connection = None
795
+
796
+ with self ._connect_to_database (
797
+ database = "template1"
798
+ ) as connection , connection .cursor () as cursor :
810
799
if cursor .fetchone () is None :
811
800
self .create_user (
812
801
"charmed_databases_owner" ,
@@ -890,6 +879,10 @@ def set_up_database(self, temp_location: Optional[str] = None) -> None:
890
879
WHEN TAG IN ('DROP SCHEMA')
891
880
EXECUTE FUNCTION update_pg_hba();
892
881
""" )
882
+
883
+ connection .close ()
884
+ connection = None
885
+
893
886
with self ._connect_to_database () as connection , connection .cursor () as cursor :
894
887
cursor .execute ("REVOKE ALL PRIVILEGES ON DATABASE postgres FROM PUBLIC;" )
895
888
cursor .execute ("REVOKE CREATE ON SCHEMA public FROM PUBLIC;" )
@@ -956,12 +949,9 @@ def set_up_login_hook_function(self) -> None:
956
949
$$ LANGUAGE plpgsql;"""
957
950
try :
958
951
for database in ["postgres" , "template1" ]:
959
- with (
960
- self ._connect_to_database (
961
- database = database ,
962
- ) as connection ,
963
- connection .cursor () as cursor ,
964
- ):
952
+ with self ._connect_to_database (
953
+ database = database ,
954
+ ) as connection , connection .cursor () as cursor :
965
955
cursor .execute (SQL ("CREATE EXTENSION IF NOT EXISTS login_hook;" ))
966
956
cursor .execute (SQL ("CREATE SCHEMA IF NOT EXISTS login_hook;" ))
967
957
cursor .execute (SQL (function_creation_statement ))
@@ -1041,10 +1031,9 @@ def set_up_predefined_catalog_roles_function(self) -> None:
1041
1031
$$ LANGUAGE plpgsql security definer;"""
1042
1032
try :
1043
1033
for database in ["postgres" , "template1" ]:
1044
- with (
1045
- self ._connect_to_database (database = database ) as connection ,
1046
- connection .cursor () as cursor ,
1047
- ):
1034
+ with self ._connect_to_database (
1035
+ database = database
1036
+ ) as connection , connection .cursor () as cursor :
1048
1037
cursor .execute (SQL (function_creation_statement ))
1049
1038
cursor .execute (
1050
1039
SQL ("ALTER FUNCTION set_up_predefined_catalog_roles OWNER TO operator;" )
@@ -1073,10 +1062,9 @@ def update_user_password(
1073
1062
"""
1074
1063
connection = None
1075
1064
try :
1076
- with (
1077
- self ._connect_to_database (database_host = database_host ) as connection ,
1078
- connection .cursor () as cursor ,
1079
- ):
1065
+ with self ._connect_to_database (
1066
+ database_host = database_host
1067
+ ) as connection , connection .cursor () as cursor :
1080
1068
cursor .execute (SQL ("BEGIN;" ))
1081
1069
cursor .execute (SQL ("SET LOCAL log_statement = 'none';" ))
1082
1070
cursor .execute (
@@ -1213,10 +1201,9 @@ def validate_date_style(self, date_style: str) -> bool:
1213
1201
Whether the date style is valid.
1214
1202
"""
1215
1203
try :
1216
- with (
1217
- self ._connect_to_database (database_host = self .current_host ) as connection ,
1218
- connection .cursor () as cursor ,
1219
- ):
1204
+ with self ._connect_to_database (
1205
+ database_host = self .current_host
1206
+ ) as connection , connection .cursor () as cursor :
1220
1207
cursor .execute (
1221
1208
SQL (
1222
1209
"SET DateStyle to {};" ,
0 commit comments