Skip to content

Commit

Permalink
Review feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
rok committed Aug 22, 2024
1 parent 6fde98c commit 1b23c10
Show file tree
Hide file tree
Showing 13 changed files with 100 additions and 65 deletions.
2 changes: 2 additions & 0 deletions cpp/src/arrow/extension/fixed_shape_tensor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "arrow/array/array_primitive.h"
#include "arrow/io/memory.h"
#include "arrow/ipc/reader.h"
#include "arrow/ipc/test_common.h"
#include "arrow/record_batch.h"
#include "arrow/tensor.h"
#include "arrow/testing/gtest_util.h"
Expand All @@ -32,6 +33,7 @@
namespace arrow {

using FixedShapeTensorType = extension::FixedShapeTensorType;
using arrow::ipc::test::RoundtripBatch;
using extension::fixed_shape_tensor;
using extension::FixedShapeTensorArray;

Expand Down
14 changes: 10 additions & 4 deletions cpp/src/arrow/extension/uuid.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
// specific language governing permissions and limitations
// under the License.

#include <sstream>

#include "arrow/extension_type.h"
#include "arrow/util/logging.h"

#include "arrow/extension/uuid.h"

namespace arrow {
namespace extension {
namespace arrow::extension {

bool UuidType::ExtensionEquals(const ExtensionType& other) const {
return (other.extension_name() == this->extension_name());
Expand All @@ -46,7 +47,12 @@ Result<std::shared_ptr<DataType>> UuidType::Deserialize(
return std::make_shared<UuidType>();
}

std::string UuidType::ToString(bool show_metadata) const {
std::stringstream ss;
ss << "extension<" << this->extension_name() << ">";
return ss.str();
}

std::shared_ptr<DataType> uuid() { return std::make_shared<UuidType>(); }

} // namespace extension
} // namespace arrow
} // namespace arrow::extension
21 changes: 13 additions & 8 deletions cpp/src/arrow/extension/uuid.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,30 @@

#include "arrow/extension_type.h"

namespace arrow {
namespace extension {
namespace arrow::extension {

/// \brief UuidArray stores array of UUIDs. Underlying storage type is
/// FixedSizeBinary(16).
class ARROW_EXPORT UuidArray : public ExtensionArray {
public:
using ExtensionArray::ExtensionArray;
};

/// \brief UuidType is a canonical arrow extension type for UUIDs.
/// UUIDs are stored as FixedSizeBinary(16) with big-endian notation and this
/// does not interpret the bytes in any way. Specific UUID version is not
/// required or guaranteed.
class ARROW_EXPORT UuidType : public ExtensionType {
public:
/// \brief Construct a UuidType.
UuidType() : ExtensionType(fixed_size_binary(16)) {}

std::string extension_name() const override { return "arrow.uuid"; }

const std::shared_ptr<DataType> value_type() const { return fixed_size_binary(16); }
std::string ToString(bool show_metadata = false) const override;

bool ExtensionEquals(const ExtensionType& other) const override;

/// Create a UuidArray from ArrayData
std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override;

Result<std::shared_ptr<DataType>> Deserialize(
Expand All @@ -49,8 +55,7 @@ class ARROW_EXPORT UuidType : public ExtensionType {
static Result<std::shared_ptr<DataType>> Make() { return std::make_shared<UuidType>(); }
};

ARROW_EXPORT
std::shared_ptr<DataType> uuid();
/// \brief Return a UuidType instance.
ARROW_EXPORT std::shared_ptr<DataType> uuid();

} // namespace extension
} // namespace arrow
} // namespace arrow::extension
29 changes: 29 additions & 0 deletions cpp/src/arrow/extension/uuid_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,16 @@
#include "arrow/array/array_primitive.h"
#include "arrow/io/memory.h"
#include "arrow/ipc/reader.h"
#include "arrow/ipc/test_common.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/util/key_value_metadata.h"

#include "arrow/testing/extension_type.h"

namespace arrow {

using arrow::ipc::test::RoundtripBatch;

TEST(TestUuuidExtensionType, ExtensionTypeTest) {
auto type = uuid();
ASSERT_EQ(type->id(), Type::EXTENSION);
Expand All @@ -42,4 +46,29 @@ TEST(TestUuuidExtensionType, ExtensionTypeTest) {
ASSERT_FALSE(deserialized->Equals(*fixed_size_binary(16)));
}

TEST(TestUuuidExtensionType, RoundtripBatch) {
auto ext_type = extension::uuid();
auto exact_ext_type = internal::checked_pointer_cast<extension::UuidType>(ext_type);
auto arr = ArrayFromJSON(fixed_size_binary(16), R"(["abcdefghijklmnop", null])");
auto ext_arr = ExtensionType::WrapArray(ext_type, arr);

// Pass extension array, expect getting back extension array
std::shared_ptr<RecordBatch> read_batch;
auto ext_field = field(/*name=*/"f0", /*type=*/ext_type);
auto batch = RecordBatch::Make(schema({ext_field}), ext_arr->length(), {ext_arr});
RoundtripBatch(batch, &read_batch);
CompareBatch(*batch, *read_batch, /*compare_metadata=*/true);

// Pass extension metadata and storage array, expect getting back extension array
std::shared_ptr<RecordBatch> read_batch2;
auto ext_metadata =
key_value_metadata({{"ARROW:extension:name", exact_ext_type->extension_name()},
{"ARROW:extension:metadata", ""}});
ext_field = field(/*name=*/"f0", /*type=*/exact_ext_type->storage_type(),
/*nullable=*/true, /*metadata=*/ext_metadata);
auto batch2 = RecordBatch::Make(schema({ext_field}), arr->length(), {arr});
RoundtripBatch(batch2, &read_batch2);
CompareBatch(*batch, *read_batch2, /*compare_metadata=*/true);
}

} // namespace arrow
6 changes: 2 additions & 4 deletions cpp/src/arrow/extension_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,9 @@ namespace internal {
static void CreateGlobalRegistry() {
g_registry = std::make_shared<ExtensionTypeRegistryImpl>();

#ifdef ARROW_JSON
std::vector<std::shared_ptr<DataType>> ext_types{
extension::bool8(), extension::fixed_shape_tensor(int64(), {}), extension::uuid()};
#else
std::vector<std::shared_ptr<DataType>> ext_types{extension::bool8(), extension::uuid()};
#ifdef ARROW_JSON
ext_types.push_back(extension::fixed_shape_tensor(int64(), {}));
#endif

// Register canonical extension types
Expand Down
31 changes: 24 additions & 7 deletions cpp/src/arrow/ipc/test_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@
#include "arrow/array.h"
#include "arrow/array/builder_binary.h"
#include "arrow/array/builder_primitive.h"
#include "arrow/array/builder_time.h"
#include "arrow/io/memory.h"
#include "arrow/ipc/reader.h"
#include "arrow/ipc/test_common.h"
#include "arrow/ipc/writer.h"
#include "arrow/pretty_print.h"
#include "arrow/record_batch.h"
#include "arrow/status.h"
Expand Down Expand Up @@ -242,11 +244,11 @@ Status MakeRandomBooleanArray(const int length, bool include_nulls,
std::shared_ptr<Array>* out) {
std::vector<uint8_t> values(length);
random_null_bytes(length, 0.5, values.data());
ARROW_ASSIGN_OR_RAISE(auto data, internal::BytesToBits(values));
ARROW_ASSIGN_OR_RAISE(auto data, arrow::internal::BytesToBits(values));

if (include_nulls) {
std::vector<uint8_t> valid_bytes(length);
ARROW_ASSIGN_OR_RAISE(auto null_bitmap, internal::BytesToBits(valid_bytes));
ARROW_ASSIGN_OR_RAISE(auto null_bitmap, arrow::internal::BytesToBits(valid_bytes));
random_null_bytes(length, 0.1, valid_bytes.data());
*out = std::make_shared<BooleanArray>(length, data, null_bitmap, -1);
} else {
Expand Down Expand Up @@ -596,7 +598,7 @@ Status MakeStruct(std::shared_ptr<RecordBatch>* out) {
std::shared_ptr<Array> no_nulls(new StructArray(type, list_batch->num_rows(), columns));
std::vector<uint8_t> null_bytes(list_batch->num_rows(), 1);
null_bytes[0] = 0;
ARROW_ASSIGN_OR_RAISE(auto null_bitmap, internal::BytesToBits(null_bytes));
ARROW_ASSIGN_OR_RAISE(auto null_bitmap, arrow::internal::BytesToBits(null_bytes));
std::shared_ptr<Array> with_nulls(
new StructArray(type, list_batch->num_rows(), columns, null_bitmap, 1));

Expand Down Expand Up @@ -1176,12 +1178,13 @@ enable_if_t<std::is_floating_point<CValueType>::value, void> FillRandomData(
Status MakeRandomTensor(const std::shared_ptr<DataType>& type,
const std::vector<int64_t>& shape, bool row_major_p,
std::shared_ptr<Tensor>* out, uint32_t seed) {
const auto& element_type = internal::checked_cast<const FixedWidthType&>(*type);
const auto& element_type = arrow::internal::checked_cast<const FixedWidthType&>(*type);
std::vector<int64_t> strides;
if (row_major_p) {
RETURN_NOT_OK(internal::ComputeRowMajorStrides(element_type, shape, &strides));
RETURN_NOT_OK(arrow::internal::ComputeRowMajorStrides(element_type, shape, &strides));
} else {
RETURN_NOT_OK(internal::ComputeColumnMajorStrides(element_type, shape, &strides));
RETURN_NOT_OK(
arrow::internal::ComputeColumnMajorStrides(element_type, shape, &strides));
}

const int64_t element_size = element_type.bit_width() / CHAR_BIT;
Expand Down Expand Up @@ -1233,6 +1236,20 @@ Status MakeRandomTensor(const std::shared_ptr<DataType>& type,
return Tensor::Make(type, buf, shape, strides).Value(out);
}

void RoundtripBatch(const std::shared_ptr<RecordBatch>& batch,
std::shared_ptr<RecordBatch>* out) {
ASSERT_OK_AND_ASSIGN(auto out_stream, io::BufferOutputStream::Create());
ASSERT_OK(ipc::WriteRecordBatchStream({batch}, ipc::IpcWriteOptions::Defaults(),
out_stream.get()));

ASSERT_OK_AND_ASSIGN(auto complete_ipc_stream, out_stream->Finish());

io::BufferReader reader(complete_ipc_stream);
std::shared_ptr<RecordBatchReader> batch_reader;
ASSERT_OK_AND_ASSIGN(batch_reader, ipc::RecordBatchStreamReader::Open(&reader));
ASSERT_OK(batch_reader->ReadNext(out));
}

} // namespace test
} // namespace ipc
} // namespace arrow
3 changes: 3 additions & 0 deletions cpp/src/arrow/ipc/test_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ Status MakeRandomTensor(const std::shared_ptr<DataType>& type,
const std::vector<int64_t>& shape, bool row_major_p,
std::shared_ptr<Tensor>* out, uint32_t seed = 0);

ARROW_TESTING_EXPORT void RoundtripBatch(const std::shared_ptr<RecordBatch>& batch,
std::shared_ptr<RecordBatch>* out);

} // namespace test
} // namespace ipc
} // namespace arrow
14 changes: 0 additions & 14 deletions cpp/src/arrow/testing/gtest_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -590,20 +590,6 @@ void ApproxCompareBatch(const RecordBatch& left, const RecordBatch& right,
});
}

void RoundtripBatch(const std::shared_ptr<RecordBatch>& batch,
std::shared_ptr<RecordBatch>* out) {
ASSERT_OK_AND_ASSIGN(auto out_stream, io::BufferOutputStream::Create());
ASSERT_OK(ipc::WriteRecordBatchStream({batch}, ipc::IpcWriteOptions::Defaults(),
out_stream.get()));

ASSERT_OK_AND_ASSIGN(auto complete_ipc_stream, out_stream->Finish());

io::BufferReader reader(complete_ipc_stream);
std::shared_ptr<RecordBatchReader> batch_reader;
ASSERT_OK_AND_ASSIGN(batch_reader, ipc::RecordBatchStreamReader::Open(&reader));
ASSERT_OK(batch_reader->ReadNext(out));
}

std::shared_ptr<Array> TweakValidityBit(const std::shared_ptr<Array>& array,
int64_t index, bool validity) {
auto data = array->data()->Copy();
Expand Down
3 changes: 0 additions & 3 deletions cpp/src/arrow/testing/gtest_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -309,9 +309,6 @@ ARROW_TESTING_EXPORT void ApproxCompareBatch(
const RecordBatch& left, const RecordBatch& right, bool compare_metadata = true,
const EqualOptions& options = TestingEqualOptions());

ARROW_TESTING_EXPORT void RoundtripBatch(const std::shared_ptr<RecordBatch>& batch,
std::shared_ptr<RecordBatch>* out);

// Check if the padding of the buffers of the array is zero.
// Also cause valgrind warnings if the padding bytes are uninitialized.
ARROW_TESTING_EXPORT void AssertZeroPadded(const Array& array);
Expand Down
6 changes: 0 additions & 6 deletions python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -2871,12 +2871,6 @@ cdef extern from "arrow/extension/uuid.h" namespace "arrow::extension":
@staticmethod
CResult[shared_ptr[CDataType]] Make()

CResult[shared_ptr[CDataType]] Deserialize(const c_string& serialized_data) const

c_string Serialize() const

const shared_ptr[CDataType] value_type()


cdef extern from "arrow/extension/fixed_shape_tensor.h" namespace "arrow::extension" nogil:
cdef cppclass CFixedShapeTensorType \
Expand Down
9 changes: 5 additions & 4 deletions python/pyarrow/public-api.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -120,15 +120,16 @@ cdef api object pyarrow_wrap_data_type(
elif type.get().id() == _Type_EXTENSION:
ext_type = <const CExtensionType*> type.get()
cpy_ext_type = dynamic_cast[_CPyExtensionTypePtr](ext_type)
extension_name = ext_type.extension_name()
if cpy_ext_type != nullptr:
return cpy_ext_type.GetInstance()
elif ext_type.extension_name() == b"arrow.bool8":
elif extension_name == b"arrow.bool8":
out = Bool8Type.__new__(Bool8Type)
elif ext_type.extension_name() == b"arrow.fixed_shape_tensor":
elif extension_name == b"arrow.fixed_shape_tensor":
out = FixedShapeTensorType.__new__(FixedShapeTensorType)
elif ext_type.extension_name() == b"arrow.opaque":
elif extension_name == b"arrow.opaque":
out = OpaqueType.__new__(OpaqueType)
elif ext_type.extension_name() == b"arrow.uuid":
elif extension_name == b"arrow.uuid":
out = UuidType.__new__(UuidType)
else:
out = BaseExtensionType.__new__(BaseExtensionType)
Expand Down
14 changes: 12 additions & 2 deletions python/pyarrow/tests/test_extension_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,14 +269,14 @@ def test_ext_type_repr():
assert repr(ty) == "IntegerType(DataType(int64))"


def test_ext_type__lifetime():
def test_ext_type_lifetime():
ty = ExampleUuidType()
wr = weakref.ref(ty)
del ty
assert wr() is None


def test_ext_type__storage_type():
def test_ext_type_storage_type():
ty = ExampleUuidType()
assert ty.storage_type == pa.binary(16)
assert ty.__class__ is ExampleUuidType
Expand Down Expand Up @@ -354,6 +354,16 @@ def test_uuid_type_pickle(pickle_module):
del ty
assert wr() is None

for proto in range(0, pickle_module.HIGHEST_PROTOCOL + 1):
ty = pa.uuid()
ser = pickle_module.dumps(ty, protocol=proto)
del ty
ty = pickle_module.loads(ser)
wr = weakref.ref(ty)
assert ty.extension_name == "arrow.uuid"
del ty
assert wr() is None


def test_ext_type_equality():
a = ParamExtType(5)
Expand Down
13 changes: 0 additions & 13 deletions python/pyarrow/types.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -1774,19 +1774,6 @@ cdef class UuidType(BaseExtensionType):
BaseExtensionType.init(self, type)
self.uuid_ext_type = <const CUuidType*> type.get()

def __arrow_ext_serialize__(self):
"""
Serialized representation of metadata to reconstruct the type object.
"""
return self.uuid_ext_type.Serialize()

@classmethod
def __arrow_ext_deserialize__(self, storage_type, serialized):
"""
Return an UuidType instance from the storage type.
"""
return self.uuid_ext_type.Deserialize(storage_type, serialized)

def __arrow_ext_class__(self):
return UuidArray

Expand Down

0 comments on commit 1b23c10

Please sign in to comment.