Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 64 additions & 5 deletions cpp/src/arrow/compute/kernels/cast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <functional>
#include <limits>
#include <memory>
#include <string>
#include <type_traits>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -1127,6 +1128,62 @@ class ZeroCopyCast : public CastKernelBase {
}
};

class ExtensionCastKernel : public CastKernelBase {
public:
static Status Make(const DataType& in_type, std::shared_ptr<DataType> out_type,
const CastOptions& options,
std::unique_ptr<CastKernelBase>* kernel) {
const auto storage_type = checked_cast<const ExtensionType&>(in_type).storage_type();

std::unique_ptr<UnaryKernel> storage_caster;
RETURN_NOT_OK(GetCastFunction(*storage_type, out_type, options, &storage_caster));
kernel->reset(
new ExtensionCastKernel(std::move(storage_caster), std::move(out_type)));

return Status::OK();
}

Status Init(const DataType& in_type) override {
auto& type = checked_cast<const ExtensionType&>(in_type);
storage_type_ = type.storage_type();
extension_name_ = type.extension_name();
return Status::OK();
}

Status Call(FunctionContext* ctx, const Datum& input, Datum* out) override {
DCHECK_EQ(input.kind(), Datum::ARRAY);

// validate: type is the same as the type the kernel was constructed with
const auto& input_type = checked_cast<const ExtensionType&>(*input.type());
if (input_type.extension_name() != extension_name_) {
return Status::TypeError(
"The cast kernel was constructed to cast from the extension type named '",
extension_name_, "' but input has extension type named '",
input_type.extension_name(), "'");
}
if (!input_type.storage_type()->Equals(storage_type_)) {
return Status::TypeError("The cast kernel was constructed with a storage type: ",
storage_type_->ToString(),
", but it is called with a different storage type:",
input_type.storage_type()->ToString());
}

// construct an ArrayData object with the underlying storage type
auto new_input = input.array()->Copy();
new_input->type = storage_type_;
return InvokeWithAllocation(ctx, storage_caster_.get(), new_input, out);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it allocate if the out_type and storage_type are the same?

Copy link
Member Author

@kszucs kszucs Mar 18, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added tests on both C++ and Python side, seems like no allocation happens.

}

protected:
ExtensionCastKernel(std::unique_ptr<UnaryKernel> storage_caster,
std::shared_ptr<DataType> out_type)
: CastKernelBase(std::move(out_type)), storage_caster_(std::move(storage_caster)) {}

std::string extension_name_;
std::shared_ptr<DataType> storage_type_;
std::unique_ptr<UnaryKernel> storage_caster_;
};

class CastKernel : public CastKernelBase {
public:
CastKernel(const CastOptions& options, const CastFunction& func,
Expand Down Expand Up @@ -1275,11 +1332,6 @@ Status GetCastFunction(const DataType& in_type, std::shared_ptr<DataType> out_ty
return Status::OK();
}

if (in_type.id() == Type::NA) {
kernel->reset(new FromNullCastKernel(std::move(out_type)));
return Status::OK();
}

std::unique_ptr<CastKernelBase> cast_kernel;
switch (in_type.id()) {
CAST_FUNCTION_CASE(BooleanType);
Expand All @@ -1304,6 +1356,9 @@ Status GetCastFunction(const DataType& in_type, std::shared_ptr<DataType> out_ty
CAST_FUNCTION_CASE(LargeBinaryType);
CAST_FUNCTION_CASE(LargeStringType);
CAST_FUNCTION_CASE(DictionaryType);
case Type::NA:
cast_kernel.reset(new FromNullCastKernel(std::move(out_type)));
break;
case Type::LIST:
RETURN_NOT_OK(
GetListCastFunc<ListType>(in_type, std::move(out_type), options, &cast_kernel));
Expand All @@ -1312,6 +1367,10 @@ Status GetCastFunction(const DataType& in_type, std::shared_ptr<DataType> out_ty
RETURN_NOT_OK(GetListCastFunc<LargeListType>(in_type, std::move(out_type), options,
&cast_kernel));
break;
case Type::EXTENSION:
RETURN_NOT_OK(ExtensionCastKernel::Make(std::move(in_type), std::move(out_type),
options, &cast_kernel));
break;
default:
break;
}
Expand Down
51 changes: 51 additions & 0 deletions cpp/src/arrow/compute/kernels/cast_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@

#include "arrow/array.h"
#include "arrow/buffer.h"
#include "arrow/extension_type.h"
#include "arrow/memory_pool.h"
#include "arrow/status.h"
#include "arrow/table.h"
#include "arrow/testing/extension_type.h"
#include "arrow/testing/gtest_common.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/testing/random.h"
Expand Down Expand Up @@ -1480,5 +1482,54 @@ TYPED_TEST(TestDictionaryCast, OutTypeError) {
this->CheckPass(*plain_array, *dict_array, dict_array->type(), options);
}*/

std::shared_ptr<Array> SmallintArrayFromJSON(const std::string& json_data) {
auto arr = ArrayFromJSON(int16(), json_data);
auto ext_data = arr->data()->Copy();
ext_data->type = smallint();
return MakeArray(ext_data);
}

TEST_F(TestCast, ExtensionTypeToIntDowncast) {
auto smallint = std::make_shared<SmallintType>();
ASSERT_OK(RegisterExtensionType(smallint));

CastOptions options;
options.allow_int_overflow = false;

std::shared_ptr<Array> result;
std::vector<bool> is_valid = {true, false, true, true, true};

// Smallint(int16) to int16
auto v0 = SmallintArrayFromJSON("[0, 100, 200, 1, 2]");
CheckZeroCopy(*v0, int16());

// Smallint(int16) to uint8, no overflow/underrun
auto v1 = SmallintArrayFromJSON("[0, 100, 200, 1, 2]");
auto e1 = ArrayFromJSON(uint8(), "[0, 100, 200, 1, 2]");
CheckPass(*v1, *e1, uint8(), options);

// Smallint(int16) to uint8, with overflow
auto v2 = SmallintArrayFromJSON("[0, null, 256, 1, 3]");
auto e2 = ArrayFromJSON(uint8(), "[0, null, 0, 1, 3]");
// allow overflow
options.allow_int_overflow = true;
CheckPass(*v2, *e2, uint8(), options);
// disallow overflow
options.allow_int_overflow = false;
ASSERT_RAISES(Invalid, Cast(&ctx_, *v2, uint8(), options, &result));

// Smallint(int16) to uint8, with underflow
auto v3 = SmallintArrayFromJSON("[0, null, -1, 1, 0]");
auto e3 = ArrayFromJSON(uint8(), "[0, null, 255, 1, 0]");
// allow overflow
options.allow_int_overflow = true;
CheckPass(*v3, *e3, uint8(), options);
// disallow overflow
options.allow_int_overflow = false;
ASSERT_RAISES(Invalid, Cast(&ctx_, *v3, uint8(), options, &result));

ASSERT_OK(UnregisterExtensionType("smallint"));
}

} // namespace compute
} // namespace arrow
28 changes: 28 additions & 0 deletions cpp/src/arrow/testing/extension_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,38 @@ class ARROW_EXPORT UUIDType : public ExtensionType {
std::string Serialize() const override { return "uuid-type-unique-code"; }
};

class ARROW_EXPORT SmallintArray : public ExtensionArray {
public:
using ExtensionArray::ExtensionArray;
};

class ARROW_EXPORT SmallintType : public ExtensionType {
public:
SmallintType() : ExtensionType(int16()) {}

std::string extension_name() const override { return "smallint"; }

bool ExtensionEquals(const ExtensionType& other) const override;

std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override;

Status Deserialize(std::shared_ptr<DataType> storage_type,
const std::string& serialized,
std::shared_ptr<DataType>* out) const override;

std::string Serialize() const override { return "smallint"; }
};

ARROW_EXPORT
std::shared_ptr<DataType> uuid();

ARROW_EXPORT
std::shared_ptr<DataType> smallint();

ARROW_EXPORT
std::shared_ptr<Array> ExampleUUID();

ARROW_EXPORT
std::shared_ptr<Array> ExampleSmallint();

} // namespace arrow
40 changes: 35 additions & 5 deletions cpp/src/arrow/testing/gtest_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -382,11 +382,7 @@ void SleepFor(double seconds) {
// Extension types

bool UUIDType::ExtensionEquals(const ExtensionType& other) const {
const auto& other_ext = static_cast<const ExtensionType&>(other);
if (other_ext.extension_name() != this->extension_name()) {
return false;
}
return true;
return (other.extension_name() == this->extension_name());
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this should be the default implementation of ExtensionEquals

Copy link
Member Author

@kszucs kszucs Mar 18, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name and the storage type should be equal, created a follow-up JIRA https://issues.apache.org/jira/browse/ARROW-8143


std::shared_ptr<Array> UUIDType::MakeArray(std::shared_ptr<ArrayData> data) const {
Expand Down Expand Up @@ -423,4 +419,38 @@ std::shared_ptr<Array> ExampleUUID() {
return MakeArray(ext_data);
}

bool SmallintType::ExtensionEquals(const ExtensionType& other) const {
return (other.extension_name() == this->extension_name());
}

std::shared_ptr<Array> SmallintType::MakeArray(std::shared_ptr<ArrayData> data) const {
DCHECK_EQ(data->type->id(), Type::EXTENSION);
DCHECK_EQ("smallint", static_cast<const ExtensionType&>(*data->type).extension_name());
return std::make_shared<SmallintArray>(data);
}

Status SmallintType::Deserialize(std::shared_ptr<DataType> storage_type,
const std::string& serialized,
std::shared_ptr<DataType>* out) const {
if (serialized != "smallint") {
return Status::Invalid("Type identifier did not match");
}
if (!storage_type->Equals(*int16())) {
return Status::Invalid("Invalid storage type for SmallintType");
}
*out = std::make_shared<SmallintType>();
return Status::OK();
}

std::shared_ptr<DataType> smallint() { return std::make_shared<SmallintType>(); }

std::shared_ptr<Array> ExampleSmallint() {
auto storage_type = int16();
auto ext_type = smallint();
auto arr = ArrayFromJSON(storage_type, "[-32768, null, 1, 2, 3, 4, 32767]");
auto ext_data = arr->data()->Copy();
ext_data->type = ext_type;
return MakeArray(ext_data);
}

} // namespace arrow
45 changes: 45 additions & 0 deletions python/pyarrow/tests/test_extension_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,15 @@
import pytest


class IntegerType(pa.PyExtensionType):

def __init__(self):
pa.PyExtensionType.__init__(self, pa.int64())

def __reduce__(self):
return IntegerType, ()


class UuidType(pa.PyExtensionType):

def __init__(self):
Expand Down Expand Up @@ -168,6 +177,42 @@ def test_ext_array_pickling():
assert arr.storage.to_pylist() == [b"foo", b"bar"]


def test_cast_kernel_on_extension_arrays():
# test array casting
storage = pa.array([1, 2, 3, 4], pa.int64())
arr = pa.ExtensionArray.from_storage(IntegerType(), storage)

# test that no allocation happens during identity cast
allocated_before_cast = pa.total_allocated_bytes()
casted = arr.cast(pa.int64())
assert pa.total_allocated_bytes() == allocated_before_cast

cases = [
(pa.int64(), pa.Int64Array),
(pa.int32(), pa.Int32Array),
(pa.int16(), pa.Int16Array),
(pa.uint64(), pa.UInt64Array),
(pa.uint32(), pa.UInt32Array),
(pa.uint16(), pa.UInt16Array)
]
for typ, klass in cases:
casted = arr.cast(typ)
assert casted.type == typ
assert isinstance(casted, klass)

# test chunked array casting
arr = pa.chunked_array([arr, arr])
casted = arr.cast(pa.int16())
assert casted.type == pa.int16()
assert isinstance(casted, pa.ChunkedArray)


def test_casting_to_extension_type_raises():
arr = pa.array([1, 2, 3, 4], pa.int64())
with pytest.raises(pa.ArrowNotImplementedError):
arr.cast(IntegerType())


def example_batch():
ty = ParamExtType(3)
storage = pa.array([b"foo", b"bar"], type=pa.binary(3))
Expand Down