diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 88d72b11832..2396a5e3a1e 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -425,6 +425,7 @@ if(ARROW_COMPUTE) compute/kernels/scalar_boolean.cc compute/kernels/scalar_cast_boolean.cc compute/kernels/scalar_cast_dictionary.cc + compute/kernels/scalar_cast_extension.cc compute/kernels/scalar_cast_internal.cc compute/kernels/scalar_cast_nested.cc compute/kernels/scalar_cast_numeric.cc diff --git a/cpp/src/arrow/compute/cast.cc b/cpp/src/arrow/compute/cast.cc index 52aecf3e45a..2bfc963b082 100644 --- a/cpp/src/arrow/compute/cast.cc +++ b/cpp/src/arrow/compute/cast.cc @@ -62,6 +62,7 @@ void InitCastTable() { AddCastFunctions(GetNumericCasts()); AddCastFunctions(GetTemporalCasts()); AddCastFunctions(GetDictionaryCasts()); + AddCastFunctions(GetExtensionCasts()); } void EnsureInitCastTable() { std::call_once(cast_table_initialized, InitCastTable); } diff --git a/cpp/src/arrow/compute/cast_internal.h b/cpp/src/arrow/compute/cast_internal.h index f00a6cdbf4d..423b791e6a7 100644 --- a/cpp/src/arrow/compute/cast_internal.h +++ b/cpp/src/arrow/compute/cast_internal.h @@ -63,6 +63,7 @@ std::vector> GetTemporalCasts(); std::vector> GetBinaryLikeCasts(); std::vector> GetNestedCasts(); std::vector> GetDictionaryCasts(); +std::vector> GetExtensionCasts(); ARROW_EXPORT Result> GetCastFunction(const DataType& to_type); diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc b/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc new file mode 100644 index 00000000000..d2e2ab72f00 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc @@ -0,0 +1,74 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Implementation of casting to extension types +#include "arrow/compute/kernels/common.h" +#include "arrow/compute/kernels/scalar_cast_internal.h" +#include "arrow/scalar.h" + +namespace arrow { +namespace compute { +namespace internal { + +namespace { +Status CastToExtension(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { + const CastOptions& options = checked_cast(ctx->state())->options; + const auto& ext_ty = static_cast(*options.to_type.type); + auto out_ty = ext_ty.storage_type(); + + DCHECK(batch[0].is_array()); + std::shared_ptr array = batch[0].array.ToArray(); + + // Try to prevent user errors by preventing casting between extensions w/ + // different storage types. Provide a tip on how to accomplish same outcome. + std::shared_ptr result; + if (array->type()->id() == Type::EXTENSION) { + if (!array->type()->Equals(out_ty)) { + return Status::TypeError("Casting from '" + array->type()->ToString() + + "' to different extension type '" + ext_ty.ToString() + + "' not permitted. One can first cast to the storage " + "type, then to the extension type."); + } + result = array; + } else { + ARROW_ASSIGN_OR_RAISE(result, Cast(*array, out_ty, options, ctx->exec_context())); + } + + ExtensionArray extension(options.to_type.GetSharedPtr(), result); + out->value = std::move(extension.data()); + return Status::OK(); +} + +std::shared_ptr GetCastToExtension(std::string name) { + auto func = std::make_shared(std::move(name), Type::EXTENSION); + for (Type::type in_ty : AllTypeIds()) { + DCHECK_OK( + func->AddKernel(in_ty, {InputType(in_ty)}, kOutputTargetType, CastToExtension)); + } + return func; +} + +}; // namespace + +std::vector> GetExtensionCasts() { + auto func = GetCastToExtension("cast_extension"); + return {func}; +} + +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc index 963748c9f97..6b172eaa140 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc @@ -225,7 +225,8 @@ TEST(Cast, CanCast) { ExpectCanCast(smallint(), {int16()}); // cast storage ExpectCanCast(smallint(), kNumericTypes); // any cast which is valid for storage is supported - ExpectCannotCast(null(), {smallint()}); // FIXME missing common cast from null + ExpectCanCast(null(), {smallint()}); + ExpectCanCast(tinyint(), {smallint()}); // cast between compatible storage types ExpectCanCast(date32(), {utf8(), large_utf8()}); ExpectCanCast(date64(), {utf8(), large_utf8()}); @@ -2728,6 +2729,13 @@ std::shared_ptr SmallintArrayFromJSON(const std::string& json_data) { return MakeArray(ext_data); } +std::shared_ptr TinyintArrayFromJSON(const std::string& json_data) { + auto arr = ArrayFromJSON(int8(), json_data); + auto ext_data = arr->data()->Copy(); + ext_data->type = tinyint(); + return MakeArray(ext_data); +} + TEST(Cast, ExtensionTypeToIntDowncast) { auto smallint = std::make_shared(); ExtensionTypeGuard smallint_guard(smallint); @@ -2765,6 +2773,68 @@ TEST(Cast, ExtensionTypeToIntDowncast) { } } +TEST(Cast, PrimitiveToExtension) { + { + auto primitive_array = ArrayFromJSON(uint8(), "[0, 1, 3]"); + auto extension_array = SmallintArrayFromJSON("[0, 1, 3]"); + CastOptions options; + options.to_type = smallint(); + CheckCast(primitive_array, extension_array, options); + } + { + CastOptions options; + options.to_type = smallint(); + CheckCastFails(ArrayFromJSON(utf8(), "[\"hello\"]"), options); + } +} + +TEST(Cast, ExtensionDictToExtension) { + auto extension_array = SmallintArrayFromJSON("[1, 2, 1]"); + auto indices_array = ArrayFromJSON(int32(), "[0, 1, 0]"); + + ASSERT_OK_AND_ASSIGN(auto dict_array, + DictionaryArray::FromArrays(indices_array, extension_array)); + + CastOptions options; + options.to_type = smallint(); + CheckCast(dict_array, extension_array, options); +} + +TEST(Cast, IntToExtensionTypeDowncast) { + CheckCast(ArrayFromJSON(uint8(), "[0, 100, 200, 1, 2]"), + SmallintArrayFromJSON("[0, 100, 200, 1, 2]")); + + // int32 to Smallint(int16), with overflow + { + CastOptions options; + options.to_type = smallint(); + CheckCastFails(ArrayFromJSON(int32(), "[0, null, 32768, 1, 3]"), options); + + options.allow_int_overflow = true; + CheckCast(ArrayFromJSON(int32(), "[0, null, 32768, 1, 3]"), + SmallintArrayFromJSON("[0, null, -32768, 1, 3]"), options); + } + + // int32 to Smallint(int16), with underflow + { + CastOptions options; + options.to_type = smallint(); + CheckCastFails(ArrayFromJSON(int32(), "[0, null, -32769, 1, 3]"), options); + + options.allow_int_overflow = true; + CheckCast(ArrayFromJSON(int32(), "[0, null, -32769, 1, 3]"), + SmallintArrayFromJSON("[0, null, 32767, 1, 3]"), options); + } + + // Cannot cast between extension types when storage types differ + { + CastOptions options; + options.to_type = smallint(); + auto tiny_array = TinyintArrayFromJSON("[0, 1, 3]"); + ASSERT_NOT_OK(Cast(tiny_array, smallint(), options)); + } +} + TEST(Cast, DictTypeToAnotherDict) { auto check_cast = [&](const std::shared_ptr& in_type, const std::shared_ptr& out_type, diff --git a/cpp/src/arrow/testing/extension_type.h b/cpp/src/arrow/testing/extension_type.h index 338b4cb4da0..846e3c7a165 100644 --- a/cpp/src/arrow/testing/extension_type.h +++ b/cpp/src/arrow/testing/extension_type.h @@ -54,6 +54,11 @@ class ARROW_TESTING_EXPORT SmallintArray : public ExtensionArray { using ExtensionArray::ExtensionArray; }; +class ARROW_TESTING_EXPORT TinyintArray : public ExtensionArray { + public: + using ExtensionArray::ExtensionArray; +}; + class ARROW_TESTING_EXPORT ListExtensionArray : public ExtensionArray { public: using ExtensionArray::ExtensionArray; @@ -76,6 +81,23 @@ class ARROW_TESTING_EXPORT SmallintType : public ExtensionType { std::string Serialize() const override { return "smallint"; } }; +class ARROW_TESTING_EXPORT TinyintType : public ExtensionType { + public: + TinyintType() : ExtensionType(int8()) {} + + std::string extension_name() const override { return "tinyint"; } + + bool ExtensionEquals(const ExtensionType& other) const override; + + std::shared_ptr MakeArray(std::shared_ptr data) const override; + + Result> Deserialize( + std::shared_ptr storage_type, + const std::string& serialized) const override; + + std::string Serialize() const override { return "tinyint"; } +}; + class ARROW_TESTING_EXPORT ListExtensionType : public ExtensionType { public: ListExtensionType() : ExtensionType(list(int32())) {} @@ -140,6 +162,9 @@ std::shared_ptr uuid(); ARROW_TESTING_EXPORT std::shared_ptr smallint(); +ARROW_TESTING_EXPORT +std::shared_ptr tinyint(); + ARROW_TESTING_EXPORT std::shared_ptr list_extension_type(); @@ -155,6 +180,9 @@ std::shared_ptr ExampleUuid(); ARROW_TESTING_EXPORT std::shared_ptr ExampleSmallint(); +ARROW_TESTING_EXPORT +std::shared_ptr ExampleTinyint(); + ARROW_TESTING_EXPORT std::shared_ptr ExampleDictExtension(); diff --git a/cpp/src/arrow/testing/gtest_util.cc b/cpp/src/arrow/testing/gtest_util.cc index a4d86708800..18f43da72a3 100644 --- a/cpp/src/arrow/testing/gtest_util.cc +++ b/cpp/src/arrow/testing/gtest_util.cc @@ -64,47 +64,6 @@ namespace arrow { using internal::checked_cast; using internal::checked_pointer_cast; -std::vector AllTypeIds() { - return {Type::NA, - Type::BOOL, - Type::INT8, - Type::INT16, - Type::INT32, - Type::INT64, - Type::UINT8, - Type::UINT16, - Type::UINT32, - Type::UINT64, - Type::HALF_FLOAT, - Type::FLOAT, - Type::DOUBLE, - Type::DECIMAL128, - Type::DECIMAL256, - Type::DATE32, - Type::DATE64, - Type::TIME32, - Type::TIME64, - Type::TIMESTAMP, - Type::INTERVAL_DAY_TIME, - Type::INTERVAL_MONTHS, - Type::DURATION, - Type::STRING, - Type::BINARY, - Type::LARGE_STRING, - Type::LARGE_BINARY, - Type::FIXED_SIZE_BINARY, - Type::STRUCT, - Type::LIST, - Type::LARGE_LIST, - Type::FIXED_SIZE_LIST, - Type::MAP, - Type::DENSE_UNION, - Type::SPARSE_UNION, - Type::DICTIONARY, - Type::EXTENSION, - Type::INTERVAL_MONTH_DAY_NANO}; -} - template void AssertTsSame(const T& expected, const T& actual, CompareFunctor&& compare) { if (!compare(actual, expected)) { @@ -832,6 +791,28 @@ Result> SmallintType::Deserialize( return std::make_shared(); } +bool TinyintType::ExtensionEquals(const ExtensionType& other) const { + return (other.extension_name() == this->extension_name()); +} + +std::shared_ptr TinyintType::MakeArray(std::shared_ptr data) const { + DCHECK_EQ(data->type->id(), Type::EXTENSION); + DCHECK_EQ("tinyint", static_cast(*data->type).extension_name()); + return std::make_shared(data); +} + +Result> TinyintType::Deserialize( + std::shared_ptr storage_type, const std::string& serialized) const { + if (serialized != "tinyint") { + return Status::Invalid("Type identifier did not match: '", serialized, "'"); + } + if (!storage_type->Equals(*int16())) { + return Status::Invalid("Invalid storage type for TinyintType: ", + storage_type->ToString()); + } + return std::make_shared(); +} + bool ListExtensionType::ExtensionEquals(const ExtensionType& other) const { return (other.extension_name() == this->extension_name()); } @@ -905,6 +886,8 @@ std::shared_ptr uuid() { return std::make_shared(); } std::shared_ptr smallint() { return std::make_shared(); } +std::shared_ptr tinyint() { return std::make_shared(); } + std::shared_ptr list_extension_type() { return std::make_shared(); } @@ -936,6 +919,11 @@ std::shared_ptr ExampleSmallint() { return ExtensionType::WrapArray(smallint(), arr); } +std::shared_ptr ExampleTinyint() { + auto arr = ArrayFromJSON(int8(), "[-128, null, 1, 2, 3, 4, 127]"); + return ExtensionType::WrapArray(tinyint(), arr); +} + std::shared_ptr ExampleDictExtension() { auto arr = DictArrayFromJSON(dictionary(int8(), utf8()), "[0, 1, null, 1]", R"(["foo", "bar"])"); diff --git a/cpp/src/arrow/testing/gtest_util.h b/cpp/src/arrow/testing/gtest_util.h index 8ce5049452a..1408042d994 100644 --- a/cpp/src/arrow/testing/gtest_util.h +++ b/cpp/src/arrow/testing/gtest_util.h @@ -190,9 +190,6 @@ class RecordBatch; class Table; struct Datum; -ARROW_TESTING_EXPORT -std::vector AllTypeIds(); - #define ASSERT_ARRAYS_EQUAL(lhs, rhs) AssertArraysEqual((lhs), (rhs)) #define ASSERT_BATCHES_EQUAL(lhs, rhs) AssertBatchesEqual((lhs), (rhs)) #define ASSERT_BATCHES_APPROX_EQUAL(lhs, rhs) AssertBatchesApproxEqual((lhs), (rhs)) diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index efff07db667..a3285cf92f5 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -95,6 +95,47 @@ constexpr Type::type DurationType::type_id; constexpr Type::type DictionaryType::type_id; +std::vector AllTypeIds() { + return {Type::NA, + Type::BOOL, + Type::INT8, + Type::INT16, + Type::INT32, + Type::INT64, + Type::UINT8, + Type::UINT16, + Type::UINT32, + Type::UINT64, + Type::HALF_FLOAT, + Type::FLOAT, + Type::DOUBLE, + Type::DECIMAL128, + Type::DECIMAL256, + Type::DATE32, + Type::DATE64, + Type::TIME32, + Type::TIME64, + Type::TIMESTAMP, + Type::INTERVAL_DAY_TIME, + Type::INTERVAL_MONTHS, + Type::DURATION, + Type::STRING, + Type::BINARY, + Type::LARGE_STRING, + Type::LARGE_BINARY, + Type::FIXED_SIZE_BINARY, + Type::STRUCT, + Type::LIST, + Type::LARGE_LIST, + Type::FIXED_SIZE_LIST, + Type::MAP, + Type::DENSE_UNION, + Type::SPARSE_UNION, + Type::DICTIONARY, + Type::EXTENSION, + Type::INTERVAL_MONTH_DAY_NANO}; +} + namespace internal { struct TypeIdToTypeNameVisitor { diff --git a/cpp/src/arrow/type_fwd.h b/cpp/src/arrow/type_fwd.h index 84a50a12eb3..e2bace974e2 100644 --- a/cpp/src/arrow/type_fwd.h +++ b/cpp/src/arrow/type_fwd.h @@ -410,6 +410,9 @@ struct Type { }; }; +/// \brief Get a vector of all type ids +ARROW_EXPORT std::vector AllTypeIds(); + /// \defgroup type-factories Factory functions for creating data types /// /// Factory functions for creating data types diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index a354f42a4b1..2e3c6ab9de8 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -1377,6 +1377,8 @@ null input value is converted into a null output value. +-----------------------------+------------------------------------+---------+ | Null | Any | | +-----------------------------+------------------------------------+---------+ +| Any | Extension | \(3) | ++-----------------------------+------------------------------------+---------+ * \(1) The dictionary indices are unchanged, the dictionary values are cast from the input value type to the output value type (if a conversion @@ -1386,6 +1388,9 @@ null input value is converted into a null output value. input value type to the output value type (if a conversion is available). +* \(3) Any input type that can be cast to the resulting extension's storage type. + This excludes extension types, unless being cast to the same extension type. + Temporal component extraction ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py index 9c5a394f895..926790cfe06 100644 --- a/python/pyarrow/tests/test_extension_type.py +++ b/python/pyarrow/tests/test_extension_type.py @@ -25,6 +25,15 @@ import pytest +class TinyIntType(pa.PyExtensionType): + + def __init__(self): + pa.PyExtensionType.__init__(self, pa.int8()) + + def __reduce__(self): + return TinyIntType, () + + class IntegerType(pa.PyExtensionType): def __init__(self): @@ -34,6 +43,15 @@ def __reduce__(self): return IntegerType, () +class IntegerEmbeddedType(pa.PyExtensionType): + + def __init__(self): + pa.PyExtensionType.__init__(self, IntegerType()) + + def __reduce__(self): + return IntegerEmbeddedType, () + + class UuidScalarType(pa.ExtensionScalar): def as_py(self): return None if self.value is None else UUID(bytes=self.value.as_py()) @@ -57,7 +75,7 @@ def __init__(self): pa.PyExtensionType.__init__(self, pa.binary(16)) def __reduce__(self): - return UuidType, () + return UuidType2, () class ParamExtType(pa.PyExtensionType): @@ -517,10 +535,85 @@ def test_cast_kernel_on_extension_arrays(): assert isinstance(casted, pa.ChunkedArray) -def test_casting_to_extension_type_raises(): - arr = pa.array([1, 2, 3, 4], pa.int64()) - with pytest.raises(pa.ArrowNotImplementedError): - arr.cast(IntegerType()) +@pytest.mark.parametrize("data,ty", ( + ([1, 2], pa.int32), + ([1, 2], pa.int64), + (["1", "2"], pa.string), + ([b"1", b"2"], pa.binary), + ([1.0, 2.0], pa.float32), + ([1.0, 2.0], pa.float64) +)) +def test_casting_to_extension_type(data, ty): + arr = pa.array(data, ty()) + out = arr.cast(IntegerType()) + assert isinstance(out, pa.ExtensionArray) + assert out.type == IntegerType() + assert out.to_pylist() == [1, 2] + + +def test_cast_between_extension_types(): + array = pa.array([1, 2, 3], pa.int8()) + + tiny_int_arr = array.cast(TinyIntType()) + assert tiny_int_arr.type == TinyIntType() + + # Casting between extension types w/ different storage types not okay. + msg = ("Casting from 'extension>' " + "to different extension type " + "'extension>' not permitted. " + "One can first cast to the storage type, " + "then to the extension type." + ) + with pytest.raises(TypeError, match=msg): + tiny_int_arr.cast(IntegerType()) + tiny_int_arr.cast(pa.int64()).cast(IntegerType()) + + # Between the same extension types is okay + array = pa.array([b'1' * 16, b'2' * 16], pa.binary(16)).cast(UuidType()) + out = array.cast(UuidType()) + assert out.type == UuidType() + + # Will still fail casting between extensions who share storage type, + # can only cast between exactly the same extension types. + with pytest.raises(TypeError, match='Casting from *'): + array.cast(UuidType2()) + + +def test_cast_to_extension_with_extension_storage(): + # Test casting directly, and IntegerType -> IntegerEmbeddedType + array = pa.array([1, 2, 3], pa.int64()) + array.cast(IntegerEmbeddedType()) + array.cast(IntegerType()).cast(IntegerEmbeddedType()) + + +@pytest.mark.parametrize("data,type_factory", ( + # list + ([[1, 2, 3]], lambda: pa.list_(IntegerType())), + # struct + ([{"foo": 1}], lambda: pa.struct([("foo", IntegerType())])), + # list> + ([[{"foo": 1}]], lambda: pa.list_(pa.struct([("foo", IntegerType())]))), + # struct> + ([{"foo": [1, 2, 3]}], lambda: pa.struct( + [("foo", pa.list_(IntegerType()))])), +)) +def test_cast_nested_extension_types(data, type_factory): + ty = type_factory() + a = pa.array(data) + b = a.cast(ty) + assert b.type == ty # casted to target extension + assert b.cast(a.type) # and can cast back + + +def test_casting_dict_array_to_extension_type(): + storage = pa.array([b"0123456789abcdef"], type=pa.binary(16)) + arr = pa.ExtensionArray.from_storage(UuidType(), storage) + dict_arr = pa.DictionaryArray.from_arrays(pa.array([0, 0], pa.int32()), + arr) + out = dict_arr.cast(UuidType()) + assert isinstance(out, pa.ExtensionArray) + assert out.to_pylist() == [UUID('30313233-3435-3637-3839-616263646566'), + UUID('30313233-3435-3637-3839-616263646566')] def test_null_storage_type():