Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup and add more tests #792

Merged
merged 1 commit into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/fondant/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,7 @@ class InvalidTypeSchema(ValidationError, FondantException):

class UnsupportedTypeAnnotation(FondantException):
"""Thrown when an unsupported type annotation is encountered during type inference."""


class InvalidPythonComponent(FondantException):
"""Thrown when a component is not a valid Python component."""
192 changes: 76 additions & 116 deletions src/fondant/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@

import pyarrow as pa

from fondant.component import BaseComponent
from fondant.core.component_spec import ComponentSpec, OperationSpec
from fondant.core.exceptions import InvalidPipelineDefinition
from fondant.core.exceptions import InvalidPipelineDefinition, InvalidPythonComponent
from fondant.core.manifest import Manifest
from fondant.core.schema import Field
from fondant.pipeline import Image, PythonComponent
Expand Down Expand Up @@ -182,7 +183,7 @@ def __init__(
self.resources = resources or Resources()

@classmethod
def from_component_yaml(cls, path, **kwargs):
def from_component_yaml(cls, path, **kwargs) -> "ComponentOp":
if cls._is_custom_component(path):
component_dir = Path(path)
else:
Expand All @@ -203,6 +204,48 @@ def from_component_yaml(cls, path, **kwargs):
**kwargs,
)

@classmethod
def from_ref(cls, ref: t.Any, **kwargs) -> "ComponentOp":
"""Create a ComponentOp from a reference. The reference can
be a reusable component name, a path to a custom component,
or a python component class.
"""
if inspect.isclass(ref) and issubclass(ref, BaseComponent):
if issubclass(ref, PythonComponent):
name = ref.__name__
image = ref.image()
description = ref.__doc__ or "python component"

component_spec = ComponentSpec(
name,
image.base_image,
description=description,
consumes={"additionalProperties": True},
produces={"additionalProperties": True},
)

operation = cls(
name,
image,
component_spec,
**kwargs,
)
else:
msg = """Reference is not a valid Python component.
Make sure the component is decorated properly."""
raise InvalidPythonComponent(msg)

elif isinstance(ref, (str, Path)):
operation = cls.from_component_yaml(
ref,
**kwargs,
)
else:
msg = f"""Invalid reference type: {type(ref)}.
Expected a string, Path, or a Python component class."""
raise ValueError(msg)
return operation

def _configure_caching_from_image_tag(
self,
cache: t.Optional[bool],
Expand Down Expand Up @@ -386,44 +429,16 @@ def read(
msg,
)

if inspect.isclass(ref) and issubclass(ref, PythonComponent):
name = ref.__name__
image = ref.image()
description = ref.__doc__ or "python component"

component_spec = ComponentSpec(
name,
image.base_image, # TODO: revisit
description=description,
consumes={"additionalProperties": True},
produces={"additionalProperties": True},
)

operation = ComponentOp(
name,
image,
component_spec,
produces=produces,
arguments=arguments,
input_partition_rows=input_partition_rows,
resources=resources,
cache=cache,
cluster_type=cluster_type,
client_kwargs=client_kwargs,
)

else:
operation = ComponentOp.from_component_yaml(
ref,
produces=produces,
arguments=arguments,
input_partition_rows=input_partition_rows,
resources=resources,
cache=cache,
cluster_type=cluster_type,
client_kwargs=client_kwargs,
)

operation = ComponentOp.from_ref(
ref,
produces=produces,
arguments=arguments,
input_partition_rows=input_partition_rows,
resources=resources,
cache=cache,
cluster_type=cluster_type,
client_kwargs=client_kwargs,
)
manifest = Manifest.create(
pipeline_name=self.name,
base_path=self.base_path,
Expand Down Expand Up @@ -697,45 +712,17 @@ def apply(
Returns:
An intermediate dataset.
"""
if inspect.isclass(ref) and issubclass(ref, PythonComponent):
name = ref.__name__
image = ref.image()
description = ref.__doc__ or "python component"

component_spec = ComponentSpec(
name,
image.base_image, # TODO: revisit
description=description,
consumes={"additionalProperties": True},
produces={"additionalProperties": True},
)

operation = ComponentOp(
name,
image,
component_spec,
consumes=consumes,
produces=produces,
arguments=arguments,
input_partition_rows=input_partition_rows,
resources=resources,
cache=cache,
cluster_type=cluster_type,
client_kwargs=client_kwargs,
)

else:
operation = ComponentOp.from_component_yaml(
ref,
consumes=consumes,
produces=produces,
arguments=arguments,
input_partition_rows=input_partition_rows,
resources=resources,
cache=cache,
cluster_type=cluster_type,
client_kwargs=client_kwargs,
)
operation = ComponentOp.from_ref(
ref,
produces=produces,
consumes=consumes,
arguments=arguments,
input_partition_rows=input_partition_rows,
resources=resources,
cache=cache,
cluster_type=cluster_type,
client_kwargs=client_kwargs,
)

return self._apply(operation)

Expand Down Expand Up @@ -772,41 +759,14 @@ def write(
Returns:
An intermediate dataset.
"""
if inspect.isclass(ref) and issubclass(ref, PythonComponent):
name = ref.__name__
image = ref.image()
description = ref.__doc__ or "python component"

component_spec = ComponentSpec(
name,
image.base_image, # TODO: revisit
description=description,
consumes={"additionalProperties": True},
produces={"additionalProperties": True},
)

operation = ComponentOp(
name,
image,
component_spec,
consumes=consumes,
arguments=arguments,
input_partition_rows=input_partition_rows,
resources=resources,
cache=cache,
cluster_type=cluster_type,
client_kwargs=client_kwargs,
)

else:
operation = ComponentOp.from_component_yaml(
ref,
consumes=consumes,
arguments=arguments,
input_partition_rows=input_partition_rows,
resources=resources,
cache=cache,
cluster_type=cluster_type,
client_kwargs=client_kwargs,
)
operation = ComponentOp.from_ref(
ref,
consumes=consumes,
arguments=arguments,
input_partition_rows=input_partition_rows,
resources=resources,
cache=cache,
cluster_type=cluster_type,
client_kwargs=client_kwargs,
)
self._apply(operation)
28 changes: 27 additions & 1 deletion tests/pipeline/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
import pyarrow as pa
import pytest
import yaml
from fondant.component import DaskLoadComponent
from fondant.core.component_spec import ComponentSpec
from fondant.core.exceptions import InvalidPipelineDefinition
from fondant.core.schema import Field, Type
from fondant.pipeline import ComponentOp, Pipeline, Resources
from fondant.pipeline import ComponentOp, Pipeline, Resources, lightweight_component

valid_pipeline_path = Path(__file__).parent / "examples/pipelines/valid_pipeline"
invalid_pipeline_path = Path(__file__).parent / "examples/pipelines/invalid_pipeline"
Expand Down Expand Up @@ -67,6 +68,31 @@ def test_component_op(
)


def test_component_op_python_component(default_pipeline_args):
@lightweight_component()
class Foo(DaskLoadComponent):
def load(self) -> str:
return ["bar"]

component = ComponentOp.from_ref(Foo, produces={"bar": pa.string()})
assert component.component_spec._specification == {
"name": "Foo",
"image": "fondant:latest",
"description": "python component",
"consumes": {"additionalProperties": True},
"produces": {"additionalProperties": True},
}


def test_component_op_bad_ref():
with pytest.raises(
ValueError,
match="""Invalid reference type: <class 'int'>.
Expected a string, Path, or a Python component class.""",
):
ComponentOp.from_ref(123)


@pytest.mark.parametrize(
"valid_pipeline_example",
[
Expand Down
51 changes: 50 additions & 1 deletion tests/pipeline/test_python_component.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import json
import textwrap

import dask.dataframe as dd
import pandas as pd
import pyarrow as pa
import pytest
from fondant.component import DaskLoadComponent, PandasTransformComponent
from fondant.core.exceptions import InvalidPythonComponent
from fondant.pipeline import Pipeline, lightweight_component


Expand Down Expand Up @@ -46,7 +49,7 @@ def load(self) -> dd.DataFrame:
)


def test_lightweight_component(tmp_path_factory):
def test_lightweight_component_sdk():
pipeline = Pipeline(
name="dummy-pipeline",
base_path="./data",
Expand All @@ -72,6 +75,20 @@ def load(self) -> dd.DataFrame:
produces={"x": pa.int32(), "y": pa.int32()},
)

assert len(pipeline._graph.keys()) == 1
operation_spec = pipeline._graph["CreateData"]["operation"].operation_spec.to_json()
assert json.loads(operation_spec) == {
"specification": {
"name": "CreateData",
"image": "python:3.8-slim-buster",
"description": "python component",
"consumes": {"additionalProperties": True},
"produces": {"additionalProperties": True},
},
"consumes": {},
"produces": {"x": {"type": "int32"}, "y": {"type": "int32"}},
}

@lightweight_component()
class AddN(PandasTransformComponent):
def __init__(self, n: int):
Expand All @@ -86,3 +103,35 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
produces={"x": pa.int32(), "y": pa.int32()},
consumes={"x": pa.int32(), "y": pa.int32()},
)
assert len(pipeline._graph.keys()) == 1 + 1
assert pipeline._graph["AddN"]["dependencies"] == ["CreateData"]
operation_spec = pipeline._graph["AddN"]["operation"].operation_spec.to_json()
assert json.loads(operation_spec) == {
"specification": {
"name": "AddN",
"image": "fondant:latest",
"description": "python component",
"consumes": {"additionalProperties": True},
"produces": {"additionalProperties": True},
},
"consumes": {"x": {"type": "int32"}, "y": {"type": "int32"}},
"produces": {"x": {"type": "int32"}, "y": {"type": "int32"}},
}
pipeline._validate_pipeline_definition(run_id="dummy-run-id")


def test_lightweight_component_missing_decorator():
pipeline = Pipeline(
name="dummy-pipeline",
base_path="./data",
)

class Foo(DaskLoadComponent):
def load(self) -> str:
return "bar"

with pytest.raises(InvalidPythonComponent):
_ = pipeline.read(
ref=Foo,
produces={"x": pa.int32(), "y": pa.int32()},
)
Loading