Skip to content

Commit

Permalink
more progress
Browse files Browse the repository at this point in the history
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
  • Loading branch information
BeryJu committed Dec 11, 2024
1 parent 2813634 commit 23485a9
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 42 deletions.
56 changes: 56 additions & 0 deletions authentik/enterprise/providers/ssf/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datetime import datetime
from functools import cached_property
from uuid import uuid4

Expand All @@ -6,7 +7,9 @@
from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes
from django.contrib.postgres.fields import ArrayField
from django.db import models
from django.http import HttpRequest
from django.templatetags.static import static
from django.urls import reverse
from django.utils.translation import gettext_lazy as _
from jwt import encode

Expand All @@ -30,6 +33,13 @@ class DeliveryMethods(models.TextChoices):
RISC_POLL = "https://schemas.openid.net/secevent/risc/delivery-method/poll"


class SSFEventStatus(models.TextChoices):
"""SSF Event status"""

PENDING = "pending"
SENT = "sent"


class SSFProvider(BackchannelProvider):
"""Shared Signals Framework"""

Expand Down Expand Up @@ -102,6 +112,35 @@ class Stream(models.Model):
def __str__(self) -> str:
return "SSF Stream"

def new_event(
self, type: EventTypes, request: HttpRequest, event_data: dict, **kwargs
) -> "StreamEvent":
"""Create a new SSF event"""
jti = uuid4()
evt = StreamEvent(
uuid=jti,
stream=self,
type=type,
payload={
"jti": jti.hex,
"aud": self.provider.aud,
"iat": int(datetime.now().timestamp()),
"iss": self.request.build_absolute_uri(
reverse(
"authentik_providers_ssf:configuration",
kwargs={
"application_slug": self.provider.application.slug,
"provider": self.provider.pk,
},
)
),
"events": {EventTypes.SET_VERIFICATION: {event_data}},
**kwargs,
},
)
evt.save()
return evt

def encode(self, data: dict) -> str:
headers = {}
if self.provider.signing_key:
Expand All @@ -117,6 +156,23 @@ class UserStreamSubject(models.Model):
def __str__(self) -> str:
return f"Stream subject {self.stream_id} to {self.user_id}"


class StreamEvent(models.Model):
"""Single stream event to be sent"""

uuid = models.UUIDField(default=uuid4, primary_key=True, editable=False)

stream = models.ForeignKey(Stream, on_delete=models.CASCADE)
status = models.TextField(choices=SSFEventStatus.choices)

type = models.TextField(choices=EventTypes.choices)
payload = models.JSONField(default=dict)

def __str__(self):
return f"Stream event {self.type}"

def queue(self):
"""Queue event to be sent"""
from authentik.enterprise.providers.ssf.tasks import send_single_ssf_event

return send_single_ssf_event.delay(str(self.stream.uuid), str(self.uuid))
29 changes: 1 addition & 28 deletions authentik/enterprise/providers/ssf/signals.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
from datetime import datetime
from uuid import uuid4

from django.contrib.auth.signals import user_logged_out
from django.db.models import Model
from django.db.models.signals import post_save
Expand All @@ -17,9 +14,8 @@
from authentik.enterprise.providers.ssf.models import (
EventTypes,
SSFProvider,
Stream,
)
from authentik.enterprise.providers.ssf.tasks import send_single_ssf_event, send_ssf_event
from authentik.enterprise.providers.ssf.tasks import send_ssf_event
from authentik.events.middleware import audit_ignore
from authentik.events.utils import get_user

Expand Down Expand Up @@ -53,29 +49,6 @@ def ssf_providers_post_save(sender: type[Model], instance: SSFProvider, created:
instance.save()


@receiver(post_save, sender=Stream)
def ssf_stream_post_create(sender: type[Model], instance: Stream, created: bool, **_):
"""Send a verification event when a stream is created"""
if not created:
return
send_single_ssf_event.delay(
str(instance.uuid),
{
"jti": uuid4().hex,
# TODO: Figure out how to get iss
"iss": "https://ak.beryju.dev/.well-known/ssf-configuration/abm-ssf/8",
"aud": instance.aud,
"iat": int(datetime.now().timestamp()),
"sub_id": {"format": "opaque", "id": str(instance.uuid)},
"events": {
"https://schemas.openid.net/secevent/ssf/event-type/verification": {
"state": None,
}
},
},
)


@receiver(user_logged_out)
def user_logged_out_session(sender, request: HttpRequest, user: User, **_):
send_ssf_event.delay(
Expand Down
34 changes: 23 additions & 11 deletions authentik/enterprise/providers/ssf/tasks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from celery import group
from requests.exceptions import RequestException

from authentik.enterprise.providers.ssf.models import DeliveryMethods, EventTypes, Stream
from authentik.enterprise.providers.ssf.models import (
DeliveryMethods,
EventTypes,
SSFEventStatus,
Stream,
StreamEvent,
)
from authentik.lib.utils.http import get_http_session
from authentik.root.celery import CELERY_APP

Expand All @@ -12,28 +18,34 @@
def send_ssf_event(event_type: EventTypes, data: dict):
tasks = []
for stream in Stream.objects.filter(events_requested__in=[event_type]):
tasks.append(send_single_ssf_event.si(str(stream.uuid), data))
event = stream.new_event(
type=event_type,
)
tasks.append(send_single_ssf_event.si(str(stream.uuid), str(event.id)))
main_task = group(*tasks)
main_task()


@CELERY_APP.task(bind=True, autoretry=True, autoretry_for=(RequestException,), retry_backoff=True)
def send_single_ssf_event(self, stream_id: str, data: dict):
def send_single_ssf_event(self, stream_id: str, evt_id: str):
stream = Stream.objects.filter(pk=stream_id).first()
if not stream:
return
event = StreamEvent.objects.filter(pk=evt_id).first()
if not event:
return
if event.status == SSFEventStatus.SENT:
return
if stream.delivery_method == DeliveryMethods.RISC_PUSH:
ssf_push_request.delay(stream_id, data)
ssf_push_request(stream_id, event)
event.status = SSFEventStatus.SENT
event.save()


@CELERY_APP.task(bind=True, autoretry=True, autoretry_for=(RequestException,), retry_backoff=True)
def ssf_push_request(self, stream_id: str, data: dict):
stream = Stream.objects.filter(pk=stream_id).first()
if not stream:
return
def ssf_push_request(event: StreamEvent):
response = session.post(
stream.endpoint_url,
data=stream.encode(data),
event.stream.endpoint_url,
data=event.stream.encode(event.data),
headers={"Content-Type": "application/secevent+jwt", "Accept": "application/json"},
)
response.raise_for_status()
12 changes: 9 additions & 3 deletions authentik/enterprise/providers/ssf/views/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ class StreamDeliverySerializer(PassiveSerializer):


class StreamSerializer(ModelSerializer):

delivery = StreamDeliverySerializer()
events_requested = ListField(
child=ChoiceField(choices=[(x.value, x.value) for x in EventTypes])
Expand Down Expand Up @@ -49,7 +48,6 @@ class Meta:


class StreamResponseSerializer(PassiveSerializer):

stream_id = CharField(source="pk")
iss = SerializerMethodField()
aud = ListField(child=CharField())
Expand Down Expand Up @@ -88,7 +86,15 @@ class StreamView(SSFView):
def post(self, request: Request, *args, **kwargs) -> Response:
stream = StreamSerializer(data=request.data)
stream.is_valid(raise_exception=True)
instance = stream.save(provider=self.provider)
instance: Stream = stream.save(provider=self.provider)
instance.new_event(
EventTypes.SET_VERIFICATION,
request,
{
"state": None,
},
sub_id={"format": "opaque", "id": str(instance.uuid)},
).queue()
response = StreamResponseSerializer(instance=instance, context={"request": request}).data
return Response(response, status=201)

Expand Down

0 comments on commit 23485a9

Please sign in to comment.