Skip to content

Commit

Permalink
fixup add useConstFeatures parameter to oneapi interface
Browse files Browse the repository at this point in the history
  • Loading branch information
ahuber21 committed Mar 21, 2023
1 parent fb5556f commit 6ba776f
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ static result_t call_daal_kernel(const context_cpu& ctx,
dal::detail::integral_cast<std::size_t>(desc.get_max_leaf_nodes());
daal_parameter.maxBins = dal::detail::integral_cast<std::size_t>(desc.get_max_bins());
daal_parameter.minBinSize = dal::detail::integral_cast<std::size_t>(desc.get_min_bin_size());
daal_parameter.useConstFeatures = desc.get_use_const_features();

daal_parameter.resultsToCompute = static_cast<std::uint64_t>(desc.get_error_metric_mode());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ static result_t call_daal_kernel(const context_cpu& ctx,
dal::detail::integral_cast<std::size_t>(desc.get_max_leaf_nodes());
daal_parameter.maxBins = dal::detail::integral_cast<std::size_t>(desc.get_max_bins());
daal_parameter.minBinSize = dal::detail::integral_cast<std::size_t>(desc.get_min_bin_size());
daal_parameter.useConstFeatures = desc.get_use_const_features();

daal_parameter.resultsToCompute = static_cast<std::uint64_t>(desc.get_error_metric_mode());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ void train_kernel_hist_impl<Float, Bin, Index, Task>::init_params(train_context_

ctx.bootstrap_ = desc.get_bootstrap();
ctx.max_tree_depth_ = desc.get_max_tree_depth();
ctx.use_const_features_ = desc.get_use_const_features();

if constexpr (std::is_same_v<Task, task::classification>) {
ctx.selected_ftr_count_ = desc.get_features_per_node() ? desc.get_features_per_node()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ struct train_context {
bool bootstrap_ = false;
bool use_private_mem_buf_ = true; // valuable for classification only
// for switching between private mem and other buffers(local, global) for storing class hist
bool use_const_features_ = false;

Index total_bin_count_ = 0;
Index max_bin_count_among_ftrs_ = 0;
Expand Down
11 changes: 11 additions & 0 deletions cpp/oneapi/dal/algo/decision_forest/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class descriptor_impl : public base {

bool memory_saving_mode = false;
bool bootstrap = true;
bool use_const_features = false;

variable_importance_mode variable_importance_mode_value = variable_importance_mode::none;
voting_mode voting_mode_value = voting_mode::weighted;
Expand Down Expand Up @@ -148,6 +149,11 @@ bool descriptor_base<Task>::get_bootstrap() const {
return impl_->bootstrap;
}

template <typename Task>
bool descriptor_base<Task>::get_use_const_features() const {
return impl_->use_const_features;
}

template <typename Task>
variable_importance_mode descriptor_base<Task>::get_variable_importance_mode() const {
return impl_->variable_importance_mode_value;
Expand Down Expand Up @@ -267,6 +273,11 @@ void descriptor_base<Task>::set_bootstrap_impl(bool value) {
impl_->bootstrap = value;
}

template <typename Task>
void descriptor_base<Task>::set_use_const_features_impl(bool value) {
impl_->use_const_features = value;
}

template <typename Task>
void descriptor_base<Task>::set_variable_importance_mode_impl(variable_importance_mode value) {
impl_->variable_importance_mode_value = value;
Expand Down
14 changes: 14 additions & 0 deletions cpp/oneapi/dal/algo/decision_forest/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ class descriptor_base : public base {
std::int64_t get_min_bin_size() const;
bool get_memory_saving_mode() const;
bool get_bootstrap() const;
bool get_use_const_features() const;
error_metric_mode get_error_metric_mode() const;
variable_importance_mode get_variable_importance_mode() const;

Expand Down Expand Up @@ -247,6 +248,7 @@ class descriptor_base : public base {
void set_min_bin_size_impl(std::int64_t value);
void set_memory_saving_mode_impl(bool value);
void set_bootstrap_impl(bool value);
void set_use_const_features_impl(bool value);
void set_error_metric_mode_impl(error_metric_mode value);
void set_variable_importance_mode_impl(variable_importance_mode value);
void set_class_count_impl(std::int64_t value);
Expand Down Expand Up @@ -503,6 +505,18 @@ class descriptor : public detail::descriptor_base<Task> {
return *this;
}

/// The use constant features mode, if true, constant-valued features are
/// considered for node splits
/// @remark default = false
bool get_use_const_features() const {
return base_t::get_use_const_features();
}

auto& set_use_const_features(bool value) {
base_t::set_use_const_features_impl(value);
return *this;
}

/// The error metric mode
/// @remark default = error_metric_mode::none
error_metric_mode get_error_metric_mode() const {
Expand Down

0 comments on commit 6ba776f

Please sign in to comment.