diff --git a/lib/charms/data_platform_libs/v0/data_interfaces.py b/lib/charms/data_platform_libs/v0/data_interfaces.py index 7fff3c4751..14c7cadca4 100644 --- a/lib/charms/data_platform_libs/v0/data_interfaces.py +++ b/lib/charms/data_platform_libs/v0/data_interfaces.py @@ -331,7 +331,7 @@ def _on_topic_requested(self, event: TopicRequestedEvent): # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 42 +LIBPATCH = 45 PYDEPS = ["ops>=2.0.0"] @@ -351,6 +351,7 @@ def _on_topic_requested(self, event: TopicRequestedEvent): PROV_SECRET_PREFIX = "secret-" +PROV_SECRET_FIELDS = "provided-secrets" REQ_SECRET_FIELDS = "requested-secrets" GROUP_MAPPING_FIELD = "secret_group_mapping" GROUP_SEPARATOR = "@" @@ -585,6 +586,7 @@ class SecretGroupsAggregate(str): def __init__(self): self.USER = SecretGroup("user") self.TLS = SecretGroup("tls") + self.MTLS = SecretGroup("mtls") self.EXTRA = SecretGroup("extra") def __setattr__(self, name, value): @@ -963,8 +965,11 @@ class Data(ABC): "read-only-uris": SECRET_GROUPS.USER, "tls": SECRET_GROUPS.TLS, "tls-ca": SECRET_GROUPS.TLS, + "mtls-cert": SECRET_GROUPS.MTLS, } + SECRET_FIELDS = [] + def __init__( self, model: Model, @@ -978,6 +983,8 @@ def __init__( self.component = self.local_app if self.SCOPE == Scope.APP else self.local_unit self.secrets = SecretCache(self._model, self.component) self.data_component = None + self._local_secret_fields = [] + self._remote_secret_fields = list(self.SECRET_FIELDS) @property def relations(self) -> List[Relation]: @@ -1000,38 +1007,250 @@ def secret_label_map(self): """Exposing secret-label map via a property -- could be overridden in descendants!""" return self.SECRET_LABEL_MAP + @property + def local_secret_fields(self) -> Optional[List[str]]: + """Local access to secrets field, in case they are being used.""" + if self.secrets_enabled: + return self._local_secret_fields + + @property + def remote_secret_fields(self) -> Optional[List[str]]: + """Local access to secrets field, in case they are being used.""" + if self.secrets_enabled: + return self._remote_secret_fields + + @property + def my_secret_groups(self) -> Optional[List[SecretGroup]]: + """Local access to secrets field, in case they are being used.""" + if self.secrets_enabled: + return [ + self.SECRET_LABEL_MAP[field] + for field in self._local_secret_fields + if field in self.SECRET_LABEL_MAP + ] + # Mandatory overrides for internal/helper methods - @abstractmethod + @juju_secrets_only def _get_relation_secret( self, relation_id: int, group_mapping: SecretGroup, relation_name: Optional[str] = None ) -> Optional[CachedSecret]: """Retrieve a Juju Secret that's been stored in the relation databag.""" - raise NotImplementedError + if not relation_name: + relation_name = self.relation_name + + label = self._generate_secret_label(relation_name, relation_id, group_mapping) + if secret := self.secrets.get(label): + return secret + + relation = self._model.get_relation(relation_name, relation_id) + if not relation: + return + + if secret_uri := self.get_secret_uri(relation, group_mapping): + return self.secrets.get(label, secret_uri) + # Mandatory overrides for requirer and peer, implemented for Provider + # Requirer uses local component and switched keys + # _local_secret_fields -> PROV_SECRET_FIELDS + # _remote_secret_fields -> REQ_SECRET_FIELDS + # provider uses remote component and + # _local_secret_fields -> REQ_SECRET_FIELDS + # _remote_secret_fields -> PROV_SECRET_FIELDS @abstractmethod + def _load_secrets_from_databag(self, relation: Relation) -> None: + """Load secrets from the databag.""" + raise NotImplementedError + def _fetch_specific_relation_data( self, relation: Relation, fields: Optional[List[str]] ) -> Dict[str, str]: - """Fetch data available (directily or indirectly -- i.e. secrets) from the relation.""" - raise NotImplementedError + """Fetch data available (directily or indirectly -- i.e. secrets) from the relation (remote app data).""" + if not relation.app: + return {} + self._load_secrets_from_databag(relation) + return self._fetch_relation_data_with_secrets( + relation.app, self.remote_secret_fields, relation, fields + ) - @abstractmethod def _fetch_my_specific_relation_data( self, relation: Relation, fields: Optional[List[str]] - ) -> Dict[str, str]: - """Fetch data available (directily or indirectly -- i.e. secrets) from the relation for owner/this_app.""" - raise NotImplementedError + ) -> dict: + """Fetch our own relation data.""" + # load secrets + self._load_secrets_from_databag(relation) + return self._fetch_relation_data_with_secrets( + self.local_app, + self.local_secret_fields, + relation, + fields, + ) - @abstractmethod def _update_relation_data(self, relation: Relation, data: Dict[str, str]) -> None: - """Update data available (directily or indirectly -- i.e. secrets) from the relation for owner/this_app.""" - raise NotImplementedError + """Set values for fields not caring whether it's a secret or not.""" + self._load_secrets_from_databag(relation) + + _, normal_fields = self._process_secret_fields( + relation, + self.local_secret_fields, + list(data), + self._add_or_update_relation_secrets, + data=data, + ) + + normal_content = {k: v for k, v in data.items() if k in normal_fields} + self._update_relation_data_without_secrets(self.local_app, relation, normal_content) + + def _add_or_update_relation_secrets( + self, + relation: Relation, + group: SecretGroup, + secret_fields: Set[str], + data: Dict[str, str], + uri_to_databag=True, + ) -> bool: + """Update contents for Secret group. If the Secret doesn't exist, create it.""" + if self._get_relation_secret(relation.id, group): + return self._update_relation_secret(relation, group, secret_fields, data) + + return self._add_relation_secret(relation, group, secret_fields, data, uri_to_databag) + + @juju_secrets_only + def _add_relation_secret( + self, + relation: Relation, + group_mapping: SecretGroup, + secret_fields: Set[str], + data: Dict[str, str], + uri_to_databag=True, + ) -> bool: + """Add a new Juju Secret that will be registered in the relation databag.""" + if uri_to_databag and self.get_secret_uri(relation, group_mapping): + logging.error("Secret for relation %s already exists, not adding again", relation.id) + return False + + content = self._content_for_secret_group(data, secret_fields, group_mapping) + + label = self._generate_secret_label(self.relation_name, relation.id, group_mapping) + secret = self.secrets.add(label, content, relation) + + if uri_to_databag: + # According to lint we may not have a Secret ID + if not secret.meta or not secret.meta.id: + logging.error("Secret is missing Secret ID") + raise SecretError("Secret added but is missing Secret ID") + + self.set_secret_uri(relation, group_mapping, secret.meta.id) + + # Return the content that was added + return True + + @juju_secrets_only + def _update_relation_secret( + self, + relation: Relation, + group_mapping: SecretGroup, + secret_fields: Set[str], + data: Dict[str, str], + ) -> bool: + """Update the contents of an existing Juju Secret, referred in the relation databag.""" + secret = self._get_relation_secret(relation.id, group_mapping) + + if not secret: + logging.error("Can't update secret for relation %s", relation.id) + return False + + content = self._content_for_secret_group(data, secret_fields, group_mapping) + + old_content = secret.get_content() + full_content = copy.deepcopy(old_content) + full_content.update(content) + secret.set_content(full_content) + + # Return True on success + return True + + @juju_secrets_only + def _delete_relation_secret( + self, relation: Relation, group: SecretGroup, secret_fields: List[str], fields: List[str] + ) -> bool: + """Update the contents of an existing Juju Secret, referred in the relation databag.""" + secret = self._get_relation_secret(relation.id, group) + + if not secret: + logging.error("Can't delete secret for relation %s", str(relation.id)) + return False + + old_content = secret.get_content() + new_content = copy.deepcopy(old_content) + for field in fields: + try: + new_content.pop(field) + except KeyError: + logging.debug( + "Non-existing secret was attempted to be removed %s, %s", + str(relation.id), + str(field), + ) + return False + + # Remove secret from the relation if it's fully gone + if not new_content: + field = self._generate_secret_field_name(group) + try: + relation.data[self.component].pop(field) + except KeyError: + pass + label = self._generate_secret_label(self.relation_name, relation.id, group) + self.secrets.remove(label) + else: + secret.set_content(new_content) + + # Return the content that was removed + return True - @abstractmethod def _delete_relation_data(self, relation: Relation, fields: List[str]) -> None: """Delete data available (directily or indirectly -- i.e. secrets) from the relation for owner/this_app.""" - raise NotImplementedError + if relation.app: + self._load_secrets_from_databag(relation) + + _, normal_fields = self._process_secret_fields( + relation, self.local_secret_fields, fields, self._delete_relation_secret, fields=fields + ) + self._delete_relation_data_without_secrets(self.local_app, relation, list(normal_fields)) + + def _register_secret_to_relation( + self, relation_name: str, relation_id: int, secret_id: str, group: SecretGroup + ): + """Fetch secrets and apply local label on them. + + [MAGIC HERE] + If we fetch a secret using get_secret(id=, label=), + then will be "stuck" on the Secret object, whenever it may + appear (i.e. as an event attribute, or fetched manually) on future occasions. + + This will allow us to uniquely identify the secret on Provider side (typically on + 'secret-changed' events), and map it to the corresponding relation. + """ + label = self._generate_secret_label(relation_name, relation_id, group) + + # Fetching the Secret's meta information ensuring that it's locally getting registered with + CachedSecret(self._model, self.component, label, secret_id).meta + + def _register_secrets_to_relation(self, relation: Relation, params_name_list: List[str]): + """Make sure that secrets of the provided list are locally 'registered' from the databag. + + More on 'locally registered' magic is described in _register_secret_to_relation() method + """ + if not relation.app: + return + + for group in SECRET_GROUPS.groups(): + secret_field = self._generate_secret_field_name(group) + if secret_field in params_name_list and ( + secret_uri := self.get_secret_uri(relation, group) + ): + self._register_secret_to_relation(relation.name, relation.id, secret_uri, group) # Optional overrides @@ -1178,7 +1397,6 @@ def _process_secret_fields( and (self.local_unit == self._model.unit and self.local_unit.is_leader()) and set(req_secret_fields) & set(relation.data[self.component]) ) - normal_fields = set(impacted_rel_fields) if req_secret_fields and self.secrets_enabled and not fallback_to_databag: normal_fields = normal_fields - set(req_secret_fields) @@ -1305,7 +1523,14 @@ def get_relation(self, relation_name, relation_id) -> Relation: def get_secret_uri(self, relation: Relation, group: SecretGroup) -> Optional[str]: """Get the secret URI for the corresponding group.""" secret_field = self._generate_secret_field_name(group) - return relation.data[self.component].get(secret_field) + # if the secret is not managed by this component, + # we need to fetch it from the other side + + # Fix for the linter + if self.my_secret_groups is None: + raise DataInterfacesError("Secrets are not enabled for this component") + component = self.component if group in self.my_secret_groups else relation.app + return relation.data[component].get(secret_field) def set_secret_uri(self, relation: Relation, group: SecretGroup, secret_uri: str) -> None: """Set the secret URI for the corresponding group.""" @@ -1434,6 +1659,32 @@ def __init__(self, charm: CharmBase, relation_data: Data, unique_key: str = ""): self._on_relation_changed_event, ) + self.framework.observe( + self.charm.on[relation_data.relation_name].relation_created, + self._on_relation_created_event, + ) + + self.framework.observe( + charm.on.secret_changed, + self._on_secret_changed_event, + ) + + # Event handlers + + def _on_relation_created_event(self, event: RelationCreatedEvent) -> None: + """Event emitted when the relation is created.""" + pass + + @abstractmethod + def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: + """Event emitted when the relation data has changed.""" + raise NotImplementedError + + @abstractmethod + def _on_secret_changed_event(self, event: SecretChangedEvent) -> None: + """Event emitted when the relation data has changed.""" + raise NotImplementedError + def _diff(self, event: RelationChangedEvent) -> Diff: """Retrieves the diff of the data in the relation changed databag. @@ -1446,11 +1697,6 @@ def _diff(self, event: RelationChangedEvent) -> Diff: """ return diff(event, self.relation_data.data_component) - @abstractmethod - def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: - """Event emitted when the relation data has changed.""" - raise NotImplementedError - # Base ProviderData and RequiresData @@ -1467,165 +1713,11 @@ def __init__( ) -> None: super().__init__(model, relation_name) self.data_component = self.local_app - - # Private methods handling secrets - - @juju_secrets_only - def _add_relation_secret( - self, - relation: Relation, - group_mapping: SecretGroup, - secret_fields: Set[str], - data: Dict[str, str], - uri_to_databag=True, - ) -> bool: - """Add a new Juju Secret that will be registered in the relation databag.""" - if uri_to_databag and self.get_secret_uri(relation, group_mapping): - logging.error("Secret for relation %s already exists, not adding again", relation.id) - return False - - content = self._content_for_secret_group(data, secret_fields, group_mapping) - - label = self._generate_secret_label(self.relation_name, relation.id, group_mapping) - secret = self.secrets.add(label, content, relation) - - # According to lint we may not have a Secret ID - if uri_to_databag and secret.meta and secret.meta.id: - self.set_secret_uri(relation, group_mapping, secret.meta.id) - - # Return the content that was added - return True - - @juju_secrets_only - def _update_relation_secret( - self, - relation: Relation, - group_mapping: SecretGroup, - secret_fields: Set[str], - data: Dict[str, str], - ) -> bool: - """Update the contents of an existing Juju Secret, referred in the relation databag.""" - secret = self._get_relation_secret(relation.id, group_mapping) - - if not secret: - logging.error("Can't update secret for relation %s", relation.id) - return False - - content = self._content_for_secret_group(data, secret_fields, group_mapping) - - old_content = secret.get_content() - full_content = copy.deepcopy(old_content) - full_content.update(content) - secret.set_content(full_content) - - # Return True on success - return True - - def _add_or_update_relation_secrets( - self, - relation: Relation, - group: SecretGroup, - secret_fields: Set[str], - data: Dict[str, str], - uri_to_databag=True, - ) -> bool: - """Update contents for Secret group. If the Secret doesn't exist, create it.""" - if self._get_relation_secret(relation.id, group): - return self._update_relation_secret(relation, group, secret_fields, data) - else: - return self._add_relation_secret(relation, group, secret_fields, data, uri_to_databag) - - @juju_secrets_only - def _delete_relation_secret( - self, relation: Relation, group: SecretGroup, secret_fields: List[str], fields: List[str] - ) -> bool: - """Update the contents of an existing Juju Secret, referred in the relation databag.""" - secret = self._get_relation_secret(relation.id, group) - - if not secret: - logging.error("Can't delete secret for relation %s", str(relation.id)) - return False - - old_content = secret.get_content() - new_content = copy.deepcopy(old_content) - for field in fields: - try: - new_content.pop(field) - except KeyError: - logging.debug( - "Non-existing secret was attempted to be removed %s, %s", - str(relation.id), - str(field), - ) - return False - - # Remove secret from the relation if it's fully gone - if not new_content: - field = self._generate_secret_field_name(group) - try: - relation.data[self.component].pop(field) - except KeyError: - pass - label = self._generate_secret_label(self.relation_name, relation.id, group) - self.secrets.remove(label) - else: - secret.set_content(new_content) - - # Return the content that was removed - return True - - # Mandatory internal overrides - - @juju_secrets_only - def _get_relation_secret( - self, relation_id: int, group_mapping: SecretGroup, relation_name: Optional[str] = None - ) -> Optional[CachedSecret]: - """Retrieve a Juju Secret that's been stored in the relation databag.""" - if not relation_name: - relation_name = self.relation_name - - label = self._generate_secret_label(relation_name, relation_id, group_mapping) - if secret := self.secrets.get(label): - return secret - - relation = self._model.get_relation(relation_name, relation_id) - if not relation: - return - - if secret_uri := self.get_secret_uri(relation, group_mapping): - return self.secrets.get(label, secret_uri) - - def _fetch_specific_relation_data( - self, relation: Relation, fields: Optional[List[str]] - ) -> Dict[str, str]: - """Fetching relation data for Provider. - - NOTE: Since all secret fields are in the Provider side of the databag, we don't need to worry about that - """ - if not relation.app: - return {} - - return self._fetch_relation_data_without_secrets(relation.app, relation, fields) - - def _fetch_my_specific_relation_data( - self, relation: Relation, fields: Optional[List[str]] - ) -> dict: - """Fetching our own relation data.""" - secret_fields = None - if relation.app: - secret_fields = get_encoded_list(relation, relation.app, REQ_SECRET_FIELDS) - - return self._fetch_relation_data_with_secrets( - self.local_app, - secret_fields, - relation, - fields, - ) + self._local_secret_fields = [] + self._remote_secret_fields = list(self.SECRET_FIELDS) def _update_relation_data(self, relation: Relation, data: Dict[str, str]) -> None: """Set values for fields not caring whether it's a secret or not.""" - req_secret_fields = [] - keys = set(data.keys()) if self.fetch_relation_field(relation.id, self.RESOURCE_FIELD) is None and ( keys - {"endpoints", "read-only-endpoints", "replset"} @@ -1633,31 +1725,7 @@ def _update_relation_data(self, relation: Relation, data: Dict[str, str]) -> Non raise PrematureDataAccessError( "Premature access to relation data, update is forbidden before the connection is initialized." ) - - if relation.app: - req_secret_fields = get_encoded_list(relation, relation.app, REQ_SECRET_FIELDS) - - _, normal_fields = self._process_secret_fields( - relation, - req_secret_fields, - list(data), - self._add_or_update_relation_secrets, - data=data, - ) - - normal_content = {k: v for k, v in data.items() if k in normal_fields} - self._update_relation_data_without_secrets(self.local_app, relation, normal_content) - - def _delete_relation_data(self, relation: Relation, fields: List[str]) -> None: - """Delete fields from the Relation not caring whether it's a secret or not.""" - req_secret_fields = [] - if relation.app: - req_secret_fields = get_encoded_list(relation, relation.app, REQ_SECRET_FIELDS) - - _, normal_fields = self._process_secret_fields( - relation, req_secret_fields, fields, self._delete_relation_secret, fields=fields - ) - self._delete_relation_data_without_secrets(self.local_app, relation, list(normal_fields)) + super()._update_relation_data(relation, data) # Public methods - "native" @@ -1697,6 +1765,16 @@ def set_tls_ca(self, relation_id: int, tls_ca: str) -> None: fetch_my_relation_data = leader_only(Data.fetch_my_relation_data) fetch_my_relation_field = leader_only(Data.fetch_my_relation_field) + def _load_secrets_from_databag(self, relation: Relation) -> None: + """Load secrets from the databag.""" + requested_secrets = get_encoded_list(relation, relation.app, REQ_SECRET_FIELDS) + provided_secrets = get_encoded_list(relation, relation.app, PROV_SECRET_FIELDS) + if requested_secrets is not None: + self._local_secret_fields = requested_secrets + + if provided_secrets is not None: + self._remote_secret_fields = provided_secrets + class RequirerData(Data): """Requirer-side of the relation.""" @@ -1713,52 +1791,18 @@ def __init__( """Manager of base client relations.""" super().__init__(model, relation_name) self.extra_user_roles = extra_user_roles - self._secret_fields = list(self.SECRET_FIELDS) + self._remote_secret_fields = list(self.SECRET_FIELDS) + self._local_secret_fields = [ + field + for field in self.SECRET_LABEL_MAP.keys() + if field not in self._remote_secret_fields + ] if additional_secret_fields: - self._secret_fields += additional_secret_fields + self._remote_secret_fields += additional_secret_fields self.data_component = self.local_unit - @property - def secret_fields(self) -> Optional[List[str]]: - """Local access to secrets field, in case they are being used.""" - if self.secrets_enabled: - return self._secret_fields - # Internal helper functions - def _register_secret_to_relation( - self, relation_name: str, relation_id: int, secret_id: str, group: SecretGroup - ): - """Fetch secrets and apply local label on them. - - [MAGIC HERE] - If we fetch a secret using get_secret(id=, label=), - then will be "stuck" on the Secret object, whenever it may - appear (i.e. as an event attribute, or fetched manually) on future occasions. - - This will allow us to uniquely identify the secret on Provider side (typically on - 'secret-changed' events), and map it to the corresponding relation. - """ - label = self._generate_secret_label(relation_name, relation_id, group) - - # Fetching the Secret's meta information ensuring that it's locally getting registered with - CachedSecret(self._model, self.component, label, secret_id).meta - - def _register_secrets_to_relation(self, relation: Relation, params_name_list: List[str]): - """Make sure that secrets of the provided list are locally 'registered' from the databag. - - More on 'locally registered' magic is described in _register_secret_to_relation() method - """ - if not relation.app: - return - - for group in SECRET_GROUPS.groups(): - secret_field = self._generate_secret_field_name(group) - if secret_field in params_name_list and ( - secret_uri := self.get_secret_uri(relation, group) - ): - self._register_secret_to_relation(relation.name, relation.id, secret_uri, group) - def _is_resource_created_for_relation(self, relation: Relation) -> bool: if not relation.app: return False @@ -1769,16 +1813,6 @@ def _is_resource_created_for_relation(self, relation: Relation) -> bool: return bool(data.get("username")) and bool(data.get("password")) # Public functions - - def get_secret_uri(self, relation: Relation, group: SecretGroup) -> Optional[str]: - """Getting relation secret URI for the corresponding Secret Group.""" - secret_field = self._generate_secret_field_name(group) - return relation.data[relation.app].get(secret_field) - - def set_secret_uri(self, relation: Relation, group: SecretGroup, uri: str) -> None: - """Setting relation secret URI is not possible for a Requirer.""" - raise NotImplementedError("Requirer can not change the relation secret URI.") - def is_resource_created(self, relation_id: Optional[int] = None) -> bool: """Check if the resource has been created. @@ -1805,70 +1839,28 @@ def is_resource_created(self, relation_id: Optional[int] = None) -> bool: raise IndexError(f"relation id {relation_id} cannot be accessed") else: return ( - all( - self._is_resource_created_for_relation(relation) for relation in self.relations - ) - if self.relations - else False - ) - - # Mandatory internal overrides - - @juju_secrets_only - def _get_relation_secret( - self, relation_id: int, group: SecretGroup, relation_name: Optional[str] = None - ) -> Optional[CachedSecret]: - """Retrieve a Juju Secret that's been stored in the relation databag.""" - if not relation_name: - relation_name = self.relation_name - - label = self._generate_secret_label(relation_name, relation_id, group) - return self.secrets.get(label) - - def _fetch_specific_relation_data( - self, relation, fields: Optional[List[str]] = None - ) -> Dict[str, str]: - """Fetching Requirer data -- that may include secrets.""" - if not relation.app: - return {} - return self._fetch_relation_data_with_secrets( - relation.app, self.secret_fields, relation, fields - ) - - def _fetch_my_specific_relation_data(self, relation, fields: Optional[List[str]]) -> dict: - """Fetching our own relation data.""" - return self._fetch_relation_data_without_secrets(self.local_app, relation, fields) - - def _update_relation_data(self, relation: Relation, data: dict) -> None: - """Updates a set of key-value pairs in the relation. - - This function writes in the application data bag, therefore, - only the leader unit can call it. - - Args: - relation: the particular relation. - data: dict containing the key-value pairs - that should be updated in the relation. - """ - return self._update_relation_data_without_secrets(self.local_app, relation, data) - - def _delete_relation_data(self, relation: Relation, fields: List[str]) -> None: - """Deletes a set of fields from the relation. - - This function writes in the application data bag, therefore, - only the leader unit can call it. - - Args: - relation: the particular relation. - fields: list containing the field names that should be removed from the relation. - """ - return self._delete_relation_data_without_secrets(self.local_app, relation, fields) + all( + self._is_resource_created_for_relation(relation) for relation in self.relations + ) + if self.relations + else False + ) # Public functions -- inherited fetch_my_relation_data = leader_only(Data.fetch_my_relation_data) fetch_my_relation_field = leader_only(Data.fetch_my_relation_field) + def _load_secrets_from_databag(self, relation: Relation) -> None: + """Load secrets from the databag.""" + requested_secrets = get_encoded_list(relation, self.local_unit, REQ_SECRET_FIELDS) + provided_secrets = get_encoded_list(relation, self.local_unit, PROV_SECRET_FIELDS) + if requested_secrets: + self._remote_secret_fields = requested_secrets + + if provided_secrets: + self._local_secret_fields = provided_secrets + class RequirerEventHandlers(EventHandlers): """Requires-side of the relation.""" @@ -1877,15 +1869,6 @@ def __init__(self, charm: CharmBase, relation_data: RequirerData, unique_key: st """Manager of base client relations.""" super().__init__(charm, relation_data, unique_key) - self.framework.observe( - self.charm.on[relation_data.relation_name].relation_created, - self._on_relation_created_event, - ) - self.framework.observe( - charm.on.secret_changed, - self._on_secret_changed_event, - ) - # Event handlers def _on_relation_created_event(self, event: RelationCreatedEvent) -> None: @@ -1893,18 +1876,56 @@ def _on_relation_created_event(self, event: RelationCreatedEvent) -> None: if not self.relation_data.local_unit.is_leader(): return - if self.relation_data.secret_fields: # pyright: ignore [reportAttributeAccessIssue] + if self.relation_data.remote_secret_fields: + if self.relation_data.SCOPE == Scope.APP: + set_encoded_field( + event.relation, + self.relation_data.local_app, + REQ_SECRET_FIELDS, + self.relation_data.remote_secret_fields, + ) + set_encoded_field( event.relation, - self.relation_data.component, + self.relation_data.local_unit, REQ_SECRET_FIELDS, - self.relation_data.secret_fields, # pyright: ignore [reportAttributeAccessIssue] + self.relation_data.remote_secret_fields, ) - @abstractmethod - def _on_secret_changed_event(self, event: RelationChangedEvent) -> None: + if self.relation_data.local_secret_fields: + if self.relation_data.SCOPE == Scope.APP: + set_encoded_field( + event.relation, + self.relation_data.local_app, + PROV_SECRET_FIELDS, + self.relation_data.local_secret_fields, + ) + set_encoded_field( + event.relation, + self.relation_data.local_unit, + PROV_SECRET_FIELDS, + self.relation_data.local_secret_fields, + ) + + +class ProviderEventHandlers(EventHandlers): + """Provider-side of the relation.""" + + def __init__(self, charm: CharmBase, relation_data: ProviderData, unique_key: str = ""): + """Manager of base client relations.""" + super().__init__(charm, relation_data, unique_key) + + # Event handlers + + def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: """Event emitted when the relation data has changed.""" - raise NotImplementedError + requested_secrets = get_encoded_list(event.relation, event.relation.app, REQ_SECRET_FIELDS) + provided_secrets = get_encoded_list(event.relation, event.relation.app, PROV_SECRET_FIELDS) + if requested_secrets is not None: + self.relation_data._local_secret_fields = requested_secrets + + if provided_secrets is not None: + self.relation_data._remote_secret_fields = provided_secrets ################################################################################ @@ -1955,7 +1976,7 @@ def __init__( secret_group = SECRET_GROUPS.get_group(group) internal_field = self._field_to_internal_name(field, secret_group) self._secret_label_map.setdefault(group, []).append(internal_field) - self._secret_fields.append(internal_field) + self._remote_secret_fields.append(internal_field) @property def scope(self) -> Optional[Scope]: @@ -1973,10 +1994,10 @@ def secret_label_map(self) -> Dict[str, str]: @property def static_secret_fields(self) -> List[str]: """Re-definition of the property in a way that dynamically extended list is retrieved.""" - return self._secret_fields + return self._remote_secret_fields @property - def secret_fields(self) -> List[str]: + def local_secret_fields(self) -> List[str]: """Re-definition of the property in a way that dynamically extended list is retrieved.""" return ( self.static_secret_fields if self.static_secret_fields else self.current_secret_fields @@ -1994,7 +2015,11 @@ def current_secret_fields(self) -> List[str]: relation = self._model.relations[self.relation_name][0] fields = [] - ignores = [SECRET_GROUPS.get_group("user"), SECRET_GROUPS.get_group("tls")] + ignores = [ + SECRET_GROUPS.get_group("user"), + SECRET_GROUPS.get_group("tls"), + SECRET_GROUPS.get_group("mtls"), + ] for group in SECRET_GROUPS.groups(): if group in ignores: continue @@ -2103,11 +2128,11 @@ def _content_for_secret_group( ) -> Dict[str, str]: """Select : pairs from input, that belong to this particular Secret group.""" if group_mapping == SECRET_GROUPS.EXTRA: - return {k: v for k, v in content.items() if k in self.secret_fields} + return {k: v for k, v in content.items() if k in self.local_secret_fields} return { self._internal_name_to_field(k)[0]: v for k, v in content.items() - if k in self.secret_fields + if k in self.local_secret_fields } def valid_field_pattern(self, field: str, full_field: str) -> bool: @@ -2122,6 +2147,16 @@ def valid_field_pattern(self, field: str, full_field: str) -> bool: return False return True + def _load_secrets_from_databag(self, relation: Relation) -> None: + """Load secrets from the databag.""" + requested_secrets = get_encoded_list(relation, self.component, REQ_SECRET_FIELDS) + provided_secrets = get_encoded_list(relation, self.component, PROV_SECRET_FIELDS) + if requested_secrets: + self._remote_secret_fields = requested_secrets + + if provided_secrets: + self._local_secret_fields = provided_secrets + ########################################################################## # Backwards compatibility / Upgrades ########################################################################## @@ -2177,7 +2212,7 @@ def _legacy_compat_check_deleted_label(self, relation, fields) -> None: if current_data is not None: # Check if the secret we wanna delete actually exists # Given the "deleted label", here we can't rely on the default mechanism (i.e. 'key not found') - if non_existent := (set(fields) & set(self.secret_fields)) - set( + if non_existent := (set(fields) & set(self.local_secret_fields)) - set( current_data.get(relation.id, []) ): logger.debug( @@ -2227,10 +2262,10 @@ def _legacy_migration_remove_secret_from_databag(self, relation, fields: List[st Practically what happens here is to remove stuff from the databag that is to be stored in secrets. """ - if not self.secret_fields: + if not self.local_secret_fields: return - secret_fields_passed = set(self.secret_fields) & set(fields) + secret_fields_passed = set(self.local_secret_fields) & set(fields) for field in secret_fields_passed: if self._fetch_relation_data_without_secrets(self.component, relation, [field]): self._delete_relation_data_without_secrets(self.component, relation, [field]) @@ -2342,15 +2377,17 @@ def _fetch_my_specific_relation_data( ) -> Dict[str, str]: """Fetch data available (directily or indirectly -- i.e. secrets) from the relation for owner/this_app.""" return self._fetch_relation_data_with_secrets( - self.component, self.secret_fields, relation, fields + self.component, self.local_secret_fields, relation, fields ) @either_static_or_dynamic_secrets def _update_relation_data(self, relation: Relation, data: Dict[str, str]) -> None: """Update data available (directily or indirectly -- i.e. secrets) from the relation for owner/this_app.""" + self._load_secrets_from_databag(relation) + _, normal_fields = self._process_secret_fields( relation, - self.secret_fields, + self.local_secret_fields, list(data), self._add_or_update_relation_secrets, data=data, @@ -2363,17 +2400,22 @@ def _update_relation_data(self, relation: Relation, data: Dict[str, str]) -> Non @either_static_or_dynamic_secrets def _delete_relation_data(self, relation: Relation, fields: List[str]) -> None: """Delete data available (directily or indirectly -- i.e. secrets) from the relation for owner/this_app.""" - if self.secret_fields and self.deleted_label: + self._load_secrets_from_databag(relation) + if self.local_secret_fields and self.deleted_label: _, normal_fields = self._process_secret_fields( relation, - self.secret_fields, + self.local_secret_fields, fields, self._update_relation_secret, data=dict.fromkeys(fields, self.deleted_label), ) else: _, normal_fields = self._process_secret_fields( - relation, self.secret_fields, fields, self._delete_relation_secret, fields=fields + relation, + self.local_secret_fields, + fields, + self._delete_relation_secret, + fields=fields, ) self._delete_relation_data_without_secrets(self.component, relation, list(normal_fields)) @@ -2896,7 +2938,7 @@ def set_subordinated(self, relation_id: int) -> None: self.update_relation_data(relation_id, {"subordinated": "true"}) -class DatabaseProviderEventHandlers(EventHandlers): +class DatabaseProviderEventHandlers(ProviderEventHandlers): """Provider-side of the database relation handlers.""" on = DatabaseProvidesEvents() # pyright: ignore [reportAssignmentType] @@ -2911,6 +2953,7 @@ def __init__( def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: """Event emitted when the relation has changed.""" + super()._on_relation_changed_event(event) # Leader only if not self.relation_data.local_unit.is_leader(): return @@ -2924,6 +2967,10 @@ def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: event.relation, app=event.app, unit=event.unit ) + def _on_secret_changed_event(self, event: SecretChangedEvent) -> None: + """Event emitted when the secret has changed.""" + pass + class DatabaseProvides(DatabaseProviderData, DatabaseProviderEventHandlers): """Provider-side of the database relations.""" @@ -3369,7 +3416,7 @@ def set_zookeeper_uris(self, relation_id: int, zookeeper_uris: str) -> None: self.update_relation_data(relation_id, {"zookeeper-uris": zookeeper_uris}) -class KafkaProviderEventHandlers(EventHandlers): +class KafkaProviderEventHandlers(ProviderEventHandlers): """Provider-side of the Kafka relation.""" on = KafkaProvidesEvents() # pyright: ignore [reportAssignmentType] @@ -3381,6 +3428,7 @@ def __init__(self, charm: CharmBase, relation_data: KafkaProviderData) -> None: def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: """Event emitted when the relation has changed.""" + super()._on_relation_changed_event(event) # Leader only if not self.relation_data.local_unit.is_leader(): return @@ -3613,7 +3661,7 @@ def set_version(self, relation_id: int, version: str) -> None: self.update_relation_data(relation_id, {"version": version}) -class OpenSearchProvidesEventHandlers(EventHandlers): +class OpenSearchProvidesEventHandlers(ProviderEventHandlers): """Provider-side of the OpenSearch relation.""" on = OpenSearchProvidesEvents() # pyright: ignore[reportAssignmentType] @@ -3625,6 +3673,8 @@ def __init__(self, charm: CharmBase, relation_data: OpenSearchProvidesData) -> N def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: """Event emitted when the relation has changed.""" + super()._on_relation_changed_event(event) + # Leader only if not self.relation_data.local_unit.is_leader(): return @@ -3778,3 +3828,320 @@ def __init__( additional_secret_fields, ) OpenSearchRequiresEventHandlers.__init__(self, charm, self) + + +# Etcd related events + + +class EtcdProviderEvent(RelationEventWithSecret): + """Base class for Etcd events.""" + + @property + def prefix(self) -> Optional[str]: + """Returns the index that was requested.""" + if not self.relation.app: + return None + + return self.relation.data[self.relation.app].get("prefix") + + @property + def mtls_cert(self) -> Optional[str]: + """Returns TLS cert of the client.""" + if not self.relation.app: + return None + + if not self.secrets_enabled: + raise SecretsUnavailableError("Secrets unavailable on current Juju version") + + secret_field = f"{PROV_SECRET_PREFIX}{SECRET_GROUPS.MTLS}" + if secret_uri := self.relation.data[self.app].get(secret_field): + secret = self.framework.model.get_secret(id=secret_uri) + content = secret.get_content(refresh=True) + if content: + return content.get("mtls-cert") + + +class MTLSCertUpdatedEvent(EtcdProviderEvent): + """Event emitted when the mtls relation is updated.""" + + def __init__(self, handle, relation, old_mtls_cert: Optional[str] = None, app=None, unit=None): + super().__init__(handle, relation, app, unit) + + self.old_mtls_cert = old_mtls_cert + + def snapshot(self): + """Return a snapshot of the event.""" + return super().snapshot() | {"old_mtls_cert": self.old_mtls_cert} + + def restore(self, snapshot): + """Restore the event from a snapshot.""" + super().restore(snapshot) + self.old_mtls_cert = snapshot["old_mtls_cert"] + + +class EtcdProviderEvents(CharmEvents): + """Etcd events. + + This class defines the events that Etcd can emit. + """ + + mtls_cert_updated = EventSource(MTLSCertUpdatedEvent) + + +class EtcdReadyEvent(AuthenticationEvent, DatabaseRequiresEvent): + """Event emitted when the etcd relation is ready to be consumed.""" + + +class EtcdRequirerEvents(CharmEvents): + """Etcd events. + + This class defines the events that the etcd requirer can emit. + """ + + endpoints_changed = EventSource(DatabaseEndpointsChangedEvent) + etcd_ready = EventSource(EtcdReadyEvent) + + +# Etcd Provides and Requires Objects + + +class EtcdProviderData(ProviderData): + """Provider-side of the Etcd relation.""" + + RESOURCE_FIELD = "prefix" + + def __init__(self, model: Model, relation_name: str) -> None: + super().__init__(model, relation_name) + + def set_uris(self, relation_id: int, uris: str) -> None: + """Set the database connection URIs in the application relation databag. + + Args: + relation_id: the identifier for a particular relation. + uris: connection URIs. + """ + self.update_relation_data(relation_id, {"uris": uris}) + + def set_endpoints(self, relation_id: int, endpoints: str) -> None: + """Set the endpoints in the application relation databag. + + Args: + relation_id: the identifier for a particular relation. + endpoints: the endpoint addresses for etcd nodes "ip:port" format. + """ + self.update_relation_data(relation_id, {"endpoints": endpoints}) + + def set_version(self, relation_id: int, version: str) -> None: + """Set the etcd version in the application relation databag. + + Args: + relation_id: the identifier for a particular relation. + version: etcd API version. + """ + self.update_relation_data(relation_id, {"version": version}) + + def set_tls_ca(self, relation_id: int, tls_ca: str) -> None: + """Set the TLS CA in the application relation databag. + + Args: + relation_id: the identifier for a particular relation. + tls_ca: TLS certification authority. + """ + self.update_relation_data(relation_id, {"tls-ca": tls_ca, "tls": "True"}) + + +class EtcdProviderEventHandlers(ProviderEventHandlers): + """Provider-side of the Etcd relation.""" + + on = EtcdProviderEvents() # pyright: ignore[reportAssignmentType] + + def __init__(self, charm: CharmBase, relation_data: EtcdProviderData) -> None: + super().__init__(charm, relation_data) + # Just to keep lint quiet, can't resolve inheritance. The same happened in super().__init__() above + self.relation_data = relation_data + + def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: + """Event emitted when the relation has changed.""" + super()._on_relation_changed_event(event) + # register all new secrets with their labels + new_data_keys = list(event.relation.data[event.app].keys()) + if any(newval for newval in new_data_keys if self.relation_data._is_secret_field(newval)): + self.relation_data._register_secrets_to_relation(event.relation, new_data_keys) + + getattr(self.on, "mtls_cert_updated").emit(event.relation, app=event.app, unit=event.unit) + return + + def _on_secret_changed_event(self, event: SecretChangedEvent): + """Event notifying about a new value of a secret.""" + if not event.secret.label: + return + + relation = self.relation_data._relation_from_secret_label(event.secret.label) + if not relation: + logging.info( + f"Received secret {event.secret.label} but couldn't parse, seems irrelevant" + ) + return + + if relation.app == self.charm.app: + logging.info("Secret changed event ignored for Secret Owner") + + remote_unit = None + for unit in relation.units: + if unit.app != self.charm.app: + remote_unit = unit + + old_mtls_cert = event.secret.get_content().get("mtls-cert") + # mtls-cert is the only secret that can be updated + logger.info("mtls-cert updated") + getattr(self.on, "mtls_cert_updated").emit( + relation, app=relation.app, unit=remote_unit, old_mtls_cert=old_mtls_cert + ) + + +class EtcdProvides(EtcdProviderData, EtcdProviderEventHandlers): + """Provider-side of the Etcd relation.""" + + def __init__(self, charm: CharmBase, relation_name: str) -> None: + EtcdProviderData.__init__(self, charm.model, relation_name) + EtcdProviderEventHandlers.__init__(self, charm, self) + if not self.secrets_enabled: + raise SecretsUnavailableError("Secrets unavailable on current Juju version") + + +class EtcdRequirerData(RequirerData): + """Requires data side of the Etcd relation.""" + + def __init__( + self, + model: Model, + relation_name: str, + prefix: str, + mtls_cert: Optional[str], + extra_user_roles: Optional[str] = None, + additional_secret_fields: Optional[List[str]] = [], + ): + """Manager of Etcd client relations.""" + super().__init__(model, relation_name, extra_user_roles, additional_secret_fields) + self.prefix = prefix + self.mtls_cert = mtls_cert + + def set_mtls_cert(self, relation_id: int, mtls_cert: str) -> None: + """Set the mtls cert in the application relation databag / secret. + + Args: + relation_id: the identifier for a particular relation. + mtls_cert: mtls cert. + """ + self.update_relation_data(relation_id, {"mtls-cert": mtls_cert}) + + +class EtcdRequirerEventHandlers(RequirerEventHandlers): + """Requires events side of the Etcd relation.""" + + on = EtcdRequirerEvents() # pyright: ignore[reportAssignmentType] + + def __init__(self, charm: CharmBase, relation_data: EtcdRequirerData) -> None: + super().__init__(charm, relation_data) + # Just to keep lint quiet, can't resolve inheritance. The same happened in super().__init__() above + self.relation_data = relation_data + + def _on_relation_created_event(self, event: RelationCreatedEvent) -> None: + """Event emitted when the Etcd relation is created.""" + super()._on_relation_created_event(event) + + payload = { + "prefix": self.relation_data.prefix, + } + if self.relation_data.mtls_cert: + payload["mtls-cert"] = self.relation_data.mtls_cert + + self.relation_data.update_relation_data( + event.relation.id, + payload, + ) + + def _on_relation_changed_event(self, event: RelationChangedEvent) -> None: + """Event emitted when the Etcd relation has changed. + + This event triggers individual custom events depending on the changing relation. + """ + # Check which data has changed to emit customs events. + diff = self._diff(event) + # Register all new secrets with their labels + if any(newval for newval in diff.added if self.relation_data._is_secret_field(newval)): + self.relation_data._register_secrets_to_relation(event.relation, diff.added) + + secret_field_user = self.relation_data._generate_secret_field_name(SECRET_GROUPS.USER) + secret_field_tls = self.relation_data._generate_secret_field_name(SECRET_GROUPS.TLS) + + # Emit a endpoints changed event if the etcd application added or changed this info + # in the relation databag. + if "endpoints" in diff.added or "endpoints" in diff.changed: + # Emit the default event (the one without an alias). + logger.info("endpoints changed on %s", datetime.now()) + getattr(self.on, "endpoints_changed").emit( + event.relation, app=event.app, unit=event.unit + ) + + if ( + secret_field_tls in diff.added + or secret_field_tls in diff.changed + or secret_field_user in diff.added + or secret_field_user in diff.changed + or "username" in diff.added + or "username" in diff.changed + ): + # Emit the default event (the one without an alias). + logger.info("etcd ready on %s", datetime.now()) + getattr(self.on, "etcd_ready").emit(event.relation, app=event.app, unit=event.unit) + + def _on_secret_changed_event(self, event: SecretChangedEvent): + """Event notifying about a new value of a secret.""" + if not event.secret.label: + return + + relation = self.relation_data._relation_from_secret_label(event.secret.label) + if not relation: + logging.info( + f"Received secret {event.secret.label} but couldn't parse, seems irrelevant" + ) + return + + if relation.app == self.charm.app: + logging.info("Secret changed event ignored for Secret Owner") + + remote_unit = None + for unit in relation.units: + if unit.app != self.charm.app: + remote_unit = unit + + # secret-user or secret-tls updated + logger.info("etcd_ready updated") + getattr(self.on, "etcd_ready").emit(relation, app=relation.app, unit=remote_unit) + + +class EtcdRequires(EtcdRequirerData, EtcdRequirerEventHandlers): + """Requires-side of the Etcd relation.""" + + def __init__( + self, + charm: CharmBase, + relation_name: str, + prefix: str, + mtls_cert: Optional[str], + extra_user_roles: Optional[str] = None, + additional_secret_fields: Optional[List[str]] = [], + ) -> None: + EtcdRequirerData.__init__( + self, + charm.model, + relation_name, + prefix, + mtls_cert, + extra_user_roles, + additional_secret_fields, + ) + EtcdRequirerEventHandlers.__init__(self, charm, self) + if not self.secrets_enabled: + raise SecretsUnavailableError("Secrets unavailable on current Juju version") diff --git a/lib/charms/postgresql_k8s/v0/postgresql.py b/lib/charms/postgresql_k8s/v0/postgresql.py index 9fe1957e4f..7e6a9d7631 100644 --- a/lib/charms/postgresql_k8s/v0/postgresql.py +++ b/lib/charms/postgresql_k8s/v0/postgresql.py @@ -35,7 +35,19 @@ # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 46 +LIBPATCH = 51 + +# Groups to distinguish HBA access +ACCESS_GROUP_IDENTITY = "identity_access" +ACCESS_GROUP_INTERNAL = "internal_access" +ACCESS_GROUP_RELATION = "relation_access" + +# List of access groups to filter role assignments by +ACCESS_GROUPS = [ + ACCESS_GROUP_IDENTITY, + ACCESS_GROUP_INTERNAL, + ACCESS_GROUP_RELATION, +] # Groups to distinguish database permissions PERMISSIONS_GROUP_ADMIN = "admin" @@ -57,10 +69,18 @@ logger = logging.getLogger(__name__) +class PostgreSQLAssignGroupError(Exception): + """Exception raised when assigning to a group fails.""" + + class PostgreSQLCreateDatabaseError(Exception): """Exception raised when creating a database fails.""" +class PostgreSQLCreateGroupError(Exception): + """Exception raised when creating a group fails.""" + + class PostgreSQLCreateUserError(Exception): """Exception raised when creating a user fails.""" @@ -93,6 +113,10 @@ class PostgreSQLGetPostgreSQLVersionError(Exception): """Exception raised when retrieving PostgreSQL version fails.""" +class PostgreSQLListGroupsError(Exception): + """Exception raised when retrieving PostgreSQL groups list fails.""" + + class PostgreSQLListUsersError(Exception): """Exception raised when retrieving PostgreSQL users list fails.""" @@ -129,7 +153,7 @@ def _configure_pgaudit(self, enable: bool) -> None: if enable: cursor.execute("ALTER SYSTEM SET pgaudit.log = 'ROLE,DDL,MISC,MISC_SET';") cursor.execute("ALTER SYSTEM SET pgaudit.log_client TO off;") - cursor.execute("ALTER SYSTEM SET pgaudit.log_parameter TO off") + cursor.execute("ALTER SYSTEM SET pgaudit.log_parameter TO off;") else: cursor.execute("ALTER SYSTEM RESET pgaudit.log;") cursor.execute("ALTER SYSTEM RESET pgaudit.log_client;") @@ -160,6 +184,24 @@ def _connect_to_database( connection.autocommit = True return connection + def create_access_groups(self) -> None: + """Create access groups to distinguish HBA authentication methods.""" + connection = None + try: + with self._connect_to_database() as connection, connection.cursor() as cursor: + for group in ACCESS_GROUPS: + cursor.execute( + SQL("CREATE ROLE {} NOLOGIN;").format( + Identifier(group), + ) + ) + except psycopg2.Error as e: + logger.error(f"Failed to create access groups: {e}") + raise PostgreSQLCreateGroupError() from e + finally: + if connection is not None: + connection.close() + def create_database( self, database: str, @@ -216,7 +258,7 @@ def create_database( raise PostgreSQLCreateDatabaseError() from e # Enable preset extensions - self.enable_disable_extensions({plugin: True for plugin in plugins}, database) + self.enable_disable_extensions(dict.fromkeys(plugins, True), database) def create_user( self, @@ -321,6 +363,50 @@ def delete_user(self, user: str) -> None: logger.error(f"Failed to delete user: {e}") raise PostgreSQLDeleteUserError() from e + def grant_internal_access_group_memberships(self) -> None: + """Grant membership to the internal access-group to existing internal users.""" + connection = None + try: + with self._connect_to_database() as connection, connection.cursor() as cursor: + for user in self.system_users: + cursor.execute( + SQL("GRANT {} TO {};").format( + Identifier(ACCESS_GROUP_INTERNAL), + Identifier(user), + ) + ) + except psycopg2.Error as e: + logger.error(f"Failed to grant internal access group memberships: {e}") + raise PostgreSQLAssignGroupError() from e + finally: + if connection is not None: + connection.close() + + def grant_relation_access_group_memberships(self) -> None: + """Grant membership to the relation access-group to existing relation users.""" + rel_users = self.list_users_from_relation() + if not rel_users: + return + + connection = None + try: + with self._connect_to_database() as connection, connection.cursor() as cursor: + rel_groups = SQL(",").join(Identifier(group) for group in [ACCESS_GROUP_RELATION]) + rel_users = SQL(",").join(Identifier(user) for user in rel_users) + + cursor.execute( + SQL("GRANT {groups} TO {users};").format( + groups=rel_groups, + users=rel_users, + ) + ) + except psycopg2.Error as e: + logger.error(f"Failed to grant relation access group memberships: {e}") + raise PostgreSQLAssignGroupError() from e + finally: + if connection is not None: + connection.close() + def enable_disable_extensions( self, extensions: Dict[str, bool], database: Optional[str] = None ) -> None: @@ -349,6 +435,8 @@ def enable_disable_extensions( for extension, enable in extensions.items(): ordered_extensions[extension] = enable + self._configure_pgaudit(False) + # Enable/disabled the extension in each database. for database in databases: with self._connect_to_database( @@ -534,12 +622,34 @@ def is_tls_enabled(self, check_current_host: bool = False) -> bool: # Connection errors happen when PostgreSQL has not started yet. return False + def list_access_groups(self) -> Set[str]: + """Returns the list of PostgreSQL database access groups. + + Returns: + List of PostgreSQL database access groups. + """ + connection = None + try: + with self._connect_to_database() as connection, connection.cursor() as cursor: + cursor.execute( + "SELECT groname FROM pg_catalog.pg_group WHERE groname LIKE '%_access';" + ) + access_groups = cursor.fetchall() + return {group[0] for group in access_groups} + except psycopg2.Error as e: + logger.error(f"Failed to list PostgreSQL database access groups: {e}") + raise PostgreSQLListGroupsError() from e + finally: + if connection is not None: + connection.close() + def list_users(self) -> Set[str]: """Returns the list of PostgreSQL database users. Returns: List of PostgreSQL database users. """ + connection = None try: with self._connect_to_database() as connection, connection.cursor() as cursor: cursor.execute("SELECT usename FROM pg_catalog.pg_user;") @@ -548,6 +658,30 @@ def list_users(self) -> Set[str]: except psycopg2.Error as e: logger.error(f"Failed to list PostgreSQL database users: {e}") raise PostgreSQLListUsersError() from e + finally: + if connection is not None: + connection.close() + + def list_users_from_relation(self) -> Set[str]: + """Returns the list of PostgreSQL database users that were created by a relation. + + Returns: + List of PostgreSQL database users. + """ + connection = None + try: + with self._connect_to_database() as connection, connection.cursor() as cursor: + cursor.execute( + "SELECT usename FROM pg_catalog.pg_user WHERE usename LIKE 'relation_id_%';" + ) + usernames = cursor.fetchall() + return {username[0] for username in usernames} + except psycopg2.Error as e: + logger.error(f"Failed to list PostgreSQL database users: {e}") + raise PostgreSQLListUsersError() from e + finally: + if connection is not None: + connection.close() def list_valid_privileges_and_roles(self) -> Tuple[Set[str], Set[str]]: """Returns two sets with valid privileges and roles. @@ -644,6 +778,42 @@ def is_restart_pending(self) -> bool: if connection: connection.close() + @staticmethod + def build_postgresql_group_map(group_map: Optional[str]) -> List[Tuple]: + """Build the PostgreSQL authorization group-map. + + Args: + group_map: serialized group-map with the following format: + =, + =, + ... + + Returns: + List of LDAP group to PostgreSQL group tuples. + """ + if group_map is None: + return [] + + group_mappings = group_map.split(",") + group_mappings = (mapping.strip() for mapping in group_mappings) + group_map_list = [] + + for mapping in group_mappings: + mapping_parts = mapping.split("=") + if len(mapping_parts) != 2: + raise ValueError("The group-map must contain value pairs split by commas") + + ldap_group = mapping_parts[0] + psql_group = mapping_parts[1] + + if psql_group in [*ACCESS_GROUPS, PERMISSIONS_GROUP_ADMIN]: + logger.warning(f"Tried to assign LDAP users to forbidden group: {psql_group}") + continue + + group_map_list.append((ldap_group, psql_group)) + + return group_map_list + @staticmethod def build_postgresql_parameters( config_options: dict, available_memory: int, limit_memory: Optional[int] = None @@ -723,3 +893,34 @@ def validate_date_style(self, date_style: str) -> bool: return True except psycopg2.Error: return False + + def validate_group_map(self, group_map: Optional[str]) -> bool: + """Validate the PostgreSQL authorization group-map. + + Args: + group_map: serialized group-map with the following format: + =, + =, + ... + + Returns: + Whether the group-map is valid. + """ + if group_map is None: + return True + + try: + group_map = self.build_postgresql_group_map(group_map) + except ValueError: + return False + + for _, psql_group in group_map: + with self._connect_to_database() as connection, connection.cursor() as cursor: + query = SQL("SELECT TRUE FROM pg_roles WHERE rolname={};") + query = query.format(Literal(psql_group)) + cursor.execute(query) + + if cursor.fetchone() is None: + return False + + return True diff --git a/tests/integration/test_plugins.py b/tests/integration/test_plugins.py index 4dedc11a7e..2adfea7cdd 100644 --- a/tests/integration/test_plugins.py +++ b/tests/integration/test_plugins.py @@ -99,8 +99,7 @@ async def test_plugins(ops_test: OpsTest, charm) -> None: charm, num_units=2, base=CHARM_BASE, - # TODO Figure out how to deal with pgaudit - config={"profile": "testing", "plugin_audit_enable": "False"}, + config={"profile": "testing"}, ) await ops_test.model.wait_for_idle(apps=[DATABASE_APP_NAME], status="active", timeout=1500)