Skip to content

Commit

Permalink
Support object_name in all from_ formats (#14)
Browse files Browse the repository at this point in the history
* support object_name in all from_ formats

* refactor arrow infer schema

* use session instead of catalog in datachain

* fix tests

* Revert "fix tests"

This reverts commit 4ca848f.

* hide anon arg in datachain

* drop optional from object_name

* accept list of col names

* fix csv col names and add model_name
  • Loading branch information
Dave Berenbaum authored Jul 15, 2024
1 parent 44cffc5 commit 3aa4031
Show file tree
Hide file tree
Showing 11 changed files with 397 additions and 284 deletions.
6 changes: 4 additions & 2 deletions examples/json-csv-reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,17 @@ def main():
print("========================================================================")
print("static CSV with header schema test parsing 3.5K objects")
print("========================================================================")
static_csv_ds = DataChain.from_csv(uri, spec=ChatFeature)
static_csv_ds = DataChain.from_csv(uri, output=ChatFeature, object_name="chat")
static_csv_ds.print_schema()
print(static_csv_ds.to_pandas())

uri = "gs://datachain-demo/laion-aesthetics-csv"
print()
print("========================================================================")
print("dynamic CSV with header schema test parsing 3M objects")
print("========================================================================")
dynamic_csv_ds = DataChain.from_csv(uri, object_name="laion", show_schema=True)
dynamic_csv_ds = DataChain.from_csv(uri, object_name="laion")
dynamic_csv_ds.print_schema()
print(dynamic_csv_ds.to_pandas())


Expand Down
225 changes: 114 additions & 111 deletions examples/multimodal/clip_fine_tuning.ipynb

Large diffs are not rendered by default.

9 changes: 3 additions & 6 deletions examples/wds.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,9 @@
.map(stem=lambda file: file.get_file_stem(), params=["emd.file"], output=str)
)

meta_pq = (
DataChain.from_storage("gs://dvcx-datacomp-small/metadata")
.filter(C.name.glob("0020f*.parquet"))
.parse_parquet()
.map(stem=lambda file: file.get_file_stem(), params=["source.file"], output=str)
)
meta_pq = DataChain.from_parquet(
"gs://dvcx-datacomp-small/metadata/0020f*.parquet"
).map(stem=lambda file: file.get_file_stem(), params=["source.file"], output=str)

meta = meta_emd.merge(
meta_pq, on=["stem", "emd.index"], right_on=["stem", "source.index"]
Expand Down
33 changes: 25 additions & 8 deletions src/datachain/lib/arrow.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import re
from collections.abc import Sequence
from typing import TYPE_CHECKING, Optional

import pyarrow as pa
from pyarrow.dataset import dataset

from datachain.lib.file import File, IndexedFile
from datachain.lib.udf import Generator

if TYPE_CHECKING:
import pyarrow as pa
from datachain.lib.dc import DataChain


class ArrowGenerator(Generator):
Expand Down Expand Up @@ -35,12 +37,29 @@ def process(self, file: File):
index += 1


def schema_to_output(schema: "pa.Schema"):
def infer_schema(chain: "DataChain", **kwargs) -> pa.Schema:
schemas = []
for file in chain.iterate_one("file"):
ds = dataset(file.get_path(), filesystem=file.get_fs(), **kwargs) # type: ignore[union-attr]
schemas.append(ds.schema)
return pa.unify_schemas(schemas)


def schema_to_output(schema: pa.Schema, col_names: Optional[Sequence[str]] = None):
"""Generate UDF output schema from pyarrow schema."""
if col_names and (len(schema) != len(col_names)):
raise ValueError(
"Error generating output from Arrow schema - "
f"Schema has {len(schema)} columns but got {len(col_names)} column names."
)
default_column = 0
output = {"source": IndexedFile}
for field in schema:
column = field.name.lower()
output = {}
for i, field in enumerate(schema):
if col_names:
column = col_names[i]
else:
column = field.name
column = column.lower()
column = re.sub("[^0-9a-z_]+", "", column)
if not column:
column = f"c{default_column}"
Expand All @@ -50,12 +69,10 @@ def schema_to_output(schema: "pa.Schema"):
return output


def _arrow_type_mapper(col_type: "pa.DataType") -> type: # noqa: PLR0911
def _arrow_type_mapper(col_type: pa.DataType) -> type: # noqa: PLR0911
"""Convert pyarrow types to basic types."""
from datetime import datetime

import pyarrow as pa

if pa.types.is_timestamp(col_type):
return datetime
if pa.types.is_binary(col_type):
Expand Down
Loading

0 comments on commit 3aa4031

Please sign in to comment.