Skip to content

Commit

Permalink
Add types
Browse files Browse the repository at this point in the history
  • Loading branch information
nsklikas committed Mar 6, 2023
1 parent 8aa0467 commit ec7b1f0
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 54 deletions.
116 changes: 70 additions & 46 deletions lib/charms/hydra/v0/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,12 @@ def _set_client_config(self):
import logging
import re
from dataclasses import asdict, dataclass, field
from typing import Dict, List, Optional

import jsonschema
from ops.framework import EventBase, EventSource, Object, ObjectEvents
from ops.charm import CharmBase, RelationChangedEvent, RelationCreatedEvent
from ops.framework import EventBase, EventSource, Handle, Object, ObjectEvents
from ops.model import Relation, Secret

# The unique Charmhub library identifier, never change it
LIBID = "a3a301e325e34aac80a2d633ef61fe97"
Expand Down Expand Up @@ -168,21 +171,21 @@ class DataValidationError(RuntimeError):
"""Raised when data validation fails on relation data."""


def _load_data(data, schema=None):
def _load_data(data: Dict, schema: Optional[Dict] = None) -> Dict:
"""Parses nested fields and checks whether `data` matches `schema`."""
ret = {}
for k, v in data.items():
try:
ret[k] = json.loads(v)
except json.JSONDecodeError as e:
except json.JSONDecodeError:
ret[k] = v

if schema:
_validate_data(ret, schema)
return ret


def _dump_data(data, schema=None):
def _dump_data(data: Dict, schema: Optional[Dict] = None) -> Dict:
if schema:
_validate_data(data, schema)

Expand All @@ -198,7 +201,7 @@ def _dump_data(data, schema=None):
return ret


def _validate_data(data, schema):
def _validate_data(data: Dict, schema: Dict) -> None:
"""Checks whether `data` matches `schema`.
Will raise DataValidationError if the data is not valid, else return None.
Expand All @@ -219,7 +222,7 @@ class ClientConfig:
audience: list[str] = field(default_factory=lambda: [])
token_endpoint_auth_method: str = "client_secret_basic"

def validate(self):
def validate(self) -> None:
"""Validate the client configuration."""
# Validate redirect_uri
if not re.match(url_regex, self.redirect_uri):
Expand All @@ -243,19 +246,19 @@ def validate(self):
class ClientCredentialsChangedEvent(EventBase):
"""Event to notify the charm that the client credentials changed."""

def __init__(self, handle, client_id, client_secret_id):
def __init__(self, handle: Handle, client_id: str, client_secret_id: str):
super().__init__(handle)
self.client_id = client_id
self.client_secret_id = client_secret_id

def snapshot(self):
def snapshot(self) -> Dict:
"""Save event."""
return {
"client_id": self.client_id,
"client_secret_id": self.client_secret_id,
}

def restore(self, snapshot):
def restore(self, snapshot: Dict) -> None:
"""Restore event."""
self.client_id = snapshot["client_id"]
self.client_secret_id = snapshot["client_secret_id"]
Expand All @@ -277,7 +280,12 @@ class OAuthRequirer(Object):

on = OAuthRequirerEvents()

def __init__(self, charm, client_config=None, relation_name=DEFAULT_RELATION_NAME):
def __init__(
self,
charm: CharmBase,
client_config: Optional[ClientConfig] = None,
relation_name: str = DEFAULT_RELATION_NAME,
) -> None:
super().__init__(charm, relation_name)
self._charm = charm
self._relation_name = relation_name
Expand All @@ -286,13 +294,13 @@ def __init__(self, charm, client_config=None, relation_name=DEFAULT_RELATION_NAM
self.framework.observe(events.relation_created, self._on_relation_created_event)
self.framework.observe(events.relation_changed, self._on_relation_changed_event)

def _on_relation_created_event(self, event):
def _on_relation_created_event(self, event: RelationCreatedEvent) -> None:
try:
self._update_relation_data(self._client_config, event.relation.id)
except Exception:
pass

def _on_relation_changed_event(self, event):
def _on_relation_changed_event(self, event: RelationChangedEvent) -> None:
if not self.model.unit.is_leader():
return

Expand All @@ -315,10 +323,15 @@ def _on_relation_changed_event(self, event):
# TODO: log some error?
pass

def _update_relation_data(self, client_config, relation_id):
if not self.model.unit.is_leader():
def _update_relation_data(
self, client_config: Optional[ClientConfig], relation_id: int
) -> None:
if not self.model.unit.is_leader() or not client_config:
return

if not isinstance(client_config, ClientConfig):
raise ValueError(f"Unexpected client_config type: {type(client_config)}")

try:
client_config.validate()
except ClientConfigError as e:
Expand All @@ -330,24 +343,31 @@ def _update_relation_data(self, client_config, relation_id):
relation_name=self._relation_name, relation_id=relation_id
)

if not relation:
return

data = _dump_data(asdict(client_config), OAUTH_REQUIRER_JSON_SCHEMA)
relation.data[self.model.app].update(data)

def get_provider_info(self):
def get_provider_info(self) -> Optional[Dict]:
"""Get the provider information from the databag."""
if len(self.model.relations) == 0:
return
return None
relation = self.model.get_relation(self._relation_name)
if not relation:
return None

data = _load_data(relation.data[relation.app], OAUTH_PROVIDER_JSON_SCHEMA)
data.pop("client_id", None)
data.pop("client_secret_id", None)
return data

def get_client_secret(self, client_secret_id):
def get_client_secret(self, client_secret_id: str) -> Secret:
"""Get the client_secret."""
client_secret = self.model.get_secret(id=client_secret_id)
return client_secret

def update_client_config(self, client_config):
def update_client_config(self, client_config: ClientConfig) -> None:
"""Update the client config stored in the object."""
self._client_config = client_config

Expand All @@ -357,14 +377,14 @@ class ClientCreateEvent(EventBase):

def __init__(
self,
handle,
redirect_uri,
scope,
grant_types,
audience,
token_endpoint_auth_method,
relation_id,
):
handle: Handle,
redirect_uri: str,
scope: str,
grant_types: List[str],
audience: List,
token_endpoint_auth_method: str,
relation_id: str,
) -> None:
super().__init__(handle)
self.redirect_uri = redirect_uri
self.scope = scope
Expand All @@ -373,7 +393,7 @@ def __init__(
self.token_endpoint_auth_method = token_endpoint_auth_method
self.relation_id = relation_id

def snapshot(self):
def snapshot(self) -> Dict:
"""Save event."""
return {
"redirect_uri": self.redirect_uri,
Expand All @@ -384,7 +404,7 @@ def snapshot(self):
"relation_id": self.relation_id,
}

def restore(self, snapshot):
def restore(self, snapshot: Dict) -> None:
"""Restore event."""
self.redirect_uri = snapshot["redirect_uri"]
self.scope = snapshot["scope"]
Expand All @@ -393,7 +413,7 @@ def restore(self, snapshot):
self.token_endpoint_auth_method = snapshot["token_endpoint_auth_method"]
self.relation_id = snapshot["relation_id"]

def to_client_config(self):
def to_client_config(self) -> ClientConfig:
"""Convert the event information to a ClientConfig object."""
return ClientConfig(
self.redirect_uri,
Expand All @@ -409,15 +429,15 @@ class ClientConfigChangedEvent(EventBase):

def __init__(
self,
handle,
redirect_uri,
scope,
grant_types,
audience,
token_endpoint_auth_method,
relation_id,
client_id,
):
handle: Handle,
redirect_uri: str,
scope: str,
grant_types: List,
audience: List,
token_endpoint_auth_method: str,
relation_id: str,
client_id: str,
) -> None:
super().__init__(handle)
self.redirect_uri = redirect_uri
self.scope = scope
Expand All @@ -427,7 +447,7 @@ def __init__(
self.relation_id = relation_id
self.client_id = client_id

def snapshot(self):
def snapshot(self) -> Dict:
"""Save event."""
return {
"redirect_uri": self.redirect_uri,
Expand All @@ -439,7 +459,7 @@ def snapshot(self):
"client_id": self.client_id,
}

def restore(self, snapshot):
def restore(self, snapshot: Dict) -> None:
"""Restore event."""
self.redirect_uri = snapshot["redirect_uri"]
self.scope = snapshot["scope"]
Expand All @@ -449,7 +469,7 @@ def restore(self, snapshot):
self.relation_id = snapshot["relation_id"]
self.client_id = snapshot["client_id"]

def to_client_config(self):
def to_client_config(self) -> ClientConfig:
"""Convert the event information to a ClientConfig object."""
return ClientConfig(
self.redirect_uri,
Expand All @@ -472,7 +492,7 @@ class OAuthProvider(Object):

on = OAuthProviderEvents()

def __init__(self, charm, relation_name=DEFAULT_RELATION_NAME):
def __init__(self, charm: CharmBase, relation_name: str = DEFAULT_RELATION_NAME) -> None:
super().__init__(charm, relation_name)
self._charm = charm
self._relation_name = relation_name
Expand All @@ -482,7 +502,7 @@ def __init__(self, charm, relation_name=DEFAULT_RELATION_NAME):
self._get_client_config_from_relation_data,
)

def _get_client_config_from_relation_data(self, event):
def _get_client_config_from_relation_data(self, event: RelationChangedEvent) -> None:
if not self.model.unit.is_leader():
return

Expand Down Expand Up @@ -517,27 +537,31 @@ def _get_client_config_from_relation_data(self, event):
redirect_uri, scope, grant_types, audience, token_endpoint_auth_method, relation_id
)

def _create_juju_secret(self, client_secret, relation):
def _create_juju_secret(self, client_secret: str, relation: Relation) -> Secret:
"""Create a juju secret and grant it to a relation."""
secret = {CLIENT_SECRET_FIELD: client_secret}
juju_secret = self.model.app.add_secret(secret, label="client_secret")
juju_secret.grant(relation)
return juju_secret

def set_provider_info_in_relation_data(self, data):
def set_provider_info_in_relation_data(self, data: Dict) -> None:
"""Put the provider information in the the databag."""
if not self.model.unit.is_leader():
return

for relation in self.model.relations[self._relation_name]:
relation.data[self.model.app].update(_dump_data(data))

def set_client_credentials_in_relation_data(self, relation_id, client_id, client_secret):
def set_client_credentials_in_relation_data(
self, relation_id: int, client_id: str, client_secret: str
) -> None:
"""Put the client credentials in the the databag."""
if not self.model.unit.is_leader():
return

relation = self.model.get_relation(self._relation_name, relation_id)
if not relation:
return None
# TODO: What if we are refreshing the client_secret? We need to add a
# new revision for that
secret = self._create_juju_secret(client_secret, relation)
Expand Down
24 changes: 18 additions & 6 deletions src/charm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,19 @@
import json
import logging
from os.path import join
from typing import Dict, Optional, Union

from charms.data_platform_libs.v0.data_interfaces import (
DatabaseCreatedEvent,
DatabaseEndpointsChangedEvent,
DatabaseRequires,
)
from charms.hydra.v0.oauth import OAuthProvider
from charms.hydra.v0.oauth import (
ClientConfig,
ClientConfigChangedEvent,
ClientCreateEvent,
OAuthProvider,
)
from charms.observability_libs.v0.kubernetes_service_patch import KubernetesServicePatch
from charms.traefik_k8s.v1.ingress import (
IngressPerAppReadyEvent,
Expand Down Expand Up @@ -282,7 +288,7 @@ def _on_ingress_revoked(self, event: IngressPerAppRevokedEvent) -> None:

self._update_endpoint_info()

def _on_client_create(self, event):
def _on_client_create(self, event: ClientCreateEvent) -> None:
if not self._container.can_connect():
event.defer()
return
Expand All @@ -302,12 +308,14 @@ def _on_client_create(self, event):
event.relation_id, client["client_id"], client["client_secret"]
)

def _on_client_config_changed(self, event):
def _on_client_config_changed(self, event: ClientConfigChangedEvent) -> None:
...
# client_config = event.to_client_config()
# self._create_client(client_config, metadata=f"{{\"relation_id\": {event.relation_id}}}")

def _create_client(self, client_config, metadata=None):
def _create_client(
self, client_config: ClientConfig, metadata: Optional[Union[Dict, str]] = None
):
cmd = [
"hydra",
"create",
Expand Down Expand Up @@ -340,8 +348,12 @@ def _create_client(self, client_config, metadata=None):

logger.info(cmd)

process = self._container.exec(cmd)
stdout, _ = process.wait_output()
try:
process = self._container.exec(cmd)
stdout, _ = process.wait_output()
except ExecError as err:
logger.error(f"Failed to create client: {err.exit_code}. Stderr: {err.stderr}")
return
logger.info(f"Created client: {stdout}")

return json.loads(stdout)
Expand Down
2 changes: 0 additions & 2 deletions tests/unit/test_oauth_requirer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# Copyright 2023 Canonical Ltd.
# See LICENSE file for licensing details.

from os.path import join

import pytest
from charms.hydra.v0.oauth import (
CLIENT_SECRET_FIELD,
Expand Down

0 comments on commit ec7b1f0

Please sign in to comment.