Skip to content

Commit

Permalink
update caching arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilippeMoussalli committed Jan 19, 2024
1 parent 72d6822 commit 1ca4479
Show file tree
Hide file tree
Showing 12 changed files with 46 additions and 36 deletions.
2 changes: 1 addition & 1 deletion data_explorer/app/df_helpers/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def get_fields_by_types(
filtered_fields = []

for field, f_type in fields.items():
if any(ftype in f_type.type.to_json()["type"] for ftype in field_types):
if any(ftype in f_type.type.to_dict()["type"] for ftype in field_types):
filtered_fields.append(field)

return filtered_fields
Expand Down
2 changes: 1 addition & 1 deletion data_explorer/app/pages/home.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def create_component_table(manifest: Manifest) -> str:
component_name = manifest.component_id

fields_with_schema = [
(field_name, field_schema.type.to_json()["type"])
(field_name, field_schema.type.to_dict()["type"])
for field_name, field_schema in fields.items()
]

Expand Down
Binary file modified examples/sample_pipeline/data/sample.parquet
Binary file not shown.
1 change: 0 additions & 1 deletion examples/sample_pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
arguments={
"dataset_uri": "/data/sample.parquet",
"column_name_mapping": load_component_column_mapping,
"n_rows_to_load": 5,
},
produces={"text_data": pa.string()},
)
Expand Down
18 changes: 10 additions & 8 deletions src/fondant/core/component_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def __init__(
image: str,
*,
description: t.Optional[str] = None,
consumes: t.Optional[t.Dict[str, t.Union[str, pa.DataType]]] = None,
produces: t.Optional[t.Dict[str, t.Union[str, pa.DataType]]] = None,
consumes: t.Optional[t.Dict[str, t.Union[str, pa.DataType, bool]]] = None,
produces: t.Optional[t.Dict[str, t.Union[str, pa.DataType, bool]]] = None,
previous_index: t.Optional[str] = None,
args: t.Optional[t.Dict[str, t.Any]] = None,
tags: t.Optional[t.List[str]] = None,
Expand Down Expand Up @@ -223,7 +223,7 @@ def consumes(self) -> t.Mapping[str, Field]:
"""The fields consumed by the component as an immutable mapping."""
return types.MappingProxyType(
{
name: Field(name=name, type=Type.from_json(field))
name: Field(name=name, type=Type.from_dict(field))
for name, field in self._specification.get("consumes", {}).items()
if name != "additionalProperties"
},
Expand All @@ -234,7 +234,7 @@ def produces(self) -> t.Mapping[str, Field]:
"""The fields produced by the component as an immutable mapping."""
return types.MappingProxyType(
{
name: Field(name=name, type=Type.from_json(field))
name: Field(name=name, type=Type.from_dict(field))
for name, field in self._specification.get("produces", {}).items()
if name != "additionalProperties"
},
Expand Down Expand Up @@ -368,7 +368,7 @@ def __init__(
self._inner_produces: t.Optional[t.Mapping[str, Field]] = None
self._outer_produces: t.Optional[t.Mapping[str, Field]] = None

def to_json(self) -> str:
def to_dict(self) -> dict:
def _dump_mapping(
mapping: t.Optional[t.Dict[str, t.Union[str, pa.DataType]]],
) -> dict:
Expand All @@ -378,15 +378,17 @@ def _dump_mapping(
serialized_mapping: t.Dict[str, t.Any] = mapping.copy()
for key, value in mapping.items():
if isinstance(value, pa.DataType):
serialized_mapping[key] = Type(value).to_json()
serialized_mapping[key] = Type(value).to_dict()
return serialized_mapping

specification_dict = {
return {
"specification": self._component_spec.specification,
"consumes": _dump_mapping(self._mappings["consumes"]),
"produces": _dump_mapping(self._mappings["produces"]),
}

def to_json(self) -> str:
specification_dict = self.to_dict()
return json.dumps(specification_dict)

@classmethod
Expand All @@ -397,7 +399,7 @@ def _parse_mapping(
"""Parse a json mapping to a Python mapping with Fondant types."""
for key, value in json_mapping.items():
if isinstance(value, dict):
json_mapping[key] = Type.from_json(value).value
json_mapping[key] = Type.from_dict(value).value
return json_mapping

return cls(
Expand Down
4 changes: 2 additions & 2 deletions src/fondant/core/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def fields(self) -> t.Mapping[str, Field]:
{
name: Field(
name=name,
type=Type.from_json(field),
type=Type.from_dict(field),
location=field["location"],
)
for name, field in self._specification["fields"].items()
Expand All @@ -208,7 +208,7 @@ def add_or_update_field(self, field: Field, overwrite: bool = False):
else:
self._specification["fields"][field.name] = {
"location": field.location,
**field.type.to_json(),
**field.type.to_dict(),
}

def _add_or_update_index(self, field: Field, overwrite: bool = True):
Expand Down
8 changes: 4 additions & 4 deletions src/fondant/core/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def list(cls, data_type: t.Union[str, pa.DataType, "Type"]) -> "Type":
)

@classmethod
def from_json(cls, json_schema: dict):
def from_dict(cls, json_schema: dict):
"""
Creates a new `Type` instance based on a dictionary representation of the json schema
of a data type (https://swagger.io/docs/specification/data-models/data-types/).
Expand All @@ -150,12 +150,12 @@ def from_json(cls, json_schema: dict):
if json_schema["type"] == "array":
items = json_schema["items"]
if isinstance(items, dict):
return cls.list(cls.from_json(items))
return cls.list(cls.from_dict(items))
return None

return cls(json_schema["type"])

def to_json(self) -> dict:
def to_dict(self) -> dict:
"""
Converts the `Type` instance to its JSON representation.
Expand All @@ -165,7 +165,7 @@ def to_json(self) -> dict:
if isinstance(self.value, pa.ListType):
items = self.value.value_type
if isinstance(items, pa.DataType):
return {"type": "array", "items": Type(items).to_json()}
return {"type": "array", "items": Type(items).to_dict()}

type_ = None
for type_name, data_type in _TYPES.items():
Expand Down
5 changes: 4 additions & 1 deletion src/fondant/pipeline/lightweight_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import itertools
import textwrap
import typing as t
from dataclasses import dataclass
from dataclasses import asdict, dataclass
from functools import wraps

from fondant.component import BaseComponent, Component
Expand All @@ -19,6 +19,9 @@ def __post_init__(self):
# TODO: link to Fondant version
self.base_image = "fondant:latest"

def to_dict(self):
return asdict(self)


class PythonComponent(BaseComponent):
@classmethod
Expand Down
9 changes: 7 additions & 2 deletions src/fondant/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,18 +332,23 @@ def get_nested_dict_hash(input_dict):
hash_object = hashlib.md5(sorted_json_string.encode()) # nosec
return hash_object.hexdigest()

component_spec_dict = self.component_spec.specification
operation_spec_dict = self.operation_spec.to_json()
image_dict = self.image.to_dict()

arguments = (
get_nested_dict_hash(self.arguments) if self.arguments is not None else None
)

component_op_uid_dict = {
"component_spec_hash": get_nested_dict_hash(component_spec_dict),
"operation_spec_hash": get_nested_dict_hash(operation_spec_dict),
"image": get_nested_dict_hash(image_dict),
"arguments": arguments,
"input_partition_rows": self.input_partition_rows,
"number_of_accelerators": self.resources.accelerator_number,
"accelerator_name": self.resources.accelerator_name,
"node_pool_name": self.resources.node_pool_name,
"cluster_type": self.cluster_type,
"client_kwargs": self.client_kwargs,
}

if previous_component_cache is not None:
Expand Down
20 changes: 10 additions & 10 deletions tests/core/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,20 @@ def test_valid_type():
assert Type("int8").value == pa.int8()
assert Type.list(Type("int8")).value == pa.list_(pa.int8())
assert Type.list(Type.list(Type("string"))).value == pa.list_(pa.list_(pa.string()))
assert Type("int8").to_json() == {"type": "int8"}
assert Type.list("float32").to_json() == {
assert Type("int8").to_dict() == {"type": "int8"}
assert Type.list("float32").to_dict() == {
"type": "array",
"items": {"type": "float32"},
}


def test_valid_json_schema():
"""Test that Type class initialized with a json schema matches the expected pyarrow schema."""
assert Type.from_json({"type": "string"}).value == pa.string()
assert Type.from_json(
assert Type.from_dict({"type": "string"}).value == pa.string()
assert Type.from_dict(
{"type": "array", "items": {"type": "int8"}},
).value == pa.list_(pa.int8())
assert Type.from_json(
assert Type.from_dict(
{"type": "array", "items": {"type": "array", "items": {"type": "int8"}}},
).value == pa.list_(pa.list_(pa.int8()))

Expand All @@ -32,12 +32,12 @@ def test_valid_json_schema():
"statement",
[
'Type("invalid_type")',
'Type("invalid_type").to_json()',
'Type("invalid_type").to_dict()',
'Type.list(Type("invalid_type"))',
'Type.list(Type("invalid_type")).to_json()',
'Type.from_json({"type": "invalid_value"})',
'Type.from_json({"type": "invalid_value", "items": {"type": "int8"}})',
'Type.from_json({"type": "array", "items": {"type": "invalid_type"}})',
'Type.list(Type("invalid_type")).to_dict()',
'Type.from_dict({"type": "invalid_value"})',
'Type.from_dict({"type": "invalid_value", "items": {"type": "int8"}})',
'Type.from_dict({"type": "array", "items": {"type": "invalid_type"}})',
],
)
def test_invalid_json_schema(statement):
Expand Down
2 changes: 1 addition & 1 deletion tests/pipeline/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ def test_invalid_vertex_configuration(tmp_path_factory):


def test_caching_dependency_docker(tmp_path_factory):
"""Test that the component cache key changes when a depending component cache key change for
"""Test that the component cache key changes when a dependant component cache key change for
the docker compiler.
"""
arg_list = ["dummy_arg_1", "dummy_arg_2"]
Expand Down
11 changes: 6 additions & 5 deletions tests/pipeline/test_python_component.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import textwrap

import dask.dataframe as dd
Expand Down Expand Up @@ -77,8 +76,10 @@ def load(self) -> dd.DataFrame:
)

assert len(pipeline._graph.keys()) == 1
operation_spec = pipeline._graph["CreateData"]["operation"].operation_spec.to_json()
assert json.loads(operation_spec) == {
operation_spec_dict = pipeline._graph["CreateData"][
"operation"
].operation_spec.to_dict()
assert operation_spec_dict == {
"specification": {
"name": "CreateData",
"image": "python:3.8-slim-buster",
Expand Down Expand Up @@ -107,8 +108,8 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
)
assert len(pipeline._graph.keys()) == 1 + 1
assert pipeline._graph["AddN"]["dependencies"] == ["CreateData"]
operation_spec = pipeline._graph["AddN"]["operation"].operation_spec.to_json()
assert json.loads(operation_spec) == {
operation_spec_dict = pipeline._graph["AddN"]["operation"].operation_spec.to_dict()
assert operation_spec_dict == {
"specification": {
"name": "AddN",
"image": "fondant:latest",
Expand Down

0 comments on commit 1ca4479

Please sign in to comment.