Skip to content

Commit

Permalink
change meta estimation (#409)
Browse files Browse the repository at this point in the history
PR that fixes an issue with the load from hub component when setting an
index dynamically. Previous implementation relied on estimating the
`meta/schema` dynamically but this lead to some errors since sometimes
the wrong type was inferred.

This PR fixes this issue by reading in the schema from the component
spec
  • Loading branch information
PhilippeMoussalli authored Sep 13, 2023
1 parent 7ad42aa commit a079473
Showing 1 changed file with 28 additions and 10 deletions.
38 changes: 28 additions & 10 deletions components/load_from_hf_hub/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,31 @@
import logging
import typing as t

import dask
import dask.dataframe as dd
import pandas as pd
from fondant.component import DaskLoadComponent
from fondant.component_spec import ComponentSpec

logger = logging.getLogger(__name__)

dask.config.set({"dataframe.convert-string": False})


class LoadFromHubComponent(DaskLoadComponent):

def __init__(self, *_,
dataset_name: str,
column_name_mapping: dict,
image_column_names: t.Optional[list],
n_rows_to_load: t.Optional[int],
index_column:t.Optional[str],
def __init__(self,
spec: ComponentSpec,
*_,
dataset_name: str,
column_name_mapping: dict,
image_column_names: t.Optional[list],
n_rows_to_load: t.Optional[int],
index_column: t.Optional[str],
) -> None:
"""
Args:
spec: the component spec
dataset_name: name of the dataset to load.
column_name_mapping: Mapping of the consumed hub dataset to fondant column names
image_column_names: A list containing the original hub image column names. Used to
Expand All @@ -34,6 +41,7 @@ def __init__(self, *_,
self.image_column_names = image_column_names
self.n_rows_to_load = n_rows_to_load
self.index_column = index_column
self.spec = spec

def load(self) -> dd.DataFrame:
# 1) Load data, read as Dask dataframe
Expand Down Expand Up @@ -74,14 +82,24 @@ def _set_unique_index(dataframe: pd.DataFrame, partition_info=None):
"""Function that sets a unique index based on the partition and row number."""
dataframe["id"] = 1
dataframe["id"] = (
str(partition_info["number"])
+ "_"
+ (dataframe.id.cumsum()).astype(str)
str(partition_info["number"])
+ "_"
+ (dataframe.id.cumsum()).astype(str)
)
dataframe.index = dataframe.pop("id")
return dataframe

dask_df = dask_df.map_partitions(_set_unique_index, meta=dask_df.head())
def _get_meta_df() -> pd.DataFrame:
meta_dict = {"id": pd.Series(dtype="object")}
for subset_name, subset in self.spec.produces.items():
for field_name, field in subset.fields.items():
meta_dict[f"{subset_name}_{field_name}"] = pd.Series(
dtype=pd.ArrowDtype(field.type.value),
)
return pd.DataFrame(meta_dict).set_index("id")

meta = _get_meta_df()
dask_df = dask_df.map_partitions(_set_unique_index, meta=meta)
else:
logger.info(f"Setting `{self.index_column}` as index")
dask_df = dask_df.set_index(self.index_column, drop=True)
Expand Down

0 comments on commit a079473

Please sign in to comment.