From 23c37dc498487b250b2bd9ca6cddb482c24a926e Mon Sep 17 00:00:00 2001 From: Leon <82407168+sed-i@users.noreply.github.com> Date: Fri, 26 Apr 2024 02:31:18 -0400 Subject: [PATCH] Fetch-lib. Fix for #315 (#327) * chore: update charm libraries * fetch-lib * Attempt at fixing potential startup errors * Attempt at fixing potential startup errors --------- Co-authored-by: Noctua --- .../observability_libs/v0/cert_handler.py | 50 +- .../prometheus_k8s/v0/prometheus_scrape.py | 17 +- lib/charms/tempo_k8s/v2/tracing.py | 74 +- .../{v2 => v3}/tls_certificates.py | 703 ++++++++---------- src/charm.py | 16 +- tests/integration/test_tls_web.py | 1 + 6 files changed, 428 insertions(+), 433 deletions(-) rename lib/charms/tls_certificates_interface/{v2 => v3}/tls_certificates.py (79%) diff --git a/lib/charms/observability_libs/v0/cert_handler.py b/lib/charms/observability_libs/v0/cert_handler.py index db14e00f..0fc610ff 100644 --- a/lib/charms/observability_libs/v0/cert_handler.py +++ b/lib/charms/observability_libs/v0/cert_handler.py @@ -37,22 +37,25 @@ import json import socket from itertools import filterfalse -from typing import List, Optional, Union +from typing import List, Optional, Union, cast try: - from charms.tls_certificates_interface.v2.tls_certificates import ( # type: ignore + from charms.tls_certificates_interface.v3.tls_certificates import ( # type: ignore AllCertificatesInvalidatedEvent, CertificateAvailableEvent, CertificateExpiringEvent, CertificateInvalidatedEvent, - TLSCertificatesRequiresV2, + TLSCertificatesRequiresV3, generate_csr, generate_private_key, ) -except ImportError: +except ImportError as e: raise ImportError( - "charms.tls_certificates_interface.v2.tls_certificates is missing; please get it through charmcraft fetch-lib" - ) + "failed to import charms.tls_certificates_interface.v3.tls_certificates; " + "Either the library itself is missing (please get it through charmcraft fetch-lib) " + "or one of its dependencies is unmet." + ) from e + import logging from ops.charm import CharmBase, RelationBrokenEvent @@ -64,7 +67,7 @@ LIBID = "b5cd5cd580f3428fa5f59a8876dcbe6a" LIBAPI = 0 -LIBPATCH = 9 +LIBPATCH = 12 def is_ip_address(value: str) -> bool: @@ -129,7 +132,7 @@ def __init__( self.peer_relation_name = peer_relation_name self.certificates_relation_name = certificates_relation_name - self.certificates = TLSCertificatesRequiresV2(self.charm, self.certificates_relation_name) + self.certificates = TLSCertificatesRequiresV3(self.charm, self.certificates_relation_name) self.framework.observe( self.charm.on.config_changed, @@ -237,6 +240,13 @@ def _generate_csr( This method intentionally does not emit any events, leave it for caller's responsibility. """ + # if we are in a relation-broken hook, we might not have a relation to publish the csr to. + if not self.charm.model.get_relation(self.certificates_relation_name): + logger.warning( + f"No {self.certificates_relation_name!r} relation found. " f"Cannot generate csr." + ) + return + # At this point, assuming "peer joined" and "certificates joined" have already fired # (caller must guard) so we must have a private_key entry in relation data at our disposal. # Otherwise, traceback -> debug. @@ -279,7 +289,7 @@ def _generate_csr( if clear_cert: self._ca_cert = "" self._server_cert = "" - self._chain = [] + self._chain = "" def _on_certificate_available(self, event: CertificateAvailableEvent) -> None: """Get the certificate from the event and store it in a peer relation. @@ -301,7 +311,7 @@ def _on_certificate_available(self, event: CertificateAvailableEvent) -> None: if event_csr == self._csr: self._ca_cert = event.ca self._server_cert = event.certificate - self._chain = event.chain + self._chain = event.chain_as_pem() self.on.cert_changed.emit() # pyright: ignore @property @@ -372,21 +382,29 @@ def _server_cert(self, value: str): rel.data[self.charm.unit].update({"certificate": value}) @property - def _chain(self) -> List[str]: + def _chain(self) -> str: if self._peer_relation: - if chain := self._peer_relation.data[self.charm.unit].get("chain", []): - return json.loads(chain) - return [] + if chain := self._peer_relation.data[self.charm.unit].get("chain", ""): + chain = json.loads(chain) + + # In a previous version of this lib, chain used to be a list. + # Convert the List[str] to str, per + # https://github.com/canonical/tls-certificates-interface/pull/141 + if isinstance(chain, list): + chain = "\n\n".join(reversed(chain)) + + return cast(str, chain) + return "" @_chain.setter - def _chain(self, value: List[str]): + def _chain(self, value: str): # Caller must guard. We want the setter to fail loudly. Failure must have a side effect. rel = self._peer_relation assert rel is not None # For type checker rel.data[self.charm.unit].update({"chain": json.dumps(value)}) @property - def chain(self) -> List[str]: + def chain(self) -> str: """Return the ca chain.""" return self._chain diff --git a/lib/charms/prometheus_k8s/v0/prometheus_scrape.py b/lib/charms/prometheus_k8s/v0/prometheus_scrape.py index 665af886..be967686 100644 --- a/lib/charms/prometheus_k8s/v0/prometheus_scrape.py +++ b/lib/charms/prometheus_k8s/v0/prometheus_scrape.py @@ -362,7 +362,7 @@ def _on_scrape_targets_changed(self, event): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 44 +LIBPATCH = 46 PYDEPS = ["cosl"] @@ -521,8 +521,8 @@ def expand_wildcard_targets_into_individual_jobs( # for such a target. Therefore labeling with Juju topology, excluding the # unit name. non_wildcard_static_config["labels"] = { - **non_wildcard_static_config.get("labels", {}), **topology.label_matcher_dict, + **non_wildcard_static_config.get("labels", {}), } non_wildcard_static_configs.append(non_wildcard_static_config) @@ -547,9 +547,9 @@ def expand_wildcard_targets_into_individual_jobs( if topology: # Add topology labels modified_static_config["labels"] = { - **modified_static_config.get("labels", {}), **topology.label_matcher_dict, **{"juju_unit": unit_name}, + **modified_static_config.get("labels", {}), } # Instance relabeling for topology should be last in order. @@ -1537,12 +1537,11 @@ def set_scrape_job_spec(self, _=None): relation.data[self._charm.app]["scrape_metadata"] = json.dumps(self._scrape_metadata) relation.data[self._charm.app]["scrape_jobs"] = json.dumps(self._scrape_jobs) - if alert_rules_as_dict: - # Update relation data with the string representation of the rule file. - # Juju topology is already included in the "scrape_metadata" field above. - # The consumer side of the relation uses this information to name the rules file - # that is written to the filesystem. - relation.data[self._charm.app]["alert_rules"] = json.dumps(alert_rules_as_dict) + # Update relation data with the string representation of the rule file. + # Juju topology is already included in the "scrape_metadata" field above. + # The consumer side of the relation uses this information to name the rules file + # that is written to the filesystem. + relation.data[self._charm.app]["alert_rules"] = json.dumps(alert_rules_as_dict) def _set_unit_ip(self, _=None): """Set unit host address. diff --git a/lib/charms/tempo_k8s/v2/tracing.py b/lib/charms/tempo_k8s/v2/tracing.py index d466531e..3cf54b1c 100644 --- a/lib/charms/tempo_k8s/v2/tracing.py +++ b/lib/charms/tempo_k8s/v2/tracing.py @@ -104,7 +104,7 @@ def __init__(self, *args): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 2 +LIBPATCH = 3 PYDEPS = ["pydantic"] @@ -117,15 +117,13 @@ def __init__(self, *args): "zipkin", "kafka", "opencensus", - "tempo", # legacy, renamed to tempo_http "tempo_http", "tempo_grpc", "otlp_grpc", "otlp_http", - "jaeger_grpc", + # "jaeger_grpc", "jaeger_thrift_compact", "jaeger_thrift_http", - "jaeger_http_thrift", # legacy, renamed to jaeger_thrift_http "jaeger_thrift_binary", ] @@ -302,11 +300,14 @@ class TracingProviderAppData(DatabagModel): # noqa: D101 """Application databag model for the tracing provider.""" host: str - """Server hostname.""" + """Server hostname (local fqdn).""" receivers: List[Receiver] """Enabled receivers and ports at which they are listening.""" + external_url: Optional[str] = None + """Server url. If an ingress is present, it will be the ingress address.""" + class TracingRequirerAppData(DatabagModel): # noqa: D101 """Application databag model for the tracing requirer.""" @@ -492,6 +493,7 @@ def __init__( self, charm: CharmBase, host: str, + external_url: Optional[str] = None, relation_name: str = DEFAULT_RELATION_NAME, ): """Initialize. @@ -499,6 +501,8 @@ def __init__( Args: charm: a `CharmBase` instance that manages this instance of the Tempo service. host: address of the node hosting the tempo server. + external_url: external address of the node hosting the tempo server, + if an ingress is present. relation_name: an optional string name of the relation between `charm` and the Tempo charmed service. The default is "tracing". @@ -519,6 +523,7 @@ def __init__( super().__init__(charm, relation_name + "tracing-provider-v2") self._charm = charm self._host = host + self._external_url = external_url self._relation_name = relation_name self.framework.observe( self._charm.on[relation_name].relation_joined, self._on_relation_event @@ -585,6 +590,7 @@ def publish_receivers(self, receivers: Sequence[RawReceiver]): try: TracingProviderAppData( host=self._host, + external_url=self._external_url, receivers=[ Receiver(port=port, protocol=protocol) for protocol, port in receivers ], @@ -612,16 +618,17 @@ class EndpointRemovedEvent(RelationBrokenEvent): class EndpointChangedEvent(_AutoSnapshotEvent): """Event representing a change in one of the receiver endpoints.""" - __args__ = ("host", "_ingesters") + __args__ = ("host", "external_url", "_receivers") if TYPE_CHECKING: host = "" # type: str - _ingesters = [] # type: List[dict] + external_url = "" # type: str + _receivers = [] # type: List[dict] @property def receivers(self) -> List[Receiver]: """Cast receivers back from dict.""" - return [Receiver(**i) for i in self._ingesters] + return [Receiver(**i) for i in self._receivers] class TracingEndpointRequirerEvents(CharmEvents): @@ -776,7 +783,9 @@ def _on_tracing_relation_changed(self, event): return data = TracingProviderAppData.load(relation.data[relation.app]) - self.on.endpoint_changed.emit(relation, data.host, [i.dict() for i in data.receivers]) # type: ignore + self.on.endpoint_changed.emit( # type: ignore + relation, data.host, data.external_url, [i.dict() for i in data.receivers] + ) def _on_tracing_relation_broken(self, event: RelationBrokenEvent): """Notify the providers that the endpoint is broken.""" @@ -787,28 +796,43 @@ def get_all_endpoints( self, relation: Optional[Relation] = None ) -> Optional[TracingProviderAppData]: """Unmarshalled relation data.""" - if not self.is_ready(relation or self._relation): + relation = relation or self._relation + if not self.is_ready(relation): return return TracingProviderAppData.load(relation.data[relation.app]) # type: ignore def _get_endpoint( - self, relation: Optional[Relation], protocol: ReceiverProtocol, ssl: bool = False - ): - ep = self.get_all_endpoints(relation) - if not ep: + self, relation: Optional[Relation], protocol: ReceiverProtocol + ) -> Optional[str]: + app_data = self.get_all_endpoints(relation) + if not app_data: return None - try: - receiver: Receiver = next(filter(lambda i: i.protocol == protocol, ep.receivers)) - if receiver.protocol in ["otlp_grpc", "jaeger_grpc"]: - if ssl: - logger.warning("unused ssl argument - was the right protocol called?") - return f"{ep.host}:{receiver.port}" - if ssl: - return f"https://{ep.host}:{receiver.port}" - return f"http://{ep.host}:{receiver.port}" - except StopIteration: + receivers: List[Receiver] = list( + filter(lambda i: i.protocol == protocol, app_data.receivers) + ) + if not receivers: logger.error(f"no receiver found with protocol={protocol!r}") - return None + return + if len(receivers) > 1: + logger.error( + f"too many receivers with protocol={protocol!r}; using first one. Found: {receivers}" + ) + return + + receiver = receivers[0] + # if there's an external_url argument (v2.5+), use that. Otherwise, we use the tempo local fqdn + if app_data.external_url: + url = app_data.external_url + else: + # FIXME: if we don't get an external url but only a + # hostname, we don't know what scheme we need to be using. ASSUME HTTP + url = f"http://{app_data.host}:{receiver.port}" + + if receiver.protocol.endswith("grpc"): + # TCP protocols don't want an http/https scheme prefix + url = url.split("://")[1] + + return url def get_endpoint( self, protocol: ReceiverProtocol, relation: Optional[Relation] = None diff --git a/lib/charms/tls_certificates_interface/v2/tls_certificates.py b/lib/charms/tls_certificates_interface/v3/tls_certificates.py similarity index 79% rename from lib/charms/tls_certificates_interface/v2/tls_certificates.py rename to lib/charms/tls_certificates_interface/v3/tls_certificates.py index 9f67833b..cbdd80d1 100644 --- a/lib/charms/tls_certificates_interface/v2/tls_certificates.py +++ b/lib/charms/tls_certificates_interface/v3/tls_certificates.py @@ -1,4 +1,4 @@ -# Copyright 2021 Canonical Ltd. +# Copyright 2024 Canonical Ltd. # See LICENSE file for licensing details. @@ -7,16 +7,19 @@ This library contains the Requires and Provides classes for handling the tls-certificates interface. +Pre-requisites: + - Juju >= 3.0 + ## Getting Started From a charm directory, fetch the library using `charmcraft`: ```shell -charmcraft fetch-lib charms.tls_certificates_interface.v2.tls_certificates +charmcraft fetch-lib charms.tls_certificates_interface.v3.tls_certificates ``` Add the following libraries to the charm's `requirements.txt` file: - jsonschema -- cryptography +- cryptography >= 42.0.0 Add the following section to the charm's `charmcraft.yaml` file: ```yaml @@ -36,10 +39,10 @@ Example: ```python -from charms.tls_certificates_interface.v2.tls_certificates import ( +from charms.tls_certificates_interface.v3.tls_certificates import ( CertificateCreationRequestEvent, CertificateRevocationRequestEvent, - TLSCertificatesProvidesV2, + TLSCertificatesProvidesV3, generate_private_key, ) from ops.charm import CharmBase, InstallEvent @@ -59,7 +62,7 @@ class ExampleProviderCharm(CharmBase): def __init__(self, *args): super().__init__(*args) - self.certificates = TLSCertificatesProvidesV2(self, "certificates") + self.certificates = TLSCertificatesProvidesV3(self, "certificates") self.framework.observe( self.certificates.on.certificate_request, self._on_certificate_request @@ -126,15 +129,15 @@ def _on_certificate_revocation_request(self, event: CertificateRevocationRequest Example: ```python -from charms.tls_certificates_interface.v2.tls_certificates import ( +from charms.tls_certificates_interface.v3.tls_certificates import ( CertificateAvailableEvent, CertificateExpiringEvent, CertificateRevokedEvent, - TLSCertificatesRequiresV2, + TLSCertificatesRequiresV3, generate_csr, generate_private_key, ) -from ops.charm import CharmBase, RelationJoinedEvent +from ops.charm import CharmBase, RelationCreatedEvent from ops.main import main from ops.model import ActiveStatus, WaitingStatus from typing import Union @@ -145,10 +148,10 @@ class ExampleRequirerCharm(CharmBase): def __init__(self, *args): super().__init__(*args) self.cert_subject = "whatever" - self.certificates = TLSCertificatesRequiresV2(self, "certificates") + self.certificates = TLSCertificatesRequiresV3(self, "certificates") self.framework.observe(self.on.install, self._on_install) self.framework.observe( - self.on.certificates_relation_joined, self._on_certificates_relation_joined + self.on.certificates_relation_created, self._on_certificates_relation_created ) self.framework.observe( self.certificates.on.certificate_available, self._on_certificate_available @@ -176,7 +179,7 @@ def _on_install(self, event) -> None: {"private_key_password": "banana", "private_key": private_key.decode()} ) - def _on_certificates_relation_joined(self, event: RelationJoinedEvent) -> None: + def _on_certificates_relation_created(self, event: RelationCreatedEvent) -> None: replicas_relation = self.model.get_relation("replicas") if not replicas_relation: self.unit.status = WaitingStatus("Waiting for peer relation to be created") @@ -277,15 +280,15 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven import logging import uuid from contextlib import suppress +from dataclasses import dataclass from datetime import datetime, timedelta, timezone from ipaddress import IPv4Address -from typing import Any, Dict, List, Literal, Optional, Union +from typing import List, Literal, Optional, Union from cryptography import x509 from cryptography.hazmat._oid import ExtensionOID from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import rsa -from cryptography.hazmat.primitives.serialization import pkcs12 from jsonschema import exceptions, validate from ops.charm import ( CharmBase, @@ -293,21 +296,27 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven RelationBrokenEvent, RelationChangedEvent, SecretExpiredEvent, - UpdateStatusEvent, ) from ops.framework import EventBase, EventSource, Handle, Object from ops.jujuversion import JujuVersion -from ops.model import ModelError, Relation, RelationDataContent, SecretNotFoundError +from ops.model import ( + Application, + ModelError, + Relation, + RelationDataContent, + SecretNotFoundError, + Unit, +) # The unique Charmhub library identifier, never change it LIBID = "afd8c2bccf834997afce12c2706d2ede" # Increment this major API version when introducing breaking changes -LIBAPI = 2 +LIBAPI = 3 # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 28 +LIBPATCH = 10 PYDEPS = ["cryptography", "jsonschema"] @@ -422,6 +431,34 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven logger = logging.getLogger(__name__) +@dataclass +class RequirerCSR: + """This class represents a certificate signing request from an interface Requirer.""" + + relation_id: int + application_name: str + unit_name: str + csr: str + is_ca: bool + + +@dataclass +class ProviderCertificate: + """This class represents a certificate from an interface Provider.""" + + relation_id: int + application_name: str + csr: str + certificate: str + ca: str + chain: List[str] + revoked: bool + + def chain_as_pem(self) -> str: + """Return full certificate chain as a PEM string.""" + return "\n\n".join(reversed(self.chain)) + + class CertificateAvailableEvent(EventBase): """Charm Event triggered when a TLS certificate is available.""" @@ -455,6 +492,10 @@ def restore(self, snapshot: dict): self.ca = snapshot["ca"] self.chain = snapshot["chain"] + def chain_as_pem(self) -> str: + """Return full certificate chain as a PEM string.""" + return "\n\n".join(reversed(self.chain)) + class CertificateExpiringEvent(EventBase): """Charm Event triggered when a TLS certificate is almost expired.""" @@ -886,38 +927,6 @@ def generate_certificate( return cert.public_bytes(serialization.Encoding.PEM) -def generate_pfx_package( - certificate: bytes, - private_key: bytes, - package_password: str, - private_key_password: Optional[bytes] = None, -) -> bytes: - """Generate a PFX package to contain the TLS certificate and private key. - - Args: - certificate (bytes): TLS certificate - private_key (bytes): Private key - package_password (str): Password to open the PFX package - private_key_password (bytes): Private key password - - Returns: - bytes: - """ - private_key_object = serialization.load_pem_private_key( - private_key, password=private_key_password - ) - certificate_object = x509.load_pem_x509_certificate(certificate) - name = certificate_object.subject.rfc4514_string() - pfx_bytes = pkcs12.serialize_key_and_certificates( - name=name.encode(), - cert=certificate_object, - key=private_key_object, # type: ignore[arg-type] - cas=None, - encryption_algorithm=serialization.BestAvailableEncryption(package_password.encode()), - ) - return pfx_bytes - - def generate_private_key( password: Optional[bytes] = None, key_size: int = 2048, @@ -1053,6 +1062,27 @@ def csr_matches_certificate(csr: str, cert: str) -> bool: return True +def _relation_data_is_valid( + relation: Relation, app_or_unit: Union[Application, Unit], json_schema: dict +) -> bool: + """Check whether relation data is valid based on json schema. + + Args: + relation (Relation): Relation object + app_or_unit (Union[Application, Unit]): Application or unit object + json_schema (dict): Json schema + + Returns: + bool: Whether relation data is valid. + """ + relation_data = _load_relation_data(relation.data[app_or_unit]) + try: + validate(instance=relation_data, schema=json_schema) + return True + except exceptions.ValidationError: + return False + + class CertificatesProviderCharmEvents(CharmEvents): """List of events that the TLS Certificates provider charm can leverage.""" @@ -1069,7 +1099,7 @@ class CertificatesRequirerCharmEvents(CharmEvents): all_certificates_invalidated = EventSource(AllCertificatesInvalidatedEvent) -class TLSCertificatesProvidesV2(Object): +class TLSCertificatesProvidesV3(Object): """TLS certificates provider class to be instantiated by TLS certificates providers.""" on = CertificatesProviderCharmEvents() # type: ignore[reportAssignmentType] @@ -1178,22 +1208,6 @@ def _remove_certificate( certificates.remove(certificate_dict) relation.data[self.model.app]["certificates"] = json.dumps(certificates) - @staticmethod - def _relation_data_is_valid(certificates_data: dict) -> bool: - """Use JSON schema validator to validate relation data content. - - Args: - certificates_data (dict): Certificate data dictionary as retrieved from relation data. - - Returns: - bool: True/False depending on whether the relation data follows the json schema. - """ - try: - validate(instance=certificates_data, schema=REQUIRER_JSON_SCHEMA) - return True - except exceptions.ValidationError: - return False - def revoke_all_certificates(self) -> None: """Revoke all certificates of this provider. @@ -1262,16 +1276,24 @@ def remove_certificate(self, certificate: str) -> None: def get_issued_certificates( self, relation_id: Optional[int] = None - ) -> Dict[str, List[Dict[str, str]]]: - """Return a dictionary of issued certificates. + ) -> List[ProviderCertificate]: + """Return a List of issued (non revoked) certificates. + + Returns: + List: List of ProviderCertificate objects + """ + provider_certificates = self.get_provider_certificates(relation_id=relation_id) + return [certificate for certificate in provider_certificates if not certificate.revoked] - It returns certificates from all relations if relation_id is not specified. - Certificates are returned per application name and CSR. + def get_provider_certificates( + self, relation_id: Optional[int] = None + ) -> List[ProviderCertificate]: + """Return a List of issued certificates. Returns: - dict: Certificates per application name. + List: List of ProviderCertificate objects """ - certificates: Dict[str, List[Dict[str, str]]] = {} + certificates: List[ProviderCertificate] = [] relations = ( [ relation @@ -1282,19 +1304,22 @@ def get_issued_certificates( else self.model.relations.get(self.relationship_name, []) ) for relation in relations: + if not relation.app: + logger.warning("Relation %s does not have an application", relation.id) + continue provider_relation_data = self._load_app_relation_data(relation) provider_certificates = provider_relation_data.get("certificates", []) - - certificates[relation.app.name] = [] # type: ignore[union-attr] for certificate in provider_certificates: - if not certificate.get("revoked", False): - certificates[relation.app.name].append( # type: ignore[union-attr] - { - "csr": certificate["certificate_signing_request"], - "certificate": certificate["certificate"], - } - ) - + provider_certificate = ProviderCertificate( + relation_id=relation.id, + application_name=relation.app.name, + csr=certificate["certificate_signing_request"], + certificate=certificate["certificate"], + ca=certificate["ca"], + chain=certificate["chain"], + revoked=certificate.get("revoked", False), + ) + certificates.append(provider_certificate) return certificates def _on_relation_changed(self, event: RelationChangedEvent) -> None: @@ -1317,124 +1342,77 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: return if not self.model.unit.is_leader(): return - requirer_relation_data = _load_relation_data(event.relation.data[event.unit]) - provider_relation_data = self._load_app_relation_data(event.relation) - if not self._relation_data_is_valid(requirer_relation_data): + if not _relation_data_is_valid(event.relation, event.unit, REQUIRER_JSON_SCHEMA): logger.debug("Relation data did not pass JSON Schema validation") return - provider_certificates = provider_relation_data.get("certificates", []) - requirer_csrs = requirer_relation_data.get("certificate_signing_requests", []) + provider_certificates = self.get_provider_certificates(relation_id=event.relation.id) + requirer_csrs = self.get_requirer_csrs(relation_id=event.relation.id) provider_csrs = [ - certificate_creation_request["certificate_signing_request"] + certificate_creation_request.csr for certificate_creation_request in provider_certificates ] - requirer_unit_certificate_requests = [ - { - "csr": certificate_creation_request["certificate_signing_request"], - "is_ca": certificate_creation_request.get("ca", False), - } - for certificate_creation_request in requirer_csrs - ] - for certificate_request in requirer_unit_certificate_requests: - if certificate_request["csr"] not in provider_csrs: + for certificate_request in requirer_csrs: + if certificate_request.csr not in provider_csrs: self.on.certificate_creation_request.emit( - certificate_signing_request=certificate_request["csr"], - relation_id=event.relation.id, - is_ca=certificate_request["is_ca"], + certificate_signing_request=certificate_request.csr, + relation_id=certificate_request.relation_id, + is_ca=certificate_request.is_ca, ) self._revoke_certificates_for_which_no_csr_exists(relation_id=event.relation.id) def _revoke_certificates_for_which_no_csr_exists(self, relation_id: int) -> None: """Revoke certificates for which no unit has a CSR. - Goes through all generated certificates and compare against the list of CSRs for all units - of a given relationship. - - Args: - relation_id (int): Relation id + Goes through all generated certificates and compare against the list of CSRs for all units. Returns: None """ - certificates_relation = self.model.get_relation( - relation_name=self.relationship_name, relation_id=relation_id - ) - if not certificates_relation: - raise RuntimeError(f"Relation {self.relationship_name} does not exist") - provider_relation_data = self._load_app_relation_data(certificates_relation) - list_of_csrs: List[str] = [] - for unit in certificates_relation.units: - requirer_relation_data = _load_relation_data(certificates_relation.data[unit]) - requirer_csrs = requirer_relation_data.get("certificate_signing_requests", []) - list_of_csrs.extend(csr["certificate_signing_request"] for csr in requirer_csrs) - provider_certificates = provider_relation_data.get("certificates", []) + provider_certificates = self.get_provider_certificates(relation_id) + requirer_csrs = self.get_requirer_csrs(relation_id) + list_of_csrs = [csr.csr for csr in requirer_csrs] for certificate in provider_certificates: - if certificate["certificate_signing_request"] not in list_of_csrs: + if certificate.csr not in list_of_csrs: self.on.certificate_revocation_request.emit( - certificate=certificate["certificate"], - certificate_signing_request=certificate["certificate_signing_request"], - ca=certificate["ca"], - chain=certificate["chain"], + certificate=certificate.certificate, + certificate_signing_request=certificate.csr, + ca=certificate.ca, + chain=certificate.chain, ) - self.remove_certificate(certificate=certificate["certificate"]) + self.remove_certificate(certificate=certificate.certificate) def get_outstanding_certificate_requests( self, relation_id: Optional[int] = None - ) -> List[Dict[str, Union[int, str, List[Dict[str, str]]]]]: + ) -> List[RequirerCSR]: """Return CSR's for which no certificate has been issued. - Example return: [ - { - "relation_id": 0, - "application_name": "tls-certificates-requirer", - "unit_name": "tls-certificates-requirer/0", - "unit_csrs": [ - { - "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", - "is_ca": false - } - ] - } - ] - Args: relation_id (int): Relation id Returns: - list: List of dictionaries that contain the unit's csrs - that don't have a certificate issued. + list: List of RequirerCSR objects. """ - all_unit_csr_mappings = copy.deepcopy(self.get_requirer_csrs(relation_id=relation_id)) - filtered_all_unit_csr_mappings: List[Dict[str, Union[int, str, List[Dict[str, str]]]]] = [] - for unit_csr_mapping in all_unit_csr_mappings: - csrs_without_certs = [] - for csr in unit_csr_mapping["unit_csrs"]: # type: ignore[union-attr] - if not self.certificate_issued_for_csr( - app_name=unit_csr_mapping["application_name"], # type: ignore[arg-type] - csr=csr["certificate_signing_request"], # type: ignore[index] - relation_id=relation_id, - ): - csrs_without_certs.append(csr) - if csrs_without_certs: - unit_csr_mapping["unit_csrs"] = csrs_without_certs # type: ignore[assignment] - filtered_all_unit_csr_mappings.append(unit_csr_mapping) - return filtered_all_unit_csr_mappings - - def get_requirer_csrs( - self, relation_id: Optional[int] = None - ) -> List[Dict[str, Union[int, str, List[Dict[str, str]]]]]: - """Return a list of requirers' CSRs grouped by unit. + requirer_csrs = self.get_requirer_csrs(relation_id=relation_id) + outstanding_csrs: List[RequirerCSR] = [] + for relation_csr in requirer_csrs: + if not self.certificate_issued_for_csr( + app_name=relation_csr.application_name, + csr=relation_csr.csr, + relation_id=relation_id, + ): + outstanding_csrs.append(relation_csr) + return outstanding_csrs + + def get_requirer_csrs(self, relation_id: Optional[int] = None) -> List[RequirerCSR]: + """Return a list of requirers' CSRs. It returns CSRs from all relations if relation_id is not specified. CSRs are returned per relation id, application name and unit name. Returns: - list: List of dictionaries that contain the unit's csrs - with the following information - relation_id, application_name and unit_name. + list: List[RequirerCSR] """ - unit_csr_mappings: List[Dict[str, Union[int, str, List[Dict[str, str]]]]] = [] - + relation_csrs: List[RequirerCSR] = [] relations = ( [ relation @@ -1449,15 +1427,24 @@ def get_requirer_csrs( for unit in relation.units: requirer_relation_data = _load_relation_data(relation.data[unit]) unit_csrs_list = requirer_relation_data.get("certificate_signing_requests", []) - unit_csr_mappings.append( - { - "relation_id": relation.id, - "application_name": relation.app.name, # type: ignore[union-attr] - "unit_name": unit.name, - "unit_csrs": unit_csrs_list, - } - ) - return unit_csr_mappings + for unit_csr in unit_csrs_list: + csr = unit_csr.get("certificate_signing_request") + if not csr: + logger.warning("No CSR found in relation data - Skipping") + continue + ca = unit_csr.get("ca", False) + if not relation.app: + logger.warning("No remote app in relation - Skipping") + continue + relation_csr = RequirerCSR( + relation_id=relation.id, + application_name=relation.app.name, + unit_name=unit.name, + csr=csr, + is_ca=ca, + ) + relation_csrs.append(relation_csr) + return relation_csrs def certificate_issued_for_csr( self, app_name: str, csr: str, relation_id: Optional[int] @@ -1468,19 +1455,18 @@ def certificate_issued_for_csr( app_name (str): Application name that the CSR belongs to. csr (str): Certificate Signing Request. relation_id (Optional[int]): Relation ID + Returns: bool: True/False depending on whether a certificate has been issued for the given CSR. """ - issued_certificates_per_csr = self.get_issued_certificates(relation_id=relation_id)[ - app_name - ] - for issued_pair in issued_certificates_per_csr: - if "csr" in issued_pair and issued_pair["csr"] == csr: - return csr_matches_certificate(csr, issued_pair["certificate"]) + issued_certificates_per_csr = self.get_issued_certificates(relation_id=relation_id) + for issued_certificate in issued_certificates_per_csr: + if issued_certificate.csr == csr and issued_certificate.application_name == app_name: + return csr_matches_certificate(csr, issued_certificate.certificate) return False -class TLSCertificatesRequiresV2(Object): +class TLSCertificatesRequiresV3(Object): """TLS certificates requirer class to be instantiated by TLS certificates requirers.""" on = CertificatesRequirerCharmEvents() # type: ignore[reportAssignmentType] @@ -1500,6 +1486,8 @@ def __init__( Used to trigger the CertificateExpiring event. Default: 7 days. """ super().__init__(charm, relationship_name) + if not JujuVersion.from_environ().has_secrets: + logger.warning("This version of the TLS library requires Juju secrets (Juju >= 3.0)") self.relationship_name = relationship_name self.charm = charm self.expiry_notification_time = expiry_notification_time @@ -1509,32 +1497,39 @@ def __init__( self.framework.observe( charm.on[relationship_name].relation_broken, self._on_relation_broken ) - if JujuVersion.from_environ().has_secrets: - self.framework.observe(charm.on.secret_expired, self._on_secret_expired) - else: - self.framework.observe(charm.on.update_status, self._on_update_status) + self.framework.observe(charm.on.secret_expired, self._on_secret_expired) - @property - def _requirer_csrs(self) -> List[Dict[str, Union[bool, str]]]: + def get_requirer_csrs(self) -> List[RequirerCSR]: """Return list of requirer's CSRs from relation unit data. - Example: - [ - { - "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", - "ca": false - } - ] + Returns: + list: List of RequirerCSR objects. """ relation = self.model.get_relation(self.relationship_name) if not relation: - raise RuntimeError(f"Relation {self.relationship_name} does not exist") + return [] + requirer_csrs = [] requirer_relation_data = _load_relation_data(relation.data[self.model.unit]) - return requirer_relation_data.get("certificate_signing_requests", []) + requirer_csrs_dict = requirer_relation_data.get("certificate_signing_requests", []) + for requirer_csr_dict in requirer_csrs_dict: + csr = requirer_csr_dict.get("certificate_signing_request") + if not csr: + logger.warning("No CSR found in relation data - Skipping") + continue + ca = requirer_csr_dict.get("ca", False) + relation_csr = RequirerCSR( + relation_id=relation.id, + application_name=self.model.app.name, + unit_name=self.model.unit.name, + csr=csr, + is_ca=ca, + ) + requirer_csrs.append(relation_csr) + return requirer_csrs - @property - def _provider_certificates(self) -> List[Dict[str, str]]: + def get_provider_certificates(self) -> List[ProviderCertificate]: """Return list of certificates from the provider's relation data.""" + provider_certificates: List[ProviderCertificate] = [] relation = self.model.get_relation(self.relationship_name) if not relation: logger.debug("No relation: %s", self.relationship_name) @@ -1543,12 +1538,32 @@ def _provider_certificates(self) -> List[Dict[str, str]]: logger.debug("No remote app in relation: %s", self.relationship_name) return [] provider_relation_data = _load_relation_data(relation.data[relation.app]) - if not self._relation_data_is_valid(provider_relation_data): - logger.warning("Provider relation data did not pass JSON Schema validation") - return [] - return provider_relation_data.get("certificates", []) + provider_certificate_dicts = provider_relation_data.get("certificates", []) + for provider_certificate_dict in provider_certificate_dicts: + certificate = provider_certificate_dict.get("certificate") + if not certificate: + logger.warning("No certificate found in relation data - Skipping") + continue + ca = provider_certificate_dict.get("ca") + chain = provider_certificate_dict.get("chain", []) + csr = provider_certificate_dict.get("certificate_signing_request") + if not csr: + logger.warning("No CSR found in relation data - Skipping") + continue + revoked = provider_certificate_dict.get("revoked", False) + provider_certificate = ProviderCertificate( + relation_id=relation.id, + application_name=relation.app.name, + csr=csr, + certificate=certificate, + ca=ca, + chain=chain, + revoked=revoked, + ) + provider_certificates.append(provider_certificate) + return provider_certificates - def _add_requirer_csr(self, csr: str, is_ca: bool) -> None: + def _add_requirer_csr_to_relation_data(self, csr: str, is_ca: bool) -> None: """Add CSR to relation data. Args: @@ -1564,18 +1579,23 @@ def _add_requirer_csr(self, csr: str, is_ca: bool) -> None: f"Relation {self.relationship_name} does not exist - " f"The certificate request can't be completed" ) - new_csr_dict: Dict[str, Union[bool, str]] = { + for requirer_csr in self.get_requirer_csrs(): + if requirer_csr.csr == csr and requirer_csr.is_ca == is_ca: + logger.info("CSR already in relation data - Doing nothing") + return + new_csr_dict = { "certificate_signing_request": csr, "ca": is_ca, } - if new_csr_dict in self._requirer_csrs: - logger.info("CSR already in relation data - Doing nothing") - return - requirer_csrs = copy.deepcopy(self._requirer_csrs) - requirer_csrs.append(new_csr_dict) - relation.data[self.model.unit]["certificate_signing_requests"] = json.dumps(requirer_csrs) + requirer_relation_data = _load_relation_data(relation.data[self.model.unit]) + existing_relation_data = requirer_relation_data.get("certificate_signing_requests", []) + new_relation_data = copy.deepcopy(existing_relation_data) + new_relation_data.append(new_csr_dict) + relation.data[self.model.unit]["certificate_signing_requests"] = json.dumps( + new_relation_data + ) - def _remove_requirer_csr(self, csr: str) -> None: + def _remove_requirer_csr_from_relation_data(self, csr: str) -> None: """Remove CSR from relation data. Args: @@ -1590,14 +1610,18 @@ def _remove_requirer_csr(self, csr: str) -> None: f"Relation {self.relationship_name} does not exist - " f"The certificate request can't be completed" ) - requirer_csrs = copy.deepcopy(self._requirer_csrs) - if not requirer_csrs: + if not self.get_requirer_csrs(): logger.info("No CSRs in relation data - Doing nothing") return - for requirer_csr in requirer_csrs: + requirer_relation_data = _load_relation_data(relation.data[self.model.unit]) + existing_relation_data = requirer_relation_data.get("certificate_signing_requests", []) + new_relation_data = copy.deepcopy(existing_relation_data) + for requirer_csr in new_relation_data: if requirer_csr["certificate_signing_request"] == csr: - requirer_csrs.remove(requirer_csr) - relation.data[self.model.unit]["certificate_signing_requests"] = json.dumps(requirer_csrs) + new_relation_data.remove(requirer_csr) + relation.data[self.model.unit]["certificate_signing_requests"] = json.dumps( + new_relation_data + ) def request_certificate_creation( self, certificate_signing_request: bytes, is_ca: bool = False @@ -1617,7 +1641,9 @@ def request_certificate_creation( f"Relation {self.relationship_name} does not exist - " f"The certificate request can't be completed" ) - self._add_requirer_csr(certificate_signing_request.decode().strip(), is_ca=is_ca) + self._add_requirer_csr_to_relation_data( + certificate_signing_request.decode().strip(), is_ca=is_ca + ) logger.info("Certificate request sent to provider") def request_certificate_revocation(self, certificate_signing_request: bytes) -> None: @@ -1633,7 +1659,7 @@ def request_certificate_revocation(self, certificate_signing_request: bytes) -> Returns: None """ - self._remove_requirer_csr(certificate_signing_request.decode().strip()) + self._remove_requirer_csr_from_relation_data(certificate_signing_request.decode().strip()) logger.info("Certificate revocation sent to provider") def request_certificate_renewal( @@ -1661,107 +1687,62 @@ def request_certificate_renewal( ) logger.info("Certificate renewal request completed.") - def get_assigned_certificates(self) -> List[Dict[str, str]]: + def get_assigned_certificates(self) -> List[ProviderCertificate]: """Get a list of certificates that were assigned to this unit. Returns: - List of certificates. For example: - [ - { - "ca": "-----BEGIN CERTIFICATE-----...", - "chain": [ - "-----BEGIN CERTIFICATE-----..." - ], - "certificate": "-----BEGIN CERTIFICATE-----...", - "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", - } - ] + List: List[ProviderCertificate] """ - final_list = [] - for csr in self.get_certificate_signing_requests(fulfilled_only=True): - assert isinstance(csr["certificate_signing_request"], str) - if cert := self._find_certificate_in_relation_data(csr["certificate_signing_request"]): - final_list.append(cert) - return final_list - - def get_expiring_certificates(self) -> List[Dict[str, str]]: + assigned_certificates = [] + for requirer_csr in self.get_certificate_signing_requests(fulfilled_only=True): + if cert := self._find_certificate_in_relation_data(requirer_csr.csr): + assigned_certificates.append(cert) + return assigned_certificates + + def get_expiring_certificates(self) -> List[ProviderCertificate]: """Get a list of certificates that were assigned to this unit that are expiring or expired. Returns: - List of certificates. For example: - [ - { - "ca": "-----BEGIN CERTIFICATE-----...", - "chain": [ - "-----BEGIN CERTIFICATE-----..." - ], - "certificate": "-----BEGIN CERTIFICATE-----...", - "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", - } - ] + List: List[ProviderCertificate] """ - final_list = [] - for csr in self.get_certificate_signing_requests(fulfilled_only=True): - assert isinstance(csr["certificate_signing_request"], str) - if cert := self._find_certificate_in_relation_data(csr["certificate_signing_request"]): - expiry_time = _get_certificate_expiry_time(cert["certificate"]) + expiring_certificates: List[ProviderCertificate] = [] + for requirer_csr in self.get_certificate_signing_requests(fulfilled_only=True): + if cert := self._find_certificate_in_relation_data(requirer_csr.csr): + expiry_time = _get_certificate_expiry_time(cert.certificate) if not expiry_time: continue expiry_notification_time = expiry_time - timedelta( hours=self.expiry_notification_time ) if datetime.now(timezone.utc) > expiry_notification_time: - final_list.append(cert) - return final_list + expiring_certificates.append(cert) + return expiring_certificates def get_certificate_signing_requests( self, fulfilled_only: bool = False, unfulfilled_only: bool = False, - ) -> List[Dict[str, Union[bool, str]]]: + ) -> List[RequirerCSR]: """Get the list of CSR's that were sent to the provider. You can choose to get only the CSR's that have a certificate assigned or only the CSR's - that don't. + that don't. Args: fulfilled_only (bool): This option will discard CSRs that don't have certificates yet. unfulfilled_only (bool): This option will discard CSRs that have certificates signed. Returns: - List of CSR dictionaries. For example: - [ - { - "certificate_signing_request": "-----BEGIN CERTIFICATE REQUEST-----...", - "ca": false - } - ] + List of RequirerCSR objects. """ - final_list = [] - for csr in self._requirer_csrs: - assert isinstance(csr["certificate_signing_request"], str) - cert = self._find_certificate_in_relation_data(csr["certificate_signing_request"]) + csrs = [] + for requirer_csr in self.get_requirer_csrs(): + cert = self._find_certificate_in_relation_data(requirer_csr.csr) if (unfulfilled_only and cert) or (fulfilled_only and not cert): continue - final_list.append(csr) - - return final_list + csrs.append(requirer_csr) - @staticmethod - def _relation_data_is_valid(certificates_data: dict) -> bool: - """Check whether relation data is valid based on json schema. - - Args: - certificates_data: Certificate data in dict format. - - Returns: - bool: Whether relation data is valid. - """ - try: - validate(instance=certificates_data, schema=PROVIDER_JSON_SCHEMA) - return True - except exceptions.ValidationError: - return False + return csrs def _on_relation_changed(self, event: RelationChangedEvent) -> None: """Handle relation changed event. @@ -1771,9 +1752,8 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: If the provider certificate is revoked, emit a CertificateInvalidateEvent, otherwise emit a CertificateAvailableEvent. - When Juju secrets are available, remove the secret for revoked certificate, - or add a secret with the correct expiry time for new certificates. - + Remove the secret for revoked certificate, or add a secret with the correct expiry + time for new certificates. Args: event: Juju event @@ -1781,51 +1761,48 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: Returns: None """ + if not event.app: + logger.warning("No remote app in relation - Skipping") + return + if not _relation_data_is_valid(event.relation, event.app, PROVIDER_JSON_SCHEMA): + logger.debug("Relation data did not pass JSON Schema validation") + return + provider_certificates = self.get_provider_certificates() requirer_csrs = [ - certificate_creation_request["certificate_signing_request"] - for certificate_creation_request in self._requirer_csrs + certificate_creation_request.csr + for certificate_creation_request in self.get_requirer_csrs() ] - for certificate in self._provider_certificates: - if certificate["certificate_signing_request"] in requirer_csrs: - if certificate.get("revoked", False): - if JujuVersion.from_environ().has_secrets: - with suppress(SecretNotFoundError): - secret = self.model.get_secret( - label=f"{LIBID}-{certificate['certificate_signing_request']}" - ) - secret.remove_all_revisions() + for certificate in provider_certificates: + if certificate.csr in requirer_csrs: + if certificate.revoked: + with suppress(SecretNotFoundError): + secret = self.model.get_secret(label=f"{LIBID}-{certificate.csr}") + secret.remove_all_revisions() self.on.certificate_invalidated.emit( reason="revoked", - certificate=certificate["certificate"], - certificate_signing_request=certificate["certificate_signing_request"], - ca=certificate["ca"], - chain=certificate["chain"], + certificate=certificate.certificate, + certificate_signing_request=certificate.csr, + ca=certificate.ca, + chain=certificate.chain, ) else: - if JujuVersion.from_environ().has_secrets: - try: - secret = self.model.get_secret( - label=f"{LIBID}-{certificate['certificate_signing_request']}" - ) - secret.set_content({"certificate": certificate["certificate"]}) - secret.set_info( - expire=self._get_next_secret_expiry_time( - certificate["certificate"] - ), - ) - except SecretNotFoundError: - secret = self.charm.unit.add_secret( - {"certificate": certificate["certificate"]}, - label=f"{LIBID}-{certificate['certificate_signing_request']}", - expire=self._get_next_secret_expiry_time( - certificate["certificate"] - ), - ) + try: + secret = self.model.get_secret(label=f"{LIBID}-{certificate.csr}") + secret.set_content({"certificate": certificate.certificate}) + secret.set_info( + expire=self._get_next_secret_expiry_time(certificate.certificate), + ) + except SecretNotFoundError: + secret = self.charm.unit.add_secret( + {"certificate": certificate.certificate}, + label=f"{LIBID}-{certificate.csr}", + expire=self._get_next_secret_expiry_time(certificate.certificate), + ) self.on.certificate_available.emit( - certificate_signing_request=certificate["certificate_signing_request"], - certificate=certificate["certificate"], - ca=certificate["ca"], - chain=certificate["chain"], + certificate_signing_request=certificate.csr, + certificate=certificate.certificate, + ca=certificate.ca, + chain=certificate.chain, ) def _get_next_secret_expiry_time(self, certificate: str) -> Optional[datetime]: @@ -1849,7 +1826,7 @@ def _get_next_secret_expiry_time(self, certificate: str) -> Optional[datetime]: return _get_closest_future_time(expiry_notification_time, expiry_time) def _on_relation_broken(self, event: RelationBrokenEvent) -> None: - """Handle relation broken event. + """Handle Relation Broken Event. Emitting `all_certificates_invalidated` from `relation-broken` rather than `relation-departed` since certs are stored in app data. @@ -1863,7 +1840,7 @@ def _on_relation_broken(self, event: RelationBrokenEvent) -> None: self.on.all_certificates_invalidated.emit() def _on_secret_expired(self, event: SecretExpiredEvent) -> None: - """Handle secret expired event. + """Handle Secret Expired Event. Loads the certificate from the secret, and will emit 1 of 2 events. @@ -1881,13 +1858,13 @@ def _on_secret_expired(self, event: SecretExpiredEvent) -> None: if not event.secret.label or not event.secret.label.startswith(f"{LIBID}-"): return csr = event.secret.label[len(f"{LIBID}-") :] - certificate_dict = self._find_certificate_in_relation_data(csr) - if not certificate_dict: + provider_certificate = self._find_certificate_in_relation_data(csr) + if not provider_certificate: # A secret expired but we did not find matching certificate. Cleaning up event.secret.remove_all_revisions() return - expiry_time = _get_certificate_expiry_time(certificate_dict["certificate"]) + expiry_time = _get_certificate_expiry_time(provider_certificate.certificate) if not expiry_time: # A secret expired but matching certificate is invalid. Cleaning up event.secret.remove_all_revisions() @@ -1896,64 +1873,28 @@ def _on_secret_expired(self, event: SecretExpiredEvent) -> None: if datetime.now(timezone.utc) < expiry_time: logger.warning("Certificate almost expired") self.on.certificate_expiring.emit( - certificate=certificate_dict["certificate"], + certificate=provider_certificate.certificate, expiry=expiry_time.isoformat(), ) event.secret.set_info( - expire=_get_certificate_expiry_time(certificate_dict["certificate"]), + expire=_get_certificate_expiry_time(provider_certificate.certificate), ) else: logger.warning("Certificate is expired") self.on.certificate_invalidated.emit( reason="expired", - certificate=certificate_dict["certificate"], - certificate_signing_request=certificate_dict["certificate_signing_request"], - ca=certificate_dict["ca"], - chain=certificate_dict["chain"], + certificate=provider_certificate.certificate, + certificate_signing_request=provider_certificate.csr, + ca=provider_certificate.ca, + chain=provider_certificate.chain, ) - self.request_certificate_revocation(certificate_dict["certificate"].encode()) + self.request_certificate_revocation(provider_certificate.certificate.encode()) event.secret.remove_all_revisions() - def _find_certificate_in_relation_data(self, csr: str) -> Optional[Dict[str, Any]]: + def _find_certificate_in_relation_data(self, csr: str) -> Optional[ProviderCertificate]: """Return the certificate that match the given CSR.""" - for certificate_dict in self._provider_certificates: - if certificate_dict["certificate_signing_request"] != csr: + for provider_certificate in self.get_provider_certificates(): + if provider_certificate.csr != csr: continue - return certificate_dict + return provider_certificate return None - - def _on_update_status(self, event: UpdateStatusEvent) -> None: - """Handle update status event. - - Goes through each certificate in the "certificates" relation and checks their expiry date. - If they are close to expire (<7 days), emits a CertificateExpiringEvent event and if - they are expired, emits a CertificateExpiredEvent. - - Args: - event (UpdateStatusEvent): Juju event - - Returns: - None - """ - for certificate_dict in self._provider_certificates: - expiry_time = _get_certificate_expiry_time(certificate_dict["certificate"]) - if not expiry_time: - continue - time_difference = expiry_time - datetime.now(timezone.utc) - if time_difference.total_seconds() < 0: - logger.warning("Certificate is expired") - self.on.certificate_invalidated.emit( - reason="expired", - certificate=certificate_dict["certificate"], - certificate_signing_request=certificate_dict["certificate_signing_request"], - ca=certificate_dict["ca"], - chain=certificate_dict["chain"], - ) - self.request_certificate_revocation(certificate_dict["certificate"].encode()) - continue - if time_difference.total_seconds() < (self.expiry_notification_time * 60 * 60): - logger.warning("Certificate almost expired") - self.on.certificate_expiring.emit( - certificate=certificate_dict["certificate"], - expiry=expiry_time.isoformat(), - ) diff --git a/src/charm.py b/src/charm.py index e0c6465b..a6c8796c 100755 --- a/src/charm.py +++ b/src/charm.py @@ -848,7 +848,13 @@ def _on_pebble_ready(self, event) -> None: self.dashboard_consumer.update_dashboards() self._update_dashboards(event) - # In case of a restart caused by an error, we collect all trusted certs from relation receive-ca-cert + # Create provisioning subfolders to avoid errors on startup + workload = self.containers["workload"] + for d in ("plugins", "notifiers", "alerting"): + workload.make_dir(Path(PROVISIONING_PATH) / d, make_parents=True) + + # In case of a restart caused by an error, we collect all trusted certs from relation + # receive-ca-cert self._update_trusted_ca_certs() version = self.grafana_version @@ -992,7 +998,13 @@ def _build_layer(self) -> Layer: } ) - if self._cert_ready(): + # For consistency, set cert entries on the same condition as scheme is set to https. + # NOTE: On one hand, we want to tell grafana to use TLS as soon as the tls relation is in + # place; on the other hand, the certs may not be written to disk yet (they need to be + # returned over relation data, go to peer data, and eventually be written to disk). When + # grafana is restarted in HTTPS mode but without certs in place, we'll see a brief error: + # "error: cert_file cannot be empty when using HTTPS". + if self._scheme == "https": extra_info.update( { "GF_SERVER_CERT_KEY": GRAFANA_KEY_PATH, diff --git a/tests/integration/test_tls_web.py b/tests/integration/test_tls_web.py index c469a0aa..064b84b1 100644 --- a/tests/integration/test_tls_web.py +++ b/tests/integration/test_tls_web.py @@ -22,6 +22,7 @@ } +@pytest.mark.abort_on_fail async def test_deploy(ops_test, grafana_charm): await asyncio.gather( ops_test.model.deploy(