From 1f4d798f3f837a3f168cc1cfe4c505cc511f23df Mon Sep 17 00:00:00 2001 From: Dave Berenbaum Date: Mon, 15 Jul 2024 12:18:25 -0400 Subject: [PATCH] fix csv col names and add model_name --- src/datachain/lib/dc.py | 42 +++++++++++++++++++++----------- tests/unit/lib/test_datachain.py | 32 +++++++++++++----------- 2 files changed, 46 insertions(+), 28 deletions(-) diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index 56f2bd4e7..a49a6965f 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -705,6 +705,7 @@ def parse_tabular( None, type[Feature], Sequence[str], dict[str, FeatureType] ] = None, object_name: str = "", + model_name: str = "", **kwargs, ) -> "DataChain": """Generate chain from list of tabular files. @@ -714,6 +715,7 @@ def parse_tabular( corresponding types. List of column names is also accepted, in which case types will be inferred. object_name : Generated object column name. + model_name : Generated model name. kwargs : Parameters to pass to pyarrow.dataset.dataset. Examples: @@ -740,10 +742,14 @@ def parse_tabular( if object_name: if isinstance(output, dict): - output = dict_to_feature(object_name, output) + model_name = model_name or object_name + output = dict_to_feature(model_name, output) output = {object_name: output} # type: ignore[dict-item] elif isinstance(output, type(Feature)): - output = {output._prefix(): output} + output = { + name: info.annotation # type: ignore[misc] + for name, info in output.model_fields.items() + } output = {"source": IndexedFile} | output # type: ignore[assignment,operator] return self.gen(ArrowGenerator(schema, **kwargs), output=output) @@ -754,8 +760,11 @@ def from_csv( delimiter: str = ",", header: bool = True, column_names: Optional[list[str]] = None, - output: Optional[dict[str, FeatureType]] = None, + output: Union[ + None, type[Feature], Sequence[str], dict[str, FeatureType] + ] = None, object_name: str = "", + model_name: str = "", **kwargs, ) -> "DataChain": """Generate chain from csv files. @@ -765,9 +774,11 @@ def from_csv( as `s3://`, `gs://`, `az://` or "file:///". delimiter : Character for delimiting columns. header : Whether the files include a header row. - column_names : Column names if no header. Implies `header = False`. - output : Dictionary defining column names and their corresponding types. + output : Dictionary or feature class defining column names and their + corresponding types. List of column names is also accepted, in which + case types will be inferred. object_name : Created object column name. + model_name : Generated model name. Examples: Reading a csv file: @@ -781,22 +792,22 @@ def from_csv( chain = DataChain.from_storage(path, **kwargs) - if column_names and output: - msg = "error parsing csv - only one of column_names or output is allowed" - raise DatasetPrepareError(chain.name, msg) - - if not header and not column_names: - if output: + if not header: + if not output: + msg = "error parsing csv - provide output if no header" + raise DatasetPrepareError(chain.name, msg) + if isinstance(output, Sequence): + column_names = output # type: ignore[assignment] + elif isinstance(output, dict): column_names = list(output.keys()) else: - msg = "error parsing csv - provide column_names or output if no header" - raise DatasetPrepareError(chain.name, msg) + column_names = list(output.model_fields.keys()) parse_options = ParseOptions(delimiter=delimiter) read_options = ReadOptions(column_names=column_names) format = CsvFileFormat(parse_options=parse_options, read_options=read_options) return chain.parse_tabular( - output=output, object_name=object_name, format=format + output=output, object_name=object_name, model_name=model_name, format=format ) @classmethod @@ -806,6 +817,7 @@ def from_parquet( partitioning: Any = "hive", output: Optional[dict[str, FeatureType]] = None, object_name: str = "", + model_name: str = "", **kwargs, ) -> "DataChain": """Generate chain from parquet files. @@ -816,6 +828,7 @@ def from_parquet( partitioning : Any pyarrow partitioning schema. output : Dictionary defining column names and their corresponding types. object_name : Created object column name. + model_name : Generated model name. Examples: Reading a single file: @@ -828,6 +841,7 @@ def from_parquet( return chain.parse_tabular( output=output, object_name=object_name, + model_name=model_name, format="parquet", partitioning=partitioning, ) diff --git a/tests/unit/lib/test_datachain.py b/tests/unit/lib/test_datachain.py index 524208e65..57c1459e4 100644 --- a/tests/unit/lib/test_datachain.py +++ b/tests/unit/lib/test_datachain.py @@ -690,8 +690,8 @@ class Output(Feature): dc = DataChain.from_storage(path.as_uri()).parse_tabular( format="json", output=Output ) - df1 = dc.select("output.fname", "output.age", "output.loc").to_pandas() - df.columns = ["output.fname", "output.age", "output.loc"] + df1 = dc.select("fname", "age", "loc").to_pandas() + df.columns = ["fname", "age", "loc"] assert df1.equals(df) @@ -725,7 +725,7 @@ def test_from_csv_no_header_error(tmp_dir, catalog): DataChain.from_csv(path.as_uri(), header=False) -def test_from_csv_no_header_output(tmp_dir, catalog): +def test_from_csv_no_header_output_dict(tmp_dir, catalog): df = pd.DataFrame(DF_DATA.values()).transpose() path = tmp_dir / "test.csv" df.to_csv(path, header=False, index=False) @@ -736,25 +736,29 @@ def test_from_csv_no_header_output(tmp_dir, catalog): assert (df1.values != df.values).sum() == 0 -def test_from_csv_no_header_column_names(tmp_dir, catalog): +def test_from_csv_no_header_output_feature(tmp_dir, catalog): + class Output(Feature): + first_name: str + age: int + city: str + df = pd.DataFrame(DF_DATA.values()).transpose() path = tmp_dir / "test.csv" df.to_csv(path, header=False, index=False) - dc = DataChain.from_csv( - path.as_uri(), header=False, column_names=["first_name", "age", "city"] - ) + dc = DataChain.from_csv(path.as_uri(), header=False, output=Output) df1 = dc.select("first_name", "age", "city").to_pandas() assert (df1.values != df.values).sum() == 0 -def test_from_csv_column_names_and_output(tmp_dir, catalog): - df = pd.DataFrame(DF_DATA) +def test_from_csv_no_header_output_list(tmp_dir, catalog): + df = pd.DataFrame(DF_DATA.values()).transpose() path = tmp_dir / "test.csv" - df.to_csv(path) - column_names = ["fname", "age", "loc"] - output = {"fname": str, "age": int, "loc": str} - with pytest.raises(DataChainParamsError): - DataChain.from_csv(path.as_uri(), column_names=column_names, output=output) + df.to_csv(path, header=False, index=False) + dc = DataChain.from_csv( + path.as_uri(), header=False, output=["first_name", "age", "city"] + ) + df1 = dc.select("first_name", "age", "city").to_pandas() + assert (df1.values != df.values).sum() == 0 def test_from_csv_tab_delimited(tmp_dir, catalog):