diff --git a/cpp/src/arrow/engine/CMakeLists.txt b/cpp/src/arrow/engine/CMakeLists.txt index ea9797ea1d7..8edd22900e6 100644 --- a/cpp/src/arrow/engine/CMakeLists.txt +++ b/cpp/src/arrow/engine/CMakeLists.txt @@ -66,6 +66,7 @@ endif() add_arrow_test(substrait_test SOURCES + substrait/ext_test.cc substrait/serde_test.cc EXTRA_LINK_LIBS ${ARROW_SUBSTRAIT_TEST_LINK_LIBS} diff --git a/cpp/src/arrow/engine/substrait/ext_test.cc b/cpp/src/arrow/engine/substrait/ext_test.cc new file mode 100644 index 00000000000..482212d75a6 --- /dev/null +++ b/cpp/src/arrow/engine/substrait/ext_test.cc @@ -0,0 +1,263 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/engine/substrait/extension_set.h" +#include "arrow/engine/substrait/util.h" + +#include +#include +#include +#include + +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/matchers.h" + +using testing::ElementsAre; +using testing::Eq; +using testing::HasSubstr; +using testing::UnorderedElementsAre; + +namespace arrow { + +using internal::checked_cast; + +namespace engine { + +// an extension-id-registry provider to be used as a test parameter +// +// we cannot pass a pointer to a nested registry as a test parameter because the +// shared_ptr in which it is made would not be held and get destructed too early, +// nor can we pass a shared_ptr to the default nested registry as a test parameter +// because it is global and must never be cleaned up, so we pass a shared_ptr to a +// provider that either owns or does not own the registry it provides, depending +// on the case. +struct ExtensionIdRegistryProvider { + virtual ExtensionIdRegistry* get() const = 0; +}; + +struct DefaultExtensionIdRegistryProvider : public ExtensionIdRegistryProvider { + virtual ~DefaultExtensionIdRegistryProvider() {} + ExtensionIdRegistry* get() const override { return default_extension_id_registry(); } +}; + +struct NestedExtensionIdRegistryProvider : public ExtensionIdRegistryProvider { + virtual ~NestedExtensionIdRegistryProvider() {} + std::shared_ptr registry_ = substrait::MakeExtensionIdRegistry(); + ExtensionIdRegistry* get() const override { return &*registry_; } +}; + +using Id = ExtensionIdRegistry::Id; + +bool operator==(const Id& id1, const Id& id2) { + return id1.uri == id2.uri && id1.name == id2.name; +} + +bool operator!=(const Id& id1, const Id& id2) { return !(id1 == id2); } + +struct TypeName { + std::shared_ptr type; + util::string_view name; +}; + +static const std::vector kTypeNames = { + TypeName{uint8(), "u8"}, + TypeName{uint16(), "u16"}, + TypeName{uint32(), "u32"}, + TypeName{uint64(), "u64"}, + TypeName{float16(), "fp16"}, + TypeName{null(), "null"}, + TypeName{month_interval(), "interval_month"}, + TypeName{day_time_interval(), "interval_day_milli"}, + TypeName{month_day_nano_interval(), "interval_month_day_nano"}, +}; + +static const std::vector kFunctionNames = { + "add", +}; + +static const std::vector kTempFunctionNames = { + "temp_func_1", + "temp_func_2", +}; + +static const std::vector kTempTypeNames = { + TypeName{timestamp(TimeUnit::SECOND, "temp_tz_1"), "temp_type_1"}, + TypeName{timestamp(TimeUnit::SECOND, "temp_tz_2"), "temp_type_2"}, +}; + +using ExtensionIdRegistryParams = + std::tuple, std::string>; + +struct ExtensionIdRegistryTest + : public testing::TestWithParam {}; + +TEST_P(ExtensionIdRegistryTest, GetTypes) { + auto provider = std::get<0>(GetParam()); + auto registry = provider->get(); + + for (TypeName e : kTypeNames) { + auto id = Id{kArrowExtTypesUri, e.name}; + for (auto typerec_opt : {registry->GetType(id), registry->GetType(*e.type)}) { + ASSERT_TRUE(typerec_opt); + auto typerec = typerec_opt.value(); + ASSERT_EQ(id, typerec.id); + ASSERT_EQ(*e.type, *typerec.type); + } + } +} + +TEST_P(ExtensionIdRegistryTest, ReregisterTypes) { + auto provider = std::get<0>(GetParam()); + auto registry = provider->get(); + + for (TypeName e : kTypeNames) { + auto id = Id{kArrowExtTypesUri, e.name}; + ASSERT_RAISES(Invalid, registry->CanRegisterType(id, e.type)); + ASSERT_RAISES(Invalid, registry->RegisterType(id, e.type)); + } +} + +TEST_P(ExtensionIdRegistryTest, GetFunctions) { + auto provider = std::get<0>(GetParam()); + auto registry = provider->get(); + + for (util::string_view name : kFunctionNames) { + auto id = Id{kArrowExtTypesUri, name}; + for (auto funcrec_opt : {registry->GetFunction(id), registry->GetFunction(name)}) { + ASSERT_TRUE(funcrec_opt); + auto funcrec = funcrec_opt.value(); + ASSERT_EQ(id, funcrec.id); + ASSERT_EQ(name, funcrec.function_name); + } + } +} + +TEST_P(ExtensionIdRegistryTest, ReregisterFunctions) { + auto provider = std::get<0>(GetParam()); + auto registry = provider->get(); + + for (util::string_view name : kFunctionNames) { + auto id = Id{kArrowExtTypesUri, name}; + ASSERT_RAISES(Invalid, registry->CanRegisterFunction(id, name.to_string())); + ASSERT_RAISES(Invalid, registry->RegisterFunction(id, name.to_string())); + } +} + +INSTANTIATE_TEST_SUITE_P( + Substrait, ExtensionIdRegistryTest, + testing::Values( + std::make_tuple(std::make_shared(), + "default"), + std::make_tuple(std::make_shared(), + "nested"))); + +TEST(ExtensionIdRegistryTest, RegisterTempTypes) { + auto default_registry = default_extension_id_registry(); + constexpr int rounds = 3; + for (int i = 0; i < rounds; i++) { + auto registry = substrait::MakeExtensionIdRegistry(); + + for (TypeName e : kTempTypeNames) { + auto id = Id{kArrowExtTypesUri, e.name}; + ASSERT_OK(registry->CanRegisterType(id, e.type)); + ASSERT_OK(registry->RegisterType(id, e.type)); + ASSERT_RAISES(Invalid, registry->CanRegisterType(id, e.type)); + ASSERT_RAISES(Invalid, registry->RegisterType(id, e.type)); + ASSERT_OK(default_registry->CanRegisterType(id, e.type)); + } + } +} + +TEST(ExtensionIdRegistryTest, RegisterTempFunctions) { + auto default_registry = default_extension_id_registry(); + constexpr int rounds = 3; + for (int i = 0; i < rounds; i++) { + auto registry = substrait::MakeExtensionIdRegistry(); + + for (util::string_view name : kTempFunctionNames) { + auto id = Id{kArrowExtTypesUri, name}; + ASSERT_OK(registry->CanRegisterFunction(id, name.to_string())); + ASSERT_OK(registry->RegisterFunction(id, name.to_string())); + ASSERT_RAISES(Invalid, registry->CanRegisterFunction(id, name.to_string())); + ASSERT_RAISES(Invalid, registry->RegisterFunction(id, name.to_string())); + ASSERT_OK(default_registry->CanRegisterFunction(id, name.to_string())); + } + } +} + +TEST(ExtensionIdRegistryTest, RegisterNestedTypes) { + std::shared_ptr type1 = kTempTypeNames[0].type; + std::shared_ptr type2 = kTempTypeNames[1].type; + auto id1 = Id{kArrowExtTypesUri, kTempTypeNames[0].name}; + auto id2 = Id{kArrowExtTypesUri, kTempTypeNames[1].name}; + + auto default_registry = default_extension_id_registry(); + constexpr int rounds = 3; + for (int i = 0; i < rounds; i++) { + auto registry1 = nested_extension_id_registry(default_registry); + + ASSERT_OK(registry1->CanRegisterType(id1, type1)); + ASSERT_OK(registry1->RegisterType(id1, type1)); + + for (int j = 0; j < rounds; j++) { + auto registry2 = nested_extension_id_registry(&*registry1); + + ASSERT_OK(registry2->CanRegisterType(id2, type2)); + ASSERT_OK(registry2->RegisterType(id2, type2)); + ASSERT_RAISES(Invalid, registry2->CanRegisterType(id2, type2)); + ASSERT_RAISES(Invalid, registry2->RegisterType(id2, type2)); + ASSERT_OK(default_registry->CanRegisterType(id2, type2)); + } + + ASSERT_RAISES(Invalid, registry1->CanRegisterType(id1, type1)); + ASSERT_RAISES(Invalid, registry1->RegisterType(id1, type1)); + ASSERT_OK(default_registry->CanRegisterType(id1, type1)); + } +} + +TEST(ExtensionIdRegistryTest, RegisterNestedFunctions) { + util::string_view name1 = kTempFunctionNames[0]; + util::string_view name2 = kTempFunctionNames[1]; + auto id1 = Id{kArrowExtTypesUri, name1}; + auto id2 = Id{kArrowExtTypesUri, name2}; + + auto default_registry = default_extension_id_registry(); + constexpr int rounds = 3; + for (int i = 0; i < rounds; i++) { + auto registry1 = substrait::MakeExtensionIdRegistry(); + + ASSERT_OK(registry1->CanRegisterFunction(id1, name1.to_string())); + ASSERT_OK(registry1->RegisterFunction(id1, name1.to_string())); + + for (int j = 0; j < rounds; j++) { + auto registry2 = substrait::MakeExtensionIdRegistry(); + + ASSERT_OK(registry2->CanRegisterFunction(id2, name2.to_string())); + ASSERT_OK(registry2->RegisterFunction(id2, name2.to_string())); + ASSERT_RAISES(Invalid, registry2->CanRegisterFunction(id2, name2.to_string())); + ASSERT_RAISES(Invalid, registry2->RegisterFunction(id2, name2.to_string())); + ASSERT_OK(default_registry->CanRegisterFunction(id2, name2.to_string())); + } + + ASSERT_RAISES(Invalid, registry1->CanRegisterFunction(id1, name1.to_string())); + ASSERT_RAISES(Invalid, registry1->RegisterFunction(id1, name1.to_string())); + ASSERT_OK(default_registry->CanRegisterFunction(id1, name1.to_string())); + } +} + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index cd85678a72c..a30c740b181 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -57,7 +57,7 @@ bool ExtensionIdRegistry::IdHashEq::operator()(ExtensionIdRegistry::Id l, // A builder used when creating a Substrait plan from an Arrow execution plan. In // that situation we do not have a set of anchor values already defined so we keep // a map of what Ids we have seen. -ExtensionSet::ExtensionSet(ExtensionIdRegistry* registry) : registry_(registry) {} +ExtensionSet::ExtensionSet(const ExtensionIdRegistry* registry) : registry_(registry) {} Status ExtensionSet::CheckHasUri(util::string_view uri) { auto it = @@ -96,7 +96,7 @@ Status ExtensionSet::AddUri(Id id) { Result ExtensionSet::Make( std::unordered_map uris, std::unordered_map type_ids, - std::unordered_map function_ids, ExtensionIdRegistry* registry) { + std::unordered_map function_ids, const ExtensionIdRegistry* registry) { ExtensionSet set; set.registry_ = registry; @@ -204,152 +204,259 @@ const int* GetIndex(const KeyToIndex& key_to_index, const Key& key) { return &it->second; } -ExtensionIdRegistry* default_extension_id_registry() { - static struct Impl : ExtensionIdRegistry { - Impl() { - struct TypeName { - std::shared_ptr type; - util::string_view name; - }; - - // The type (variation) mappings listed below need to be kept in sync - // with the YAML at substrait/format/extension_types.yaml manually; - // see ARROW-15535. - for (TypeName e : { - TypeName{uint8(), "u8"}, - TypeName{uint16(), "u16"}, - TypeName{uint32(), "u32"}, - TypeName{uint64(), "u64"}, - TypeName{float16(), "fp16"}, - }) { - DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type))); - } - - for (TypeName e : { - TypeName{null(), "null"}, - TypeName{month_interval(), "interval_month"}, - TypeName{day_time_interval(), "interval_day_milli"}, - TypeName{month_day_nano_interval(), "interval_month_day_nano"}, - }) { - DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type))); - } - - // TODO: this is just a placeholder right now. We'll need a YAML file for - // all functions (and prototypes) that Arrow provides that are relevant - // for Substrait, and include mappings for all of them here. See - // ARROW-15535. - for (util::string_view name : { - "add", - "equal", - "is_not_distinct_from", - }) { - DCHECK_OK(RegisterFunction({kArrowExtTypesUri, name}, name.to_string())); - } +namespace { + +struct ExtensionIdRegistryImpl : ExtensionIdRegistry { + virtual ~ExtensionIdRegistryImpl() {} + + std::vector Uris() const override { + return {uris_.begin(), uris_.end()}; + } + + util::optional GetType(const DataType& type) const override { + if (auto index = GetIndex(type_to_index_, &type)) { + return TypeRecord{type_ids_[*index], types_[*index]}; + } + return {}; + } + + util::optional GetType(Id id) const override { + if (auto index = GetIndex(id_to_index_, id)) { + return TypeRecord{type_ids_[*index], types_[*index]}; + } + return {}; + } + + Status CanRegisterType(Id id, const std::shared_ptr& type) const override { + if (id_to_index_.find(id) != id_to_index_.end()) { + return Status::Invalid("Type id was already registered"); + } + if (type_to_index_.find(&*type) != type_to_index_.end()) { + return Status::Invalid("Type was already registered"); + } + return Status::OK(); + } + + Status RegisterType(Id id, std::shared_ptr type) override { + DCHECK_EQ(type_ids_.size(), types_.size()); + + Id copied_id{*uris_.emplace(id.uri.to_string()).first, + *names_.emplace(id.name.to_string()).first}; + + auto index = static_cast(type_ids_.size()); + + auto it_success = id_to_index_.emplace(copied_id, index); + + if (!it_success.second) { + return Status::Invalid("Type id was already registered"); + } + + if (!type_to_index_.emplace(type.get(), index).second) { + id_to_index_.erase(it_success.first); + return Status::Invalid("Type was already registered"); } - std::vector Uris() const override { - return {uris_.begin(), uris_.end()}; + type_ids_.push_back(copied_id); + types_.push_back(std::move(type)); + return Status::OK(); + } + + util::optional GetFunction( + util::string_view arrow_function_name) const override { + if (auto index = GetIndex(function_name_to_index_, arrow_function_name)) { + return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]}; } + return {}; + } - util::optional GetType(const DataType& type) const override { - if (auto index = GetIndex(type_to_index_, &type)) { - return TypeRecord{type_ids_[*index], types_[*index]}; - } - return {}; + util::optional GetFunction(Id id) const override { + if (auto index = GetIndex(function_id_to_index_, id)) { + return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]}; } + return {}; + } - util::optional GetType(Id id) const override { - if (auto index = GetIndex(id_to_index_, id)) { - return TypeRecord{type_ids_[*index], types_[*index]}; - } - return {}; + Status CanRegisterFunction(Id id, + const std::string& arrow_function_name) const override { + if (function_id_to_index_.find(id) != function_id_to_index_.end()) { + return Status::Invalid("Function id was already registered"); + } + if (function_name_to_index_.find(arrow_function_name) != + function_name_to_index_.end()) { + return Status::Invalid("Function name was already registered"); } + return Status::OK(); + } - Status RegisterType(Id id, std::shared_ptr type) override { - DCHECK_EQ(type_ids_.size(), types_.size()); + Status RegisterFunction(Id id, std::string arrow_function_name) override { + DCHECK_EQ(function_ids_.size(), function_name_ptrs_.size()); - Id copied_id{*uris_.emplace(id.uri.to_string()).first, - *names_.emplace(id.name.to_string()).first}; + Id copied_id{*uris_.emplace(id.uri.to_string()).first, + *names_.emplace(id.name.to_string()).first}; - auto index = static_cast(type_ids_.size()); + const std::string& copied_function_name{ + *function_names_.emplace(std::move(arrow_function_name)).first}; - auto it_success = id_to_index_.emplace(copied_id, index); + auto index = static_cast(function_ids_.size()); - if (!it_success.second) { - return Status::Invalid("Type id was already registered"); - } + auto it_success = function_id_to_index_.emplace(copied_id, index); - if (!type_to_index_.emplace(type.get(), index).second) { - id_to_index_.erase(it_success.first); - return Status::Invalid("Type was already registered"); - } + if (!it_success.second) { + return Status::Invalid("Function id was already registered"); + } - type_ids_.push_back(copied_id); - types_.push_back(std::move(type)); - return Status::OK(); + if (!function_name_to_index_.emplace(copied_function_name, index).second) { + function_id_to_index_.erase(it_success.first); + return Status::Invalid("Function name was already registered"); } - util::optional GetFunction( - util::string_view arrow_function_name) const override { - if (auto index = GetIndex(function_name_to_index_, arrow_function_name)) { - return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]}; - } - return {}; + function_name_ptrs_.push_back(&copied_function_name); + function_ids_.push_back(copied_id); + return Status::OK(); + } + + // owning storage of uris, names, (arrow::)function_names, types + // note that storing strings like this is safe since references into an + // unordered_set are not invalidated on insertion + std::unordered_set uris_, names_, function_names_; + DataTypeVector types_; + + // non-owning lookup helpers + std::vector type_ids_, function_ids_; + std::unordered_map id_to_index_; + std::unordered_map type_to_index_; + + std::vector function_name_ptrs_; + std::unordered_map function_id_to_index_; + std::unordered_map + function_name_to_index_; +}; + +struct NestedExtensionIdRegistryImpl : ExtensionIdRegistryImpl { + explicit NestedExtensionIdRegistryImpl(const ExtensionIdRegistry* parent) + : parent_(parent) {} + + virtual ~NestedExtensionIdRegistryImpl() {} + + std::vector Uris() const override { + std::vector uris = parent_->Uris(); + std::unordered_set uri_set; + uri_set.insert(uris.begin(), uris.end()); + uri_set.insert(uris_.begin(), uris_.end()); + return std::vector(uris); + } + + util::optional GetType(const DataType& type) const override { + auto type_opt = ExtensionIdRegistryImpl::GetType(type); + if (type_opt) { + return type_opt; } + return parent_->GetType(type); + } - util::optional GetFunction(Id id) const override { - if (auto index = GetIndex(function_id_to_index_, id)) { - return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]}; - } - return {}; + util::optional GetType(Id id) const override { + auto type_opt = ExtensionIdRegistryImpl::GetType(id); + if (type_opt) { + return type_opt; } + return parent_->GetType(id); + } - Status RegisterFunction(Id id, std::string arrow_function_name) override { - DCHECK_EQ(function_ids_.size(), function_name_ptrs_.size()); + Status CanRegisterType(Id id, const std::shared_ptr& type) const override { + return parent_->CanRegisterType(id, type) & + ExtensionIdRegistryImpl::CanRegisterType(id, type); + } - Id copied_id{*uris_.emplace(id.uri.to_string()).first, - *names_.emplace(id.name.to_string()).first}; + Status RegisterType(Id id, std::shared_ptr type) override { + return parent_->CanRegisterType(id, type) & + ExtensionIdRegistryImpl::RegisterType(id, type); + } - const std::string& copied_function_name{ - *function_names_.emplace(std::move(arrow_function_name)).first}; + util::optional GetFunction( + util::string_view arrow_function_name) const override { + auto func_opt = ExtensionIdRegistryImpl::GetFunction(arrow_function_name); + if (func_opt) { + return func_opt; + } + return parent_->GetFunction(arrow_function_name); + } - auto index = static_cast(function_ids_.size()); + util::optional GetFunction(Id id) const override { + auto func_opt = ExtensionIdRegistryImpl::GetFunction(id); + if (func_opt) { + return func_opt; + } + return parent_->GetFunction(id); + } - auto it_success = function_id_to_index_.emplace(copied_id, index); + Status CanRegisterFunction(Id id, + const std::string& arrow_function_name) const override { + return parent_->CanRegisterFunction(id, arrow_function_name) & + ExtensionIdRegistryImpl::CanRegisterFunction(id, arrow_function_name); + } - if (!it_success.second) { - return Status::Invalid("Function id was already registered"); - } + Status RegisterFunction(Id id, std::string arrow_function_name) override { + return parent_->CanRegisterFunction(id, arrow_function_name) & + ExtensionIdRegistryImpl::RegisterFunction(id, arrow_function_name); + } - if (!function_name_to_index_.emplace(copied_function_name, index).second) { - function_id_to_index_.erase(it_success.first); - return Status::Invalid("Function name was already registered"); - } + const ExtensionIdRegistry* parent_; +}; - function_name_ptrs_.push_back(&copied_function_name); - function_ids_.push_back(copied_id); - return Status::OK(); +struct DefaultExtensionIdRegistry : ExtensionIdRegistryImpl { + DefaultExtensionIdRegistry() { + struct TypeName { + std::shared_ptr type; + util::string_view name; + }; + + // The type (variation) mappings listed below need to be kept in sync + // with the YAML at substrait/format/extension_types.yaml manually; + // see ARROW-15535. + for (TypeName e : { + TypeName{uint8(), "u8"}, + TypeName{uint16(), "u16"}, + TypeName{uint32(), "u32"}, + TypeName{uint64(), "u64"}, + TypeName{float16(), "fp16"}, + }) { + DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type))); } - // owning storage of uris, names, (arrow::)function_names, types - // note that storing strings like this is safe since references into an - // unordered_set are not invalidated on insertion - std::unordered_set uris_, names_, function_names_; - DataTypeVector types_; + for (TypeName e : { + TypeName{null(), "null"}, + TypeName{month_interval(), "interval_month"}, + TypeName{day_time_interval(), "interval_day_milli"}, + TypeName{month_day_nano_interval(), "interval_month_day_nano"}, + }) { + DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type))); + } - // non-owning lookup helpers - std::vector type_ids_, function_ids_; - std::unordered_map id_to_index_; - std::unordered_map type_to_index_; + // TODO: this is just a placeholder right now. We'll need a YAML file for + // all functions (and prototypes) that Arrow provides that are relevant + // for Substrait, and include mappings for all of them here. See + // ARROW-15535. + for (util::string_view name : { + "add", + "equal", + "is_not_distinct_from", + }) { + DCHECK_OK(RegisterFunction({kArrowExtTypesUri, name}, name.to_string())); + } + } +}; - std::vector function_name_ptrs_; - std::unordered_map function_id_to_index_; - std::unordered_map - function_name_to_index_; - } impl_; +} // namespace +ExtensionIdRegistry* default_extension_id_registry() { + static DefaultExtensionIdRegistry impl_; return &impl_; } +std::shared_ptr nested_extension_id_registry( + const ExtensionIdRegistry* parent) { + return std::make_shared(parent); +} + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/extension_set.h b/cpp/src/arrow/engine/substrait/extension_set.h index 55ea4d02324..638a354c6f2 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.h +++ b/cpp/src/arrow/engine/substrait/extension_set.h @@ -70,6 +70,7 @@ class ARROW_ENGINE_EXPORT ExtensionIdRegistry { }; virtual util::optional GetType(const DataType&) const = 0; virtual util::optional GetType(Id) const = 0; + virtual Status CanRegisterType(Id, const std::shared_ptr& type) const = 0; virtual Status RegisterType(Id, std::shared_ptr) = 0; /// \brief A mapping between a Substrait ID and an Arrow function @@ -91,6 +92,8 @@ class ARROW_ENGINE_EXPORT ExtensionIdRegistry { virtual util::optional GetFunction(Id) const = 0; virtual util::optional GetFunction( util::string_view arrow_function_name) const = 0; + virtual Status CanRegisterFunction(Id, + const std::string& arrow_function_name) const = 0; virtual Status RegisterFunction(Id, std::string arrow_function_name) = 0; }; @@ -103,6 +106,19 @@ constexpr util::string_view kArrowExtTypesUri = /// Note: Function support is currently very minimal, see ARROW-15538 ARROW_ENGINE_EXPORT ExtensionIdRegistry* default_extension_id_registry(); +/// \brief Makes a nested registry with a given parent. +/// +/// A nested registry supports registering types and functions other and on top of those +/// already registered in its parent registry. No conflicts in IDs and names used for +/// lookup are allowed. Normally, the given parent is the default registry. +/// +/// One use case for a nested registry is for dynamic registration of functions defined +/// within a Substrait plan while keeping these registrations specific to the plan. When +/// the Substrait plan is disposed of, normally after its execution, the nested registry +/// can be disposed of as well. +ARROW_ENGINE_EXPORT std::shared_ptr nested_extension_id_registry( + const ExtensionIdRegistry* parent); + /// \brief A set of extensions used within a plan /// /// Each time an extension is used within a Substrait plan the extension @@ -147,7 +163,7 @@ class ARROW_ENGINE_EXPORT ExtensionSet { }; /// Construct an empty ExtensionSet to be populated during serialization. - explicit ExtensionSet(ExtensionIdRegistry* = default_extension_id_registry()); + explicit ExtensionSet(const ExtensionIdRegistry* = default_extension_id_registry()); ARROW_DEFAULT_MOVE_AND_ASSIGN(ExtensionSet); /// Construct an ExtensionSet with explicit extension ids for efficient referencing @@ -168,7 +184,7 @@ class ARROW_ENGINE_EXPORT ExtensionSet { std::unordered_map uris, std::unordered_map type_ids, std::unordered_map function_ids, - ExtensionIdRegistry* = default_extension_id_registry()); + const ExtensionIdRegistry* = default_extension_id_registry()); const std::unordered_map& uris() const { return uris_; } @@ -229,7 +245,7 @@ class ARROW_ENGINE_EXPORT ExtensionSet { std::size_t num_functions() const { return functions_.size(); } private: - ExtensionIdRegistry* registry_; + const ExtensionIdRegistry* registry_; // Map from anchor values to URI values referenced by this extension set std::unordered_map uris_; diff --git a/cpp/src/arrow/engine/substrait/plan_internal.cc b/cpp/src/arrow/engine/substrait/plan_internal.cc index cc20de6da64..fcee0b2188f 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.cc +++ b/cpp/src/arrow/engine/substrait/plan_internal.cc @@ -91,7 +91,7 @@ Status AddExtensionSetToPlan(const ExtensionSet& ext_set, substrait::Plan* plan) } Result GetExtensionSetFromPlan(const substrait::Plan& plan, - ExtensionIdRegistry* registry) { + const ExtensionIdRegistry* registry) { std::unordered_map uris; uris.reserve(plan.extension_uris_size()); for (const auto& uri : plan.extension_uris()) { diff --git a/cpp/src/arrow/engine/substrait/plan_internal.h b/cpp/src/arrow/engine/substrait/plan_internal.h index 281cab0c0f3..dce23cdceba 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.h +++ b/cpp/src/arrow/engine/substrait/plan_internal.h @@ -49,7 +49,7 @@ Status AddExtensionSetToPlan(const ExtensionSet& ext_set, substrait::Plan* plan) ARROW_ENGINE_EXPORT Result GetExtensionSetFromPlan( const substrait::Plan& plan, - ExtensionIdRegistry* registry = default_extension_id_registry()); + const ExtensionIdRegistry* registry = default_extension_id_registry()); } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/util.cc b/cpp/src/arrow/engine/substrait/util.cc index bc2aa36856e..2ae3771f3fb 100644 --- a/cpp/src/arrow/engine/substrait/util.cc +++ b/cpp/src/arrow/engine/substrait/util.cc @@ -123,6 +123,10 @@ Result> SerializeJsonPlan(const std::string& substrait_j return engine::internal::SubstraitFromJSON("Plan", substrait_json); } +std::shared_ptr MakeExtensionIdRegistry() { + return nested_extension_id_registry(default_extension_id_registry()); +} + } // namespace substrait } // namespace engine diff --git a/cpp/src/arrow/engine/substrait/util.h b/cpp/src/arrow/engine/substrait/util.h index 860a459da2f..c7fc4c5d808 100644 --- a/cpp/src/arrow/engine/substrait/util.h +++ b/cpp/src/arrow/engine/substrait/util.h @@ -37,6 +37,10 @@ ARROW_ENGINE_EXPORT Result> ExecuteSerialized ARROW_ENGINE_EXPORT Result> SerializeJsonPlan( const std::string& substrait_json); +/// \brief Makes a nested registry with the default registry as parent. +/// See arrow::engine::nested_extension_id_registry for details. +ARROW_ENGINE_EXPORT std::shared_ptr MakeExtensionIdRegistry(); + } // namespace substrait } // namespace engine diff --git a/testing b/testing index 53b49804710..634739c6644 160000 --- a/testing +++ b/testing @@ -1 +1 @@ -Subproject commit 53b498047109d9940fcfab388bd9d6aeb8c57425 +Subproject commit 634739c664433cec366b4b9a81d1e1044a8c5eda