Skip to content

Commit

Permalink
unpin pydantic, use python API for datamodel_codegen (#400)
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry authored Sep 6, 2024
1 parent f881a93 commit f19095b
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 54 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ dependencies = [
"dill==0.3.8",
"cloudpickle",
"orjson>=3.10.5",
"pydantic>2,<2.9",
"pydantic>=2,<3",
"jmespath>=1.0",
"datamodel-code-generator>=0.25",
"Pillow>=10.0.0,<11",
Expand Down
89 changes: 36 additions & 53 deletions src/datachain/lib/meta_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
# pip install jmespath
#
import csv
import io
import json
import subprocess
import sys
import tempfile
import uuid
from collections.abc import Iterator
from pathlib import Path
from typing import Any, Callable

import datamodel_code_generator
import jmespath as jsp
from pydantic import BaseModel, ConfigDict, Field, ValidationError # noqa: F401

Expand Down Expand Up @@ -47,9 +47,8 @@ def read_schema(source_file, data_type="csv", expr=None, model_name=None):
data_string = ""
# using uiid to get around issue #1617
if not model_name:
uid_str = str(generate_uuid()).replace(
"-", ""
) # comply with Python class names
# comply with Python class names
uid_str = str(generate_uuid()).replace("-", "")
model_name = f"Model{data_type}{uid_str}"
try:
with source_file.open() as fd: # CSV can be larger than memory
Expand All @@ -70,33 +69,27 @@ def read_schema(source_file, data_type="csv", expr=None, model_name=None):
if data_type == "jsonl":
data_type = "json" # treat json line as plain JSON in auto-schema
data_string = json.dumps(json_object)
command = [
"datamodel-codegen",
"--input-file-type",
data_type,
"--class-name",
model_name,
"--base-class",
"datachain.lib.meta_formats.UserModel",
]
try:
result = subprocess.run( # noqa: S603
command,
input=data_string,
text=True,
capture_output=True,
check=True,

input_file_types = {i.value: i for i in datamodel_code_generator.InputFileType}
input_file_type = input_file_types[data_type]
with tempfile.TemporaryDirectory() as tmpdir:
output = Path(tmpdir) / "model.py"
datamodel_code_generator.generate(
data_string,
input_file_type=input_file_type,
output=output,
target_python_version=datamodel_code_generator.PythonVersion.PY_39,
base_class="datachain.lib.meta_formats.UserModel",
class_name=model_name,
additional_imports=["datachain.lib.data_model.DataModel"],
use_standard_collections=True,
)
model_output = (
result.stdout
) # This will contain the output from datamodel-codegen
except subprocess.CalledProcessError as e:
model_output = f"An error occurred in datamodel-codegen: {e.stderr}"
print(f"{model_output}")
print("from datachain.lib.data_model import DataModel")
print("\n" + f"DataModel.register({model_name})" + "\n")
print("\n" + f"spec={model_name}" + "\n")
return model_output
epilogue = f"""
{model_name}.model_rebuild()
DataModel.register({model_name})
spec = {model_name}
"""
return output.read_text() + epilogue


#
Expand All @@ -113,34 +106,24 @@ def read_meta( # noqa: C901
) -> Callable:
from datachain.lib.dc import DataChain

# ugly hack: datachain is run redirecting printed outputs to a variable
if schema_from:
captured_output = io.StringIO()
current_stdout = sys.stdout
sys.stdout = captured_output
try:
chain = (
DataChain.from_storage(schema_from, type="text")
.limit(1)
.map( # dummy column created (#1615)
meta_schema=lambda file: read_schema(
file, data_type=meta_type, expr=jmespath, model_name=model_name
),
output=str,
)
chain = (
DataChain.from_storage(schema_from, type="text")
.limit(1)
.map( # dummy column created (#1615)
meta_schema=lambda file: read_schema(
file, data_type=meta_type, expr=jmespath, model_name=model_name
),
output=str,
)
chain.exec()
finally:
sys.stdout = current_stdout
model_output = captured_output.getvalue()
captured_output.close()

)
(model_output,) = chain.collect("meta_schema")
if print_schema:
print(f"{model_output}")
# Below 'spec' should be a dynamically converted DataModel from Pydantic
if not spec:
local_vars: dict[str, Any] = {}
exec(model_output, globals(), local_vars) # noqa: S102
exec(model_output, globals(), local_vars) # type: ignore[arg-type] # noqa: S102
spec = local_vars["spec"]

if not (spec) and not (schema_from):
Expand Down
88 changes: 88 additions & 0 deletions tests/func/test_meta_formats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import json

import pytest

from datachain.lib.file import TextFile
from datachain.lib.meta_formats import read_meta, read_schema

example = {
"id": "1",
"split": "test",
"image_id": {
"author": "author",
"title": "title",
"size": 5090109,
"md5": "md5",
"url": "https://example.org/image.jpg",
"rotation": 0.0,
},
"classifications": [
{"Source": "verification", "LabelName": "label", "Confidence": 0}
],
}


@pytest.mark.filterwarnings("ignore::pydantic.warnings.PydanticDeprecatedSince20")
def test_read_schema(tmp_dir, catalog):
(tmp_dir / "valid.json").write_text(json.dumps(example), encoding="utf8")
file = TextFile(path=tmp_dir / "valid.json")
file._set_stream(catalog)

expected = """\
from __future__ import annotations
from datachain.lib.data_model import DataModel
from datachain.lib.meta_formats import UserModel
class ImageId(UserModel):
author: str
title: str
size: int
md5: str
url: str
rotation: float
class Classification(UserModel):
Source: str
LabelName: str
Confidence: int
class Image(UserModel):
id: str
split: str
image_id: ImageId
classifications: list[Classification]
Image.model_rebuild()
DataModel.register(Image)
spec = Image"""

actual = read_schema(file, data_type="json", model_name="Image")
actual = "\n".join(actual.splitlines()[4:]) # remove header
assert actual == expected


@pytest.mark.filterwarnings("ignore::pydantic.warnings.PydanticDeprecatedSince20")
def test_read_meta(tmp_dir, catalog):
(tmp_dir / "valid.json").write_text(json.dumps(example), encoding="utf8")
file = TextFile(path=tmp_dir / "valid.json")
file._set_stream(catalog)

parser = read_meta(
schema_from=str(tmp_dir / "valid.json"),
meta_type="jsonl",
model_name="Image",
)
rows = list(parser(file))
assert len(rows) == 1
assert rows[0].model_dump() == example

(tmp_dir / "invalid.json").write_text(
json.dumps({"hello": "world"}), encoding="utf8"
)
invalid_file = TextFile(path=tmp_dir / "invalid.json")
invalid_file._set_stream(catalog)
assert not list(parser(invalid_file))

0 comments on commit f19095b

Please sign in to comment.