Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix datasets export to JSON #7181

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
298 changes: 290 additions & 8 deletions src/datasets/io/json.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import multiprocessing
import os
from typing import BinaryIO, Optional, Union
Expand Down Expand Up @@ -119,46 +120,305 @@ def write(self) -> int:
return written

def _batch_json(self, args):
offset, orient, lines, to_json_kwargs = args
offset, orient, lines, column_names, to_json_kwargs = args

batch = query_table(
table=self.dataset.data,
key=slice(offset, offset + self.batch_size),
indices=self.dataset._indices,
)
json_str = batch.to_pandas().to_json(path_or_buf=None, orient=orient, lines=lines, **to_json_kwargs)

if orient in ["columns"]:
batch = batch.select([column_names])

df = batch.to_pandas()

if orient in ["index", "columns"]:
# Adjust the index to reflect the batch offset
df.index = range(offset, offset + len(df))

if orient in ["split"]:
# Index has already been taken care of
# to_json_kwargs["index"] = False
json_str = df.to_json(path_or_buf=None, orient="values", lines=lines, **to_json_kwargs)
else:
json_str = df.to_json(path_or_buf=None, orient=orient, lines=lines, **to_json_kwargs)

if orient in ["columns"]:
# Read the dict pointed by column_names key
json_str = json.dumps(json.loads(json_str)[column_names], ensure_ascii=False, separators=(",", ":"))

if not json_str.endswith("\n"):
json_str += "\n"

return json_str.encode(self.encoding)

def _write(
def _write_orient_split(
self,
file_obj: BinaryIO,
orient,
lines,
**to_json_kwargs,
):
"""Writes the dataset in 'split' orientation to the specified file object."""

written = 0
column_names = self.dataset.column_names
written += file_obj.write(
f"{{\"columns\":{json.dumps(self.dataset.column_names, separators = (',', ':'))},".encode(self.encoding)
)

if to_json_kwargs.get("index", False):
written += file_obj.write(
f"\"index\":{json.dumps(list(range(0, self.dataset.num_rows)), separators = (',', ':'))},".encode(
self.encoding
)
)

written += file_obj.write('"data":['.encode(self.encoding))
first_batch = True

if self.num_proc is None or self.num_proc == 1:
for offset in hf_tqdm(
range(0, len(self.dataset), self.batch_size), unit="ba", desc=f"Writing column {column_names}"
):
json_str = self._batch_json((offset, orient, lines, column_names, to_json_kwargs))
json_str = json_str.decode(self.encoding).strip()

# Remove the curly brackets from the batch string
if json_str.startswith("[") and json_str.endswith("]"):
json_str = json_str[1:-1]

if not first_batch:
written += file_obj.write(",".encode(self.encoding))
else:
first_batch = False

json_str = json_str.encode(self.encoding)
written += file_obj.write(json_str)

else:
num_rows, batch_size = len(self.dataset), self.batch_size
with multiprociessing.Pool(self.num_proc) as pool:
for json_str in hf_tqdm(
pool.imap(
self._batch_json,
[
(offset, orient, lines, column_names, to_json_kwargs)
for offset in range(0, num_rows, batch_size)
],
),
total=(num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size,
units="ba",
desc="Writing JSON lines",
):
json_str = json_str.decode(self.encoding).strip()

# Remove the curly brackets from the batch string
if json_str.startswith("[") and json_str.endswith("]"):
json_str = json_str[1:-1]

if not first_batch:
written += file_obj.write(",".encode(self.encoding))
else:
first_batch = False

json_str = json_str.encode(self.encoding)
written += file_obj.write(json_str)

written += file_obj.write("]}".encode(self.encoding))

return written

def _write_orient_columns(
self,
file_obj: BinaryIO,
orient,
lines,
**to_json_kwargs,
):
"""Handles writing to file when orient in ['columns']"""
written = 0
first_column = True

written += file_obj.write("{".encode(self.encoding))

for column_name in self.dataset.column_names:
if not first_column:
written += file_obj.write(",".encode(self.encoding))
else:
first_column = False

written += file_obj.write(f'"{column_name}":{{'.encode(self.encoding))
first_batch = True

if self.num_proc is None or self.num_proc == 1:
for offset in hf_tqdm(
range(0, len(self.dataset), self.batch_size), unit="ba", desc=f"Writing column {column_name}"
):
json_str = self._batch_json((offset, orient, lines, column_name, to_json_kwargs))

json_str = json_str.decode(self.encoding).strip()

# Remove the curly brackets from the batch string
if json_str.startswith("{") and json_str.endswith("}"):
json_str = json_str[1:-1]

if not first_batch:
written += file_obj.write(",".encode(self.encoding))
else:
first_batch = False

json_str = json_str.encode(self.encoding)
written += file_obj.write(json_str)
else:
num_rows, batch_size = len(self.dataset), self.batch_size
with multiprocessing.Pool(self.num_proc) as pool:
for json_str in hf_tqdm(
pool.imap(
self._batch_json,
[
(offset, orient, lines, column_name, to_json_kwargs)
for offset in range(0, num_rows, batch_size)
],
),
total=(num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size,
unit="ba",
desc="Writing JSON lines",
):
json_str = json_str.decode(self.encoding).strip()

# Remove the curly brackets from the batch string
if json_str.startswith("{") and json_str.endswith("}"):
json_str = json_str[1:-1]

if not first_batch:
written += file_obj.write(",".encode(self.encoding))
else:
first_batch = False

json_str = json_str.encode(self.encoding)
written += file_obj.write(json_str)

written += file_obj.write("}".encode(self.encoding))

written += file_obj.write("}".encode(self.encoding))

return written

def _write_orient_list_like(
self,
file_obj: BinaryIO,
orient,
lines,
**to_json_kwargs,
):
"""Handles writing to file when orient in ['records', 'values', 'index']"""

written = 0
first_batch = True
column_names = self.dataset.column_names

if not lines:
if orient in ["records", "values"]:
written += file_obj.write("[".encode(self.encoding))
elif orient in ["index"]:
written += file_obj.write("{".encode(self.encoding))

if self.num_proc is None or self.num_proc == 1:
for offset in hf_tqdm(range(0, len(self.dataset), self.batch_size), unit="ba", desc="Writing JSON lines"):
json_str = self._batch_json((offset, orient, lines, column_names, to_json_kwargs))

if not lines:
json_str = json_str.decode(self.encoding).strip()

# Remove the square or curly brackets from the batch string
if (json_str.startswith("[") and json_str.endswith("]")) or (
json_str.startswith("{") and json_str.endswith("}")
):
json_str = json_str[1:-1]

if not first_batch:
# Add a comma between batches
written += file_obj.write(",".encode(self.encoding))
else:
first_batch = False

json_str = json_str.encode(self.encoding)

written += file_obj.write(json_str)
else:
num_rows, batch_size = len(self.dataset), self.batch_size
with multiprocessing.Pool(self.num_proc) as pool:
for json_str in hf_tqdm(
pool.imap(
self._batch_json,
[
(offset, orient, lines, column_names, to_json_kwargs)
for offset in range(0, num_rows, batch_size)
],
),
total=(num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size,
unit="ba",
desc="Writing JSON lines",
):
if not lines:
json_str = json_str.decode(self.encoding).strip()

# Remove the square or curly brackets from the batch string
if (json_str.startswith("[") and json_str.endswith("]")) or (
json_str.startswith("{") and json_str.endswith("}")
):
json_str = json_str[1:-1]

if not first_batch:
# Add a comma between batches
written += file_obj.write(",".encode(self.encoding))
else:
first_batch = False

json_str = json_str.encode(self.encoding)

written += file_obj.write(json_str)

if not lines:
if orient in ["records", "values"]:
written += file_obj.write("]".encode(self.encoding))
elif orient in ["index"]:
written += file_obj.write("}".encode(self.encoding))

return written

def _write_legacy(
self,
file_obj: BinaryIO,
orient,
lines,
**to_json_kwargs,
) -> int:
"""Writes the pyarrow table as JSON lines to a binary file handle.
"""Handles writing to file when orient in ['table']"""

Caller is responsible for opening and closing the handle.
"""
written = 0
column_names = self.dataset.column_names

if self.num_proc is None or self.num_proc == 1:
for offset in hf_tqdm(
range(0, len(self.dataset), self.batch_size),
unit="ba",
desc="Creating json from Arrow format",
):
json_str = self._batch_json((offset, orient, lines, to_json_kwargs))
json_str = self._batch_json((offset, orient, lines, column_names, to_json_kwargs))
written += file_obj.write(json_str)
else:
num_rows, batch_size = len(self.dataset), self.batch_size
with multiprocessing.Pool(self.num_proc) as pool:
for json_str in hf_tqdm(
pool.imap(
self._batch_json,
[(offset, orient, lines, to_json_kwargs) for offset in range(0, num_rows, batch_size)],
[
(offset, orient, lines, column_names, to_json_kwargs)
for offset in range(0, num_rows, batch_size)
],
),
total=(num_rows // batch_size) + 1 if num_rows % batch_size else num_rows // batch_size,
unit="ba",
Expand All @@ -167,3 +427,25 @@ def _write(
written += file_obj.write(json_str)

return written

def _write(
self,
file_obj: BinaryIO,
orient,
lines,
**to_json_kwargs,
) -> int:
"""Writes the pyarrow table as JSON, dispatching based on the orient and lines."""

written = 0

if orient in ["records", "values", "index"]:
written = self._write_orient_list_like(file_obj=file_obj, orient=orient, lines=lines, **to_json_kwargs)
elif orient in ["columns"]:
written = self._write_orient_columns(file_obj=file_obj, orient=orient, lines=lines, **to_json_kwargs)
elif orient in ["split"]:
written = self._write_orient_split(file_obj=file_obj, orient=orient, lines=lines, **to_json_kwargs)
else:
written = self._write_legacy(file_obj=file_obj, orient=orient, lines=lines, **to_json_kwargs)

return written