Skip to content
50 changes: 33 additions & 17 deletions src/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from jinja2 import Template
from tenacity import (
AttemptManager,
RetryError,
Retrying,
retry,
Expand Down Expand Up @@ -161,14 +162,14 @@ def get_member_ip(self, member_name: str) -> str:
Returns:
IP address of the cluster member.
"""
ip = None
# Request info from cluster endpoint (which returns all members of the cluster).
cluster_status = requests.get(f"{self._patroni_url}/cluster", verify=self.verify)
for member in cluster_status.json()["members"]:
if member["name"] == member_name:
ip = member["host"]
break
return ip
for attempt in Retrying(stop=stop_after_attempt(len(self.peers_ips) + 1)):
with attempt:
url = self._get_alternative_patroni_url(attempt)
cluster_status = requests.get(f"{url}/cluster", verify=self.verify, timeout=10)
for member in cluster_status.json()["members"]:
if member["name"] == member_name:
return member["host"]

def get_primary(self, unit_name_pattern=False) -> str:
"""Get primary instance.
Expand All @@ -179,17 +180,32 @@ def get_primary(self, unit_name_pattern=False) -> str:
Returns:
primary pod or unit name.
"""
primary = None
# Request info from cluster endpoint (which returns all members of the cluster).
cluster_status = requests.get(f"{self._patroni_url}/cluster", verify=self.verify)
for member in cluster_status.json()["members"]:
if member["role"] == "leader":
primary = member["name"]
if unit_name_pattern:
# Change the last dash to / in order to match unit name pattern.
primary = "/".join(primary.rsplit("-", 1))
break
return primary
for attempt in Retrying(stop=stop_after_attempt(len(self.peers_ips) + 1)):
with attempt:
url = self._get_alternative_patroni_url(attempt)
cluster_status = requests.get(f"{url}/cluster", verify=self.verify, timeout=10)
for member in cluster_status.json()["members"]:
if member["role"] == "leader":
primary = member["name"]
if unit_name_pattern:
# Change the last dash to / in order to match unit name pattern.
primary = "/".join(primary.rsplit("-", 1))
return primary

def _get_alternative_patroni_url(self, attempt: AttemptManager) -> str:
"""Get an alternative REST API URL from another member each time.

When the Patroni process is not running in the current unit it's needed
to use a URL from another cluster member REST API to do some operations.
"""
if attempt.retry_state.attempt_number > 1:
url = self._patroni_url.replace(
self.unit_ip, list(self.peers_ips)[attempt.retry_state.attempt_number - 2]
)
else:
url = self._patroni_url
return url

def are_all_members_ready(self) -> bool:
"""Check if all members are correctly running Patroni and PostgreSQL.
Expand Down
1 change: 0 additions & 1 deletion tests/integration/ha_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ async def master_start_timeout(ops_test: OpsTest) -> None:
"""Temporary change the master start timeout configuration."""
# Change the parameter that makes the primary reelection faster.
initial_master_start_timeout = await get_master_start_timeout(ops_test)
await change_master_start_timeout(ops_test, 0)
yield
# Rollback to the initial configuration.
await change_master_start_timeout(ops_test, initial_master_start_timeout)
79 changes: 73 additions & 6 deletions tests/integration/ha_tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@
APP_NAME = METADATA["name"]


class MemberNotListedOnClusterError(Exception):
"""Raised when a member is not listed in the cluster."""


class MemberNotUpdatedOnClusterError(Exception):
"""Raised when a member is not yet updated in the cluster."""


class ProcessError(Exception):
pass

Expand Down Expand Up @@ -54,11 +62,14 @@ async def change_master_start_timeout(ops_test: OpsTest, seconds: Optional[int])
)


async def count_writes(ops_test: OpsTest) -> int:
async def count_writes(ops_test: OpsTest, down_unit: str = None) -> int:
"""Count the number of writes in the database."""
app = await app_name(ops_test)
password = await get_password(ops_test, app)
host = ops_test.model.applications[app].units[0].public_address
password = await get_password(ops_test, app, down_unit)
for unit in ops_test.model.applications[app].units:
if unit.name != down_unit:
host = unit.public_address
break
connection_string = (
f"dbname='application' user='operator'"
f" host='{host}' password='{password}' connect_timeout=10"
Expand All @@ -77,6 +88,27 @@ async def count_writes(ops_test: OpsTest) -> int:
return count


async def fetch_cluster_members(ops_test: OpsTest):
"""Fetches the IPs listed by Patroni as cluster members.

Args:
ops_test: OpsTest instance.
"""
app = await app_name(ops_test)
member_ips = {}
for unit in ops_test.model.applications[app].units:
cluster_info = requests.get(f"http://{unit.public_address}:8008/cluster")
if len(member_ips) > 0:
# If the list of members IPs was already fetched, also compare the
# list provided by other members.
assert member_ips == {
member["host"] for member in cluster_info.json()["members"]
}, "members report different lists of cluster members."
else:
member_ips = {member["host"] for member in cluster_info.json()["members"]}
return member_ips


async def get_master_start_timeout(ops_test: OpsTest) -> Optional[int]:
"""Get the master start timeout configuration.

Expand All @@ -96,19 +128,52 @@ async def get_master_start_timeout(ops_test: OpsTest) -> Optional[int]:
return int(master_start_timeout) if master_start_timeout is not None else None


async def get_password(ops_test: OpsTest, app) -> str:
async def get_password(ops_test: OpsTest, app: str, down_unit: str = None) -> str:
"""Use the charm action to retrieve the password from provided application.

Returns:
string with the password stored on the peer relation databag.
"""
# Can retrieve from any unit running unit, so we pick the first.
unit_name = ops_test.model.applications[app].units[0].name
for unit in ops_test.model.applications[app].units:
if unit.name != down_unit:
unit_name = unit.name
break
action = await ops_test.model.units.get(unit_name).run_action("get-password")
action = await action.wait()
return action.results["operator-password"]


def is_replica(ops_test: OpsTest, unit_name: str) -> bool:
"""Returns whether the unit a replica in the cluster."""
unit_ip = get_unit_address(ops_test, unit_name)
member_name = unit_name.replace("/", "-")

try:
for attempt in Retrying(stop=stop_after_delay(60 * 3), wait=wait_fixed(3)):
with attempt:
cluster_info = requests.get(f"http://{unit_ip}:8008/cluster")

# The unit may take some time to be listed on Patroni REST API cluster endpoint.
if member_name not in {
member["name"] for member in cluster_info.json()["members"]
}:
raise MemberNotListedOnClusterError()

for member in cluster_info.json()["members"]:
if member["name"] == member_name:
role = member["role"]

# A member that restarted has the DB process stopped may
# take some time to know that a new primary was elected.
if role == "replica":
return True
else:
raise MemberNotUpdatedOnClusterError()
except RetryError:
return False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great function 🤩



async def get_primary(ops_test: OpsTest, app) -> str:
"""Use the charm action to retrieve the primary from provided application.

Expand All @@ -122,7 +187,9 @@ async def get_primary(ops_test: OpsTest, app) -> str:
return action.results["primary"]


async def kill_process(ops_test: OpsTest, unit_name: str, process: str, kill_code: str) -> None:
async def send_signal_to_process(
ops_test: OpsTest, unit_name: str, process: str, kill_code: str
) -> None:
"""Kills process on the unit according to the provided kill code."""
# Killing the only instance can be disastrous.
app = await app_name(ops_test)
Expand Down
89 changes: 87 additions & 2 deletions tests/integration/ha_tests/test_self_healing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,15 @@
from tests.integration.ha_tests.helpers import (
METADATA,
app_name,
change_master_start_timeout,
count_writes,
fetch_cluster_members,
get_master_start_timeout,
get_primary,
kill_process,
is_replica,
postgresql_ready,
secondary_up_to_date,
send_signal_to_process,
start_continuous_writes,
stop_continuous_writes,
)
Expand Down Expand Up @@ -52,8 +56,12 @@ async def test_kill_db_process(
# Start an application that continuously writes data to the database.
await start_continuous_writes(ops_test, app)

# Change the "master_start_timeout" parameter to speed up the fail-over.
original_master_start_timeout = await get_master_start_timeout(ops_test)
await change_master_start_timeout(ops_test, 0)

# Kill the database process.
await kill_process(ops_test, primary_name, process, kill_code="SIGKILL")
await send_signal_to_process(ops_test, primary_name, process, kill_code="SIGKILL")

async with ops_test.fast_forward():
# Verify new writes are continuing by counting the number of writes before and after a
Expand All @@ -72,6 +80,83 @@ async def test_kill_db_process(
new_primary_name = await get_primary(ops_test, app)
assert new_primary_name != primary_name

# Revert the "master_start_timeout" parameter to avoid fail-over again.
await change_master_start_timeout(ops_test, original_master_start_timeout)

# Verify that the old primary is now a replica.
assert is_replica(ops_test, primary_name), "there are more than one primary in the cluster."

# Verify that all units are part of the same cluster.
member_ips = await fetch_cluster_members(ops_test)
ip_addresses = [unit.public_address for unit in ops_test.model.applications[app].units]
assert set(member_ips) == set(ip_addresses), "not all units are part of the same cluster."

# Verify that no writes to the database were missed after stopping the writes.
total_expected_writes = await stop_continuous_writes(ops_test)
for attempt in Retrying(stop=stop_after_delay(60), wait=wait_fixed(3)):
with attempt:
actual_writes = await count_writes(ops_test)
assert total_expected_writes == actual_writes, "writes to the db were missed."

# Verify that old primary is up-to-date.
assert await secondary_up_to_date(
ops_test, primary_name, total_expected_writes
), "secondary not up to date with the cluster after restarting."


@pytest.mark.ha_self_healing_tests
@pytest.mark.parametrize("process", DB_PROCESSES)
async def test_freeze_db_process(
ops_test: OpsTest, process: str, continuous_writes, master_start_timeout
) -> None:
# Locate primary unit.
app = await app_name(ops_test)
primary_name = await get_primary(ops_test, app)

# Start an application that continuously writes data to the database.
await start_continuous_writes(ops_test, app)

# Change the "master_start_timeout" parameter to speed up the fail-over.
original_master_start_timeout = await get_master_start_timeout(ops_test)
await change_master_start_timeout(ops_test, 0)

# Freeze the database process.
await send_signal_to_process(ops_test, primary_name, process, "SIGSTOP")

async with ops_test.fast_forward():
# Verify new writes are continuing by counting the number of writes before and after a
# 3 minutes wait (this is a little more than the loop wait configuration, that is
# considered to trigger a fail-over after master_start_timeout is changed, and also
# when freezing the DB process it take some more time to trigger the fail-over).
writes = await count_writes(ops_test, primary_name)
for attempt in Retrying(stop=stop_after_delay(60 * 3), wait=wait_fixed(3)):
with attempt:
more_writes = await count_writes(ops_test, primary_name)
assert more_writes > writes, "writes not continuing to DB"

# Verify that a new primary gets elected (ie old primary is secondary).
for attempt in Retrying(stop=stop_after_delay(60 * 3), wait=wait_fixed(3)):
with attempt:
new_primary_name = await get_primary(ops_test, app)
assert new_primary_name != primary_name

# Revert the "master_start_timeout" parameter to avoid fail-over again.
await change_master_start_timeout(ops_test, original_master_start_timeout)

# Un-freeze the old primary.
await send_signal_to_process(ops_test, primary_name, process, "SIGCONT")

# Verify that the database service got restarted and is ready in the old primary.
assert await postgresql_ready(ops_test, primary_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add checks that verify:

  • that the old primary is now the secondary
  • all units are in the same replica set

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great suggestion. Added both on 6a1e39b. I also added those checks on the kill DB process test.


# Verify that the old primary is now a replica.
assert is_replica(ops_test, primary_name), "there are more than one primary in the cluster."

# Verify that all units are part of the same cluster.
member_ips = await fetch_cluster_members(ops_test)
ip_addresses = [unit.public_address for unit in ops_test.model.applications[app].units]
assert set(member_ips) == set(ip_addresses), "not all units are part of the same cluster."

# Verify that no writes to the database were missed after stopping the writes.
total_expected_writes = await stop_continuous_writes(ops_test)
for attempt in Retrying(stop=stop_after_delay(60), wait=wait_fixed(3)):
Expand Down
Loading