Skip to content

Commit

Permalink
Add produces to lightweight component (#829)
Browse files Browse the repository at this point in the history
PR that adds the `produces` argument to the lightweight component

---------

Co-authored-by: Robbe Sneyders <robbe.sneyders@ml6.eu>
Co-authored-by: Robbe Sneyders <robbe.sneyders@gmail.com>
Co-authored-by: Georges Lorré <35808396+GeorgesLorre@users.noreply.github.com>
  • Loading branch information
4 people authored Jan 31, 2024
1 parent 8e0ec82 commit c009879
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 23 deletions.
24 changes: 23 additions & 1 deletion src/fondant/pipeline/lightweight_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,15 @@ def image(cls) -> Image:
def consumes(cls) -> t.Optional[t.Dict[str, t.Any]]:
pass

@classmethod
def produces(cls) -> t.Optional[t.Dict[str, t.Any]]:
pass

@classmethod
def modify_spec_consumes(
cls,
spec_consumes: t.Dict[str, t.Any],
apply_consumes: t.Optional[t.Dict[str, t.Union[str, pa.DataType]]],
apply_consumes: t.Optional[t.Dict[str, pa.DataType]],
):
"""Modify fields based on the consumes argument in the 'apply' method."""
if apply_consumes:
Expand Down Expand Up @@ -135,12 +139,26 @@ def get_spec_consumes(

return spec_consumes

@classmethod
def get_spec_produces(cls):
"""Get the produces spec for the component."""
produces = cls.produces()

if produces is None:
return {"additionalProperties": True}

return {
k: (Type(v).to_dict() if k != "additionalProperties" else v)
for k, v in produces.items()
}


def lightweight_component(
*args,
extra_requires: t.Optional[t.List[str]] = None,
base_image: t.Optional[str] = None,
consumes: t.Optional[t.Dict[str, t.Any]] = None,
produces: t.Optional[t.Dict[str, t.Any]] = None,
):
"""Decorator to enable a lightweight component."""

Expand Down Expand Up @@ -229,6 +247,10 @@ class LightweightComponentOp(cls, LightweightComponent):
def image(cls) -> Image:
return image

@classmethod
def produces(cls) -> t.Optional[t.Dict[str, pa.DataType]]:
return produces

@classmethod
def consumes(cls) -> t.Optional[t.Dict[str, t.Dict[t.Any, t.Any]]]:
return consumes
Expand Down
7 changes: 4 additions & 3 deletions src/fondant/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,9 @@ def from_ref(
name = ref.__name__
image = ref.image()
description = ref.__doc__ or "lightweight component"
spec_produces = ref.get_spec_produces()

consumes_spec = (
spec_consumes = (
ref.get_spec_consumes(fields, kwargs["consumes"])
if fields
else {"additionalProperties": True}
Expand All @@ -237,8 +238,8 @@ def from_ref(
name,
image.base_image,
description=description,
consumes=consumes_spec,
produces={"additionalProperties": True},
consumes=spec_consumes,
produces=spec_produces,
args={
name: arg.to_spec()
for name, arg in infer_arguments(ref).items()
Expand Down
2 changes: 1 addition & 1 deletion tests/pipeline/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ def test_kubeflow_component_spec_from_lightweight_component(
@lightweight_component(
base_image="python:3.8-slim-buster",
extra_requires=["pandas", "dask"],
produces={"x": pa.int32(), "y": pa.int32()},
)
class CreateData(DaskLoadComponent):
def load(self) -> dd.DataFrame:
Expand All @@ -427,7 +428,6 @@ def load(self) -> dd.DataFrame:

_ = pipeline.read(
ref=CreateData,
produces={"x": pa.int32(), "y": pa.int32()},
)

compiler = KubeFlowCompiler()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def load_pipeline(caplog):
@lightweight_component(
base_image="python:3.8-slim-buster",
extra_requires=["pandas", "dask"],
produces={"x": pa.int32(), "y": pa.int32(), "z": pa.int32()},
)
class CreateData(DaskLoadComponent):
def load(self) -> dd.DataFrame:
Expand All @@ -54,7 +55,6 @@ def load(self) -> dd.DataFrame:

dataset = pipeline.read(
ref=CreateData,
produces={"x": pa.int32(), "y": pa.int32(), "z": pa.int32()},
)

caplog_records = caplog.records
Expand Down Expand Up @@ -103,22 +103,22 @@ def test_lightweight_component_sdk(default_fondant_image, load_pipeline):
"image": "python:3.8-slim-buster",
"description": "lightweight component",
"consumes": {"additionalProperties": True},
"produces": {"additionalProperties": True},
"produces": {
"x": {"type": "int32"},
"y": {"type": "int32"},
"z": {"type": "int32"},
},
},
"consumes": {},
"produces": {
"x": {"type": "int32"},
"y": {"type": "int32"},
"z": {"type": "int32"},
},
"produces": {},
}

# check warning: fondant is not part of the requirements
msg = "You are not using a Fondant default base image"

assert any(msg in record.message for record in caplog_records)

@lightweight_component
@lightweight_component(produces={"x": pa.int32()})
class AddN(PandasTransformComponent):
def __init__(self, n: int):
self.n = n
Expand All @@ -129,7 +129,6 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:

_ = dataset.apply(
ref=AddN,
produces={"x": pa.int32(), "y": pa.int32(), "z": pa.int32()},
arguments={"n": 1},
)
assert len(pipeline._graph.keys()) == 1 + 1
Expand All @@ -147,15 +146,11 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
"y": {"type": "int32"},
"z": {"type": "int32"},
},
"produces": {"additionalProperties": True},
"produces": {"x": {"type": "int32"}},
"args": {"n": {"type": "int"}},
},
"consumes": {},
"produces": {
"x": {"type": "int32"},
"y": {"type": "int32"},
"z": {"type": "int32"},
},
"produces": {},
}
pipeline._validate_pipeline_definition(run_id="dummy-run-id")

Expand All @@ -168,6 +163,7 @@ def test_consumes_mapping_all_fields(tmp_path_factory, load_pipeline):
extra_requires=[
"fondant[component]@git+https://github.com/ml6team/fondant@main",
],
produces={"a": pa.int32()},
)
class AddN(PandasTransformComponent):
def __init__(self, n: int):
Expand All @@ -182,7 +178,6 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
_ = dataset.apply(
ref=AddN,
consumes={"a": "x"},
produces={"a": pa.int32()},
arguments={"n": 1},
)

Expand All @@ -204,6 +199,7 @@ def test_consumes_mapping_specific_fields(tmp_path_factory, load_pipeline):
"fondant[component]@git+https://github.com/ml6team/fondant@main",
],
consumes={"a": pa.int32()},
produces={"a": pa.int32()},
)
class AddN(PandasTransformComponent):
def __init__(self, n: int):
Expand All @@ -218,7 +214,6 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
_ = dataset.apply(
ref=AddN,
consumes={"a": "x"},
produces={"a": pa.int32()},
arguments={"n": 1},
)

Expand All @@ -241,6 +236,7 @@ def test_consumes_mapping_additional_fields(tmp_path_factory, load_pipeline):
"fondant[component]@git+https://github.com/ml6team/fondant@main",
],
consumes={"additionalProperties": True},
produces={"a": pa.int32()},
)
class AddN(PandasTransformComponent):
def __init__(self, n: int):
Expand All @@ -255,7 +251,6 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
_ = dataset.apply(
ref=AddN,
consumes={"x": pa.int32()},
produces={"a": pa.int32()},
arguments={"n": 1},
)

Expand All @@ -271,6 +266,43 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
assert "z" not in operation_spec.inner_consumes


def test_produces_mapping_additional_fields(tmp_path_factory, load_pipeline):
@lightweight_component(
base_image="python:3.8",
extra_requires=[
"fondant[component]@git+https://github.com/ml6team/fondant@main",
],
consumes={"additionalProperties": True},
)
class AddN(PandasTransformComponent):
def __init__(self, n: int):
self.n = n

def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
dataframe["a"] = dataframe["x"].map(lambda x: x + self.n)
dataframe["b"] = dataframe["x"].map(lambda x: x + self.n)
dataframe["c"] = dataframe["x"].map(lambda x: x + self.n)
return dataframe

pipeline, dataset, _, _ = load_pipeline

_ = dataset.apply(
ref=AddN,
consumes={"x": pa.int32()},
produces={"a": pa.int32(), "b": pa.int32(), "c": pa.int32()},
arguments={"n": 1},
)

with tmp_path_factory.mktemp("temp") as fn:
output_path = str(fn / "kubeflow_pipeline.yml")
DockerCompiler().compile(pipeline=pipeline, output_path=output_path)
pipeline_configs = DockerComposeConfigs.from_spec(output_path)
operation_spec = OperationSpec.from_json(
pipeline_configs.component_configs["addn"].arguments["operation_spec"],
)
assert all(k in ["a", "b", "c"] for k in operation_spec.inner_produces)


def test_lightweight_component_missing_decorator():
pipeline = Pipeline(
name="dummy-pipeline",
Expand Down

0 comments on commit c009879

Please sign in to comment.