diff --git a/c_glib/arrow-glib/compute.cpp b/c_glib/arrow-glib/compute.cpp index 193ba2837b9..f3a29be5e43 100644 --- a/c_glib/arrow-glib/compute.cpp +++ b/c_glib/arrow-glib/compute.cpp @@ -5284,12 +5284,9 @@ GArrowCastOptions * garrow_cast_options_new_raw(const arrow::compute::CastOptions *arrow_options) { GArrowDataType *to_data_type = NULL; - if (arrow_options->to_type) { - auto arrow_copied_options = arrow_options->Copy(); - auto arrow_copied_cast_options = - static_cast(arrow_copied_options.get()); - to_data_type = - garrow_data_type_new_raw(&(arrow_copied_cast_options->to_type)); + if (arrow_options->to_type.type) { + auto arrow_to_data_type = arrow_options->to_type.GetSharedPtr(); + to_data_type = garrow_data_type_new_raw(&arrow_to_data_type); } auto options = g_object_new(GARROW_TYPE_CAST_OPTIONS, diff --git a/c_glib/arrow-glib/scalar.cpp b/c_glib/arrow-glib/scalar.cpp index cef11578e1c..f8699f34eea 100644 --- a/c_glib/arrow-glib/scalar.cpp +++ b/c_glib/arrow-glib/scalar.cpp @@ -2401,9 +2401,31 @@ garrow_sparse_union_scalar_new(GArrowSparseUnionDataType *data_type, gint8 type_code, GArrowScalar *value) { - return GARROW_SPARSE_UNION_SCALAR( - garrow_union_scalar_new( - GARROW_DATA_TYPE(data_type), type_code, value)); + auto arrow_data_type = garrow_data_type_get_raw(GARROW_DATA_TYPE(data_type)); + const auto &arrow_type_codes = + std::dynamic_pointer_cast( + arrow_data_type)->type_codes(); + auto arrow_value = garrow_scalar_get_raw(value); + arrow::SparseUnionScalar::ValueType arrow_field_values; + for (int i = 0; i < arrow_data_type->num_fields(); ++i) { + if (arrow_type_codes[i] == type_code) { + arrow_field_values.emplace_back(arrow_value); + } else { + arrow_field_values.emplace_back( + arrow::MakeNullScalar(arrow_data_type->field(i)->type())); + } + } + auto arrow_scalar = + std::static_pointer_cast( + std::make_shared(arrow_field_values, + type_code, + arrow_data_type)); + auto scalar = garrow_scalar_new_raw(&arrow_scalar, + "scalar", &arrow_scalar, + "data-type", data_type, + "value", value, + NULL); + return GARROW_SPARSE_UNION_SCALAR(scalar); } @@ -2436,9 +2458,19 @@ garrow_dense_union_scalar_new(GArrowDenseUnionDataType *data_type, gint8 type_code, GArrowScalar *value) { - return GARROW_DENSE_UNION_SCALAR( - garrow_union_scalar_new( - GARROW_DATA_TYPE(data_type), type_code, value)); + auto arrow_data_type = garrow_data_type_get_raw(GARROW_DATA_TYPE(data_type)); + auto arrow_value = garrow_scalar_get_raw(value); + auto arrow_scalar = + std::static_pointer_cast( + std::make_shared(arrow_value, + type_code, + arrow_data_type)); + auto scalar = garrow_scalar_new_raw(&arrow_scalar, + "scalar", &arrow_scalar, + "data-type", data_type, + "value", value, + NULL); + return GARROW_DENSE_UNION_SCALAR(scalar); } diff --git a/c_glib/test/test-large-binary-scalar.rb b/c_glib/test/test-large-binary-scalar.rb index a6bc4addb10..d716e13f3ea 100644 --- a/c_glib/test/test-large-binary-scalar.rb +++ b/c_glib/test/test-large-binary-scalar.rb @@ -38,7 +38,11 @@ def test_equal end def test_to_s - assert_equal("...", @scalar.to_s) + assert_equal(<<-BINARY.strip, @scalar.to_s) +[ + 030102 +] + BINARY end def test_value diff --git a/c_glib/test/test-large-string-scalar.rb b/c_glib/test/test-large-string-scalar.rb index 13e28f647ac..42e24a601b4 100644 --- a/c_glib/test/test-large-string-scalar.rb +++ b/c_glib/test/test-large-string-scalar.rb @@ -38,7 +38,11 @@ def test_equal end def test_to_s - assert_equal("...", @scalar.to_s) + assert_equal(<<-STRING.strip, @scalar.to_s) +[ + "Hello" +] + STRING end def test_value diff --git a/c_glib/test/test-list-scalar.rb b/c_glib/test/test-list-scalar.rb index 3fda3f25bbb..0ddbf60bc05 100644 --- a/c_glib/test/test-list-scalar.rb +++ b/c_glib/test/test-list-scalar.rb @@ -41,7 +41,17 @@ def test_equal end def test_to_s - assert_equal("...", @scalar.to_s) + assert_equal(<<-LIST.strip, @scalar.to_s) +[ + [ + [ + 1, + 2, + 3 + ] + ] +] + LIST end def test_value diff --git a/c_glib/test/test-map-scalar.rb b/c_glib/test/test-map-scalar.rb index 9c6eb69e0a8..1e004569ef3 100644 --- a/c_glib/test/test-map-scalar.rb +++ b/c_glib/test/test-map-scalar.rb @@ -56,7 +56,20 @@ def test_equal end def test_to_s - assert_equal("...", @scalar.to_s) + assert_equal(<<-MAP.strip, @scalar.to_s) +[ + keys: + [ + "hello", + "world" + ] + values: + [ + 1, + 2 + ] +] + MAP end def test_value diff --git a/cpp/examples/arrow/compute_register_example.cc b/cpp/examples/arrow/compute_register_example.cc index 13d80b29631..113dfd0faf3 100644 --- a/cpp/examples/arrow/compute_register_example.cc +++ b/cpp/examples/arrow/compute_register_example.cc @@ -127,8 +127,7 @@ const cp::FunctionDoc func_doc{ int main(int argc, char** argv) { const std::string name = "compute_register_example"; auto func = std::make_shared(name, cp::Arity::Unary(), func_doc); - cp::ScalarKernel kernel({cp::InputType::Array(arrow::int64())}, arrow::int64(), - ExampleFunctionImpl); + cp::ScalarKernel kernel({arrow::int64()}, arrow::int64(), ExampleFunctionImpl); kernel.mem_allocation = cp::MemAllocation::NO_PREALLOCATE; ABORT_ON_FAILURE(func->AddKernel(std::move(kernel))); diff --git a/cpp/examples/arrow/udf_example.cc b/cpp/examples/arrow/udf_example.cc index 47c45411477..ccd804339a2 100644 --- a/cpp/examples/arrow/udf_example.cc +++ b/cpp/examples/arrow/udf_example.cc @@ -75,10 +75,8 @@ arrow::Status SampleFunction(cp::KernelContext* ctx, const cp::ExecSpan& batch, arrow::Status Execute() { const std::string name = "add_three"; auto func = std::make_shared(name, cp::Arity::Ternary(), func_doc); - cp::ScalarKernel kernel( - {cp::InputType::Array(arrow::int64()), cp::InputType::Array(arrow::int64()), - cp::InputType::Array(arrow::int64())}, - arrow::int64(), SampleFunction); + cp::ScalarKernel kernel({arrow::int64(), arrow::int64(), arrow::int64()}, + arrow::int64(), SampleFunction); kernel.mem_allocation = cp::MemAllocation::PREALLOCATE; kernel.null_handling = cp::NullHandling::INTERSECTION; diff --git a/cpp/gdb_arrow.py b/cpp/gdb_arrow.py index cd687ec8b2e..af3dad9c087 100644 --- a/cpp/gdb_arrow.py +++ b/cpp/gdb_arrow.py @@ -1406,13 +1406,12 @@ class FixedSizeBinaryScalarPrinter(BaseBinaryScalarPrinter): def to_string(self): size = self.type['byte_width_'] - if not self.is_valid: - return f"{self._format_type()} of size {size}, null value" bufptr = BufferPtr(SharedPtr(self.val['value']).get()) if bufptr.data is None: return f"{self._format_type()} of size {size}, " + nullness = '' if self.is_valid else 'null with ' return (f"{self._format_type()} of size {size}, " - f"value {self._format_buf(bufptr)}") + f"{nullness}value {self._format_buf(bufptr)}") class DictionaryScalarPrinter(ScalarPrinter): @@ -1450,6 +1449,8 @@ def display_hint(self): return 'map' def children(self): + if not self.is_valid: + return None eval_fields = StdVector(self.type['children_']) eval_values = StdVector(self.val['value']) for field, value in zip(eval_fields, eval_values): @@ -1463,7 +1464,24 @@ def to_string(self): return f"{self._format_type()}" -class UnionScalarPrinter(ScalarPrinter): +class SparseUnionScalarPrinter(ScalarPrinter): + """ + Pretty-printer for arrow::UnionScalar and subclasses. + """ + + def to_string(self): + type_code = self.val['type_code'].cast(gdb.lookup_type('int')) + if not self.is_valid: + return (f"{self._format_type()} of type {self.type}, " + f"type code {type_code}, null value") + eval_values = StdVector(self.val['value']) + child_id = self.val['child_id'].cast(gdb.lookup_type('int')) + return (f"{self._format_type()} of type code {type_code}, " + f"value {deref(eval_values[child_id])}") + + + +class DenseUnionScalarPrinter(ScalarPrinter): """ Pretty-printer for arrow::UnionScalar and subclasses. """ @@ -1968,10 +1986,16 @@ class StructTypeClass(DataTypeClass): scalar_printer = StructScalarPrinter -class UnionTypeClass(DataTypeClass): +class DenseUnionTypeClass(DataTypeClass): + is_parametric = True + type_printer = UnionTypePrinter + scalar_printer = DenseUnionScalarPrinter + + +class SparseUnionTypeClass(DataTypeClass): is_parametric = True type_printer = UnionTypePrinter - scalar_printer = UnionScalarPrinter + scalar_printer = SparseUnionScalarPrinter class DictionaryTypeClass(DataTypeClass): @@ -2037,8 +2061,8 @@ class ExtensionTypeClass(DataTypeClass): Type.MAP: DataTypeTraits(MapTypeClass, 'MapType'), Type.STRUCT: DataTypeTraits(StructTypeClass, 'StructType'), - Type.SPARSE_UNION: DataTypeTraits(UnionTypeClass, 'SparseUnionType'), - Type.DENSE_UNION: DataTypeTraits(UnionTypeClass, 'DenseUnionType'), + Type.SPARSE_UNION: DataTypeTraits(SparseUnionTypeClass, 'SparseUnionType'), + Type.DENSE_UNION: DataTypeTraits(DenseUnionTypeClass, 'DenseUnionType'), Type.DICTIONARY: DataTypeTraits(DictionaryTypeClass, 'DictionaryType'), Type.EXTENSION: DataTypeTraits(ExtensionTypeClass, 'ExtensionType'), diff --git a/cpp/src/arrow/array/array_base.cc b/cpp/src/arrow/array/array_base.cc index b36fb0fb94a..5d27b2aedfb 100644 --- a/cpp/src/arrow/array/array_base.cc +++ b/cpp/src/arrow/array/array_base.cc @@ -104,16 +104,15 @@ struct ScalarFromArraySlotImpl { } Status Visit(const SparseUnionArray& a) { - const auto type_code = a.type_code(index_); - // child array which stores the actual value - const auto arr = a.field(a.child_id(index_)); - // no need to adjust the index - ARROW_ASSIGN_OR_RAISE(auto value, arr->GetScalar(index_)); - if (value->is_valid) { - out_ = std::shared_ptr(new SparseUnionScalar(value, type_code, a.type())); - } else { - out_ = std::shared_ptr(new SparseUnionScalar(type_code, a.type())); + int8_t type_code = a.type_code(index_); + + ScalarVector children; + for (int i = 0; i < a.type()->num_fields(); ++i) { + children.emplace_back(); + ARROW_ASSIGN_OR_RAISE(children.back(), a.field(i)->GetScalar(index_)); } + + out_ = std::make_shared(std::move(children), type_code, a.type()); return Status::OK(); } @@ -124,11 +123,7 @@ struct ScalarFromArraySlotImpl { // need to look up the value based on offsets auto offset = a.value_offset(index_); ARROW_ASSIGN_OR_RAISE(auto value, arr->GetScalar(offset)); - if (value->is_valid) { - out_ = std::shared_ptr(new DenseUnionScalar(value, type_code, a.type())); - } else { - out_ = std::shared_ptr(new DenseUnionScalar(type_code, a.type())); - } + out_ = std::make_shared(value, type_code, a.type()); return Status::OK(); } diff --git a/cpp/src/arrow/array/array_test.cc b/cpp/src/arrow/array/array_test.cc index 0d9afba6ece..d438557a330 100644 --- a/cpp/src/arrow/array/array_test.cc +++ b/cpp/src/arrow/array/array_test.cc @@ -561,16 +561,16 @@ static ScalarVector GetScalars() { }, struct_({field("min", int32()), field("max", int32())})), // Same values, different union type codes - std::make_shared(std::make_shared(100), 6, - sparse_union_ty), - std::make_shared(std::make_shared(100), 42, - sparse_union_ty), - std::make_shared(42, sparse_union_ty), + SparseUnionScalar::FromValue(std::make_shared(100), 1, + sparse_union_ty), + SparseUnionScalar::FromValue(std::make_shared(100), 2, + sparse_union_ty), + SparseUnionScalar::FromValue(MakeNullScalar(int32()), 2, sparse_union_ty), std::make_shared(std::make_shared(101), 6, dense_union_ty), std::make_shared(std::make_shared(101), 42, dense_union_ty), - std::make_shared(42, dense_union_ty), + std::make_shared(MakeNullScalar(int32()), 42, dense_union_ty), DictionaryScalar::Make(ScalarFromJSON(int8(), "1"), ArrayFromJSON(utf8(), R"(["foo", "bar"])")), DictionaryScalar::Make(ScalarFromJSON(uint8(), "1"), diff --git a/cpp/src/arrow/array/builder_base.cc b/cpp/src/arrow/array/builder_base.cc index 49abd8e0234..ff37cee5ba1 100644 --- a/cpp/src/arrow/array/builder_base.cc +++ b/cpp/src/arrow/array/builder_base.cc @@ -34,6 +34,8 @@ namespace arrow { +using internal::checked_cast; + Status ArrayBuilder::CheckArrayType(const std::shared_ptr& expected_type, const Array& array, const char* message) { if (!expected_type->Equals(*array.type())) { @@ -105,14 +107,13 @@ struct AppendScalarImpl { is_fixed_size_binary_type::value, Status> Visit(const T&) { - auto builder = internal::checked_cast::BuilderType*>(builder_); + auto builder = checked_cast::BuilderType*>(builder_); RETURN_NOT_OK(builder->Reserve(n_repeats_ * (scalars_end_ - scalars_begin_))); for (int64_t i = 0; i < n_repeats_; i++) { for (const std::shared_ptr* raw = scalars_begin_; raw != scalars_end_; raw++) { - auto scalar = - internal::checked_cast::ScalarType*>(raw->get()); + auto scalar = checked_cast::ScalarType*>(raw->get()); if (scalar->is_valid) { builder->UnsafeAppend(scalar->value); } else { @@ -128,22 +129,20 @@ struct AppendScalarImpl { int64_t data_size = 0; for (const std::shared_ptr* raw = scalars_begin_; raw != scalars_end_; raw++) { - auto scalar = - internal::checked_cast::ScalarType*>(raw->get()); + auto scalar = checked_cast::ScalarType*>(raw->get()); if (scalar->is_valid) { data_size += scalar->value->size(); } } - auto builder = internal::checked_cast::BuilderType*>(builder_); + auto builder = checked_cast::BuilderType*>(builder_); RETURN_NOT_OK(builder->Reserve(n_repeats_ * (scalars_end_ - scalars_begin_))); RETURN_NOT_OK(builder->ReserveData(n_repeats_ * data_size)); for (int64_t i = 0; i < n_repeats_; i++) { for (const std::shared_ptr* raw = scalars_begin_; raw != scalars_end_; raw++) { - auto scalar = - internal::checked_cast::ScalarType*>(raw->get()); + auto scalar = checked_cast::ScalarType*>(raw->get()); if (scalar->is_valid) { builder->UnsafeAppend(util::string_view{*scalar->value}); } else { @@ -156,13 +155,12 @@ struct AppendScalarImpl { template enable_if_list_like Visit(const T&) { - auto builder = internal::checked_cast::BuilderType*>(builder_); + auto builder = checked_cast::BuilderType*>(builder_); int64_t num_children = 0; for (const std::shared_ptr* scalar = scalars_begin_; scalar != scalars_end_; scalar++) { if (!(*scalar)->is_valid) continue; - num_children += - internal::checked_cast(**scalar).value->length(); + num_children += checked_cast(**scalar).value->length(); } RETURN_NOT_OK(builder->value_builder()->Reserve(num_children * n_repeats_)); @@ -171,8 +169,7 @@ struct AppendScalarImpl { scalar++) { if ((*scalar)->is_valid) { RETURN_NOT_OK(builder->Append()); - const Array& list = - *internal::checked_cast(**scalar).value; + const Array& list = *checked_cast(**scalar).value; for (int64_t i = 0; i < list.length(); i++) { ARROW_ASSIGN_OR_RAISE(auto scalar, list.GetScalar(i)); RETURN_NOT_OK(builder->value_builder()->AppendScalar(*scalar)); @@ -186,7 +183,7 @@ struct AppendScalarImpl { } Status Visit(const StructType& type) { - auto* builder = internal::checked_cast(builder_); + auto* builder = checked_cast(builder_); auto count = n_repeats_ * (scalars_end_ - scalars_begin_); RETURN_NOT_OK(builder->Reserve(count)); for (int field_index = 0; field_index < type.num_fields(); ++field_index) { @@ -194,7 +191,7 @@ struct AppendScalarImpl { } for (int64_t i = 0; i < n_repeats_; i++) { for (const std::shared_ptr* s = scalars_begin_; s != scalars_end_; s++) { - const auto& scalar = internal::checked_cast(**s); + const auto& scalar = checked_cast(**s); for (int field_index = 0; field_index < type.num_fields(); ++field_index) { if (!scalar.is_valid || !scalar.value[field_index]) { RETURN_NOT_OK(builder->field_builder(field_index)->AppendNull()); @@ -213,12 +210,54 @@ struct AppendScalarImpl { Status Visit(const DenseUnionType& type) { return MakeUnionArray(type); } + Status AppendUnionScalar(const DenseUnionType& type, const Scalar& s, + DenseUnionBuilder* builder) { + const auto& scalar = checked_cast(s); + const auto scalar_field_index = type.child_ids()[scalar.type_code]; + RETURN_NOT_OK(builder->Append(scalar.type_code)); + + for (int field_index = 0; field_index < type.num_fields(); ++field_index) { + auto* child_builder = builder->child_builder(field_index).get(); + if (field_index == scalar_field_index) { + if (scalar.is_valid) { + RETURN_NOT_OK(child_builder->AppendScalar(*scalar.value)); + } else { + RETURN_NOT_OK(child_builder->AppendNull()); + } + } + } + return Status::OK(); + } + + Status AppendUnionScalar(const SparseUnionType& type, const Scalar& s, + SparseUnionBuilder* builder) { + // For each scalar, + // 1. append the type code, + // 2. append the value to the corresponding child, + // 3. append null to the other children. + const auto& scalar = checked_cast(s); + RETURN_NOT_OK(builder->Append(scalar.type_code)); + + for (int field_index = 0; field_index < type.num_fields(); ++field_index) { + auto* child_builder = builder->child_builder(field_index).get(); + if (field_index == scalar.child_id) { + if (scalar.is_valid) { + RETURN_NOT_OK(child_builder->AppendScalar(*scalar.value[field_index])); + } else { + RETURN_NOT_OK(child_builder->AppendNull()); + } + } else { + RETURN_NOT_OK(child_builder->AppendNull()); + } + } + return Status::OK(); + } + template Status MakeUnionArray(const T& type) { using BuilderType = typename TypeTraits::BuilderType; - constexpr bool is_dense = std::is_same::value; - auto* builder = internal::checked_cast(builder_); + auto* builder = checked_cast(builder_); const auto count = n_repeats_ * (scalars_end_ - scalars_begin_); RETURN_NOT_OK(builder->Reserve(count)); @@ -230,26 +269,7 @@ struct AppendScalarImpl { for (int64_t i = 0; i < n_repeats_; i++) { for (const std::shared_ptr* s = scalars_begin_; s != scalars_end_; s++) { - // For each scalar, - // 1. append the type code, - // 2. append the value to the corresponding child, - // 3. if the union is sparse, append null to the other children. - const auto& scalar = internal::checked_cast(**s); - const auto scalar_field_index = type.child_ids()[scalar.type_code]; - RETURN_NOT_OK(builder->Append(scalar.type_code)); - - for (int field_index = 0; field_index < type.num_fields(); ++field_index) { - auto* child_builder = builder->child_builder(field_index).get(); - if (field_index == scalar_field_index) { - if (scalar.is_valid) { - RETURN_NOT_OK(child_builder->AppendScalar(*scalar.value)); - } else { - RETURN_NOT_OK(child_builder->AppendNull()); - } - } else if (!is_dense) { - RETURN_NOT_OK(child_builder->AppendNull()); - } - } + RETURN_NOT_OK(AppendUnionScalar(type, **s, builder)); } } return Status::OK(); diff --git a/cpp/src/arrow/array/builder_nested.h b/cpp/src/arrow/array/builder_nested.h index 3d36cb5f65e..306d861b09f 100644 --- a/cpp/src/arrow/array/builder_nested.h +++ b/cpp/src/arrow/array/builder_nested.h @@ -304,10 +304,12 @@ class ARROW_EXPORT MapBuilder : public ArrayBuilder { if (!validity || bit_util::GetBit(validity, array.offset + row)) { ARROW_RETURN_NOT_OK(Append()); const int64_t slot_length = offsets[row + 1] - offsets[row]; + // Add together the inner StructArray offset to the Map/List offset + int64_t key_value_offset = array.child_data[0].offset + offsets[row]; ARROW_RETURN_NOT_OK(key_builder_->AppendArraySlice( - array.child_data[0].child_data[0], offsets[row], slot_length)); + array.child_data[0].child_data[0], key_value_offset, slot_length)); ARROW_RETURN_NOT_OK(item_builder_->AppendArraySlice( - array.child_data[0].child_data[1], offsets[row], slot_length)); + array.child_data[0].child_data[1], key_value_offset, slot_length)); } else { ARROW_RETURN_NOT_OK(AppendNull()); } diff --git a/cpp/src/arrow/array/data.cc b/cpp/src/arrow/array/data.cc index 37db8ccb775..c1a597fea62 100644 --- a/cpp/src/arrow/array/data.cc +++ b/cpp/src/arrow/array/data.cc @@ -38,6 +38,7 @@ namespace arrow { +using internal::checked_cast; using internal::CountSetBits; static inline void AdjustNonNullable(Type::type type_id, int64_t length, @@ -147,7 +148,7 @@ void ArraySpan::SetMembers(const ArrayData& data) { if (buffer) { SetBuffer(i, buffer); } else { - ClearBuffer(i); + this->buffers[i] = {}; } } @@ -160,7 +161,7 @@ void ArraySpan::SetMembers(const ArrayData& data) { // Makes sure any other buffers are seen as null / non-existent for (int i = static_cast(data.buffers.size()); i < 3; ++i) { - ClearBuffer(i); + this->buffers[i] = {}; } if (this->type->id() == Type::DICTIONARY) { @@ -174,27 +175,203 @@ void ArraySpan::SetMembers(const ArrayData& data) { } } +namespace { + +template +void SetOffsetsForScalar(ArraySpan* span, offset_type* buffer, int64_t value_size, + int buffer_index = 1) { + buffer[0] = 0; + buffer[1] = static_cast(value_size); + span->buffers[buffer_index].data = reinterpret_cast(buffer); + span->buffers[buffer_index].size = 2 * sizeof(offset_type); +} + +int GetNumBuffers(const DataType& type) { + switch (type.id()) { + case Type::NA: + case Type::STRUCT: + case Type::FIXED_SIZE_LIST: + return 1; + case Type::BINARY: + case Type::LARGE_BINARY: + case Type::STRING: + case Type::LARGE_STRING: + case Type::DENSE_UNION: + return 3; + case Type::EXTENSION: + // The number of buffers depends on the storage type + return GetNumBuffers( + *internal::checked_cast(type).storage_type()); + default: + // Everything else has 2 buffers + return 2; + } +} + +} // namespace + +namespace internal { + +void FillZeroLengthArray(const DataType* type, ArraySpan* span) { + memset(span->scratch_space, 0x00, sizeof(span->scratch_space)); + + span->type = type; + span->length = 0; + int num_buffers = GetNumBuffers(*type); + for (int i = 0; i < num_buffers; ++i) { + span->buffers[i].data = span->scratch_space; + span->buffers[i].size = 0; + } + + for (int i = num_buffers; i < 3; ++i) { + span->buffers[i] = {}; + } + + // Fill children + span->child_data.resize(type->num_fields()); + for (int i = 0; i < type->num_fields(); ++i) { + FillZeroLengthArray(type->field(i)->type().get(), &span->child_data[i]); + } +} + +} // namespace internal + void ArraySpan::FillFromScalar(const Scalar& value) { - static const uint8_t kValidByte = 0x01; - static const uint8_t kNullByte = 0x00; + static uint8_t kTrueBit = 0x01; + static uint8_t kFalseBit = 0x00; this->type = value.type.get(); this->length = 1; - // Populate null count and validity bitmap + Type::type type_id = value.type->id(); + + // Populate null count and validity bitmap (only for non-union/null types) this->null_count = value.is_valid ? 0 : 1; - this->buffers[0].data = const_cast(value.is_valid ? &kValidByte : &kNullByte); - this->buffers[0].size = 1; + if (!is_union(type_id) && type_id != Type::NA) { + this->buffers[0].data = value.is_valid ? &kTrueBit : &kFalseBit; + this->buffers[0].size = 1; + } - if (is_primitive(value.type->id())) { - const auto& scalar = - internal::checked_cast(value); + if (type_id == Type::BOOL) { + const auto& scalar = checked_cast(value); + this->buffers[1].data = scalar.value ? &kTrueBit : &kFalseBit; + this->buffers[1].size = 1; + } else if (is_primitive(type_id) || is_decimal(type_id) || + type_id == Type::DICTIONARY) { + const auto& scalar = checked_cast(value); const uint8_t* scalar_data = reinterpret_cast(scalar.view().data()); this->buffers[1].data = const_cast(scalar_data); this->buffers[1].size = scalar.type->byte_width(); + if (type_id == Type::DICTIONARY) { + // Populate dictionary data + const auto& dict_scalar = checked_cast(value); + this->child_data.resize(1); + this->child_data[0].SetMembers(*dict_scalar.value.dictionary->data()); + } + } else if (is_base_binary_like(type_id)) { + const auto& scalar = checked_cast(value); + this->buffers[1].data = this->scratch_space; + const uint8_t* data_buffer = nullptr; + int64_t data_size = 0; + if (scalar.is_valid) { + data_buffer = scalar.value->data(); + data_size = scalar.value->size(); + } + if (is_binary_like(type_id)) { + SetOffsetsForScalar(this, reinterpret_cast(this->scratch_space), + data_size); + } else { + // is_large_binary_like + SetOffsetsForScalar(this, reinterpret_cast(this->scratch_space), + data_size); + } + this->buffers[2].data = const_cast(data_buffer); + this->buffers[2].size = data_size; + } else if (type_id == Type::FIXED_SIZE_BINARY) { + const auto& scalar = checked_cast(value); + this->buffers[1].data = const_cast(scalar.value->data()); + this->buffers[1].size = scalar.value->size(); + } else if (is_list_like(type_id)) { + const auto& scalar = checked_cast(value); + + int64_t value_length = 0; + this->child_data.resize(1); + if (scalar.value != nullptr) { + // When the scalar is null, scalar.value can also be null + this->child_data[0].SetMembers(*scalar.value->data()); + value_length = scalar.value->length(); + } else { + // Even when the value is null, we still must populate the + // child_data to yield a valid array. Tedious + internal::FillZeroLengthArray(this->type->field(0)->type().get(), + &this->child_data[0]); + } + + if (type_id == Type::LIST || type_id == Type::MAP) { + SetOffsetsForScalar(this, reinterpret_cast(this->scratch_space), + value_length); + } else if (type_id == Type::LARGE_LIST) { + SetOffsetsForScalar(this, reinterpret_cast(this->scratch_space), + value_length); + } else { + // FIXED_SIZE_LIST: does not have a second buffer + this->buffers[1] = {}; + } + } else if (type_id == Type::STRUCT) { + const auto& scalar = checked_cast(value); + this->child_data.resize(this->type->num_fields()); + DCHECK_EQ(this->type->num_fields(), static_cast(scalar.value.size())); + for (size_t i = 0; i < scalar.value.size(); ++i) { + this->child_data[i].FillFromScalar(*scalar.value[i]); + } + } else if (is_union(type_id)) { + // First buffer is kept null since unions have no validity vector + this->buffers[0] = {}; + + this->buffers[1].data = this->scratch_space; + this->buffers[1].size = 1; + int8_t* type_codes = reinterpret_cast(this->scratch_space); + type_codes[0] = checked_cast(value).type_code; + + this->child_data.resize(this->type->num_fields()); + if (type_id == Type::DENSE_UNION) { + const auto& scalar = checked_cast(value); + // Has offset; start 4 bytes in so it's aligned to a 32-bit boundaries + SetOffsetsForScalar(this, + reinterpret_cast(this->scratch_space) + 1, 1, + /*buffer_index=*/2); + // We can't "see" the other arrays in the union, but we put the "active" + // union array in the right place and fill zero-length arrays for the + // others + const std::vector& child_ids = + checked_cast(this->type)->child_ids(); + DCHECK_GE(scalar.type_code, 0); + DCHECK_LT(scalar.type_code, static_cast(child_ids.size())); + for (int i = 0; i < static_cast(this->child_data.size()); ++i) { + if (i == child_ids[scalar.type_code]) { + this->child_data[i].FillFromScalar(*scalar.value); + } else { + internal::FillZeroLengthArray(this->type->field(i)->type().get(), + &this->child_data[i]); + } + } + } else { + const auto& scalar = checked_cast(value); + // Sparse union scalars have a full complement of child values even + // though only one of them is relevant, so we just fill them in here + for (int i = 0; i < static_cast(this->child_data.size()); ++i) { + this->child_data[i].FillFromScalar(*scalar.value[i]); + } + } + } else if (type_id == Type::EXTENSION) { + // Pass through storage + const auto& scalar = checked_cast(value); + FillFromScalar(*scalar.value); + + // Restore the extension type + this->type = value.type.get(); } else { - // TODO(wesm): implement for other types - DCHECK(false) << "need to implement for other types"; + DCHECK_EQ(Type::NA, type_id) << "should be unreachable: " << *value.type; } } @@ -212,40 +389,14 @@ int64_t ArraySpan::GetNullCount() const { return precomputed; } -int GetNumBuffers(const DataType& type) { - switch (type.id()) { - case Type::NA: - case Type::STRUCT: - case Type::FIXED_SIZE_LIST: - return 1; - case Type::BINARY: - case Type::LARGE_BINARY: - case Type::STRING: - case Type::LARGE_STRING: - case Type::DENSE_UNION: - return 3; - case Type::EXTENSION: - // The number of buffers depends on the storage type - return GetNumBuffers( - *internal::checked_cast(type).storage_type()); - default: - // Everything else has 2 buffers - return 2; - } -} - int ArraySpan::num_buffers() const { return GetNumBuffers(*this->type); } std::shared_ptr ArraySpan::ToArrayData() const { - auto result = std::make_shared(this->type->Copy(), this->length, + auto result = std::make_shared(this->type->GetSharedPtr(), this->length, this->null_count, this->offset); for (int i = 0; i < this->num_buffers(); ++i) { - if (this->buffers[i].owner) { - result->buffers.emplace_back(this->GetBuffer(i)); - } else { - result->buffers.push_back(nullptr); - } + result->buffers.emplace_back(this->GetBuffer(i)); } if (this->type->id() == Type::NA) { diff --git a/cpp/src/arrow/array/data.h b/cpp/src/arrow/array/data.h index df547aedfaf..fddc60293d8 100644 --- a/cpp/src/arrow/array/data.h +++ b/cpp/src/arrow/array/data.h @@ -266,6 +266,11 @@ struct ARROW_EXPORT ArraySpan { int64_t offset = 0; BufferSpan buffers[3]; + // 16 bytes of scratch space to enable this ArraySpan to be a view onto + // scalar values including binary scalars (where we need to create a buffer + // that looks like two 32-bit or 64-bit offsets) + alignas(64) uint8_t scratch_space[16]; + ArraySpan() = default; explicit ArraySpan(const DataType* type, int64_t length) : type(type), length(length) {} @@ -273,9 +278,7 @@ struct ARROW_EXPORT ArraySpan { ArraySpan(const ArrayData& data) { // NOLINT implicit conversion SetMembers(data); } - ArraySpan(const Scalar& data) { // NOLINT implicit converstion - FillFromScalar(data); - } + explicit ArraySpan(const Scalar& data) { FillFromScalar(data); } /// If dictionary-encoded, put dictionary in the first entry std::vector child_data; @@ -292,12 +295,6 @@ struct ARROW_EXPORT ArraySpan { this->buffers[index].owner = &buffer; } - void ClearBuffer(int index) { - this->buffers[index].data = NULLPTR; - this->buffers[index].size = 0; - this->buffers[index].owner = NULLPTR; - } - const ArraySpan& dictionary() const { return child_data[0]; } /// \brief Return the number of buffers (out of 3) that are used to @@ -343,10 +340,14 @@ struct ARROW_EXPORT ArraySpan { std::shared_ptr ToArray() const; std::shared_ptr GetBuffer(int index) const { - if (this->buffers[index].owner == NULLPTR) { - return NULLPTR; + const BufferSpan& buf = this->buffers[index]; + if (buf.owner) { + return *buf.owner; + } else if (buf.data != NULLPTR) { + // Buffer points to some memory without an owning buffer + return std::make_shared(buf.data, buf.size); } else { - return *this->buffers[index].owner; + return NULLPTR; } } @@ -372,6 +373,8 @@ struct ARROW_EXPORT ArraySpan { namespace internal { +void FillZeroLengthArray(const DataType* type, ArraySpan* span); + /// Construct a zero-copy view of this ArrayData with the given type. /// /// This method checks if the types are layout-compatible. diff --git a/cpp/src/arrow/array/util.cc b/cpp/src/arrow/array/util.cc index e5b4ab39493..c0cdcab730c 100644 --- a/cpp/src/arrow/array/util.cc +++ b/cpp/src/arrow/array/util.cc @@ -664,22 +664,20 @@ class RepeatedArrayFactory { } Status Visit(const SparseUnionType& type) { - const auto& union_scalar = checked_cast(scalar_); - const auto& union_type = checked_cast(*scalar_.type); + const auto& union_scalar = checked_cast(scalar_); const auto scalar_type_code = union_scalar.type_code; - const auto scalar_child_id = union_type.child_ids()[scalar_type_code]; // Create child arrays: most of them are all-null, except for the child array // for the given type code (if the scalar is valid). ArrayVector fields; for (int i = 0; i < type.num_fields(); ++i) { fields.emplace_back(); - if (i == scalar_child_id && scalar_.is_valid) { - ARROW_ASSIGN_OR_RAISE(fields.back(), - MakeArrayFromScalar(*union_scalar.value, length_, pool_)); - } else { + if (i == union_scalar.child_id && scalar_.is_valid) { ARROW_ASSIGN_OR_RAISE( - fields.back(), MakeArrayOfNull(union_type.field(i)->type(), length_, pool_)); + fields.back(), MakeArrayFromScalar(*union_scalar.value[i], length_, pool_)); + } else { + ARROW_ASSIGN_OR_RAISE(fields.back(), + MakeArrayOfNull(type.field(i)->type(), length_, pool_)); } } @@ -691,7 +689,7 @@ class RepeatedArrayFactory { } Status Visit(const DenseUnionType& type) { - const auto& union_scalar = checked_cast(scalar_); + const auto& union_scalar = checked_cast(scalar_); const auto& union_type = checked_cast(*scalar_.type); const auto scalar_type_code = union_scalar.type_code; const auto scalar_child_id = union_type.child_ids()[scalar_type_code]; diff --git a/cpp/src/arrow/compare.cc b/cpp/src/arrow/compare.cc index 8af319ed9ea..c5406ee583f 100644 --- a/cpp/src/arrow/compare.cc +++ b/cpp/src/arrow/compare.cc @@ -796,12 +796,19 @@ class ScalarEqualsVisitor { return Status::OK(); } - Status Visit(const UnionScalar& left) { - const auto& right = checked_cast(right_); + Status Visit(const DenseUnionScalar& left) { + const auto& right = checked_cast(right_); result_ = ScalarEquals(*left.value, *right.value, options_, floating_approximate_); return Status::OK(); } + Status Visit(const SparseUnionScalar& left) { + const auto& right = checked_cast(right_); + result_ = ScalarEquals(*left.value[left.child_id], *right.value[right.child_id], + options_, floating_approximate_); + return Status::OK(); + } + Status Visit(const DictionaryScalar& left) { const auto& right = checked_cast(right_); result_ = ScalarEquals(*left.value.index, *right.value.index, options_, diff --git a/cpp/src/arrow/compute/api_vector.cc b/cpp/src/arrow/compute/api_vector.cc index 4ebdecf5e78..ff1d6619905 100644 --- a/cpp/src/arrow/compute/api_vector.cc +++ b/cpp/src/arrow/compute/api_vector.cc @@ -347,11 +347,11 @@ Result Filter(const Datum& values, const Datum& filter, return CallFunction("filter", {values, filter}, &options, ctx); } -Result Take(const Datum& values, const Datum& filter, const TakeOptions& options, +Result Take(const Datum& values, const Datum& indices, const TakeOptions& options, ExecContext* ctx) { // Invoke metafunction which deals with Datum kinds other than just Array, // ChunkedArray. - return CallFunction("take", {values, filter}, &options, ctx); + return CallFunction("take", {values, indices}, &options, ctx); } Result> Take(const Array& values, const Array& indices, diff --git a/cpp/src/arrow/compute/cast.cc b/cpp/src/arrow/compute/cast.cc index bd49041b4f3..52aecf3e45a 100644 --- a/cpp/src/arrow/compute/cast.cc +++ b/cpp/src/arrow/compute/cast.cc @@ -66,25 +66,6 @@ void InitCastTable() { void EnsureInitCastTable() { std::call_once(cast_table_initialized, InitCastTable); } -// Private version of GetCastFunction with better error reporting -// if the input type is known. -Result> GetCastFunctionInternal( - const std::shared_ptr& to_type, const DataType* from_type = nullptr) { - internal::EnsureInitCastTable(); - auto it = internal::g_cast_table.find(static_cast(to_type->id())); - if (it == internal::g_cast_table.end()) { - if (from_type != nullptr) { - return Status::NotImplemented("Unsupported cast from ", *from_type, " to ", - *to_type, - " (no available cast function for target type)"); - } else { - return Status::NotImplemented("Unsupported cast to ", *to_type, - " (no available cast function for target type)"); - } - } - return it->second; -} - const FunctionDoc cast_doc{"Cast values to another data type", ("Behavior when values wouldn't fit in the target type\n" "can be controlled through CastOptions."), @@ -116,10 +97,13 @@ class CastMetaFunction : public MetaFunction { if (args[0].type()->Equals(*cast_options->to_type)) { return args[0]; } - ARROW_ASSIGN_OR_RAISE( - std::shared_ptr cast_func, - GetCastFunctionInternal(cast_options->to_type, args[0].type().get())); - return cast_func->Execute(args, options, ctx); + Result> result = + GetCastFunction(*cast_options->to_type); + if (!result.ok()) { + Status s = result.status(); + return s.WithMessage(s.message(), " from ", *args[0].type()); + } + return (*result)->Execute(args, options, ctx); } }; @@ -139,18 +123,6 @@ void RegisterScalarCast(FunctionRegistry* registry) { DCHECK_OK(registry->AddFunction(std::make_shared())); DCHECK_OK(registry->AddFunctionOptionsType(kCastOptionsType)); } -} // namespace internal - -CastOptions::CastOptions(bool safe) - : FunctionOptions(internal::kCastOptionsType), - allow_int_overflow(!safe), - allow_time_truncate(!safe), - allow_time_overflow(!safe), - allow_decimal_truncate(!safe), - allow_float_truncate(!safe), - allow_invalid_utf8(!safe) {} - -constexpr char CastOptions::kTypeName[]; CastFunction::CastFunction(std::string name, Type::type out_type_id) : ScalarFunction(std::move(name), Arity::Unary(), FunctionDoc::Empty()), @@ -177,18 +149,18 @@ Status CastFunction::AddKernel(Type::type in_type_id, std::vector in_ } Result CastFunction::DispatchExact( - const std::vector& values) const { - RETURN_NOT_OK(CheckArity(values)); + const std::vector& types) const { + RETURN_NOT_OK(CheckArity(types.size())); std::vector candidate_kernels; for (const auto& kernel : kernels_) { - if (kernel.signature->MatchesInputs(values)) { + if (kernel.signature->MatchesInputs(types)) { candidate_kernels.push_back(&kernel); } } if (candidate_kernels.size() == 0) { - return Status::NotImplemented("Unsupported cast from ", values[0].type->ToString(), + return Status::NotImplemented("Unsupported cast from ", types[0].type->ToString(), " to ", ToTypeName(out_type_id_), " using function ", this->name()); } @@ -213,28 +185,45 @@ Result CastFunction::DispatchExact( return candidate_kernels[0]; } +Result> GetCastFunction(const DataType& to_type) { + internal::EnsureInitCastTable(); + auto it = internal::g_cast_table.find(static_cast(to_type.id())); + if (it == internal::g_cast_table.end()) { + return Status::NotImplemented("Unsupported cast to ", to_type); + } + return it->second; +} + +} // namespace internal + +CastOptions::CastOptions(bool safe) + : FunctionOptions(internal::kCastOptionsType), + allow_int_overflow(!safe), + allow_time_truncate(!safe), + allow_time_overflow(!safe), + allow_decimal_truncate(!safe), + allow_float_truncate(!safe), + allow_invalid_utf8(!safe) {} + +constexpr char CastOptions::kTypeName[]; + Result Cast(const Datum& value, const CastOptions& options, ExecContext* ctx) { return CallFunction("cast", {value}, &options, ctx); } -Result Cast(const Datum& value, std::shared_ptr to_type, +Result Cast(const Datum& value, const TypeHolder& to_type, const CastOptions& options, ExecContext* ctx) { CastOptions options_with_to_type = options; options_with_to_type.to_type = to_type; return Cast(value, options_with_to_type, ctx); } -Result> Cast(const Array& value, std::shared_ptr to_type, +Result> Cast(const Array& value, const TypeHolder& to_type, const CastOptions& options, ExecContext* ctx) { ARROW_ASSIGN_OR_RAISE(Datum result, Cast(Datum(value), to_type, options, ctx)); return result.make_array(); } -Result> GetCastFunction( - const std::shared_ptr& to_type) { - return internal::GetCastFunctionInternal(to_type); -} - bool CanCast(const DataType& from_type, const DataType& to_type) { internal::EnsureInitCastTable(); auto it = internal::g_cast_table.find(static_cast(to_type.id())); @@ -242,7 +231,7 @@ bool CanCast(const DataType& from_type, const DataType& to_type) { return false; } - const CastFunction* function = it->second.get(); + const internal::CastFunction* function = it->second.get(); DCHECK_EQ(function->out_type_id(), to_type.id()); for (auto from_id : function->in_type_ids()) { @@ -253,21 +242,5 @@ bool CanCast(const DataType& from_type, const DataType& to_type) { return false; } -Result> Cast(std::vector datums, std::vector descrs, - ExecContext* ctx) { - for (size_t i = 0; i != datums.size(); ++i) { - if (descrs[i] != datums[i].descr()) { - if (descrs[i].shape != datums[i].shape()) { - return Status::NotImplemented("casting between Datum shapes"); - } - - ARROW_ASSIGN_OR_RAISE(datums[i], - Cast(datums[i], CastOptions::Safe(descrs[i].type), ctx)); - } - } - - return datums; -} - } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/cast.h b/cpp/src/arrow/compute/cast.h index e9c3cf55da9..7432933a124 100644 --- a/cpp/src/arrow/compute/cast.h +++ b/cpp/src/arrow/compute/cast.h @@ -22,8 +22,7 @@ #include #include "arrow/compute/function.h" -#include "arrow/compute/kernel.h" -#include "arrow/datum.h" +#include "arrow/compute/type_fwd.h" #include "arrow/result.h" #include "arrow/status.h" #include "arrow/type.h" @@ -46,13 +45,13 @@ class ARROW_EXPORT CastOptions : public FunctionOptions { explicit CastOptions(bool safe = true); static constexpr char const kTypeName[] = "CastOptions"; - static CastOptions Safe(std::shared_ptr to_type = NULLPTR) { + static CastOptions Safe(TypeHolder to_type = {}) { CastOptions safe(true); safe.to_type = std::move(to_type); return safe; } - static CastOptions Unsafe(std::shared_ptr to_type = NULLPTR) { + static CastOptions Unsafe(TypeHolder to_type = {}) { CastOptions unsafe(false); unsafe.to_type = std::move(to_type); return unsafe; @@ -60,7 +59,7 @@ class ARROW_EXPORT CastOptions : public FunctionOptions { // Type being casted to. May be passed separate to eager function // compute::Cast - std::shared_ptr to_type; + TypeHolder to_type; bool allow_int_overflow; bool allow_time_truncate; @@ -74,36 +73,6 @@ class ARROW_EXPORT CastOptions : public FunctionOptions { /// @} -// Cast functions are _not_ registered in the FunctionRegistry, though they use -// the same execution machinery -class CastFunction : public ScalarFunction { - public: - CastFunction(std::string name, Type::type out_type_id); - - Type::type out_type_id() const { return out_type_id_; } - const std::vector& in_type_ids() const { return in_type_ids_; } - - Status AddKernel(Type::type in_type_id, std::vector in_types, - OutputType out_type, ArrayKernelExec exec, - NullHandling::type = NullHandling::INTERSECTION, - MemAllocation::type = MemAllocation::PREALLOCATE); - - // Note, this function toggles off memory allocation and sets the init - // function to CastInit - Status AddKernel(Type::type in_type_id, ScalarKernel kernel); - - Result DispatchExact( - const std::vector& values) const override; - - private: - std::vector in_type_ids_; - const Type::type out_type_id_; -}; - -ARROW_EXPORT -Result> GetCastFunction( - const std::shared_ptr& to_type); - /// \brief Return true if a cast function is defined ARROW_EXPORT bool CanCast(const DataType& from_type, const DataType& to_type); @@ -121,7 +90,7 @@ bool CanCast(const DataType& from_type, const DataType& to_type); /// \since 1.0.0 /// \note API not yet finalized ARROW_EXPORT -Result> Cast(const Array& value, std::shared_ptr to_type, +Result> Cast(const Array& value, const TypeHolder& to_type, const CastOptions& options = CastOptions::Safe(), ExecContext* ctx = NULLPTR); @@ -147,21 +116,9 @@ Result Cast(const Datum& value, const CastOptions& options, /// \since 1.0.0 /// \note API not yet finalized ARROW_EXPORT -Result Cast(const Datum& value, std::shared_ptr to_type, +Result Cast(const Datum& value, const TypeHolder& to_type, const CastOptions& options = CastOptions::Safe(), ExecContext* ctx = NULLPTR); -/// \brief Cast several values simultaneously. Safe cast options are used. -/// \param[in] values datums to cast -/// \param[in] descrs ValueDescrs to cast to -/// \param[in] ctx the function execution context, optional -/// \return the resulting datums -/// -/// \since 4.0.0 -/// \note API not yet finalized -ARROW_EXPORT -Result> Cast(std::vector values, std::vector descrs, - ExecContext* ctx = NULLPTR); - } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/cast_internal.h b/cpp/src/arrow/compute/cast_internal.h index 0105d08a573..f00a6cdbf4d 100644 --- a/cpp/src/arrow/compute/cast_internal.h +++ b/cpp/src/arrow/compute/cast_internal.h @@ -30,6 +30,32 @@ namespace internal { using CastState = OptionsWrapper; +// Cast functions are _not_ registered in the FunctionRegistry, though they use +// the same execution machinery +class CastFunction : public ScalarFunction { + public: + CastFunction(std::string name, Type::type out_type_id); + + Type::type out_type_id() const { return out_type_id_; } + const std::vector& in_type_ids() const { return in_type_ids_; } + + Status AddKernel(Type::type in_type_id, std::vector in_types, + OutputType out_type, ArrayKernelExec exec, + NullHandling::type = NullHandling::INTERSECTION, + MemAllocation::type = MemAllocation::PREALLOCATE); + + // Note, this function toggles off memory allocation and sets the init + // function to CastInit + Status AddKernel(Type::type in_type_id, ScalarKernel kernel); + + Result DispatchExact( + const std::vector& types) const override; + + private: + std::vector in_type_ids_; + const Type::type out_type_id_; +}; + // See kernels/scalar_cast_*.cc for these std::vector> GetBooleanCasts(); std::vector> GetNumericCasts(); @@ -38,6 +64,9 @@ std::vector> GetBinaryLikeCasts(); std::vector> GetNestedCasts(); std::vector> GetDictionaryCasts(); +ARROW_EXPORT +Result> GetCastFunction(const DataType& to_type); + } // namespace internal } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec.cc b/cpp/src/arrow/compute/exec.cc index a612a83e7a8..e5e256ea6dd 100644 --- a/cpp/src/arrow/compute/exec.cc +++ b/cpp/src/arrow/compute/exec.cc @@ -219,16 +219,6 @@ void ComputeDataPreallocate(const DataType& type, namespace detail { -Status CheckAllValues(const std::vector& values) { - for (const auto& value : values) { - if (!value.is_value()) { - return Status::Invalid("Tried executing function with non-value type: ", - value.ToString()); - } - } - return Status::OK(); -} - ExecBatchIterator::ExecBatchIterator(std::vector args, int64_t length, int64_t max_chunksize) : args_(std::move(args)), @@ -249,9 +239,7 @@ Result> ExecBatchIterator::Make( } } - // If the arguments are all scalars, then the length is 1 - int64_t length = 1; - + int64_t length = -1; bool length_set = false; for (auto& arg : args) { if (arg.is_scalar()) { @@ -267,6 +255,11 @@ Result> ExecBatchIterator::Make( } } + if (!length_set) { + // All scalar case, to be removed soon + length = 1; + } + max_chunksize = std::min(length, max_chunksize); return std::unique_ptr( @@ -328,8 +321,34 @@ bool ExecBatchIterator::Next(ExecBatch* batch) { // ---------------------------------------------------------------------- // ExecSpanIterator; to eventually replace ExecBatchIterator -Status ExecSpanIterator::Init(const ExecBatch& batch, ValueDescr::Shape output_shape, - int64_t max_chunksize) { +namespace { + +void PromoteExecSpanScalars(ExecSpan* span) { + // In the "all scalar" case, we "promote" the scalars to ArraySpans of + // length 1, since the kernel implementations do not handle the all + // scalar case + for (int i = 0; i < span->num_values(); ++i) { + ExecValue* value = &span->values[i]; + if (value->is_scalar()) { + value->array.FillFromScalar(*value->scalar); + value->scalar = nullptr; + } + } +} + +bool CheckIfAllScalar(const ExecBatch& batch) { + for (const Datum& value : batch.values) { + if (!value.is_scalar()) { + DCHECK(value.is_arraylike()); + return false; + } + } + return batch.num_values() > 0; +} + +} // namespace + +Status ExecSpanIterator::Init(const ExecBatch& batch, int64_t max_chunksize) { if (batch.num_values() > 0) { // Validate arguments bool all_args_same_length = false; @@ -343,8 +362,9 @@ Status ExecSpanIterator::Init(const ExecBatch& batch, ValueDescr::Shape output_s } args_ = &batch.values; initialized_ = have_chunked_arrays_ = false; + have_all_scalars_ = CheckIfAllScalar(batch); position_ = 0; - length_ = output_shape == ValueDescr::SCALAR ? 1 : batch.length; + length_ = batch.length; chunk_indexes_.clear(); chunk_indexes_.resize(args_->size(), 0); value_positions_.clear(); @@ -358,8 +378,7 @@ Status ExecSpanIterator::Init(const ExecBatch& batch, ValueDescr::Shape output_s int64_t ExecSpanIterator::GetNextChunkSpan(int64_t iteration_size, ExecSpan* span) { for (size_t i = 0; i < args_->size() && iteration_size > 0; ++i) { // If the argument is not a chunked array, it's either a Scalar or Array, - // in which case it doesn't influence the size of this span. Note that if - // the args are all scalars the span length is 1 + // in which case it doesn't influence the size of this span if (!args_->at(i).is_chunked_array()) { continue; } @@ -386,12 +405,6 @@ int64_t ExecSpanIterator::GetNextChunkSpan(int64_t iteration_size, ExecSpan* spa } bool ExecSpanIterator::Next(ExecSpan* span) { - if (position_ == length_) { - // This also protects from degenerate cases like ChunkedArrays - // without any chunks - return false; - } - if (!initialized_) { span->length = 0; @@ -402,25 +415,37 @@ bool ExecSpanIterator::Next(ExecSpan* span) { // iteration span->values.resize(args_->size()); for (size_t i = 0; i < args_->size(); ++i) { - if (args_->at(i).is_scalar()) { - span->values[i].SetScalar(args_->at(i).scalar().get()); - } else if (args_->at(i).is_array()) { - const ArrayData& arr = *args_->at(i).array(); + const Datum& arg = (*args_)[i]; + if (arg.is_scalar()) { + span->values[i].SetScalar(arg.scalar().get()); + } else if (arg.is_array()) { + const ArrayData& arr = *arg.array(); span->values[i].SetArray(arr); value_offsets_[i] = arr.offset; } else { // Populate members from the first chunk - const Array* first_chunk = args_->at(i).chunked_array()->chunk(0).get(); - const ArrayData& arr = *first_chunk->data(); - span->values[i].SetArray(arr); - value_offsets_[i] = arr.offset; + const ChunkedArray& carr = *arg.chunked_array(); + if (carr.num_chunks() > 0) { + const ArrayData& arr = *carr.chunk(0)->data(); + span->values[i].SetArray(arr); + value_offsets_[i] = arr.offset; + } else { + // Fill as zero-length array + ::arrow::internal::FillZeroLengthArray(carr.type().get(), + &span->values[i].array); + span->values[i].scalar = nullptr; + } have_chunked_arrays_ = true; } } - initialized_ = true; - } - if (position_ == length_) { + if (have_all_scalars_) { + PromoteExecSpanScalars(span); + } + + initialized_ = true; + } else if (position_ == length_) { + // We've emitted at least one span and we're at the end so we are done return false; } @@ -441,6 +466,7 @@ bool ExecSpanIterator::Next(ExecSpan* span) { value_positions_[i] += iteration_size; } } + position_ += iteration_size; DCHECK_LE(position_, length_); return true; @@ -662,7 +688,7 @@ class NullPropagator { }; std::shared_ptr ToChunkedArray(const std::vector& values, - const std::shared_ptr& type) { + const TypeHolder& type) { std::vector> arrays; arrays.reserve(values.size()); for (const Datum& val : values) { @@ -672,7 +698,7 @@ std::shared_ptr ToChunkedArray(const std::vector& values, } arrays.emplace_back(val.make_array()); } - return std::make_shared(std::move(arrays), type); + return std::make_shared(std::move(arrays), type.GetSharedPtr()); } bool HaveChunkedArray(const std::vector& values) { @@ -691,9 +717,9 @@ class KernelExecutorImpl : public KernelExecutor { kernel_ctx_ = kernel_ctx; kernel_ = static_cast(args.kernel); - // Resolve the output descriptor for this kernel + // Resolve the output type for this kernel ARROW_ASSIGN_OR_RAISE( - output_descr_, kernel_->signature->out_type().Resolve(kernel_ctx_, args.inputs)); + output_type_, kernel_->signature->out_type().Resolve(kernel_ctx_, args.inputs)); return Status::OK(); } @@ -703,7 +729,7 @@ class KernelExecutorImpl : public KernelExecutor { // Kernel::mem_allocation is not MemAllocation::PREALLOCATE, then no // data buffers will be set Result> PrepareOutput(int64_t length) { - auto out = std::make_shared(output_descr_.type, length); + auto out = std::make_shared(output_type_.GetSharedPtr(), length); out->buffers.resize(output_num_buffers_); if (validity_preallocated_) { @@ -726,10 +752,10 @@ class KernelExecutorImpl : public KernelExecutor { Status CheckResultType(const Datum& out, const char* function_name) override { const auto& type = out.type(); - if (type != nullptr && !type->Equals(output_descr_.type)) { + if (type != nullptr && !type->Equals(*output_type_.type)) { return Status::TypeError( "kernel type result mismatch for function '", function_name, "': declared as ", - output_descr_.type->ToString(), ", actual is ", type->ToString()); + output_type_.type->ToString(), ", actual is ", type->ToString()); } return Status::OK(); } @@ -741,7 +767,7 @@ class KernelExecutorImpl : public KernelExecutor { KernelContext* kernel_ctx_; const KernelType* kernel_; - ValueDescr output_descr_; + TypeHolder output_type_; int output_num_buffers_; @@ -757,18 +783,23 @@ class KernelExecutorImpl : public KernelExecutor { class ScalarExecutor : public KernelExecutorImpl { public: Status Execute(const ExecBatch& batch, ExecListener* listener) override { - RETURN_NOT_OK(span_iterator_.Init(batch, output_descr_.shape, - exec_context()->exec_chunksize())); + RETURN_NOT_OK(span_iterator_.Init(batch, exec_context()->exec_chunksize())); - // TODO(wesm): remove if with ARROW-16757 - if (output_descr_.shape != ValueDescr::SCALAR) { - // If the executor is configured to produce a single large Array output for - // kernels supporting preallocation, then we do so up front and then - // iterate over slices of that large array. Otherwise, we preallocate prior - // to processing each span emitted from the ExecSpanIterator - RETURN_NOT_OK(SetupPreallocation(span_iterator_.length(), batch.values)); + if (batch.length == 0) { + // For zero-length batches, we do nothing except return a zero-length + // array of the correct output type + ARROW_ASSIGN_OR_RAISE(std::shared_ptr result, + MakeArrayOfNull(output_type_.GetSharedPtr(), /*length=*/0, + exec_context()->memory_pool())); + return EmitResult(result->data(), listener); } + // If the executor is configured to produce a single large Array output for + // kernels supporting preallocation, then we do so up front and then + // iterate over slices of that large array. Otherwise, we preallocate prior + // to processing each span emitted from the ExecSpanIterator + RETURN_NOT_OK(SetupPreallocation(span_iterator_.length(), batch.values)); + // ARROW-16756: Here we have to accommodate the distinct cases // // * Fully-preallocated contiguous output @@ -784,30 +815,28 @@ class ScalarExecutor : public KernelExecutorImpl { Datum WrapResults(const std::vector& inputs, const std::vector& outputs) override { - if (output_descr_.shape == ValueDescr::SCALAR) { - // TODO(wesm): to remove, see ARROW-16757 - DCHECK_EQ(outputs.size(), 1); - // Return as SCALAR - return outputs[0]; + // If execution yielded multiple chunks (because large arrays were split + // based on the ExecContext parameters, then the result is a ChunkedArray + if (HaveChunkedArray(inputs) || outputs.size() > 1) { + return ToChunkedArray(outputs, output_type_); } else { - // If execution yielded multiple chunks (because large arrays were split - // based on the ExecContext parameters, then the result is a ChunkedArray - if (HaveChunkedArray(inputs) || outputs.size() > 1) { - return ToChunkedArray(outputs, output_descr_.type); - } else if (outputs.size() == 1) { - // Outputs have just one element - return outputs[0]; - } else { - // XXX: In the case where no outputs are omitted, is returning a 0-length - // array always the correct move? - return MakeArrayOfNull(output_descr_.type, /*length=*/0, - exec_context()->memory_pool()) - .ValueOrDie(); - } + // Outputs have just one element + return outputs[0]; } } protected: + Status EmitResult(std::shared_ptr out, ExecListener* listener) { + if (span_iterator_.have_all_scalars()) { + // ARROW-16757 We boxed scalar inputs as ArraySpan, so now we have to + // unbox the output as a scalar + ARROW_ASSIGN_OR_RAISE(std::shared_ptr scalar, MakeArray(out)->GetScalar(0)); + return listener->OnResult(std::move(scalar)); + } else { + return listener->OnResult(std::move(out)); + } + } + Status ExecuteSpans(ExecListener* listener) { // We put the preallocation in an ArraySpan to be passed to the // kernel which is expecting to receive that. More @@ -817,6 +846,7 @@ class ScalarExecutor : public KernelExecutorImpl { ExecSpan input; ExecResult output; ArraySpan* output_span = output.array_span(); + if (preallocate_contiguous_) { // Make one big output allocation ARROW_ASSIGN_OR_RAISE(preallocation, PrepareOutput(span_iterator_.length())); @@ -832,7 +862,7 @@ class ScalarExecutor : public KernelExecutorImpl { } // Kernel execution is complete; emit result - RETURN_NOT_OK(listener->OnResult(std::move(preallocation))); + return EmitResult(std::move(preallocation), listener); } else { // Fully preallocating, but not contiguously // We preallocate (maybe) only for the output of processing the current @@ -842,15 +872,15 @@ class ScalarExecutor : public KernelExecutorImpl { output_span->SetMembers(*preallocation); RETURN_NOT_OK(ExecuteSingleSpan(input, &output)); // Emit the result for this chunk - RETURN_NOT_OK(listener->OnResult(std::move(preallocation))); + RETURN_NOT_OK(EmitResult(std::move(preallocation), listener)); } + return Status::OK(); } - return Status::OK(); } Status ExecuteSingleSpan(const ExecSpan& input, ExecResult* out) { ArraySpan* result_span = out->array_span(); - if (output_descr_.type->id() == Type::NA) { + if (output_type_.type->id() == Type::NA) { result_span->null_count = result_span->length; } else if (kernel_->null_handling == NullHandling::INTERSECTION) { if (!elide_validity_bitmap_) { @@ -859,7 +889,10 @@ class ScalarExecutor : public KernelExecutorImpl { } else if (kernel_->null_handling == NullHandling::OUTPUT_NOT_NULL) { result_span->null_count = 0; } - return kernel_->exec(kernel_ctx_, input, out); + RETURN_NOT_OK(kernel_->exec(kernel_ctx_, input, out)); + // Output type didn't change + DCHECK(out->is_array_span()); + return Status::OK(); } Status ExecuteNonSpans(ExecListener* listener) { @@ -873,60 +906,32 @@ class ScalarExecutor : public KernelExecutorImpl { ExecSpan input; ExecResult output; while (span_iterator_.Next(&input)) { - if (output_descr_.shape == ValueDescr::ARRAY) { - ARROW_ASSIGN_OR_RAISE(output.value, PrepareOutput(input.length)); - DCHECK(output.is_array_data()); - } else { - // For scalar outputs, we set a null scalar of the correct type to - // communicate the output type to the kernel if needed - // - // XXX: Is there some way to avoid this step? - // TODO: Remove this path in ARROW-16757 - output.value = MakeNullScalar(output_descr_.type); - } + ARROW_ASSIGN_OR_RAISE(output.value, PrepareOutput(input.length)); + DCHECK(output.is_array_data()); - if (output_descr_.shape == ValueDescr::ARRAY) { - ArrayData* out_arr = output.array_data().get(); - if (output_descr_.type->id() == Type::NA) { - out_arr->null_count = out_arr->length; - } else if (kernel_->null_handling == NullHandling::INTERSECTION) { - RETURN_NOT_OK(PropagateNulls(kernel_ctx_, input, out_arr)); - } else if (kernel_->null_handling == NullHandling::OUTPUT_NOT_NULL) { - out_arr->null_count = 0; - } - } else { - // TODO(wesm): to remove, see ARROW-16757 - if (kernel_->null_handling == NullHandling::INTERSECTION) { - // set scalar validity - output.scalar()->is_valid = - std::all_of(input.values.begin(), input.values.end(), - [](const ExecValue& input) { return input.scalar->is_valid; }); - } else if (kernel_->null_handling == NullHandling::OUTPUT_NOT_NULL) { - output.scalar()->is_valid = true; - } + ArrayData* out_arr = output.array_data().get(); + if (output_type_.type->id() == Type::NA) { + out_arr->null_count = out_arr->length; + } else if (kernel_->null_handling == NullHandling::INTERSECTION) { + RETURN_NOT_OK(PropagateNulls(kernel_ctx_, input, out_arr)); + } else if (kernel_->null_handling == NullHandling::OUTPUT_NOT_NULL) { + out_arr->null_count = 0; } RETURN_NOT_OK(kernel_->exec(kernel_ctx_, input, &output)); - // Assert that the kernel did not alter the shape of the output - // type. After ARROW-16577 delete this since ValueDescr::SCALAR will not - // exist anymore - DCHECK(((output_descr_.shape == ValueDescr::ARRAY) && output.is_array_data()) || - ((output_descr_.shape == ValueDescr::SCALAR) && output.is_scalar())); + // Output type didn't change + DCHECK(output.is_array_data()); // Emit a result for each chunk - if (output_descr_.shape == ValueDescr::ARRAY) { - RETURN_NOT_OK(listener->OnResult(output.array_data())); - } else { - RETURN_NOT_OK(listener->OnResult(output.scalar())); - } + RETURN_NOT_OK(EmitResult(std::move(output.array_data()), listener)); } return Status::OK(); } Status SetupPreallocation(int64_t total_length, const std::vector& args) { - output_num_buffers_ = static_cast(output_descr_.type->layout().buffers.size()); - auto out_type_id = output_descr_.type->id(); + output_num_buffers_ = static_cast(output_type_.type->layout().buffers.size()); + auto out_type_id = output_type_.type->id(); // Default to no validity pre-allocation for following cases: // - Output Array is NullArray // - kernel_->null_handling is COMPUTED_NO_PREALLOCATE or OUTPUT_NOT_NULL @@ -950,7 +955,7 @@ class ScalarExecutor : public KernelExecutorImpl { } } if (kernel_->mem_allocation == MemAllocation::PREALLOCATE) { - ComputeDataPreallocate(*output_descr_.type, &data_preallocated_); + ComputeDataPreallocate(*output_type_.type, &data_preallocated_); } // Validity bitmap either preallocated or elided, and all data @@ -995,14 +1000,28 @@ class ScalarExecutor : public KernelExecutorImpl { ExecSpanIterator span_iterator_; }; +namespace { + +Status CheckCanExecuteChunked(const VectorKernel* kernel) { + if (kernel->exec_chunked == nullptr) { + return Status::Invalid( + "Vector kernel cannot execute chunkwise and no " + "chunked exec function was defined"); + } + + if (kernel->null_handling == NullHandling::INTERSECTION) { + return Status::Invalid( + "Null pre-propagation is unsupported for ChunkedArray " + "execution in vector kernels"); + } + return Status::OK(); +} + +} // namespace + class VectorExecutor : public KernelExecutorImpl { public: Status Execute(const ExecBatch& batch, ExecListener* listener) override { - // TODO(wesm): remove in ARROW-16577 - if (output_descr_.shape == ValueDescr::SCALAR) { - return Status::Invalid("VectorExecutor only supports array output types"); - } - // Some vector kernels have a separate code path for handling // chunked arrays (VectorKernel::exec_chunked) so we check if we // have any chunked arrays. If we do and an exec_chunked function @@ -1012,19 +1031,18 @@ class VectorExecutor : public KernelExecutorImpl { if (arg.is_chunked_array()) have_chunked_arrays = true; } - output_num_buffers_ = static_cast(output_descr_.type->layout().buffers.size()); + output_num_buffers_ = static_cast(output_type_.type->layout().buffers.size()); // Decide if we need to preallocate memory for this kernel validity_preallocated_ = (kernel_->null_handling != NullHandling::COMPUTED_NO_PREALLOCATE && kernel_->null_handling != NullHandling::OUTPUT_NOT_NULL); if (kernel_->mem_allocation == MemAllocation::PREALLOCATE) { - ComputeDataPreallocate(*output_descr_.type, &data_preallocated_); + ComputeDataPreallocate(*output_type_.type, &data_preallocated_); } if (kernel_->can_execute_chunkwise) { - RETURN_NOT_OK(span_iterator_.Init(batch, output_descr_.shape, - exec_context()->exec_chunksize())); + RETURN_NOT_OK(span_iterator_.Init(batch, exec_context()->exec_chunksize())); ExecSpan span; while (span_iterator_.Next(&span)) { RETURN_NOT_OK(Exec(span, listener)); @@ -1038,7 +1056,11 @@ class VectorExecutor : public KernelExecutorImpl { } else { // No chunked arrays. We pack the args into an ExecSpan and // call the regular exec code path - RETURN_NOT_OK(Exec(ExecSpan(batch), listener)); + ExecSpan span(batch); + if (CheckIfAllScalar(batch)) { + PromoteExecSpanScalars(&span); + } + RETURN_NOT_OK(Exec(span, listener)); } } @@ -1058,63 +1080,46 @@ class VectorExecutor : public KernelExecutorImpl { // If execution yielded multiple chunks (because large arrays were split // based on the ExecContext parameters, then the result is a ChunkedArray if (kernel_->output_chunked && (HaveChunkedArray(inputs) || outputs.size() > 1)) { - return ToChunkedArray(outputs, output_descr_.type); - } else if (outputs.size() == 1) { + return ToChunkedArray(outputs, output_type_.GetSharedPtr()); + } else { // Outputs have just one element return outputs[0]; - } else { - // XXX: In the case where no outputs are omitted, is returning a 0-length - // array always the correct move? - return MakeArrayOfNull(output_descr_.type, /*length=*/0).ValueOrDie(); } } protected: - Status Exec(const ExecSpan& span, ExecListener* listener) { - ExecResult out; - - // We preallocate (maybe) only for the output of processing the current - // batch, but create an output ArrayData instance regardless - ARROW_ASSIGN_OR_RAISE(out.value, PrepareOutput(span.length)); - - if (kernel_->null_handling == NullHandling::INTERSECTION) { - RETURN_NOT_OK(PropagateNulls(kernel_ctx_, span, out.array_data().get())); - } - RETURN_NOT_OK(kernel_->exec(kernel_ctx_, span, &out)); + Status EmitResult(Datum result, ExecListener* listener) { if (!kernel_->finalize) { // If there is no result finalizer (e.g. for hash-based functions, we can // emit the processed batch right away rather than waiting - RETURN_NOT_OK(listener->OnResult(out.array_data())); + RETURN_NOT_OK(listener->OnResult(std::move(result))); } else { - results_.emplace_back(out.array_data()); + results_.emplace_back(std::move(result)); } return Status::OK(); } - Status ExecChunked(const ExecBatch& batch, ExecListener* listener) { - if (kernel_->exec_chunked == nullptr) { - return Status::Invalid( - "Vector kernel cannot execute chunkwise and no " - "chunked exec function was defined"); - } - + Status Exec(const ExecSpan& span, ExecListener* listener) { + ExecResult out; + ARROW_ASSIGN_OR_RAISE(out.value, PrepareOutput(span.length)); if (kernel_->null_handling == NullHandling::INTERSECTION) { - return Status::Invalid( - "Null pre-propagation is unsupported for ChunkedArray " - "execution in vector kernels"); + RETURN_NOT_OK(PropagateNulls(kernel_ctx_, span, out.array_data().get())); } + RETURN_NOT_OK(kernel_->exec(kernel_ctx_, span, &out)); + return EmitResult(std::move(out.array_data()), listener); + } + Status ExecChunked(const ExecBatch& batch, ExecListener* listener) { + RETURN_NOT_OK(CheckCanExecuteChunked(kernel_)); Datum out; ARROW_ASSIGN_OR_RAISE(out.value, PrepareOutput(batch.length)); RETURN_NOT_OK(kernel_->exec_chunked(kernel_ctx_, batch, &out)); - if (!kernel_->finalize) { - // If there is no result finalizer (e.g. for hash-based functions, we can - // emit the processed batch right away rather than waiting - RETURN_NOT_OK(listener->OnResult(std::move(out))); + if (out.is_array()) { + return EmitResult(std::move(out.array()), listener); } else { - results_.emplace_back(std::move(out)); + DCHECK(out.is_chunked_array()); + return EmitResult(std::move(out.chunked_array()), listener); } - return Status::OK(); } ExecSpanIterator span_iterator_; @@ -1124,7 +1129,7 @@ class VectorExecutor : public KernelExecutorImpl { class ScalarAggExecutor : public KernelExecutorImpl { public: Status Init(KernelContext* ctx, KernelInitArgs args) override { - input_descrs_ = &args.inputs; + input_types_ = &args.inputs; options_ = args.options; return KernelExecutorImpl::Init(ctx, args); } @@ -1160,9 +1165,8 @@ class ScalarAggExecutor : public KernelExecutorImpl { private: Status Consume(const ExecBatch& batch) { // FIXME(ARROW-11840) don't merge *any* aggegates for every batch - ARROW_ASSIGN_OR_RAISE( - auto batch_state, - kernel_->init(kernel_ctx_, {kernel_, *input_descrs_, options_})); + ARROW_ASSIGN_OR_RAISE(auto batch_state, + kernel_->init(kernel_ctx_, {kernel_, *input_types_, options_})); if (batch_state == nullptr) { return Status::Invalid("ScalarAggregation requires non-null kernel state"); @@ -1177,7 +1181,7 @@ class ScalarAggExecutor : public KernelExecutorImpl { } std::unique_ptr batch_iterator_; - const std::vector* input_descrs_; + const std::vector* input_types_; const FunctionOptions* options_; }; @@ -1358,8 +1362,7 @@ Result> SelectionVector::FromMask( Result CallFunction(const std::string& func_name, const std::vector& args, const FunctionOptions* options, ExecContext* ctx) { if (ctx == nullptr) { - ExecContext default_ctx; - return CallFunction(func_name, args, options, &default_ctx); + ctx = default_exec_context(); } ARROW_ASSIGN_OR_RAISE(std::shared_ptr func, ctx->func_registry()->GetFunction(func_name)); @@ -1374,8 +1377,7 @@ Result CallFunction(const std::string& func_name, const std::vector CallFunction(const std::string& func_name, const ExecBatch& batch, const FunctionOptions* options, ExecContext* ctx) { if (ctx == nullptr) { - ExecContext default_ctx; - return CallFunction(func_name, batch, options, &default_ctx); + ctx = default_exec_context(); } ARROW_ASSIGN_OR_RAISE(std::shared_ptr func, ctx->func_registry()->GetFunction(func_name)); diff --git a/cpp/src/arrow/compute/exec.h b/cpp/src/arrow/compute/exec.h index 8fd938ce299..f0b951dccb8 100644 --- a/cpp/src/arrow/compute/exec.h +++ b/cpp/src/arrow/compute/exec.h @@ -235,12 +235,11 @@ struct ARROW_EXPORT ExecBatch { ExecBatch Slice(int64_t offset, int64_t length) const; - /// \brief A convenience for returning the ValueDescr objects (types and - /// shapes) from the batch. - std::vector GetDescriptors() const { - std::vector result; + /// \brief A convenience for returning the types from the batch. + std::vector GetTypes() const { + std::vector result; for (const auto& value : this->values) { - result.emplace_back(value.descr()); + result.emplace_back(value.type()); } return result; } @@ -254,19 +253,16 @@ inline bool operator==(const ExecBatch& l, const ExecBatch& r) { return l.Equals inline bool operator!=(const ExecBatch& l, const ExecBatch& r) { return !l.Equals(r); } struct ExecValue { - enum Kind { ARRAY, SCALAR }; - Kind kind = ARRAY; ArraySpan array; - const Scalar* scalar; + const Scalar* scalar = NULLPTR; ExecValue(Scalar* scalar) // NOLINT implicit conversion - : kind(SCALAR), scalar(scalar) {} + : scalar(scalar) {} ExecValue(ArraySpan array) // NOLINT implicit conversion - : kind(ARRAY), array(std::move(array)) {} + : array(std::move(array)) {} - ExecValue(const ArrayData& array) // NOLINT implicit conversion - : kind(ARRAY) { + ExecValue(const ArrayData& array) { // NOLINT implicit conversion this->array.SetMembers(array); } @@ -278,31 +274,21 @@ struct ExecValue { int64_t length() const { return this->is_array() ? this->array.length : 1; } - bool is_array() const { return this->kind == ARRAY; } - bool is_scalar() const { return this->kind == SCALAR; } + bool is_array() const { return this->scalar == NULLPTR; } + bool is_scalar() const { return !this->is_array(); } void SetArray(const ArrayData& array) { - this->kind = ARRAY; this->array.SetMembers(array); + this->scalar = NULLPTR; } - void SetScalar(const Scalar* scalar) { - this->kind = SCALAR; - this->scalar = scalar; - } + void SetScalar(const Scalar* scalar) { this->scalar = scalar; } template const ExactType& scalar_as() const { return ::arrow::internal::checked_cast(*this->scalar); } - /// XXX: here only temporarily until type resolution can be cleaned - /// up to not use ValueDescr - ValueDescr descr() const { - ValueDescr::Shape shape = this->is_array() ? ValueDescr::ARRAY : ValueDescr::SCALAR; - return ValueDescr(const_cast(this->type())->shared_from_this(), shape); - } - /// XXX: here temporarily for compatibility with datum, see /// e.g. MakeStructExec in scalar_nested.cc int64_t null_count() const { @@ -314,7 +300,7 @@ struct ExecValue { } const DataType* type() const { - if (this->kind == ARRAY) { + if (this->is_array()) { return array.type; } else { return scalar->type.get(); @@ -324,29 +310,21 @@ struct ExecValue { struct ARROW_EXPORT ExecResult { // The default value of the variant is ArraySpan - // TODO(wesm): remove Scalar output modality in ARROW-16577 - util::Variant, std::shared_ptr> value; + util::Variant> value; int64_t length() const { if (this->is_array_span()) { return this->array_span()->length; - } else if (this->is_array_data()) { - return this->array_data()->length; } else { - // Should not reach here - return 1; + return this->array_data()->length; } } const DataType* type() const { - switch (this->value.index()) { - case 0: - return this->array_span()->type; - case 1: - return this->array_data()->type.get(); - default: - // scalar - return this->scalar()->type.get(); + if (this->is_array_span()) { + return this->array_span()->type; + } else { + return this->array_data()->type.get(); } } @@ -360,12 +338,6 @@ struct ARROW_EXPORT ExecResult { } bool is_array_data() const { return this->value.index() == 1; } - - const std::shared_ptr& scalar() const { - return util::get>(this->value); - } - - bool is_scalar() const { return this->value.index() == 2; } }; /// \brief A "lightweight" column batch object which contains no @@ -395,15 +367,6 @@ struct ARROW_EXPORT ExecSpan { } } - bool is_all_scalar() const { - for (const ExecValue& value : this->values) { - if (value.is_array()) { - return false; - } - } - return true; - } - /// \brief Return the value at the i-th index template inline const ExecValue& operator[](index_type i) const { @@ -412,7 +375,7 @@ struct ARROW_EXPORT ExecSpan { void AddOffset(int64_t offset) { for (ExecValue& value : values) { - if (value.kind == ExecValue::ARRAY) { + if (value.is_array()) { value.array.AddOffset(offset); } } @@ -420,7 +383,7 @@ struct ARROW_EXPORT ExecSpan { void SetOffset(int64_t offset) { for (ExecValue& value : values) { - if (value.kind == ExecValue::ARRAY) { + if (value.is_array()) { value.array.SetOffset(offset); } } @@ -429,12 +392,10 @@ struct ARROW_EXPORT ExecSpan { /// \brief A convenience for the number of values / arguments. int num_values() const { return static_cast(values.size()); } - // XXX: eliminate the need for ValueDescr; copied temporarily from - // ExecBatch - std::vector GetDescriptors() const { - std::vector result; + std::vector GetTypes() const { + std::vector result; for (const auto& value : this->values) { - result.emplace_back(value.descr()); + result.emplace_back(value.type()); } return result; } diff --git a/cpp/src/arrow/compute/exec/aggregate.cc b/cpp/src/arrow/compute/exec/aggregate.cc index 41b5bb75b66..5cb9a9c5633 100644 --- a/cpp/src/arrow/compute/exec/aggregate.cc +++ b/cpp/src/arrow/compute/exec/aggregate.cc @@ -31,20 +31,19 @@ namespace internal { Result> GetKernels( ExecContext* ctx, const std::vector& aggregates, - const std::vector& in_descrs) { - if (aggregates.size() != in_descrs.size()) { + const std::vector& in_types) { + if (aggregates.size() != in_types.size()) { return Status::Invalid(aggregates.size(), " aggregate functions were specified but ", - in_descrs.size(), " arguments were provided."); + in_types.size(), " arguments were provided."); } - std::vector kernels(in_descrs.size()); + std::vector kernels(in_types.size()); for (size_t i = 0; i < aggregates.size(); ++i) { ARROW_ASSIGN_OR_RAISE(auto function, ctx->func_registry()->GetFunction(aggregates[i].function)); - ARROW_ASSIGN_OR_RAISE( - const Kernel* kernel, - function->DispatchExact({in_descrs[i], ValueDescr::Array(uint32())})); + ARROW_ASSIGN_OR_RAISE(const Kernel* kernel, + function->DispatchExact({in_types[i], uint32()})); kernels[i] = static_cast(kernel); } return kernels; @@ -52,7 +51,7 @@ Result> GetKernels( Result>> InitKernels( const std::vector& kernels, ExecContext* ctx, - const std::vector& aggregates, const std::vector& in_descrs) { + const std::vector& aggregates, const std::vector& in_types) { std::vector> states(kernels.size()); for (size_t i = 0; i < aggregates.size(); ++i) { @@ -69,14 +68,13 @@ Result>> InitKernels( } KernelContext kernel_ctx{ctx}; - ARROW_ASSIGN_OR_RAISE( - states[i], - kernels[i]->init(&kernel_ctx, KernelInitArgs{kernels[i], - { - in_descrs[i], - ValueDescr::Array(uint32()), - }, - options})); + ARROW_ASSIGN_OR_RAISE(states[i], + kernels[i]->init(&kernel_ctx, KernelInitArgs{kernels[i], + { + in_types[i], + uint32(), + }, + options})); } return std::move(states); @@ -86,19 +84,16 @@ Result ResolveKernels( const std::vector& aggregates, const std::vector& kernels, const std::vector>& states, ExecContext* ctx, - const std::vector& descrs) { - FieldVector fields(descrs.size()); + const std::vector& types) { + FieldVector fields(types.size()); for (size_t i = 0; i < kernels.size(); ++i) { KernelContext kernel_ctx{ctx}; kernel_ctx.SetState(states[i].get()); - ARROW_ASSIGN_OR_RAISE(auto descr, kernels[i]->signature->out_type().Resolve( - &kernel_ctx, { - descrs[i], - ValueDescr::Array(uint32()), - })); - fields[i] = field(aggregates[i].function, std::move(descr.type)); + ARROW_ASSIGN_OR_RAISE(auto type, kernels[i]->signature->out_type().Resolve( + &kernel_ctx, {types[i], uint32()})); + fields[i] = field(aggregates[i].function, type.GetSharedPtr()); } return fields; } @@ -122,18 +117,17 @@ Result GroupBy(const std::vector& arguments, const std::vectorparallelism()); for (auto& state : states) { - ARROW_ASSIGN_OR_RAISE(state, - InitKernels(kernels, ctx, aggregates, argument_descrs)); + ARROW_ASSIGN_OR_RAISE(state, InitKernels(kernels, ctx, aggregates, argument_types)); } ARROW_ASSIGN_OR_RAISE( - out_fields, ResolveKernels(aggregates, kernels, states[0], ctx, argument_descrs)); + out_fields, ResolveKernels(aggregates, kernels, states[0], ctx, argument_types)); ARROW_ASSIGN_OR_RAISE( argument_batch_iterator, @@ -142,19 +136,19 @@ Result GroupBy(const std::vector& arguments, const std::vector> groupers(task_group->parallelism()); for (auto& grouper : groupers) { - ARROW_ASSIGN_OR_RAISE(grouper, Grouper::Make(key_descrs, ctx)); + ARROW_ASSIGN_OR_RAISE(grouper, Grouper::Make(key_types, ctx)); } std::mutex mutex; std::unordered_map thread_ids; int i = 0; - for (ValueDescr& key_descr : key_descrs) { - out_fields.push_back(field("key_" + std::to_string(i++), std::move(key_descr.type))); + for (const TypeHolder& key_type : key_types) { + out_fields.push_back(field("key_" + std::to_string(i++), key_type.GetSharedPtr())); } ARROW_ASSIGN_OR_RAISE( diff --git a/cpp/src/arrow/compute/exec/aggregate.h b/cpp/src/arrow/compute/exec/aggregate.h index 753b0a8c47e..72990f3b6e7 100644 --- a/cpp/src/arrow/compute/exec/aggregate.h +++ b/cpp/src/arrow/compute/exec/aggregate.h @@ -42,17 +42,17 @@ Result GroupBy(const std::vector& arguments, const std::vector> GetKernels( ExecContext* ctx, const std::vector& aggregates, - const std::vector& in_descrs); + const std::vector& in_types); Result>> InitKernels( const std::vector& kernels, ExecContext* ctx, - const std::vector& aggregates, const std::vector& in_descrs); + const std::vector& aggregates, const std::vector& in_types); Result ResolveKernels( const std::vector& aggregates, const std::vector& kernels, const std::vector>& states, ExecContext* ctx, - const std::vector& descrs); + const std::vector& in_types); } // namespace internal } // namespace compute diff --git a/cpp/src/arrow/compute/exec/aggregate_node.cc b/cpp/src/arrow/compute/exec/aggregate_node.cc index 8c7899c41ec..0131319be3b 100644 --- a/cpp/src/arrow/compute/exec/aggregate_node.cc +++ b/cpp/src/arrow/compute/exec/aggregate_node.cc @@ -104,8 +104,7 @@ class ScalarAggregateNode : public ExecNode { aggregates[i].function); } - auto in_type = ValueDescr::Array(input_schema.field(target_field_ids[i])->type()); - + TypeHolder in_type(input_schema.field(target_field_ids[i])->type().get()); ARROW_ASSIGN_OR_RAISE(const Kernel* kernel, function->DispatchExact({in_type})); kernels[i] = static_cast(kernel); @@ -125,10 +124,10 @@ class ScalarAggregateNode : public ExecNode { // pick one to resolve the kernel signature kernel_ctx.SetState(states[i][0].get()); - ARROW_ASSIGN_OR_RAISE( - auto descr, kernels[i]->signature->out_type().Resolve(&kernel_ctx, {in_type})); + ARROW_ASSIGN_OR_RAISE(auto out_type, kernels[i]->signature->out_type().Resolve( + &kernel_ctx, {in_type})); - fields[i] = field(aggregate_options.aggregates[i].name, std::move(descr.type)); + fields[i] = field(aggregate_options.aggregates[i].name, out_type.GetSharedPtr()); } return plan->EmplaceNode( @@ -313,25 +312,24 @@ class GroupByNode : public ExecNode { } // Build vector of aggregate source field data types - std::vector agg_src_descrs(aggs.size()); + std::vector agg_src_types(aggs.size()); for (size_t i = 0; i < aggs.size(); ++i) { auto agg_src_field_id = agg_src_field_ids[i]; - agg_src_descrs[i] = - ValueDescr(input_schema->field(agg_src_field_id)->type(), ValueDescr::ARRAY); + agg_src_types[i] = input_schema->field(agg_src_field_id)->type().get(); } auto ctx = input->plan()->exec_context(); // Construct aggregates ARROW_ASSIGN_OR_RAISE(auto agg_kernels, - internal::GetKernels(ctx, aggs, agg_src_descrs)); + internal::GetKernels(ctx, aggs, agg_src_types)); ARROW_ASSIGN_OR_RAISE(auto agg_states, - internal::InitKernels(agg_kernels, ctx, aggs, agg_src_descrs)); + internal::InitKernels(agg_kernels, ctx, aggs, agg_src_types)); ARROW_ASSIGN_OR_RAISE( FieldVector agg_result_fields, - internal::ResolveKernels(aggs, agg_kernels, agg_states, ctx, agg_src_descrs)); + internal::ResolveKernels(aggs, agg_kernels, agg_states, ctx, agg_src_types)); // Build field vector for output schema FieldVector output_fields{keys.size() + aggs.size()}; @@ -621,26 +619,24 @@ class GroupByNode : public ExecNode { if (state->grouper != nullptr) return Status::OK(); // Build vector of key field data types - std::vector key_descrs(key_field_ids_.size()); + std::vector key_types(key_field_ids_.size()); for (size_t i = 0; i < key_field_ids_.size(); ++i) { auto key_field_id = key_field_ids_[i]; - key_descrs[i] = ValueDescr(input_schema->field(key_field_id)->type()); + key_types[i] = input_schema->field(key_field_id)->type().get(); } // Construct grouper - ARROW_ASSIGN_OR_RAISE(state->grouper, Grouper::Make(key_descrs, ctx_)); + ARROW_ASSIGN_OR_RAISE(state->grouper, Grouper::Make(key_types, ctx_)); // Build vector of aggregate source field data types - std::vector agg_src_descrs(agg_kernels_.size()); + std::vector agg_src_types(agg_kernels_.size()); for (size_t i = 0; i < agg_kernels_.size(); ++i) { auto agg_src_field_id = agg_src_field_ids_[i]; - agg_src_descrs[i] = - ValueDescr(input_schema->field(agg_src_field_id)->type(), ValueDescr::ARRAY); + agg_src_types[i] = input_schema->field(agg_src_field_id)->type().get(); } - ARROW_ASSIGN_OR_RAISE( - state->agg_states, - internal::InitKernels(agg_kernels_, ctx_, aggs_, agg_src_descrs)); + ARROW_ASSIGN_OR_RAISE(state->agg_states, internal::InitKernels(agg_kernels_, ctx_, + aggs_, agg_src_types)); return Status::OK(); } diff --git a/cpp/src/arrow/compute/exec/expression.cc b/cpp/src/arrow/compute/exec/expression.cc index b796f5cda3b..c890b3c5935 100644 --- a/cpp/src/arrow/compute/exec/expression.cc +++ b/cpp/src/arrow/compute/exec/expression.cc @@ -64,7 +64,7 @@ Expression::Expression(Parameter parameter) Expression literal(Datum lit) { return Expression(std::move(lit)); } Expression field_ref(FieldRef ref) { - return Expression(Expression::Parameter{std::move(ref), ValueDescr{}, {-1}}); + return Expression(Expression::Parameter{std::move(ref), TypeHolder{}, {-1}}); } Expression call(std::string function, std::vector arguments, @@ -93,36 +93,18 @@ const Expression::Call* Expression::call() const { return util::get_if(impl_.get()); } -ValueDescr Expression::descr() const { - if (impl_ == nullptr) return {}; +const DataType* Expression::type() const { + if (impl_ == nullptr) return nullptr; - if (auto lit = literal()) { - return lit->descr(); - } - - if (auto parameter = this->parameter()) { - return parameter->descr; - } - - return CallNotNull(*this)->descr; -} - -// This is a module-global singleton to avoid synchronization costs of a -// function-static singleton. -static const std::shared_ptr kNoType; - -const std::shared_ptr& Expression::type() const { - if (impl_ == nullptr) return kNoType; - - if (auto lit = literal()) { - return lit->type(); + if (const Datum* lit = literal()) { + return lit->type().get(); } - if (auto parameter = this->parameter()) { - return parameter->descr.type; + if (const Parameter* parameter = this->parameter()) { + return parameter->type.type; } - return CallNotNull(*this)->descr.type; + return CallNotNull(*this)->type.type; } namespace { @@ -276,7 +258,7 @@ size_t Expression::hash() const { bool Expression::IsBound() const { if (type() == nullptr) return false; - if (auto call = this->call()) { + if (const Call* call = this->call()) { if (call->kernel == nullptr) return false; for (const Expression& arg : call->arguments) { @@ -338,7 +320,7 @@ util::optional GetNullHandling( } // namespace bool Expression::IsSatisfiable() const { - if (!type()) return true; + if (type() == nullptr) return true; if (type()->id() != Type::BOOL) return true; if (auto lit = literal()) { @@ -382,25 +364,20 @@ Result BindNonRecursive(Expression::Call call, bool insert_implicit_ DCHECK(std::all_of(call.arguments.begin(), call.arguments.end(), [](const Expression& argument) { return argument.IsBound(); })); - auto descrs = GetDescriptors(call.arguments); + std::vector types = GetTypes(call.arguments); ARROW_ASSIGN_OR_RAISE(call.function, GetFunction(call, exec_context)); if (!insert_implicit_casts) { - ARROW_ASSIGN_OR_RAISE(call.kernel, call.function->DispatchExact(descrs)); + ARROW_ASSIGN_OR_RAISE(call.kernel, call.function->DispatchExact(types)); } else { - ARROW_ASSIGN_OR_RAISE(call.kernel, call.function->DispatchBest(&descrs)); + ARROW_ASSIGN_OR_RAISE(call.kernel, call.function->DispatchBest(&types)); - for (size_t i = 0; i < descrs.size(); ++i) { - if (descrs[i] == call.arguments[i].descr()) continue; + for (size_t i = 0; i < types.size(); ++i) { + if (types[i] == call.arguments[i].type()) continue; - if (descrs[i].shape != call.arguments[i].descr().shape) { - return Status::NotImplemented( - "Automatic broadcasting of scalars arguments to arrays in ", - Expression(std::move(call)).ToString()); - } - - if (auto lit = call.arguments[i].literal()) { - ARROW_ASSIGN_OR_RAISE(Datum new_lit, compute::Cast(*lit, descrs[i].type)); + if (const Datum* lit = call.arguments[i].literal()) { + ARROW_ASSIGN_OR_RAISE(Datum new_lit, + compute::Cast(*lit, types[i].GetSharedPtr())); call.arguments[i] = literal(std::move(new_lit)); continue; } @@ -409,8 +386,10 @@ Result BindNonRecursive(Expression::Call call, bool insert_implicit_ Expression::Call implicit_cast; implicit_cast.function_name = "cast"; implicit_cast.arguments = {std::move(call.arguments[i])}; + + // TODO(wesm): Use TypeHolder in options implicit_cast.options = std::make_shared( - compute::CastOptions::Safe(descrs[i].type)); + compute::CastOptions::Safe(types[i].GetSharedPtr())); ARROW_ASSIGN_OR_RAISE( call.arguments[i], @@ -425,43 +404,41 @@ Result BindNonRecursive(Expression::Call call, bool insert_implicit_ call.options ? call.options.get() : call.function->default_options(); ARROW_ASSIGN_OR_RAISE( call.kernel_state, - call.kernel->init(&kernel_context, {call.kernel, descrs, options})); + call.kernel->init(&kernel_context, {call.kernel, types, options})); kernel_context.SetState(call.kernel_state.get()); } ARROW_ASSIGN_OR_RAISE( - call.descr, call.kernel->signature->out_type().Resolve(&kernel_context, descrs)); + call.type, call.kernel->signature->out_type().Resolve(&kernel_context, types)); return Expression(std::move(call)); } template Result BindImpl(Expression expr, const TypeOrSchema& in, - ValueDescr::Shape shape, compute::ExecContext* exec_context) { + compute::ExecContext* exec_context) { if (exec_context == nullptr) { compute::ExecContext exec_context; - return BindImpl(std::move(expr), in, shape, &exec_context); + return BindImpl(std::move(expr), in, &exec_context); } if (expr.literal()) return expr; - if (auto ref = expr.field_ref()) { - ARROW_ASSIGN_OR_RAISE(auto path, ref->FindOne(in)); + if (const FieldRef* ref = expr.field_ref()) { + ARROW_ASSIGN_OR_RAISE(FieldPath path, ref->FindOne(in)); - auto bound = *expr.parameter(); - bound.indices.resize(path.indices().size()); - std::copy(path.indices().begin(), path.indices().end(), bound.indices.begin()); + Expression::Parameter param = *expr.parameter(); + param.indices.resize(path.indices().size()); + std::copy(path.indices().begin(), path.indices().end(), param.indices.begin()); ARROW_ASSIGN_OR_RAISE(auto field, path.Get(in)); - bound.descr.type = field->type(); - bound.descr.shape = shape; - return Expression{std::move(bound)}; + param.type = field->type(); + return Expression{std::move(param)}; } auto call = *CallNotNull(expr); for (auto& argument : call.arguments) { - ARROW_ASSIGN_OR_RAISE(argument, - BindImpl(std::move(argument), in, shape, exec_context)); + ARROW_ASSIGN_OR_RAISE(argument, BindImpl(std::move(argument), in, exec_context)); } return BindNonRecursive(std::move(call), /*insert_implicit_casts=*/true, exec_context); @@ -469,14 +446,14 @@ Result BindImpl(Expression expr, const TypeOrSchema& in, } // namespace -Result Expression::Bind(const ValueDescr& in, +Result Expression::Bind(const TypeHolder& in, compute::ExecContext* exec_context) const { - return BindImpl(*this, *in.type, in.shape, exec_context); + return BindImpl(*this, *in.type, exec_context); } Result Expression::Bind(const Schema& in_schema, compute::ExecContext* exec_context) const { - return BindImpl(*this, in_schema, ValueDescr::ARRAY, exec_context); + return BindImpl(*this, in_schema, exec_context); } Result MakeExecBatch(const Schema& full_schema, const Datum& partial) { @@ -558,7 +535,7 @@ Result ExecuteScalarExpression(const Expression& expr, const ExecBatch& i if (auto lit = expr.literal()) return *lit; if (auto param = expr.parameter()) { - if (param->descr.type->id() == Type::NA) { + if (param->type.id() == Type::NA) { return MakeNullScalar(null()); } @@ -569,10 +546,10 @@ Result ExecuteScalarExpression(const Expression& expr, const ExecBatch& i ARROW_ASSIGN_OR_RAISE( field, compute::CallFunction("struct_field", {std::move(field)}, &options)); } - if (!field.type()->Equals(param->descr.type)) { + if (!field.type()->Equals(*param->type.type)) { return Status::Invalid("Referenced field ", expr.ToString(), " was ", field.type()->ToString(), " but should have been ", - param->descr.type->ToString()); + param->type.ToString()); } return field; @@ -596,10 +573,10 @@ Result ExecuteScalarExpression(const Expression& expr, const ExecBatch& i compute::KernelContext kernel_context(exec_context, call->kernel); kernel_context.SetState(call->kernel_state.get()); - auto kernel = call->kernel; - auto descrs = GetDescriptors(arguments); + const Kernel* kernel = call->kernel; + std::vector types = GetTypes(arguments); auto options = call->options.get(); - RETURN_NOT_OK(executor->Init(&kernel_context, {kernel, descrs, options})); + RETURN_NOT_OK(executor->Init(&kernel_context, {kernel, types, options})); compute::detail::DatumAccumulator listener; RETURN_NOT_OK(executor->Execute( @@ -683,16 +660,16 @@ Result FoldConstants(Expression expr) { if (GetNullHandling(*call) == compute::NullHandling::INTERSECTION) { // kernels which always produce intersected validity can be resolved // to null *now* if any of their inputs is a null literal - if (!call->descr.type) { + if (!call->type.type) { return Status::Invalid("Cannot fold constants for unbound expression ", expr.ToString()); } - for (const auto& argument : call->arguments) { + for (const Expression& argument : call->arguments) { if (argument.IsNullLiteral()) { - if (argument.type()->Equals(*call->descr.type)) { + if (argument.type()->Equals(*call->type.type)) { return argument; } else { - return literal(MakeNullScalar(call->descr.type)); + return literal(MakeNullScalar(call->type.GetSharedPtr())); } } } @@ -815,7 +792,7 @@ Result ReplaceFieldsWithKnownValues(const KnownFieldValues& known_va auto it = known_values.map.find(*ref); if (it != known_values.map.end()) { Datum lit = it->second; - if (lit.descr() == expr.descr()) return literal(std::move(lit)); + if (lit.type()->Equals(*expr.type())) return literal(std::move(lit)); // type mismatch, try casting the known value to the correct type if (expr.type()->id() == Type::DICTIONARY && @@ -836,7 +813,7 @@ Result ReplaceFieldsWithKnownValues(const KnownFieldValues& known_va } } - ARROW_ASSIGN_OR_RAISE(lit, compute::Cast(lit, expr.type())); + ARROW_ASSIGN_OR_RAISE(lit, compute::Cast(lit, expr.type()->GetSharedPtr())); return literal(std::move(lit)); } } diff --git a/cpp/src/arrow/compute/exec/expression.h b/cpp/src/arrow/compute/exec/expression.h index a1765d0fcca..e9026961aa9 100644 --- a/cpp/src/arrow/compute/exec/expression.h +++ b/cpp/src/arrow/compute/exec/expression.h @@ -55,7 +55,7 @@ class ARROW_EXPORT Expression { std::shared_ptr function; const Kernel* kernel = NULLPTR; std::shared_ptr kernel_state; - ValueDescr descr; + TypeHolder type; void ComputeHash(); }; @@ -70,7 +70,7 @@ class ARROW_EXPORT Expression { /// Bind this expression to the given input type, looking up Kernels and field types. /// Some expression simplification may be performed and implicit casts will be inserted. /// Any state necessary for execution will be initialized and returned. - Result Bind(const ValueDescr& in, ExecContext* = NULLPTR) const; + Result Bind(const TypeHolder& in, ExecContext* = NULLPTR) const; Result Bind(const Schema& in_schema, ExecContext* = NULLPTR) const; // XXX someday @@ -82,8 +82,8 @@ class ARROW_EXPORT Expression { // Result CloneState() const; // Status SetState(ExpressionState); - /// Return true if all an expression's field references have explicit ValueDescr and all - /// of its functions' kernels are looked up. + /// Return true if all an expression's field references have explicit types + /// and all of its functions' kernels are looked up. bool IsBound() const; /// Return true if this expression is composed only of Scalar literals, field @@ -107,9 +107,8 @@ class ARROW_EXPORT Expression { /// Access a FieldRef or return nullptr if this expression is not a field_ref const FieldRef* field_ref() const; - /// The type and shape to which this expression will evaluate - ValueDescr descr() const; - const std::shared_ptr& type() const; + /// The type to which this expression will evaluate + const DataType* type() const; // XXX someday // NullGeneralization::type nullable() const; @@ -117,7 +116,7 @@ class ARROW_EXPORT Expression { FieldRef ref; // post-bind properties - ValueDescr descr; + TypeHolder type; ::arrow::internal::SmallVector indices; }; const Parameter* parameter() const; diff --git a/cpp/src/arrow/compute/exec/expression_internal.h b/cpp/src/arrow/compute/exec/expression_internal.h index f8c686d2c81..027c954c6d0 100644 --- a/cpp/src/arrow/compute/exec/expression_internal.h +++ b/cpp/src/arrow/compute/exec/expression_internal.h @@ -23,6 +23,7 @@ #include "arrow/compute/api_scalar.h" #include "arrow/compute/cast.h" +#include "arrow/compute/cast_internal.h" #include "arrow/compute/registry.h" #include "arrow/record_batch.h" #include "arrow/table.h" @@ -31,6 +32,8 @@ namespace arrow { namespace compute { +using internal::GetCastFunction; + struct KnownFieldValues { std::unordered_map map; }; @@ -41,21 +44,21 @@ inline const Expression::Call* CallNotNull(const Expression& expr) { return call; } -inline std::vector GetDescriptors(const std::vector& exprs) { - std::vector descrs(exprs.size()); +inline std::vector GetTypes(const std::vector& exprs) { + std::vector types(exprs.size()); for (size_t i = 0; i < exprs.size(); ++i) { DCHECK(exprs[i].IsBound()); - descrs[i] = exprs[i].descr(); + types[i] = exprs[i].type(); } - return descrs; + return types; } -inline std::vector GetDescriptors(const std::vector& values) { - std::vector descrs(values.size()); +inline std::vector GetTypes(const std::vector& values) { + std::vector types(values.size()); for (size_t i = 0; i < values.size(); ++i) { - descrs[i] = values[i].descr(); + types[i] = values[i].type(); } - return descrs; + return types; } struct Comparison { @@ -279,9 +282,9 @@ inline Result> GetFunction( return exec_context->func_registry()->GetFunction(call.function_name); } // XXX this special case is strange; why not make "cast" a ScalarFunction? - const auto& to_type = + const TypeHolder& to_type = ::arrow::internal::checked_cast(*call.options).to_type; - return compute::GetCastFunction(to_type); + return GetCastFunction(*to_type); } /// Modify an Expression with pre-order and post-order visitation. diff --git a/cpp/src/arrow/compute/exec/expression_test.cc b/cpp/src/arrow/compute/exec/expression_test.cc index 95adb1652eb..b4466d827eb 100644 --- a/cpp/src/arrow/compute/exec/expression_test.cc +++ b/cpp/src/arrow/compute/exec/expression_test.cc @@ -493,8 +493,8 @@ TEST(Expression, BindLiteral) { Datum(ArrayFromJSON(int32(), "[1,2,3]")), }) { // literals are always considered bound - auto expr = literal(dat); - EXPECT_EQ(expr.descr(), dat.descr()); + Expression expr = literal(dat); + EXPECT_TRUE(dat.type()->Equals(*expr.type())); EXPECT_TRUE(expr.IsBound()); } } @@ -518,13 +518,13 @@ void ExpectBindsTo(Expression expr, util::optional expected, } TEST(Expression, BindFieldRef) { - // an unbound field_ref does not have the output ValueDescr set + // an unbound field_ref does not have the output type set auto expr = field_ref("alpha"); - EXPECT_EQ(expr.descr(), ValueDescr{}); + EXPECT_EQ(expr.type(), nullptr); EXPECT_FALSE(expr.IsBound()); ExpectBindsTo(field_ref("i32"), no_change, &expr); - EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); + EXPECT_TRUE(expr.type()->Equals(*int32())); // if the field is not found, an error will be raised ASSERT_RAISES(Invalid, field_ref("no such field").Bind(*kBoringSchema)); @@ -541,11 +541,11 @@ TEST(Expression, BindNestedFieldRef) { ExpectBindsTo(field_ref(FieldRef("a", "b")), no_change, &expr, schema); EXPECT_TRUE(expr.IsBound()); - EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); + EXPECT_TRUE(expr.type()->Equals(*int32())); ExpectBindsTo(field_ref(FieldRef(FieldPath({0, 0}))), no_change, &expr, schema); EXPECT_TRUE(expr.IsBound()); - EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); + EXPECT_TRUE(expr.type()->Equals(*int32())); ASSERT_RAISES(Invalid, field_ref(FieldPath({0, 1})).Bind(schema)); ASSERT_RAISES(Invalid, field_ref(FieldRef("a", "b")) @@ -558,7 +558,7 @@ TEST(Expression, BindCall) { EXPECT_FALSE(expr.IsBound()); ExpectBindsTo(expr, no_change, &expr); - EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); + EXPECT_TRUE(expr.type()->Equals(*int32())); ExpectBindsTo(call("add", {field_ref("f32"), literal(3)}), call("add", {field_ref("f32"), literal(3.0F)})); @@ -607,7 +607,7 @@ TEST(Expression, BindNestedCall) { ASSERT_OK_AND_ASSIGN(expr, expr.Bind(Schema({field("a", int32()), field("b", int32()), field("c", int32()), field("d", int32())}))); - EXPECT_EQ(expr.descr(), ValueDescr::Array(int32())); + EXPECT_TRUE(expr.type()->Equals(*int32())); EXPECT_TRUE(expr.IsBound()); } @@ -615,7 +615,7 @@ TEST(Expression, ExecuteFieldRef) { auto ExpectRefIs = [](FieldRef ref, Datum in, Datum expected) { auto expr = field_ref(ref); - ASSERT_OK_AND_ASSIGN(expr, expr.Bind(in.descr())); + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(in.type())); ASSERT_OK_AND_ASSIGN(Datum actual, ExecuteScalarExpression(expr, Schema(in.type()->fields()), in)); @@ -716,8 +716,8 @@ Result NaiveExecuteScalarExpression(const Expression& expr, const Datum& compute::ExecContext exec_context; ARROW_ASSIGN_OR_RAISE(auto function, GetFunction(*call, &exec_context)); - auto descrs = GetDescriptors(call->arguments); - ARROW_ASSIGN_OR_RAISE(auto expected_kernel, function->DispatchExact(descrs)); + std::vector types = GetTypes(call->arguments); + ARROW_ASSIGN_OR_RAISE(auto expected_kernel, function->DispatchExact(types)); EXPECT_EQ(call->kernel, expected_kernel); return function->Execute(arguments, call->options.get(), &exec_context); @@ -726,7 +726,7 @@ Result NaiveExecuteScalarExpression(const Expression& expr, const Datum& void ExpectExecute(Expression expr, Datum in, Datum* actual_out = NULLPTR) { std::shared_ptr schm; if (in.is_value()) { - ASSERT_OK_AND_ASSIGN(expr, expr.Bind(in.descr())); + ASSERT_OK_AND_ASSIGN(expr, expr.Bind(in.type())); schm = schema(in.type()->fields()); } else { ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*in.schema())); diff --git a/cpp/src/arrow/compute/exec/hash_join.cc b/cpp/src/arrow/compute/exec/hash_join.cc index a145863e597..a376fb5f57b 100644 --- a/cpp/src/arrow/compute/exec/hash_join.cc +++ b/cpp/src/arrow/compute/exec/hash_join.cc @@ -84,13 +84,11 @@ class HashJoinBasicImpl : public HashJoinImpl { private: void InitEncoder(int side, HashJoinProjection projection_handle, RowEncoder* encoder) { - std::vector data_types; + std::vector data_types; int num_cols = schema_mgr_->proj_maps[side].num_cols(projection_handle); data_types.resize(num_cols); for (int icol = 0; icol < num_cols; ++icol) { - data_types[icol] = - ValueDescr(schema_mgr_->proj_maps[side].data_type(projection_handle, icol), - ValueDescr::ARRAY); + data_types[icol] = schema_mgr_->proj_maps[side].data_type(projection_handle, icol); } encoder->Init(data_types, ctx_); encoder->Clear(); diff --git a/cpp/src/arrow/compute/exec/hash_join_dict.cc b/cpp/src/arrow/compute/exec/hash_join_dict.cc index 731a5662d7d..560b0ea8d4d 100644 --- a/cpp/src/arrow/compute/exec/hash_join_dict.cc +++ b/cpp/src/arrow/compute/exec/hash_join_dict.cc @@ -224,8 +224,8 @@ Status HashJoinDictBuild::Init(ExecContext* ctx, std::shared_ptr dictiona // Initialize encoder internal::RowEncoder encoder; - std::vector encoder_types; - encoder_types.emplace_back(value_type_, ValueDescr::ARRAY); + std::vector encoder_types; + encoder_types.emplace_back(value_type_); encoder.Init(encoder_types, ctx); // Encode all dictionary values @@ -285,8 +285,7 @@ Result> HashJoinDictBuild::RemapInputValues( // Initialize encoder // internal::RowEncoder encoder; - std::vector encoder_types; - encoder_types.emplace_back(value_type_, ValueDescr::ARRAY); + std::vector encoder_types = {value_type_}; encoder.Init(encoder_types, ctx); // Encode all @@ -422,8 +421,7 @@ Result> HashJoinDictProbe::RemapInput( remapped_ids_, opt_build_side->RemapInputValues(ctx, Datum(dict->data()), dict->length())); } else { - std::vector encoder_types; - encoder_types.emplace_back(dict_type.value_type(), ValueDescr::ARRAY); + std::vector encoder_types = {dict_type.value_type()}; encoder_.Init(encoder_types, ctx); RETURN_NOT_OK( encoder_.EncodeAndAppend(ExecSpan({*dict->data()}, dict->length()))); @@ -516,14 +514,14 @@ void HashJoinDictBuildMulti::InitEncoder( const SchemaProjectionMaps& proj_map, RowEncoder* encoder, ExecContext* ctx) { int num_cols = proj_map.num_cols(HashJoinProjection::KEY); - std::vector data_types(num_cols); + std::vector data_types(num_cols); for (int icol = 0; icol < num_cols; ++icol) { std::shared_ptr data_type = proj_map.data_type(HashJoinProjection::KEY, icol); if (HashJoinDictBuild::KeyNeedsProcessing(data_type)) { data_type = HashJoinDictBuild::DataTypeAfterRemapping(); } - data_types[icol] = ValueDescr(data_type, ValueDescr::ARRAY); + data_types[icol] = data_type; } encoder->Init(data_types, ctx); } @@ -610,7 +608,7 @@ void HashJoinDictProbeMulti::InitEncoder( const SchemaProjectionMaps& proj_map_build, RowEncoder* encoder, ExecContext* ctx) { int num_cols = proj_map_probe.num_cols(HashJoinProjection::KEY); - std::vector data_types(num_cols); + std::vector data_types(num_cols); for (int icol = 0; icol < num_cols; ++icol) { std::shared_ptr data_type = proj_map_probe.data_type(HashJoinProjection::KEY, icol); @@ -619,7 +617,7 @@ void HashJoinDictProbeMulti::InitEncoder( if (HashJoinDictProbe::KeyNeedsProcessing(data_type, build_data_type)) { data_type = HashJoinDictProbe::DataTypeAfterRemapping(build_data_type); } - data_types[icol] = ValueDescr(data_type, ValueDescr::ARRAY); + data_types[icol] = data_type; } encoder->Init(data_types, ctx); } diff --git a/cpp/src/arrow/compute/exec/hash_join_node_test.cc b/cpp/src/arrow/compute/exec/hash_join_node_test.cc index 46600a96da3..9a3c7342788 100644 --- a/cpp/src/arrow/compute/exec/hash_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/hash_join_node_test.cc @@ -44,13 +44,13 @@ BatchesWithSchema GenerateBatchesFromString( const std::vector& json_strings, int multiplicity = 1) { BatchesWithSchema out_batches{{}, schema}; - std::vector descrs; + std::vector types; for (auto&& field : schema->fields()) { - descrs.emplace_back(field->type()); + types.emplace_back(field->type()); } for (auto&& s : json_strings) { - out_batches.batches.push_back(ExecBatchFromJSON(descrs, s)); + out_batches.batches.push_back(ExecBatchFromJSON(types, s)); } size_t batch_count = out_batches.batches.size(); @@ -473,7 +473,7 @@ void TakeUsingVector(ExecContext* ctx, const std::vector> } } -// Generate random arrays given list of data type descriptions and null probabilities. +// Generate random arrays given list of data types and null probabilities. // Make sure that all generated records are unique. // The actual number of generated records may be lower than desired because duplicates // will be removed without replacement. @@ -485,12 +485,12 @@ std::vector> GenRandomUniqueRecords( GenRandomRecords(rng, data_types.data_types, num_desired); ExecContext* ctx = default_exec_context(); - std::vector val_descrs; + std::vector val_types; for (size_t i = 0; i < result.size(); ++i) { - val_descrs.push_back(ValueDescr(result[i]->type(), ValueDescr::ARRAY)); + val_types.push_back(result[i]->type()); } internal::RowEncoder encoder; - encoder.Init(val_descrs, ctx); + encoder.Init(val_types, ctx); ExecBatch batch({}, num_desired); batch.values.resize(result.size()); for (size_t i = 0; i < result.size(); ++i) { diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index 9efa6623e5a..f67d541e1ea 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -1133,12 +1133,11 @@ TEST(ExecPlanExecution, SourceScalarAggSink) { }) .AddToPlan(plan.get())); - ASSERT_THAT( - StartAndCollect(plan.get(), sink_gen), - Finishes(ResultWith(UnorderedElementsAreArray({ - ExecBatchFromJSON({ValueDescr::Scalar(int64()), ValueDescr::Scalar(boolean())}, - "[[22, true]]"), - })))); + ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), + Finishes(ResultWith(UnorderedElementsAreArray({ + ExecBatchFromJSON({int64(), boolean()}, + {ArgShape::SCALAR, ArgShape::SCALAR}, "[[22, true]]"), + })))); } TEST(ExecPlanExecution, AggregationPreservesOptions) { @@ -1168,7 +1167,7 @@ TEST(ExecPlanExecution, AggregationPreservesOptions) { ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), Finishes(ResultWith(UnorderedElementsAreArray({ - ExecBatchFromJSON({ValueDescr::Array(float64())}, "[[5.5]]"), + ExecBatchFromJSON({float64()}, "[[5.5]]"), })))); } { @@ -1209,7 +1208,7 @@ TEST(ExecPlanExecution, ScalarSourceScalarAggSink) { BatchesWithSchema scalar_data; scalar_data.batches = { - ExecBatchFromJSON({ValueDescr::Scalar(int32()), ValueDescr::Scalar(boolean())}, + ExecBatchFromJSON({int32(), boolean()}, {ArgShape::SCALAR, ArgShape::SCALAR}, "[[5, false], [5, false], [5, false]]"), ExecBatchFromJSON({int32(), boolean()}, "[[5, true], [6, false], [7, true]]")}; scalar_data.schema = schema({field("a", int32()), field("b", boolean())}); @@ -1239,11 +1238,11 @@ TEST(ExecPlanExecution, ScalarSourceScalarAggSink) { StartAndCollect(plan.get(), sink_gen), Finishes(ResultWith(UnorderedElementsAreArray({ ExecBatchFromJSON( - {ValueDescr::Scalar(boolean()), ValueDescr::Scalar(boolean()), - ValueDescr::Scalar(int64()), ValueDescr::Scalar(float64()), - ValueDescr::Scalar(int64()), ValueDescr::Scalar(float64()), - ValueDescr::Scalar(int64()), ValueDescr::Array(float64()), - ValueDescr::Scalar(float64())}, + {boolean(), boolean(), int64(), float64(), int64(), float64(), int64(), + float64(), float64()}, + {ArgShape::SCALAR, ArgShape::SCALAR, ArgShape::SCALAR, ArgShape::SCALAR, + ArgShape::SCALAR, ArgShape::SCALAR, ArgShape::SCALAR, ArgShape::ARRAY, + ArgShape::SCALAR}, R"([[false, true, 6, 5.5, 26250, 0.7637626158259734, 33, 5.0, 0.5833333333333334]])"), })))); } @@ -1255,9 +1254,9 @@ TEST(ExecPlanExecution, ScalarSourceGroupedSum) { BatchesWithSchema scalar_data; scalar_data.batches = { - ExecBatchFromJSON({int32(), ValueDescr::Scalar(boolean())}, + ExecBatchFromJSON({int32(), boolean()}, {ArgShape::ARRAY, ArgShape::SCALAR}, "[[5, false], [6, false], [7, false]]"), - ExecBatchFromJSON({int32(), ValueDescr::Scalar(boolean())}, + ExecBatchFromJSON({int32(), boolean()}, {ArgShape::ARRAY, ArgShape::SCALAR}, "[[1, true], [2, true], [3, true]]"), }; scalar_data.schema = schema({field("a", int32()), field("b", boolean())}); diff --git a/cpp/src/arrow/compute/exec/project_node.cc b/cpp/src/arrow/compute/exec/project_node.cc index cad8d7c45ae..76925eb6139 100644 --- a/cpp/src/arrow/compute/exec/project_node.cc +++ b/cpp/src/arrow/compute/exec/project_node.cc @@ -67,7 +67,7 @@ class ProjectNode : public MapNode { ARROW_ASSIGN_OR_RAISE( expr, expr.Bind(*inputs[0]->output_schema(), plan->exec_context())); } - fields[i] = field(std::move(names[i]), expr.type()); + fields[i] = field(std::move(names[i]), expr.type()->GetSharedPtr()); ++i; } return plan->EmplaceNode(plan, std::move(inputs), @@ -82,7 +82,7 @@ class ProjectNode : public MapNode { for (size_t i = 0; i < exprs_.size(); ++i) { util::tracing::Span span; START_COMPUTE_SPAN(span, "Project", - {{"project.descr", exprs_[i].descr().ToString()}, + {{"project.type", exprs_[i].type()->ToString()}, {"project.length", target.length}, {"project.expression", exprs_[i].ToString()}}); ARROW_ASSIGN_OR_RAISE(Expression simplified_expr, diff --git a/cpp/src/arrow/compute/exec/test_util.cc b/cpp/src/arrow/compute/exec/test_util.cc index 1e09cb742fa..330ee471126 100644 --- a/cpp/src/arrow/compute/exec/test_util.cc +++ b/cpp/src/arrow/compute/exec/test_util.cc @@ -143,16 +143,25 @@ ExecNode* MakeDummyNode(ExecPlan* plan, std::string label, std::vector& descrs, +ExecBatch ExecBatchFromJSON(const std::vector& types, util::string_view json) { auto fields = ::arrow::internal::MapVector( - [](const ValueDescr& descr) { return field("", descr.type); }, descrs); + [](const TypeHolder& th) { return field("", th.GetSharedPtr()); }, types); ExecBatch batch{*RecordBatchFromJSON(schema(std::move(fields)), json)}; + return batch; +} + +ExecBatch ExecBatchFromJSON(const std::vector& types, + const std::vector& shapes, util::string_view json) { + DCHECK_EQ(types.size(), shapes.size()); + + ExecBatch batch = ExecBatchFromJSON(types, json); + auto value_it = batch.values.begin(); - for (const auto& descr : descrs) { - if (descr.shape == ValueDescr::SCALAR) { + for (ArgShape shape : shapes) { + if (shape == ArgShape::SCALAR) { if (batch.length == 0) { *value_it = MakeNullScalar(value_it->type()); } else { @@ -232,13 +241,13 @@ BatchesWithSchema MakeBatchesFromString( const std::vector& json_strings, int multiplicity) { BatchesWithSchema out_batches{{}, schema}; - std::vector descrs; + std::vector types; for (auto&& field : schema->fields()) { - descrs.emplace_back(field->type()); + types.emplace_back(field->type()); } for (auto&& s : json_strings) { - out_batches.batches.push_back(ExecBatchFromJSON(descrs, s)); + out_batches.batches.push_back(ExecBatchFromJSON(types, s)); } size_t batch_count = out_batches.batches.size(); diff --git a/cpp/src/arrow/compute/exec/test_util.h b/cpp/src/arrow/compute/exec/test_util.h index ba7e4bb3411..ddbded64d42 100644 --- a/cpp/src/arrow/compute/exec/test_util.h +++ b/cpp/src/arrow/compute/exec/test_util.h @@ -27,6 +27,7 @@ #include "arrow/compute/exec.h" #include "arrow/compute/exec/exec_plan.h" +#include "arrow/compute/kernel.h" #include "arrow/testing/visibility.h" #include "arrow/util/async_generator.h" #include "arrow/util/pcg_random.h" @@ -44,8 +45,16 @@ ExecNode* MakeDummyNode(ExecPlan* plan, std::string label, std::vector& descrs, - util::string_view json); +ExecBatch ExecBatchFromJSON(const std::vector& types, util::string_view json); + +/// \brief Shape qualifier for value types. In certain instances +/// (e.g. "map_lookup" kernel), an argument may only be a scalar, where in +/// other kernels arguments can be arrays or scalars +enum class ArgShape { ANY, ARRAY, SCALAR }; + +ARROW_TESTING_EXPORT +ExecBatch ExecBatchFromJSON(const std::vector& types, + const std::vector& shapes, util::string_view json); struct BatchesWithSchema { std::vector batches; diff --git a/cpp/src/arrow/compute/exec_internal.h b/cpp/src/arrow/compute/exec_internal.h index c475a61c1ba..afca289c20e 100644 --- a/cpp/src/arrow/compute/exec_internal.h +++ b/cpp/src/arrow/compute/exec_internal.h @@ -84,8 +84,7 @@ class ARROW_EXPORT ExecSpanIterator { /// \param[in] batch the input ExecBatch /// \param[in] max_chunksize the maximum length of each ExecSpan. Depending /// on the chunk layout of ChunkedArray. - Status Init(const ExecBatch& batch, ValueDescr::Shape output_shape = ValueDescr::ARRAY, - int64_t max_chunksize = kDefaultMaxChunksize); + Status Init(const ExecBatch& batch, int64_t max_chunksize = kDefaultMaxChunksize); /// \brief Compute the next span by updating the state of the /// previous span object. You must keep passing in the previous @@ -101,6 +100,8 @@ class ARROW_EXPORT ExecSpanIterator { int64_t length() const { return length_; } int64_t position() const { return position_; } + bool have_all_scalars() const { return have_all_scalars_; } + private: ExecSpanIterator(const std::vector& args, int64_t length, int64_t max_chunksize); @@ -108,6 +109,7 @@ class ARROW_EXPORT ExecSpanIterator { bool initialized_ = false; bool have_chunked_arrays_ = false; + bool have_all_scalars_ = false; const std::vector* args_; std::vector chunk_indexes_; std::vector value_positions_; @@ -117,8 +119,8 @@ class ARROW_EXPORT ExecSpanIterator { // from the relative position within each chunk (which is in // value_positions_) std::vector value_offsets_; - int64_t position_; - int64_t length_; + int64_t position_ = 0; + int64_t length_ = 0; int64_t max_chunksize_; }; @@ -147,11 +149,6 @@ class DatumAccumulator : public ExecListener { std::vector values_; }; -/// \brief Check that each Datum is of a "value" type, which means either -/// SCALAR, ARRAY, or CHUNKED_ARRAY. If there are chunked inputs, then these -/// inputs will be split into non-chunked ExecBatch values for execution -Status CheckAllValues(const std::vector& values); - class ARROW_EXPORT KernelExecutor { public: virtual ~KernelExecutor() = default; diff --git a/cpp/src/arrow/compute/exec_test.cc b/cpp/src/arrow/compute/exec_test.cc index bd344fb2297..573f4aee4a0 100644 --- a/cpp/src/arrow/compute/exec_test.cc +++ b/cpp/src/arrow/compute/exec_test.cc @@ -728,10 +728,10 @@ TEST_F(TestExecBatchIterator, Basics) { ASSERT_EQ(3, batch.num_values()); ASSERT_EQ(length, batch.length); - std::vector descrs = batch.GetDescriptors(); - ASSERT_EQ(ValueDescr::Array(int32()), descrs[0]); - ASSERT_EQ(ValueDescr::Array(float64()), descrs[1]); - ASSERT_EQ(ValueDescr::Scalar(int32()), descrs[2]); + std::vector types = batch.GetTypes(); + ASSERT_EQ(types[0], int32()); + ASSERT_EQ(types[1], float64()); + ASSERT_EQ(types[2], int32()); AssertArraysEqual(*args[0].make_array(), *batch[0].make_array()); AssertArraysEqual(*args[1].make_array(), *batch[1].make_array()); @@ -795,13 +795,12 @@ TEST_F(TestExecBatchIterator, ZeroLengthInputs) { class TestExecSpanIterator : public TestComputeInternals { public: void SetupIterator(const ExecBatch& batch, - ValueDescr::Shape output_shape = ValueDescr::ARRAY, int64_t max_chunksize = kDefaultMaxChunksize) { - ASSERT_OK(iterator_.Init(batch, output_shape, max_chunksize)); + ASSERT_OK(iterator_.Init(batch, max_chunksize)); } void CheckIteration(const ExecBatch& input, int chunksize, const std::vector& ex_batch_sizes) { - SetupIterator(input, ValueDescr::ARRAY, chunksize); + SetupIterator(input, chunksize); ExecSpan batch; int64_t position = 0; for (size_t i = 0; i < ex_batch_sizes.size(); ++i) { @@ -902,8 +901,10 @@ TEST_F(TestExecSpanIterator, ZeroLengthInputs) { auto CheckArgs = [&](const ExecBatch& batch) { ExecSpanIterator iterator; - ASSERT_OK(iterator.Init(batch, ValueDescr::ARRAY)); + ASSERT_OK(iterator.Init(batch)); ExecSpan iter_span; + ASSERT_TRUE(iterator.Next(&iter_span)); + ASSERT_EQ(0, iter_span.length); ASSERT_FALSE(iterator.Next(&iter_span)); }; @@ -1045,11 +1046,13 @@ Status ExecStateful(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) return Status::OK(); } -// TODO: remove this / refactor it in ARROW-16577 Status ExecAddInt32(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - const Int32Scalar& arg0 = batch[0].scalar_as(); - const Int32Scalar& arg1 = batch[1].scalar_as(); - out->value = std::make_shared(arg0.value + arg1.value); + const int32_t* left_data = batch[0].array.GetValues(1); + const int32_t* right_data = batch[1].array.GetValues(1); + int32_t* out_data = out->array_span()->GetValues(1); + for (int64_t i = 0; i < batch.length; ++i) { + *out_data++ = *left_data++ + *right_data++; + } return Status::OK(); } @@ -1078,16 +1081,15 @@ class TestCallScalarFunction : public TestComputeInternals { /*doc=*/FunctionDoc::Empty()); // Add a few kernels. Our implementation only accepts arrays - ASSERT_OK(func->AddKernel({InputType::Array(uint8())}, uint8(), ExecCopyArraySpan)); - ASSERT_OK(func->AddKernel({InputType::Array(int32())}, int32(), ExecCopyArraySpan)); - ASSERT_OK( - func->AddKernel({InputType::Array(float64())}, float64(), ExecCopyArraySpan)); + ASSERT_OK(func->AddKernel({uint8()}, uint8(), ExecCopyArraySpan)); + ASSERT_OK(func->AddKernel({int32()}, int32(), ExecCopyArraySpan)); + ASSERT_OK(func->AddKernel({float64()}, float64(), ExecCopyArraySpan)); ASSERT_OK(registry->AddFunction(func)); // A version which doesn't want the executor to call PropagateNulls auto func2 = std::make_shared( "test_copy_computed_bitmap", Arity::Unary(), /*doc=*/FunctionDoc::Empty()); - ScalarKernel kernel({InputType::Array(uint8())}, uint8(), ExecComputedBitmap); + ScalarKernel kernel({uint8()}, uint8(), ExecComputedBitmap); kernel.null_handling = NullHandling::COMPUTED_PREALLOCATE; ASSERT_OK(func2->AddKernel(kernel)); ASSERT_OK(registry->AddFunction(func2)); @@ -1103,7 +1105,7 @@ class TestCallScalarFunction : public TestComputeInternals { auto f2 = std::make_shared( "test_nopre_validity_or_data", Arity::Unary(), /*doc=*/FunctionDoc::Empty()); - ScalarKernel kernel({InputType::Array(uint8())}, uint8(), ExecNoPreallocatedData); + ScalarKernel kernel({uint8()}, uint8(), ExecNoPreallocatedData); kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; ASSERT_OK(f1->AddKernel(kernel)); @@ -1123,7 +1125,7 @@ class TestCallScalarFunction : public TestComputeInternals { auto func = std::make_shared("test_stateful", Arity::Unary(), /*doc=*/FunctionDoc::Empty()); - ScalarKernel kernel({InputType::Array(int32())}, int32(), ExecStateful, InitStateful); + ScalarKernel kernel({int32()}, int32(), ExecStateful, InitStateful); ASSERT_OK(func->AddKernel(kernel)); ASSERT_OK(registry->AddFunction(func)); } @@ -1133,8 +1135,7 @@ class TestCallScalarFunction : public TestComputeInternals { auto func = std::make_shared("test_scalar_add_int32", Arity::Binary(), /*doc=*/FunctionDoc::Empty()); - ASSERT_OK(func->AddKernel({InputType::Scalar(int32()), InputType::Scalar(int32())}, - int32(), ExecAddInt32)); + ASSERT_OK(func->AddKernel({int32(), int32()}, int32(), ExecAddInt32)); ASSERT_OK(registry->AddFunction(func)); } }; @@ -1154,8 +1155,9 @@ TEST_F(TestCallScalarFunction, ArgumentValidation) { ASSERT_RAISES(Invalid, CallFunction("test_copy", args)); // Cannot do scalar - args = {Datum(std::make_shared(5))}; - ASSERT_RAISES(NotImplemented, CallFunction("test_copy", args)); + Datum d1_scalar(std::make_shared(5)); + ASSERT_OK_AND_ASSIGN(auto result, CallFunction("test_copy", {d1})); + ASSERT_OK_AND_ASSIGN(result, CallFunction("test_copy", {d1_scalar})); } TEST_F(TestCallScalarFunction, PreallocationCases) { diff --git a/cpp/src/arrow/compute/function.cc b/cpp/src/arrow/compute/function.cc index b5ebc67d180..12d80a8c9ae 100644 --- a/cpp/src/arrow/compute/function.cc +++ b/cpp/src/arrow/compute/function.cc @@ -79,51 +79,35 @@ static const FunctionDoc kEmptyFunctionDoc{}; const FunctionDoc& FunctionDoc::Empty() { return kEmptyFunctionDoc; } -static Status CheckArityImpl(const Function& function, int passed_num_args, - const char* passed_num_args_label) { - if (function.arity().is_varargs && passed_num_args < function.arity().num_args) { - return Status::Invalid("VarArgs function '", function.name(), "' needs at least ", - function.arity().num_args, " arguments but ", - passed_num_args_label, " only ", passed_num_args); +static Status CheckArityImpl(const Function& func, int num_args) { + if (func.arity().is_varargs && num_args < func.arity().num_args) { + return Status::Invalid("VarArgs function '", func.name(), "' needs at least ", + func.arity().num_args, " arguments but only ", num_args, + " passed"); } - if (!function.arity().is_varargs && passed_num_args != function.arity().num_args) { - return Status::Invalid("Function '", function.name(), "' accepts ", - function.arity().num_args, " arguments but ", - passed_num_args_label, " ", passed_num_args); + if (!func.arity().is_varargs && num_args != func.arity().num_args) { + return Status::Invalid("Function '", func.name(), "' accepts ", func.arity().num_args, + " arguments but ", num_args, " passed"); } - return Status::OK(); } -Status Function::CheckArity(const std::vector& in_types) const { - return CheckArityImpl(*this, static_cast(in_types.size()), "kernel accepts"); -} - -Status Function::CheckArity(const std::vector& descrs) const { - return CheckArityImpl(*this, static_cast(descrs.size()), - "attempted to look up kernel(s) with"); -} - -static Status CheckOptions(const Function& function, const FunctionOptions* options) { - if (options == nullptr && function.doc().options_required) { - return Status::Invalid("Function '", function.name(), - "' cannot be called without options"); - } - return Status::OK(); +Status Function::CheckArity(size_t num_args) const { + return CheckArityImpl(*this, static_cast(num_args)); } namespace detail { -Status NoMatchingKernel(const Function* func, const std::vector& descrs) { +Status NoMatchingKernel(const Function* func, const std::vector& types) { return Status::NotImplemented("Function '", func->name(), "' has no kernel matching input types ", - ValueDescr::ToString(descrs)); + TypeHolder::ToString(types)); } template const KernelType* DispatchExactImpl(const std::vector& kernels, - const std::vector& values) { + const std::vector& values) { const KernelType* kernel_matches[SimdLevel::MAX] = {nullptr}; // Validate arity @@ -159,7 +143,7 @@ const KernelType* DispatchExactImpl(const std::vector& kernels, } const Kernel* DispatchExactImpl(const Function* func, - const std::vector& values) { + const std::vector& values) { if (func->kind() == Function::SCALAR) { return DispatchExactImpl(checked_cast(func)->kernels(), values); @@ -186,11 +170,11 @@ const Kernel* DispatchExactImpl(const Function* func, } // namespace detail Result Function::DispatchExact( - const std::vector& values) const { + const std::vector& values) const { if (kind_ == Function::META) { return Status::NotImplemented("Dispatch for a MetaFunction's Kernels"); } - RETURN_NOT_OK(CheckArity(values)); + RETURN_NOT_OK(CheckArity(values.size())); if (auto kernel = detail::DispatchExactImpl(this, values)) { return kernel; @@ -198,75 +182,92 @@ Result Function::DispatchExact( return detail::NoMatchingKernel(this, values); } -Result Function::DispatchBest(std::vector* values) const { +Result Function::DispatchBest(std::vector* values) const { // TODO(ARROW-11508) permit generic conversions here return DispatchExact(*values); } -Result Function::Execute(const std::vector& args, - const FunctionOptions* options, ExecContext* ctx) const { - return ExecuteInternal(args, /*passed_length=*/-1, options, ctx); +namespace { + +Status CheckAllArrayOrScalar(const std::vector& values) { + for (const auto& value : values) { + if (!value.is_value()) { + return Status::Invalid("Tried executing function with non-value type: ", + value.ToString()); + } + } + return Status::OK(); } -Result Function::Execute(const ExecBatch& batch, const FunctionOptions* options, - ExecContext* ctx) const { - return ExecuteInternal(batch.values, batch.length, options, ctx); +Status CheckOptions(const Function& function, const FunctionOptions* options) { + if (options == nullptr && function.doc().options_required) { + return Status::Invalid("Function '", function.name(), + "' cannot be called without options"); + } + return Status::OK(); } -Result Function::ExecuteInternal(const std::vector& args, - int64_t passed_length, - const FunctionOptions* options, - ExecContext* ctx) const { +Result ExecuteInternal(const Function& func, std::vector args, + int64_t passed_length, const FunctionOptions* options, + ExecContext* ctx) { + std::unique_ptr default_ctx; if (options == nullptr) { - RETURN_NOT_OK(CheckOptions(*this, options)); - options = default_options(); + RETURN_NOT_OK(CheckOptions(func, options)); + options = func.default_options(); } if (ctx == nullptr) { - ExecContext default_ctx; - return ExecuteInternal(args, passed_length, options, &default_ctx); + default_ctx.reset(new ExecContext()); + ctx = default_ctx.get(); } util::tracing::Span span; - START_COMPUTE_SPAN(span, name(), - {{"function.name", name()}, + START_COMPUTE_SPAN(span, func.name(), + {{"function.name", func.name()}, {"function.options", options ? options->ToString() : ""}, - {"function.kind", kind()}}); + {"function.kind", func.kind()}}); // type-check Datum arguments here. Really we'd like to avoid this as much as // possible - RETURN_NOT_OK(detail::CheckAllValues(args)); - std::vector inputs(args.size()); + RETURN_NOT_OK(CheckAllArrayOrScalar(args)); + std::vector in_types(args.size()); for (size_t i = 0; i != args.size(); ++i) { - inputs[i] = args[i].descr(); + in_types[i] = args[i].type().get(); } std::unique_ptr executor; - if (kind() == Function::SCALAR) { + if (func.kind() == Function::SCALAR) { executor = detail::KernelExecutor::MakeScalar(); - } else if (kind() == Function::VECTOR) { + } else if (func.kind() == Function::VECTOR) { executor = detail::KernelExecutor::MakeVector(); - } else if (kind() == Function::SCALAR_AGGREGATE) { + } else if (func.kind() == Function::SCALAR_AGGREGATE) { executor = detail::KernelExecutor::MakeScalarAggregate(); } else { return Status::NotImplemented("Direct execution of HASH_AGGREGATE functions"); } - ARROW_ASSIGN_OR_RAISE(const Kernel* kernel, DispatchBest(&inputs)); - ARROW_ASSIGN_OR_RAISE(std::vector args_with_casts, Cast(args, inputs, ctx)); + ARROW_ASSIGN_OR_RAISE(const Kernel* kernel, func.DispatchBest(&in_types)); + + // Cast arguments if necessary + for (size_t i = 0; i != args.size(); ++i) { + if (in_types[i] != args[i].type()) { + ARROW_ASSIGN_OR_RAISE(args[i], Cast(args[i], CastOptions::Safe(in_types[i]), ctx)); + } + } - std::unique_ptr state; KernelContext kernel_ctx{ctx, kernel}; + + std::unique_ptr state; if (kernel->init) { - ARROW_ASSIGN_OR_RAISE(state, kernel->init(&kernel_ctx, {kernel, inputs, options})); + ARROW_ASSIGN_OR_RAISE(state, kernel->init(&kernel_ctx, {kernel, in_types, options})); kernel_ctx.SetState(state.get()); } - RETURN_NOT_OK(executor->Init(&kernel_ctx, {kernel, inputs, options})); + RETURN_NOT_OK(executor->Init(&kernel_ctx, {kernel, in_types, options})); detail::DatumAccumulator listener; - ExecBatch input(std::move(args_with_casts), /*length=*/0); + ExecBatch input(std::move(args), /*length=*/0); if (input.num_values() == 0) { if (passed_length != -1) { input.length = passed_length; @@ -275,9 +276,13 @@ Result Function::ExecuteInternal(const std::vector& args, bool all_same_length = false; int64_t inferred_length = detail::InferBatchLength(input.values, &all_same_length); input.length = inferred_length; - if (kind() == Function::SCALAR) { - DCHECK(passed_length == -1 || passed_length == inferred_length); - } else if (kind() == Function::VECTOR) { + if (func.kind() == Function::SCALAR) { + if (passed_length != -1 && passed_length != inferred_length) { + return Status::Invalid( + "Passed batch length for execution did not match actual" + " length of values for scalar function execution"); + } + } else if (func.kind() == Function::VECTOR) { auto vkernel = static_cast(kernel); if (!(all_same_length || !vkernel->can_execute_chunkwise)) { return Status::Invalid("Vector kernel arguments must all be the same length"); @@ -287,12 +292,25 @@ Result Function::ExecuteInternal(const std::vector& args, RETURN_NOT_OK(executor->Execute(input, &listener)); const auto out = executor->WrapResults(input.values, listener.values()); #ifndef NDEBUG - DCHECK_OK(executor->CheckResultType(out, name_.c_str())); + DCHECK_OK(executor->CheckResultType(out, func.name().c_str())); #endif return out; } +} // namespace + +Result Function::Execute(const std::vector& args, + const FunctionOptions* options, ExecContext* ctx) const { + return ExecuteInternal(*this, args, /*passed_length=*/-1, options, ctx); +} + +Result Function::Execute(const ExecBatch& batch, const FunctionOptions* options, + ExecContext* ctx) const { + return ExecuteInternal(*this, batch.values, batch.length, options, ctx); +} + namespace { + Status ValidateFunctionSummary(const std::string& s) { if (s.find('\n') != s.npos) { return Status::Invalid("summary contains a newline"); @@ -347,7 +365,7 @@ Status Function::Validate() const { Status ScalarFunction::AddKernel(std::vector in_types, OutputType out_type, ArrayKernelExec exec, KernelInit init) { - RETURN_NOT_OK(CheckArity(in_types)); + RETURN_NOT_OK(CheckArity(in_types.size())); if (arity_.is_varargs && in_types.size() != 1) { return Status::Invalid("VarArgs signatures must have exactly one input type"); @@ -359,7 +377,7 @@ Status ScalarFunction::AddKernel(std::vector in_types, OutputType out } Status ScalarFunction::AddKernel(ScalarKernel kernel) { - RETURN_NOT_OK(CheckArity(kernel.signature->in_types())); + RETURN_NOT_OK(CheckArity(kernel.signature->in_types().size())); if (arity_.is_varargs && !kernel.signature->is_varargs()) { return Status::Invalid("Function accepts varargs but kernel signature does not"); } @@ -369,7 +387,7 @@ Status ScalarFunction::AddKernel(ScalarKernel kernel) { Status VectorFunction::AddKernel(std::vector in_types, OutputType out_type, ArrayKernelExec exec, KernelInit init) { - RETURN_NOT_OK(CheckArity(in_types)); + RETURN_NOT_OK(CheckArity(in_types.size())); if (arity_.is_varargs && in_types.size() != 1) { return Status::Invalid("VarArgs signatures must have exactly one input type"); @@ -381,7 +399,7 @@ Status VectorFunction::AddKernel(std::vector in_types, OutputType out } Status VectorFunction::AddKernel(VectorKernel kernel) { - RETURN_NOT_OK(CheckArity(kernel.signature->in_types())); + RETURN_NOT_OK(CheckArity(kernel.signature->in_types().size())); if (arity_.is_varargs && !kernel.signature->is_varargs()) { return Status::Invalid("Function accepts varargs but kernel signature does not"); } @@ -390,7 +408,7 @@ Status VectorFunction::AddKernel(VectorKernel kernel) { } Status ScalarAggregateFunction::AddKernel(ScalarAggregateKernel kernel) { - RETURN_NOT_OK(CheckArity(kernel.signature->in_types())); + RETURN_NOT_OK(CheckArity(kernel.signature->in_types().size())); if (arity_.is_varargs && !kernel.signature->is_varargs()) { return Status::Invalid("Function accepts varargs but kernel signature does not"); } @@ -399,7 +417,7 @@ Status ScalarAggregateFunction::AddKernel(ScalarAggregateKernel kernel) { } Status HashAggregateFunction::AddKernel(HashAggregateKernel kernel) { - RETURN_NOT_OK(CheckArity(kernel.signature->in_types())); + RETURN_NOT_OK(CheckArity(kernel.signature->in_types().size())); if (arity_.is_varargs && !kernel.signature->is_varargs()) { return Status::Invalid("Function accepts varargs but kernel signature does not"); } @@ -410,8 +428,7 @@ Status HashAggregateFunction::AddKernel(HashAggregateKernel kernel) { Result MetaFunction::Execute(const std::vector& args, const FunctionOptions* options, ExecContext* ctx) const { - RETURN_NOT_OK( - CheckArityImpl(*this, static_cast(args.size()), "attempted to Execute with")); + RETURN_NOT_OK(CheckArityImpl(*this, static_cast(args.size()))); RETURN_NOT_OK(CheckOptions(*this, options)); if (options == nullptr) { diff --git a/cpp/src/arrow/compute/function.h b/cpp/src/arrow/compute/function.h index c32c8766a91..7f2fba68caf 100644 --- a/cpp/src/arrow/compute/function.h +++ b/cpp/src/arrow/compute/function.h @@ -211,19 +211,19 @@ class ARROW_EXPORT Function { virtual int num_kernels() const = 0; /// \brief Return a kernel that can execute the function given the exact - /// argument types (without implicit type casts or scalar->array promotions). + /// argument types (without implicit type casts). /// /// NB: This function is overridden in CastFunction. - virtual Result DispatchExact( - const std::vector& values) const; + virtual Result DispatchExact(const std::vector& types) const; /// \brief Return a best-match kernel that can execute the function given the argument /// types, after implicit casts are applied. /// - /// \param[in,out] values Argument types. An element may be modified to indicate that - /// the returned kernel only approximately matches the input value descriptors; callers - /// are responsible for casting inputs to the type and shape required by the kernel. - virtual Result DispatchBest(std::vector* values) const; + /// \param[in,out] values Argument types. An element may be modified to + /// indicate that the returned kernel only approximately matches the input + /// value descriptors; callers are responsible for casting inputs to the type + /// required by the kernel. + virtual Result DispatchBest(std::vector* values) const; /// \brief Execute the function eagerly with the passed input arguments with /// kernel dispatch, batch iteration, and memory allocation details taken @@ -255,11 +255,7 @@ class ARROW_EXPORT Function { doc_(std::move(doc)), default_options_(default_options) {} - Result ExecuteInternal(const std::vector& args, int64_t passed_length, - const FunctionOptions* options, ExecContext* ctx) const; - - Status CheckArity(const std::vector&) const; - Status CheckArity(const std::vector&) const; + Status CheckArity(size_t num_args) const; std::string name_; Function::Kind kind_; @@ -294,11 +290,11 @@ class FunctionImpl : public Function { /// \brief Look up a kernel in a function. If no Kernel is found, nullptr is returned. ARROW_EXPORT -const Kernel* DispatchExactImpl(const Function* func, const std::vector&); +const Kernel* DispatchExactImpl(const Function* func, const std::vector&); /// \brief Return an error message if no Kernel is found. ARROW_EXPORT -Status NoMatchingKernel(const Function* func, const std::vector&); +Status NoMatchingKernel(const Function* func, const std::vector&); } // namespace detail diff --git a/cpp/src/arrow/compute/function_benchmark.cc b/cpp/src/arrow/compute/function_benchmark.cc index b508ad047fb..791052358e7 100644 --- a/cpp/src/arrow/compute/function_benchmark.cc +++ b/cpp/src/arrow/compute/function_benchmark.cc @@ -19,6 +19,7 @@ #include "arrow/array/array_base.h" #include "arrow/compute/api.h" +#include "arrow/compute/cast_internal.h" #include "arrow/compute/exec_internal.h" #include "arrow/memory_pool.h" #include "arrow/scalar.h" @@ -67,14 +68,13 @@ void BM_CastDispatchBaseline(benchmark::State& state) { // Repeatedly invoke a trivial Cast with all dispatch outside the hot loop random::RandomArrayGenerator rag(kSeed); - auto int_scalars = ToScalars(rag.Int64(kScalarCount, 0, 1 << 20)); - + auto int_array = rag.Int64(1, 0, 1 << 20); auto double_type = float64(); CastOptions cast_options; cast_options.to_type = double_type; - ASSERT_OK_AND_ASSIGN(auto cast_function, GetCastFunction(double_type)); + ASSERT_OK_AND_ASSIGN(auto cast_function, internal::GetCastFunction(*double_type)); ASSERT_OK_AND_ASSIGN(auto cast_kernel, - cast_function->DispatchExact({int_scalars[0]->type})); + cast_function->DispatchExact({int_array->type()})); const auto& exec = static_cast(cast_kernel)->exec; ExecContext exec_context; @@ -85,15 +85,13 @@ void BM_CastDispatchBaseline(benchmark::State& state) { .ValueOrDie(); kernel_context.SetState(cast_state.get()); - ExecSpan input; - input.length = 1; + ExecSpan input({ExecValue(*int_array->data())}, 1); + ExecResult result; + ASSERT_OK_AND_ASSIGN(std::shared_ptr result_space, + MakeArrayOfNull(double_type, 1)); + result.array_span()->SetMembers(*result_space->data()); for (auto _ : state) { - ExecResult result; - result.value = MakeNullScalar(double_type); - for (const std::shared_ptr& int_scalar : int_scalars) { - input.values = {ExecValue(int_scalar.get())}; - ABORT_NOT_OK(exec(&kernel_context, input, &result)); - } + ABORT_NOT_OK(exec(&kernel_context, input, &result)); } state.SetItemsProcessed(state.iterations() * kScalarCount); @@ -153,31 +151,26 @@ void BM_ExecuteScalarFunctionOnScalar(benchmark::State& state) { void BM_ExecuteScalarKernelOnScalar(benchmark::State& state) { // Execute a trivial function, with argument dispatch outside the hot path - const int64_t N = 10000; - auto function = *GetFunctionRegistry()->GetFunction("is_valid"); - auto kernel = *function->DispatchExact({ValueDescr::Scalar(int64())}); + auto kernel = *function->DispatchExact({int64()}); const auto& exec = static_cast(*kernel).exec; - const auto scalars = MakeScalarsForIsValid(N); - ExecContext exec_context; KernelContext kernel_context(&exec_context); - ExecSpan input; - input.length = 1; + ASSERT_OK_AND_ASSIGN(std::shared_ptr input_arr, MakeArrayOfNull(int64(), 1)); + ExecSpan input({*input_arr->data()}, 1); + + ExecResult output; + ASSERT_OK_AND_ASSIGN(std::shared_ptr output_arr, MakeArrayOfNull(int64(), 1)); + output.array_span()->SetMembers(*output_arr->data()); + + const int64_t N = 10000; for (auto _ : state) { - int64_t total = 0; - for (const std::shared_ptr& scalar : scalars) { - ExecResult result; - result.value = MakeNullScalar(int64()); - input.values = {scalar.get()}; - ABORT_NOT_OK(exec(&kernel_context, input, &result)); - total += result.scalar()->is_valid; + for (int i = 0; i < N; ++i) { + ABORT_NOT_OK(exec(&kernel_context, input, &output)); } - benchmark::DoNotOptimize(total); } - state.SetItemsProcessed(state.iterations() * N); } diff --git a/cpp/src/arrow/compute/function_internal.h b/cpp/src/arrow/compute/function_internal.h index f2303b87d90..17261332619 100644 --- a/cpp/src/arrow/compute/function_internal.h +++ b/cpp/src/arrow/compute/function_internal.h @@ -345,6 +345,10 @@ static inline Result> GenericToScalar( return MakeNullScalar(value); } +static inline Result> GenericToScalar(const TypeHolder& value) { + return GenericToScalar(value.GetSharedPtr()); +} + static inline Result> GenericToScalar( const std::shared_ptr& value) { return value; @@ -430,6 +434,12 @@ static inline enable_if_same_result> GenericFromSca return value->type; } +template +static inline enable_if_same_result GenericFromScalar( + const std::shared_ptr& value) { + return value->type; +} + template static inline enable_if_same_result> GenericFromScalar( const std::shared_ptr& value) { diff --git a/cpp/src/arrow/compute/function_test.cc b/cpp/src/arrow/compute/function_test.cc index f06f225f5b9..94daa6baa96 100644 --- a/cpp/src/arrow/compute/function_test.cc +++ b/cpp/src/arrow/compute/function_test.cc @@ -230,9 +230,9 @@ void CheckAddDispatch(FunctionType* func, ExecType exec) { // Duplicate sig is okay ASSERT_OK(func->AddKernel(in_types1, out_type1, exec)); - // Add given a descr - KernelType descr({float64(), float64()}, float64(), exec); - ASSERT_OK(func->AddKernel(descr)); + // Add a kernel + KernelType kernel({float64(), float64()}, float64(), exec); + ASSERT_OK(func->AddKernel(kernel)); ASSERT_EQ(4, func->num_kernels()); ASSERT_EQ(4, func->kernels().size()); @@ -249,9 +249,9 @@ void CheckAddDispatch(FunctionType* func, ExecType exec) { KernelType invalid_kernel({boolean()}, boolean(), exec); ASSERT_RAISES(Invalid, func->AddKernel(invalid_kernel)); - ASSERT_OK_AND_ASSIGN(const Kernel* kernel, func->DispatchExact({int32(), int32()})); + ASSERT_OK_AND_ASSIGN(const Kernel* dispatched, func->DispatchExact({int32(), int32()})); KernelSignature expected_sig(in_types1, out_type1); - ASSERT_TRUE(kernel->signature->Equals(expected_sig)); + ASSERT_TRUE(dispatched->signature->Equals(expected_sig)); // No kernel available ASSERT_RAISES(NotImplemented, func->DispatchExact({utf8(), utf8()})); @@ -288,7 +288,7 @@ TEST(ArrayFunction, VarArgs) { ScalarKernel non_va_kernel(std::make_shared(va_args, int8()), ExecNYI); ASSERT_RAISES(Invalid, va_func.AddKernel(non_va_kernel)); - std::vector args = {ValueDescr::Scalar(int8()), int8(), int8()}; + std::vector args = {int8(), int8(), int8()}; ASSERT_OK_AND_ASSIGN(const Kernel* kernel, va_func.DispatchExact(args)); ASSERT_TRUE(kernel->signature->MatchesInputs(args)); @@ -319,7 +319,7 @@ Status NoopFinalize(KernelContext*, Datum*) { return Status::OK(); } TEST(ScalarAggregateFunction, DispatchExact) { ScalarAggregateFunction func("agg_test", Arity::Unary(), FunctionDoc::Empty()); - std::vector in_args = {ValueDescr::Array(int8())}; + std::vector in_args = {int8()}; ScalarAggregateKernel kernel(std::move(in_args), int64(), NoopInit, NoopConsume, NoopMerge, NoopFinalize); ASSERT_OK(func.AddKernel(kernel)); @@ -341,18 +341,14 @@ TEST(ScalarAggregateFunction, DispatchExact) { kernel.signature = std::make_shared(in_args, float64()); ASSERT_RAISES(Invalid, func.AddKernel(kernel)); - std::vector dispatch_args = {ValueDescr::Array(int8())}; + std::vector dispatch_args = {int8()}; ASSERT_OK_AND_ASSIGN(const Kernel* selected_kernel, func.DispatchExact(dispatch_args)); ASSERT_EQ(func.kernels()[0], selected_kernel); ASSERT_TRUE(selected_kernel->signature->MatchesInputs(dispatch_args)); - // We declared that only arrays are accepted - dispatch_args[0] = {ValueDescr::Scalar(int8())}; - ASSERT_RAISES(NotImplemented, func.DispatchExact(dispatch_args)); - // Didn't qualify the float64() kernel so this actually dispatches (even // though that may not be what you want) - dispatch_args[0] = {ValueDescr::Scalar(float64())}; + dispatch_args[0] = {float64()}; ASSERT_OK_AND_ASSIGN(selected_kernel, func.DispatchExact(dispatch_args)); ASSERT_TRUE(selected_kernel->signature->MatchesInputs(dispatch_args)); } diff --git a/cpp/src/arrow/compute/kernel.cc b/cpp/src/arrow/compute/kernel.cc index 909c2399c8e..9a0e9c986a2 100644 --- a/cpp/src/arrow/compute/kernel.cc +++ b/cpp/src/arrow/compute/kernel.cc @@ -282,7 +282,6 @@ std::shared_ptr FixedSizeBinaryLike() { size_t InputType::Hash() const { size_t result = kHashSeed; - hash_combine(result, static_cast(shape_)); hash_combine(result, static_cast(kind_)); switch (kind_) { case InputType::EXACT_TYPE: @@ -296,21 +295,6 @@ size_t InputType::Hash() const { std::string InputType::ToString() const { std::stringstream ss; - switch (shape_) { - case ValueDescr::ANY: - ss << "any"; - break; - case ValueDescr::ARRAY: - ss << "array"; - break; - case ValueDescr::SCALAR: - ss << "scalar"; - break; - default: - DCHECK(false); - break; - } - ss << "["; switch (kind_) { case InputType::ANY_TYPE: ss << "any"; @@ -325,7 +309,6 @@ std::string InputType::ToString() const { DCHECK(false); break; } - ss << "]"; return ss.str(); } @@ -333,7 +316,7 @@ bool InputType::Equals(const InputType& other) const { if (this == &other) { return true; } - if (kind_ != other.kind_ || shape_ != other.shape_) { + if (kind_ != other.kind_) { return false; } switch (kind_) { @@ -348,22 +331,30 @@ bool InputType::Equals(const InputType& other) const { } } -bool InputType::Matches(const ValueDescr& descr) const { - if (shape_ != ValueDescr::ANY && descr.shape != shape_) { - return false; - } +bool InputType::Matches(const DataType& type) const { switch (kind_) { case InputType::EXACT_TYPE: - return type_->Equals(*descr.type); + return type_->Equals(type); case InputType::USE_TYPE_MATCHER: - return type_matcher_->Matches(*descr.type); + return type_matcher_->Matches(type); default: // ANY_TYPE return true; } } -bool InputType::Matches(const Datum& value) const { return Matches(value.descr()); } +bool InputType::Matches(const Datum& value) const { + switch (value.kind()) { + case Datum::ARRAY: + case Datum::CHUNKED_ARRAY: + case Datum::SCALAR: + break; + default: + DCHECK(false); + return false; + } + return Matches(*value.type()); +} const std::shared_ptr& InputType::type() const { DCHECK_EQ(InputType::EXACT_TYPE, kind_); @@ -378,21 +369,12 @@ const TypeMatcher& InputType::type_matcher() const { // ---------------------------------------------------------------------- // OutputType -OutputType::OutputType(ValueDescr descr) : OutputType(descr.type) { - shape_ = descr.shape; -} - -Result OutputType::Resolve(KernelContext* ctx, - const std::vector& args) const { - ValueDescr::Shape broadcasted_shape = GetBroadcastShape(args); +Result OutputType::Resolve(KernelContext* ctx, + const std::vector& types) const { if (kind_ == OutputType::FIXED) { - return ValueDescr(type_, shape_ == ValueDescr::ANY ? broadcasted_shape : shape_); + return type_.get(); } else { - ARROW_ASSIGN_OR_RAISE(ValueDescr resolved_descr, resolver_(ctx, args)); - if (resolved_descr.shape == ValueDescr::ANY) { - resolved_descr.shape = broadcasted_shape; - } - return resolved_descr; + return resolver_(ctx, types); } } @@ -448,19 +430,19 @@ bool KernelSignature::Equals(const KernelSignature& other) const { return true; } -bool KernelSignature::MatchesInputs(const std::vector& args) const { +bool KernelSignature::MatchesInputs(const std::vector& types) const { if (is_varargs_) { - for (size_t i = 0; i < args.size(); ++i) { - if (!in_types_[std::min(i, in_types_.size() - 1)].Matches(args[i])) { + for (size_t i = 0; i < types.size(); ++i) { + if (!in_types_[std::min(i, in_types_.size() - 1)].Matches(*types[i])) { return false; } } } else { - if (args.size() != in_types_.size()) { + if (types.size() != in_types_.size()) { return false; } for (size_t i = 0; i < in_types_.size(); ++i) { - if (!in_types_[i].Matches(args[i])) { + if (!in_types_[i].Matches(*types[i])) { return false; } } @@ -495,7 +477,7 @@ std::string KernelSignature::ToString() const { ss << in_types_[i].ToString(); } if (is_varargs_) { - ss << "]"; + ss << "*]"; } else { ss << ")"; } diff --git a/cpp/src/arrow/compute/kernel.h b/cpp/src/arrow/compute/kernel.h index 93a1c605a99..5463a2de579 100644 --- a/cpp/src/arrow/compute/kernel.h +++ b/cpp/src/arrow/compute/kernel.h @@ -143,10 +143,9 @@ ARROW_EXPORT std::shared_ptr Primitive(); } // namespace match -/// \brief An object used for type- and shape-checking arguments to be passed -/// to a kernel and stored in a KernelSignature. Distinguishes between ARRAY -/// and SCALAR arguments using ValueDescr::Shape. The type-checking rule can be -/// supplied either with an exact DataType instance or a custom TypeMatcher. +/// \brief An object used for type-checking arguments to be passed to a kernel +/// and stored in a KernelSignature. The type-checking rule can be supplied +/// either with an exact DataType instance or a custom TypeMatcher. class ARROW_EXPORT InputType { public: /// \brief The kind of type-checking rule that the InputType contains. @@ -163,29 +162,21 @@ class ARROW_EXPORT InputType { USE_TYPE_MATCHER }; - /// \brief Accept any value type but with a specific shape (e.g. any Array or - /// any Scalar). - InputType(ValueDescr::Shape shape = ValueDescr::ANY) // NOLINT implicit construction - : kind_(ANY_TYPE), shape_(shape) {} + /// \brief Accept any value type + InputType() : kind_(ANY_TYPE) {} /// \brief Accept an exact value type. - InputType(std::shared_ptr type, // NOLINT implicit construction - ValueDescr::Shape shape = ValueDescr::ANY) - : kind_(EXACT_TYPE), shape_(shape), type_(std::move(type)) {} - - /// \brief Accept an exact value type and shape provided by a ValueDescr. - InputType(const ValueDescr& descr) // NOLINT implicit construction - : InputType(descr.type, descr.shape) {} + InputType(std::shared_ptr type) // NOLINT implicit construction + : kind_(EXACT_TYPE), type_(std::move(type)) {} /// \brief Use the passed TypeMatcher to type check. - InputType(std::shared_ptr type_matcher, // NOLINT implicit construction - ValueDescr::Shape shape = ValueDescr::ANY) - : kind_(USE_TYPE_MATCHER), shape_(shape), type_matcher_(std::move(type_matcher)) {} + InputType(std::shared_ptr type_matcher) // NOLINT implicit construction + : kind_(USE_TYPE_MATCHER), type_matcher_(std::move(type_matcher)) {} /// \brief Match any type with the given Type::type. Uses a TypeMatcher for /// its implementation. - explicit InputType(Type::type type_id, ValueDescr::Shape shape = ValueDescr::ANY) - : InputType(match::SameTypeId(type_id), shape) {} + InputType(Type::type type_id) // NOLINT implicit construction + : InputType(match::SameTypeId(type_id)) {} InputType(const InputType& other) { CopyInto(other); } @@ -195,23 +186,8 @@ class ARROW_EXPORT InputType { void operator=(InputType&& other) { MoveInto(std::forward(other)); } - // \brief Match an array with the given exact type. Convenience constructor. - static InputType Array(std::shared_ptr type) { - return InputType(std::move(type), ValueDescr::ARRAY); - } - - // \brief Match a scalar with the given exact type. Convenience constructor. - static InputType Scalar(std::shared_ptr type) { - return InputType(std::move(type), ValueDescr::SCALAR); - } - - // \brief Match an array with the given Type::type id. Convenience - // constructor. - static InputType Array(Type::type id) { return InputType(id, ValueDescr::ARRAY); } - - // \brief Match a scalar with the given Type::type id. Convenience - // constructor. - static InputType Scalar(Type::type id) { return InputType(id, ValueDescr::SCALAR); } + // \brief Match any input (array, scalar of any type) + static InputType Any() { return InputType(); } /// \brief Return true if this input type matches the same type cases as the /// other. @@ -227,21 +203,16 @@ class ARROW_EXPORT InputType { /// \brief Render a human-readable string representation. std::string ToString() const; - /// \brief Return true if the value matches this argument kind in type - /// and shape. + /// \brief Return true if the Datum matches this argument kind in + /// type (and only allows scalar or array-like Datums). bool Matches(const Datum& value) const; - /// \brief Return true if the value descriptor matches this argument kind in - /// type and shape. - bool Matches(const ValueDescr& value) const; + /// \brief Return true if the type matches this InputType + bool Matches(const DataType& type) const; /// \brief The type matching rule that this InputType uses. Kind kind() const { return kind_; } - /// \brief Indicates whether this InputType matches Array (ValueDescr::ARRAY), - /// Scalar (ValueDescr::SCALAR) values, or both (ValueDescr::ANY). - ValueDescr::Shape shape() const { return shape_; } - /// \brief For InputType::EXACT_TYPE kind, the exact type that this InputType /// must match. Otherwise this function should not be used and will assert in /// debug builds. @@ -255,22 +226,18 @@ class ARROW_EXPORT InputType { private: void CopyInto(const InputType& other) { this->kind_ = other.kind_; - this->shape_ = other.shape_; this->type_ = other.type_; this->type_matcher_ = other.type_matcher_; } void MoveInto(InputType&& other) { this->kind_ = other.kind_; - this->shape_ = other.shape_; this->type_ = std::move(other.type_); this->type_matcher_ = std::move(other.type_matcher_); } Kind kind_; - ValueDescr::Shape shape_ = ValueDescr::ANY; - // For EXACT_TYPE Kind std::shared_ptr type_; @@ -279,43 +246,30 @@ class ARROW_EXPORT InputType { }; /// \brief Container to capture both exact and input-dependent output types. -/// -/// The value shape returned by Resolve will be determined by broadcasting the -/// shapes of the input arguments, otherwise this is handled by the -/// user-defined resolver function: -/// -/// * Any ARRAY shape -> output shape is ARRAY -/// * All SCALAR shapes -> output shape is SCALAR class ARROW_EXPORT OutputType { public: /// \brief An enum indicating whether the value type is an invariant fixed /// value or one that's computed by a kernel-defined resolver function. enum ResolveKind { FIXED, COMPUTED }; - /// Type resolution function. Given input types and shapes, return output - /// type and shape. This function MAY may use the kernel state to decide - /// the output type based on the functionoptions. + /// Type resolution function. Given input types, return output type. This + /// function MAY may use the kernel state to decide the output type based on + /// the FunctionOptions. /// /// This function SHOULD _not_ be used to check for arity, that is to be /// performed one or more layers above. - using Resolver = - std::function(KernelContext*, const std::vector&)>; + typedef Result (*Resolver)(KernelContext*, const std::vector&); - /// \brief Output an exact type, but with shape determined by promoting the - /// shapes of the inputs (any ARRAY argument yields ARRAY). + /// \brief Output an exact type OutputType(std::shared_ptr type) // NOLINT implicit construction : kind_(FIXED), type_(std::move(type)) {} - /// \brief Output the exact type and shape provided by a ValueDescr - OutputType(ValueDescr descr); // NOLINT implicit construction - /// \brief Output a computed type depending on actual input types OutputType(Resolver resolver) // NOLINT implicit construction : kind_(COMPUTED), resolver_(std::move(resolver)) {} OutputType(const OutputType& other) { this->kind_ = other.kind_; - this->shape_ = other.shape_; this->type_ = other.type_; this->resolver_ = other.resolver_; } @@ -323,19 +277,17 @@ class ARROW_EXPORT OutputType { OutputType(OutputType&& other) { this->kind_ = other.kind_; this->type_ = std::move(other.type_); - this->shape_ = other.shape_; this->resolver_ = other.resolver_; } OutputType& operator=(const OutputType&) = default; OutputType& operator=(OutputType&&) = default; - /// \brief Return the shape and type of the expected output value of the - /// kernel given the value descriptors (shapes and types) of the input - /// arguments. The resolver may make use of state information kept in the - /// KernelContext. - Result Resolve(KernelContext* ctx, - const std::vector& args) const; + /// \brief Return the type of the expected output value of the kernel given + /// the input argument types. The resolver may make use of state information + /// kept in the KernelContext. + Result Resolve(KernelContext* ctx, + const std::vector& args) const; /// \brief The exact output value type for the FIXED kind. const std::shared_ptr& type() const; @@ -352,22 +304,14 @@ class ARROW_EXPORT OutputType { /// fixed/invariant or computed by a resolver. ResolveKind kind() const { return kind_; } - /// \brief If the shape is ANY, then Resolve will compute the shape based on - /// the input arguments. - ValueDescr::Shape shape() const { return shape_; } - private: ResolveKind kind_; // For FIXED resolution std::shared_ptr type_; - /// \brief The shape of the output type to return when using Resolve. If ANY - /// will promote the input shapes. - ValueDescr::Shape shape_ = ValueDescr::ANY; - // For COMPUTED resolution - Resolver resolver_; + Resolver resolver_ = NULLPTR; }; /// \brief Holds the input types and output type of the kernel. @@ -388,7 +332,7 @@ class ARROW_EXPORT KernelSignature { /// \brief Return true if the signature if compatible with the list of input /// value descriptors. - bool MatchesInputs(const std::vector& descriptors) const; + bool MatchesInputs(const std::vector& types) const; /// \brief Returns true if the input types of each signature are /// equal. Well-formed functions should have a deterministic output type @@ -408,9 +352,10 @@ class ARROW_EXPORT KernelSignature { /// function arguments. const std::vector& in_types() const { return in_types_; } - /// \brief The output type for the kernel. Use Resolve to return the exact - /// output given input argument ValueDescrs, since many kernels' output types - /// depend on their input types (or their type metadata). + /// \brief The output type for the kernel. Use Resolve to return the + /// exact output given input argument types, since many kernels' + /// output types depend on their input types (or their type + /// metadata). const OutputType& out_type() const { return out_type_; } /// \brief Render a human-readable string representation @@ -493,12 +438,9 @@ struct KernelInitArgs { /// depend on the kernel's KernelSignature or other data contained there. const Kernel* kernel; - /// \brief The types and shapes of the input arguments that the kernel is + /// \brief The types of the input arguments that the kernel is /// about to be executed against. - /// - /// TODO: should this be const std::vector*? const-ref is being - /// used to avoid the cost of copying the struct into the args struct. - const std::vector& inputs; + const std::vector& inputs; /// \brief Opaque options specific to this kernel. May be nullptr for functions /// that do not require options. @@ -523,7 +465,7 @@ struct Kernel { std::move(init)) {} /// \brief The "signature" of the kernel containing the InputType input - /// argument validators and OutputType output type and shape resolver. + /// argument validators and OutputType output type resolver. std::shared_ptr signature; /// \brief Create a new KernelState for invocations of this kernel, e.g. to @@ -546,6 +488,9 @@ struct Kernel { /// contain multiple kernels with the same signature but different levels of SIMD, /// so that the most optimized kernel supported on a host's processor can be chosen. SimdLevel::type simd_level = SimdLevel::NONE; + + // Additional kernel-specific data + std::shared_ptr data; }; /// \brief The scalar kernel execution API that must be implemented for SCALAR @@ -555,8 +500,7 @@ struct Kernel { /// endeavor to write into pre-allocated memory if they are able, though for /// some kernels (e.g. in cases when a builder like StringBuilder) must be /// employed this may not be possible. -using ArrayKernelExec = - std::function; +typedef Status (*ArrayKernelExec)(KernelContext*, const ExecSpan&, ExecResult*); /// \brief Kernel data structure for implementations of ScalarFunction. In /// addition to the members found in Kernel, contains the null handling @@ -566,12 +510,11 @@ struct ScalarKernel : public Kernel { ScalarKernel(std::shared_ptr sig, ArrayKernelExec exec, KernelInit init = NULLPTR) - : Kernel(std::move(sig), init), exec(std::move(exec)) {} + : Kernel(std::move(sig), init), exec(exec) {} ScalarKernel(std::vector in_types, OutputType out_type, ArrayKernelExec exec, KernelInit init = NULLPTR) - : Kernel(std::move(in_types), std::move(out_type), std::move(init)), - exec(std::move(exec)) {} + : Kernel(std::move(in_types), std::move(out_type), std::move(init)), exec(exec) {} /// \brief Perform a single invocation of this kernel. Depending on the /// implementation, it may only write into preallocated memory, while in some @@ -590,9 +533,6 @@ struct ScalarKernel : public Kernel { // bitmaps is a reasonable default NullHandling::type null_handling = NullHandling::INTERSECTION; MemAllocation::type mem_allocation = MemAllocation::PREALLOCATE; - - // Additional kernel-specific data - std::shared_ptr data; }; // ---------------------------------------------------------------------- @@ -615,13 +555,13 @@ struct VectorKernel : public Kernel { VectorKernel(std::vector in_types, OutputType out_type, ArrayKernelExec exec, KernelInit init = NULLPTR, FinalizeFunc finalize = NULLPTR) : Kernel(std::move(in_types), std::move(out_type), std::move(init)), - exec(std::move(exec)), + exec(exec), finalize(std::move(finalize)) {} VectorKernel(std::shared_ptr sig, ArrayKernelExec exec, KernelInit init = NULLPTR, FinalizeFunc finalize = NULLPTR) : Kernel(std::move(sig), std::move(init)), - exec(std::move(exec)), + exec(exec), finalize(std::move(finalize)) {} /// \brief Perform a single invocation of this kernel. Any required state is diff --git a/cpp/src/arrow/compute/kernel_test.cc b/cpp/src/arrow/compute/kernel_test.cc index 2d427374426..2676e93c3d7 100644 --- a/cpp/src/arrow/compute/kernel_test.cc +++ b/cpp/src/arrow/compute/kernel_test.cc @@ -21,6 +21,7 @@ #include +#include "arrow/array/util.h" #include "arrow/compute/kernel.h" #include "arrow/status.h" #include "arrow/testing/gtest_util.h" @@ -75,44 +76,24 @@ TEST(InputType, AnyTypeConstructor) { // Check the ANY_TYPE ctors InputType ty; ASSERT_EQ(InputType::ANY_TYPE, ty.kind()); - ASSERT_EQ(ValueDescr::ANY, ty.shape()); - - ty = InputType(ValueDescr::SCALAR); - ASSERT_EQ(ValueDescr::SCALAR, ty.shape()); - - ty = InputType(ValueDescr::ARRAY); - ASSERT_EQ(ValueDescr::ARRAY, ty.shape()); } TEST(InputType, Constructors) { // Exact type constructor InputType ty1(int8()); ASSERT_EQ(InputType::EXACT_TYPE, ty1.kind()); - ASSERT_EQ(ValueDescr::ANY, ty1.shape()); AssertTypeEqual(*int8(), *ty1.type()); InputType ty1_implicit = int8(); ASSERT_TRUE(ty1.Equals(ty1_implicit)); - InputType ty1_array(int8(), ValueDescr::ARRAY); - ASSERT_EQ(ValueDescr::ARRAY, ty1_array.shape()); - - InputType ty1_scalar(int8(), ValueDescr::SCALAR); - ASSERT_EQ(ValueDescr::SCALAR, ty1_scalar.shape()); - // Same type id constructor InputType ty2(Type::DECIMAL); ASSERT_EQ(InputType::USE_TYPE_MATCHER, ty2.kind()); - ASSERT_EQ("any[Type::DECIMAL128]", ty2.ToString()); + ASSERT_EQ("Type::DECIMAL128", ty2.ToString()); ASSERT_TRUE(ty2.type_matcher().Matches(*decimal(12, 2))); ASSERT_FALSE(ty2.type_matcher().Matches(*int16())); - InputType ty2_array(Type::DECIMAL, ValueDescr::ARRAY); - ASSERT_EQ(ValueDescr::ARRAY, ty2_array.shape()); - - InputType ty2_scalar(Type::DECIMAL, ValueDescr::SCALAR); - ASSERT_EQ(ValueDescr::SCALAR, ty2_scalar.shape()); - // Implicit construction in a vector std::vector types = {int8(), InputType(Type::DECIMAL)}; ASSERT_TRUE(types[0].Equals(ty1)); @@ -131,69 +112,33 @@ TEST(InputType, Constructors) { ASSERT_TRUE(ty6.Equals(ty2)); // ToString - ASSERT_EQ("any[int8]", ty1.ToString()); - ASSERT_EQ("array[int8]", ty1_array.ToString()); - ASSERT_EQ("scalar[int8]", ty1_scalar.ToString()); - - ASSERT_EQ("any[Type::DECIMAL128]", ty2.ToString()); - ASSERT_EQ("array[Type::DECIMAL128]", ty2_array.ToString()); - ASSERT_EQ("scalar[Type::DECIMAL128]", ty2_scalar.ToString()); + ASSERT_EQ("int8", ty1.ToString()); + ASSERT_EQ("Type::DECIMAL128", ty2.ToString()); InputType ty7(match::TimestampTypeUnit(TimeUnit::MICRO)); - ASSERT_EQ("any[timestamp(us)]", ty7.ToString()); + ASSERT_EQ("timestamp(us)", ty7.ToString()); InputType ty8; - InputType ty9(ValueDescr::ANY); - InputType ty10(ValueDescr::ARRAY); - InputType ty11(ValueDescr::SCALAR); - ASSERT_EQ("any[any]", ty8.ToString()); - ASSERT_EQ("any[any]", ty9.ToString()); - ASSERT_EQ("array[any]", ty10.ToString()); - ASSERT_EQ("scalar[any]", ty11.ToString()); + ASSERT_EQ("any", ty8.ToString()); } TEST(InputType, Equals) { InputType t1 = int8(); InputType t2 = int8(); - InputType t3(int8(), ValueDescr::ARRAY); - InputType t3_i32(int32(), ValueDescr::ARRAY); - InputType t3_scalar(int8(), ValueDescr::SCALAR); - InputType t4(int8(), ValueDescr::ARRAY); - InputType t4_i32(int32(), ValueDescr::ARRAY); + InputType t3 = int32(); InputType t5(Type::DECIMAL); InputType t6(Type::DECIMAL); - InputType t7(Type::DECIMAL, ValueDescr::SCALAR); - InputType t7_i32(Type::INT32, ValueDescr::SCALAR); - InputType t8(Type::DECIMAL, ValueDescr::SCALAR); - InputType t8_i32(Type::INT32, ValueDescr::SCALAR); ASSERT_TRUE(t1.Equals(t2)); ASSERT_EQ(t1, t2); - - // ANY vs SCALAR ASSERT_NE(t1, t3); - ASSERT_EQ(t3, t4); - - // both ARRAY, but different type - ASSERT_NE(t3, t3_i32); - - // ARRAY vs SCALAR - ASSERT_NE(t3, t3_scalar); - - ASSERT_EQ(t3_i32, t4_i32); - ASSERT_FALSE(t1.Equals(t5)); ASSERT_NE(t1, t5); ASSERT_EQ(t5, t5); ASSERT_EQ(t5, t6); - ASSERT_NE(t5, t7); - ASSERT_EQ(t7, t8); - ASSERT_EQ(t7, t8); - ASSERT_NE(t7, t7_i32); - ASSERT_EQ(t7_i32, t8_i32); // NOTE: For the time being, we treat int32() and Type::INT32 as being // different. This could obviously be fixed later to make these equivalent @@ -208,9 +153,6 @@ TEST(InputType, Equals) { TEST(InputType, Hash) { InputType t0; - InputType t0_scalar(ValueDescr::SCALAR); - InputType t0_array(ValueDescr::ARRAY); - InputType t1 = int8(); InputType t2(Type::DECIMAL); @@ -218,36 +160,32 @@ TEST(InputType, Hash) { // same value, and whether the elements of the type are all incorporated into // the Hash ASSERT_EQ(t0.Hash(), t0.Hash()); - ASSERT_NE(t0.Hash(), t0_scalar.Hash()); - ASSERT_NE(t0.Hash(), t0_array.Hash()); - ASSERT_NE(t0_scalar.Hash(), t0_array.Hash()); - ASSERT_EQ(t1.Hash(), t1.Hash()); ASSERT_EQ(t2.Hash(), t2.Hash()); - ASSERT_NE(t0.Hash(), t1.Hash()); ASSERT_NE(t0.Hash(), t2.Hash()); ASSERT_NE(t1.Hash(), t2.Hash()); } TEST(InputType, Matches) { - InputType ty1 = int8(); - - ASSERT_TRUE(ty1.Matches(ValueDescr::Scalar(int8()))); - ASSERT_TRUE(ty1.Matches(ValueDescr::Array(int8()))); - ASSERT_TRUE(ty1.Matches(ValueDescr::Any(int8()))); - ASSERT_FALSE(ty1.Matches(ValueDescr::Any(int16()))); - - InputType ty2(Type::DECIMAL); - ASSERT_TRUE(ty2.Matches(ValueDescr::Scalar(decimal(12, 2)))); - ASSERT_TRUE(ty2.Matches(ValueDescr::Array(decimal(12, 2)))); - ASSERT_FALSE(ty2.Matches(ValueDescr::Any(float64()))); - - InputType ty3(int64(), ValueDescr::SCALAR); - ASSERT_FALSE(ty3.Matches(ValueDescr::Array(int64()))); - ASSERT_TRUE(ty3.Matches(ValueDescr::Scalar(int64()))); - ASSERT_FALSE(ty3.Matches(ValueDescr::Scalar(int32()))); - ASSERT_FALSE(ty3.Matches(ValueDescr::Any(int64()))); + InputType input1 = int8(); + + ASSERT_TRUE(input1.Matches(*int8())); + ASSERT_TRUE(input1.Matches(*int8())); + ASSERT_FALSE(input1.Matches(*int16())); + + InputType input2(Type::DECIMAL); + ASSERT_TRUE(input2.Matches(*decimal(12, 2))); + + auto ty2 = decimal(12, 2); + auto ty3 = float64(); + ASSERT_OK_AND_ASSIGN(std::shared_ptr arr2, MakeArrayOfNull(ty2, 1)); + ASSERT_OK_AND_ASSIGN(std::shared_ptr arr3, MakeArrayOfNull(ty3, 1)); + ASSERT_OK_AND_ASSIGN(std::shared_ptr scalar2, arr2->GetScalar(0)); + ASSERT_TRUE(input2.Matches(Datum(arr2))); + ASSERT_TRUE(input2.Matches(Datum(scalar2))); + ASSERT_FALSE(input2.Matches(*ty3)); + ASSERT_FALSE(input2.Matches(arr3)); } // ---------------------------------------------------------------------- @@ -259,14 +197,14 @@ TEST(OutputType, Constructors) { AssertTypeEqual(*int8(), *ty1.type()); auto DummyResolver = [](KernelContext*, - const std::vector& args) -> Result { - return ValueDescr(int32(), GetBroadcastShape(args)); + const std::vector& args) -> Result { + return int32(); }; OutputType ty2(DummyResolver); ASSERT_EQ(OutputType::COMPUTED, ty2.kind()); - ASSERT_OK_AND_ASSIGN(ValueDescr out_descr2, ty2.Resolve(nullptr, {})); - ASSERT_EQ(ValueDescr::Array(int32()), out_descr2); + ASSERT_OK_AND_ASSIGN(TypeHolder out_type2, ty2.Resolve(nullptr, {})); + ASSERT_EQ(out_type2, int32()); // Copy constructor OutputType ty3 = ty1; @@ -275,8 +213,8 @@ TEST(OutputType, Constructors) { OutputType ty4 = ty2; ASSERT_EQ(OutputType::COMPUTED, ty4.kind()); - ASSERT_OK_AND_ASSIGN(ValueDescr out_descr4, ty4.Resolve(nullptr, {})); - ASSERT_EQ(ValueDescr::Array(int32()), out_descr4); + ASSERT_OK_AND_ASSIGN(TypeHolder out_type4, ty4.Resolve(nullptr, {})); + ASSERT_EQ(out_type4, int32()); // Move constructor OutputType ty5 = std::move(ty1); @@ -285,8 +223,8 @@ TEST(OutputType, Constructors) { OutputType ty6 = std::move(ty4); ASSERT_EQ(OutputType::COMPUTED, ty6.kind()); - ASSERT_OK_AND_ASSIGN(ValueDescr out_descr6, ty6.Resolve(nullptr, {})); - ASSERT_EQ(ValueDescr::Array(int32()), out_descr6); + ASSERT_OK_AND_ASSIGN(TypeHolder out_type6, ty6.Resolve(nullptr, {})); + ASSERT_EQ(out_type6, int32()); // ToString @@ -296,89 +234,63 @@ TEST(OutputType, Constructors) { } TEST(OutputType, Resolve) { - // Check shape promotion rules for FIXED kind OutputType ty1(int32()); - ASSERT_OK_AND_ASSIGN(ValueDescr descr, ty1.Resolve(nullptr, {})); - ASSERT_EQ(ValueDescr::Array(int32()), descr); + ASSERT_OK_AND_ASSIGN(TypeHolder result, ty1.Resolve(nullptr, {})); + ASSERT_EQ(result, int32()); - ASSERT_OK_AND_ASSIGN(descr, - ty1.Resolve(nullptr, {ValueDescr(int8(), ValueDescr::SCALAR)})); - ASSERT_EQ(ValueDescr::Scalar(int32()), descr); + ASSERT_OK_AND_ASSIGN(result, ty1.Resolve(nullptr, {int8()})); + ASSERT_EQ(result, int32()); - ASSERT_OK_AND_ASSIGN(descr, - ty1.Resolve(nullptr, {ValueDescr(int8(), ValueDescr::SCALAR), - ValueDescr(int8(), ValueDescr::ARRAY)})); - ASSERT_EQ(ValueDescr::Array(int32()), descr); + ASSERT_OK_AND_ASSIGN(result, ty1.Resolve(nullptr, {int8(), int8()})); + ASSERT_EQ(result, int32()); - OutputType ty2([](KernelContext*, const std::vector& args) { - return ValueDescr(args[0].type, GetBroadcastShape(args)); - }); + auto resolver = [](KernelContext*, + const std::vector& args) -> Result { + return args[0]; + }; + OutputType ty2(resolver); - ASSERT_OK_AND_ASSIGN(descr, ty2.Resolve(nullptr, {ValueDescr::Array(utf8())})); - ASSERT_EQ(ValueDescr::Array(utf8()), descr); + ASSERT_OK_AND_ASSIGN(result, ty2.Resolve(nullptr, {utf8()})); + ASSERT_EQ(result, utf8()); // Type resolver that returns an error OutputType ty3( - [](KernelContext* ctx, const std::vector& args) -> Result { + [](KernelContext* ctx, const std::vector& types) -> Result { // NB: checking the value types versus the function arity should be // validated elsewhere, so this is just for illustration purposes - if (args.size() == 0) { + if (types.size() == 0) { return Status::Invalid("Need at least one argument"); } - return ValueDescr(args[0]); + return types[0]; }); ASSERT_RAISES(Invalid, ty3.Resolve(nullptr, {})); - // Type resolver that returns ValueDescr::ANY and needs type promotion + // Type resolver that returns a fixed value OutputType ty4( - [](KernelContext* ctx, const std::vector& args) -> Result { + [](KernelContext* ctx, const std::vector& types) -> Result { return int32(); }); - ASSERT_OK_AND_ASSIGN(descr, ty4.Resolve(nullptr, {ValueDescr::Array(int8())})); - ASSERT_EQ(ValueDescr::Array(int32()), descr); - ASSERT_OK_AND_ASSIGN(descr, ty4.Resolve(nullptr, {ValueDescr::Scalar(int8())})); - ASSERT_EQ(ValueDescr::Scalar(int32()), descr); -} - -TEST(OutputType, ResolveDescr) { - ValueDescr d1 = ValueDescr::Scalar(int32()); - ValueDescr d2 = ValueDescr::Array(int32()); - - OutputType ty1(d1); - OutputType ty2(d2); - - ASSERT_EQ(ValueDescr::SCALAR, ty1.shape()); - ASSERT_EQ(ValueDescr::ARRAY, ty2.shape()); - - { - ASSERT_OK_AND_ASSIGN(ValueDescr descr, ty1.Resolve(nullptr, {})); - ASSERT_EQ(d1, descr); - } - - { - ASSERT_OK_AND_ASSIGN(ValueDescr descr, ty2.Resolve(nullptr, {})); - ASSERT_EQ(d2, descr); - } + ASSERT_OK_AND_ASSIGN(result, ty4.Resolve(nullptr, {int8()})); + ASSERT_EQ(result, int32()); + ASSERT_OK_AND_ASSIGN(result, ty4.Resolve(nullptr, {int8()})); + ASSERT_EQ(result, int32()); } // ---------------------------------------------------------------------- // KernelSignature TEST(KernelSignature, Basics) { - // (any[int8], scalar[decimal]) -> utf8 - std::vector in_types({int8(), InputType(Type::DECIMAL, ValueDescr::SCALAR)}); + // (int8, decimal) -> utf8 + std::vector in_types({int8(), InputType(Type::DECIMAL)}); OutputType out_type(utf8()); KernelSignature sig(in_types, out_type); ASSERT_EQ(2, sig.in_types().size()); ASSERT_TRUE(sig.in_types()[0].type()->Equals(*int8())); - ASSERT_TRUE(sig.in_types()[0].Matches(ValueDescr::Scalar(int8()))); - ASSERT_TRUE(sig.in_types()[0].Matches(ValueDescr::Array(int8()))); - - ASSERT_TRUE(sig.in_types()[1].Matches(ValueDescr::Scalar(decimal(12, 2)))); - ASSERT_FALSE(sig.in_types()[1].Matches(ValueDescr::Array(decimal(12, 2)))); + ASSERT_TRUE(sig.in_types()[0].Matches(*int8())); + ASSERT_TRUE(sig.in_types()[1].Matches(*decimal(12, 2))); } TEST(KernelSignature, Equals) { @@ -393,10 +305,6 @@ TEST(KernelSignature, Equals) { KernelSignature sig4_copy({int8(), int16()}, utf8()); KernelSignature sig5({int8(), int16(), int32()}, utf8()); - // Differ in shape - KernelSignature sig6({ValueDescr::Scalar(int8())}, utf8()); - KernelSignature sig7({ValueDescr::Array(int8())}, utf8()); - ASSERT_EQ(sig1, sig1); ASSERT_EQ(sig2, sig3); @@ -408,8 +316,6 @@ TEST(KernelSignature, Equals) { // Match first 2 args, but not third ASSERT_NE(sig4, sig5); - - ASSERT_NE(sig6, sig7); } TEST(KernelSignature, VarArgsEquals) { @@ -441,40 +347,32 @@ TEST(KernelSignature, MatchesInputs) { ASSERT_TRUE(sig1.MatchesInputs({})); ASSERT_FALSE(sig1.MatchesInputs({int8()})); - // (any[int8], any[decimal]) -> boolean + // (int8, decimal) -> boolean KernelSignature sig2({int8(), InputType(Type::DECIMAL)}, boolean()); ASSERT_FALSE(sig2.MatchesInputs({})); ASSERT_FALSE(sig2.MatchesInputs({int8()})); ASSERT_TRUE(sig2.MatchesInputs({int8(), decimal(12, 2)})); - ASSERT_TRUE(sig2.MatchesInputs( - {ValueDescr::Scalar(int8()), ValueDescr::Scalar(decimal(12, 2))})); - ASSERT_TRUE( - sig2.MatchesInputs({ValueDescr::Array(int8()), ValueDescr::Array(decimal(12, 2))})); - // (scalar[int8], array[int32]) -> boolean - KernelSignature sig3({ValueDescr::Scalar(int8()), ValueDescr::Array(int32())}, - boolean()); + // (int8, int32) -> boolean + KernelSignature sig3({int8(), int32()}, boolean()); ASSERT_FALSE(sig3.MatchesInputs({})); // Unqualified, these are ANY type and do not match because the kernel // requires a scalar and an array - ASSERT_FALSE(sig3.MatchesInputs({int8(), int32()})); - ASSERT_TRUE( - sig3.MatchesInputs({ValueDescr::Scalar(int8()), ValueDescr::Array(int32())})); - ASSERT_FALSE( - sig3.MatchesInputs({ValueDescr::Array(int8()), ValueDescr::Array(int32())})); + ASSERT_TRUE(sig3.MatchesInputs({int8(), int32()})); + ASSERT_FALSE(sig3.MatchesInputs({int8(), int16()})); } TEST(KernelSignature, VarArgsMatchesInputs) { { KernelSignature sig({int8()}, utf8(), /*is_varargs=*/true); - std::vector args = {int8()}; + std::vector args = {int8()}; ASSERT_TRUE(sig.MatchesInputs(args)); - args.push_back(ValueDescr::Scalar(int8())); - args.push_back(ValueDescr::Array(int8())); + args.push_back(int8()); + args.push_back(int8()); ASSERT_TRUE(sig.MatchesInputs(args)); args.push_back(int32()); ASSERT_FALSE(sig.MatchesInputs(args)); @@ -482,10 +380,10 @@ TEST(KernelSignature, VarArgsMatchesInputs) { { KernelSignature sig({int8(), utf8()}, utf8(), /*is_varargs=*/true); - std::vector args = {int8()}; + std::vector args = {int8()}; ASSERT_TRUE(sig.MatchesInputs(args)); - args.push_back(ValueDescr::Scalar(utf8())); - args.push_back(ValueDescr::Array(utf8())); + args.push_back(utf8()); + args.push_back(utf8()); ASSERT_TRUE(sig.MatchesInputs(args)); args.push_back(int32()); ASSERT_FALSE(sig.MatchesInputs(args)); @@ -493,23 +391,25 @@ TEST(KernelSignature, VarArgsMatchesInputs) { } TEST(KernelSignature, ToString) { - std::vector in_types = {InputType(int8(), ValueDescr::SCALAR), - InputType(Type::DECIMAL, ValueDescr::ARRAY), + std::vector in_types = {InputType(int8()), InputType(Type::DECIMAL), InputType(utf8())}; KernelSignature sig(in_types, utf8()); - ASSERT_EQ("(scalar[int8], array[Type::DECIMAL128], any[string]) -> string", - sig.ToString()); - - OutputType out_type([](KernelContext*, const std::vector& args) { - return Status::Invalid("NYI"); - }); - KernelSignature sig2({int8(), InputType(Type::DECIMAL)}, out_type); - ASSERT_EQ("(any[int8], any[Type::DECIMAL128]) -> computed", sig2.ToString()); + ASSERT_EQ("(int8, Type::DECIMAL128, string) -> string", sig.ToString()); + + OutputType out_type( + [](KernelContext*, const std::vector& args) -> Result { + return Status::Invalid("NYI"); + }); + KernelSignature sig2({int8(), Type::DECIMAL}, out_type); + ASSERT_EQ("(int8, Type::DECIMAL128) -> computed", sig2.ToString()); } TEST(KernelSignature, VarArgsToString) { KernelSignature sig({int8()}, utf8(), /*is_varargs=*/true); - ASSERT_EQ("varargs[any[int8]] -> string", sig.ToString()); + ASSERT_EQ("varargs[int8*] -> string", sig.ToString()); + + KernelSignature sig2({utf8(), int8()}, utf8(), /*is_varargs=*/true); + ASSERT_EQ("varargs[string, int8*] -> string", sig2.ToString()); } } // namespace compute diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic.cc b/cpp/src/arrow/compute/kernels/aggregate_basic.cc index 661b6a4edb1..57cee87f00d 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_basic.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_basic.cc @@ -195,7 +195,7 @@ Result> CountDistinctInit(KernelContext* ctx, template void AddCountDistinctKernel(InputType type, ScalarAggregateFunction* func) { - AddAggKernel(KernelSignature::Make({type}, ValueDescr::Scalar(int64())), + AddAggKernel(KernelSignature::Make({type}, int64()), CountDistinctInit, func); } @@ -252,7 +252,7 @@ struct MeanImplDefault : public MeanImpl { Result> SumInit(KernelContext* ctx, const KernelInitArgs& args) { SumLikeInit visitor( - ctx, args.inputs[0].type, + ctx, args.inputs[0].GetSharedPtr(), static_cast(*args.options)); return visitor.Create(); } @@ -260,7 +260,7 @@ Result> SumInit(KernelContext* ctx, Result> MeanInit(KernelContext* ctx, const KernelInitArgs& args) { MeanKernelInit visitor( - ctx, args.inputs[0].type, + ctx, args.inputs[0].GetSharedPtr(), static_cast(*args.options)); return visitor.Create(); } @@ -277,7 +277,7 @@ struct ProductImpl : public ScalarAggregator { using ProductType = typename TypeTraits::CType; using OutputType = typename TypeTraits::ScalarType; - explicit ProductImpl(const std::shared_ptr& out_type, + explicit ProductImpl(std::shared_ptr out_type, const ScalarAggregateOptions& options) : out_type(out_type), options(options), @@ -356,10 +356,10 @@ struct NullProductImpl : public NullImpl { struct ProductInit { std::unique_ptr state; KernelContext* ctx; - const std::shared_ptr& type; + std::shared_ptr type; const ScalarAggregateOptions& options; - ProductInit(KernelContext* ctx, const std::shared_ptr& type, + ProductInit(KernelContext* ctx, std::shared_ptr type, const ScalarAggregateOptions& options) : ctx(ctx), type(type), options(options) {} @@ -402,7 +402,7 @@ struct ProductInit { static Result> Init(KernelContext* ctx, const KernelInitArgs& args) { - ProductInit visitor(ctx, args.inputs[0].type, + ProductInit visitor(ctx, args.inputs[0].GetSharedPtr(), static_cast(*args.options)); return visitor.Create(); } @@ -413,10 +413,10 @@ struct ProductInit { Result> MinMaxInit(KernelContext* ctx, const KernelInitArgs& args) { - ARROW_ASSIGN_OR_RAISE(auto out_type, + ARROW_ASSIGN_OR_RAISE(TypeHolder out_type, args.kernel->signature->out_type().Resolve(ctx, args.inputs)); MinMaxInitState visitor( - ctx, *args.inputs[0].type, std::move(out_type.type), + ctx, *args.inputs[0], out_type.GetSharedPtr(), static_cast(*args.options)); return visitor.Create(); } @@ -425,14 +425,7 @@ Result> MinMaxInit(KernelContext* ctx, template void AddMinOrMaxAggKernel(ScalarAggregateFunction* func, ScalarAggregateFunction* min_max_func) { - auto sig = KernelSignature::Make( - {InputType(ValueDescr::ANY)}, - OutputType([](KernelContext*, - const std::vector& descrs) -> Result { - // any[T] -> scalar[T] - return ValueDescr::Scalar(descrs.front().type); - })); - + auto sig = KernelSignature::Make({InputType::Any()}, FirstType); auto init = [min_max_func]( KernelContext* ctx, const KernelInitArgs& args) -> Result> { @@ -775,8 +768,7 @@ void AddBasicAggKernels(KernelInit init, SimdLevel::type simd_level) { for (const auto& ty : types) { // array[InT] -> scalar[OutT] - auto sig = - KernelSignature::Make({InputType::Array(ty->id())}, ValueDescr::Scalar(out_ty)); + auto sig = KernelSignature::Make({ty->id()}, out_ty); AddAggKernel(std::move(sig), init, func, simd_level); } } @@ -786,9 +778,7 @@ void AddScalarAggKernels(KernelInit init, std::shared_ptr out_ty, ScalarAggregateFunction* func) { for (const auto& ty : types) { - // scalar[InT] -> scalar[OutT] - auto sig = - KernelSignature::Make({InputType::Scalar(ty->id())}, ValueDescr::Scalar(out_ty)); + auto sig = KernelSignature::Make({ty->id()}, out_ty); AddAggKernel(std::move(sig), init, func, SimdLevel::NONE); } } @@ -804,17 +794,17 @@ void AddArrayScalarAggKernels(KernelInit init, namespace { -Result MinMaxType(KernelContext*, const std::vector& descrs) { - // any[T] -> scalar[struct] - auto ty = descrs.front().type; - return ValueDescr::Scalar(struct_({field("min", ty), field("max", ty)})); +Result MinMaxType(KernelContext*, const std::vector& types) { + // T -> struct + auto ty = types.front().GetSharedPtr(); + return struct_({field("min", ty), field("max", ty)}); } } // namespace void AddMinMaxKernel(KernelInit init, internal::detail::GetTypeId get_id, ScalarAggregateFunction* func, SimdLevel::type simd_level) { - auto sig = KernelSignature::Make({InputType(get_id.id)}, OutputType(MinMaxType)); + auto sig = KernelSignature::Make({InputType(get_id.id)}, MinMaxType); AddAggKernel(std::move(sig), init, func, simd_level); } @@ -828,13 +818,6 @@ void AddMinMaxKernels(KernelInit init, namespace { -Result ScalarFirstType(KernelContext*, - const std::vector& descrs) { - ValueDescr result = descrs.front(); - result.shape = ValueDescr::SCALAR; - return result; -} - const FunctionDoc count_doc{"Count the number of null / non-null values", ("By default, only non-null values are counted.\n" "This can be changed through CountOptions."), @@ -922,8 +905,7 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) { // Takes any input, outputs int64 scalar InputType any_input; - AddAggKernel(KernelSignature::Make({any_input}, ValueDescr::Scalar(int64())), CountInit, - func.get()); + AddAggKernel(KernelSignature::Make({any_input}, int64()), CountInit, func.get()); DCHECK_OK(registry->AddFunction(std::move(func))); func = std::make_shared( @@ -935,12 +917,10 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) { func = std::make_shared("sum", Arity::Unary(), sum_doc, &default_scalar_aggregate_options); AddArrayScalarAggKernels(SumInit, {boolean()}, uint64(), func.get()); - AddAggKernel( - KernelSignature::Make({InputType(Type::DECIMAL128)}, OutputType(ScalarFirstType)), - SumInit, func.get(), SimdLevel::NONE); - AddAggKernel( - KernelSignature::Make({InputType(Type::DECIMAL256)}, OutputType(ScalarFirstType)), - SumInit, func.get(), SimdLevel::NONE); + AddAggKernel(KernelSignature::Make({Type::DECIMAL128}, FirstType), SumInit, func.get(), + SimdLevel::NONE); + AddAggKernel(KernelSignature::Make({Type::DECIMAL256}, FirstType), SumInit, func.get(), + SimdLevel::NONE); AddArrayScalarAggKernels(SumInit, SignedIntTypes(), int64(), func.get()); AddArrayScalarAggKernels(SumInit, UnsignedIntTypes(), uint64(), func.get()); AddArrayScalarAggKernels(SumInit, FloatingPointTypes(), float64(), func.get()); @@ -965,12 +945,10 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) { &default_scalar_aggregate_options); AddArrayScalarAggKernels(MeanInit, {boolean()}, float64(), func.get()); AddArrayScalarAggKernels(MeanInit, NumericTypes(), float64(), func.get()); - AddAggKernel( - KernelSignature::Make({InputType(Type::DECIMAL128)}, OutputType(ScalarFirstType)), - MeanInit, func.get(), SimdLevel::NONE); - AddAggKernel( - KernelSignature::Make({InputType(Type::DECIMAL256)}, OutputType(ScalarFirstType)), - MeanInit, func.get(), SimdLevel::NONE); + AddAggKernel(KernelSignature::Make({Type::DECIMAL128}, FirstType), MeanInit, func.get(), + SimdLevel::NONE); + AddAggKernel(KernelSignature::Make({Type::DECIMAL256}, FirstType), MeanInit, func.get(), + SimdLevel::NONE); AddArrayScalarAggKernels(MeanInit, {null()}, float64(), func.get()); // Add the SIMD variants for mean #if defined(ARROW_HAVE_RUNTIME_AVX2) @@ -1028,12 +1006,10 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) { AddArrayScalarAggKernels(ProductInit::Init, UnsignedIntTypes(), uint64(), func.get()); AddArrayScalarAggKernels(ProductInit::Init, FloatingPointTypes(), float64(), func.get()); - AddAggKernel( - KernelSignature::Make({InputType(Type::DECIMAL128)}, OutputType(ScalarFirstType)), - ProductInit::Init, func.get(), SimdLevel::NONE); - AddAggKernel( - KernelSignature::Make({InputType(Type::DECIMAL256)}, OutputType(ScalarFirstType)), - ProductInit::Init, func.get(), SimdLevel::NONE); + AddAggKernel(KernelSignature::Make({Type::DECIMAL128}, FirstType), ProductInit::Init, + func.get(), SimdLevel::NONE); + AddAggKernel(KernelSignature::Make({Type::DECIMAL256}, FirstType), ProductInit::Init, + func.get(), SimdLevel::NONE); AddArrayScalarAggKernels(ProductInit::Init, {null()}, int64(), func.get()); DCHECK_OK(registry->AddFunction(std::move(func))); diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic_avx2.cc b/cpp/src/arrow/compute/kernels/aggregate_basic_avx2.cc index 00e3e2e5fd4..03b45107eec 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_basic_avx2.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_basic_avx2.cc @@ -37,7 +37,7 @@ struct MeanImplAvx2 : public MeanImpl { Result> SumInitAvx2(KernelContext* ctx, const KernelInitArgs& args) { SumLikeInit visitor( - ctx, args.inputs[0].type, + ctx, args.inputs[0].GetSharedPtr(), static_cast(*args.options)); return visitor.Create(); } @@ -45,7 +45,7 @@ Result> SumInitAvx2(KernelContext* ctx, Result> MeanInitAvx2(KernelContext* ctx, const KernelInitArgs& args) { SumLikeInit visitor( - ctx, args.inputs[0].type, + ctx, args.inputs[0].GetSharedPtr(), static_cast(*args.options)); return visitor.Create(); } @@ -55,10 +55,10 @@ Result> MeanInitAvx2(KernelContext* ctx, Result> MinMaxInitAvx2(KernelContext* ctx, const KernelInitArgs& args) { - ARROW_ASSIGN_OR_RAISE(auto out_type, + ARROW_ASSIGN_OR_RAISE(TypeHolder out_type, args.kernel->signature->out_type().Resolve(ctx, args.inputs)); MinMaxInitState visitor( - ctx, *args.inputs[0].type, std::move(out_type.type), + ctx, *args.inputs[0], out_type.GetSharedPtr(), static_cast(*args.options)); return visitor.Create(); } diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic_avx512.cc b/cpp/src/arrow/compute/kernels/aggregate_basic_avx512.cc index 8c10eb19b07..0d66ed2ec3e 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_basic_avx512.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_basic_avx512.cc @@ -37,7 +37,7 @@ struct MeanImplAvx512 : public MeanImpl { Result> SumInitAvx512(KernelContext* ctx, const KernelInitArgs& args) { SumLikeInit visitor( - ctx, args.inputs[0].type, + ctx, args.inputs[0].GetSharedPtr(), static_cast(*args.options)); return visitor.Create(); } @@ -45,7 +45,7 @@ Result> SumInitAvx512(KernelContext* ctx, Result> MeanInitAvx512(KernelContext* ctx, const KernelInitArgs& args) { SumLikeInit visitor( - ctx, args.inputs[0].type, + ctx, args.inputs[0].GetSharedPtr(), static_cast(*args.options)); return visitor.Create(); } @@ -55,10 +55,10 @@ Result> MeanInitAvx512(KernelContext* ctx, Result> MinMaxInitAvx512(KernelContext* ctx, const KernelInitArgs& args) { - ARROW_ASSIGN_OR_RAISE(auto out_type, + ARROW_ASSIGN_OR_RAISE(TypeHolder out_type, args.kernel->signature->out_type().Resolve(ctx, args.inputs)); MinMaxInitState visitor( - ctx, *args.inputs[0].type, std::move(out_type.type), + ctx, *args.inputs[0], out_type.GetSharedPtr(), static_cast(*args.options)); return visitor.Create(); } diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h b/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h index a5b473793a9..6645e1a76bc 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h +++ b/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h @@ -65,8 +65,7 @@ struct SumImpl : public ScalarAggregator { using SumCType = typename TypeTraits::CType; using OutputType = typename TypeTraits::ScalarType; - SumImpl(const std::shared_ptr& out_type, - const ScalarAggregateOptions& options_) + SumImpl(std::shared_ptr out_type, const ScalarAggregateOptions& options_) : out_type(out_type), options(options_) {} Status Consume(KernelContext*, const ExecBatch& batch) override { @@ -216,10 +215,10 @@ template