Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support transient identities and traits #4325

Merged
merged 11 commits into from
Jul 17, 2024
83 changes: 60 additions & 23 deletions api/environments/identities/models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import typing
from itertools import chain

from django.db import models
from django.db.models import Prefetch, Q
from django.utils import timezone
from flag_engine.identities.traits.types import TraitValue
from flag_engine.segments.evaluator import evaluate_identity_in_segment

from environments.identities.managers import IdentityManager
Expand Down Expand Up @@ -204,28 +206,32 @@ def generate_traits(self, trait_data_items, persist=False):
:return: list of TraitModels
"""
trait_models = []
trait_models_to_persist = []

# Remove traits having Null(None) values
trait_data_items = filter(
lambda trait: trait["trait_value"] is not None, trait_data_items
)
for trait_data_item in trait_data_items:
# exclude traits with null values
if (trait_value := trait_data_item["trait_value"]) is None:
continue

trait_key = trait_data_item["trait_key"]
trait_value = trait_data_item["trait_value"]
trait_models.append(
Trait(
trait_key=trait_key,
identity=self,
**Trait.generate_trait_value_data(trait_value),
)
trait = Trait(
trait_key=trait_key,
identity=self,
**Trait.generate_trait_value_data(trait_value),
)
trait_models.append(trait)
if not trait_data_item.get("transient"):
trait_models_to_persist.append(trait)

if persist:
Trait.objects.bulk_create(trait_models)
Trait.objects.bulk_create(trait_models_to_persist)

return trait_models

def update_traits(self, trait_data_items):
def update_traits(
self,
trait_data_items: list[dict[str, TraitValue]],
) -> list[Trait]:
"""
Given a list of traits, update any that already exist and create any new ones.
Return the full list of traits for the given identity after these changes.
Expand All @@ -235,38 +241,59 @@ def update_traits(self, trait_data_items):
"""
current_traits = {t.trait_key: t for t in self.identity_traits.all()}

keys_to_delete = []
keys_to_delete = set()
new_traits = []
updated_traits = []
transient_traits = []

for trait_data_item in trait_data_items:
trait_key = trait_data_item["trait_key"]
trait_value = trait_data_item["trait_value"]
transient = trait_data_item.get("transient")

if transient:
transient_traits.append(
Trait(
**Trait.generate_trait_value_data(trait_value),
trait_key=trait_key,
identity=self,
)
)
continue

if trait_value is None:
# build a list of trait keys to delete having been nulled by the
# input data
keys_to_delete.append(trait_key)
keys_to_delete.add(trait_key)
matthewelwell marked this conversation as resolved.
Show resolved Hide resolved
continue

trait_value_data = Trait.generate_trait_value_data(trait_value)

if trait_key in current_traits:
current_trait = current_traits[trait_key]
# Don't update the trait if the value hasn't changed
if current_trait.trait_value == trait_value:
continue

for attr, value in trait_value_data.items():
for attr, value in Trait.generate_trait_value_data(trait_value).items():
setattr(current_trait, attr, value)
updated_traits.append(current_trait)
else:
new_traits.append(
Trait(**trait_value_data, trait_key=trait_key, identity=self)
continue

new_traits.append(
Trait(
**Trait.generate_trait_value_data(trait_value),
trait_key=trait_key,
identity=self,
)
)

# delete the traits that had their keys set to None
# (except the transient ones)
if keys_to_delete:
current_traits = {
trait_key: trait
for trait_key, trait in current_traits.items()
if trait_key not in keys_to_delete
}
self.identity_traits.filter(trait_key__in=keys_to_delete).delete()

Trait.objects.bulk_update(updated_traits, fields=Trait.BULK_UPDATE_FIELDS)
Expand All @@ -278,5 +305,15 @@ def update_traits(self, trait_data_items):
Trait.objects.bulk_create(new_traits, ignore_conflicts=True)

# return the full list of traits for this identity by refreshing from the db
# TODO: handle this in the above logic to avoid a second hit to the DB
return self.identity_traits.all()
# override persisted traits by transient traits in case of key collisions
khvn26 marked this conversation as resolved.
Show resolved Hide resolved
return [
*{
trait.trait_key: trait
for trait in chain(
current_traits.values(),
updated_traits,
new_traits,
transient_traits,
)
khvn26 marked this conversation as resolved.
Show resolved Hide resolved
}.values()
]
1 change: 1 addition & 0 deletions api/environments/identities/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class _TraitSerializer(serializers.Serializer):

class SDKIdentitiesQuerySerializer(serializers.Serializer):
identifier = serializers.CharField(required=True)
transient = serializers.BooleanField(default=False)


class IdentityAllFeatureStatesFeatureSerializer(serializers.Serializer):
Expand Down
3 changes: 2 additions & 1 deletion api/environments/identities/traits/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ def get_trait_value(obj):

class TraitSerializerBasic(serializers.ModelSerializer):
trait_value = TraitValueField(allow_null=True)
transient = serializers.BooleanField(default=False, write_only=True)

class Meta:
model = Trait
fields = ("id", "trait_key", "trait_value")
fields = ("id", "trait_key", "trait_value", "transient")
read_only_fields = ("id",)


Expand Down
18 changes: 13 additions & 5 deletions api/environments/identities/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from core.request_origin import RequestOrigin
from django.conf import settings
from django.db.models import Q
from django.utils import timezone
from django.utils.decorators import method_decorator
from django.views.decorators.cache import cache_page
from drf_yasg.utils import swagger_auto_schema
Expand Down Expand Up @@ -173,11 +174,18 @@ def get(self, request):
{"detail": "Missing identifier"}
) # TODO: add 400 status - will this break the clients?

identity, _ = Identity.objects.get_or_create_for_sdk(
identifier=identifier,
environment=request.environment,
integrations=IDENTITY_INTEGRATIONS,
)
if request.query_params.get("transient"):
identity = Identity(
created_date=timezone.now(),
identifier=identifier,
environment=request.environment,
)
else:
identity, _ = Identity.objects.get_or_create_for_sdk(
identifier=identifier,
environment=request.environment,
integrations=IDENTITY_INTEGRATIONS,
)
self.identity = identity

if settings.EDGE_API_URL and request.environment.project.enable_dynamo_db:
Expand Down
35 changes: 24 additions & 11 deletions api/environments/sdk/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import defaultdict

from core.constants import BOOLEAN, FLOAT, INTEGER, STRING
from django.utils import timezone
from rest_framework import serializers

from environments.identities.models import Identity
Expand Down Expand Up @@ -125,6 +126,7 @@ class IdentifyWithTraitsSerializer(
HideSensitiveFieldsSerializerMixin, serializers.Serializer
):
identifier = serializers.CharField(write_only=True, required=True)
transient = serializers.BooleanField(write_only=True, default=False)
traits = TraitSerializerBasic(required=False, many=True)
flags = SDKFeatureStateSerializer(read_only=True, many=True)

Expand All @@ -136,23 +138,34 @@ def save(self, **kwargs):
(optionally store traits if flag set on org)
"""
environment = self.context["environment"]
identity, created = Identity.objects.get_or_create(
identifier=self.validated_data["identifier"], environment=environment
)

transient = self.validated_data["transient"]
trait_data_items = self.validated_data.get("traits", [])

if not created and environment.project.organisation.persist_trait_data:
# if this is an update and we're persisting traits, then we need to
# partially update any traits and return the full list
trait_models = identity.update_traits(trait_data_items)
if transient:
identity = Identity(
created_date=timezone.now(),
identifier=self.validated_data["identifier"],
environment=environment,
)
trait_models = identity.generate_traits(trait_data_items, persist=False)

else:
# generate traits for the identity and store them if configured to do so
trait_models = identity.generate_traits(
trait_data_items,
persist=environment.project.organisation.persist_trait_data,
identity, created = Identity.objects.get_or_create(
identifier=self.validated_data["identifier"], environment=environment
)

if not created and environment.project.organisation.persist_trait_data:
# if this is an update and we're persisting traits, then we need to
# partially update any traits and return the full list
trait_models = identity.update_traits(trait_data_items)
else:
# generate traits for the identity and store them if configured to do so
trait_models = identity.generate_traits(
trait_data_items,
persist=environment.project.organisation.persist_trait_data,
)

all_feature_states = identity.get_all_feature_states(
traits=trait_models,
additional_filters=self.context.get("feature_states_additional_filters"),
Expand Down
2 changes: 2 additions & 0 deletions api/tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,8 @@ def segment_featurestate(
feature_segment: int,
) -> int:
data = {
"enabled": True,
"feature_state_value": {"type": "unicode", "string_value": "segment override"},
"feature": feature,
"environment": environment,
"feature_segment": feature_segment,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
from django.urls import reverse
from rest_framework import status
from rest_framework.test import APIClient

from features.feature_types import MULTIVARIATE
from tests.integration.helpers import (
Expand Down Expand Up @@ -221,3 +222,103 @@ def test_get_feature_states_for_identity_only_makes_one_query_to_get_mv_feature_

second_identity_response_json = second_identity_response.json()
assert len(second_identity_response_json["flags"]) == 3


def test_get_feature_states_for_identity__transient_identity__segment_match_expected(
sdk_client: APIClient,
feature: int,
segment: int,
segment_condition_property: str,
segment_condition_value: str,
segment_featurestate: int,
) -> None:
# Given
url = reverse("api-v1:sdk-identities")

# When
# flags are requested for a new transient identity
# that matches the segment
response = sdk_client.post(
url,
data=json.dumps(
{
"identifier": "unseen",
"traits": [
{
"trait_key": segment_condition_property,
"trait_value": segment_condition_value,
}
],
"transient": True,
}
),
content_type="application/json",
)

# Then
assert response.status_code == status.HTTP_200_OK
response_json = response.json()
assert (
flag_data := next(
(
flag
for flag in response_json["flags"]
if flag["feature"]["id"] == feature
),
None,
)
)
assert flag_data["enabled"] is True
assert flag_data["feature_state_value"] == "segment override"


def test_get_feature_states_for_identity__transient_trait__segment_match_expected(
sdk_client: APIClient,
feature: int,
segment: int,
segment_condition_property: str,
segment_condition_value: str,
segment_featurestate: int,
) -> None:
# Given
url = reverse("api-v1:sdk-identities")

# When
# flags are requested for a new transient identity
# that matches the segment
response = sdk_client.post(
url,
data=json.dumps(
{
"identifier": "unseen",
"traits": [
{
"trait_key": segment_condition_property,
"trait_value": segment_condition_value,
"transient": True,
},
{
"trait_key": "persistent",
"trait_value": "trait value",
},
],
}
),
content_type="application/json",
)

# Then
assert response.status_code == status.HTTP_200_OK
response_json = response.json()
assert (
flag_data := next(
(
flag
for flag in response_json["flags"]
if flag["feature"]["id"] == feature
),
None,
)
)
assert flag_data["enabled"] is True
assert flag_data["feature_state_value"] == "segment override"
10 changes: 8 additions & 2 deletions api/tests/unit/environments/identities/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,15 @@


def generate_trait_data_item(
trait_key: str = "trait_key", trait_value: typing.Any = "trait_value"
trait_key: str = "trait_key",
trait_value: typing.Any = "trait_value",
transient: bool = False,
):
return {"trait_key": trait_key, "trait_value": trait_value}
return {
"trait_key": trait_key,
"trait_value": trait_value,
"transient": transient,
}


def create_trait_for_identity(
Expand Down
Loading
Loading