From 23485a9d246e5d4c5c9a2e1cdfdaafeea14f8cbb Mon Sep 17 00:00:00 2001 From: Jens Langhammer Date: Wed, 11 Dec 2024 17:32:17 +0100 Subject: [PATCH] more progress Signed-off-by: Jens Langhammer --- authentik/enterprise/providers/ssf/models.py | 56 +++++++++++++++++++ authentik/enterprise/providers/ssf/signals.py | 29 +--------- authentik/enterprise/providers/ssf/tasks.py | 34 +++++++---- .../enterprise/providers/ssf/views/stream.py | 12 +++- 4 files changed, 89 insertions(+), 42 deletions(-) diff --git a/authentik/enterprise/providers/ssf/models.py b/authentik/enterprise/providers/ssf/models.py index 501f158d7ce33..2005178440990 100644 --- a/authentik/enterprise/providers/ssf/models.py +++ b/authentik/enterprise/providers/ssf/models.py @@ -1,3 +1,4 @@ +from datetime import datetime from functools import cached_property from uuid import uuid4 @@ -6,7 +7,9 @@ from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes from django.contrib.postgres.fields import ArrayField from django.db import models +from django.http import HttpRequest from django.templatetags.static import static +from django.urls import reverse from django.utils.translation import gettext_lazy as _ from jwt import encode @@ -30,6 +33,13 @@ class DeliveryMethods(models.TextChoices): RISC_POLL = "https://schemas.openid.net/secevent/risc/delivery-method/poll" +class SSFEventStatus(models.TextChoices): + """SSF Event status""" + + PENDING = "pending" + SENT = "sent" + + class SSFProvider(BackchannelProvider): """Shared Signals Framework""" @@ -102,6 +112,35 @@ class Stream(models.Model): def __str__(self) -> str: return "SSF Stream" + def new_event( + self, type: EventTypes, request: HttpRequest, event_data: dict, **kwargs + ) -> "StreamEvent": + """Create a new SSF event""" + jti = uuid4() + evt = StreamEvent( + uuid=jti, + stream=self, + type=type, + payload={ + "jti": jti.hex, + "aud": self.provider.aud, + "iat": int(datetime.now().timestamp()), + "iss": self.request.build_absolute_uri( + reverse( + "authentik_providers_ssf:configuration", + kwargs={ + "application_slug": self.provider.application.slug, + "provider": self.provider.pk, + }, + ) + ), + "events": {EventTypes.SET_VERIFICATION: {event_data}}, + **kwargs, + }, + ) + evt.save() + return evt + def encode(self, data: dict) -> str: headers = {} if self.provider.signing_key: @@ -117,6 +156,23 @@ class UserStreamSubject(models.Model): def __str__(self) -> str: return f"Stream subject {self.stream_id} to {self.user_id}" + class StreamEvent(models.Model): + """Single stream event to be sent""" uuid = models.UUIDField(default=uuid4, primary_key=True, editable=False) + + stream = models.ForeignKey(Stream, on_delete=models.CASCADE) + status = models.TextField(choices=SSFEventStatus.choices) + + type = models.TextField(choices=EventTypes.choices) + payload = models.JSONField(default=dict) + + def __str__(self): + return f"Stream event {self.type}" + + def queue(self): + """Queue event to be sent""" + from authentik.enterprise.providers.ssf.tasks import send_single_ssf_event + + return send_single_ssf_event.delay(str(self.stream.uuid), str(self.uuid)) diff --git a/authentik/enterprise/providers/ssf/signals.py b/authentik/enterprise/providers/ssf/signals.py index 5132f7984cc71..73c054743953f 100644 --- a/authentik/enterprise/providers/ssf/signals.py +++ b/authentik/enterprise/providers/ssf/signals.py @@ -1,6 +1,3 @@ -from datetime import datetime -from uuid import uuid4 - from django.contrib.auth.signals import user_logged_out from django.db.models import Model from django.db.models.signals import post_save @@ -17,9 +14,8 @@ from authentik.enterprise.providers.ssf.models import ( EventTypes, SSFProvider, - Stream, ) -from authentik.enterprise.providers.ssf.tasks import send_single_ssf_event, send_ssf_event +from authentik.enterprise.providers.ssf.tasks import send_ssf_event from authentik.events.middleware import audit_ignore from authentik.events.utils import get_user @@ -53,29 +49,6 @@ def ssf_providers_post_save(sender: type[Model], instance: SSFProvider, created: instance.save() -@receiver(post_save, sender=Stream) -def ssf_stream_post_create(sender: type[Model], instance: Stream, created: bool, **_): - """Send a verification event when a stream is created""" - if not created: - return - send_single_ssf_event.delay( - str(instance.uuid), - { - "jti": uuid4().hex, - # TODO: Figure out how to get iss - "iss": "https://ak.beryju.dev/.well-known/ssf-configuration/abm-ssf/8", - "aud": instance.aud, - "iat": int(datetime.now().timestamp()), - "sub_id": {"format": "opaque", "id": str(instance.uuid)}, - "events": { - "https://schemas.openid.net/secevent/ssf/event-type/verification": { - "state": None, - } - }, - }, - ) - - @receiver(user_logged_out) def user_logged_out_session(sender, request: HttpRequest, user: User, **_): send_ssf_event.delay( diff --git a/authentik/enterprise/providers/ssf/tasks.py b/authentik/enterprise/providers/ssf/tasks.py index 07cdb4c06255a..30938eb4c00b6 100644 --- a/authentik/enterprise/providers/ssf/tasks.py +++ b/authentik/enterprise/providers/ssf/tasks.py @@ -1,7 +1,13 @@ from celery import group from requests.exceptions import RequestException -from authentik.enterprise.providers.ssf.models import DeliveryMethods, EventTypes, Stream +from authentik.enterprise.providers.ssf.models import ( + DeliveryMethods, + EventTypes, + SSFEventStatus, + Stream, + StreamEvent, +) from authentik.lib.utils.http import get_http_session from authentik.root.celery import CELERY_APP @@ -12,28 +18,34 @@ def send_ssf_event(event_type: EventTypes, data: dict): tasks = [] for stream in Stream.objects.filter(events_requested__in=[event_type]): - tasks.append(send_single_ssf_event.si(str(stream.uuid), data)) + event = stream.new_event( + type=event_type, + ) + tasks.append(send_single_ssf_event.si(str(stream.uuid), str(event.id))) main_task = group(*tasks) main_task() @CELERY_APP.task(bind=True, autoretry=True, autoretry_for=(RequestException,), retry_backoff=True) -def send_single_ssf_event(self, stream_id: str, data: dict): +def send_single_ssf_event(self, stream_id: str, evt_id: str): stream = Stream.objects.filter(pk=stream_id).first() if not stream: return + event = StreamEvent.objects.filter(pk=evt_id).first() + if not event: + return + if event.status == SSFEventStatus.SENT: + return if stream.delivery_method == DeliveryMethods.RISC_PUSH: - ssf_push_request.delay(stream_id, data) + ssf_push_request(stream_id, event) + event.status = SSFEventStatus.SENT + event.save() -@CELERY_APP.task(bind=True, autoretry=True, autoretry_for=(RequestException,), retry_backoff=True) -def ssf_push_request(self, stream_id: str, data: dict): - stream = Stream.objects.filter(pk=stream_id).first() - if not stream: - return +def ssf_push_request(event: StreamEvent): response = session.post( - stream.endpoint_url, - data=stream.encode(data), + event.stream.endpoint_url, + data=event.stream.encode(event.data), headers={"Content-Type": "application/secevent+jwt", "Accept": "application/json"}, ) response.raise_for_status() diff --git a/authentik/enterprise/providers/ssf/views/stream.py b/authentik/enterprise/providers/ssf/views/stream.py index c5aef348af1a5..cf292d1818ebf 100644 --- a/authentik/enterprise/providers/ssf/views/stream.py +++ b/authentik/enterprise/providers/ssf/views/stream.py @@ -18,7 +18,6 @@ class StreamDeliverySerializer(PassiveSerializer): class StreamSerializer(ModelSerializer): - delivery = StreamDeliverySerializer() events_requested = ListField( child=ChoiceField(choices=[(x.value, x.value) for x in EventTypes]) @@ -49,7 +48,6 @@ class Meta: class StreamResponseSerializer(PassiveSerializer): - stream_id = CharField(source="pk") iss = SerializerMethodField() aud = ListField(child=CharField()) @@ -88,7 +86,15 @@ class StreamView(SSFView): def post(self, request: Request, *args, **kwargs) -> Response: stream = StreamSerializer(data=request.data) stream.is_valid(raise_exception=True) - instance = stream.save(provider=self.provider) + instance: Stream = stream.save(provider=self.provider) + instance.new_event( + EventTypes.SET_VERIFICATION, + request, + { + "state": None, + }, + sub_id={"format": "opaque", "id": str(instance.uuid)}, + ).queue() response = StreamResponseSerializer(instance=instance, context={"request": request}).data return Response(response, status=201)