diff --git a/components/download_images/src/main.py b/components/download_images/src/main.py index 307ca052b..36c92a546 100644 --- a/components/download_images/src/main.py +++ b/components/download_images/src/main.py @@ -126,7 +126,7 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: result_type="expand", ) - return dataframe[[("images", "data"), ("images", "width"), ("images", "height")]] + return dataframe if __name__ == "__main__": diff --git a/examples/pipelines/datacomp/components/filter_text_complexity/src/main.py b/examples/pipelines/datacomp/components/filter_text_complexity/src/main.py index 865597d0b..2e7f6616a 100644 --- a/examples/pipelines/datacomp/components/filter_text_complexity/src/main.py +++ b/examples/pipelines/datacomp/components/filter_text_complexity/src/main.py @@ -68,11 +68,7 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: ) mask = mask.to_numpy() - dataframe = dataframe[mask] - - dataframe = dataframe.drop(("text", "data"), axis=1) - - return dataframe + return dataframe[mask] if __name__ == "__main__": diff --git a/fondant/component.py b/fondant/component.py index 26c68468a..d10514ebb 100644 --- a/fondant/component.py +++ b/fondant/component.py @@ -173,9 +173,6 @@ def optional_fondant_arguments() -> t.List[str]: return ["input_manifest_path"] def _load_or_create_manifest(self) -> Manifest: - # create initial manifest - # TODO ideally get rid of args.metadata by including them in the storage args - component_id = self.spec.name.lower().replace(" ", "_") manifest = Manifest.create( base_path=self.metadata["base_path"], @@ -277,6 +274,15 @@ def wrapped_transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: tuple(column.split("_")) for column in dataframe.columns ) dataframe = self.transform(dataframe) + # Drop columns not in the produces section of the component spec + dataframe.drop( + columns=[ + (subset, field) + for (subset, field) in dataframe.columns + if subset not in self.spec.produces + or field not in self.spec.produces[subset].fields + ] + ) dataframe.columns = [ "_".join(column) for column in dataframe.columns.to_flat_index() ] @@ -300,9 +306,7 @@ def _process_dataset(self, manifest: Manifest) -> dd.DataFrame: meta_dict = {"id": pd.Series(dtype="object")} for subset_name, subset in self.spec.produces.items(): for field_name, field in subset.fields.items(): - print(field.type.value) meta_dict[f"{subset_name}_{field_name}"] = pd.Series( - # dtype=f"{field.type.value}[pyarrow]" dtype=pd.ArrowDtype(field.type.value) ) meta_df = pd.DataFrame(meta_dict).set_index("id")