Skip to content

Commit d0730e2

Browse files
committed
Temp tblspace outside of transaction
1 parent 91436cf commit d0730e2

File tree

1 file changed

+69
-79
lines changed

1 file changed

+69
-79
lines changed

lib/charms/postgresql_k8s/v1/postgresql.py

Lines changed: 69 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -331,12 +331,9 @@ def create_predefined_instance_roles(self) -> None:
331331
connection = None
332332
try:
333333
for database in ["postgres", "template1"]:
334-
with (
335-
self._connect_to_database(
336-
database=database,
337-
) as connection,
338-
connection.cursor() as cursor,
339-
):
334+
with self._connect_to_database(
335+
database=database,
336+
) as connection, connection.cursor() as cursor:
340337
cursor.execute(SQL("CREATE EXTENSION IF NOT EXISTS set_user;"))
341338
finally:
342339
if connection is not None:
@@ -421,10 +418,9 @@ def delete_user(self, user: str) -> None:
421418
# Existing objects need to be reassigned in each database
422419
# before the user can be deleted.
423420
for database in databases:
424-
with (
425-
self._connect_to_database(database) as connection,
426-
connection.cursor() as cursor,
427-
):
421+
with self._connect_to_database(
422+
database
423+
) as connection, connection.cursor() as cursor:
428424
cursor.execute(
429425
SQL("REASSIGN OWNED BY {} TO {};").format(
430426
Identifier(user), Identifier(self.user)
@@ -515,10 +511,9 @@ def enable_disable_extensions(
515511

516512
# Enable/disabled the extension in each database.
517513
for database in databases:
518-
with (
519-
self._connect_to_database(database=database) as connection,
520-
connection.cursor() as cursor,
521-
):
514+
with self._connect_to_database(
515+
database=database
516+
) as connection, connection.cursor() as cursor:
522517
for extension, enable in ordered_extensions.items():
523518
cursor.execute(
524519
f"CREATE EXTENSION IF NOT EXISTS {extension};"
@@ -562,10 +557,9 @@ def get_postgresql_text_search_configs(self) -> Set[str]:
562557
Returns:
563558
Set of PostgreSQL text search configs.
564559
"""
565-
with (
566-
self._connect_to_database(database_host=self.current_host) as connection,
567-
connection.cursor() as cursor,
568-
):
560+
with self._connect_to_database(
561+
database_host=self.current_host
562+
) as connection, connection.cursor() as cursor:
569563
cursor.execute("SELECT CONCAT('pg_catalog.', cfgname) FROM pg_ts_config;")
570564
text_search_configs = cursor.fetchall()
571565
return {text_search_config[0] for text_search_config in text_search_configs}
@@ -576,10 +570,9 @@ def get_postgresql_timezones(self) -> Set[str]:
576570
Returns:
577571
Set of PostgreSQL timezones.
578572
"""
579-
with (
580-
self._connect_to_database(database_host=self.current_host) as connection,
581-
connection.cursor() as cursor,
582-
):
573+
with self._connect_to_database(
574+
database_host=self.current_host
575+
) as connection, connection.cursor() as cursor:
583576
cursor.execute("SELECT name FROM pg_timezone_names;")
584577
timezones = cursor.fetchall()
585578
return {timezone[0] for timezone in timezones}
@@ -590,10 +583,9 @@ def get_postgresql_default_table_access_methods(self) -> Set[str]:
590583
Returns:
591584
Set of PostgreSQL table access methods.
592585
"""
593-
with (
594-
self._connect_to_database(database_host=self.current_host) as connection,
595-
connection.cursor() as cursor,
596-
):
586+
with self._connect_to_database(
587+
database_host=self.current_host
588+
) as connection, connection.cursor() as cursor:
597589
cursor.execute("SELECT amname FROM pg_am WHERE amtype = 't';")
598590
access_methods = cursor.fetchall()
599591
return {access_method[0] for access_method in access_methods}
@@ -606,10 +598,9 @@ def get_postgresql_version(self, current_host=True) -> str:
606598
"""
607599
host = self.current_host if current_host else None
608600
try:
609-
with (
610-
self._connect_to_database(database_host=host) as connection,
611-
connection.cursor() as cursor,
612-
):
601+
with self._connect_to_database(
602+
database_host=host
603+
) as connection, connection.cursor() as cursor:
613604
cursor.execute("SELECT version();")
614605
# Split to get only the version number.
615606
return cursor.fetchone()[0].split(" ")[1]
@@ -628,12 +619,9 @@ def is_tls_enabled(self, check_current_host: bool = False) -> bool:
628619
whether TLS is enabled.
629620
"""
630621
try:
631-
with (
632-
self._connect_to_database(
633-
database_host=self.current_host if check_current_host else None
634-
) as connection,
635-
connection.cursor() as cursor,
636-
):
622+
with self._connect_to_database(
623+
database_host=self.current_host if check_current_host else None
624+
) as connection, connection.cursor() as cursor:
637625
cursor.execute("SHOW ssl;")
638626
return "on" in cursor.fetchone()[0]
639627
except psycopg2.Error:
@@ -653,10 +641,9 @@ def list_access_groups(self, current_host=False) -> Set[str]:
653641
connection = None
654642
host = self.current_host if current_host else None
655643
try:
656-
with (
657-
self._connect_to_database(database_host=host) as connection,
658-
connection.cursor() as cursor,
659-
):
644+
with self._connect_to_database(
645+
database_host=host
646+
) as connection, connection.cursor() as cursor:
660647
cursor.execute(
661648
"SELECT groname FROM pg_catalog.pg_group WHERE groname LIKE '%_access';"
662649
)
@@ -684,10 +671,9 @@ def list_accessible_databases_for_user(self, user: str, current_host=False) -> S
684671
connection = None
685672
host = self.current_host if current_host else None
686673
try:
687-
with (
688-
self._connect_to_database(database_host=host) as connection,
689-
connection.cursor() as cursor,
690-
):
674+
with self._connect_to_database(
675+
database_host=host
676+
) as connection, connection.cursor() as cursor:
691677
cursor.execute(
692678
SQL(
693679
"SELECT TRUE FROM pg_catalog.pg_user WHERE usename = {} AND usesuper;"
@@ -723,10 +709,9 @@ def list_users(self, group: Optional[str] = None, current_host=False) -> Set[str
723709
connection = None
724710
host = self.current_host if current_host else None
725711
try:
726-
with (
727-
self._connect_to_database(database_host=host) as connection,
728-
connection.cursor() as cursor,
729-
):
712+
with self._connect_to_database(
713+
database_host=host
714+
) as connection, connection.cursor() as cursor:
730715
if group:
731716
query = SQL(
732717
"SELECT usename FROM (SELECT UNNEST(grolist) AS user_id FROM pg_catalog.pg_group WHERE groname = {}) AS g JOIN pg_catalog.pg_user AS u ON g.user_id = u.usesysid;"
@@ -756,10 +741,9 @@ def list_users_from_relation(self, current_host=False) -> Set[str]:
756741
connection = None
757742
host = self.current_host if current_host else None
758743
try:
759-
with (
760-
self._connect_to_database(database_host=host) as connection,
761-
connection.cursor() as cursor,
762-
):
744+
with self._connect_to_database(
745+
database_host=host
746+
) as connection, connection.cursor() as cursor:
763747
cursor.execute(
764748
"SELECT usename "
765749
"FROM pg_catalog.pg_user "
@@ -795,15 +779,23 @@ def set_up_database(self, temp_location: Optional[str] = None) -> None:
795779
connection = None
796780
cursor = None
797781
try:
798-
with (
799-
self._connect_to_database(database="template1") as connection,
800-
connection.cursor() as cursor,
801-
):
802-
if temp_location is not None:
803-
cursor.execute("SELECT TRUE FROM pg_tablespace WHERE spcname='temp';")
804-
if cursor.fetchone() is None:
805-
cursor.execute(f"CREATE TABLESPACE temp LOCATION '{temp_location}';")
806-
cursor.execute("GRANT CREATE ON TABLESPACE temp TO public;")
782+
connection = self._connect_to_database()
783+
cursor = connection.cursor()
784+
785+
if temp_location is not None:
786+
cursor.execute("SELECT TRUE FROM pg_tablespace WHERE spcname='temp';")
787+
if cursor.fetchone() is None:
788+
cursor.execute(f"CREATE TABLESPACE temp LOCATION '{temp_location}';")
789+
cursor.execute("GRANT CREATE ON TABLESPACE temp TO public;")
790+
791+
cursor.close()
792+
cursor = None
793+
connection.close()
794+
connection = None
795+
796+
with self._connect_to_database(
797+
database="template1"
798+
) as connection, connection.cursor() as cursor:
807799
cursor.execute(
808800
"SELECT TRUE FROM pg_roles WHERE rolname='charmed_databases_owner';"
809801
)
@@ -890,6 +882,10 @@ def set_up_database(self, temp_location: Optional[str] = None) -> None:
890882
WHEN TAG IN ('DROP SCHEMA')
891883
EXECUTE FUNCTION update_pg_hba();
892884
""")
885+
886+
connection.close()
887+
connection = None
888+
893889
with self._connect_to_database() as connection, connection.cursor() as cursor:
894890
cursor.execute("REVOKE ALL PRIVILEGES ON DATABASE postgres FROM PUBLIC;")
895891
cursor.execute("REVOKE CREATE ON SCHEMA public FROM PUBLIC;")
@@ -956,12 +952,9 @@ def set_up_login_hook_function(self) -> None:
956952
$$ LANGUAGE plpgsql;"""
957953
try:
958954
for database in ["postgres", "template1"]:
959-
with (
960-
self._connect_to_database(
961-
database=database,
962-
) as connection,
963-
connection.cursor() as cursor,
964-
):
955+
with self._connect_to_database(
956+
database=database,
957+
) as connection, connection.cursor() as cursor:
965958
cursor.execute(SQL("CREATE EXTENSION IF NOT EXISTS login_hook;"))
966959
cursor.execute(SQL("CREATE SCHEMA IF NOT EXISTS login_hook;"))
967960
cursor.execute(SQL(function_creation_statement))
@@ -1041,10 +1034,9 @@ def set_up_predefined_catalog_roles_function(self) -> None:
10411034
$$ LANGUAGE plpgsql security definer;"""
10421035
try:
10431036
for database in ["postgres", "template1"]:
1044-
with (
1045-
self._connect_to_database(database=database) as connection,
1046-
connection.cursor() as cursor,
1047-
):
1037+
with self._connect_to_database(
1038+
database=database
1039+
) as connection, connection.cursor() as cursor:
10481040
cursor.execute(SQL(function_creation_statement))
10491041
cursor.execute(
10501042
SQL("ALTER FUNCTION set_up_predefined_catalog_roles OWNER TO operator;")
@@ -1073,10 +1065,9 @@ def update_user_password(
10731065
"""
10741066
connection = None
10751067
try:
1076-
with (
1077-
self._connect_to_database(database_host=database_host) as connection,
1078-
connection.cursor() as cursor,
1079-
):
1068+
with self._connect_to_database(
1069+
database_host=database_host
1070+
) as connection, connection.cursor() as cursor:
10801071
cursor.execute(SQL("BEGIN;"))
10811072
cursor.execute(SQL("SET LOCAL log_statement = 'none';"))
10821073
cursor.execute(
@@ -1213,10 +1204,9 @@ def validate_date_style(self, date_style: str) -> bool:
12131204
Whether the date style is valid.
12141205
"""
12151206
try:
1216-
with (
1217-
self._connect_to_database(database_host=self.current_host) as connection,
1218-
connection.cursor() as cursor,
1219-
):
1207+
with self._connect_to_database(
1208+
database_host=self.current_host
1209+
) as connection, connection.cursor() as cursor:
12201210
cursor.execute(
12211211
SQL(
12221212
"SET DateStyle to {};",

0 commit comments

Comments
 (0)