Skip to content

Commit

Permalink
Auto load json cols (#444)
Browse files Browse the repository at this point in the history
* Auto load json cols

* Changing JSON column to dict + Fixing Tests

---------

Co-authored-by: David Tulga <3924980+dtulga@users.noreply.github.com>
  • Loading branch information
Dave Berenbaum and dtulga authored Sep 18, 2024
1 parent ee43fd1 commit 16c2729
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 15 deletions.
6 changes: 4 additions & 2 deletions src/datachain/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
( https://docs.python.org/3/library/sqlite3.html#sqlite3.register_converter )
"""

import json
from datetime import datetime
from types import MappingProxyType
from typing import Any, Union

import orjson
import sqlalchemy as sa
from sqlalchemy import TypeDecorator, types

Expand Down Expand Up @@ -312,7 +312,7 @@ def db_default_value(dialect):
def on_read_convert(self, value, dialect):
r = read_converter(dialect).array(value, self.item_type, dialect)
if isinstance(self.item_type, JSON):
r = [json.loads(item) if isinstance(item, str) else item for item in r]
r = [orjson.loads(item) if isinstance(item, str) else item for item in r]
return r


Expand Down Expand Up @@ -420,6 +420,8 @@ def array(self, value, item_type, dialect):
return [item_type.on_read_convert(x, dialect) for x in value]

def json(self, value):
if isinstance(value, str):
return orjson.loads(value)
return value

def datetime(self, value):
Expand Down
7 changes: 3 additions & 4 deletions tests/func/test_datachain.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import math
import os
import pickle
Expand Down Expand Up @@ -993,7 +992,7 @@ def test_types():
[0.5],
"s",
True,
json.dumps({"a": 1}),
{"a": 1},
pickle.dumps(obj),
)

Expand All @@ -1018,7 +1017,7 @@ def test_types():
"array_col_64": list[float],
"string_col": str,
"bool_col": bool,
"json_col": str,
"json_col": dict,
"binary_col": bytes,
},
)
Expand Down Expand Up @@ -1059,7 +1058,7 @@ def test_types():
[0.5],
"s",
True,
json.dumps({"a": 1}),
{"a": 1},
obj,
)
]
Expand Down
14 changes: 6 additions & 8 deletions tests/func/test_datasets.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import json
import posixpath
import uuid
from json import dumps
from unittest.mock import ANY

import pytest
Expand Down Expand Up @@ -659,7 +657,7 @@ def test_types():
[0.5],
"s",
True,
json.dumps({"a": 1}),
{"a": 1},
int_example.to_bytes(2, "big"),
)

Expand All @@ -683,7 +681,7 @@ def test_types():
"array_col_64": list[float],
"string_col": str,
"bool_col": bool,
"json_col": str,
"json_col": dict,
"binary_col": bytes,
},
)
Expand All @@ -703,7 +701,7 @@ def test_types():
assert r["array_col_64"] == [0.5]
assert r["string_col"] == "s"
assert r["bool_col"]
assert r["json_col"] == dumps({"a": 1})
assert r["json_col"] == {"a": 1}
assert r["binary_col"] == int_example.to_bytes(2, "big")


Expand All @@ -725,7 +723,7 @@ def test_types():
[0.5],
"s",
True,
json.dumps({"a": 1}),
{"a": 1},
int_example.to_bytes(2, "big"),
)

Expand All @@ -749,7 +747,7 @@ def test_types():
"array_col_64": list[float],
"string_col": str,
"bool_col": bool,
"json_col": str,
"json_col": dict,
"binary_col": bytes,
},
)
Expand All @@ -769,7 +767,7 @@ def test_types():
assert r["array_col_64"] == [0.5]
assert r["string_col"] == "s"
assert r["bool_col"]
assert r["json_col"] == '{"a": 1}'
assert r["json_col"] == {"a": 1}
assert r["binary_col"] == [0, 25]


Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_data_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def run_convert_type(value, sql_type):
[Float32(), lambda val: math.isnan(val)],
[Float64(), lambda val: math.isnan(val)],
[Array(Int), []],
[JSON(), "{}"],
[JSON(), {}],
[DateTime(), datetime(1970, 1, 1, 0, 0, tzinfo=timezone.utc)],
[Binary(), b""],
],
Expand Down

0 comments on commit 16c2729

Please sign in to comment.