From c0b497e577eea5d2805601f2d78d4f4013e40f34 Mon Sep 17 00:00:00 2001 From: Robbe Sneyders Date: Fri, 23 Jun 2023 13:14:40 +0200 Subject: [PATCH] Don't validate returned Pandas dataframe strictly (#226) This PR filters out data from the pandas dataframe returned by the user that is not defined in the component spec. Previously, returning additional columns would raise an error. (see https://github.com/ml6team/fondant/pull/223#issuecomment-1602233930) --- components/download_images/src/main.py | 2 +- .../components/filter_text_complexity/src/main.py | 6 +----- fondant/component.py | 14 +++++++++----- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/components/download_images/src/main.py b/components/download_images/src/main.py index 307ca052b..36c92a546 100644 --- a/components/download_images/src/main.py +++ b/components/download_images/src/main.py @@ -126,7 +126,7 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: result_type="expand", ) - return dataframe[[("images", "data"), ("images", "width"), ("images", "height")]] + return dataframe if __name__ == "__main__": diff --git a/examples/pipelines/datacomp/components/filter_text_complexity/src/main.py b/examples/pipelines/datacomp/components/filter_text_complexity/src/main.py index 865597d0b..2e7f6616a 100644 --- a/examples/pipelines/datacomp/components/filter_text_complexity/src/main.py +++ b/examples/pipelines/datacomp/components/filter_text_complexity/src/main.py @@ -68,11 +68,7 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: ) mask = mask.to_numpy() - dataframe = dataframe[mask] - - dataframe = dataframe.drop(("text", "data"), axis=1) - - return dataframe + return dataframe[mask] if __name__ == "__main__": diff --git a/fondant/component.py b/fondant/component.py index 26c68468a..d10514ebb 100644 --- a/fondant/component.py +++ b/fondant/component.py @@ -173,9 +173,6 @@ def optional_fondant_arguments() -> t.List[str]: return ["input_manifest_path"] def _load_or_create_manifest(self) -> Manifest: - # create initial manifest - # TODO ideally get rid of args.metadata by including them in the storage args - component_id = self.spec.name.lower().replace(" ", "_") manifest = Manifest.create( base_path=self.metadata["base_path"], @@ -277,6 +274,15 @@ def wrapped_transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: tuple(column.split("_")) for column in dataframe.columns ) dataframe = self.transform(dataframe) + # Drop columns not in the produces section of the component spec + dataframe.drop( + columns=[ + (subset, field) + for (subset, field) in dataframe.columns + if subset not in self.spec.produces + or field not in self.spec.produces[subset].fields + ] + ) dataframe.columns = [ "_".join(column) for column in dataframe.columns.to_flat_index() ] @@ -300,9 +306,7 @@ def _process_dataset(self, manifest: Manifest) -> dd.DataFrame: meta_dict = {"id": pd.Series(dtype="object")} for subset_name, subset in self.spec.produces.items(): for field_name, field in subset.fields.items(): - print(field.type.value) meta_dict[f"{subset_name}_{field_name}"] = pd.Series( - # dtype=f"{field.type.value}[pyarrow]" dtype=pd.ArrowDtype(field.type.value) ) meta_df = pd.DataFrame(meta_dict).set_index("id")