Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix array cast/embed with null values #6283

Merged
merged 26 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ jobs:
run: pip install --upgrade pyarrow huggingface-hub dill
- name: Install dependencies (minimum versions)
if: ${{ matrix.deps_versions != 'deps-latest' }}
run: pip install pyarrow==8.0.0 huggingface-hub==0.19.4 transformers dill==0.3.1.1
run: pip install pyarrow==12.0.0 huggingface-hub==0.19.4 transformers dill==0.3.1.1
- name: Test with pytest
run: |
python -m pytest -rfExX -m ${{ matrix.test }} -n 2 --dist loadfile -sv ./tests/
Expand Down
8 changes: 4 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@
# We use numpy>=1.17 to have np.random.Generator (Dataset shuffling)
"numpy>=1.17",
# Backend and serialization.
# Minimum 8.0.0 to be able to use .to_reader()
"pyarrow>=8.0.0",
# Minimum 12.0.0 to be able to concatenate extension arrays
"pyarrow>=12.0.0",
# As long as we allow pyarrow < 14.0.1, to fix vulnerability CVE-2023-47248
"pyarrow-hotfix",
# For smart caching dataset processing
Expand Down Expand Up @@ -166,7 +166,7 @@
"pytest-datadir",
"pytest-xdist",
# optional dependencies
"apache-beam>=2.26.0,<2.44.0;python_version<'3.10'", # doesn't support recent dill versions for recent python versions
"apache-beam>=2.26.0; sys_platform != 'win32' and python_version<'3.10'", # doesn't support recent dill versions for recent python versions and on windows requires pyarrow<12.0.0
"elasticsearch<8.0.0", # 8.0 asks users to provide hosts or cloud_id when instantiating ElasticSearch()
"faiss-cpu>=1.6.4",
"jax>=0.3.14; sys_platform != 'win32'",
Expand Down Expand Up @@ -233,7 +233,7 @@
EXTRAS_REQUIRE = {
"audio": AUDIO_REQUIRE,
"vision": VISION_REQUIRE,
"apache-beam": ["apache-beam>=2.26.0,<2.44.0"],
"apache-beam": ["apache-beam>=2.26.0"],
"tensorflow": [
"tensorflow>=2.2.0,!=2.6.0,!=2.6.1; sys_platform != 'darwin' or platform_machine != 'arm64'",
"tensorflow-macos; sys_platform == 'darwin' and platform_machine == 'arm64'",
Expand Down
36 changes: 16 additions & 20 deletions src/datasets/arrow_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from .filesystems import is_remote_filesystem
from .info import DatasetInfo
from .keyhash import DuplicatedKeysError, KeyHasher
from .table import array_cast, array_concat, cast_array_to_feature, embed_table_storage, table_cast
from .table import array_cast, cast_array_to_feature, embed_table_storage, table_cast
from .utils import logging
from .utils import tqdm as hf_tqdm
from .utils.file_utils import hash_url_to_filename
Expand Down Expand Up @@ -441,7 +441,12 @@ def write_examples_on_file(self):
# This can happen in `.map()` when we want to re-write the same Arrow data
if all(isinstance(row[0][col], (pa.Array, pa.ChunkedArray)) for row in self.current_examples):
arrays = [row[0][col] for row in self.current_examples]
batch_examples[col] = array_concat(arrays)
arrays = [
chunk
for array in arrays
for chunk in (array.chunks if isinstance(array, pa.ChunkedArray) else [array])
]
batch_examples[col] = pa.concat_arrays(arrays)
else:
batch_examples[col] = [
row[0][col].to_pylist()[0] if isinstance(row[0][col], (pa.Array, pa.ChunkedArray)) else row[0][col]
Expand Down Expand Up @@ -669,33 +674,23 @@ def finalize(self, metrics_query_result: dict):
metrics_query_result: `dict` obtained from pipeline_results.metrics().query(m_filter). Make sure
that the filter keeps only the metrics for the considered split, under the namespace `split_name`.
"""
import apache_beam as beam

from .utils import beam_utils

# Beam FileSystems require the system's path separator in the older versions
fs, _, [parquet_path] = fsspec.get_fs_token_paths(self._parquet_path)
parquet_path = str(Path(parquet_path)) if not is_remote_filesystem(fs) else fs.unstrip_protocol(parquet_path)

shards_metadata = list(beam.io.filesystems.FileSystems.match([parquet_path + "*.parquet"])[0].metadata_list)
shards = [metadata.path for metadata in shards_metadata]
num_bytes = sum([metadata.size_in_bytes for metadata in shards_metadata])
shards = fs.glob(parquet_path + "*.parquet")
num_bytes = sum(fs.sizes(shards))
shard_lengths = get_parquet_lengths(shards)

# Convert to arrow
if self._path.endswith(".arrow"):
logger.info(f"Converting parquet files {self._parquet_path} to arrow {self._path}")
shards = [
metadata.path
for metadata in beam.io.filesystems.FileSystems.match([parquet_path + "*.parquet"])[0].metadata_list
]
try: # stream conversion
num_bytes = 0
for shard in hf_tqdm(shards, unit="shards"):
with beam.io.filesystems.FileSystems.open(shard) as source:
with beam.io.filesystems.FileSystems.create(
shard.replace(".parquet", ".arrow")
) as destination:
with fs.open(shard, "rb") as source:
with fs.open(shard.replace(".parquet", ".arrow"), "wb") as destination:
shard_num_bytes, _ = parquet_to_arrow(source, destination)
num_bytes += shard_num_bytes
except OSError as e: # broken pipe can happen if the connection is unstable, do local conversion instead
Expand All @@ -709,12 +704,12 @@ def finalize(self, metrics_query_result: dict):
num_bytes = 0
for shard in hf_tqdm(shards, unit="shards"):
local_parquet_path = os.path.join(local_convert_dir, hash_url_to_filename(shard) + ".parquet")
beam_utils.download_remote_to_local(shard, local_parquet_path)
fs.download(shard, local_parquet_path)
local_arrow_path = local_parquet_path.replace(".parquet", ".arrow")
shard_num_bytes, _ = parquet_to_arrow(local_parquet_path, local_arrow_path)
num_bytes += shard_num_bytes
remote_arrow_path = shard.replace(".parquet", ".arrow")
beam_utils.upload_local_to_remote(local_arrow_path, remote_arrow_path)
fs.upload(local_arrow_path, remote_arrow_path)

# Save metrics
counters_dict = {metric.key.metric.name: metric.result for metric in metrics_query_result["counters"]}
Expand All @@ -735,8 +730,9 @@ def get_parquet_lengths(sources) -> List[int]:
def parquet_to_arrow(source, destination) -> List[int]:
"""Convert parquet file to arrow file. Inputs can be str paths or file-like objects"""
stream = None if isinstance(destination, str) else destination
with ArrowWriter(path=destination, stream=stream) as writer:
parquet_file = pa.parquet.ParquetFile(source)
parquet_file = pa.parquet.ParquetFile(source)
# Beam can create empty Parquet files, so we need to pass the source Parquet file's schema
with ArrowWriter(schema=parquet_file.schema_arrow, path=destination, stream=stream) as writer:
for record_batch in parquet_file.iter_batches():
pa_table = pa.Table.from_batches([record_batch])
writer.write_table(pa_table)
Expand Down
Loading
Loading