Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ jobs:
uses: charmed-kubernetes/actions-operator@main
with:
provider: lxd
# This is needed until https://bugs.launchpad.net/juju/+bug/1992833 is fixed.
bootstrap-options: "--agent-version 2.9.34"
- name: Run integration tests
run: tox -e database-relation-integration

Expand Down Expand Up @@ -105,6 +107,8 @@ jobs:
uses: charmed-kubernetes/actions-operator@main
with:
provider: lxd
# This is needed until https://bugs.launchpad.net/juju/+bug/1992833 is fixed.
bootstrap-options: "--agent-version 2.9.34"
- name: Run integration tests
run: tox -e ha-self-healing-integration

Expand Down
48 changes: 37 additions & 11 deletions lib/charms/postgresql_k8s/v0/postgresql_tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""

import base64
import ipaddress
import logging
import re
import socket
Expand All @@ -34,7 +35,7 @@
from cryptography.x509.extensions import ExtensionType
from ops.charm import ActionEvent
from ops.framework import Object
from ops.pebble import PathError, ProtocolError
from ops.pebble import ConnectionError, PathError, ProtocolError

# The unique Charmhub library identifier, never change it
LIBID = "c27af44a92df4ef38d7ae06418b2800f"
Expand All @@ -44,7 +45,7 @@

# Increment this PATCH version before using `charmcraft publish-lib` or reset
# to 0 if you are raising the major API version.
LIBPATCH = 1
LIBPATCH = 4

logger = logging.getLogger(__name__)
SCOPE = "unit"
Expand All @@ -54,11 +55,14 @@
class PostgreSQLTLS(Object):
"""In this class we manage certificates relation."""

def __init__(self, charm, peer_relation):
def __init__(
self, charm, peer_relation: str, additional_dns_names: Optional[List[str]] = None
):
"""Manager of PostgreSQL relation with TLS Certificates Operator."""
super().__init__(charm, "client-relations")
self.charm = charm
self.peer_relation = peer_relation
self.additional_dns_names = additional_dns_names or []
self.certs = TLSCertificatesRequiresV1(self.charm, TLS_RELATION)
self.framework.observe(
self.charm.on.set_tls_private_key_action, self._on_set_tls_private_key
Expand Down Expand Up @@ -86,8 +90,8 @@ def _request_certificate(self, param: Optional[str]):
csr = generate_csr(
private_key=key,
subject=self.charm.get_hostname_by_unit(self.charm.unit.name),
sans=self._get_sans(),
additional_critical_extensions=self._get_tls_extensions(),
**self._get_sans(),
)

self.charm.set_secret(SCOPE, "key", key.decode("utf-8"))
Expand Down Expand Up @@ -133,7 +137,7 @@ def _on_certificate_available(self, event: CertificateAvailableEvent) -> None:

try:
self.charm.push_tls_files_to_workload()
except (PathError, ProtocolError) as e:
except (ConnectionError, PathError, ProtocolError) as e:
logger.error("Cannot push TLS certificates: %r", e)
event.defer()
return
Expand All @@ -149,35 +153,57 @@ def _on_certificate_expiring(self, event: CertificateExpiringEvent) -> None:
new_csr = generate_csr(
private_key=key,
subject=self.charm.get_hostname_by_unit(self.charm.unit.name),
sans=self._get_sans(),
additional_critical_extensions=self._get_tls_extensions(),
**self._get_sans(),
)
self.certs.request_certificate_renewal(
old_certificate_signing_request=old_csr,
new_certificate_signing_request=new_csr,
)
self.charm.set_secret(SCOPE, "csr", new_csr.decode("utf-8"))

def _get_sans(self) -> List[str]:
"""Create a list of DNS names for a PostgreSQL unit.
def _get_sans(self) -> dict:
"""Create a list of Subject Alternative Names for a PostgreSQL unit.

Returns:
A list representing the hostnames of the PostgreSQL unit.
A list representing the IP and hostnames of the PostgreSQL unit.
"""

def is_ip_address(address: str) -> bool:
"""Returns whether and address is an IP address."""
try:
ipaddress.ip_address(address)
return True
except (ipaddress.AddressValueError, ValueError):
return False

unit_id = self.charm.unit.name.split("/")[1]
return [

# Create a list of all the Subject Alternative Names.
sans = [
f"{self.charm.app.name}-{unit_id}",
self.charm.get_hostname_by_unit(self.charm.unit.name),
socket.getfqdn(),
str(self.charm.model.get_binding(self.peer_relation).network.bind_address),
]
sans.extend(self.additional_dns_names)

# Separate IP addresses and DNS names.
sans_ip = [san for san in sans if is_ip_address(san)]
sans_dns = [san for san in sans if not is_ip_address(san)]

return {
"sans_ip": sans_ip,
"sans_dns": sans_dns,
}

@staticmethod
def _get_tls_extensions() -> Optional[List[ExtensionType]]:
"""Return a list of TLS extensions for which certificate key can be used."""
basic_constraints = x509.BasicConstraints(ca=True, path_length=None)
return [basic_constraints]

def get_tls_files(self) -> (Optional[str], Optional[str]):
def get_tls_files(self) -> (Optional[str], Optional[str], Optional[str]):
"""Prepare TLS files in special PostgreSQL way.

PostgreSQL needs three files:
Expand Down
24 changes: 18 additions & 6 deletions lib/charms/tls_certificates_interface/v1/tls_certificates.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def _on_certificate_expiring(self, event: CertificateExpiringEvent) -> None:
import logging
import uuid
from datetime import datetime, timedelta
from ipaddress import IPv4Address
from typing import Dict, List, Optional

from cryptography import x509
Expand Down Expand Up @@ -657,7 +658,9 @@ def generate_csr(
email_address: str = None,
country_name: str = None,
private_key_password: Optional[bytes] = None,
sans: Optional[List[str]] = None,
sans_oid: Optional[str] = None,
sans_ip: Optional[List[str]] = None,
sans_dns: Optional[List[str]] = None,
additional_critical_extensions: Optional[List] = None,
) -> bytes:
"""Generates a CSR using private key and subject.
Expand All @@ -672,7 +675,9 @@ def generate_csr(
email_address (str): Email address.
country_name (str): Country Name.
private_key_password (bytes): Private key password
sans (list): List of subject alternative names
sans_dns (list): List of DNS subject alternative names
sans_ip (list): List of IP subject alternative names
sans_oid (str): Additional OID
additional_critical_extensions (list): List if critical additional extension objects.
Object must be a x509 ExtensionType.

Expand All @@ -693,10 +698,17 @@ def generate_csr(
if country_name:
subject_name.append(x509.NameAttribute(x509.NameOID.COUNTRY_NAME, country_name))
csr = x509.CertificateSigningRequestBuilder(subject_name=x509.Name(subject_name))
if sans:
csr = csr.add_extension(
x509.SubjectAlternativeName([x509.DNSName(san) for san in sans]), critical=False
)

_sans = []
if sans_oid:
_sans.append(x509.RegisteredID(x509.ObjectIdentifier(sans_oid)))
if sans_ip:
_sans.extend([x509.IPAddress(IPv4Address(san)) for san in sans_ip])
if sans_dns:
_sans.extend([x509.DNSName(san) for san in sans_dns])
if _sans:
csr = csr.add_extension(x509.SubjectAlternativeName(_sans), critical=False)

if additional_critical_extensions:
for extension in additional_critical_extensions:
csr = csr.add_extension(extension, critical=True)
Expand Down
14 changes: 10 additions & 4 deletions src/charm.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,10 +821,12 @@ def push_tls_files_to_workload(self) -> None:
self.update_config()

def _restart(self, _) -> None:
"""Restart Patroni and PostgreSQL."""
if not self._patroni.restart_patroni():
logger.exception("failed to restart PostgreSQL")
self.unit.status = BlockedStatus("failed to restart Patroni and PostgreSQL")
"""Restart PostgreSQL."""
try:
self._patroni.restart_postgresql()
except RetryError as e:
logger.error("failed to restart PostgreSQL")
self.unit.status = BlockedStatus(f"failed to restart PostgreSQL with error {e}")

def update_config(self) -> None:
"""Updates Patroni config file based on the existence of the TLS files."""
Expand All @@ -833,6 +835,10 @@ def update_config(self) -> None:
# Update and reload configuration based on TLS files availability.
self._patroni.render_patroni_yml_file(enable_tls=enable_tls)
if not self._patroni.member_started:
# If Patroni/PostgreSQL has not started yet and TLS relations was initialised,
# then mark TLS as enabled. This commonly happens when the charm is deployed
# in a bundle together with the TLS certificates operator.
self.unit_peer_data.update({"tls": "enabled" if enable_tls else ""})
return

restart_postgresql = enable_tls != self.postgresql.is_tls_enabled()
Expand Down
6 changes: 2 additions & 4 deletions tests/integration/test_tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ async def test_deploy_active(ops_test: OpsTest):
charm, resources={"patroni": "patroni.tar.gz"}, application_name=APP_NAME, num_units=3
)
await ops_test.juju("attach-resource", APP_NAME, "patroni=patroni.tar.gz")
await ops_test.model.wait_for_idle(apps=[APP_NAME], status="active", timeout=1000)
# No wait between deploying charms, since we can't guarantee users will wait. Furthermore,
# bundles don't wait between deploying charms.


@pytest.mark.tls_tests
Expand All @@ -36,9 +37,6 @@ async def test_tls_enabled(ops_test: OpsTest) -> None:
# Deploy TLS Certificates operator.
config = {"generate-self-signed-certificates": "true", "ca-common-name": "Test CA"}
await ops_test.model.deploy(TLS_CERTIFICATES_APP_NAME, channel="edge", config=config)
await ops_test.model.wait_for_idle(
apps=[TLS_CERTIFICATES_APP_NAME], status="active", timeout=1000
)

# Relate it to the PostgreSQL to enable TLS.
await ops_test.model.relate(DATABASE_APP_NAME, TLS_CERTIFICATES_APP_NAME)
Expand Down
72 changes: 72 additions & 0 deletions tests/unit/test_charm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
)
from ops.model import ActiveStatus, BlockedStatus, WaitingStatus
from ops.testing import Harness
from tenacity import RetryError

from charm import PostgresqlOperatorCharm
from constants import PEER
Expand Down Expand Up @@ -424,3 +425,74 @@ def test_set_secret(self, _):
self.harness.get_relation_data(self.rel_id, self.charm.unit.name)["password"]
== "test-password"
)

@patch_network_get(private_address="1.1.1.1")
@patch("charm.Patroni.restart_postgresql")
def test_restart(self, _restart_postgresql):
# Test a successful restart.
self.charm._restart(None)
self.assertFalse(isinstance(self.charm.unit.status, BlockedStatus))

# Test a failed restart.
_restart_postgresql.side_effect = RetryError(last_attempt=1)
self.charm._restart(None)
self.assertTrue(isinstance(self.charm.unit.status, BlockedStatus))

@patch_network_get(private_address="1.1.1.1")
@patch("charms.rolling_ops.v0.rollingops.RollingOpsManager._on_acquire_lock")
@patch("charm.Patroni.reload_patroni_configuration")
@patch("charm.Patroni.member_started", new_callable=PropertyMock)
@patch("charm.Patroni.render_patroni_yml_file")
@patch("charms.postgresql_k8s.v0.postgresql_tls.PostgreSQLTLS.get_tls_files")
def test_update_config(
self,
_get_tls_files,
_render_patroni_yml_file,
_member_started,
_reload_patroni_configuration,
_restart,
):
with patch.object(PostgresqlOperatorCharm, "postgresql", Mock()) as postgresql_mock:
# Mock some properties.
postgresql_mock.is_tls_enabled = PropertyMock(side_effect=[False, False, False])
_member_started.side_effect = [True, True, False]

# Test without TLS files available.
self.harness.update_relation_data(
self.rel_id, self.charm.unit.name, {"tls": "enabled"}
) # Mock some data in the relation to test that it change.
_get_tls_files.return_value = [None]
self.charm.update_config()
_render_patroni_yml_file.assert_called_once_with(enable_tls=False)
_reload_patroni_configuration.assert_called_once()
_restart.assert_not_called()
self.assertNotIn(
"tls", self.harness.get_relation_data(self.rel_id, self.charm.unit.name)
)

# Test with TLS files available.
self.harness.update_relation_data(
self.rel_id, self.charm.unit.name, {"tls": ""}
) # Mock some data in the relation to test that it change.
_get_tls_files.return_value = ["something"]
_render_patroni_yml_file.reset_mock()
_reload_patroni_configuration.reset_mock()
self.charm.update_config()
_render_patroni_yml_file.assert_called_once_with(enable_tls=True)
_reload_patroni_configuration.assert_called_once()
_restart.assert_called_once()
self.assertEqual(
self.harness.get_relation_data(self.rel_id, self.charm.unit.name)["tls"], "enabled"
)

# Test with member not started yet.
self.harness.update_relation_data(
self.rel_id, self.charm.unit.name, {"tls": ""}
) # Mock some data in the relation to test that it change.
_reload_patroni_configuration.reset_mock()
self.charm.update_config()
_reload_patroni_configuration.assert_not_called()
_restart.assert_called_once()
self.assertEqual(
self.harness.get_relation_data(self.rel_id, self.charm.unit.name)["tls"], "enabled"
)