Skip to content

Commit

Permalink
Run serialization only once and do not validate non-large attribute k…
Browse files Browse the repository at this point in the history
…inds
  • Loading branch information
LucasG0 committed Oct 9, 2024
1 parent 292bb88 commit 302f474
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 52 deletions.
74 changes: 24 additions & 50 deletions backend/infrahub/core/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from infrahub.exceptions import ValidationError
from infrahub.helpers import hash_password

from ..types import ATTRIBUTE_TYPES, Password, Text
from ..types import ATTRIBUTE_TYPES, LARGE_ATTRIBUTE_TYPES
from .constants.relationship_label import RELATIONSHIP_TO_NODE_LABEL, RELATIONSHIP_TO_VALUE_LABEL

if TYPE_CHECKING:
Expand Down Expand Up @@ -269,7 +269,12 @@ def to_db(self) -> dict[str, Any]:
if self.value is None:
data["value"] = NULL_VALUE
else:
data["value"] = self.serialize_value(self.value)
serialized_value = self.serialize_value()
if isinstance(serialized_value, str) and ATTRIBUTE_TYPES[self.schema.kind] not in LARGE_ATTRIBUTE_TYPES:
# Perform validation here to avoid an extra serialization during validation step.
# Standard non-str attributes (integer, boolean) do not exceed limit size related to neo4j indexing.
validate_string_length(serialized_value)
data["value"] = serialized_value

return data

Expand Down Expand Up @@ -297,13 +302,11 @@ def value_from_db(self, data: AttributeFromDB) -> Any:
return None
return self.deserialize_value(data=data)

@classmethod
def serialize_value(cls, value: Any) -> Any:
"""Serialize the value before storing it in the database.
This method is a class method because `validate_content` is a class method and relies on it."""
if isinstance(value, Enum):
return value.value
return value
def serialize_value(self) -> Any:
"""Serialize the value before storing it in the database."""
if isinstance(self.value, Enum):
return self.value.value
return self.value

def deserialize_value(self, data: AttributeFromDB) -> Any:
"""Deserialize the value coming from the database."""
Expand Down Expand Up @@ -628,12 +631,6 @@ class String(BaseAttribute):
type = str
value: str

@classmethod
def validate_content(cls, value: Any, name: str, schema: AttributeSchema) -> None:
super().validate_content(value, name, schema)
if ATTRIBUTE_TYPES[schema.kind] in [Text, Password]:
validate_string_length(cls.serialize_value(value))


class StringOptional(String):
value: Optional[str]
Expand All @@ -643,10 +640,9 @@ class HashedPassword(BaseAttribute):
type = str
value: str

@classmethod
def serialize_value(cls, value: str) -> str:
def serialize_value(self) -> str:
"""Serialize the value before storing it in the database."""
return hash_password(value)
return hash_password(self.value)


class HashedPasswordOptional(HashedPassword):
Expand All @@ -660,9 +656,7 @@ class Integer(BaseAttribute):


class IntegerOptional(Integer):
type = int
value: Optional[int]
from_pool: Optional[str] = None


class Boolean(BaseAttribute):
Expand Down Expand Up @@ -753,11 +747,6 @@ def validate_format(cls, value: Any, name: str, schema: AttributeSchema) -> None
if not is_valid_url(value):
raise ValidationError({name: f"{value} is not a valid {schema.kind}"})

@classmethod
def validate_content(cls, value: Any, name: str, schema: AttributeSchema) -> None:
super().validate_content(value=value, name=name, schema=schema)
validate_string_length(cls.serialize_value(value))


class URLOptional(URL):
value: Optional[str]
Expand Down Expand Up @@ -870,11 +859,10 @@ def validate_format(cls, value: Any, name: str, schema: AttributeSchema) -> None
except ValueError as exc:
raise ValidationError({name: f"{value} is not a valid {schema.kind}"}) from exc

@classmethod
def serialize_value(cls, value: str) -> str:
def serialize_value(self) -> str:
"""Serialize the value before storing it in the database."""

return ipaddress.ip_network(value).with_prefixlen
return ipaddress.ip_network(self.value).with_prefixlen

def get_db_node_type(self) -> AttributeDBNodeType:
if self.value is not None:
Expand All @@ -892,11 +880,6 @@ def to_db(self) -> dict[str, Any]:

return data

@classmethod
def validate_content(cls, value: Any, name: str, schema: AttributeSchema) -> None:
super().validate_content(value=value, name=name, schema=schema)
validate_string_length(cls.serialize_value(value))


class IPNetworkOptional(IPNetwork):
value: Optional[str]
Expand Down Expand Up @@ -1002,11 +985,10 @@ def validate_format(cls, value: Any, name: str, schema: AttributeSchema) -> None
except ValueError as exc:
raise ValidationError({name: f"{value} is not a valid {schema.kind}"}) from exc

@classmethod
def serialize_value(cls, value: str) -> str:
def serialize_value(self) -> str:
"""Serialize the value before storing it in the database."""

return ipaddress.ip_interface(value).with_prefixlen
return ipaddress.ip_interface(self.value).with_prefixlen

def get_db_node_type(self) -> AttributeDBNodeType:
if self.value is not None:
Expand All @@ -1023,11 +1005,6 @@ def to_db(self) -> dict[str, Any]:

return data

@classmethod
def validate_content(cls, value: Any, name: str, schema: AttributeSchema) -> None:
super().validate_content(value=value, name=name, schema=schema)
validate_string_length(cls.serialize_value(value))


class IPHostOptional(IPHost):
value: Optional[str]
Expand Down Expand Up @@ -1130,10 +1107,9 @@ def validate_format(cls, value: Any, name: str, schema: AttributeSchema) -> None
if not netaddr.valid_mac(addr=str(value)):
raise ValidationError({name: f"{value} is not a valid {schema.kind}"})

@classmethod
def serialize_value(cls, value: str) -> str:
def serialize_value(self) -> str:
"""Serialize the value as standard EUI-48 or EUI-64 before storing it in the database."""
return str(netaddr.EUI(addr=value))
return str(netaddr.EUI(addr=self.value))


class MacAddressOptional(MacAddress):
Expand All @@ -1150,10 +1126,9 @@ def deserialize_from_string(cls, value_as_string: str) -> Any:
return ujson.loads(value_as_string)
return []

@classmethod
def serialize_value(cls, value: list[Any]) -> str:
def serialize_value(self) -> str:
"""Serialize the value before storing it in the database."""
return ujson.dumps(value)
return ujson.dumps(self.value)

def deserialize_value(self, data: AttributeFromDB) -> Any:
"""Deserialize the value (potentially) coming from the database."""
Expand All @@ -1176,10 +1151,9 @@ def deserialize_from_string(cls, value_as_string: str) -> Any:
return ujson.loads(value_as_string)
return {}

@classmethod
def serialize_value(cls, value: Union[dict, list]) -> str:
def serialize_value(self) -> str:
"""Serialize the value before storing it in the database."""
return ujson.dumps(value)
return ujson.dumps(self.value)

def deserialize_value(self, data: AttributeFromDB) -> Any:
"""Deserialize the value (potentially) coming from the database."""
Expand Down
3 changes: 3 additions & 0 deletions backend/infrahub/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,9 @@ class Any(InfrahubDataType):

ATTRIBUTE_KIND_LABELS = list(ATTRIBUTE_TYPES.keys())

# Data types supporting large values, which can therefore not be indexed in neo4j.
LARGE_ATTRIBUTE_TYPES = [TextArea, JSON]


def get_attribute_type(kind: str = "Default") -> type[InfrahubDataType]:
"""Return an InfrahubDataType object for a given kind
Expand Down
7 changes: 5 additions & 2 deletions backend/tests/unit/core/test_attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,11 +691,14 @@ async def test_attribute_size(db: InfrahubDatabase, default_branch: Branch, all_

large_string = "a" * 5_000

await obj.new(db=db, name="obj1", mystring=large_string)

# Text field
with pytest.raises(
ValidationError, match=f"Text attribute length should be less than {MAX_STRING_LENGTH} characters."
):
await obj.new(db=db, name="obj1", mystring=large_string)
await obj.save(db=db)

# TextArea field should have no size limitation
await obj.new(db=db, name="obj1", mytextarea=large_string)
await obj.new(db=db, name="obj2", mytextarea=large_string)
await obj.save(db=db)

0 comments on commit 302f474

Please sign in to comment.