diff --git a/include/xgboost/span.h b/include/xgboost/span.h index 579737a59033..f7c2ead41183 100644 --- a/include/xgboost/span.h +++ b/include/xgboost/span.h @@ -433,7 +433,7 @@ class Span { using const_reverse_iterator = const std::reverse_iterator; // NOLINT // constructors - constexpr Span() __span_noexcept = default; + constexpr Span() = default; XGBOOST_DEVICE Span(pointer _ptr, index_type _count) : size_(_count), data_(_ptr) { @@ -480,16 +480,11 @@ class Span { __span_noexcept : size_(_other.size()), data_(_other.data()) {} - XGBOOST_DEVICE constexpr Span(const Span& _other) - __span_noexcept : size_(_other.size()), data_(_other.data()) {} - - XGBOOST_DEVICE Span& operator=(const Span& _other) __span_noexcept { - size_ = _other.size(); - data_ = _other.data(); - return *this; - } - - XGBOOST_DEVICE ~Span() __span_noexcept {}; // NOLINT + constexpr Span(Span const& _other) noexcept(true) = default; + constexpr Span& operator=(Span const& _other) noexcept(true) = default; + constexpr Span(Span&& _other) noexcept(true) = default; + constexpr Span& operator=(Span&& _other) noexcept(true) = default; + ~Span() noexcept(true) = default; XGBOOST_DEVICE constexpr iterator begin() const __span_noexcept { // NOLINT return {this, 0}; diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 0b5468d2913b..d2610ff1586b 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -1098,7 +1098,6 @@ XGB_DLL int XGBoosterTrainOneIter(BoosterHandle handle, DMatrixHandle dtrain, in auto ctx = learner->Ctx(); if (!grad_is_cuda) { gpair.Reshape(i_grad.Shape(0), i_grad.Shape(1)); - auto const shape = gpair.Shape(); auto h_gpair = gpair.HostView(); DispatchDType(i_grad, DeviceOrd::CPU(), [&](auto &&t_grad) { DispatchDType(i_hess, DeviceOrd::CPU(), [&](auto &&t_hess) { diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index 4b3a3cae644f..8b80e86ec2b7 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -180,10 +180,9 @@ void RemoveDuplicatedCategories(Context const* ctx, MetaInfo const& info, sorted_entries.resize(n_uniques); // Renew the column scan and cut scan based on categorical data. - auto d_old_column_sizes_scan = dh::ToSpan(column_sizes_scan); dh::caching_device_vector new_cuts_size(info.num_col_ + 1); CHECK_EQ(new_column_scan.size(), new_cuts_size.size()); - dh::LaunchN(new_column_scan.size(), + dh::LaunchN(new_column_scan.size(), ctx->CUDACtx()->Stream(), [=, d_new_cuts_size = dh::ToSpan(new_cuts_size), d_old_column_sizes_scan = dh::ToSpan(column_sizes_scan), d_new_columns_ptr = dh::ToSpan(new_column_scan)] __device__(size_t idx) { @@ -277,7 +276,6 @@ void ProcessWeightedBatch(Context const* ctx, const SparsePage& page, MetaInfo c HostDeviceVector* p_out_weight) { if (hessian.empty()) { if (info.IsRanking() && !info.weights_.Empty()) { - common::Span group_weight = info.weights_.ConstDeviceSpan(); dh::device_vector group_ptr(info.group_ptr_); auto d_group_ptr = dh::ToSpan(group_ptr); CHECK_GE(d_group_ptr.size(), 2) << "Must have at least 1 group for ranking."; diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index 5d016dfc72b4..d12befae8f42 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -167,7 +167,6 @@ void CopyGradient(Context const* ctx, linalg::Matrix const* in_gpa GPUCopyGradient(ctx, in_gpair, group_id, out_gpair); } else { auto const& in = *in_gpair; - auto target_gpair = in.Slice(linalg::All(), group_id); auto h_tmp = out_gpair->HostView(); auto h_in = in.HostView().Slice(linalg::All(), group_id); CHECK_EQ(h_tmp.Size(), h_in.Size()); diff --git a/src/metric/rank_metric.cc b/src/metric/rank_metric.cc index 53841c05135e..d0247c14b3d2 100644 --- a/src/metric/rank_metric.cc +++ b/src/metric/rank_metric.cc @@ -329,7 +329,6 @@ class EvalPrecision : public EvalRankWithCache { auto gptr = p_cache->DataGroupPtr(ctx_); auto h_label = info.labels.HostView().Slice(linalg::All(), 0); - auto h_predt = linalg::MakeTensorView(ctx_, &predt, predt.Size()); auto rank_idx = p_cache->SortedIdx(ctx_, predt.ConstHostSpan()); auto weight = common::MakeOptionalWeights(ctx_, info.weights_); @@ -433,7 +432,6 @@ class EvalMAPScore : public EvalRankWithCache { auto gptr = p_cache->DataGroupPtr(ctx_); auto h_label = info.labels.HostView().Slice(linalg::All(), 0); - auto h_predt = linalg::MakeTensorView(ctx_, &predt, predt.Size()); auto map_gloc = p_cache->Map(ctx_); std::fill_n(map_gloc.data(), map_gloc.size(), 0.0); diff --git a/src/objective/lambdarank_obj.cu b/src/objective/lambdarank_obj.cu index 25c5d138ccc7..c8da43748ae8 100644 --- a/src/objective/lambdarank_obj.cu +++ b/src/objective/lambdarank_obj.cu @@ -488,9 +488,6 @@ void LambdaRankGetGradientPairwise(Context const* ctx, std::int32_t iter, info.labels.SetDevice(device); predt.SetDevice(device); - auto d_predt = predt.ConstDeviceSpan(); - auto const d_sorted_idx = p_cache->SortedIdx(ctx, d_predt); - auto delta = [] XGBOOST_DEVICE(float, float, std::size_t, std::size_t, bst_group_t) { return 1.0; }; diff --git a/src/tree/gpu_hist/evaluate_splits.cu b/src/tree/gpu_hist/evaluate_splits.cu index 387eeda91b28..423cbe9b5253 100644 --- a/src/tree/gpu_hist/evaluate_splits.cu +++ b/src/tree/gpu_hist/evaluate_splits.cu @@ -7,6 +7,7 @@ #include "../../collective/allgather.h" #include "../../collective/communicator-inl.h" // for GetWorldSize, GetRank #include "../../common/categorical.h" +#include "../../common/cuda_context.cuh" // for CUDAContext #include "evaluate_splits.cuh" #include "expand_entry.cuh" @@ -472,8 +473,8 @@ void GPUHistEvaluator::EvaluateSplits(Context const *ctx, const std::vector inputs(1); - dh::safe_cuda(cudaMemcpyAsync(inputs.data().get(), &input, sizeof(input), cudaMemcpyDefault)); + dh::CachingDeviceUVector inputs(1); + dh::safe_cuda(cudaMemcpyAsync(inputs.data(), &input, sizeof(input), cudaMemcpyDefault)); dh::TemporaryArray out_entries(1); this->EvaluateSplits(ctx, {input.nidx}, input.feature_set.size(), dh::ToSpan(inputs), diff --git a/src/tree/gpu_hist/evaluate_splits.cuh b/src/tree/gpu_hist/evaluate_splits.cuh index 19e8f2f931d6..5abcf3d04065 100644 --- a/src/tree/gpu_hist/evaluate_splits.cuh +++ b/src/tree/gpu_hist/evaluate_splits.cuh @@ -10,7 +10,6 @@ #include "../split_evaluator.h" #include "../updater_gpu_common.cuh" #include "expand_entry.cuh" -#include "histogram.cuh" namespace xgboost { namespace common { diff --git a/src/tree/gpu_hist/evaluator.cu b/src/tree/gpu_hist/evaluator.cu index ee542a94a825..90e83c2c3b3d 100644 --- a/src/tree/gpu_hist/evaluator.cu +++ b/src/tree/gpu_hist/evaluator.cu @@ -7,6 +7,7 @@ #include // thrust::any_of #include // thrust::stable_sort +#include "../../common/cuda_context.cuh" // for CUDAContext #include "../../common/device_helpers.cuh" #include "../../common/hist_util.h" // common::HistogramCuts #include "evaluate_splits.cuh" @@ -53,9 +54,7 @@ void GPUHistEvaluator::Reset(Context const *ctx, common::HistogramCuts const &cu * cache feature index binary search result */ feature_idx_.resize(cat_sorted_idx_.size()); - auto d_fidxes = dh::ToSpan(feature_idx_); auto it = thrust::make_counting_iterator(0ul); - auto values = cuts.cut_values_.ConstDeviceSpan(); thrust::transform(ctx->CUDACtx()->CTP(), it, it + feature_idx_.size(), feature_idx_.begin(), [=] XGBOOST_DEVICE(size_t i) { auto fidx = dh::SegmentId(ptrs, i); diff --git a/tests/cpp/common/test_host_device_vector.cu b/tests/cpp/common/test_host_device_vector.cu index 65b8135bfadc..7c3c2cd070cf 100644 --- a/tests/cpp/common/test_host_device_vector.cu +++ b/tests/cpp/common/test_host_device_vector.cu @@ -146,7 +146,7 @@ TEST(HostDeviceVector, SetDevice) { vec.SetDevice(device); ASSERT_EQ(vec.Size(), h_vec.size()); - auto span = vec.DeviceSpan(); // sync to device + vec.DeviceSpan(); // sync to device vec.SetDevice(DeviceOrd::CPU()); // pull back to cpu. ASSERT_EQ(vec.Size(), h_vec.size()); diff --git a/tests/cpp/common/test_linalg.cc b/tests/cpp/common/test_linalg.cc index 21c5ad30d566..a9adccd0a511 100644 --- a/tests/cpp/common/test_linalg.cc +++ b/tests/cpp/common/test_linalg.cc @@ -117,9 +117,8 @@ TEST(Linalg, TensorView) { { // Don't assign the initial dimension, tensor should be able to deduce the correct dim // for Slice. - auto t = MakeTensorView(&ctx, data, 2, 3, 4); - auto s = t.Slice(1, 2, All()); - static_assert(decltype(s)::kDimension == 1); + static_assert(decltype(MakeTensorView(&ctx, data, 2, 3, 4).Slice(1, 2, All()))::kDimension == + 1); } { auto t = MakeTensorView(&ctx, data, 2, 3, 4); diff --git a/tests/cpp/common/test_span.cc b/tests/cpp/common/test_span.cc index 486896c24891..53ce77d16c69 100644 --- a/tests/cpp/common/test_span.cc +++ b/tests/cpp/common/test_span.cc @@ -11,6 +11,15 @@ #include "../../../src/common/transform_iterator.h" // for MakeIndexTransformIter namespace xgboost::common { +namespace { +using ST = common::Span; +static_assert(std::is_trivially_copyable_v); +static_assert(std::is_trivially_move_assignable_v); +static_assert(std::is_trivially_move_constructible_v); +static_assert(std::is_trivially_copy_assignable_v); +static_assert(std::is_trivially_copy_constructible_v); +} // namespace + TEST(Span, TestStatus) { int status = 1; TestTestStatus {&status}(); diff --git a/tests/cpp/data/test_extmem_quantile_dmatrix.cu b/tests/cpp/data/test_extmem_quantile_dmatrix.cu index 6f700d98cee7..6d91ac7f60b1 100644 --- a/tests/cpp/data/test_extmem_quantile_dmatrix.cu +++ b/tests/cpp/data/test_extmem_quantile_dmatrix.cu @@ -23,8 +23,8 @@ class ExtMemQuantileDMatrixGpu : public ::testing::TestWithParam h_orig, h_sparse; - auto orig_acc = orig.Impl()->GetHostAccessor(ctx, &h_orig, {}); - auto sparse_acc = sparse.Impl()->GetHostAccessor(ctx, &h_sparse, {}); + [[maybe_unused]] auto orig_acc = orig.Impl()->GetHostAccessor(ctx, &h_orig, {}); + [[maybe_unused]] auto sparse_acc = sparse.Impl()->GetHostAccessor(ctx, &h_sparse, {}); ASSERT_EQ(h_orig.size(), h_sparse.size()); auto equal = std::equal(h_orig.cbegin(), h_orig.cend(), h_sparse.cbegin()); diff --git a/tests/cpp/data/test_simple_dmatrix.cu b/tests/cpp/data/test_simple_dmatrix.cu index 04859ed1e300..6bf76e37acc1 100644 --- a/tests/cpp/data/test_simple_dmatrix.cu +++ b/tests/cpp/data/test_simple_dmatrix.cu @@ -360,8 +360,6 @@ TEST(SimpleDMatrix, FromCupySparse){ auto& batch = *dmat.GetBatches().begin(); auto page = batch.GetView(); - auto inst0 = page[0]; - auto inst1 = page[1]; EXPECT_EQ(page[0].size(), 1); EXPECT_EQ(page[1].size(), 1); EXPECT_EQ(page[0][0].fvalue, 0.0f); diff --git a/tests/cpp/tree/gpu_hist/test_histogram.cu b/tests/cpp/tree/gpu_hist/test_histogram.cu index a59d07137cc2..ce34c3d3b561 100644 --- a/tests/cpp/tree/gpu_hist/test_histogram.cu +++ b/tests/cpp/tree/gpu_hist/test_histogram.cu @@ -104,7 +104,6 @@ void TestBuildHist(bool use_shared_memory_histograms) { auto ctx = MakeCUDACtx(0); auto page = BuildEllpackPage(&ctx, kNRows, kNCols); - BatchParam batch_param{}; xgboost::SimpleLCG gen; xgboost::SimpleRealUniformDistribution dist(0.0f, 1.0f); @@ -448,7 +447,6 @@ class HistogramExternalMemoryTest : public ::testing::TestWithParamGetDeviceAccessor(&ctx); fg = std::make_unique(impl->Cuts()); auto init = GradientPairInt64{0, 0}; multi_hist = decltype(multi_hist)(impl->Cuts().TotalBins(), init);