From baec8e0347ac01f342c70c2cb90c5d31368d7431 Mon Sep 17 00:00:00 2001 From: Daniel K Date: Sat, 20 Jul 2024 14:57:35 -0700 Subject: [PATCH] fixing regressions from switching to ModelStore.add() (#109) --- examples/json-csv-reader.py | 12 +++++++----- src/datachain/lib/dc.py | 9 ++++++--- src/datachain/lib/meta_formats.py | 16 ++++++++++++++-- 3 files changed, 27 insertions(+), 10 deletions(-) diff --git a/examples/json-csv-reader.py b/examples/json-csv-reader.py index faca24406..dce4ecb53 100644 --- a/examples/json-csv-reader.py +++ b/examples/json-csv-reader.py @@ -23,8 +23,8 @@ from pydantic import BaseModel +from datachain.lib.data_model import ModelStore from datachain.lib.dc import C, DataChain -from datachain.lib.feature_utils import pydantic_to_feature # Sample model for static JSON model @@ -34,7 +34,7 @@ class LicenseModel(BaseModel): name: str -LicenseFeature = pydantic_to_feature(LicenseModel) +LicenseFeature = ModelStore.add(LicenseModel) # Sample model for static CSV model @@ -45,7 +45,7 @@ class ChatDialog(BaseModel): text: Optional[str] = None -ChatFeature = pydantic_to_feature(ChatDialog) +ChatFeature = ModelStore.add(ChatDialog) def main(): @@ -86,9 +86,11 @@ def main(): print() print("========================================================================") - print("static JSON schema test parsing 7 objects") + print("static JSON schema test parsing 3/7 objects") print("========================================================================") - static_json_ds = DataChain.from_json(uri, jmespath="licenses", spec=LicenseFeature) + static_json_ds = DataChain.from_json( + uri, jmespath="licenses", spec=LicenseFeature, nrows=3 + ) print(static_json_ds.to_pandas()) print() diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index cb8295c89..4702f3245 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -346,6 +346,7 @@ def from_json( model_name: Optional[str] = None, show_schema: Optional[bool] = False, meta_type: Optional[str] = "json", + nrows=None, **kwargs, ) -> "DataChain": """Get data from JSON. It returns the chain itself. @@ -355,11 +356,12 @@ def from_json( as `s3://`, `gs://`, `az://` or "file:///" type : read file as "binary", "text", or "image" data. Default is "binary". spec : optional Data Model - schema_from : path to sample to infer spec from + schema_from : path to sample to infer spec (if schema not provided) object_name : generated object column name - model_name : generated model name + model_name : optional generated model name show_schema : print auto-generated schema - jmespath : JMESPATH expression to reduce JSON + jmespath : optional JMESPATH expression to reduce JSON + nrows : optional row limit for jsonl and JSON arrays Example: infer JSON schema from data, reduce using JMESPATH, print schema @@ -392,6 +394,7 @@ def jmespath_to_name(s: str): model_name=model_name, show_schema=show_schema, jmespath=jmespath, + nrows=nrows, ) } return chain.gen(**signal_dict) # type: ignore[arg-type] diff --git a/src/datachain/lib/meta_formats.py b/src/datachain/lib/meta_formats.py index 4eeb903ce..af146e946 100644 --- a/src/datachain/lib/meta_formats.py +++ b/src/datachain/lib/meta_formats.py @@ -13,6 +13,7 @@ import jmespath as jsp from pydantic import ValidationError +from datachain.lib.data_model import ModelStore # noqa: F401 from datachain.lib.file import File @@ -86,6 +87,8 @@ def read_schema(source_file, data_type="csv", expr=None, model_name=None): except subprocess.CalledProcessError as e: model_output = f"An error occurred in datamodel-codegen: {e.stderr}" print(f"{model_output}") + print("\n" + f"ModelStore.add({model_name})" + "\n") + print("\n" + f"spec={model_name}" + "\n") return model_output @@ -99,6 +102,7 @@ def read_meta( # noqa: C901 jmespath=None, show_schema=False, model_name=None, + nrows=None, ) -> Callable: from datachain.lib.dc import DataChain @@ -118,8 +122,7 @@ def read_meta( # noqa: C901 output=str, ) ) - # dummy executor (#1616) - chain.save() + chain.exec() finally: sys.stdout = current_stdout model_output = captured_output.getvalue() @@ -147,6 +150,7 @@ def parse_data( DataModel=spec, # noqa: N803 meta_type=meta_type, jmespath=jmespath, + nrows=nrows, ) -> Iterator[spec]: def validator(json_object: dict) -> spec: json_string = json.dumps(json_object) @@ -175,14 +179,22 @@ def validator(json_object: dict) -> spec: yield from validator(json_object) else: + nrow = 0 for json_dict in json_object: + nrow = nrow + 1 + if nrows is not None and nrow > nrows: + return yield from validator(json_dict) if meta_type == "jsonl": try: + nrow = 0 with file.open() as fd: data_string = fd.readline().replace("\r", "") while data_string: + nrow = nrow + 1 + if nrows is not None and nrow > nrows: + return json_object = process_json(data_string, jmespath) data_string = fd.readline() yield from validator(json_object)