From 6ba776fdbad071f08a0c506a7c45925101a7e402 Mon Sep 17 00:00:00 2001 From: Andreas Huber Date: Tue, 21 Mar 2023 04:06:35 -0700 Subject: [PATCH] fixup add useConstFeatures parameter to oneapi interface --- .../backend/cpu/train_kernel_cls.cpp | 1 + .../backend/cpu/train_kernel_reg.cpp | 1 + .../backend/gpu/train_kernel_hist_impl_dpc.cpp | 1 + .../backend/gpu/train_misc_structs.hpp | 1 + cpp/oneapi/dal/algo/decision_forest/common.cpp | 11 +++++++++++ cpp/oneapi/dal/algo/decision_forest/common.hpp | 14 ++++++++++++++ 6 files changed, 29 insertions(+) diff --git a/cpp/oneapi/dal/algo/decision_forest/backend/cpu/train_kernel_cls.cpp b/cpp/oneapi/dal/algo/decision_forest/backend/cpu/train_kernel_cls.cpp index 290ba43d6ed..467d3f3cf6d 100644 --- a/cpp/oneapi/dal/algo/decision_forest/backend/cpu/train_kernel_cls.cpp +++ b/cpp/oneapi/dal/algo/decision_forest/backend/cpu/train_kernel_cls.cpp @@ -91,6 +91,7 @@ static result_t call_daal_kernel(const context_cpu& ctx, dal::detail::integral_cast(desc.get_max_leaf_nodes()); daal_parameter.maxBins = dal::detail::integral_cast(desc.get_max_bins()); daal_parameter.minBinSize = dal::detail::integral_cast(desc.get_min_bin_size()); + daal_parameter.useConstFeatures = desc.get_use_const_features(); daal_parameter.resultsToCompute = static_cast(desc.get_error_metric_mode()); diff --git a/cpp/oneapi/dal/algo/decision_forest/backend/cpu/train_kernel_reg.cpp b/cpp/oneapi/dal/algo/decision_forest/backend/cpu/train_kernel_reg.cpp index 2e29a6dfc70..26cf4d4ae57 100644 --- a/cpp/oneapi/dal/algo/decision_forest/backend/cpu/train_kernel_reg.cpp +++ b/cpp/oneapi/dal/algo/decision_forest/backend/cpu/train_kernel_reg.cpp @@ -90,6 +90,7 @@ static result_t call_daal_kernel(const context_cpu& ctx, dal::detail::integral_cast(desc.get_max_leaf_nodes()); daal_parameter.maxBins = dal::detail::integral_cast(desc.get_max_bins()); daal_parameter.minBinSize = dal::detail::integral_cast(desc.get_min_bin_size()); + daal_parameter.useConstFeatures = desc.get_use_const_features(); daal_parameter.resultsToCompute = static_cast(desc.get_error_metric_mode()); diff --git a/cpp/oneapi/dal/algo/decision_forest/backend/gpu/train_kernel_hist_impl_dpc.cpp b/cpp/oneapi/dal/algo/decision_forest/backend/gpu/train_kernel_hist_impl_dpc.cpp index b0bee12e38a..6e8df90dbaf 100644 --- a/cpp/oneapi/dal/algo/decision_forest/backend/gpu/train_kernel_hist_impl_dpc.cpp +++ b/cpp/oneapi/dal/algo/decision_forest/backend/gpu/train_kernel_hist_impl_dpc.cpp @@ -173,6 +173,7 @@ void train_kernel_hist_impl::init_params(train_context_ ctx.bootstrap_ = desc.get_bootstrap(); ctx.max_tree_depth_ = desc.get_max_tree_depth(); + ctx.use_const_features_ = desc.get_use_const_features(); if constexpr (std::is_same_v) { ctx.selected_ftr_count_ = desc.get_features_per_node() ? desc.get_features_per_node() diff --git a/cpp/oneapi/dal/algo/decision_forest/backend/gpu/train_misc_structs.hpp b/cpp/oneapi/dal/algo/decision_forest/backend/gpu/train_misc_structs.hpp index 6895d237da7..25977a62d8a 100644 --- a/cpp/oneapi/dal/algo/decision_forest/backend/gpu/train_misc_structs.hpp +++ b/cpp/oneapi/dal/algo/decision_forest/backend/gpu/train_misc_structs.hpp @@ -136,6 +136,7 @@ struct train_context { bool bootstrap_ = false; bool use_private_mem_buf_ = true; // valuable for classification only // for switching between private mem and other buffers(local, global) for storing class hist + bool use_const_features_ = false; Index total_bin_count_ = 0; Index max_bin_count_among_ftrs_ = 0; diff --git a/cpp/oneapi/dal/algo/decision_forest/common.cpp b/cpp/oneapi/dal/algo/decision_forest/common.cpp index a5c574c70ef..78f220091f4 100644 --- a/cpp/oneapi/dal/algo/decision_forest/common.cpp +++ b/cpp/oneapi/dal/algo/decision_forest/common.cpp @@ -63,6 +63,7 @@ class descriptor_impl : public base { bool memory_saving_mode = false; bool bootstrap = true; + bool use_const_features = false; variable_importance_mode variable_importance_mode_value = variable_importance_mode::none; voting_mode voting_mode_value = voting_mode::weighted; @@ -148,6 +149,11 @@ bool descriptor_base::get_bootstrap() const { return impl_->bootstrap; } +template +bool descriptor_base::get_use_const_features() const { + return impl_->use_const_features; +} + template variable_importance_mode descriptor_base::get_variable_importance_mode() const { return impl_->variable_importance_mode_value; @@ -267,6 +273,11 @@ void descriptor_base::set_bootstrap_impl(bool value) { impl_->bootstrap = value; } +template +void descriptor_base::set_use_const_features_impl(bool value) { + impl_->use_const_features = value; +} + template void descriptor_base::set_variable_importance_mode_impl(variable_importance_mode value) { impl_->variable_importance_mode_value = value; diff --git a/cpp/oneapi/dal/algo/decision_forest/common.hpp b/cpp/oneapi/dal/algo/decision_forest/common.hpp index 319a37a7d96..bba0cde2a34 100644 --- a/cpp/oneapi/dal/algo/decision_forest/common.hpp +++ b/cpp/oneapi/dal/algo/decision_forest/common.hpp @@ -212,6 +212,7 @@ class descriptor_base : public base { std::int64_t get_min_bin_size() const; bool get_memory_saving_mode() const; bool get_bootstrap() const; + bool get_use_const_features() const; error_metric_mode get_error_metric_mode() const; variable_importance_mode get_variable_importance_mode() const; @@ -247,6 +248,7 @@ class descriptor_base : public base { void set_min_bin_size_impl(std::int64_t value); void set_memory_saving_mode_impl(bool value); void set_bootstrap_impl(bool value); + void set_use_const_features_impl(bool value); void set_error_metric_mode_impl(error_metric_mode value); void set_variable_importance_mode_impl(variable_importance_mode value); void set_class_count_impl(std::int64_t value); @@ -503,6 +505,18 @@ class descriptor : public detail::descriptor_base { return *this; } + /// The use constant features mode, if true, constant-valued features are + /// considered for node splits + /// @remark default = false + bool get_use_const_features() const { + return base_t::get_use_const_features(); + } + + auto& set_use_const_features(bool value) { + base_t::set_use_const_features_impl(value); + return *this; + } + /// The error metric mode /// @remark default = error_metric_mode::none error_metric_mode get_error_metric_mode() const {