diff --git a/python/python/lance/arrow.py b/python/python/lance/arrow.py index e8dcdf5a06b..10b0e91a491 100644 --- a/python/python/lance/arrow.py +++ b/python/python/lance/arrow.py @@ -14,6 +14,7 @@ import json from pathlib import Path from typing import Callable, Iterable, Optional, Union +from uuid import UUID import pyarrow as pa @@ -615,3 +616,37 @@ def __arrow_array__(self, type=None): def from_numpy(cls, array): inner = BFloat16Array.from_numpy(array) return cls(inner) + + +class UuidArray(pa.ExtensionArray): + def __repr__(self): + return "\n%s" % ( + id(self), + repr(self.to_pylist()), + ) + + +class UuidType(pa.ExtensionType): + def __init__(self): + pa.ExtensionType.__init__(self, pa.binary(16), "lance.uuid") + + def __arrow_ext_serialize__(self): + return b"" + + @classmethod + def __arrow_ext_deserialize__(self, storage_type, serialized): + return UuidType() + + def __arrow_ext_class__(self): + return UuidArray + + def __arrow_ext_scalar_class__(self): + return UuidScalarType + + +class UuidScalarType(pa.ExtensionScalar): + def as_py(self): + return None if self.value is None else UUID(bytes=self.value.as_py()) + + +pa.register_extension_type(UuidType()) diff --git a/python/python/tests/test_arrow.py b/python/python/tests/test_arrow.py index e7ca4f31869..3f7b355aa27 100644 --- a/python/python/tests/test_arrow.py +++ b/python/python/tests/test_arrow.py @@ -14,6 +14,7 @@ import re from pathlib import Path +from uuid import uuid4 import lance import numpy as np @@ -27,6 +28,7 @@ ImageArray, ImageURIArray, PandasBFloat16Array, + UuidType, bfloat16_array, ) @@ -252,3 +254,49 @@ def test_roundtrip_image_tensor(tmp_path: Path): tensor_image_array_2 = tbl2.take(indices).column(2) assert tensor_image_array_2.type == tensor_image_array.type + + +def test_uuid_type(): + ty = UuidType() + assert ty.storage_type == pa.binary(16) + assert ty.__class__ is UuidType + + ty = UuidType() + expected = uuid4() + scalar = pa.ExtensionScalar.from_storage(ty, expected.bytes) + assert scalar.as_py() == expected + + # test array + uuids = [uuid4() for _ in range(3)] + storage = pa.array([uuid.bytes for uuid in uuids], type=pa.binary(16)) + arr = pa.ExtensionArray.from_storage(ty, storage) + + # Works for __get_item__ + for i, expected in enumerate(uuids): + assert arr[i].as_py() == expected + + # Works for __iter__ + for result, expected in zip(arr, uuids): + assert result.as_py() == expected + + # test chunked array + data = [ + pa.ExtensionArray.from_storage(ty, storage), + pa.ExtensionArray.from_storage(ty, storage), + ] + carr = pa.chunked_array(data) + for i, expected in enumerate(uuids + uuids): + assert carr[i].as_py() == expected + + for result, expected in zip(carr, uuids + uuids): + assert result.as_py() == expected + + storage1 = pa.array([b"0123456789abcdef"], type=pa.binary(16)) + storage2 = pa.array([b"0123456789abcdef"], type=pa.binary(16)) + storage3 = pa.array([], type=pa.binary(16)) + + a = pa.ExtensionArray.from_storage(ty, storage1) + b = pa.ExtensionArray.from_storage(ty, storage2) + assert a.equals(b) + c = pa.ExtensionArray.from_storage(ty, storage3) + assert not a.equals(c)