Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,233 changes: 917 additions & 316 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ jinja2 = "^3.1.6"
pysyncobj = "^0.3.14"
psutil = "^7.0.0"
charm-refresh = "^3.0.0.1"
aiohttp = "^3.12.1"

[tool.poetry.group.charm-libs.dependencies]
# data_platform_libs/v0/data_interfaces.py
Expand Down
46 changes: 32 additions & 14 deletions scripts/cluster_topology_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,19 @@
import json
import subprocess
import sys
from asyncio import as_completed, get_running_loop, run, wait
from contextlib import suppress
from ssl import CERT_NONE, create_default_context
from time import sleep
from urllib.parse import urljoin
from urllib.request import urlopen

API_REQUEST_TIMEOUT = 5
PATRONI_CLUSTER_STATUS_ENDPOINT = "cluster"
TLS_CA_FILE = "ca.pem"
SNAP_CURRENT_PATH = "/var/snap/charmed-postgresql/current"
SNAP_CONF_PATH = f"{SNAP_CURRENT_PATH}/etc"
PATRONI_CONF_PATH = f"{SNAP_CONF_PATH}/patroni"

# File path for the spawned cluster topology observer process to write logs.
LOG_FILE_PATH = "/var/log/cluster_topology_observer.log"
Expand All @@ -22,14 +28,28 @@ class UnreachableUnitsError(Exception):
"""Cannot reach any known cluster member."""


def call_url(url, context):
"""Task handler for calling an url."""
try:
# Scheme is generated by the charm
resp = urlopen( # noqa: S310
url,
timeout=API_REQUEST_TIMEOUT,
context=context,
)
return json.loads(resp.read())
except Exception as e:
print(f"Failed to contact {url} with {e}")


def dispatch(run_cmd, unit, charm_dir):
"""Use the input juju-run command to dispatch a :class:`ClusterTopologyChangeEvent`."""
dispatch_sub_cmd = "JUJU_DISPATCH_PATH=hooks/cluster_topology_change {}/dispatch"
# Input is generated by the charm
subprocess.run([run_cmd, "-u", unit, dispatch_sub_cmd.format(charm_dir)]) # noqa: S603


def main():
async def main():
"""Main watch and dispatch loop.

Watch the Patroni API cluster info. When changes are detected, dispatch the change event.
Expand All @@ -42,23 +62,21 @@ def main():
while True:
# Disable TLS chain verification
context = create_default_context()
with suppress(FileNotFoundError):
context.load_verify_locations(cafile=f"{PATRONI_CONF_PATH}/{TLS_CA_FILE}")
context.check_hostname = False
context.verify_mode = CERT_NONE

cluster_status = None
for url in urls:
try:
# Scheme is generated by the charm
resp = urlopen( # noqa: S310
url,
timeout=API_REQUEST_TIMEOUT,
context=context,
)
cluster_status = json.loads(resp.read())
loop = get_running_loop()
tasks = [loop.run_in_executor(None, call_url, url, context) for url in urls]
for task in as_completed(tasks):
if result := await task:
for task in tasks:
task.cancel()
await wait(tasks)
cluster_status = result
break
except Exception as e:
print(f"Failed to contact {url} with {e}")
continue
if not cluster_status:
raise UnreachableUnitsError("Unable to reach cluster members")
current_cluster_topology = {}
Expand Down Expand Up @@ -86,4 +104,4 @@ def main():


if __name__ == "__main__":
main()
run(main())
141 changes: 63 additions & 78 deletions src/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,22 @@
import pwd
import re
import shutil
import ssl
import subprocess
from asyncio import as_completed, create_task, run, wait
from contextlib import suppress
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, TypedDict

import charm_refresh
import psutil
import requests
from aiohttp import BasicAuth, ClientError, ClientSession, ClientTimeout
from charms.operator_libs_linux.v2 import snap
from jinja2 import Template
from ops import BlockedStatus
from pysyncobj.utility import TcpUtility, UtilityException
from tenacity import (
AttemptManager,
RetryError,
Retrying,
retry,
Expand Down Expand Up @@ -174,6 +177,10 @@ def __init__(
def _patroni_auth(self) -> requests.auth.HTTPBasicAuth:
return requests.auth.HTTPBasicAuth("patroni", self.patroni_password)

@property
def _patroni_async_auth(self) -> BasicAuth:
return BasicAuth("patroni", password=self.patroni_password)

@property
def _patroni_url(self) -> str:
"""Patroni REST API URL."""
Expand Down Expand Up @@ -251,25 +258,13 @@ def get_postgresql_version(self) -> str:
if snp["name"] == charm_refresh.snap_name():
return snp["version"]

def cluster_status(
self, alternative_endpoints: Optional[list] = None
) -> Optional[list[ClusterMember]]:
def cluster_status(self, alternative_endpoints: Optional[list] = None) -> list[ClusterMember]:
"""Query the cluster status."""
# Request info from cluster endpoint (which returns all members of the cluster).
for attempt in Retrying(stop=stop_after_attempt(2 * len(self.peers_ips) + 1)):
with attempt:
if alternative_endpoints:
request_url = self._get_alternative_patroni_url(attempt, alternative_endpoints)
else:
request_url = self._patroni_url

cluster_status = requests.get(
f"{request_url}/{PATRONI_CLUSTER_STATUS_ENDPOINT}",
verify=self.verify,
timeout=API_REQUEST_TIMEOUT,
auth=self._patroni_auth,
)
return cluster_status.json()["members"]
if response := self.parallel_patroni_get_request(
f"/{PATRONI_CLUSTER_STATUS_ENDPOINT}", alternative_endpoints
):
return response["members"]
return []

def get_member_ip(self, member_name: str) -> Optional[str]:
"""Get cluster member IP address.
Expand Down Expand Up @@ -306,6 +301,42 @@ def get_member_status(self, member_name: str) -> str:
return member["state"]
return ""

async def _aiohttp_get_request(self, url):
ssl_ctx = ssl.create_default_context()
with suppress(FileNotFoundError):
ssl_ctx.load_verify_locations(cafile=f"{PATRONI_CONF_PATH}/{TLS_CA_FILE}")
async with ClientSession(
auth=self._patroni_async_auth,
timeout=ClientTimeout(total=API_REQUEST_TIMEOUT),
) as session:
try:
async with session.get(url, ssl=ssl_ctx) as response:
if response.status > 299:
logger.debug(
"Call failed with status code {response.status}: {response.text()}"
)
return
return await response.json()
except (ClientError, ValueError):
return None

async def _async_get_request(self, uri, endpoints):
tasks = [
create_task(self._aiohttp_get_request(f"http://{ip}:8008{uri}")) for ip in endpoints
] + [create_task(self._aiohttp_get_request(f"https://{ip}:8008{uri}")) for ip in endpoints]
for task in as_completed(tasks):
if result := await task:
for task in tasks:
task.cancel()
await wait(tasks)
return result

def parallel_patroni_get_request(self, uri: str, endpoints: list[str] | None = None) -> dict:
"""Call all possible patroni endpoints in parallel."""
if not endpoints:
endpoints = (self.unit_ip, *self.peers_ips)
return run(self._async_get_request(uri, endpoints))

def get_primary(
self, unit_name_pattern=False, alternative_endpoints: list[str] | None = None
) -> Optional[str]:
Expand Down Expand Up @@ -341,72 +372,26 @@ def get_standby_leader(
standby leader pod or unit name.
"""
# Request info from cluster endpoint (which returns all members of the cluster).
for attempt in Retrying(stop=stop_after_attempt(2 * len(self.peers_ips) + 1)):
with attempt:
url = self._get_alternative_patroni_url(attempt)
cluster_status = requests.get(
f"{url}/{PATRONI_CLUSTER_STATUS_ENDPOINT}",
verify=self.verify,
timeout=API_REQUEST_TIMEOUT,
auth=self._patroni_auth,
)
for member in cluster_status.json()["members"]:
if member["role"] == "standby_leader":
if check_whether_is_running and member["state"] not in RUNNING_STATES:
logger.warning(f"standby leader {member['name']} is not running")
continue
standby_leader = member["name"]
if unit_name_pattern:
# Change the last dash to / in order to match unit name pattern.
standby_leader = label2name(standby_leader)
return standby_leader
for member in self.cluster_status():
if member["role"] == "standby_leader":
if check_whether_is_running and member["state"] not in RUNNING_STATES:
logger.warning(f"standby leader {member['name']} is not running")
continue
standby_leader = member["name"]
if unit_name_pattern:
# Change the last dash to / in order to match unit name pattern.
standby_leader = label2name(standby_leader)
return standby_leader

def get_sync_standby_names(self) -> list[str]:
"""Get the list of sync standby unit names."""
sync_standbys = []
# Request info from cluster endpoint (which returns all members of the cluster).
for attempt in Retrying(stop=stop_after_attempt(2 * len(self.peers_ips) + 1)):
with attempt:
url = self._get_alternative_patroni_url(attempt)
r = requests.get(
f"{url}/cluster",
verify=self.verify,
auth=self._patroni_auth,
timeout=PATRONI_TIMEOUT,
)
for member in r.json()["members"]:
if member["role"] == "sync_standby":
sync_standbys.append(label2name(member["name"]))
for member in self.cluster_status():
if member["role"] == "sync_standby":
sync_standbys.append(label2name(member["name"]))
return sync_standbys

def _get_alternative_patroni_url(
self, attempt: AttemptManager, alternative_endpoints: list[str] | None = None
) -> str:
"""Get an alternative REST API URL from another member each time.

When the Patroni process is not running in the current unit it's needed
to use a URL from another cluster member REST API to do some operations.
"""
if alternative_endpoints is not None:
return self._patroni_url.replace(
self.unit_ip, alternative_endpoints[attempt.retry_state.attempt_number - 1]
)
attempt_number = attempt.retry_state.attempt_number
if attempt_number > 1:
url = self._patroni_url
# Build the URL using http and later using https for each peer.
if (attempt_number - 1) <= len(self.peers_ips):
url = url.replace("https://", "http://")
unit_number = attempt_number - 2
else:
url = url.replace("http://", "https://")
unit_number = attempt_number - 2 - len(self.peers_ips)
other_unit_ip = list(self.peers_ips)[unit_number]
url = url.replace(self.unit_ip, other_unit_ip)
else:
url = self._patroni_url
return url

def are_all_members_ready(self) -> bool:
"""Check if all members are correctly running Patroni and PostgreSQL.

Expand Down
3 changes: 0 additions & 3 deletions tests/integration/ha_tests/test_self_healing_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from pytest_operator.plugin import OpsTest
from tenacity import Retrying, stop_after_delay, wait_fixed

from .. import markers
from ..helpers import (
CHARM_BASE,
db_connect,
Expand Down Expand Up @@ -161,7 +160,6 @@ async def test_forceful_restart_without_data_and_transaction_logs(


@pytest.mark.abort_on_fail
@markers.amd64_only
async def test_network_cut(ops_test: OpsTest, continuous_writes, primary_start_timeout):
"""Completely cut and restore network."""
# Locate primary unit.
Expand Down Expand Up @@ -250,7 +248,6 @@ async def test_network_cut(ops_test: OpsTest, continuous_writes, primary_start_t


@pytest.mark.abort_on_fail
@markers.amd64_only
async def test_network_cut_without_ip_change(
ops_test: OpsTest, continuous_writes, primary_start_timeout
):
Expand Down
Loading
Loading