diff --git a/api/features/feature_segments/serializers.py b/api/features/feature_segments/serializers.py index e7aa85afe5ce..9052eac8264d 100644 --- a/api/features/feature_segments/serializers.py +++ b/api/features/feature_segments/serializers.py @@ -1,8 +1,7 @@ -import typing - from common.features.serializers import ( CreateSegmentOverrideFeatureSegmentSerializer, ) +from django.db import transaction from rest_framework import serializers from rest_framework.exceptions import PermissionDenied @@ -49,17 +48,41 @@ class CustomCreateSegmentOverrideFeatureSegmentSerializer( # field here, and use it manually in the save method. priority = serializers.IntegerField(min_value=0, required=False) - def save(self, **kwargs: typing.Any) -> FeatureSegment: - priority: int | None = self.initial_data.pop("priority", None) + @transaction.atomic() + def save(self, **kwargs) -> FeatureSegment: + """ + Note that this method is marked as atomic since a lot of additional validation is + performed in the call to super. If that fails, we want to roll the changes made by + `collision.to` back. + """ + + priority: int | None = self.validated_data.get("priority", None) - feature_segment: FeatureSegment = super().save(**kwargs) + if kwargs["environment"].use_v2_feature_versioning: # pragma: no cover + assert ( + kwargs["environment_feature_version"] is not None + ), "Must provide environment_feature_version for environment using v2 versioning" - if priority: - feature_segment.to(priority) - else: - feature_segment.bottom(priority) + if ( + priority is not None + and ( + collision := FeatureSegment.objects.filter( + environment=kwargs["environment"], + feature=kwargs["feature"], + environment_feature_version=kwargs.get( + "environment_feature_version" + ), + priority=priority, + ).first() + ) + is not None + ): + # Since there is no unique clause on the priority field, if a priority + # is set, it will just save the feature segment and not move others + # down. This ensures that the incoming priority space is 'free'. + collision.to(priority + 1) - return feature_segment + return super().save(**kwargs) class FeatureSegmentQuerySerializer(serializers.Serializer): diff --git a/api/tests/unit/features/feature_segments/test_unit_feature_segments_serializers.py b/api/tests/unit/features/feature_segments/test_unit_feature_segments_serializers.py index e710a70197af..848783dcfb2e 100644 --- a/api/tests/unit/features/feature_segments/test_unit_feature_segments_serializers.py +++ b/api/tests/unit/features/feature_segments/test_unit_feature_segments_serializers.py @@ -1,5 +1,8 @@ from unittest.mock import MagicMock +import pytest +from pytest_mock import MockerFixture + from environments.models import Environment from features.feature_segments.serializers import ( CustomCreateSegmentOverrideFeatureSegmentSerializer, @@ -61,7 +64,7 @@ def test_feature_segment_change_priorities_serializer_validate_fails_if_non_uniq assert serializer.errors -def test_feature_segment_serializer_save_sets_lowest_priority_if_none_given( +def test_feature_segment_serializer_save_new_feature_segment_sets_lowest_priority_if_none_given( feature: Feature, segment_featurestate: FeatureState, feature_segment: FeatureSegment, @@ -78,5 +81,174 @@ def test_feature_segment_serializer_save_sets_lowest_priority_if_none_given( new_feature_segment = serializer.save(feature=feature, environment=environment) # Then + feature_segment.refresh_from_db() assert feature_segment.priority == 0 assert new_feature_segment.priority == 1 + + +def test_feature_segment_serializer_save_new_feature_segment_sets_highest_priority_if_0_given( + feature: Feature, + segment_featurestate: FeatureState, + feature_segment: FeatureSegment, + another_segment: Segment, + environment: Environment, +) -> None: + # Given + serializer = CustomCreateSegmentOverrideFeatureSegmentSerializer( + data={"segment": another_segment.id, "priority": 0} + ) + serializer.is_valid(raise_exception=True) + + # When + new_feature_segment = serializer.save(feature=feature, environment=environment) + + # Then + feature_segment.refresh_from_db() + assert feature_segment.priority == 1 + assert new_feature_segment.priority == 0 + + +def test_feature_segment_serializer_save_new_feature_segment_moves_others_if_needed( + feature: Feature, + segment_featurestate: FeatureState, + feature_segment: FeatureSegment, + another_segment: Segment, + environment: Environment, +) -> None: + # Given + feature_segment.to(1) + + serializer = CustomCreateSegmentOverrideFeatureSegmentSerializer( + data={"segment": another_segment.id, "priority": 1} + ) + serializer.is_valid(raise_exception=True) + + # When + new_feature_segment = serializer.save(feature=feature, environment=environment) + + # Then + feature_segment.refresh_from_db() + assert feature_segment.priority == 2 + assert new_feature_segment.priority == 1 + + +def test_feature_segment_serializer_save_existing_feature_segment_does_nothing_if_no_priority_given( + feature: Feature, + segment_featurestate: FeatureState, + feature_segment: FeatureSegment, + another_segment: Segment, + environment: Environment, +) -> None: + # Given + feature_segment_to_update = FeatureSegment.objects.create( + environment=environment, + feature=feature, + segment=another_segment, + priority=1, + ) + + serializer = CustomCreateSegmentOverrideFeatureSegmentSerializer( + instance=feature_segment_to_update, data={"segment": another_segment.id} + ) + serializer.is_valid(raise_exception=True) + + # When + updated_feature_segment = serializer.save(feature=feature, environment=environment) + + # Then + feature_segment.refresh_from_db() + assert feature_segment.priority == 0 + assert updated_feature_segment.priority == 1 + + +def test_feature_segment_serializer_save_existing_feature_segment_sets_highest_priority_if_0_given( + feature: Feature, + segment_featurestate: FeatureState, + feature_segment: FeatureSegment, + another_segment: Segment, + environment: Environment, +) -> None: + # Given + feature_segment_to_update = FeatureSegment.objects.create( + environment=environment, + feature=feature, + segment=another_segment, + priority=1, + ) + + serializer = CustomCreateSegmentOverrideFeatureSegmentSerializer( + instance=feature_segment_to_update, + data={"segment": another_segment.id, "priority": 0}, + ) + serializer.is_valid(raise_exception=True) + + # When + updated_feature_segment = serializer.save(feature=feature, environment=environment) + + # Then + feature_segment.refresh_from_db() + assert feature_segment.priority == 1 + assert updated_feature_segment.priority == 0 + + +def test_feature_segment_serializer_save_existing_feature_segment_moves_others_if_needed( + feature: Feature, + segment_featurestate: FeatureState, + feature_segment: FeatureSegment, + another_segment: Segment, + environment: Environment, +) -> None: + # Given + feature_segment.to(1) + + feature_segment_to_update = FeatureSegment.objects.create( + environment=environment, + feature=feature, + segment=another_segment, + ) + + serializer = CustomCreateSegmentOverrideFeatureSegmentSerializer( + instance=feature_segment_to_update, + data={"segment": another_segment.id, "priority": 1}, + ) + serializer.is_valid(raise_exception=True) + + # When + updated_feature_segment = serializer.save(feature=feature, environment=environment) + + # Then + feature_segment.refresh_from_db() + assert feature_segment.priority == 2 + assert updated_feature_segment.priority == 1 + + +def test_feature_segment_serializer_save_new_feature_segment_does_nothing_on_error( + feature: Feature, + segment_featurestate: FeatureState, + feature_segment: FeatureSegment, + another_segment: Segment, + environment: Environment, + mocker: MockerFixture, +) -> None: + # Given + feature_segment.to(1) + serializer = CustomCreateSegmentOverrideFeatureSegmentSerializer( + data={"segment": another_segment.id, "priority": 1}, + ) + serializer.is_valid(raise_exception=True) + + class MockException(Exception): + pass + + mocker.patch( + "features.feature_segments.serializers.CustomCreateSegmentOverrideFeatureSegmentSerializer.create", + side_effect=MockException, + ) + + # When + with pytest.raises(MockException): + serializer.save(feature=feature, environment=environment) + + # Then + feature_segment.refresh_from_db() + assert feature_segment.priority == 1