Skip to content

Commit

Permalink
more edge cases
Browse files Browse the repository at this point in the history
  • Loading branch information
ramonavic committed Jan 24, 2025
1 parent 215f812 commit 53be751
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 8 deletions.
7 changes: 6 additions & 1 deletion app/signals/apps/api/serializers/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,12 @@ def update(self, instance, validated_data): # noqa
return signal


class PrivateSignalSerializerList(SignalReporterEmailValidationMixin, SignalParentValidationMixin, AddressValidationMixin, HALSerializer):
class PrivateSignalSerializerList(
SignalReporterEmailValidationMixin,
SignalParentValidationMixin,
AddressValidationMixin,
HALSerializer
):
"""
This serializer is used for the list endpoint and when creating a new instance
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -380,12 +380,15 @@ def test_create_initial_signal_invalid_source(self, validate_address):

@patch('signals.apps.api.validation.address.base.BaseAddressValidation.validate_address',
side_effect=AddressValidationUnavailableException) # Skip address validation
def test_create_initial_signal_missing_source_should_give_default_source(self, validate_address):
def test_create_initial_signal_missing_source_should_give_internal_default_source_if_reporter_of_certain_domain(self, validate_address):
signal_count = Signal.objects.count()

SourceFactory.create_batch(5)

initial_data = copy.deepcopy(self.initial_data_base)

initial_data['reporter']['email'] = 'test-email-1' \
f'{settings.API_TRANSFORM_SOURCE_BASED_ON_REPORTER_DOMAIN_EXTENSIONS}'
del initial_data['source']

response = self.client.post(self.list_endpoint, initial_data, format='json')
Expand All @@ -396,6 +399,26 @@ def test_create_initial_signal_missing_source_should_give_default_source(self, v
data = response.json()
self.assertEqual(data['source'], settings.API_TRANSFORM_SOURCE_BASED_ON_REPORTER_SOURCE)

@patch('signals.apps.api.validation.address.base.BaseAddressValidation.validate_address',
side_effect=AddressValidationUnavailableException) # Skip address validation
def test_create_initial_signal_missing_source_should_give_internal_default_source_if_reporter_unknown_domain(self,
validate_address):
signal_count = Signal.objects.count()

SourceFactory.create_batch(5)

initial_data = copy.deepcopy(self.initial_data_base)
del initial_data['source']

response = self.client.post(self.list_endpoint, initial_data, format='json')

self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(Signal.objects.count(), signal_count + 1)

data = response.json()
self.assertEqual(data['source'], Signal._meta.get_field('source').get_default())


@patch('signals.apps.api.validation.address.base.BaseAddressValidation.validate_address',
side_effect=AddressValidationUnavailableException) # Skip address validation
def test_create_initial_signal_valid_source(self, validate_address):
Expand Down
13 changes: 7 additions & 6 deletions app/signals/apps/api/validation/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@ def validate(self, attrs):
and attrs['reporter']['email']):
reporter_email = attrs['reporter']['email']

if self.__class__.__name__ == 'PrivateSignalSerializerList' and 'source' not in attrs:
attrs['source'] = settings.API_TRANSFORM_SOURCE_BASED_ON_REPORTER_SOURCE
if (reporter_email not in settings.API_TRANSFORM_SOURCE_BASED_ON_REPORTER_EXCEPTIONS
and reporter_email.endswith(settings.API_TRANSFORM_SOURCE_BASED_ON_REPORTER_DOMAIN_EXTENSIONS)):

if (self.__class__.__name__ == 'PublicSignalCreateSerializer'
and reporter_email not in settings.API_TRANSFORM_SOURCE_BASED_ON_REPORTER_EXCEPTIONS
and reporter_email.endswith(settings.API_TRANSFORM_SOURCE_BASED_ON_REPORTER_DOMAIN_EXTENSIONS)):
attrs['source'] = settings.API_TRANSFORM_SOURCE_BASED_ON_REPORTER_SOURCE
if self.__class__.__name__ == 'PrivateSignalSerializerList' and 'source' not in attrs:
attrs['source'] = settings.API_TRANSFORM_SOURCE_BASED_ON_REPORTER_SOURCE

if self.__class__.__name__ == 'PublicSignalCreateSerializer':
attrs['source'] = settings.API_TRANSFORM_SOURCE_BASED_ON_REPORTER_SOURCE

return super().validate(attrs)

Expand Down

0 comments on commit 53be751

Please sign in to comment.