-
Notifications
You must be signed in to change notification settings - Fork 97
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Arrow nrows fix #221
Arrow nrows fix #221
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
import re | ||
from collections.abc import Sequence | ||
from tempfile import NamedTemporaryFile | ||
from typing import TYPE_CHECKING, Optional | ||
|
||
import pyarrow as pa | ||
|
@@ -43,13 +44,17 @@ def __init__( | |
self.kwargs = kwargs | ||
|
||
def process(self, file: File): | ||
path = file.get_path() | ||
ds = dataset( | ||
path, filesystem=file.get_fs(), schema=self.input_schema, **self.kwargs | ||
) | ||
if self.nrows: | ||
path = _nrows_file(file, self.nrows) | ||
ds = dataset(path, schema=self.input_schema, **self.kwargs) | ||
else: | ||
path = file.get_path() | ||
ds = dataset( | ||
path, filesystem=file.get_fs(), schema=self.input_schema, **self.kwargs | ||
) | ||
index = 0 | ||
with tqdm(desc="Parsed by pyarrow", unit=" rows") as pbar: | ||
for record_batch in ds.to_batches(use_threads=False): | ||
for record_batch in ds.to_batches(): | ||
for record in record_batch.to_pylist(): | ||
vals = list(record.values()) | ||
if self.output_schema: | ||
|
@@ -60,8 +65,6 @@ def process(self, file: File): | |
else: | ||
yield vals | ||
index += 1 | ||
if self.nrows and index >= self.nrows: | ||
return | ||
pbar.update(len(record_batch)) | ||
|
||
|
||
|
@@ -125,3 +128,15 @@ def _arrow_type_mapper(col_type: pa.DataType) -> type: # noqa: PLR0911 | |
if isinstance(col_type, pa.lib.DictionaryType): | ||
return _arrow_type_mapper(col_type.value_type) # type: ignore[return-value] | ||
raise TypeError(f"{col_type!r} datatypes not supported") | ||
|
||
|
||
def _nrows_file(file: File, nrows: int) -> str: | ||
tf = NamedTemporaryFile(delete=False) | ||
with file.open(mode="r") as reader: | ||
with open(tf.name, "a") as writer: | ||
for row, line in enumerate(reader): | ||
writer.write(line) | ||
writer.write("\n") | ||
if row >= nrows: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just to double check: will it support headers? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch. I forgot to support no-header use cases. Fixed in the latest commit by adding a row if there is a header. |
||
break | ||
return tf.name |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1263,8 +1263,20 @@ def parse_tabular( | |
dc = dc.parse_tabular(format="json") | ||
``` | ||
""" | ||
from pyarrow.dataset import CsvFileFormat, JsonFileFormat | ||
|
||
from datachain.lib.arrow import ArrowGenerator, infer_schema, schema_to_output | ||
|
||
if nrows: | ||
format = kwargs.get("format") | ||
if format not in ["csv", "json"] and not isinstance( | ||
format, (CsvFileFormat, JsonFileFormat) | ||
): | ||
raise ValueError( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. minor: raise NotImplemented? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changed to |
||
"Error in `parse_tabular` - " | ||
"`nrows` only supported for csv and json formats." | ||
) | ||
|
||
schema = None | ||
col_names = output if isinstance(output, Sequence) else None | ||
if col_names or not output: | ||
|
@@ -1384,7 +1396,6 @@ def from_parquet( | |
object_name: str = "", | ||
model_name: str = "", | ||
source: bool = True, | ||
nrows=None, | ||
**kwargs, | ||
) -> "DataChain": | ||
"""Generate chain from parquet files. | ||
|
@@ -1397,7 +1408,6 @@ def from_parquet( | |
object_name : Created object column name. | ||
model_name : Generated model name. | ||
source : Whether to include info about the source file. | ||
nrows : Optional row limit. | ||
|
||
Example: | ||
Reading a single file: | ||
|
@@ -1416,7 +1426,6 @@ def from_parquet( | |
object_name=object_name, | ||
model_name=model_name, | ||
source=source, | ||
nrows=None, | ||
format="parquet", | ||
partitioning=partitioning, | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why removing use_threads?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use_threads=False
was a hack to make nrows not hang forever. Previously nrows broke the batch iteration and hung forever withoutuse_threads=False
. We no longer are messing with the batch iteration, so this isn't needed.