diff --git a/tests/test_enterprise/api/test_filters.py b/tests/test_enterprise/api/test_filters.py index 17f63127c0..32c975b9db 100644 --- a/tests/test_enterprise/api/test_filters.py +++ b/tests/test_enterprise/api/test_filters.py @@ -91,42 +91,40 @@ class TestEnterpriseCourseEnrollmentFilterBackend(APITest): def setUp(self): super().setUp() - self.other_user= factories.UserFactory() - self.enterprise_customer_1 = factories.EnterpriseCustomerFactory(uuid=FAKE_UUIDS[0]) - self.enterprise_customer_2 = factories.EnterpriseCustomerFactory(uuid=FAKE_UUIDS[1]) - self.enterprise_customer_user_1 = factories.EnterpriseCustomerUserFactory( - enterprise_customer=self.enterprise_customer_1, - user_id=self.user.id - ) - self.enterprise_customer_user_2 = factories.EnterpriseCustomerUserFactory( - enterprise_customer=self.enterprise_customer_1, - user_id=self.other_user.id - ) - self.course_enrollment_1 = factories.EnterpriseCourseEnrollmentFactory( - enterprise_customer_user=self.enterprise_customer_user_1 + self._setup_enterprise_customer_and_enrollments( + uuid=FAKE_UUIDS[0], + users=[self.user, factories.UserFactory()] ) - self.course_enrollment_2 = factories.EnterpriseCourseEnrollmentFactory( - enterprise_customer_user=self.enterprise_customer_user_2 - ) - - self.enterprise_customer_user_3 = factories.EnterpriseCustomerUserFactory( - enterprise_customer=self.enterprise_customer_2, - user_id=factories.UserFactory().id - ) - self.enterprise_customer_user_4 = factories.EnterpriseCustomerUserFactory( - enterprise_customer=self.enterprise_customer_2, - user_id=factories.UserFactory().id - ) - self.course_enrollment_3 = factories.EnterpriseCourseEnrollmentFactory( - enterprise_customer_user=self.enterprise_customer_user_3 - ) - self.course_enrollment_4 = factories.EnterpriseCourseEnrollmentFactory( - enterprise_customer_user=self.enterprise_customer_user_4 + self._setup_enterprise_customer_and_enrollments( + uuid=FAKE_UUIDS[1], + users=[factories.UserFactory(), factories.UserFactory()] ) self.url = settings.TEST_SERVER + reverse('enterprise-course-enrollment-list') + def _setup_enterprise_customer_and_enrollments(self, uuid, users): + enterprise_customer = factories.EnterpriseCustomerFactory(uuid=uuid) + + for user in users: + enterprise_customer_user = factories.EnterpriseCustomerUserFactory( + enterprise_customer=enterprise_customer, + user_id=user.id + ) + course_enrollment = factories.EnterpriseCourseEnrollmentFactory( + enterprise_customer_user=enterprise_customer_user + ) + + def _setup_user_privileges_by_role(self, user, role): + if role == "staff": + user.is_staff = True + user.save() + elif role == "enrollment_api_admin": + factories.EnterpriseFeatureUserRoleAssignmentFactory( + user=user, + role=EnterpriseFeatureRole.objects.get(name=ENTERPRISE_ENROLLMENT_API_ADMIN_ROLE) + ) + @ddt.data( ("regular", 1), ("enrollment_api_admin", 2), @@ -137,14 +135,7 @@ def test_filter_for_list(self, user_role, expected_course_enrollment_count): """ Filter objects based off whether the user is a staff, enterprise enrollment api admin, or neither. """ - if user_role == "staff": - self.user.is_staff = True - self.user.save() - elif user_role == "enrollment_api_admin": - factories.EnterpriseFeatureUserRoleAssignmentFactory( - user=self.user, - role=EnterpriseFeatureRole.objects.get(name=ENTERPRISE_ENROLLMENT_API_ADMIN_ROLE) - ) + self._setup_user_privileges_by_role(self.user, user_role) response = self.client.get(self.url) assert response.status_code == status.HTTP_200_OK