Skip to content

add np.datetime64 serialization and tests #1036

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Jul 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
## Release notes

### 0.13.7 -- Jul 13, 2022
* Bugfix - Fix networkx incompatable change by version pinning to 2.6.3 PR #1036 (#1035)
* Add - Support for serializing numpy datetime64 types PR #1036 (#1022)
* Update - Add traceback to default logging PR #1036

### 0.13.6 -- Jun 13, 2022
* Add - Config option to set threshold for when to stop using checksums for filepath stores. PR #1025
* Add - Unified package level logger for package (#667) PR #1031
Expand Down
100 changes: 61 additions & 39 deletions datajoint/blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,44 @@
from .settings import config


mxClassID = dict(
(
# see http://www.mathworks.com/help/techdoc/apiref/mxclassid.html
("mxUNKNOWN_CLASS", None),
("mxCELL_CLASS", None),
("mxSTRUCT_CLASS", None),
("mxLOGICAL_CLASS", np.dtype("bool")),
("mxCHAR_CLASS", np.dtype("c")),
("mxVOID_CLASS", np.dtype("O")),
("mxDOUBLE_CLASS", np.dtype("float64")),
("mxSINGLE_CLASS", np.dtype("float32")),
("mxINT8_CLASS", np.dtype("int8")),
("mxUINT8_CLASS", np.dtype("uint8")),
("mxINT16_CLASS", np.dtype("int16")),
("mxUINT16_CLASS", np.dtype("uint16")),
("mxINT32_CLASS", np.dtype("int32")),
("mxUINT32_CLASS", np.dtype("uint32")),
("mxINT64_CLASS", np.dtype("int64")),
("mxUINT64_CLASS", np.dtype("uint64")),
("mxFUNCTION_CLASS", None),
)
)

rev_class_id = {dtype: i for i, dtype in enumerate(mxClassID.values())}
dtype_list = list(mxClassID.values())
type_names = list(mxClassID)
deserialize_lookup = {
0: {"dtype": None, "scalar_type": "UNKNOWN"},
1: {"dtype": None, "scalar_type": "CELL"},
2: {"dtype": None, "scalar_type": "STRUCT"},
3: {"dtype": np.dtype("bool"), "scalar_type": "LOGICAL"},
4: {"dtype": np.dtype("c"), "scalar_type": "CHAR"},
5: {"dtype": np.dtype("O"), "scalar_type": "VOID"},
6: {"dtype": np.dtype("float64"), "scalar_type": "DOUBLE"},
7: {"dtype": np.dtype("float32"), "scalar_type": "SINGLE"},
8: {"dtype": np.dtype("int8"), "scalar_type": "INT8"},
9: {"dtype": np.dtype("uint8"), "scalar_type": "UINT8"},
10: {"dtype": np.dtype("int16"), "scalar_type": "INT16"},
11: {"dtype": np.dtype("uint16"), "scalar_type": "UINT16"},
12: {"dtype": np.dtype("int32"), "scalar_type": "INT32"},
13: {"dtype": np.dtype("uint32"), "scalar_type": "UINT32"},
14: {"dtype": np.dtype("int64"), "scalar_type": "INT64"},
15: {"dtype": np.dtype("uint64"), "scalar_type": "UINT64"},
16: {"dtype": None, "scalar_type": "FUNCTION"},
65_536: {"dtype": np.dtype("datetime64[Y]"), "scalar_type": "DATETIME64[Y]"},
65_537: {"dtype": np.dtype("datetime64[M]"), "scalar_type": "DATETIME64[M]"},
65_538: {"dtype": np.dtype("datetime64[W]"), "scalar_type": "DATETIME64[W]"},
65_539: {"dtype": np.dtype("datetime64[D]"), "scalar_type": "DATETIME64[D]"},
65_540: {"dtype": np.dtype("datetime64[h]"), "scalar_type": "DATETIME64[h]"},
65_541: {"dtype": np.dtype("datetime64[m]"), "scalar_type": "DATETIME64[m]"},
65_542: {"dtype": np.dtype("datetime64[s]"), "scalar_type": "DATETIME64[s]"},
65_543: {"dtype": np.dtype("datetime64[ms]"), "scalar_type": "DATETIME64[ms]"},
65_544: {"dtype": np.dtype("datetime64[us]"), "scalar_type": "DATETIME64[us]"},
65_545: {"dtype": np.dtype("datetime64[ns]"), "scalar_type": "DATETIME64[ns]"},
65_546: {"dtype": np.dtype("datetime64[ps]"), "scalar_type": "DATETIME64[ps]"},
65_547: {"dtype": np.dtype("datetime64[fs]"), "scalar_type": "DATETIME64[fs]"},
65_548: {"dtype": np.dtype("datetime64[as]"), "scalar_type": "DATETIME64[as]"},
}
serialize_lookup = {
v["dtype"]: {"type_id": k, "scalar_type": v["scalar_type"]}
for k, v in deserialize_lookup.items()
if v["dtype"] is not None
}


compression = {b"ZL123\0": zlib.decompress}

Expand Down Expand Up @@ -176,7 +188,7 @@ def pack_blob(self, obj):
return self.pack_float(obj)
if isinstance(obj, np.ndarray) and obj.dtype.fields:
return self.pack_recarray(np.array(obj))
if isinstance(obj, np.number):
if isinstance(obj, (np.number, np.datetime64)):
return self.pack_array(np.array(obj))
if isinstance(obj, (bool, np.bool_)):
return self.pack_array(np.array(obj))
Expand Down Expand Up @@ -211,14 +223,18 @@ def read_array(self):
shape = self.read_value(count=n_dims)
n_elem = np.prod(shape, dtype=int)
dtype_id, is_complex = self.read_value("uint32", 2)
dtype = dtype_list[dtype_id]

if type_names[dtype_id] == "mxVOID_CLASS":
# Get dtype from type id
dtype = deserialize_lookup[dtype_id]["dtype"]

# Check if name is void
if deserialize_lookup[dtype_id]["scalar_type"] == "VOID":
data = np.array(
list(self.read_blob(self.read_value()) for _ in range(n_elem)),
dtype=np.dtype("O"),
)
elif type_names[dtype_id] == "mxCHAR_CLASS":
# Check if name is char
elif deserialize_lookup[dtype_id]["scalar_type"] == "CHAR":
# compensate for MATLAB packing of char arrays
data = self.read_value(dtype, count=2 * n_elem)
data = data[::2].astype("U1")
Expand All @@ -240,6 +256,8 @@ def pack_array(self, array):
"""
Serialize an np.ndarray into bytes. Scalars are encoded with ndim=0.
"""
if "datetime64" in array.dtype.name:
self.set_dj0()
blob = (
b"A"
+ np.uint64(array.ndim).tobytes()
Expand All @@ -248,22 +266,26 @@ def pack_array(self, array):
is_complex = np.iscomplexobj(array)
if is_complex:
array, imaginary = np.real(array), np.imag(array)
type_id = (
rev_class_id[array.dtype]
if array.dtype.char != "U"
else rev_class_id[np.dtype("O")]
)
if dtype_list[type_id] is None:
raise DataJointError("Type %s is ambiguous or unknown" % array.dtype)
try:
type_id = serialize_lookup[array.dtype]["type_id"]
except KeyError:
# U is for unicode string
if array.dtype.char == "U":
type_id = serialize_lookup[np.dtype("O")]["type_id"]
else:
raise DataJointError(f"Type {array.dtype} is ambiguous or unknown")

blob += np.array([type_id, is_complex], dtype=np.uint32).tobytes()
if type_names[type_id] == "mxVOID_CLASS": # array of dtype('O')
if (
array.dtype.char == "U"
or serialize_lookup[array.dtype]["scalar_type"] == "VOID"
):
blob += b"".join(
len_u64(it) + it
for it in (self.pack_blob(e) for e in array.flatten(order="F"))
)
self.set_dj0() # not supported by original mym
elif type_names[type_id] == "mxCHAR_CLASS": # array of dtype('c')
elif serialize_lookup[array.dtype]["scalar_type"] == "CHAR":
blob += (
array.view(np.uint8).astype(np.uint16).tobytes()
) # convert to 16-bit chars for MATLAB
Expand Down
7 changes: 1 addition & 6 deletions datajoint/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,7 @@ def excepthook(exc_type, exc_value, exc_traceback):
sys.__excepthook__(exc_type, exc_value, exc_traceback)
return

if logger.getEffectiveLevel() == 10:
logger.debug(
"Uncaught exception", exc_info=(exc_type, exc_value, exc_traceback)
)
else:
logger.error(f"Uncaught exception: {exc_value}")
logger.error("Uncaught exception", exc_info=(exc_type, exc_value, exc_traceback))


sys.excepthook = excepthook
2 changes: 1 addition & 1 deletion datajoint/version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
__version__ = "0.13.6"
__version__ = "0.13.7"

assert len(__version__) <= 10 # The log table limits version to the 10 characters
6 changes: 6 additions & 0 deletions docs-parts/intro/Releases_lang1.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
0.13.7 -- Jul 13, 2022
----------------------
* Bugfix - Fix networkx incompatable change by version pinning to 2.6.3 PR #1036 (#1035)
* Add - Support for serializing numpy datetime64 types PR #1036 (#1022)
* Update - Add traceback to default logging PR #1036

0.13.6 -- Jun 13, 2022
----------------------
* Add - Config option to set threshold for when to stop using checksums for filepath stores. PR #1025
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ pyparsing
ipython
pandas
tqdm
networkx
networkx<=2.6.3
pydot
minio>=7.0.0
matplotlib
Expand Down
26 changes: 26 additions & 0 deletions tests/test_blob.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datajoint as dj
import timeit
import numpy as np
import uuid
from . import schema
Expand Down Expand Up @@ -149,6 +150,9 @@ def test_pack():
x == unpack(pack(x)), "Numpy string array object did not pack/unpack correctly"
)

x = np.datetime64("1998").astype("datetime64[us]")
assert_true(x == unpack(pack(x)))


def test_recarrays():
x = np.array([(1.0, 2), (3.0, 4)], dtype=[("x", float), ("y", int)])
Expand Down Expand Up @@ -222,3 +226,25 @@ def test_insert_longblob():
}
(schema.Longblob & "id=1").delete()
dj.blob.use_32bit_dims = False


def test_datetime_serialization_speed():
# If this fails that means for some reason deserializing/serializing
# np arrays of np.datetime64 types is now slower than regular arrays of datetime64

optimized_exe_time = timeit.timeit(
setup="myarr=pack(np.array([np.datetime64('2022-10-13 03:03:13') for _ in range(0, 10000)]))",
stmt="unpack(myarr)",
number=10,
globals=globals(),
)
print(f"np time {optimized_exe_time}")
baseline_exe_time = timeit.timeit(
setup="myarr2=pack(np.array([datetime(2022,10,13,3,3,13) for _ in range (0, 10000)]))",
stmt="unpack(myarr2)",
number=10,
globals=globals(),
)
print(f"python time {baseline_exe_time}")

assert optimized_exe_time * 1000 < baseline_exe_time