diff --git a/tests/cpp/tree/test_approx.cc b/tests/cpp/tree/test_approx.cc index 4eb049cdb170..57a0cd3545bb 100644 --- a/tests/cpp/tree/test_approx.cc +++ b/tests/cpp/tree/test_approx.cc @@ -8,19 +8,6 @@ namespace xgboost { namespace tree { -namespace { -void GetSplit(RegTree *tree, float split_value, std::vector *candidates) { - tree->ExpandNode( - /*nid=*/RegTree::kRoot, /*split_index=*/0, /*split_value=*/split_value, - /*default_left=*/true, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, - /*left_sum=*/0.0f, - /*right_sum=*/0.0f); - candidates->front().split.split_value = split_value; - candidates->front().split.sindex = 0; - candidates->front().split.sindex |= (1U << 31); -} -} // anonymous namespace - TEST(Approx, Partitioner) { size_t n_samples = 1024, n_features = 1, base_rowid = 0; ApproxRowPartitioner partitioner{n_samples, base_rowid}; @@ -33,18 +20,20 @@ TEST(Approx, Partitioner) { ctx.InitAllowUnknown(Args{}); std::vector candidates{{0, 0, 0.4}}; - auto grad = GenerateRandomGradients(n_samples); - std::vector hess(grad.Size()); - std::transform(grad.HostVector().cbegin(), grad.HostVector().cend(), hess.begin(), - [](auto gpair) { return gpair.GetHess(); }); - - for (auto const &page : Xy->GetBatches({GenericParameter::kCpuId, 64, hess})) { - bst_feature_t const split_ind = 0; + for (auto const &page : Xy->GetBatches({GenericParameter::kCpuId, 64})) { + bst_feature_t split_ind = 0; { auto min_value = page.cut.MinValues()[split_ind]; RegTree tree; + tree.ExpandNode( + /*nid=*/0, /*split_index=*/0, /*split_value=*/min_value, + /*default_left=*/true, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + /*left_sum=*/0.0f, + /*right_sum=*/0.0f); ApproxRowPartitioner partitioner{n_samples, base_rowid}; - GetSplit(&tree, min_value, &candidates); + candidates.front().split.split_value = min_value; + candidates.front().split.sindex = 0; + candidates.front().split.sindex |= (1U << 31); partitioner.UpdatePosition(&ctx, page, candidates, &tree); ASSERT_EQ(partitioner.Size(), 3); ASSERT_EQ(partitioner[1].Size(), 0); @@ -55,8 +44,16 @@ TEST(Approx, Partitioner) { auto ptr = page.cut.Ptrs()[split_ind + 1]; float split_value = page.cut.Values().at(ptr / 2); RegTree tree; - GetSplit(&tree, split_value, &candidates); + tree.ExpandNode( + /*nid=*/RegTree::kRoot, /*split_index=*/split_ind, + /*split_value=*/split_value, + /*default_left=*/true, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + /*left_sum=*/0.0f, + /*right_sum=*/0.0f); auto left_nidx = tree[RegTree::kRoot].LeftChild(); + candidates.front().split.split_value = split_value; + candidates.front().split.sindex = 0; + candidates.front().split.sindex |= (1U << 31); partitioner.UpdatePosition(&ctx, page, candidates, &tree); auto elem = partitioner[left_nidx];