From 1606b9619fdb979c0083f967ceff55e731172e35 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Wed, 25 May 2022 11:36:38 -0400 Subject: [PATCH 01/10] ARROW-16657: [C++] Support nesting of extension-id-registries --- .../arrow/engine/substrait/extension_set.cc | 338 +++++++++++------- .../arrow/engine/substrait/extension_set.h | 23 +- .../arrow/engine/substrait/plan_internal.cc | 2 +- .../arrow/engine/substrait/plan_internal.h | 2 +- cpp/src/arrow/engine/substrait/util.cc | 4 + cpp/src/arrow/engine/substrait/util.h | 4 + 6 files changed, 247 insertions(+), 126 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index 80cdf59f496..c0f061cd869 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -107,14 +107,14 @@ struct ExtensionSet::Impl { std::unordered_map types_, functions_; }; -ExtensionSet::ExtensionSet(ExtensionIdRegistry* registry) +ExtensionSet::ExtensionSet(const ExtensionIdRegistry* registry) : registry_(registry), impl_(new Impl(), [](Impl* impl) { delete impl; }) {} Result ExtensionSet::Make(std::vector uris, std::vector type_ids, std::vector type_is_variation, std::vector function_ids, - ExtensionIdRegistry* registry) { + const ExtensionIdRegistry* registry) { ExtensionSet set; set.registry_ = registry; @@ -210,158 +210,254 @@ const int* GetIndex(const KeyToIndex& key_to_index, const Key& key) { return &it->second; } -ExtensionIdRegistry* default_extension_id_registry() { - static struct Impl : ExtensionIdRegistry { - Impl() { - struct TypeName { - std::shared_ptr type; - util::string_view name; - }; - - // The type (variation) mappings listed below need to be kept in sync - // with the YAML at substrait/format/extension_types.yaml manually; - // see ARROW-15535. - for (TypeName e : { - TypeName{uint8(), "u8"}, - TypeName{uint16(), "u16"}, - TypeName{uint32(), "u32"}, - TypeName{uint64(), "u64"}, - TypeName{float16(), "fp16"}, - }) { - DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type), - /*is_variation=*/true)); - } - - for (TypeName e : { - TypeName{null(), "null"}, - TypeName{month_interval(), "interval_month"}, - TypeName{day_time_interval(), "interval_day_milli"}, - TypeName{month_day_nano_interval(), "interval_month_day_nano"}, - }) { - DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type), - /*is_variation=*/false)); - } - - // TODO: this is just a placeholder right now. We'll need a YAML file for - // all functions (and prototypes) that Arrow provides that are relevant - // for Substrait, and include mappings for all of them here. See - // ARROW-15535. - for (util::string_view name : { - "add", - }) { - DCHECK_OK(RegisterFunction({kArrowExtTypesUri, name}, name.to_string())); - } - } +namespace { + +struct ExtensionIdRegistryImpl : ExtensionIdRegistry { + virtual ~ExtensionIdRegistryImpl() {} - std::vector Uris() const override { - return {uris_.begin(), uris_.end()}; + std::vector Uris() const override { + return {uris_.begin(), uris_.end()}; + } + + util::optional GetType(const DataType& type) const override { + if (auto index = GetIndex(type_to_index_, &type)) { + return TypeRecord{type_ids_[*index], types_[*index], type_is_variation_[*index]}; } + return {}; + } - util::optional GetType(const DataType& type) const override { - if (auto index = GetIndex(type_to_index_, &type)) { - return TypeRecord{type_ids_[*index], types_[*index], type_is_variation_[*index]}; - } - return {}; + util::optional GetType(Id id, bool is_variation) const override { + if (auto index = GetIndex(is_variation ? variation_id_to_index_ : id_to_index_, id)) { + return TypeRecord{type_ids_[*index], types_[*index], type_is_variation_[*index]}; } + return {}; + } - util::optional GetType(Id id, bool is_variation) const override { - if (auto index = - GetIndex(is_variation ? variation_id_to_index_ : id_to_index_, id)) { - return TypeRecord{type_ids_[*index], types_[*index], type_is_variation_[*index]}; - } - return {}; + Status CanRegisterType(Id id, const std::shared_ptr& type, + bool is_variation) const override { + auto& id_to_index = is_variation ? variation_id_to_index_ : id_to_index_; + if (id_to_index.find(id) != id_to_index.end()) { + return Status::Invalid("Type id was already registered"); + } + if (type_to_index_.find(&*type) != type_to_index_.end()) { + return Status::Invalid("Type was already registered"); } + return Status::OK(); + } - Status RegisterType(Id id, std::shared_ptr type, - bool is_variation) override { - DCHECK_EQ(type_ids_.size(), types_.size()); - DCHECK_EQ(type_ids_.size(), type_is_variation_.size()); + Status RegisterType(Id id, std::shared_ptr type, bool is_variation) override { + DCHECK_EQ(type_ids_.size(), types_.size()); + DCHECK_EQ(type_ids_.size(), type_is_variation_.size()); - Id copied_id{*uris_.emplace(id.uri.to_string()).first, - *names_.emplace(id.name.to_string()).first}; + Id copied_id{*uris_.emplace(id.uri.to_string()).first, + *names_.emplace(id.name.to_string()).first}; - auto index = static_cast(type_ids_.size()); + auto index = static_cast(type_ids_.size()); - auto* id_to_index = is_variation ? &variation_id_to_index_ : &id_to_index_; - auto it_success = id_to_index->emplace(copied_id, index); + auto* id_to_index = is_variation ? &variation_id_to_index_ : &id_to_index_; + auto it_success = id_to_index->emplace(copied_id, index); - if (!it_success.second) { - return Status::Invalid("Type id was already registered"); - } + if (!it_success.second) { + return Status::Invalid("Type id was already registered"); + } - if (!type_to_index_.emplace(type.get(), index).second) { - id_to_index->erase(it_success.first); - return Status::Invalid("Type was already registered"); - } + if (!type_to_index_.emplace(type.get(), index).second) { + id_to_index->erase(it_success.first); + return Status::Invalid("Type was already registered"); + } + + type_ids_.push_back(copied_id); + types_.push_back(std::move(type)); + type_is_variation_.push_back(is_variation); + return Status::OK(); + } - type_ids_.push_back(copied_id); - types_.push_back(std::move(type)); - type_is_variation_.push_back(is_variation); - return Status::OK(); + util::optional GetFunction( + util::string_view arrow_function_name) const override { + if (auto index = GetIndex(function_name_to_index_, arrow_function_name)) { + return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]}; } + return {}; + } + + util::optional GetFunction(Id id) const override { + if (auto index = GetIndex(function_id_to_index_, id)) { + return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]}; + } + return {}; + } + + Status CanRegisterFunction(Id id, + const std::string& arrow_function_name) const override { + if (function_id_to_index_.find(id) != function_id_to_index_.end()) { + return Status::Invalid("Function id was already registered"); + } + if (function_name_to_index_.find(arrow_function_name) != + function_name_to_index_.end()) { + return Status::Invalid("Function name was already registered"); + } + return Status::OK(); + } + + Status RegisterFunction(Id id, std::string arrow_function_name) override { + DCHECK_EQ(function_ids_.size(), function_name_ptrs_.size()); + + Id copied_id{*uris_.emplace(id.uri.to_string()).first, + *names_.emplace(id.name.to_string()).first}; + + const std::string& copied_function_name{ + *function_names_.emplace(std::move(arrow_function_name)).first}; + + auto index = static_cast(function_ids_.size()); + + auto it_success = function_id_to_index_.emplace(copied_id, index); - util::optional GetFunction( - util::string_view arrow_function_name) const override { - if (auto index = GetIndex(function_name_to_index_, arrow_function_name)) { - return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]}; - } - return {}; + if (!it_success.second) { + return Status::Invalid("Function id was already registered"); } - util::optional GetFunction(Id id) const override { - if (auto index = GetIndex(function_id_to_index_, id)) { - return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]}; - } - return {}; + if (!function_name_to_index_.emplace(copied_function_name, index).second) { + function_id_to_index_.erase(it_success.first); + return Status::Invalid("Function name was already registered"); } - Status RegisterFunction(Id id, std::string arrow_function_name) override { - DCHECK_EQ(function_ids_.size(), function_name_ptrs_.size()); + function_name_ptrs_.push_back(&copied_function_name); + function_ids_.push_back(copied_id); + return Status::OK(); + } - Id copied_id{*uris_.emplace(id.uri.to_string()).first, - *names_.emplace(id.name.to_string()).first}; + // owning storage of uris, names, (arrow::)function_names, types + // note that storing strings like this is safe since references into an + // unordered_set are not invalidated on insertion + std::unordered_set uris_, names_, function_names_; + DataTypeVector types_; + std::vector type_is_variation_; + + // non-owning lookup helpers + std::vector type_ids_, function_ids_; + std::unordered_map id_to_index_, variation_id_to_index_; + std::unordered_map type_to_index_; + + std::vector function_name_ptrs_; + std::unordered_map function_id_to_index_; + std::unordered_map + function_name_to_index_; +}; - const std::string& copied_function_name{ - *function_names_.emplace(std::move(arrow_function_name)).first}; +struct NestedExtensionIdRegistryImpl : ExtensionIdRegistryImpl { + explicit NestedExtensionIdRegistryImpl(const ExtensionIdRegistry* parent) + : parent_(parent) {} - auto index = static_cast(function_ids_.size()); + virtual ~NestedExtensionIdRegistryImpl() {} - auto it_success = function_id_to_index_.emplace(copied_id, index); + std::vector Uris() const override { + std::vector uris = parent_->Uris(); + std::unordered_set uri_set; + uri_set.insert(uris.begin(), uris.end()); + uri_set.insert(uris_.begin(), uris_.end()); + return std::vector(uris); + } - if (!it_success.second) { - return Status::Invalid("Function id was already registered"); - } + util::optional GetType(const DataType& type) const override { + auto type_opt = ExtensionIdRegistryImpl::GetType(type); + if (type_opt) { + return type_opt; + } + return parent_->GetType(type); + } + + util::optional GetType(Id id, bool is_variation) const override { + auto type_opt = ExtensionIdRegistryImpl::GetType(id, is_variation); + if (type_opt) { + return type_opt; + } + return parent_->GetType(id, is_variation); + } - if (!function_name_to_index_.emplace(copied_function_name, index).second) { - function_id_to_index_.erase(it_success.first); - return Status::Invalid("Function name was already registered"); - } + Status RegisterType(Id id, std::shared_ptr type, bool is_variation) override { + return parent_->CanRegisterType(id, type, is_variation) & + ExtensionIdRegistryImpl::RegisterType(id, type, is_variation); + } - function_name_ptrs_.push_back(&copied_function_name); - function_ids_.push_back(copied_id); - return Status::OK(); + util::optional GetFunction( + util::string_view arrow_function_name) const override { + auto func_opt = ExtensionIdRegistryImpl::GetFunction(arrow_function_name); + if (func_opt) { + return func_opt; } + return parent_->GetFunction(arrow_function_name); + } - // owning storage of uris, names, (arrow::)function_names, types - // note that storing strings like this is safe since references into an - // unordered_set are not invalidated on insertion - std::unordered_set uris_, names_, function_names_; - DataTypeVector types_; - std::vector type_is_variation_; + util::optional GetFunction(Id id) const override { + auto func_opt = ExtensionIdRegistryImpl::GetFunction(id); + if (func_opt) { + return func_opt; + } + return parent_->GetFunction(id); + } - // non-owning lookup helpers - std::vector type_ids_, function_ids_; - std::unordered_map id_to_index_, variation_id_to_index_; - std::unordered_map type_to_index_; + Status RegisterFunction(Id id, std::string arrow_function_name) override { + return parent_->CanRegisterFunction(id, arrow_function_name) & + ExtensionIdRegistryImpl::RegisterFunction(id, arrow_function_name); + } - std::vector function_name_ptrs_; - std::unordered_map function_id_to_index_; - std::unordered_map - function_name_to_index_; - } impl_; + const ExtensionIdRegistry* parent_; +}; +struct DefaultExtensionIdRegistry : ExtensionIdRegistryImpl { + DefaultExtensionIdRegistry() { + struct TypeName { + std::shared_ptr type; + util::string_view name; + }; + + // The type (variation) mappings listed below need to be kept in sync + // with the YAML at substrait/format/extension_types.yaml manually; + // see ARROW-15535. + for (TypeName e : { + TypeName{uint8(), "u8"}, + TypeName{uint16(), "u16"}, + TypeName{uint32(), "u32"}, + TypeName{uint64(), "u64"}, + TypeName{float16(), "fp16"}, + }) { + DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type), + /*is_variation=*/true)); + } + + for (TypeName e : { + TypeName{null(), "null"}, + TypeName{month_interval(), "interval_month"}, + TypeName{day_time_interval(), "interval_day_milli"}, + TypeName{month_day_nano_interval(), "interval_month_day_nano"}, + }) { + DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type), + /*is_variation=*/false)); + } + + // TODO: this is just a placeholder right now. We'll need a YAML file for + // all functions (and prototypes) that Arrow provides that are relevant + // for Substrait, and include mappings for all of them here. See + // ARROW-15535. + for (util::string_view name : { + "add", + }) { + DCHECK_OK(RegisterFunction({kArrowExtTypesUri, name}, name.to_string())); + } + } +}; + +} // namespace + +ExtensionIdRegistry* default_extension_id_registry() { + static DefaultExtensionIdRegistry impl_; return &impl_; } +std::shared_ptr nested_extension_id_registry( + const ExtensionIdRegistry* parent) { + return std::make_shared(parent); +} + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/extension_set.h b/cpp/src/arrow/engine/substrait/extension_set.h index 951f7ffa3a1..17cbf7d0633 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.h +++ b/cpp/src/arrow/engine/substrait/extension_set.h @@ -63,6 +63,8 @@ class ARROW_ENGINE_EXPORT ExtensionIdRegistry { }; virtual util::optional GetType(const DataType&) const = 0; virtual util::optional GetType(Id, bool is_variation) const = 0; + virtual Status CanRegisterType(Id, const std::shared_ptr& type, + bool is_variation) const = 0; virtual Status RegisterType(Id, std::shared_ptr, bool is_variation) = 0; /// \brief A mapping between a Substrait ID and an Arrow function @@ -84,6 +86,8 @@ class ARROW_ENGINE_EXPORT ExtensionIdRegistry { virtual util::optional GetFunction(Id) const = 0; virtual util::optional GetFunction( util::string_view arrow_function_name) const = 0; + virtual Status CanRegisterFunction(Id, + const std::string& arrow_function_name) const = 0; virtual Status RegisterFunction(Id, std::string arrow_function_name) = 0; }; @@ -96,6 +100,19 @@ constexpr util::string_view kArrowExtTypesUri = /// Note: Function support is currently very minimal, see ARROW-15538 ARROW_ENGINE_EXPORT ExtensionIdRegistry* default_extension_id_registry(); +/// \brief Makes a nested registry with a given parent. +/// +/// A nested registry supports registering types and functions other and on top of those +/// already registered in its parent registry. No conflicts in IDs and names used for +/// lookup are allowed. Normally, the given parent is the default registry. +/// +/// One use case for a nested registry is for dynamic registration of functions defined +/// within a Substrait plan while keeping these registrations specific to the plan. When +/// the Substrait plan is disposed of, normally after its execution, the nested registry +/// can be disposed of as well. +ARROW_ENGINE_EXPORT std::shared_ptr nested_extension_id_registry( + const ExtensionIdRegistry* parent); + /// \brief A set of extensions used within a plan /// /// Each time an extension is used within a Substrait plan the extension @@ -140,7 +157,7 @@ class ARROW_ENGINE_EXPORT ExtensionSet { }; /// Construct an empty ExtensionSet to be populated during serialization. - explicit ExtensionSet(ExtensionIdRegistry* = default_extension_id_registry()); + explicit ExtensionSet(const ExtensionIdRegistry* = default_extension_id_registry()); ARROW_DEFAULT_MOVE_AND_ASSIGN(ExtensionSet); /// Construct an ExtensionSet with explicit extension ids for efficient referencing @@ -160,7 +177,7 @@ class ARROW_ENGINE_EXPORT ExtensionSet { static Result Make( std::vector uris, std::vector type_ids, std::vector type_is_variation, std::vector function_ids, - ExtensionIdRegistry* = default_extension_id_registry()); + const ExtensionIdRegistry* = default_extension_id_registry()); // index in these vectors == value of _anchor/_reference fields /// TODO(ARROW-15583) this assumes that _anchor/_references won't be huge, which is not @@ -224,7 +241,7 @@ class ARROW_ENGINE_EXPORT ExtensionSet { std::size_t num_functions() const { return functions_.size(); } private: - ExtensionIdRegistry* registry_; + const ExtensionIdRegistry* registry_; /// The subset of extension registry URIs referenced by this extension set std::vector uris_; std::vector types_; diff --git a/cpp/src/arrow/engine/substrait/plan_internal.cc b/cpp/src/arrow/engine/substrait/plan_internal.cc index 5813dcde24c..8b3d53ae794 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.cc +++ b/cpp/src/arrow/engine/substrait/plan_internal.cc @@ -111,7 +111,7 @@ void SetElement(size_t i, const Element& element, std::vector* vector) { } // namespace Result GetExtensionSetFromPlan(const substrait::Plan& plan, - ExtensionIdRegistry* registry) { + const ExtensionIdRegistry* registry) { std::vector uris; for (const auto& uri : plan.extension_uris()) { SetElement(uri.extension_uri_anchor(), uri.uri(), &uris); diff --git a/cpp/src/arrow/engine/substrait/plan_internal.h b/cpp/src/arrow/engine/substrait/plan_internal.h index 281cab0c0f3..dce23cdceba 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.h +++ b/cpp/src/arrow/engine/substrait/plan_internal.h @@ -49,7 +49,7 @@ Status AddExtensionSetToPlan(const ExtensionSet& ext_set, substrait::Plan* plan) ARROW_ENGINE_EXPORT Result GetExtensionSetFromPlan( const substrait::Plan& plan, - ExtensionIdRegistry* registry = default_extension_id_registry()); + const ExtensionIdRegistry* registry = default_extension_id_registry()); } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/util.cc b/cpp/src/arrow/engine/substrait/util.cc index bc2aa36856e..2ae3771f3fb 100644 --- a/cpp/src/arrow/engine/substrait/util.cc +++ b/cpp/src/arrow/engine/substrait/util.cc @@ -123,6 +123,10 @@ Result> SerializeJsonPlan(const std::string& substrait_j return engine::internal::SubstraitFromJSON("Plan", substrait_json); } +std::shared_ptr MakeExtensionIdRegistry() { + return nested_extension_id_registry(default_extension_id_registry()); +} + } // namespace substrait } // namespace engine diff --git a/cpp/src/arrow/engine/substrait/util.h b/cpp/src/arrow/engine/substrait/util.h index 860a459da2f..914da6a8ebf 100644 --- a/cpp/src/arrow/engine/substrait/util.h +++ b/cpp/src/arrow/engine/substrait/util.h @@ -37,6 +37,10 @@ ARROW_ENGINE_EXPORT Result> ExecuteSerialized ARROW_ENGINE_EXPORT Result> SerializeJsonPlan( const std::string& substrait_json); +/// \brief Makes a nested registry with the default registry as parent. +/// See arrow:engine::nested_extension_id_registry for details. +ARROW_ENGINE_EXPORT std::shared_ptr MakeExtensionIdRegistry(); + } // namespace substrait } // namespace engine From cb65dcb8ef56e99c82d7934d83487e3fff479d59 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Wed, 25 May 2022 12:50:30 -0400 Subject: [PATCH 02/10] fix merge --- cpp/src/arrow/engine/substrait/plan_internal.cc | 1 + testing | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/engine/substrait/plan_internal.cc b/cpp/src/arrow/engine/substrait/plan_internal.cc index e382a42b1d0..fcee0b2188f 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.cc +++ b/cpp/src/arrow/engine/substrait/plan_internal.cc @@ -91,6 +91,7 @@ Status AddExtensionSetToPlan(const ExtensionSet& ext_set, substrait::Plan* plan) } Result GetExtensionSetFromPlan(const substrait::Plan& plan, + const ExtensionIdRegistry* registry) { std::unordered_map uris; uris.reserve(plan.extension_uris_size()); for (const auto& uri : plan.extension_uris()) { diff --git a/testing b/testing index 53b49804710..634739c6644 160000 --- a/testing +++ b/testing @@ -1 +1 @@ -Subproject commit 53b498047109d9940fcfab388bd9d6aeb8c57425 +Subproject commit 634739c664433cec366b4b9a81d1e1044a8c5eda From 4fa003b27c35b86eee5864a94c8b03e68125da81 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Wed, 25 May 2022 13:03:55 -0400 Subject: [PATCH 03/10] lint --- cpp/src/arrow/engine/substrait/extension_set.cc | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index b6c5608e33e..a235f1c72ef 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -96,8 +96,7 @@ Status ExtensionSet::AddUri(Id id) { Result ExtensionSet::Make( std::unordered_map uris, std::unordered_map type_ids, - std::unordered_map function_ids, - const ExtensionIdRegistry* registry) { + std::unordered_map function_ids, const ExtensionIdRegistry* registry) { ExtensionSet set; set.registry_ = registry; @@ -228,8 +227,7 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry { return {}; } - Status CanRegisterType(Id id, - const std::shared_ptr& type) const override { + Status CanRegisterType(Id id, const std::shared_ptr& type) const override { if (id_to_index_.find(id) != id_to_index_.end()) { return Status::Invalid("Type id was already registered"); } From b0e01125ec50dc71cdd67f7a1dbb1d5bbed6a539 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Thu, 26 May 2022 07:55:44 -0400 Subject: [PATCH 04/10] add tests --- cpp/src/arrow/engine/CMakeLists.txt | 1 + cpp/src/arrow/engine/substrait/ext_test.cc | 270 ++++++++++++++++++ .../arrow/engine/substrait/extension_set.cc | 11 + cpp/src/arrow/engine/substrait/util.h | 2 +- 4 files changed, 283 insertions(+), 1 deletion(-) create mode 100644 cpp/src/arrow/engine/substrait/ext_test.cc diff --git a/cpp/src/arrow/engine/CMakeLists.txt b/cpp/src/arrow/engine/CMakeLists.txt index ea9797ea1d7..8edd22900e6 100644 --- a/cpp/src/arrow/engine/CMakeLists.txt +++ b/cpp/src/arrow/engine/CMakeLists.txt @@ -66,6 +66,7 @@ endif() add_arrow_test(substrait_test SOURCES + substrait/ext_test.cc substrait/serde_test.cc EXTRA_LINK_LIBS ${ARROW_SUBSTRAIT_TEST_LINK_LIBS} diff --git a/cpp/src/arrow/engine/substrait/ext_test.cc b/cpp/src/arrow/engine/substrait/ext_test.cc new file mode 100644 index 00000000000..26be0f4cdf2 --- /dev/null +++ b/cpp/src/arrow/engine/substrait/ext_test.cc @@ -0,0 +1,270 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/engine/substrait/extension_set.h" +#include "arrow/engine/substrait/util.h" + +#include +#include +#include +#include + +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/matchers.h" + +using testing::ElementsAre; +using testing::Eq; +using testing::HasSubstr; +using testing::UnorderedElementsAre; + +namespace arrow { + +using internal::checked_cast; + +namespace engine { + +// an extension-id-registry provider to be used as a test parameter +// +// we cannot pass a pointer to a nested registry as a test parameter because the +// shared_ptr in which it is made would not be held and get destructed too early, +// nor can we pass a shared_ptr to the default nested registry as a test parameter +// because it is global and must never be cleaned up, so we pass a shared_ptr to a +// provider that either owns or does not own the registry it provides, depending +// on the case. +struct ExtensionIdRegistryProvider { + virtual ExtensionIdRegistry *get() const = 0; +}; + +struct DefaultExtensionIdRegistryProvider + : public ExtensionIdRegistryProvider { + ExtensionIdRegistry *get() const override { + return default_extension_id_registry(); + } +}; + +struct NestedExtensionIdRegistryProvider + : public ExtensionIdRegistryProvider { + std::shared_ptr registry_ = + substrait::MakeExtensionIdRegistry(); + ExtensionIdRegistry *get() const override { + return &*registry_; + } +}; + +using Id = ExtensionIdRegistry::Id; + +bool operator==(const Id& id1, const Id& id2) { + return id1.uri == id2.uri && id1.name == id2.name; +} + +bool operator!=(const Id& id1, const Id& id2) { + return !(id1 == id2); +} + +struct TypeName { + std::shared_ptr type; + util::string_view name; +}; + +static const std::vector kTypeNames = { + TypeName{uint8(), "u8"}, + TypeName{uint16(), "u16"}, + TypeName{uint32(), "u32"}, + TypeName{uint64(), "u64"}, + TypeName{float16(), "fp16"}, + TypeName{null(), "null"}, + TypeName{month_interval(), "interval_month"}, + TypeName{day_time_interval(), "interval_day_milli"}, + TypeName{month_day_nano_interval(), "interval_month_day_nano"}, +}; + +static const std::vector kFunctionNames = { + "add", +}; + +static const std::vector kTempFunctionNames = { + "temp_func_1", + "temp_func_2", +}; + +static const std::vector kTempTypeNames = { + TypeName{timestamp(TimeUnit::SECOND, "temp_tz_1"), "temp_type_1"}, + TypeName{timestamp(TimeUnit::SECOND, "temp_tz_2"), "temp_type_2"}, +}; + +using ExtensionIdRegistryParams = + std::tuple, std::string>; + +struct ExtensionIdRegistryTest + : public testing::TestWithParam {}; + +TEST_P(ExtensionIdRegistryTest, GetTypes) { + auto provider = std::get<0>(GetParam()); + auto registry = provider->get(); + + for (TypeName e : kTypeNames) { + auto id = Id{kArrowExtTypesUri, e.name}; + for (auto typerec_opt : {registry->GetType(id), registry->GetType(*e.type)}) { + ASSERT_TRUE(typerec_opt); + auto typerec = typerec_opt.value(); + ASSERT_EQ(id, typerec.id); + ASSERT_EQ(*e.type, *typerec.type); + } + } +} + +TEST_P(ExtensionIdRegistryTest, ReregisterTypes) { + auto provider = std::get<0>(GetParam()); + auto registry = provider->get(); + + for (TypeName e : kTypeNames) { + auto id = Id{kArrowExtTypesUri, e.name}; + ASSERT_RAISES(Invalid, registry->CanRegisterType(id, e.type)); + ASSERT_RAISES(Invalid, registry->RegisterType(id, e.type)); + } +} + +TEST_P(ExtensionIdRegistryTest, GetFunctions) { + auto provider = std::get<0>(GetParam()); + auto registry = provider->get(); + + for (util::string_view name : kFunctionNames) { + auto id = Id{kArrowExtTypesUri, name}; + for (auto funcrec_opt : {registry->GetFunction(id), registry->GetFunction(name)}) { + ASSERT_TRUE(funcrec_opt); + auto funcrec = funcrec_opt.value(); + ASSERT_EQ(id, funcrec.id); + ASSERT_EQ(name, funcrec.function_name); + } + } +} + +TEST_P(ExtensionIdRegistryTest, ReregisterFunctions) { + auto provider = std::get<0>(GetParam()); + auto registry = provider->get(); + + for (util::string_view name : kFunctionNames) { + auto id = Id{kArrowExtTypesUri, name}; + ASSERT_RAISES(Invalid, registry->CanRegisterFunction(id, name.to_string())); + ASSERT_RAISES(Invalid, registry->RegisterFunction(id, name.to_string())); + } +} + +INSTANTIATE_TEST_SUITE_P( + Substrait, + ExtensionIdRegistryTest, + testing::Values( + std::make_tuple(std::make_shared(), "default"), + std::make_tuple(std::make_shared(), "nested"))); + +TEST(ExtensionIdRegistryTest, RegisterTempTypes) { + auto default_registry = default_extension_id_registry(); + constexpr int rounds = 3; + for (int i = 0; i < rounds; i++) { + auto registry = substrait::MakeExtensionIdRegistry(); + + for (TypeName e : kTempTypeNames) { + auto id = Id{kArrowExtTypesUri, e.name}; + ASSERT_OK(registry->CanRegisterType(id, e.type)); + ASSERT_OK(registry->RegisterType(id, e.type)); + ASSERT_RAISES(Invalid, registry->CanRegisterType(id, e.type)); + ASSERT_RAISES(Invalid, registry->RegisterType(id, e.type)); + ASSERT_OK(default_registry->CanRegisterType(id, e.type)); + } + } +} + +TEST(ExtensionIdRegistryTest, RegisterTempFunctions) { + auto default_registry = default_extension_id_registry(); + constexpr int rounds = 3; + for (int i = 0; i < rounds; i++) { + auto registry = substrait::MakeExtensionIdRegistry(); + + for (util::string_view name : kTempFunctionNames) { + auto id = Id{kArrowExtTypesUri, name}; + ASSERT_OK(registry->CanRegisterFunction(id, name.to_string())); + ASSERT_OK(registry->RegisterFunction(id, name.to_string())); + ASSERT_RAISES(Invalid, registry->CanRegisterFunction(id, name.to_string())); + ASSERT_RAISES(Invalid, registry->RegisterFunction(id, name.to_string())); + ASSERT_OK(default_registry->CanRegisterFunction(id, name.to_string())); + } + } +} + +TEST(ExtensionIdRegistryTest, RegisterNestedTypes) { + std::shared_ptr type1 = kTempTypeNames[0].type; + std::shared_ptr type2 = kTempTypeNames[1].type; + auto id1 = Id{kArrowExtTypesUri, kTempTypeNames[0].name}; + auto id2 = Id{kArrowExtTypesUri, kTempTypeNames[1].name}; + + auto default_registry = default_extension_id_registry(); + constexpr int rounds = 3; + for (int i = 0; i < rounds; i++) { + auto registry1 = substrait::MakeExtensionIdRegistry(); + + ASSERT_OK(registry1->CanRegisterType(id1, type1)); + ASSERT_OK(registry1->RegisterType(id1, type1)); + + for (int j = 0; j < rounds; j++) { + auto registry2 = substrait::MakeExtensionIdRegistry(); + + ASSERT_OK(registry2->CanRegisterType(id2, type2)); + ASSERT_OK(registry2->RegisterType(id2, type2)); + ASSERT_RAISES(Invalid, registry2->CanRegisterType(id2, type2)); + ASSERT_RAISES(Invalid, registry2->RegisterType(id2, type2)); + ASSERT_OK(default_registry->CanRegisterType(id2, type2)); + } + + ASSERT_RAISES(Invalid, registry1->CanRegisterType(id1, type1)); + ASSERT_RAISES(Invalid, registry1->RegisterType(id1, type1)); + ASSERT_OK(default_registry->CanRegisterType(id1, type1)); + } +} + +TEST(ExtensionIdRegistryTest, RegisterNestedFunctions) { + util::string_view name1 = kTempFunctionNames[0]; + util::string_view name2 = kTempFunctionNames[1]; + auto id1 = Id{kArrowExtTypesUri, name1}; + auto id2 = Id{kArrowExtTypesUri, name2}; + + auto default_registry = default_extension_id_registry(); + constexpr int rounds = 3; + for (int i = 0; i < rounds; i++) { + auto registry1 = substrait::MakeExtensionIdRegistry(); + + ASSERT_OK(registry1->CanRegisterFunction(id1, name1.to_string())); + ASSERT_OK(registry1->RegisterFunction(id1, name1.to_string())); + + for (int j = 0; j < rounds; j++) { + auto registry2 = substrait::MakeExtensionIdRegistry(); + + ASSERT_OK(registry2->CanRegisterFunction(id2, name2.to_string())); + ASSERT_OK(registry2->RegisterFunction(id2, name2.to_string())); + ASSERT_RAISES(Invalid, registry2->CanRegisterFunction(id2, name2.to_string())); + ASSERT_RAISES(Invalid, registry2->RegisterFunction(id2, name2.to_string())); + ASSERT_OK(default_registry->CanRegisterFunction(id2, name2.to_string())); + } + + ASSERT_RAISES(Invalid, registry1->CanRegisterFunction(id1, name1.to_string())); + ASSERT_RAISES(Invalid, registry1->RegisterFunction(id1, name1.to_string())); + ASSERT_OK(default_registry->CanRegisterFunction(id1, name1.to_string())); + } +} + + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index a235f1c72ef..648e019ff72 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -362,6 +362,11 @@ struct NestedExtensionIdRegistryImpl : ExtensionIdRegistryImpl { return parent_->GetType(id); } + Status CanRegisterType(Id id, const std::shared_ptr& type) const override { + return parent_->CanRegisterType(id, type) & + ExtensionIdRegistryImpl::CanRegisterType(id, type); + } + Status RegisterType(Id id, std::shared_ptr type) override { return parent_->CanRegisterType(id, type) & ExtensionIdRegistryImpl::RegisterType(id, type); @@ -384,6 +389,12 @@ struct NestedExtensionIdRegistryImpl : ExtensionIdRegistryImpl { return parent_->GetFunction(id); } + Status CanRegisterFunction(Id id, + const std::string& arrow_function_name) const override { + return parent_->CanRegisterFunction(id, arrow_function_name) & + ExtensionIdRegistryImpl::CanRegisterFunction(id, arrow_function_name); + } + Status RegisterFunction(Id id, std::string arrow_function_name) override { return parent_->CanRegisterFunction(id, arrow_function_name) & ExtensionIdRegistryImpl::RegisterFunction(id, arrow_function_name); diff --git a/cpp/src/arrow/engine/substrait/util.h b/cpp/src/arrow/engine/substrait/util.h index 914da6a8ebf..c7fc4c5d808 100644 --- a/cpp/src/arrow/engine/substrait/util.h +++ b/cpp/src/arrow/engine/substrait/util.h @@ -38,7 +38,7 @@ ARROW_ENGINE_EXPORT Result> SerializeJsonPlan( const std::string& substrait_json); /// \brief Makes a nested registry with the default registry as parent. -/// See arrow:engine::nested_extension_id_registry for details. +/// See arrow::engine::nested_extension_id_registry for details. ARROW_ENGINE_EXPORT std::shared_ptr MakeExtensionIdRegistry(); } // namespace substrait From 16a476f227d8eb648e4177322bde065d3f196542 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Thu, 26 May 2022 08:17:52 -0400 Subject: [PATCH 05/10] lint --- cpp/src/arrow/engine/substrait/ext_test.cc | 56 ++++++++++------------ 1 file changed, 24 insertions(+), 32 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/ext_test.cc b/cpp/src/arrow/engine/substrait/ext_test.cc index 26be0f4cdf2..0b505b60cf5 100644 --- a/cpp/src/arrow/engine/substrait/ext_test.cc +++ b/cpp/src/arrow/engine/substrait/ext_test.cc @@ -46,23 +46,16 @@ namespace engine { // provider that either owns or does not own the registry it provides, depending // on the case. struct ExtensionIdRegistryProvider { - virtual ExtensionIdRegistry *get() const = 0; + virtual ExtensionIdRegistry* get() const = 0; }; -struct DefaultExtensionIdRegistryProvider - : public ExtensionIdRegistryProvider { - ExtensionIdRegistry *get() const override { - return default_extension_id_registry(); - } +struct DefaultExtensionIdRegistryProvider : public ExtensionIdRegistryProvider { + ExtensionIdRegistry *get() const override { return default_extension_id_registry(); } }; -struct NestedExtensionIdRegistryProvider - : public ExtensionIdRegistryProvider { - std::shared_ptr registry_ = - substrait::MakeExtensionIdRegistry(); - ExtensionIdRegistry *get() const override { - return &*registry_; - } +struct NestedExtensionIdRegistryProvider : public ExtensionIdRegistryProvider { + std::shared_ptr registry_ = substrait::MakeExtensionIdRegistry(); + ExtensionIdRegistry *get() const override { return &*registry_; } }; using Id = ExtensionIdRegistry::Id; @@ -71,9 +64,7 @@ bool operator==(const Id& id1, const Id& id2) { return id1.uri == id2.uri && id1.name == id2.name; } -bool operator!=(const Id& id1, const Id& id2) { - return !(id1 == id2); -} +bool operator!=(const Id& id1, const Id& id2) { return !(id1 == id2); } struct TypeName { std::shared_ptr type; @@ -81,15 +72,15 @@ struct TypeName { }; static const std::vector kTypeNames = { - TypeName{uint8(), "u8"}, - TypeName{uint16(), "u16"}, - TypeName{uint32(), "u32"}, - TypeName{uint64(), "u64"}, - TypeName{float16(), "fp16"}, - TypeName{null(), "null"}, - TypeName{month_interval(), "interval_month"}, - TypeName{day_time_interval(), "interval_day_milli"}, - TypeName{month_day_nano_interval(), "interval_month_day_nano"}, + TypeName{uint8(), "u8"}, + TypeName{uint16(), "u16"}, + TypeName{uint32(), "u32"}, + TypeName{uint64(), "u64"}, + TypeName{float16(), "fp16"}, + TypeName{null(), "null"}, + TypeName{month_interval(), "interval_month"}, + TypeName{day_time_interval(), "interval_day_milli"}, + TypeName{month_day_nano_interval(), "interval_month_day_nano"}, }; static const std::vector kFunctionNames = { @@ -102,12 +93,12 @@ static const std::vector kTempFunctionNames = { }; static const std::vector kTempTypeNames = { - TypeName{timestamp(TimeUnit::SECOND, "temp_tz_1"), "temp_type_1"}, - TypeName{timestamp(TimeUnit::SECOND, "temp_tz_2"), "temp_type_2"}, + TypeName{timestamp(TimeUnit::SECOND, "temp_tz_1"), "temp_type_1"}, + TypeName{timestamp(TimeUnit::SECOND, "temp_tz_2"), "temp_type_2"}, }; using ExtensionIdRegistryParams = - std::tuple, std::string>; + std::tuple, std::string>; struct ExtensionIdRegistryTest : public testing::TestWithParam {}; @@ -165,11 +156,12 @@ TEST_P(ExtensionIdRegistryTest, ReregisterFunctions) { } INSTANTIATE_TEST_SUITE_P( - Substrait, - ExtensionIdRegistryTest, + Substrait, ExtensionIdRegistryTest, testing::Values( - std::make_tuple(std::make_shared(), "default"), - std::make_tuple(std::make_shared(), "nested"))); + std::make_tuple(std::make_shared(), + "default"), + std::make_tuple(std::make_shared(), + "nested"))); TEST(ExtensionIdRegistryTest, RegisterTempTypes) { auto default_registry = default_extension_id_registry(); From 2592ffa65ce336b71c1259b520578fc30ed9f35b Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Thu, 26 May 2022 08:56:52 -0400 Subject: [PATCH 06/10] lint --- cpp/src/arrow/engine/substrait/ext_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/ext_test.cc b/cpp/src/arrow/engine/substrait/ext_test.cc index 0b505b60cf5..77690dc6ec4 100644 --- a/cpp/src/arrow/engine/substrait/ext_test.cc +++ b/cpp/src/arrow/engine/substrait/ext_test.cc @@ -50,12 +50,12 @@ struct ExtensionIdRegistryProvider { }; struct DefaultExtensionIdRegistryProvider : public ExtensionIdRegistryProvider { - ExtensionIdRegistry *get() const override { return default_extension_id_registry(); } + ExtensionIdRegistry* get() const override { return default_extension_id_registry(); } }; struct NestedExtensionIdRegistryProvider : public ExtensionIdRegistryProvider { std::shared_ptr registry_ = substrait::MakeExtensionIdRegistry(); - ExtensionIdRegistry *get() const override { return &*registry_; } + ExtensionIdRegistry* get() const override { return &*registry_; } }; using Id = ExtensionIdRegistry::Id; From 76e0581b99efa2deb41a31dc3c5bf50b3caae542 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Thu, 26 May 2022 09:10:32 -0400 Subject: [PATCH 07/10] lint --- cpp/src/arrow/engine/substrait/ext_test.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/src/arrow/engine/substrait/ext_test.cc b/cpp/src/arrow/engine/substrait/ext_test.cc index 77690dc6ec4..cd02dc37a20 100644 --- a/cpp/src/arrow/engine/substrait/ext_test.cc +++ b/cpp/src/arrow/engine/substrait/ext_test.cc @@ -257,6 +257,5 @@ TEST(ExtensionIdRegistryTest, RegisterNestedFunctions) { } } - } // namespace engine } // namespace arrow From ff1a1ad66bb37ea5ce573289b04f5069153c51b5 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Thu, 26 May 2022 09:57:12 -0400 Subject: [PATCH 08/10] add virtual dtors --- cpp/src/arrow/engine/substrait/ext_test.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cpp/src/arrow/engine/substrait/ext_test.cc b/cpp/src/arrow/engine/substrait/ext_test.cc index cd02dc37a20..ea6954dd00e 100644 --- a/cpp/src/arrow/engine/substrait/ext_test.cc +++ b/cpp/src/arrow/engine/substrait/ext_test.cc @@ -50,10 +50,12 @@ struct ExtensionIdRegistryProvider { }; struct DefaultExtensionIdRegistryProvider : public ExtensionIdRegistryProvider { + virtual ~DefaultExtensionIdRegistryProvider() {} ExtensionIdRegistry* get() const override { return default_extension_id_registry(); } }; struct NestedExtensionIdRegistryProvider : public ExtensionIdRegistryProvider { + virtual ~NestedExtensionIdRegistryProvider() {} std::shared_ptr registry_ = substrait::MakeExtensionIdRegistry(); ExtensionIdRegistry* get() const override { return &*registry_; } }; From e2e7f2b33bcc54b5069b20e52ef0363d8452e123 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Fri, 27 May 2022 03:41:44 -0400 Subject: [PATCH 09/10] fix nested test --- cpp/src/arrow/engine/substrait/ext_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/ext_test.cc b/cpp/src/arrow/engine/substrait/ext_test.cc index ea6954dd00e..482212d75a6 100644 --- a/cpp/src/arrow/engine/substrait/ext_test.cc +++ b/cpp/src/arrow/engine/substrait/ext_test.cc @@ -208,13 +208,13 @@ TEST(ExtensionIdRegistryTest, RegisterNestedTypes) { auto default_registry = default_extension_id_registry(); constexpr int rounds = 3; for (int i = 0; i < rounds; i++) { - auto registry1 = substrait::MakeExtensionIdRegistry(); + auto registry1 = nested_extension_id_registry(default_registry); ASSERT_OK(registry1->CanRegisterType(id1, type1)); ASSERT_OK(registry1->RegisterType(id1, type1)); for (int j = 0; j < rounds; j++) { - auto registry2 = substrait::MakeExtensionIdRegistry(); + auto registry2 = nested_extension_id_registry(&*registry1); ASSERT_OK(registry2->CanRegisterType(id2, type2)); ASSERT_OK(registry2->RegisterType(id2, type2)); From 0eb19147bcce6e09798309019baca3be4531bcec Mon Sep 17 00:00:00 2001 From: Vibhatha Abeykoon Date: Thu, 2 Jun 2022 19:44:52 +0530 Subject: [PATCH 10/10] submodule update --- testing | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/testing b/testing index 634739c6644..53b49804710 160000 --- a/testing +++ b/testing @@ -1 +1 @@ -Subproject commit 634739c664433cec366b4b9a81d1e1044a8c5eda +Subproject commit 53b498047109d9940fcfab388bd9d6aeb8c57425