diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 51e2890bd5..ae6419f8eb 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -16,6 +16,18 @@ Change Log Unreleased ---------- +[4.10.0] +-------- + +feat: enrollment API enhancements + +- Allows Enrollment API Admin to see all enrollments. +- Makes the endpoint return more fields, such as: enrollment_track, + enrollment_date, user_email, course_start and course_end. +- Changes EnterpriseCourseEnrollment's default ordering from 'created' + to 'id', which equivalent, but faster in some cases (due to the + existing indes on 'id'). + [4.9.5] -------- diff --git a/enterprise/__init__.py b/enterprise/__init__.py index 897c6f5e55..b3a0ce8824 100644 --- a/enterprise/__init__.py +++ b/enterprise/__init__.py @@ -2,4 +2,4 @@ Your project description goes here. """ -__version__ = "4.9.5" +__version__ = "4.10.0" diff --git a/enterprise/api/filters.py b/enterprise/api/filters.py index 0ea4115085..92c1bb4551 100644 --- a/enterprise/api/filters.py +++ b/enterprise/api/filters.py @@ -7,7 +7,7 @@ from django.contrib import auth -from enterprise.models import EnterpriseCustomerUser, SystemWideEnterpriseUserRoleAssignment +from enterprise.models import EnterpriseCustomer, EnterpriseCustomerUser, SystemWideEnterpriseUserRoleAssignment User = auth.get_user_model() @@ -33,6 +33,36 @@ def filter_queryset(self, request, queryset, view): return queryset +class EnterpriseCourseEnrollmentFilterBackend(filters.BaseFilterBackend): + """ + Filter backend to return enrollments under the user's enterprise(s) only. + + * Staff users will bypass this filter. + * Non-staff users will receive enrollments under their linked enterprises, + only if they have the `enterprise.can_enroll_learners` permission. + * Non-staff users without the `enterprise.can_enroll_learners` permission + will receive only their own enrollments. + """ + + def filter_queryset(self, request, queryset, view): + """ + Filter out enrollments if learner is not linked + """ + + if request.user.is_staff: + return queryset + + if request.user.has_perm('enterprise.can_enroll_learners'): + enterprise_customers = EnterpriseCustomer.objects.filter(enterprise_customer_users__user_id=request.user.id) + return queryset.filter(enterprise_customer_user__enterprise_customer__in=enterprise_customers) + + filter_kwargs = { + view.USER_ID_FILTER: request.user.id, + } + + return queryset.filter(**filter_kwargs) + + class EnterpriseCustomerUserFilterBackend(filters.BaseFilterBackend): """ Allow filtering on the enterprise customer user api endpoint. diff --git a/enterprise/api/v1/serializers.py b/enterprise/api/v1/serializers.py index 3010e52805..d991cc445b 100644 --- a/enterprise/api/v1/serializers.py +++ b/enterprise/api/v1/serializers.py @@ -356,6 +356,32 @@ class Meta: ) +class EnterpriseCourseEnrollmentWithAdditionalFieldsReadOnlySerializer(EnterpriseCourseEnrollmentReadOnlySerializer): + """ + Serializer for EnterpriseCourseEnrollment model with additional fields. + """ + + class Meta: + model = models.EnterpriseCourseEnrollment + fields = ( + 'enterprise_customer_user', + 'course_id', + 'created', + 'unenrolled_at', + 'enrollment_date', + 'enrollment_track', + 'user_email', + 'course_start', + 'course_end', + ) + + enrollment_track = serializers.CharField() + enrollment_date = serializers.DateTimeField() + user_email = serializers.EmailField() + course_start = serializers.DateTimeField() + course_end = serializers.DateTimeField() + + class EnterpriseCourseEnrollmentWriteSerializer(serializers.ModelSerializer): """ Serializer for writing to the EnterpriseCourseEnrollment model. diff --git a/enterprise/api/v1/views/enterprise_course_enrollment.py b/enterprise/api/v1/views/enterprise_course_enrollment.py index c7aef5ffc0..59ebf75c01 100644 --- a/enterprise/api/v1/views/enterprise_course_enrollment.py +++ b/enterprise/api/v1/views/enterprise_course_enrollment.py @@ -1,17 +1,68 @@ """ Views for the ``enterprise-course-enrollment`` API endpoint. """ +from django_filters.rest_framework import DjangoFilterBackend +from edx_rest_framework_extensions.paginators import DefaultPagination +from rest_framework import filters + +from django.core.paginator import Paginator +from django.utils.functional import cached_property + from enterprise import models +from enterprise.api.filters import EnterpriseCourseEnrollmentFilterBackend from enterprise.api.v1 import serializers from enterprise.api.v1.views.base_views import EnterpriseReadWriteModelViewSet +try: + from common.djangoapps.util.query import read_replica_or_default +except ImportError: + def read_replica_or_default(): + return None + + +class PaginatorWithOptimizedCount(Paginator): + """ + Django < 4.2 ORM doesn't strip unused annotations from count queries. + + For example, if we execute this query: + + Book.objects.annotate(Count('chapters')).count() + + it will generate SQL like this: + + SELECT COUNT(*) FROM (SELECT COUNT(...), ... FROM ...) subquery + + This isn't optimal on its own, but it's not a big deal. However, this + becomes problematic when annotations use subqueries, because it's terribly + inefficient to execute the subquery for every row in the outer query. + + This class overrides the count() method of Django's Paginator to strip + unused annotations from the query by requesting only the primary key + instead of all fields. + + This is a temporary workaround until Django is updated to 4.2, which will + include a fix for this issue. + + See https://code.djangoproject.com/ticket/32169 for more details. + + TODO: remove this class once Django is updated to 4.2 or higher. + """ + @cached_property + def count(self): + return self.object_list.values("pk").count() + + +class EnterpriseCourseEnrollmentPagination(DefaultPagination): + django_paginator_class = PaginatorWithOptimizedCount + class EnterpriseCourseEnrollmentViewSet(EnterpriseReadWriteModelViewSet): """ API views for the ``enterprise-course-enrollment`` API endpoint. """ - queryset = models.EnterpriseCourseEnrollment.objects.all() + queryset = models.EnterpriseCourseEnrollment.with_additional_fields.all() + filter_backends = (filters.OrderingFilter, DjangoFilterBackend, EnterpriseCourseEnrollmentFilterBackend) USER_ID_FILTER = 'enterprise_customer_user__user_id' FIELDS = ( @@ -20,10 +71,18 @@ class EnterpriseCourseEnrollmentViewSet(EnterpriseReadWriteModelViewSet): filterset_fields = FIELDS ordering_fields = FIELDS + pagination_class = EnterpriseCourseEnrollmentPagination + + def get_queryset(self): + queryset = super().get_queryset() + if self.request.method == 'GET': + queryset = queryset.using(read_replica_or_default()) + return queryset + def get_serializer_class(self): """ Use a special serializer for any requests that aren't read-only. """ if self.request.method in ('GET',): - return serializers.EnterpriseCourseEnrollmentReadOnlySerializer + return serializers.EnterpriseCourseEnrollmentWithAdditionalFieldsReadOnlySerializer return serializers.EnterpriseCourseEnrollmentWriteSerializer diff --git a/enterprise/migrations/0198_alter_enterprisecourseenrollment_options.py b/enterprise/migrations/0198_alter_enterprisecourseenrollment_options.py new file mode 100644 index 0000000000..1c3982cd8f --- /dev/null +++ b/enterprise/migrations/0198_alter_enterprisecourseenrollment_options.py @@ -0,0 +1,17 @@ +# Generated by Django 4.2 on 2023-12-29 17:03 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ('enterprise', '0197_auto_20231130_2239'), + ] + + operations = [ + migrations.AlterModelOptions( + name='enterprisecourseenrollment', + options={'ordering': ['id']}, + ), + ] diff --git a/enterprise/models.py b/enterprise/models.py index e3c47734cf..6c2bfec1ca 100644 --- a/enterprise/models.py +++ b/enterprise/models.py @@ -93,6 +93,11 @@ except ImportError: CourseEntitlement = None +try: + from openedx.core.djangoapps.content.course_overviews.models import CourseOverview +except ImportError: + CourseOverview = None + LOGGER = getEnterpriseLogger(__name__) User = auth.get_user_model() mark_safe_lazy = lazy(mark_safe, str) @@ -1857,11 +1862,61 @@ def get_queryset(self): """ Override to return only those enrollment records for which learner is linked to an enterprise. """ + return super().get_queryset().select_related('enterprise_customer_user').filter( enterprise_customer_user__linked=True ) +class EnterpriseCourseEnrollmentWithAdditionalFieldsManager(models.Manager): + """ + Model manager for `EnterpriseCourseEnrollment`. + """ + + def get_queryset(self): + """ + Override to return only those enrollment records for which learner is linked to an enterprise. + """ + + return super().get_queryset().select_related('enterprise_customer_user').filter( + enterprise_customer_user__linked=True + ).annotate(**self._get_additional_data_annotations()) + + def _get_additional_data_annotations(self): + """ + Return annotations with additional data for the queryset. + Additional fields are None in the test environment, where platform models are not available. + """ + + if not CourseEnrollment or not CourseOverview: + return { + 'enrollment_track': models.Value(None, output_field=models.CharField()), + 'enrollment_date': models.Value(None, output_field=models.DateTimeField()), + 'user_email': models.Value(None, output_field=models.EmailField()), + 'course_start': models.Value(None, output_field=models.DateTimeField()), + 'course_end': models.Value(None, output_field=models.DateTimeField()), + } + + enrollment_subquery = CourseEnrollment.objects.filter( + user=models.OuterRef('enterprise_customer_user__user_id'), + course_id=models.OuterRef('course_id'), + ) + user_subquery = auth.get_user_model().objects.filter( + id=models.OuterRef('enterprise_customer_user__user_id'), + ).values('email')[:1] + course_subquery = CourseOverview.objects.filter( + id=models.OuterRef('course_id'), + ) + + return { + 'enrollment_track': models.Subquery(enrollment_subquery.values('mode')[:1]), + 'enrollment_date': models.Subquery(enrollment_subquery.values('created')[:1]), + 'user_email': models.Subquery(user_subquery), + 'course_start': models.Subquery(course_subquery.values('start')[:1]), + 'course_end': models.Subquery(course_subquery.values('end')[:1]), + } + + class EnterpriseCourseEnrollment(TimeStampedModel): """ Store information about the enrollment of a user in a course. @@ -1881,11 +1936,15 @@ class EnterpriseCourseEnrollment(TimeStampedModel): """ objects = EnterpriseCourseEnrollmentManager() + with_additional_fields = EnterpriseCourseEnrollmentWithAdditionalFieldsManager() class Meta: unique_together = (('enterprise_customer_user', 'course_id',),) app_label = 'enterprise' - ordering = ['created'] + # Originally, we were ordering by 'created', but there was never an index on that column. To avoid creating + # an index on that column, we are ordering by 'id' instead, which is indexed by default and is equivalent to + # ordering by 'created' in this case. + ordering = ['id'] enterprise_customer_user = models.ForeignKey( EnterpriseCustomerUser, diff --git a/test_utils/factories.py b/test_utils/factories.py index 3d87059754..811cb5c27a 100644 --- a/test_utils/factories.py +++ b/test_utils/factories.py @@ -27,6 +27,8 @@ EnterpriseCustomerReportingConfiguration, EnterpriseCustomerSsoConfiguration, EnterpriseCustomerUser, + EnterpriseFeatureRole, + EnterpriseFeatureUserRoleAssignment, LearnerCreditEnterpriseCourseEnrollment, LicensedEnterpriseCourseEnrollment, PendingEnrollment, @@ -272,6 +274,39 @@ class Meta: invite_key = None +class EnterpriseFeatureRoleFactory(factory.django.DjangoModelFactory): + """ + EnterpriseFeatureRole factory. + Creates an instance of EnterpriseFeatureRole with minimal boilerplate. + """ + + class Meta: + """ + Meta for EnterpriseFeatureRoleFactory. + """ + + model = EnterpriseFeatureRole + + name = factory.LazyAttribute(lambda x: FAKER.word()) + + +class EnterpriseFeatureUserRoleAssignmentFactory(factory.django.DjangoModelFactory): + """ + EnterpriseFeatureUserRoleAssignment factory. + Creates an instance of EnterpriseFeatureUserRoleAssignment with minimal boilerplate. + """ + + class Meta: + """ + Meta for EnterpriseFeatureUserRoleAssignmentFactory. + """ + + model = EnterpriseFeatureUserRoleAssignment + + role = factory.SubFactory(EnterpriseFeatureRoleFactory) + user = factory.SubFactory(UserFactory) + + class AnonymousUserFactory(factory.Factory): """ Anonymous User factory. diff --git a/tests/test_enterprise/api/test_filters.py b/tests/test_enterprise/api/test_filters.py index b37b90790c..1076c58f4b 100644 --- a/tests/test_enterprise/api/test_filters.py +++ b/tests/test_enterprise/api/test_filters.py @@ -10,7 +10,8 @@ from django.conf import settings -from enterprise.constants import ENTERPRISE_ADMIN_ROLE +from enterprise.constants import ENTERPRISE_ADMIN_ROLE, ENTERPRISE_ENROLLMENT_API_ADMIN_ROLE +from enterprise.models import EnterpriseFeatureRole from test_utils import FAKE_UUIDS, TEST_EMAIL, TEST_USERNAME, APITest, factories ENTERPRISE_CUSTOMER_LIST_ENDPOINT = reverse('enterprise-customer-list') @@ -80,6 +81,73 @@ def test_filter_for_detail(self, is_staff, is_linked, expected_content_in_respon assert data[key] == value +@ddt.ddt +@mark.django_db +class TestEnterpriseCourseEnrollmentFilterBackend(APITest): + """ + Test suite for the ``EnterpriseCourseEnrollmentFilterBackend`` filter. + """ + + def setUp(self): + super().setUp() + + self._setup_enterprise_customer_and_enrollments( + uuid=FAKE_UUIDS[0], + users=[self.user, factories.UserFactory()] + ) + self._setup_enterprise_customer_and_enrollments( + uuid=FAKE_UUIDS[1], + users=[factories.UserFactory(), factories.UserFactory()] + ) + + self.url = settings.TEST_SERVER + reverse('enterprise-course-enrollment-list') + + def _setup_enterprise_customer_and_enrollments(self, uuid, users): + """ + Creates an enterprise customer with the uuid and enrolls passed users. + """ + enterprise_customer = factories.EnterpriseCustomerFactory(uuid=uuid) + + for user in users: + enterprise_customer_user = factories.EnterpriseCustomerUserFactory( + enterprise_customer=enterprise_customer, + user_id=user.id + ) + factories.EnterpriseCourseEnrollmentFactory( + enterprise_customer_user=enterprise_customer_user + ) + + def _setup_user_privileges_by_role(self, user, role): + """ + Sets up privileges for the passed user based on the role. + """ + if role == "staff": + user.is_staff = True + user.save() + elif role == "enrollment_api_admin": + factories.EnterpriseFeatureUserRoleAssignmentFactory( + user=user, + role=EnterpriseFeatureRole.objects.get(name=ENTERPRISE_ENROLLMENT_API_ADMIN_ROLE) + ) + + @ddt.data( + ("regular", 1), + ("enrollment_api_admin", 2), + ("staff", 4), + ) + @ddt.unpack + def test_filter_for_list(self, user_role, expected_course_enrollment_count): + """ + Filter objects based off whether the user is a staff, enterprise enrollment api admin, or neither. + """ + self._setup_user_privileges_by_role(self.user, user_role) + + response = self.client.get(self.url) + assert response.status_code == status.HTTP_200_OK + data = self.load_json(response.content) + assert len(data['results']) == expected_course_enrollment_count + + @ddt.ddt @mark.django_db class TestEnterpriseCustomerUserFilterBackend(APITest): diff --git a/tests/test_enterprise/api/test_views.py b/tests/test_enterprise/api/test_views.py index aa41038a9a..6383f3dce1 100644 --- a/tests/test_enterprise/api/test_views.py +++ b/tests/test_enterprise/api/test_views.py @@ -1276,6 +1276,11 @@ class TestEnterpriseCustomerViewSet(BaseTestEnterpriseAPIViews): 'course_id': 'course-v1:edX+DemoX+DemoCourse', 'created': '2021-10-20T19:01:31Z', 'unenrolled_at': None, + 'enrollment_date': None, + 'enrollment_track': None, + 'user_email': None, + 'course_start': None, + 'course_end': None, }], ), (