Skip to content

Commit

Permalink
fix csv col names and add model_name
Browse files Browse the repository at this point in the history
  • Loading branch information
Dave Berenbaum committed Jul 15, 2024
1 parent 4349b6e commit 1f4d798
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 28 deletions.
42 changes: 28 additions & 14 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,7 @@ def parse_tabular(
None, type[Feature], Sequence[str], dict[str, FeatureType]
] = None,
object_name: str = "",
model_name: str = "",
**kwargs,
) -> "DataChain":
"""Generate chain from list of tabular files.
Expand All @@ -714,6 +715,7 @@ def parse_tabular(
corresponding types. List of column names is also accepted, in which
case types will be inferred.
object_name : Generated object column name.
model_name : Generated model name.
kwargs : Parameters to pass to pyarrow.dataset.dataset.
Examples:
Expand All @@ -740,10 +742,14 @@ def parse_tabular(

if object_name:
if isinstance(output, dict):
output = dict_to_feature(object_name, output)
model_name = model_name or object_name
output = dict_to_feature(model_name, output)
output = {object_name: output} # type: ignore[dict-item]
elif isinstance(output, type(Feature)):
output = {output._prefix(): output}
output = {
name: info.annotation # type: ignore[misc]
for name, info in output.model_fields.items()
}
output = {"source": IndexedFile} | output # type: ignore[assignment,operator]
return self.gen(ArrowGenerator(schema, **kwargs), output=output)

Expand All @@ -754,8 +760,11 @@ def from_csv(
delimiter: str = ",",
header: bool = True,
column_names: Optional[list[str]] = None,
output: Optional[dict[str, FeatureType]] = None,
output: Union[
None, type[Feature], Sequence[str], dict[str, FeatureType]
] = None,
object_name: str = "",
model_name: str = "",
**kwargs,
) -> "DataChain":
"""Generate chain from csv files.
Expand All @@ -765,9 +774,11 @@ def from_csv(
as `s3://`, `gs://`, `az://` or "file:///".
delimiter : Character for delimiting columns.
header : Whether the files include a header row.
column_names : Column names if no header. Implies `header = False`.
output : Dictionary defining column names and their corresponding types.
output : Dictionary or feature class defining column names and their
corresponding types. List of column names is also accepted, in which
case types will be inferred.
object_name : Created object column name.
model_name : Generated model name.
Examples:
Reading a csv file:
Expand All @@ -781,22 +792,22 @@ def from_csv(

chain = DataChain.from_storage(path, **kwargs)

if column_names and output:
msg = "error parsing csv - only one of column_names or output is allowed"
raise DatasetPrepareError(chain.name, msg)

if not header and not column_names:
if output:
if not header:
if not output:
msg = "error parsing csv - provide output if no header"
raise DatasetPrepareError(chain.name, msg)
if isinstance(output, Sequence):
column_names = output # type: ignore[assignment]
elif isinstance(output, dict):
column_names = list(output.keys())
else:
msg = "error parsing csv - provide column_names or output if no header"
raise DatasetPrepareError(chain.name, msg)
column_names = list(output.model_fields.keys())

parse_options = ParseOptions(delimiter=delimiter)
read_options = ReadOptions(column_names=column_names)
format = CsvFileFormat(parse_options=parse_options, read_options=read_options)
return chain.parse_tabular(
output=output, object_name=object_name, format=format
output=output, object_name=object_name, model_name=model_name, format=format
)

@classmethod
Expand All @@ -806,6 +817,7 @@ def from_parquet(
partitioning: Any = "hive",
output: Optional[dict[str, FeatureType]] = None,
object_name: str = "",
model_name: str = "",
**kwargs,
) -> "DataChain":
"""Generate chain from parquet files.
Expand All @@ -816,6 +828,7 @@ def from_parquet(
partitioning : Any pyarrow partitioning schema.
output : Dictionary defining column names and their corresponding types.
object_name : Created object column name.
model_name : Generated model name.
Examples:
Reading a single file:
Expand All @@ -828,6 +841,7 @@ def from_parquet(
return chain.parse_tabular(
output=output,
object_name=object_name,
model_name=model_name,
format="parquet",
partitioning=partitioning,
)
Expand Down
32 changes: 18 additions & 14 deletions tests/unit/lib/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,8 +690,8 @@ class Output(Feature):
dc = DataChain.from_storage(path.as_uri()).parse_tabular(
format="json", output=Output
)
df1 = dc.select("output.fname", "output.age", "output.loc").to_pandas()
df.columns = ["output.fname", "output.age", "output.loc"]
df1 = dc.select("fname", "age", "loc").to_pandas()
df.columns = ["fname", "age", "loc"]
assert df1.equals(df)


Expand Down Expand Up @@ -725,7 +725,7 @@ def test_from_csv_no_header_error(tmp_dir, catalog):
DataChain.from_csv(path.as_uri(), header=False)


def test_from_csv_no_header_output(tmp_dir, catalog):
def test_from_csv_no_header_output_dict(tmp_dir, catalog):
df = pd.DataFrame(DF_DATA.values()).transpose()
path = tmp_dir / "test.csv"
df.to_csv(path, header=False, index=False)
Expand All @@ -736,25 +736,29 @@ def test_from_csv_no_header_output(tmp_dir, catalog):
assert (df1.values != df.values).sum() == 0


def test_from_csv_no_header_column_names(tmp_dir, catalog):
def test_from_csv_no_header_output_feature(tmp_dir, catalog):
class Output(Feature):
first_name: str
age: int
city: str

df = pd.DataFrame(DF_DATA.values()).transpose()
path = tmp_dir / "test.csv"
df.to_csv(path, header=False, index=False)
dc = DataChain.from_csv(
path.as_uri(), header=False, column_names=["first_name", "age", "city"]
)
dc = DataChain.from_csv(path.as_uri(), header=False, output=Output)
df1 = dc.select("first_name", "age", "city").to_pandas()
assert (df1.values != df.values).sum() == 0


def test_from_csv_column_names_and_output(tmp_dir, catalog):
df = pd.DataFrame(DF_DATA)
def test_from_csv_no_header_output_list(tmp_dir, catalog):
df = pd.DataFrame(DF_DATA.values()).transpose()
path = tmp_dir / "test.csv"
df.to_csv(path)
column_names = ["fname", "age", "loc"]
output = {"fname": str, "age": int, "loc": str}
with pytest.raises(DataChainParamsError):
DataChain.from_csv(path.as_uri(), column_names=column_names, output=output)
df.to_csv(path, header=False, index=False)
dc = DataChain.from_csv(
path.as_uri(), header=False, output=["first_name", "age", "city"]
)
df1 = dc.select("first_name", "age", "city").to_pandas()
assert (df1.values != df.values).sum() == 0


def test_from_csv_tab_delimited(tmp_dir, catalog):
Expand Down

0 comments on commit 1f4d798

Please sign in to comment.