diff --git a/tests/cpp/tree/test_approx.cc b/tests/cpp/tree/test_approx.cc index 8f28bfa218c8..83e9243a2fa4 100644 --- a/tests/cpp/tree/test_approx.cc +++ b/tests/cpp/tree/test_approx.cc @@ -4,10 +4,12 @@ #include #include "../../../src/tree/common_row_partitioner.h" +#include "../../../src/tree/param.h" // for TrainParam #include "../collective/test_worker.h" // for TestDistributedGlobal #include "../helpers.h" #include "test_column_split.h" // for TestColumnSplit #include "test_partitioner.h" +#include "xgboost/tree_model.h" // for RegTree namespace xgboost::tree { namespace { @@ -76,6 +78,53 @@ TEST(Approx, Partitioner) { } } +TEST(Approx, InteractionConstraint) { + auto constexpr kRows = 32; + auto constexpr kCols = 16; + auto p_dmat = GenerateCatDMatrix(kRows, kCols, 0.6f, false); + Context ctx; + + linalg::Matrix gpair({kRows}, ctx.Device()); + gpair.Data()->Copy(GenerateRandomGradients(kRows)); + + ObjInfo task{ObjInfo::kRegression}; + { + // With constraints + RegTree tree{1, kCols}; + + std::unique_ptr updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)}; + TrainParam param; + param.UpdateAllowUnknown( + Args{{"interaction_constraints", "[[0, 1]]"}, {"num_feature", std::to_string(kCols)}}); + std::vector> position(1); + updater->Configure(Args{}); + updater->Update(¶m, &gpair, p_dmat.get(), position, {&tree}); + + ASSERT_EQ(tree.NumExtraNodes(), 4); + ASSERT_EQ(tree[0].SplitIndex(), 1); + + ASSERT_EQ(tree[tree[0].LeftChild()].SplitIndex(), 0); + ASSERT_EQ(tree[tree[0].RightChild()].SplitIndex(), 0); + } + { + // Without constraints + RegTree tree{1u, kCols}; + + std::unique_ptr updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)}; + std::vector> position(1); + TrainParam param; + param.Init(Args{}); + updater->Configure(Args{}); + updater->Update(¶m, &gpair, p_dmat.get(), position, {&tree}); + + ASSERT_EQ(tree.NumExtraNodes(), 10); + ASSERT_EQ(tree[0].SplitIndex(), 1); + + ASSERT_NE(tree[tree[0].LeftChild()].SplitIndex(), 0); + ASSERT_NE(tree[tree[0].RightChild()].SplitIndex(), 0); + } +} + namespace { void TestColumnSplitPartitioner(size_t n_samples, size_t base_rowid, std::shared_ptr Xy, std::vector* hess, float min_value, float mid_value, diff --git a/tests/cpp/tree/test_column_split.h b/tests/cpp/tree/test_column_split.h index b03597f38681..eba452a15a1c 100644 --- a/tests/cpp/tree/test_column_split.h +++ b/tests/cpp/tree/test_column_split.h @@ -23,9 +23,13 @@ inline std::shared_ptr GenerateCatDMatrix(std::size_t rows, std::size_t for (size_t i = 0; i < ft.size(); ++i) { ft[i] = (i % 3 == 0) ? FeatureType::kNumerical : FeatureType::kCategorical; } - return RandomDataGenerator(rows, cols, 0.6f).Seed(3).Type(ft).MaxCategory(17).GenerateDMatrix(); + return RandomDataGenerator(rows, cols, sparsity) + .Seed(3) + .Type(ft) + .MaxCategory(17) + .GenerateDMatrix(); } else { - return RandomDataGenerator{rows, cols, 0.6f}.Seed(3).GenerateDMatrix(); + return RandomDataGenerator{rows, cols, sparsity}.Seed(3).GenerateDMatrix(); } } diff --git a/tests/cpp/tree/test_histmaker.cc b/tests/cpp/tree/test_histmaker.cc deleted file mode 100644 index 888790aa7c3c..000000000000 --- a/tests/cpp/tree/test_histmaker.cc +++ /dev/null @@ -1,59 +0,0 @@ -/** - * Copyright 2019-2024, XGBoost Contributors - */ -#include -#include -#include - -#include "../../../src/tree/param.h" // for TrainParam -#include "../helpers.h" -#include "test_column_split.h" // for GenerateCatDMatrix - -namespace xgboost::tree { -TEST(GrowHistMaker, InteractionConstraint) { - auto constexpr kRows = 32; - auto constexpr kCols = 16; - auto p_dmat = GenerateCatDMatrix(kRows, kCols, 0.0, false); - Context ctx; - - linalg::Matrix gpair({kRows}, ctx.Device()); - gpair.Data()->Copy(GenerateRandomGradients(kRows)); - - ObjInfo task{ObjInfo::kRegression}; - { - // With constraints - RegTree tree{1, kCols}; - - std::unique_ptr updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)}; - TrainParam param; - param.UpdateAllowUnknown( - Args{{"interaction_constraints", "[[0, 1]]"}, {"num_feature", std::to_string(kCols)}}); - std::vector> position(1); - updater->Configure(Args{}); - updater->Update(¶m, &gpair, p_dmat.get(), position, {&tree}); - - ASSERT_EQ(tree.NumExtraNodes(), 4); - ASSERT_EQ(tree[0].SplitIndex(), 1); - - ASSERT_EQ(tree[tree[0].LeftChild()].SplitIndex(), 0); - ASSERT_EQ(tree[tree[0].RightChild()].SplitIndex(), 0); - } - { - // Without constraints - RegTree tree{1u, kCols}; - - std::unique_ptr updater{TreeUpdater::Create("grow_histmaker", &ctx, &task)}; - std::vector> position(1); - TrainParam param; - param.Init(Args{}); - updater->Configure(Args{}); - updater->Update(¶m, &gpair, p_dmat.get(), position, {&tree}); - - ASSERT_EQ(tree.NumExtraNodes(), 10); - ASSERT_EQ(tree[0].SplitIndex(), 1); - - ASSERT_NE(tree[tree[0].LeftChild()].SplitIndex(), 0); - ASSERT_NE(tree[tree[0].RightChild()].SplitIndex(), 0); - } -} -} // namespace xgboost::tree