diff --git a/poetry.lock b/poetry.lock index 4cc08e3f1d..b38c58fd6f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1712,6 +1712,7 @@ files = [ {file = "psycopg2-2.9.10-cp311-cp311-win_amd64.whl", hash = "sha256:0435034157049f6846e95103bd8f5a668788dd913a7c30162ca9503fdf542cb4"}, {file = "psycopg2-2.9.10-cp312-cp312-win32.whl", hash = "sha256:65a63d7ab0e067e2cdb3cf266de39663203d38d6a8ed97f5ca0cb315c73fe067"}, {file = "psycopg2-2.9.10-cp312-cp312-win_amd64.whl", hash = "sha256:4a579d6243da40a7b3182e0430493dbd55950c493d8c68f4eec0b302f6bbf20e"}, + {file = "psycopg2-2.9.10-cp313-cp313-win_amd64.whl", hash = "sha256:91fd603a2155da8d0cfcdbf8ab24a2d54bca72795b90d2a3ed2b6da8d979dee2"}, {file = "psycopg2-2.9.10-cp39-cp39-win32.whl", hash = "sha256:9d5b3b94b79a844a986d029eee38998232451119ad653aea42bb9220a8c5066b"}, {file = "psycopg2-2.9.10-cp39-cp39-win_amd64.whl", hash = "sha256:88138c8dedcbfa96408023ea2b0c369eda40fe5d75002c0964c78f46f11fa442"}, {file = "psycopg2-2.9.10.tar.gz", hash = "sha256:12ec0b40b0273f95296233e8750441339298e6a572f7039da5b260e3c8b60e11"}, @@ -1772,6 +1773,7 @@ files = [ {file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:bb89f0a835bcfc1d42ccd5f41f04870c1b936d8507c6df12b7737febc40f0909"}, {file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:f0c2d907a1e102526dd2986df638343388b94c33860ff3bbe1384130828714b1"}, {file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f8157bed2f51db683f31306aa497311b560f2265998122abe1dce6428bd86567"}, + {file = "psycopg2_binary-2.9.10-cp313-cp313-win_amd64.whl", hash = "sha256:27422aa5f11fbcd9b18da48373eb67081243662f9b46e6fd07c3eb46e4535142"}, {file = "psycopg2_binary-2.9.10-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:eb09aa7f9cecb45027683bb55aebaaf45a0df8bf6de68801a6afdc7947bb09d4"}, {file = "psycopg2_binary-2.9.10-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b73d6d7f0ccdad7bc43e6d34273f70d587ef62f824d7261c4ae9b8b1b6af90e8"}, {file = "psycopg2_binary-2.9.10-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ce5ab4bf46a211a8e924d307c1b1fcda82368586a19d0a24f8ae166f5c784864"}, @@ -2913,4 +2915,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" python-versions = "^3.12" -content-hash = "deabc8b622eca4ea83ed6dc3262d3a069c68ce8ec6ec6a2d30a0419ab98c59a4" +content-hash = "bb8005256018de79c03521150add72c590afdb6b7e5ba1c3dfac529968af0440" diff --git a/pyproject.toml b/pyproject.toml index 5199a4a62c..55a57aa4b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ jinja2 = "^3.1.6" pysyncobj = "^0.3.14" psutil = "^7.0.0" charm-refresh = "^3.0.0.3" +httpx = "^0.28.1" [tool.poetry.group.charm-libs.dependencies] # data_platform_libs/v0/data_interfaces.py diff --git a/scripts/cluster_topology_observer.py b/scripts/cluster_topology_observer.py index e2fcb1c3e8..f1a38135fc 100644 --- a/scripts/cluster_topology_observer.py +++ b/scripts/cluster_topology_observer.py @@ -6,8 +6,10 @@ import json import subprocess import sys +from asyncio import as_completed, get_running_loop, run, wait +from contextlib import suppress from os import environ -from ssl import CERT_NONE, create_default_context +from ssl import create_default_context from time import sleep from urllib.parse import urljoin from urllib.request import urlopen @@ -16,6 +18,10 @@ API_REQUEST_TIMEOUT = 5 PATRONI_CLUSTER_STATUS_ENDPOINT = "cluster" +TLS_CA_BUNDLE_FILE = "peer_ca_bundle.pem" +SNAP_CURRENT_PATH = "/var/snap/charmed-postgresql/current" +SNAP_CONF_PATH = f"{SNAP_CURRENT_PATH}/etc" +PATRONI_CONF_PATH = f"{SNAP_CONF_PATH}/patroni" # File path for the spawned cluster topology observer process to write logs. LOG_FILE_PATH = "/var/log/cluster_topology_observer.log" @@ -25,6 +31,20 @@ class UnreachableUnitsError(Exception): """Cannot reach any known cluster member.""" +def call_url(url, context): + """Task handler for calling an url.""" + try: + # Scheme is generated by the charm + resp = urlopen( # noqa: S310 + url, + timeout=API_REQUEST_TIMEOUT, + context=context, + ) + return json.loads(resp.read()) + except Exception as e: + print(f"Failed to contact {url} with {e}") + + def check_for_authorisation_rules_changes(run_cmd, unit, charm_dir, previous_authorisation_rules): """Check for changes in the authorisation rules. @@ -120,7 +140,7 @@ def dispatch(run_cmd, unit, charm_dir, custom_event): subprocess.run([run_cmd, "-u", unit, dispatch_sub_cmd.format(custom_event, charm_dir)]) # noqa: S603 -def main(): +async def main(): """Main watch and dispatch loop. Watch the Patroni API cluster info. When changes are detected, dispatch the change event. @@ -135,23 +155,19 @@ def main(): while True: # Disable TLS chain verification context = create_default_context() - context.check_hostname = False - context.verify_mode = CERT_NONE + with suppress(FileNotFoundError): + context.load_verify_locations(cafile=f"{PATRONI_CONF_PATH}/{TLS_CA_BUNDLE_FILE}") cluster_status = None - for url in urls: - try: - # Scheme is generated by the charm - resp = urlopen( # noqa: S310 - url, - timeout=API_REQUEST_TIMEOUT, - context=context, - ) - cluster_status = json.loads(resp.read()) + loop = get_running_loop() + tasks = [loop.run_in_executor(None, call_url, url, context) for url in urls] + for task in as_completed(tasks): + if result := await task: + for task in tasks: + task.cancel() + await wait(tasks) + cluster_status = result break - except Exception as e: - print(f"Failed to contact {url} with {e}") - continue if not cluster_status: raise UnreachableUnitsError("Unable to reach cluster members") current_cluster_topology = {} @@ -186,4 +202,4 @@ def main(): if __name__ == "__main__": - main() + run(main()) diff --git a/src/charm.py b/src/charm.py index d85822413c..8848151e04 100755 --- a/src/charm.py +++ b/src/charm.py @@ -130,6 +130,7 @@ logger = logging.getLogger(__name__) logging.getLogger("httpx").setLevel(logging.WARNING) logging.getLogger("httpcore").setLevel(logging.WARNING) +logging.getLogger("asyncio").setLevel(logging.WARNING) PRIMARY_NOT_REACHABLE_MESSAGE = "waiting for primary to be reachable from this unit" EXTENSIONS_DEPENDENCY_MESSAGE = "Unsatisfied plugin dependencies. Please check the logs" diff --git a/src/cluster.py b/src/cluster.py index 5160c0ebaa..bba9191ac2 100644 --- a/src/cluster.py +++ b/src/cluster.py @@ -11,18 +11,21 @@ import re import shutil import subprocess +from asyncio import as_completed, create_task, run, wait +from contextlib import suppress from pathlib import Path +from ssl import CERT_NONE, create_default_context from typing import TYPE_CHECKING, Any, TypedDict import charm_refresh import psutil import requests from charms.operator_libs_linux.v2 import snap +from httpx import AsyncClient, BasicAuth, HTTPError from jinja2 import Template from ops import BlockedStatus from pysyncobj.utility import TcpUtility, UtilityException from tenacity import ( - AttemptManager, RetryError, Retrying, retry, @@ -172,6 +175,10 @@ def __init__( def _patroni_auth(self) -> requests.auth.HTTPBasicAuth: return requests.auth.HTTPBasicAuth("patroni", self.patroni_password) + @property + def _patroni_async_auth(self) -> BasicAuth: + return BasicAuth("patroni", password=self.patroni_password) + @property def _patroni_url(self) -> str: """Patroni REST API URL.""" @@ -249,28 +256,14 @@ def get_postgresql_version(self) -> str: if snp["name"] == charm_refresh.snap_name(): return snp["version"] - def cluster_status( - self, alternative_endpoints: list | None = None - ) -> list[ClusterMember] | None: + def cluster_status(self, alternative_endpoints: list | None = None) -> list[ClusterMember]: """Query the cluster status.""" # Request info from cluster endpoint (which returns all members of the cluster). - # TODO we don't know the other cluster's ca - verify = self.verify if not alternative_endpoints else False - for attempt in Retrying( - stop=stop_after_attempt( - len(alternative_endpoints) if alternative_endpoints else len(self.peers_ips) - ) + if response := self.parallel_patroni_get_request( + f"/{PATRONI_CLUSTER_STATUS_ENDPOINT}", alternative_endpoints ): - with attempt: - request_url = self._get_alternative_patroni_url(attempt, alternative_endpoints) - - cluster_status = requests.get( - f"{request_url}/{PATRONI_CLUSTER_STATUS_ENDPOINT}", - verify=verify, - timeout=API_REQUEST_TIMEOUT, - auth=self._patroni_auth, - ) - return cluster_status.json()["members"] + return response["members"] + raise RetryError(last_attempt=Exception("Unable to reach any units")) def get_member_ip(self, member_name: str) -> str | None: """Get cluster member IP address. @@ -281,13 +274,14 @@ def get_member_ip(self, member_name: str) -> str | None: Returns: IP address of the cluster member. """ - cluster_status = self.cluster_status() - if not cluster_status: - return + try: + cluster_status = self.cluster_status() - for member in cluster_status: - if member["name"] == member_name: - return member["host"] + for member in cluster_status: + if member["name"] == member_name: + return member["host"] + except RetryError: + logger.debug("Unable to get IP. Cluster status unreachable") def get_member_status(self, member_name: str) -> str: """Get cluster member status. @@ -307,6 +301,44 @@ def get_member_status(self, member_name: str) -> str: return member["state"] return "" + async def _httpx_get_request(self, url: str, verify: bool = True): + ssl_ctx = create_default_context() + if verify: + with suppress(FileNotFoundError): + ssl_ctx.load_verify_locations(cafile=f"{PATRONI_CONF_PATH}/{TLS_CA_BUNDLE_FILE}") + else: + ssl_ctx.check_hostname = False + ssl_ctx.verify_mode = CERT_NONE + async with AsyncClient( + auth=self._patroni_async_auth, timeout=API_REQUEST_TIMEOUT, verify=ssl_ctx + ) as client: + try: + return (await client.get(url)).json() + except (HTTPError, ValueError): + return None + + async def _async_get_request(self, uri: str, endpoints: list[str], verify: bool = True): + tasks = [ + create_task(self._httpx_get_request(f"https://{ip}:8008{uri}", verify)) + for ip in endpoints + ] + for task in as_completed(tasks): + if result := await task: + for task in tasks: + task.cancel() + await wait(tasks) + return result + + def parallel_patroni_get_request(self, uri: str, endpoints: list[str] | None = None) -> dict: + """Call all possible patroni endpoints in parallel.""" + if not endpoints: + endpoints = (self.unit_ip, *self.peers_ips) + verify = True + else: + # TODO we don't know the other cluster's ca + verify = False + return run(self._async_get_request(uri, endpoints, verify)) + def get_primary( self, unit_name_pattern=False, alternative_endpoints: list[str] | None = None ) -> str | None: @@ -320,7 +352,8 @@ def get_primary( primary pod or unit name. """ # Request info from cluster endpoint (which returns all members of the cluster). - if cluster_status := self.cluster_status(alternative_endpoints): + try: + cluster_status = self.cluster_status(alternative_endpoints) for member in cluster_status: if member["role"] == "leader": primary = member["name"] @@ -328,6 +361,8 @@ def get_primary( # Change the last dash to / in order to match unit name pattern. primary = label2name(primary) return primary + except RetryError: + logger.debug("Unable to get primary. Cluster status unreachable") def get_standby_leader( self, unit_name_pattern=False, check_whether_is_running: bool = False @@ -366,31 +401,6 @@ def get_sync_standby_names(self) -> list[str]: sync_standbys.append(label2name(member["name"])) return sync_standbys - def _get_alternative_patroni_url( - self, attempt: AttemptManager, alternative_endpoints: list[str] | None = None - ) -> str: - """Get an alternative REST API URL from another member each time. - - When the Patroni process is not running in the current unit it's needed - to use a URL from another cluster member REST API to do some operations. - """ - if alternative_endpoints is not None: - return self._patroni_url.replace( - self.unit_ip, alternative_endpoints[attempt.retry_state.attempt_number - 1] - ) - attempt_number = attempt.retry_state.attempt_number - if attempt_number > 1: - url = self._patroni_url - if (attempt_number - 1) <= len(self.peers_ips): - unit_number = attempt_number - 2 - else: - unit_number = attempt_number - 2 - len(self.peers_ips) - other_unit_ip = list(self.peers_ips)[unit_number] - url = url.replace(self.unit_ip, other_unit_ip) - else: - url = self._patroni_url - return url - def are_all_members_ready(self) -> bool: """Check if all members are correctly running Patroni and PostgreSQL. diff --git a/tests/unit/test_charm.py b/tests/unit/test_charm.py index 287a591891..caec405a04 100644 --- a/tests/unit/test_charm.py +++ b/tests/unit/test_charm.py @@ -602,7 +602,7 @@ def test_on_start(harness): patch( "charm.PostgresqlOperatorCharm._is_storage_attached", side_effect=[False, True, True, True, True, True], - ) as _is_storage_attached, + ), patch( "charm.PostgresqlOperatorCharm._can_connect_to_postgresql", new_callable=PropertyMock, diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index 1d4afe7612..965b200db2 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -11,8 +11,6 @@ from ops.testing import Harness from pysyncobj.utility import UtilityException from tenacity import ( - AttemptManager, - RetryCallState, RetryError, Retrying, stop_after_delay, @@ -93,47 +91,28 @@ def patroni(harness, peers_ips): yield patroni -def test_get_alternative_patroni_url(peers_ips, patroni): - # Mock tenacity attempt. - retry = Retrying() - retry_state = RetryCallState(retry, None, None, None) - attempt = AttemptManager(retry_state) - - # Test the first URL that is returned (it should have the current unit IP). - url = patroni._get_alternative_patroni_url(attempt) - assert url == f"https://{patroni.unit_ip}:8008" - - # Test returning the other servers URLs. - for attempt_number in range(attempt.retry_state.attempt_number + 1, len(peers_ips) + 2): - attempt.retry_state.attempt_number = attempt_number - url = patroni._get_alternative_patroni_url(attempt) - assert url.split("https://")[1].split(":8008")[0] in peers_ips - - def test_get_member_ip(peers_ips, patroni): with ( - patch("requests.get", side_effect=mocked_requests_get), - patch("charm.Patroni._patroni_url", new_callable=PropertyMock) as _patroni_url, + patch( + "charm.Patroni.parallel_patroni_get_request", return_value=None + ) as _parallel_patroni_get_request, ): - # Test error on trying to get the member IP. - _patroni_url.return_value = "http://server2" - with pytest.raises(RetryError): - patroni.get_member_ip(patroni.member_name) - assert False - - # Test using an alternative Patroni URL. - _patroni_url.return_value = "http://server1" - - ip = patroni.get_member_ip(patroni.member_name) - assert ip == "1.1.1.1" - - # Test using the current Patroni URL. - ip = patroni.get_member_ip(patroni.member_name) - assert ip == "1.1.1.1" + # No IP if no members + assert patroni.get_member_ip(patroni.member_name) is None - # Test when not having that specific member in the cluster. - ip = patroni.get_member_ip("other-member-name") - assert ip is None + _parallel_patroni_get_request.return_value = { + "members": [ + { + "name": "postgresql-1", + "host": "2.2.2.2", + }, + { + "name": "postgresql-0", + "host": "1.1.1.1", + }, + ] + } + assert patroni.get_member_ip(patroni.member_name) == "1.1.1.1" def test_get_patroni_health(peers_ips, patroni): @@ -195,24 +174,30 @@ def test_dict_to_hba_string(harness, patroni): def test_get_primary(peers_ips, patroni): with ( - patch("requests.get", side_effect=mocked_requests_get), - patch("charm.Patroni._patroni_url", new_callable=PropertyMock) as _patroni_url, + patch( + "charm.Patroni.parallel_patroni_get_request", return_value=None + ) as _parallel_patroni_get_request, ): - # Test error on trying to get the member IP. - _patroni_url.return_value = "http://server2" - with pytest.raises(RetryError): - patroni.get_primary(patroni.member_name) - assert False + # No primary if no members + assert patroni.get_primary() is None + _parallel_patroni_get_request.return_value = { + "members": [ + { + "name": "postgresql-1", + "role": "replica", + }, + { + "name": "postgresql-0", + "role": "leader", + }, + ] + } # Test using the current Patroni URL. - _patroni_url.return_value = "http://server1" - primary = patroni.get_primary() - assert primary == "postgresql-0" + assert patroni.get_primary() == "postgresql-0" # Test requesting the primary in the unit name pattern. - _patroni_url.return_value = "http://server1" - primary = patroni.get_primary(unit_name_pattern=True) - assert primary == "postgresql/0" + assert patroni.get_primary(unit_name_pattern=True) == "postgresql/0" def test_is_creating_backup(peers_ips, patroni): @@ -238,6 +223,7 @@ def test_is_replication_healthy(peers_ips, patroni): with ( patch("requests.get") as _get, patch("charm.Patroni.get_primary"), + patch("charm.Patroni.get_member_ip"), patch("cluster.stop_after_delay", return_value=stop_after_delay(0)), ): # Test when replication is healthy. diff --git a/tests/unit/test_cluster_topology_observer.py b/tests/unit/test_cluster_topology_observer.py index 97fe87d6b9..75ca5c5de3 100644 --- a/tests/unit/test_cluster_topology_observer.py +++ b/tests/unit/test_cluster_topology_observer.py @@ -3,7 +3,7 @@ import signal import sys from json import dumps -from unittest.mock import Mock, PropertyMock, call, mock_open, patch, sentinel +from unittest.mock import Mock, PropertyMock, mock_open, patch import pytest from ops.charm import CharmBase @@ -145,7 +145,7 @@ def test_dispatch(harness): ]) -def test_main(): +async def test_main(): with ( patch("scripts.cluster_topology_observer.check_for_database_changes"), patch("scripts.cluster_topology_observer.check_for_authorisation_rules_changes"), @@ -157,10 +157,7 @@ def test_main(): patch("scripts.cluster_topology_observer.sleep", return_value=None), patch("scripts.cluster_topology_observer.urlopen") as _urlopen, patch("scripts.cluster_topology_observer.subprocess") as _subprocess, - patch( - "scripts.cluster_topology_observer.create_default_context", - return_value=sentinel.sslcontext, - ), + patch("scripts.cluster_topology_observer.create_default_context") as _context, ): response1 = { "members": [ @@ -179,16 +176,13 @@ def test_main(): mock2.read.return_value = dumps(response2) _urlopen.side_effect = [mock1, Exception, mock2] with pytest.raises(UnreachableUnitsError): - main() - assert _urlopen.call_args_list == [ - # Iteration 1. server2 is not called - call("http://server1:8008/cluster", timeout=5, context=sentinel.sslcontext), - # Iteration 2 local unit server1 is called first - call("http://server1:8008/cluster", timeout=5, context=sentinel.sslcontext), - call("http://server3:8008/cluster", timeout=5, context=sentinel.sslcontext), - # Iteration 3 Last known member is server3 - call("https://server3:8008/cluster", timeout=5, context=sentinel.sslcontext), - ] + await main() + _urlopen.assert_any_call( + "http://server1:8008/cluster", timeout=5, context=_context.return_value + ) + _urlopen.assert_any_call( + "http://server3:8008/cluster", timeout=5, context=_context.return_value + ) _subprocess.run.assert_called_once_with([ "run_cmd",