Skip to content

Commit

Permalink
Adding JSON / JSON Lines Export Support (#538)
Browse files Browse the repository at this point in the history
  • Loading branch information
dtulga authored Oct 26, 2024
1 parent 34e7c2b commit cfe3d9c
Show file tree
Hide file tree
Showing 5 changed files with 316 additions and 1 deletion.
75 changes: 74 additions & 1 deletion src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
from datachain.query.schema import DEFAULT_DELIMITER, Column, ColumnMeta
from datachain.sql.functions import path as pathfunc
from datachain.telemetry import telemetry
from datachain.utils import batched_it, inside_notebook
from datachain.utils import batched_it, inside_notebook, row_to_nested_dict

if TYPE_CHECKING:
from pyarrow import DataType as ArrowDataType
Expand Down Expand Up @@ -2051,6 +2051,79 @@ def to_csv(
for row in results_iter:
writer.writerow(row)

def to_json(
self,
path: Union[str, os.PathLike[str]],
fs_kwargs: Optional[dict[str, Any]] = None,
include_outer_list: bool = True,
) -> None:
"""Save chain to a JSON file.
Parameters:
path : Path to save the file. This supports local paths as well as
remote paths, such as s3:// or hf:// with fsspec.
fs_kwargs : Optional kwargs to pass to the fsspec filesystem, used only for
write, for fsspec-type URLs, such as s3:// or hf:// when
provided as the destination path.
include_outer_list : Sets whether to include an outer list for all rows.
Setting this to True makes the file valid JSON, while False instead
writes in the JSON lines format.
"""
opener = open

if isinstance(path, str) and "://" in path:
from datachain.client.fsspec import Client

fs_kwargs = {
**self._query.catalog.client_config,
**(fs_kwargs or {}),
}

client = Client.get_implementation(path)

fsspec_fs = client.create_fs(**fs_kwargs)

opener = fsspec_fs.open

headers, _ = self._effective_signals_schema.get_headers_with_length()
headers = [list(filter(None, header)) for header in headers]

is_first = True

with opener(path, "wb") as f:
if include_outer_list:
# This makes the file JSON instead of JSON lines.
f.write(b"[\n")
for row in self.collect_flatten():
if not is_first:
if include_outer_list:
# This makes the file JSON instead of JSON lines.
f.write(b",\n")
else:
f.write(b"\n")
else:
is_first = False
f.write(orjson.dumps(row_to_nested_dict(headers, row)))
if include_outer_list:
# This makes the file JSON instead of JSON lines.
f.write(b"\n]\n")

def to_jsonl(
self,
path: Union[str, os.PathLike[str]],
fs_kwargs: Optional[dict[str, Any]] = None,
) -> None:
"""Save chain to a JSON lines file.
Parameters:
path : Path to save the file. This supports local paths as well as
remote paths, such as s3:// or hf:// with fsspec.
fs_kwargs : Optional kwargs to pass to the fsspec filesystem, used only for
write, for fsspec-type URLs, such as s3:// or hf:// when
provided as the destination path.
"""
self.to_json(path, fs_kwargs, include_outer_list=False)

@classmethod
def from_records(
cls,
Expand Down
24 changes: 24 additions & 0 deletions src/datachain/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,3 +455,27 @@ def env2bool(var, undefined=False):
if var is None:
return undefined
return bool(re.search("1|y|yes|true", var, flags=re.IGNORECASE))


def nested_dict_path_set(
data: dict[str, Any], path: Sequence[str], value: Any
) -> dict[str, Any]:
"""Sets a value inside a nested dict based on the list of dict keys as a path,
and will create sub-dicts as needed to set the value."""
sub_data = data
for element in path[:-1]:
if element not in sub_data:
sub_data[element] = {}
sub_data = sub_data[element]
sub_data[path[len(path) - 1]] = value
return data


def row_to_nested_dict(
headers: Iterable[Sequence[str]], row: Iterable[Any]
) -> dict[str, Any]:
"""Converts a row to a nested dict based on the provided headers."""
result: dict[str, Any] = {}
for h, v in zip(headers, row):
nested_dict_path_set(result, h, v)
return result
32 changes: 32 additions & 0 deletions tests/func/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -1505,3 +1505,35 @@ def test_to_from_parquet_partitioned_remote(cloud_test_catalog_upload, chunk_siz
df1 = dc_from.select("first_name", "age", "city").to_pandas()
df1 = df1.sort_values("first_name").reset_index(drop=True)
assert df1.equals(df)


# These deprecation warnings occur in the datamodel-code-generator package.
@pytest.mark.filterwarnings("ignore::pydantic.warnings.PydanticDeprecatedSince20")
def test_to_from_json_remote(cloud_test_catalog_upload):
ctc = cloud_test_catalog_upload
path = f"{ctc.src_uri}/test.json"

df = pd.DataFrame(DF_DATA)
dc_to = DataChain.from_pandas(df, session=ctc.session)
dc_to.to_json(path)

dc_from = DataChain.from_json(path, session=ctc.session)
df1 = dc_from.select("json.first_name", "json.age", "json.city").to_pandas()
df1 = df1["json"]
assert df1.equals(df)


# These deprecation warnings occur in the datamodel-code-generator package.
@pytest.mark.filterwarnings("ignore::pydantic.warnings.PydanticDeprecatedSince20")
def test_to_from_jsonl_remote(cloud_test_catalog_upload):
ctc = cloud_test_catalog_upload
path = f"{ctc.src_uri}/test.jsonl"

df = pd.DataFrame(DF_DATA)
dc_to = DataChain.from_pandas(df, session=ctc.session)
dc_to.to_jsonl(path)

dc_from = DataChain.from_jsonl(path, session=ctc.session)
df1 = dc_from.select("jsonl.first_name", "jsonl.age", "jsonl.city").to_pandas()
df1 = df1["jsonl"]
assert df1.equals(df)
139 changes: 139 additions & 0 deletions tests/unit/lib/test_datachain.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
import json
import math
import os
import re
Expand Down Expand Up @@ -1275,6 +1276,144 @@ def test_to_csv_features_nested(tmp_dir, test_session):
]


# These deprecation warnings occur in the datamodel-code-generator package.
@pytest.mark.filterwarnings("ignore::pydantic.warnings.PydanticDeprecatedSince20")
def test_to_from_json(tmp_dir, test_session):
df = pd.DataFrame(DF_DATA)
dc_to = DataChain.from_pandas(df, session=test_session)
path = tmp_dir / "test.json"
dc_to.to_json(path)

with open(path) as f:
values = json.load(f)
assert values == [
{"first_name": n, "age": a, "city": c}
for n, a, c in zip(DF_DATA["first_name"], DF_DATA["age"], DF_DATA["city"])
]

dc_from = DataChain.from_json(path.as_uri(), session=test_session)
df1 = dc_from.select("json.first_name", "json.age", "json.city").to_pandas()
df1 = df1["json"]
assert df1.equals(df)


# These deprecation warnings occur in the datamodel-code-generator package.
@pytest.mark.filterwarnings("ignore::pydantic.warnings.PydanticDeprecatedSince20")
def test_from_json_jmespath(tmp_dir, test_session):
df = pd.DataFrame(DF_DATA)
values = [
{"first_name": n, "age": a, "city": c}
for n, a, c in zip(DF_DATA["first_name"], DF_DATA["age"], DF_DATA["city"])
]
path = tmp_dir / "test.json"
with open(path, "w") as f:
json.dump({"author": "Test User", "version": 5, "values": values}, f)

dc_from = DataChain.from_json(
path.as_uri(), jmespath="values", session=test_session
)
df1 = dc_from.select("values.first_name", "values.age", "values.city").to_pandas()
df1 = df1["values"]
assert df1.equals(df)


def test_to_json_features(tmp_dir, test_session):
dc_to = DataChain.from_values(
f1=features, num=range(len(features)), session=test_session
)
path = tmp_dir / "test.json"
dc_to.to_json(path)
with open(path) as f:
values = json.load(f)
assert values == [
{"f1": {"nnn": f.nnn, "count": f.count}, "num": n}
for n, f in enumerate(features)
]


def test_to_json_features_nested(tmp_dir, test_session):
dc_to = DataChain.from_values(sign1=features_nested, session=test_session)
path = tmp_dir / "test.json"
dc_to.to_json(path)
with open(path) as f:
values = json.load(f)
assert values == [
{"sign1": {"label": f"label_{n}", "fr": {"nnn": f.nnn, "count": f.count}}}
for n, f in enumerate(features)
]


# These deprecation warnings occur in the datamodel-code-generator package.
@pytest.mark.filterwarnings("ignore::pydantic.warnings.PydanticDeprecatedSince20")
def test_to_from_jsonl(tmp_dir, test_session):
df = pd.DataFrame(DF_DATA)
dc_to = DataChain.from_pandas(df, session=test_session)
path = tmp_dir / "test.jsonl"
dc_to.to_jsonl(path)

with open(path) as f:
values = [json.loads(line) for line in f.read().split("\n")]
assert values == [
{"first_name": n, "age": a, "city": c}
for n, a, c in zip(DF_DATA["first_name"], DF_DATA["age"], DF_DATA["city"])
]

dc_from = DataChain.from_jsonl(path.as_uri(), session=test_session)
df1 = dc_from.select("jsonl.first_name", "jsonl.age", "jsonl.city").to_pandas()
df1 = df1["jsonl"]
assert df1.equals(df)


# These deprecation warnings occur in the datamodel-code-generator package.
@pytest.mark.filterwarnings("ignore::pydantic.warnings.PydanticDeprecatedSince20")
def test_from_jsonl_jmespath(tmp_dir, test_session):
df = pd.DataFrame(DF_DATA)
values = [
{"first_name": n, "age": a, "city": c}
for n, a, c in zip(DF_DATA["first_name"], DF_DATA["age"], DF_DATA["city"])
]
path = tmp_dir / "test.jsonl"
with open(path, "w") as f:
for v in values:
f.write(
json.dumps({"data": "Contained Within", "row_version": 5, "value": v})
)
f.write("\n")

dc_from = DataChain.from_jsonl(
path.as_uri(), jmespath="value", session=test_session
)
df1 = dc_from.select("value.first_name", "value.age", "value.city").to_pandas()
df1 = df1["value"]
assert df1.equals(df)


def test_to_jsonl_features(tmp_dir, test_session):
dc_to = DataChain.from_values(
f1=features, num=range(len(features)), session=test_session
)
path = tmp_dir / "test.json"
dc_to.to_jsonl(path)
with open(path) as f:
values = [json.loads(line) for line in f.read().split("\n")]
assert values == [
{"f1": {"nnn": f.nnn, "count": f.count}, "num": n}
for n, f in enumerate(features)
]


def test_to_jsonl_features_nested(tmp_dir, test_session):
dc_to = DataChain.from_values(sign1=features_nested, session=test_session)
path = tmp_dir / "test.json"
dc_to.to_jsonl(path)
with open(path) as f:
values = [json.loads(line) for line in f.read().split("\n")]
assert values == [
{"sign1": {"label": f"label_{n}", "fr": {"nnn": f.nnn, "count": f.count}}}
for n, f in enumerate(features)
]


def test_from_parquet(tmp_dir, test_session):
df = pd.DataFrame(DF_DATA)
path = tmp_dir / "test.parquet"
Expand Down
47 changes: 47 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from datachain.utils import (
datachain_paths_join,
determine_processes,
nested_dict_path_set,
retry_with_backoff,
row_to_nested_dict,
sizeof_fmt,
sql_escape_like,
suffix_to_number,
Expand Down Expand Up @@ -170,3 +172,48 @@ def test_determine_processes(parallel, settings, expected):
)
def test_uses_glob(path, expected):
assert uses_glob(path) is expected


@pytest.mark.parametrize(
"data,path,value,expected",
(
({}, ["test"], True, {"test": True}),
({"extra": False}, ["test"], True, {"extra": False, "test": True}),
(
{"extra": False},
["test", "nested"],
True,
{"extra": False, "test": {"nested": True}},
),
(
{"extra": False},
["test", "nested", "deep"],
True,
{"extra": False, "test": {"nested": {"deep": True}}},
),
(
{"extra": False, "test": {"test2": 5, "nested": {}}},
["test", "nested", "deep"],
True,
{"extra": False, "test": {"test2": 5, "nested": {"deep": True}}},
),
),
)
def test_nested_dict_path_set(data, path, value, expected):
assert nested_dict_path_set(data, path, value) == expected


@pytest.mark.parametrize(
"headers,row,expected",
(
([["a"], ["b"]], (3, 7), {"a": 3, "b": 7}),
([["a"], ["b", "c"]], (3, 7), {"a": 3, "b": {"c": 7}}),
(
[["a", "b"], ["a", "c"], ["d"], ["a", "e", "f"]],
(1, 5, "test", 11),
{"a": {"b": 1, "c": 5, "e": {"f": 11}}, "d": "test"},
),
),
)
def test_row_to_nested_dict(headers, row, expected):
assert row_to_nested_dict(headers, row) == expected

0 comments on commit cfe3d9c

Please sign in to comment.