Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Arrow nrows fix #221

Merged
merged 6 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/get_started/json-csv-reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def main():
static_csv_ds.print_schema()
print(static_csv_ds.to_pandas())

uri = "gs://datachain-demo/laion-aesthetics-csv"
uri = "gs://datachain-demo/laion-aesthetics-csv/laion_aesthetics_1024_33M_1.csv"
print()
print("========================================================================")
print("dynamic CSV with header schema test parsing 3/3M objects")
Expand Down
29 changes: 22 additions & 7 deletions src/datachain/lib/arrow.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
from collections.abc import Sequence
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Optional

import pyarrow as pa
Expand Down Expand Up @@ -43,13 +44,17 @@ def __init__(
self.kwargs = kwargs

def process(self, file: File):
path = file.get_path()
ds = dataset(
path, filesystem=file.get_fs(), schema=self.input_schema, **self.kwargs
)
if self.nrows:
path = _nrows_file(file, self.nrows)
ds = dataset(path, schema=self.input_schema, **self.kwargs)
else:
path = file.get_path()
ds = dataset(
path, filesystem=file.get_fs(), schema=self.input_schema, **self.kwargs
)
index = 0
with tqdm(desc="Parsed by pyarrow", unit=" rows") as pbar:
for record_batch in ds.to_batches(use_threads=False):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why removing use_threads?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use_threads=False was a hack to make nrows not hang forever. Previously nrows broke the batch iteration and hung forever without use_threads=False. We no longer are messing with the batch iteration, so this isn't needed.

for record_batch in ds.to_batches():
for record in record_batch.to_pylist():
vals = list(record.values())
if self.output_schema:
Expand All @@ -60,8 +65,6 @@ def process(self, file: File):
else:
yield vals
index += 1
if self.nrows and index >= self.nrows:
return
pbar.update(len(record_batch))


Expand Down Expand Up @@ -125,3 +128,15 @@ def _arrow_type_mapper(col_type: pa.DataType) -> type: # noqa: PLR0911
if isinstance(col_type, pa.lib.DictionaryType):
return _arrow_type_mapper(col_type.value_type) # type: ignore[return-value]
raise TypeError(f"{col_type!r} datatypes not supported")


def _nrows_file(file: File, nrows: int) -> str:
tf = NamedTemporaryFile(delete=False)
with file.open(mode="r") as reader:
with open(tf.name, "a") as writer:
for row, line in enumerate(reader):
writer.write(line)
writer.write("\n")
if row >= nrows:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just to double check: will it support headers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. I forgot to support no-header use cases. Fixed in the latest commit by adding a row if there is a header.

break
return tf.name
15 changes: 12 additions & 3 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1263,8 +1263,20 @@ def parse_tabular(
dc = dc.parse_tabular(format="json")
```
"""
from pyarrow.dataset import CsvFileFormat, JsonFileFormat

from datachain.lib.arrow import ArrowGenerator, infer_schema, schema_to_output

if nrows:
format = kwargs.get("format")
if format not in ["csv", "json"] and not isinstance(
format, (CsvFileFormat, JsonFileFormat)
):
raise ValueError(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: raise NotImplemented?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to DatasetPrepareError since that's used as the main type for datachain errors.

"Error in `parse_tabular` - "
"`nrows` only supported for csv and json formats."
)

schema = None
col_names = output if isinstance(output, Sequence) else None
if col_names or not output:
Expand Down Expand Up @@ -1384,7 +1396,6 @@ def from_parquet(
object_name: str = "",
model_name: str = "",
source: bool = True,
nrows=None,
**kwargs,
) -> "DataChain":
"""Generate chain from parquet files.
Expand All @@ -1397,7 +1408,6 @@ def from_parquet(
object_name : Created object column name.
model_name : Generated model name.
source : Whether to include info about the source file.
nrows : Optional row limit.

Example:
Reading a single file:
Expand All @@ -1416,7 +1426,6 @@ def from_parquet(
object_name=object_name,
model_name=model_name,
source=source,
nrows=None,
format="parquet",
partitioning=partitioning,
)
Expand Down
6 changes: 3 additions & 3 deletions src/datachain/lib/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,9 +317,9 @@ class TextFile(File):
"""`DataModel` for reading text files."""

@contextmanager
def open(self):
"""Open the file and return a file object in text mode."""
with super().open(mode="r") as stream:
def open(self, mode: Literal["rb", "r"] = "r"):
"""Open the file and return a file object (default to text mode)."""
with super().open(mode=mode) as stream:
yield stream

def read_text(self):
Expand Down
12 changes: 6 additions & 6 deletions tests/unit/lib/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
_arrow_type_mapper,
schema_to_output,
)
from datachain.lib.file import File, IndexedFile
from datachain.lib.file import File, IndexedFile, TextFile


def test_arrow_generator(tmp_path, catalog):
Expand Down Expand Up @@ -41,13 +41,13 @@ def test_arrow_generator_nrows(tmp_path, catalog):
texts = ["28", "22", "we", "hello world"]
df = pd.DataFrame({"id": ids, "text": texts})

name = "111.parquet"
pq_path = tmp_path / name
df.to_parquet(pq_path)
stream = File(name=name, parent=tmp_path.as_posix(), source="file:///")
name = "111.csv"
csv_path = tmp_path / name
df.to_csv(csv_path)
stream = TextFile(name=name, parent=tmp_path.as_posix(), source="file:///")
stream._set_stream(catalog, caching_enabled=False)

func = ArrowGenerator(nrows=2)
func = ArrowGenerator(nrows=2, format="csv")
objs = list(func.process(stream))

assert len(objs) == 2
Expand Down
17 changes: 17 additions & 0 deletions tests/unit/lib/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,14 @@ def test_parse_tabular_output_list(tmp_dir, catalog):
assert df1.equals(df)


def test_parse_tabular_nrows_invalid(tmp_dir, catalog):
df = pd.DataFrame(DF_DATA)
path = tmp_dir / "test.parquet"
df.to_parquet(path)
with pytest.raises(ValueError):
DataChain.from_storage(path.as_uri()).parse_tabular(nrows=2)


def test_from_csv(tmp_dir, catalog):
df = pd.DataFrame(DF_DATA)
path = tmp_dir / "test.csv"
Expand Down Expand Up @@ -901,6 +909,15 @@ def test_from_csv_null_collect(tmp_dir, catalog):
assert row[1].height == height[i]


def test_from_csv_nrows(tmp_dir, catalog):
df = pd.DataFrame(DF_DATA)
path = tmp_dir / "test.csv"
df.to_csv(path, index=False)
dc = DataChain.from_csv(path.as_uri(), nrows=2)
df1 = dc.select("first_name", "age", "city").to_pandas()
assert df1.equals(df[:2])


def test_from_parquet(tmp_dir, catalog):
df = pd.DataFrame(DF_DATA)
path = tmp_dir / "test.parquet"
Expand Down