diff --git a/CHANGELOG.md b/CHANGELOG.md index 9a81a1b32..6b4166141 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to ### Added - ✨(teams) allow team management for team admins/owners #509 +- ✨(backend) add ServiceProvider #522 ## [1.5.0] - 2024-11-14 diff --git a/Dockerfile b/Dockerfile index 79c0c5bea..07236aa45 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,7 +1,7 @@ # Django People # ---- base image to inherit from ---- -FROM python:3.12.6-alpine3.20 as base +FROM python:3.12.6-alpine3.20 AS base # Upgrade pip to its latest release to speed up dependencies installation RUN python -m pip install --upgrade pip setuptools @@ -11,7 +11,7 @@ RUN apk update && \ apk upgrade ### ---- Front-end dependencies image ---- -FROM node:20 as frontend-deps +FROM node:20 AS frontend-deps WORKDIR /deps @@ -24,7 +24,7 @@ COPY ./src/frontend/packages/eslint-config-people/package.json ./packages/eslint RUN yarn --frozen-lockfile ### ---- Front-end builder dev image ---- -FROM node:20 as frontend-builder-dev +FROM node:20 AS frontend-builder-dev WORKDIR /builder @@ -34,12 +34,12 @@ COPY ./src/frontend . WORKDIR ./apps/desk ### ---- Front-end builder image ---- -FROM frontend-builder-dev as frontend-builder +FROM frontend-builder-dev AS frontend-builder RUN yarn build # ---- Front-end image ---- -FROM nginxinc/nginx-unprivileged:1.26-alpine as frontend-production +FROM nginxinc/nginx-unprivileged:1.26-alpine AS frontend-production # Un-privileged user running the application ARG DOCKER_USER @@ -60,7 +60,7 @@ CMD ["nginx", "-g", "daemon off;"] # ---- Back-end builder image ---- -FROM base as back-builder +FROM base AS back-builder WORKDIR /builder @@ -72,7 +72,7 @@ RUN mkdir /install && \ # ---- mails ---- -FROM node:20 as mail-builder +FROM node:20 AS mail-builder COPY ./src/mail /mail/app @@ -83,7 +83,7 @@ RUN yarn install --frozen-lockfile && \ # ---- static link collector ---- -FROM base as link-collector +FROM base AS link-collector ARG PEOPLE_STATIC_ROOT=/data/static # Install libpangocairo & rdfind @@ -108,7 +108,7 @@ RUN DJANGO_CONFIGURATION=Build DJANGO_JWT_PRIVATE_SIGNING_KEY=Dummy \ RUN rdfind -makesymlinks true -followsymlinks true -makeresultsfile false ${PEOPLE_STATIC_ROOT} # ---- Core application image ---- -FROM base as core +FROM base AS core ENV PYTHONUNBUFFERED=1 @@ -143,7 +143,7 @@ WORKDIR /app ENTRYPOINT [ "/usr/local/bin/entrypoint" ] # ---- Development image ---- -FROM core as backend-development +FROM core AS backend-development # Switch back to the root user to install development dependencies USER root:root @@ -169,7 +169,7 @@ ENV DB_HOST=postgresql \ CMD ["python", "manage.py", "runserver", "0.0.0.0:8000"] # ---- Production image ---- -FROM core as backend-production +FROM core AS backend-production ARG PEOPLE_STATIC_ROOT=/data/static diff --git a/src/backend/core/admin.py b/src/backend/core/admin.py index c0b3d4a95..adfe05d64 100644 --- a/src/backend/core/admin.py +++ b/src/backend/core/admin.py @@ -108,11 +108,20 @@ def get_user(self, obj): get_user.short_description = _("User") +class TeamServiceProviderInline(admin.TabularInline): + """Inline admin class for service providers.""" + + can_delete = False + model = models.Team.service_providers.through + extra = 0 + + @admin.register(models.Team) class TeamAdmin(admin.ModelAdmin): """Team admin interface declaration.""" - inlines = (TeamAccessInline, TeamWebhookInline) + inlines = (TeamAccessInline, TeamWebhookInline, TeamServiceProviderInline) + exclude = ("service_providers",) # Handled by the inline list_display = ( "name", "created_at", @@ -188,6 +197,14 @@ class ContactAdmin(admin.ModelAdmin): ) +class OrganizationServiceProviderInline(admin.TabularInline): + """Inline admin class for service providers.""" + + can_delete = False + model = models.Organization.service_providers.through + extra = 0 + + @admin.register(models.Organization) class OrganizationAdmin(admin.ModelAdmin): """Admin interface for organizations.""" @@ -198,7 +215,8 @@ class OrganizationAdmin(admin.ModelAdmin): "updated_at", ) search_fields = ("name",) - inlines = (OrganizationAccessInline,) + inlines = (OrganizationAccessInline, OrganizationServiceProviderInline) + exclude = ("service_providers",) # Handled by the inline @admin.register(models.OrganizationAccess) @@ -213,3 +231,17 @@ class OrganizationAccessAdmin(admin.ModelAdmin): "created_at", "updated_at", ) + + +@admin.register(models.ServiceProvider) +class ServiceProviderAdmin(admin.ModelAdmin): + """Admin interface for service providers.""" + + list_display = ( + "name", + "audience_id", + "created_at", + "updated_at", + ) + search_fields = ("name", "audience_id") + readonly_fields = ("created_at", "updated_at") diff --git a/src/backend/core/api/serializers.py b/src/backend/core/api/serializers.py index 4e73651f5..feccd7f2c 100644 --- a/src/backend/core/api/serializers.py +++ b/src/backend/core/api/serializers.py @@ -4,6 +4,7 @@ from timezone_field.rest_framework import TimeZoneSerializerField from core import models +from core.models import ServiceProvider class ContactSerializer(serializers.ModelSerializer): @@ -205,6 +206,9 @@ class TeamSerializer(serializers.ModelSerializer): """Serialize teams.""" abilities = serializers.SerializerMethodField(read_only=True) + service_providers = serializers.PrimaryKeyRelatedField( + queryset=ServiceProvider.objects.all(), many=True, required=False + ) class Meta: model = models.Team @@ -215,6 +219,7 @@ class Meta: "abilities", "created_at", "updated_at", + "service_providers", ] read_only_fields = [ "id", @@ -226,6 +231,13 @@ class Meta: def create(self, validated_data): """Create a new team with organization enforcement.""" + # When called as a resource server, we enforce the team service provider + if sp_audience := self.context.get("from_service_provider_audience", None): + service_providers, _created = models.ServiceProvider.objects.get_or_create( + audience_id=sp_audience + ) + validated_data["service_providers"] = [service_providers] + # Note: this is not the purpose of this API to check the user has an organization return super().create( validated_data=validated_data @@ -273,3 +285,12 @@ def validate(self, attrs): attrs["team_id"] = team_id attrs["issuer"] = user return attrs + + +class ServiceProviderSerializer(serializers.ModelSerializer): + """Serialize service providers.""" + + class Meta: + model = models.ServiceProvider + fields = ["id", "audience_id", "name"] + read_only_fields = ["id", "audience_id"] diff --git a/src/backend/core/api/viewsets.py b/src/backend/core/api/viewsets.py index 5fbd4ce43..c03e92b4c 100644 --- a/src/backend/core/api/viewsets.py +++ b/src/backend/core/api/viewsets.py @@ -2,7 +2,7 @@ from django.conf import settings from django.contrib.postgres.search import TrigramSimilarity -from django.db.models import Func, Max, OuterRef, Q, Subquery, Value +from django.db.models import Func, Max, OuterRef, Prefetch, Q, Subquery, Value from django.db.models.functions import Coalesce from rest_framework import ( @@ -20,6 +20,7 @@ from core import models +from ..resource_server.mixins import ResourceServerMixin from . import permissions, serializers SIMILARITY_THRESHOLD = 0.04 @@ -240,6 +241,7 @@ def get_me(self, request): class TeamViewSet( + ResourceServerMixin, mixins.CreateModelMixin, mixins.DestroyModelMixin, mixins.ListModelMixin, @@ -262,8 +264,36 @@ def get_queryset(self): user_role_query = models.TeamAccess.objects.filter( user=self.request.user, team=OuterRef("pk") ).values("role")[:1] - return models.Team.objects.filter(accesses__user=self.request.user).annotate( - user_role=Subquery(user_role_query) + + service_provider_audience = self._get_service_provider_audience() + if service_provider_audience: + # Restrict displayed service providers when used as a resource server + service_provider_prefetch = Prefetch( + "service_providers", + queryset=models.ServiceProvider.objects.filter( + audience_id=self._get_service_provider_audience() + ), + ) + + # Restrict results to the Service Provider's teams when used as a resource server + service_provider_filters = { + "service_providers__audience_id": service_provider_audience + } + + else: + service_provider_prefetch = Prefetch( + "service_providers", + queryset=models.ServiceProvider.objects.all(), + ) + service_provider_filters = {} + + return ( + models.Team.objects.prefetch_related("accesses", service_provider_prefetch) + .filter( + accesses__user=self.request.user, + **service_provider_filters, + ) + .annotate(user_role=Subquery(user_role_query)) ) def perform_create(self, serializer): @@ -510,3 +540,50 @@ def get(self, request): dict_settings[setting] = getattr(settings, setting) return response.Response(dict_settings) + + +class ServiceProviderFilter(filters.BaseFilterBackend): + """ + Filter service providers by audience. + """ + + def filter_queryset(self, request, queryset, view): + """ + Filter service providers by audience. + """ + if name := request.GET.get("name"): + queryset = queryset.filter(name__icontains=name) + if audience_id := request.GET.get("audience_id"): + queryset = queryset.filter(audience_id=audience_id) + return queryset + + +class ServiceProviderViewSet( + mixins.ListModelMixin, + mixins.RetrieveModelMixin, + viewsets.GenericViewSet, +): + """ + API ViewSet for all interactions with service providers. + + GET /api/v1.0/service-providers/ + Return a list of service providers. + + GET /api/v1.0/service-providers// + Return a service provider. + """ + + permission_classes = [permissions.IsAuthenticated] + queryset = models.ServiceProvider.objects.all() + serializer_class = serializers.ServiceProviderSerializer + throttle_classes = [BurstRateThrottle, SustainedRateThrottle] + pagination_class = Pagination + filter_backends = [filters.OrderingFilter, ServiceProviderFilter] + ordering = ["name"] + ordering_fields = ["name", "created_at"] + + def get_queryset(self): + """Filter the queryset to limit results to user's organization.""" + queryset = super().get_queryset() + queryset = queryset.filter(organizations__id=self.request.user.organization_id) + return queryset diff --git a/src/backend/core/factories.py b/src/backend/core/factories.py index 4f8f19d4d..ca715bd9f 100644 --- a/src/backend/core/factories.py +++ b/src/backend/core/factories.py @@ -180,6 +180,13 @@ def users(self, create, extracted, **kwargs): else: TeamAccessFactory(team=self, user=user_entry[0], role=user_entry[1]) + @factory.post_generation + def service_providers(self, create, extracted, **kwargs): + """Add service providers to team from a given list of service providers.""" + if not create or not extracted: + return + self.service_providers.set(extracted) + class TeamAccessFactory(factory.django.DjangoModelFactory): """Create fake team user accesses for testing.""" @@ -212,3 +219,27 @@ class Meta: email = factory.Faker("email") role = factory.fuzzy.FuzzyChoice([role[0] for role in models.RoleChoices.choices]) issuer = factory.SubFactory(UserFactory) + + +class ServiceProviderFactory(factory.django.DjangoModelFactory): + """A factory to create service providers for testing purposes.""" + + class Meta: + model = models.ServiceProvider + skip_postgeneration_save = True + + audience_id = factory.Faker("uuid4") + + @factory.post_generation + def teams(self, create, extracted, **kwargs): + """Add teams to service provider from a given list.""" + if not create or not extracted: + return + self.teams.set(extracted) + + @factory.post_generation + def organizations(self, create, extracted, **kwargs): + """Add organization to service provider from a given list.""" + if not create or not extracted: + return + self.organizations.set(extracted) diff --git a/src/backend/core/migrations/0002_add_organization_and_more.py b/src/backend/core/migrations/0002_add_organization_and_more.py index bb7b54b6c..c0d84c389 100644 --- a/src/backend/core/migrations/0002_add_organization_and_more.py +++ b/src/backend/core/migrations/0002_add_organization_and_more.py @@ -22,8 +22,8 @@ class Migration(migrations.Migration): ('created_at', models.DateTimeField(auto_now_add=True, help_text='date and time at which a record was created', verbose_name='created at')), ('updated_at', models.DateTimeField(auto_now=True, help_text='date and time at which a record was last updated', verbose_name='updated at')), ('name', models.CharField(max_length=100, verbose_name='name')), - ('registration_id_list', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=128), blank=True, default=list, size=None, validators=[core.models.validate_unique_registration_id], verbose_name='registration ID list')), - ('domain_list', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=256), blank=True, default=list, size=None, validators=[core.models.validate_unique_domain], verbose_name='domain list')), + ('registration_id_list', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=128), blank=True, default=list, size=None, verbose_name='registration ID list')), + ('domain_list', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=256), blank=True, default=list, size=None, verbose_name='domain list')), ], options={ 'verbose_name': 'organization', diff --git a/src/backend/core/migrations/0004_add_serviceprovider.py b/src/backend/core/migrations/0004_add_serviceprovider.py new file mode 100644 index 000000000..ba4701cf7 --- /dev/null +++ b/src/backend/core/migrations/0004_add_serviceprovider.py @@ -0,0 +1,39 @@ +# Generated by Django 5.1.2 on 2024-11-07 16:24 + +import uuid +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('core', '0003_team_slug_nullable'), + ] + + operations = [ + migrations.CreateModel( + name='ServiceProvider', + fields=[ + ('id', models.UUIDField(default=uuid.uuid4, editable=False, help_text='primary key for the record as UUID', primary_key=True, serialize=False, verbose_name='id')), + ('created_at', models.DateTimeField(auto_now_add=True, help_text='date and time at which a record was created', verbose_name='created at')), + ('updated_at', models.DateTimeField(auto_now=True, help_text='date and time at which a record was last updated', verbose_name='updated at')), + ('name', models.CharField(max_length=256, unique=True, verbose_name='name')), + ('audience_id', models.CharField(db_index=True, max_length=256, unique=True, verbose_name='audience id')), + ], + options={ + 'verbose_name': 'service provider', + 'verbose_name_plural': 'service providers', + 'db_table': 'people_service_provider', + }, + ), + migrations.AddField( + model_name='organization', + name='service_providers', + field=models.ManyToManyField(blank=True, related_name='organizations', to='core.serviceprovider'), + ), + migrations.AddField( + model_name='team', + name='service_providers', + field=models.ManyToManyField(blank=True, related_name='teams', to='core.serviceprovider'), + ), + ] diff --git a/src/backend/core/models.py b/src/backend/core/models.py index 547c2dd12..55c98f387 100644 --- a/src/backend/core/models.py +++ b/src/backend/core/models.py @@ -173,6 +173,34 @@ def clean(self): raise exceptions.ValidationError({"data": [error_message]}) from e +class ServiceProvider(BaseModel): + """ + Represents a service provider that will consume our information. + + Organization uses this model to define the list of SP available to their users. + Team uses this model to define their visibility to the various SP. + """ + + name = models.CharField(_("name"), max_length=256, unique=True) + audience_id = models.CharField( + _("audience id"), max_length=256, unique=True, db_index=True + ) + + class Meta: + db_table = "people_service_provider" + verbose_name = _("service provider") + verbose_name_plural = _("service providers") + + def __str__(self): + return self.name + + def save(self, *args, **kwargs): + """Enforce name (even if ugly) from the `audience_id` field.""" + if not self.name: + self.name = self.audience_id # ok, same length + return super().save(*args, **kwargs) + + class OrganizationManager(models.Manager): """ Custom manager for the Organization model, to manage complexity/automation. @@ -223,24 +251,6 @@ def get_or_create_from_user_claims( raise ValueError("Should never reach this point.") -def validate_unique_registration_id(value): - """ - Validate that the registration ID values in an array field are unique across all instances. - """ - if Organization.objects.filter(registration_id_list__overlap=value).exists(): - raise ValidationError( - "registration_id_list value must be unique across all instances." - ) - - -def validate_unique_domain(value): - """ - Validate that the domain values in an array field are unique across all instances. - """ - if Organization.objects.filter(domain_list__overlap=value).exists(): - raise ValidationError("domain_list value must be unique across all instances.") - - class Organization(BaseModel): """ Organization model used to regroup Teams. @@ -270,16 +280,20 @@ class Organization(BaseModel): verbose_name=_("registration ID list"), default=list, blank=True, - validators=[ - validate_unique_registration_id, - ], + # list overlap validation is done in the validate_unique method ) domain_list = ArrayField( models.CharField(max_length=256), verbose_name=_("domain list"), default=list, blank=True, - validators=[validate_unique_domain], + # list overlap validation is done in the validate_unique method + ) + + service_providers = models.ManyToManyField( + ServiceProvider, + related_name="organizations", + blank=True, ) objects = OrganizationManager() @@ -306,6 +320,41 @@ class Meta: def __str__(self): return f"{self.name} (# {self.pk})" + def validate_unique(self, exclude=None): + """ + Validate Registration/Domain values in an array field are unique + across all instances. + + This can't be done in the field validator because we need to + exclude the current object if already in database. + """ + super().validate_unique(exclude=exclude) + + if self.pk: + organization_qs = Organization.objects.exclude(pk=self.pk) + else: + organization_qs = Organization.objects.all() + + # Check a registration ID can only be present in one organization + if ( + self.registration_id_list + and organization_qs.filter( + registration_id_list__overlap=self.registration_id_list + ).exists() + ): + raise ValidationError( + "registration_id_list value must be unique across all instances." + ) + + # Check a domain can only be present in one organization + if ( + self.domain_list + and organization_qs.filter(domain_list__overlap=self.domain_list).exists() + ): + raise ValidationError( + "domain_list value must be unique across all instances." + ) + class User(AbstractBaseUser, BaseModel, auth_models.PermissionsMixin): """User model to work with OIDC only authentication.""" @@ -539,6 +588,11 @@ def __str__(self): class Team(BaseModel): """ Represents the link between teams and users, specifying the role a user has in a team. + + When a team is created from here, the user have to choose which Service Providers + can see it. + When a team is created from a Service Provider this one is automatically set in the + Team `service_providers`. """ name = models.CharField(max_length=100) @@ -556,6 +610,11 @@ class Team(BaseModel): null=True, # Need to be set to False when everything is migrated blank=True, # Need to be set to False when everything is migrated ) + service_providers = models.ManyToManyField( + ServiceProvider, + related_name="teams", + blank=True, + ) class Meta: db_table = "people_team" diff --git a/src/backend/core/resource_server/authentication.py b/src/backend/core/resource_server/authentication.py index 145521499..f6ecb2a03 100644 --- a/src/backend/core/resource_server/authentication.py +++ b/src/backend/core/resource_server/authentication.py @@ -53,3 +53,20 @@ def get_access_token(self, request): pass return access_token + + def authenticate(self, request): + """ + Authenticate the request and return a tuple of (user, token) or None. + + We override the 'authenticate' method from the parent class to store + the introspected token audience inside the request. + """ + result = super().authenticate(request) # Might raise AuthenticationFailed + + if result is None: # Case when there is no access token + return None + + # Note: at this stage, the request is a "drf_request" object + request.resource_server_token_audience = self.backend.token_origin_audience + + return result diff --git a/src/backend/core/resource_server/backend.py b/src/backend/core/resource_server/backend.py index b7ce7732e..1f71017c0 100644 --- a/src/backend/core/resource_server/backend.py +++ b/src/backend/core/resource_server/backend.py @@ -61,6 +61,10 @@ def __init__(self, authorization_server_client): token_introspection={"essential": True}, ) + # Declare the token origin audience: to know where the token comes from + # and store it for further use in the application + self.token_origin_audience = None + # pylint: disable=unused-argument def get_or_create_user(self, access_token, id_token, payload): """Maintain API compatibility with OIDCAuthentication class from mozilla-django-oidc @@ -85,6 +89,8 @@ def get_user(self, access_token): that extends RFC 7662 by returning a signed and encrypted JWT for stronger assurance that the authorization server issued the token introspection response. """ + self.token_origin_audience = None # Reset the token origin audience + jwt = self._introspect(access_token) claims = self._verify_claims(jwt) user_info = self._verify_user_info(claims["token_introspection"]) @@ -100,6 +106,8 @@ def get_user(self, access_token): logger.debug("Login failed: No user with %s found", sub) return None + self.token_origin_audience = str(user_info["aud"]) + return user def _verify_user_info(self, introspection_response): @@ -127,6 +135,12 @@ def _verify_user_info(self, introspection_response): logger.debug(message) raise SuspiciousOperation(message) + audience = introspection_response.get("aud", None) + if not audience: + raise SuspiciousOperation( + "Introspection response does not provide source audience." + ) + return introspection_response def _introspect(self, token): @@ -219,6 +233,8 @@ def _verify_claims(self, token): class ResourceServerImproperlyConfiguredBackend: """Fallback backend for improperly configured Resource Servers.""" + token_origin_audience = None + def get_or_create_user(self, access_token, id_token, payload): """Indicate that the Resource Server is improperly configured.""" raise AuthenticationFailed("Resource Server is improperly configured") diff --git a/src/backend/core/resource_server/mixins.py b/src/backend/core/resource_server/mixins.py new file mode 100644 index 000000000..ed37378fd --- /dev/null +++ b/src/backend/core/resource_server/mixins.py @@ -0,0 +1,58 @@ +""" +Mixins for resource server views. +""" + +from rest_framework import exceptions as drf_exceptions + +from .authentication import ResourceServerAuthentication + + +class ResourceServerMixin: + """ + Mixin for resource server views: + - Adds the ResourceServerAuthentication to the list of authenticators. + - Adds the Service Provider ID to the serializer context. + - Fetch the Service Provider ID from the OIDC introspected token stored + in the request. + + This Mixin *must* be used in every view that should act as a resource server. + """ + + def get_authenticators(self): + """ + Return the list of authenticators that this view uses + including the Resource Server auth. + """ + return [ResourceServerAuthentication()] + super().get_authenticators() + + def get_serializer_context(self): + """Extra context provided to the serializer class.""" + context = super().get_serializer_context() + + # When used as a resource server, we need to know the audience to automatically: + # - add the Service Provider to the team "scope" on creation + context["from_service_provider_audience"] = ( + self._get_service_provider_audience() + ) + + return context + + def _get_service_provider_audience(self): + """Return the audience of the Service Provider from the OIDC introspected token.""" + if not isinstance( + self.request.successful_authenticator, ResourceServerAuthentication + ): + # We could check request.resource_server_token_audience here, but it's + # more explicit to check the authenticator type and assert the attribute + # existence. + return None + + # When used as a resource server, the request has a token audience + service_provider_audience = self.request.resource_server_token_audience + + if not service_provider_audience: # should not happen + raise drf_exceptions.AuthenticationFailed( + "Resource server token audience not found in request" + ) + + return service_provider_audience diff --git a/src/backend/core/tests/conftest.py b/src/backend/core/tests/conftest.py new file mode 100644 index 000000000..3b35e020f --- /dev/null +++ b/src/backend/core/tests/conftest.py @@ -0,0 +1,148 @@ +"""Defines fixtures for the tests.""" + +import base64 +import json +from unittest import mock + +import pytest +import responses +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from joserfc import jwe as jose_jwe +from joserfc import jwt as jose_jwt +from joserfc.rfc7518.rsa_key import RSAKey +from jwt.utils import to_base64url_uint + +from core.factories import ServiceProviderFactory, UserFactory +from core.resource_server.authentication import ResourceServerAuthentication + + +@pytest.fixture(name="resource_server_settings") +def resource_server_settings_fixture(settings): + """ + Defines the settings for the resource server + for a full authentication with introspection process. + + This is more for integration tests, I believe it is nice to have + at least few tests that cover the full process of authentication. + This is more useful when we want to check corner cases around + the data provided by the introspection endpoint. + + For unit tests, we can mock the introspection process using the + `authenticated_user_with_service_provider` fixture. + """ + + private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + + unencrypted_pem_private_key = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + + pem_public_key = private_key.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + + settings.OIDC_RS_PRIVATE_KEY_STR = unencrypted_pem_private_key.decode("utf-8") + settings.OIDC_RS_ENCRYPTION_KEY_TYPE = "RSA" + settings.OIDC_RS_ENCRYPTION_ENCODING = "A256GCM" + settings.OIDC_RS_ENCRYPTION_ALGO = "RSA-OAEP" + settings.OIDC_RS_SIGNING_ALGO = "RS256" + settings.OIDC_RS_CLIENT_ID = "some_client_id" + settings.OIDC_RS_CLIENT_SECRET = "some_client_secret" + + settings.OIDC_OP_URL = "https://oidc.example.com" + settings.OIDC_VERIFY_SSL = False + settings.OIDC_TIMEOUT = 5 + settings.OIDC_PROXY = None + settings.OIDC_OP_JWKS_ENDPOINT = "https://oidc.example.com/jwks" + settings.OIDC_OP_INTROSPECTION_ENDPOINT = "https://oidc.example.com/introspect" + + # Mock the JWKS endpoint + public_numbers = private_key.public_key().public_numbers() + responses.add( + responses.GET, + settings.OIDC_OP_JWKS_ENDPOINT, + body=json.dumps( + { + "keys": [ + { + "kty": settings.OIDC_RS_ENCRYPTION_KEY_TYPE, + "alg": settings.OIDC_RS_SIGNING_ALGO, + "use": "sig", + "kid": "1234567890", + "n": to_base64url_uint(public_numbers.n).decode("ascii"), + "e": to_base64url_uint(public_numbers.e).decode("ascii"), + } + ] + } + ), + ) + + def encrypt_jwt(json_data): + token = jose_jwt.encode( + { + "kid": "1234567890", + "alg": settings.OIDC_RS_SIGNING_ALGO, + }, + json_data, + RSAKey.import_key(unencrypted_pem_private_key), + algorithms=[settings.OIDC_RS_SIGNING_ALGO], + ) + + return jose_jwe.encrypt_compact( + protected={ + "alg": settings.OIDC_RS_ENCRYPTION_ALGO, + "enc": settings.OIDC_RS_ENCRYPTION_ENCODING, + }, + plaintext=token, + public_key=RSAKey.import_key(pem_public_key), + algorithms=[ + settings.OIDC_RS_ENCRYPTION_ALGO, + settings.OIDC_RS_ENCRYPTION_ENCODING, + ], + ) + + yield { + "encrypt_jwt": encrypt_jwt, + } + + +@pytest.fixture +def authenticated_user_with_service_provider(client): + """ + Fixture to authenticate a user with a service provider via a resource server call. + + This fixture allows to authenticate a user with a service provider without doing + all the introspection process. + """ + user = UserFactory() + service_provider = ServiceProviderFactory() + + def mock_authenticate(self, request): # pylint: disable=unused-argument + request.resource_server_token_audience = service_provider.audience_id + return user, "1234" + + with mock.patch.object( + ResourceServerAuthentication, "authenticate", mock_authenticate + ): + client.force_login( + user, + backend="core.resource_server.authentication.ResourceServerAuthentication", + ) + yield user, service_provider + + +def build_authorization_bearer(token): + """ + Build an Authorization Bearer header value from a token. + + This can be used like this: + client.post( + ... + HTTP_AUTHORIZATION=f"Bearer {build_authorization_bearer('some_token')}", + ) + """ + return base64.b64encode(token.encode("utf-8")).decode("utf-8") diff --git a/src/backend/core/tests/resource_server/test_backend.py b/src/backend/core/tests/resource_server/test_backend.py index 73b808cf7..1acd82a4a 100644 --- a/src/backend/core/tests/resource_server/test_backend.py +++ b/src/backend/core/tests/resource_server/test_backend.py @@ -296,7 +296,7 @@ def test_introspect_public_key_import_failure( def test_verify_user_info_success(resource_server_backend): """Test '_verify_user_info' with a successful response.""" - introspection_response = {"active": True, "scope": "groups"} + introspection_response = {"active": True, "scope": "groups", "aud": "123"} result = resource_server_backend._verify_user_info(introspection_response) assert result == introspection_response @@ -333,7 +333,7 @@ def test_get_user_success(resource_server_backend): access_token = "valid_access_token" mock_jwt = Mock() - mock_claims = {"token_introspection": {"sub": "user123"}} + mock_claims = {"token_introspection": {"sub": "user123", "aud": "123"}} mock_user = Mock() resource_server_backend._introspect = Mock(return_value=mock_jwt) diff --git a/src/backend/core/tests/service_providers/__init__.py b/src/backend/core/tests/service_providers/__init__.py new file mode 100644 index 000000000..a1d1f693b --- /dev/null +++ b/src/backend/core/tests/service_providers/__init__.py @@ -0,0 +1,3 @@ +""" +Test for the service providers viewset. +""" diff --git a/src/backend/core/tests/service_providers/test_core_api_service_providers_list.py b/src/backend/core/tests/service_providers/test_core_api_service_providers_list.py new file mode 100644 index 000000000..9eb179aad --- /dev/null +++ b/src/backend/core/tests/service_providers/test_core_api_service_providers_list.py @@ -0,0 +1,91 @@ +""" +Tests for Service Provider API endpoint in People's core app: list +""" + +import pytest +from rest_framework.status import HTTP_200_OK, HTTP_401_UNAUTHORIZED + +from core import factories + +pytestmark = pytest.mark.django_db + + +def test_api_service_providers_list_anonymous(client): + """Anonymous users should not be allowed to list service providers.""" + factories.ServiceProviderFactory.create_batch(2) + + response = client.get("/api/v1.0/service-providers/") + + assert response.status_code == HTTP_401_UNAUTHORIZED + assert response.json() == { + "detail": "Authentication credentials were not provided." + } + + +def test_api_service_providers_list_authenticated(client): + """ + Authenticated users should be able to list service providers + of their organization. + """ + user = factories.UserFactory(with_organization=True) + client.force_login(user) + + service_provider_1 = factories.ServiceProviderFactory( + name="A", organizations=[user.organization] + ) + service_provider_2 = factories.ServiceProviderFactory( + name="B", organizations=[user.organization] + ) + + # Generate some not fetched data + factories.ServiceProviderFactory.create_batch( + 2, organizations=[factories.OrganizationFactory(with_registration_id=True)] + ) # Other service providers + factories.TeamFactory( + users=[user], service_providers=[factories.ServiceProviderFactory()] + ) + + response = client.get( + "/api/v1.0/service-providers/", + ) + + assert response.status_code == HTTP_200_OK + assert response.json() == { + "count": 2, + "next": None, + "previous": None, + "results": [ + { + "audience_id": str(service_provider_1.audience_id), + "id": str(service_provider_1.pk), + "name": "A", + }, + { + "audience_id": str(service_provider_2.audience_id), + "id": str(service_provider_2.pk), + "name": "B", + }, + ], + } + + +def test_api_service_providers_order(client): + """Test that the service providers are sorted as requested.""" + user = factories.UserFactory(with_organization=True) + factories.ServiceProviderFactory(name="A", organizations=[user.organization]) + factories.ServiceProviderFactory(name="B", organizations=[user.organization]) + + client.force_login(user) + + # Test ordering by name descending + response = client.get("/api/v1.0/service-providers/?ordering=-name") + assert response.status_code == 200 + response_data = response.json()["results"] + assert response_data[0]["name"] == "B" + assert response_data[1]["name"] == "A" + + # Test ordering by creation date ascending + response = client.get("/api/v1.0/service-providers/?ordering=created_at") + response_data = response.json()["results"] + assert response_data[0]["name"] == "A" + assert response_data[1]["name"] == "B" diff --git a/src/backend/core/tests/service_providers/test_core_api_service_providers_retrieve.py b/src/backend/core/tests/service_providers/test_core_api_service_providers_retrieve.py new file mode 100644 index 000000000..5951a9d12 --- /dev/null +++ b/src/backend/core/tests/service_providers/test_core_api_service_providers_retrieve.py @@ -0,0 +1,84 @@ +""" +Tests for Service Provider API endpoint in People's core app: retrieve +""" + +import pytest +from rest_framework.status import HTTP_200_OK, HTTP_401_UNAUTHORIZED, HTTP_404_NOT_FOUND + +from core import factories + +pytestmark = pytest.mark.django_db + + +def test_api_service_providers_retrieve_anonymous(client): + """Anonymous users should not be allowed to retrieve service providers.""" + service_provider = factories.ServiceProviderFactory() + + response = client.get(f"/api/v1.0/service-providers/{service_provider.pk}/") + + assert response.status_code == HTTP_401_UNAUTHORIZED + assert response.json() == { + "detail": "Authentication credentials were not provided." + } + + +def test_api_service_providers_retrieve_authenticated_allowed(client): + """ + Authenticated users should be able to retrieve service providers + of their organization. + """ + user = factories.UserFactory(with_organization=True) + client.force_login(user) + + service_provider = factories.ServiceProviderFactory( + organizations=[user.organization] + ) + + response = client.get(f"/api/v1.0/service-providers/{service_provider.pk}/") + + assert response.status_code == HTTP_200_OK + assert response.json() == { + "audience_id": str(service_provider.audience_id), + "id": str(service_provider.pk), + "name": service_provider.name, + } + + +def test_api_service_providers_retrieve_authenticated_other_organization(client): + """ + Authenticated users should not be able to retrieve service providers + of other organization. + """ + user = factories.UserFactory(with_organization=True) + client.force_login(user) + + service_provider = factories.ServiceProviderFactory( + organizations=[factories.OrganizationFactory(with_registration_id=True)] + ) + + response = client.get(f"/api/v1.0/service-providers/{service_provider.pk}/") + + assert response.status_code == HTTP_404_NOT_FOUND + assert response.json() == {"detail": "No ServiceProvider matches the given query."} + + +def test_api_service_providers_retrieve_authenticated_on_teams(client): + """ + Authenticated users should not be able to retrieve service providers + of because of their teams (might change later if needed). + """ + user = factories.UserFactory(with_organization=True) + client.force_login(user) + + other_organization = factories.OrganizationFactory(with_registration_id=True) + service_provider = factories.ServiceProviderFactory() + factories.TeamFactory( + users=[user], + organization=other_organization, + service_providers=[service_provider], + ) + + response = client.get(f"/api/v1.0/service-providers/{service_provider.pk}/") + + assert response.status_code == HTTP_404_NOT_FOUND + assert response.json() == {"detail": "No ServiceProvider matches the given query."} diff --git a/src/backend/core/tests/team_accesses/__init__.py b/src/backend/core/tests/team_accesses/__init__.py new file mode 100644 index 000000000..bff1c1814 --- /dev/null +++ b/src/backend/core/tests/team_accesses/__init__.py @@ -0,0 +1 @@ +"""Team accesses tests package.""" diff --git a/src/backend/core/tests/teams/__init__.py b/src/backend/core/tests/teams/__init__.py new file mode 100644 index 000000000..5630f312f --- /dev/null +++ b/src/backend/core/tests/teams/__init__.py @@ -0,0 +1 @@ +"""Teams tests package.""" diff --git a/src/backend/core/tests/teams/test_core_api_teams_create.py b/src/backend/core/tests/teams/test_core_api_teams_create.py index 54c40e03d..6653aa290 100644 --- a/src/backend/core/tests/teams/test_core_api_teams_create.py +++ b/src/backend/core/tests/teams/test_core_api_teams_create.py @@ -3,6 +3,7 @@ """ import pytest +import responses from rest_framework.status import ( HTTP_201_CREATED, HTTP_401_UNAUTHORIZED, @@ -10,7 +11,8 @@ from rest_framework.test import APIClient from core.factories import OrganizationFactory, UserFactory -from core.models import Team +from core.models import ServiceProvider, Team +from core.tests.conftest import build_authorization_bearer pytestmark = pytest.mark.django_db @@ -81,3 +83,105 @@ def test_api_teams_create_cannot_override_organization(): assert team.name == "my team" assert team.organization == organization assert team.accesses.filter(role="owner", user=user).exists() + + +@responses.activate +def test_api_teams_create_authenticated_resource_server( + client, resource_server_settings +): + """ + Authenticated users should be able to create teams and should automatically be declared + as the owner of the newly created team. + """ + organization = OrganizationFactory(with_registration_id=True) + user = UserFactory(organization=organization) + + # Mock the introspection endpoint + encrypt_jwt = resource_server_settings["encrypt_jwt"] + + responses.add( + responses.POST, + "https://oidc.example.com/introspect", + body=encrypt_jwt( + { + "iss": "https://oidc.example.com", + "aud": "some_client_id", # settings.OIDC_RS_CLIENT_ID + "token_introspection": { + "sub": user.sub, + "iss": "https://oidc.example.com", + "aud": "some_service_provider", + "scope": "openid groups", + "active": True, + }, + } + ), + ) + + # Authenticate using the resource server, ie via the Authorization header + response = client.post( + "/api/v1.0/teams/", + { + "name": "my team", + }, + format="json", + HTTP_AUTHORIZATION=f"Bearer {build_authorization_bearer('some_token')}", + ) + + assert response.status_code == HTTP_201_CREATED + + team = Team.objects.get() + team_access = team.accesses.get() + service_provider = ServiceProvider.objects.get() + + assert response.json() == { + "abilities": { + "delete": True, + "get": True, + "manage_accesses": True, + "patch": True, + "put": True, + }, + "accesses": [str(team_access.pk)], + "created_at": team.created_at.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "id": str(team.pk), + "name": "my team", + "service_providers": [str(service_provider.pk)], + "updated_at": team.updated_at.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + } + + # check team data + assert team.name == "my team" + assert team.organization == organization + + # check team access data + assert team_access.role == "owner" + assert team_access.user == user + + # check service provider data + assert service_provider.audience_id == "some_service_provider" + + +def test_api_teams_create_authenticated_resource_server_existing_service_provider( + client, authenticated_user_with_service_provider +): + """ + Authenticated user should be able to create a team via + resource server with an existing service provider (then not created again). + """ + _user, service_provider = authenticated_user_with_service_provider + + response = client.post( + "/api/v1.0/teams/", + { + "name": "my team", + }, + format="json", + HTTP_AUTHORIZATION="Bearer b64untestedbearertoken", + ) + + assert response.status_code == HTTP_201_CREATED + + assert ServiceProvider.objects.count() == 1 # no object created + team = Team.objects.get() # team created + assert team.service_providers.get().audience_id == service_provider.audience_id + assert team.name == "my team" diff --git a/src/backend/core/tests/teams/test_core_api_teams_delete.py b/src/backend/core/tests/teams/test_core_api_teams_delete.py index 4e3de4fa3..94a4070ba 100644 --- a/src/backend/core/tests/teams/test_core_api_teams_delete.py +++ b/src/backend/core/tests/teams/test_core_api_teams_delete.py @@ -3,6 +3,7 @@ """ import pytest +import responses from rest_framework.status import ( HTTP_204_NO_CONTENT, HTTP_401_UNAUTHORIZED, @@ -12,6 +13,7 @@ from rest_framework.test import APIClient from core import factories, models +from core.tests.conftest import build_authorization_bearer pytestmark = pytest.mark.django_db @@ -113,3 +115,79 @@ def test_api_teams_delete_authenticated_owner(): assert response.status_code == HTTP_204_NO_CONTENT assert models.Team.objects.exists() is False + + +@responses.activate +def test_api_teams_delete_authenticated_owner_resource_server( + client, resource_server_settings +): + """ + Authenticated users should be able to delete a team for which they are directly + owner, even if the request is authenticated via a resource server. + """ + user = factories.UserFactory() + service_provider = factories.ServiceProviderFactory( + audience_id="some_service_provider" + ) + team = factories.TeamFactory( + users=[(user, "owner")], service_providers=[service_provider] + ) + + # Mock the introspection endpoint + encrypt_jwt = resource_server_settings["encrypt_jwt"] + + responses.add( + responses.POST, + "https://oidc.example.com/introspect", + body=encrypt_jwt( + { + "iss": "https://oidc.example.com", + "aud": "some_client_id", # settings.OIDC_RS_CLIENT_ID + "token_introspection": { + "sub": user.sub, + "iss": "https://oidc.example.com", + "aud": "some_service_provider", + "scope": "openid groups", + "active": True, + }, + } + ), + ) + + # Authenticate using the resource server, ie via the Authorization header + response = client.delete( + f"/api/v1.0/teams/{team.pk}/", + HTTP_AUTHORIZATION=f"Bearer {build_authorization_bearer('some_token')}", + ) + + assert response.status_code == HTTP_204_NO_CONTENT + assert models.Team.objects.exists() is False + + +def test_api_teams_delete_authenticated_other_resource_server( + client, authenticated_user_with_service_provider +): + """ + Authenticated users should not be able to delete a team for which they are directly + owner, if the request is authenticated via a different resource server. + """ + user, _service_provider = authenticated_user_with_service_provider + + other_service_provider = factories.ServiceProviderFactory( + audience_id="some_service_provider" + ) + team = factories.TeamFactory( + users=[(user, "owner")], service_providers=[other_service_provider] + ) + + response = client.delete( + f"/api/v1.0/teams/{team.pk}/", + HTTP_AUTHORIZATION="Bearer b64untestedbearertoken", + ) + + assert response.status_code == HTTP_404_NOT_FOUND + + team.refresh_from_db() # should not fail + assert ( + team.service_providers.get().audience_id == other_service_provider.audience_id + ) diff --git a/src/backend/core/tests/teams/test_core_api_teams_list.py b/src/backend/core/tests/teams/test_core_api_teams_list.py index 420dd59e7..80e37421b 100644 --- a/src/backend/core/tests/teams/test_core_api_teams_list.py +++ b/src/backend/core/tests/teams/test_core_api_teams_list.py @@ -2,14 +2,13 @@ Tests for Teams API endpoint in People's core app: list """ -from unittest import mock - import pytest -from rest_framework.pagination import PageNumberPagination +import responses from rest_framework.status import HTTP_200_OK, HTTP_401_UNAUTHORIZED from rest_framework.test import APIClient from core import factories +from core.tests.conftest import build_authorization_bearer pytestmark = pytest.mark.django_db @@ -123,9 +122,107 @@ def test_api_teams_order_param(): assert response.status_code == 200 response_data = response.json() - response_team_ids = [team["id"] for team in response_data] assert ( response_team_ids == team_ids ), "created_at values are not sorted from oldest to newest" + + +@responses.activate +def test_api_teams_list_authenticated_resource_server( + client, django_assert_num_queries, resource_server_settings +): + """ + Authenticated users should be able to list teams + they are an owner/administrator/member of, and + the service provider should be filtered to the one + using the resource server. + """ + user = factories.UserFactory() + service_provider = factories.ServiceProviderFactory() + hidden_service_provider = factories.ServiceProviderFactory() + + # Mock the introspection endpoint + encrypt_jwt = resource_server_settings["encrypt_jwt"] + + responses.add( + responses.POST, + "https://oidc.example.com/introspect", + body=encrypt_jwt( + { + "iss": "https://oidc.example.com", + "aud": "some_client_id", # settings.OIDC_RS_CLIENT_ID + "token_introspection": { + "sub": user.sub, + "iss": "https://oidc.example.com", + "aud": str(service_provider.audience_id), + "scope": "openid groups", + "active": True, + }, + } + ), + ) + + team_access_1 = factories.TeamAccessFactory( + user=user, team__service_providers=[service_provider], role="member" + ) + team_1 = team_access_1.team + + team_access_2 = factories.TeamAccessFactory( + user=user, + team__service_providers=[service_provider, hidden_service_provider], + role="member", + ) + team_2 = team_access_2.team + + # Team filtered out because of the service provider + _not_included_team_access = factories.TeamAccessFactory( + user=user, team__service_providers=[hidden_service_provider] + ) + + # Authenticate using the resource server, ie via the Authorization header + with django_assert_num_queries(4): # User, Team, ServiceProvider, TeamAccess + response = client.get( + "/api/v1.0/teams/?ordering=created_at", + format="json", + HTTP_AUTHORIZATION=f"Bearer {build_authorization_bearer('some_token')}", + ) + + assert response.status_code == HTTP_200_OK + results = response.json() + + assert results == [ + { + "abilities": { + "delete": False, + "get": True, + "manage_accesses": False, + "patch": False, + "put": False, + }, + "accesses": [str(team_access_1.pk)], + "created_at": team_1.created_at.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "id": str(team_1.pk), + "name": team_1.name, + "service_providers": [str(service_provider.pk)], + "updated_at": team_1.updated_at.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + }, + { + "abilities": { + "delete": False, + "get": True, + "manage_accesses": False, + "patch": False, + "put": False, + }, + "accesses": [str(team_access_2.pk)], + "created_at": team_2.created_at.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "id": str(team_2.pk), + "name": team_2.name, + "service_providers": [ + str(service_provider.pk) + ], # Only the relevant service provider + "updated_at": team_2.updated_at.strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + }, + ] diff --git a/src/backend/core/tests/teams/test_core_api_teams_retrieve.py b/src/backend/core/tests/teams/test_core_api_teams_retrieve.py index 2edd1fa3d..fa774cd4a 100644 --- a/src/backend/core/tests/teams/test_core_api_teams_retrieve.py +++ b/src/backend/core/tests/teams/test_core_api_teams_retrieve.py @@ -3,10 +3,13 @@ """ import pytest +import responses from rest_framework import status +from rest_framework.status import HTTP_404_NOT_FOUND from rest_framework.test import APIClient from core import factories +from core.tests.conftest import build_authorization_bearer pytestmark = pytest.mark.django_db @@ -72,4 +75,90 @@ def test_api_teams_retrieve_authenticated_related(): "abilities": team.get_abilities(user), "created_at": team.created_at.isoformat().replace("+00:00", "Z"), "updated_at": team.updated_at.isoformat().replace("+00:00", "Z"), + "service_providers": [], } + + +@responses.activate +def test_api_teams_retrieve_authenticated_owner_resource_server( + client, resource_server_settings +): + """ + Authenticated users should be allowed to retrieve a team to which they + are related whatever the role even if the request is authenticated via + a resource server. + """ + service_provider = factories.ServiceProviderFactory( + audience_id="some_service_provider" + ) + user = factories.UserFactory() + team = factories.TeamFactory(service_providers=[service_provider]) + access1 = factories.TeamAccessFactory(team=team, user=user) + access2 = factories.TeamAccessFactory(team=team) + + # Mock the introspection endpoint + encrypt_jwt = resource_server_settings["encrypt_jwt"] + + responses.add( + responses.POST, + "https://oidc.example.com/introspect", + body=encrypt_jwt( + { + "iss": "https://oidc.example.com", + "aud": "some_client_id", # settings.OIDC_RS_CLIENT_ID + "token_introspection": { + "sub": user.sub, + "iss": "https://oidc.example.com", + "aud": "some_service_provider", + "scope": "openid groups", + "active": True, + }, + } + ), + ) + + # Authenticate using the resource server, ie via the Authorization header + response = client.get( + f"/api/v1.0/teams/{team.id!s}/", + HTTP_AUTHORIZATION=f"Bearer {build_authorization_bearer('some_token')}", + ) + + assert response.status_code == status.HTTP_200_OK + assert sorted(response.json().pop("accesses")) == sorted( + [ + str(access1.id), + str(access2.id), + ] + ) + assert response.json() == { + "id": str(team.id), + "name": team.name, + "abilities": team.get_abilities(user), + "created_at": team.created_at.isoformat().replace("+00:00", "Z"), + "updated_at": team.updated_at.isoformat().replace("+00:00", "Z"), + "service_providers": [str(service_provider.pk)], + } + + +def test_api_teams_retrieve_authenticated_other_resource_server( + client, authenticated_user_with_service_provider +): + """ + Authenticated users should not be able to delete a team for which they are directly + owner, if the request is authenticated via a different resource server. + """ + user, _service_provider = authenticated_user_with_service_provider + + other_service_provider = factories.ServiceProviderFactory( + audience_id="some_service_provider" + ) + team = factories.TeamFactory( + users=[user], service_providers=[other_service_provider] + ) + + response = client.get( + f"/api/v1.0/teams/{team.id!s}/", + HTTP_AUTHORIZATION="Bearer b64untestedbearertoken", + ) + + assert response.status_code == HTTP_404_NOT_FOUND diff --git a/src/backend/core/tests/teams/test_core_api_teams_update.py b/src/backend/core/tests/teams/test_core_api_teams_update.py index ace2b5339..b892aa00f 100644 --- a/src/backend/core/tests/teams/test_core_api_teams_update.py +++ b/src/backend/core/tests/teams/test_core_api_teams_update.py @@ -5,6 +5,7 @@ import random import pytest +import responses from rest_framework.status import ( HTTP_200_OK, HTTP_401_UNAUTHORIZED, @@ -15,6 +16,7 @@ from core import factories from core.api import serializers +from core.tests.conftest import build_authorization_bearer pytestmark = pytest.mark.django_db @@ -188,3 +190,95 @@ def test_api_teams_update_administrator_or_owner_of_another(): team.refresh_from_db() team_values = serializers.TeamSerializer(instance=team).data assert team_values == old_team_values + + +@responses.activate +def test_api_teams_update_authenticated_owner_resource_server( + client, resource_server_settings +): + """ + Authenticated users should be allowed to update a team to which they + are related whatever the role even if the request is authenticated via + a resource server. + """ + service_provider = factories.ServiceProviderFactory( + audience_id="some_service_provider" + ) + user = factories.UserFactory() + team = factories.TeamFactory( + name="Old name", + users=[(user, "owner")], + service_providers=[service_provider], + ) + + # Mock the introspection endpoint + encrypt_jwt = resource_server_settings["encrypt_jwt"] + + responses.add( + responses.POST, + "https://oidc.example.com/introspect", + body=encrypt_jwt( + { + "iss": "https://oidc.example.com", + "aud": "some_client_id", # settings.OIDC_RS_CLIENT_ID + "token_introspection": { + "sub": user.sub, + "iss": "https://oidc.example.com", + "aud": "some_service_provider", + "scope": "openid groups", + "active": True, + }, + } + ), + ) + + # Authenticate using the resource server, ie via the Authorization header + response = client.put( + f"/api/v1.0/teams/{team.id!s}/", + data=serializers.TeamSerializer(instance=team).data + | { + "name": "New name", + }, + content_type="application/json", + HTTP_AUTHORIZATION=f"Bearer {build_authorization_bearer('some_token')}", + ) + + assert response.status_code == HTTP_200_OK + + team.refresh_from_db() + assert team.name == "New name" + + +def test_api_teams_update_authenticated_other_resource_server( + client, authenticated_user_with_service_provider +): + """ + Authenticated users should not be able to update a team for which they are directly + owner, if the request is authenticated via a different resource server. + """ + user, _service_provider = authenticated_user_with_service_provider + + other_service_provider = factories.ServiceProviderFactory( + audience_id="some_service_provider" + ) + team = factories.TeamFactory( + name="Old name", + users=[(user, "owner")], + service_providers=[other_service_provider], + ) + + response = client.put( + f"/api/v1.0/teams/{team.id!s}/", + data=serializers.TeamSerializer(instance=team).data + | { + "name": "New name", + }, + content_type="application/json", + HTTP_AUTHORIZATION="Bearer b64untestedbearertoken", + ) + + assert response.status_code == HTTP_404_NOT_FOUND + assert response.json() == {"detail": "No Team matches the given query."} + + team.refresh_from_db() + assert team.name == "Old name" diff --git a/src/backend/people/api_urls.py b/src/backend/people/api_urls.py index 059b66264..451e3c0eb 100644 --- a/src/backend/people/api_urls.py +++ b/src/backend/people/api_urls.py @@ -14,6 +14,9 @@ router.register("contacts", viewsets.ContactViewSet, basename="contacts") router.register("teams", viewsets.TeamViewSet, basename="teams") router.register("users", viewsets.UserViewSet, basename="users") +router.register( + "service-providers", viewsets.ServiceProviderViewSet, basename="service-providers" +) # - Routes nested under a team team_related_router = DefaultRouter() diff --git a/src/backend/people/settings.py b/src/backend/people/settings.py index 8c99afa6b..8dfd58477 100755 --- a/src/backend/people/settings.py +++ b/src/backend/people/settings.py @@ -219,7 +219,10 @@ class Base(Configuration): REST_FRAMEWORK = { "DEFAULT_AUTHENTICATION_CLASSES": ( - "core.resource_server.authentication.ResourceServerAuthentication", + # "core.resource_server.authentication.ResourceServerAuthentication", + # The resource server authentication is added on a per-view basis + # to enforce the filtering adapted from the introspected token. + # See ResourceServerMixin usage for more details. "mozilla_django_oidc.contrib.drf.OIDCAuthentication", "rest_framework.authentication.SessionAuthentication", ),