Skip to content

Commit

Permalink
remove page in uniform sampling.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jun 12, 2023
1 parent 1b0dab2 commit f383f76
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 7 deletions.
9 changes: 5 additions & 4 deletions src/tree/gpu_hist/gradient_based_sampler.cu
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ GradientBasedSample ExternalMemoryNoSampling::Sample(Context const* ctx,
return {dmat->Info().num_row_, page_.get(), gpair};
}

UniformSampling::UniformSampling(EllpackPageImpl const* page, float subsample)
: page_(page), subsample_(subsample) {}
UniformSampling::UniformSampling(BatchParam batch_param, float subsample)
: batch_param_{std::move(batch_param)}, subsample_(subsample) {}

GradientBasedSample UniformSampling::Sample(Context const* ctx, common::Span<GradientPair> gpair,
DMatrix* dmat) {
Expand All @@ -185,7 +185,8 @@ GradientBasedSample UniformSampling::Sample(Context const* ctx, common::Span<Gra
thrust::replace_if(cuctx->CTP(), dh::tbegin(gpair), dh::tend(gpair),
thrust::counting_iterator<std::size_t>(0),
BernoulliTrial(common::GlobalRandom()(), subsample_), GradientPair());
return {dmat->Info().num_row_, page_, gpair};
auto page = (*dmat->GetBatches<EllpackPage>(ctx, batch_param_).begin()).Impl();
return {dmat->Info().num_row_, page, gpair};
}

ExternalMemoryUniformSampling::ExternalMemoryUniformSampling(size_t n_rows,
Expand Down Expand Up @@ -331,7 +332,7 @@ GradientBasedSampler::GradientBasedSampler(Context const* ctx, EllpackPageImpl c
if (is_external_memory) {
strategy_.reset(new ExternalMemoryUniformSampling(n_rows, batch_param, subsample));
} else {
strategy_.reset(new UniformSampling(page, subsample));
strategy_.reset(new UniformSampling(batch_param, subsample));
}
break;
case TrainParam::kGradientBased:
Expand Down
4 changes: 2 additions & 2 deletions src/tree/gpu_hist/gradient_based_sampler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@ class ExternalMemoryNoSampling : public SamplingStrategy {
/*! \brief Uniform sampling in in-memory mode. */
class UniformSampling : public SamplingStrategy {
public:
UniformSampling(EllpackPageImpl const* page, float subsample);
UniformSampling(BatchParam batch_param, float subsample);
GradientBasedSample Sample(Context const* ctx, common::Span<GradientPair> gpair,
DMatrix* dmat) override;

private:
EllpackPageImpl const* page_;
BatchParam batch_param_;
float subsample_;
};

Expand Down
2 changes: 1 addition & 1 deletion src/tree/updater_gpu_hist.cu
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ struct GPUHistMakerDevice {
Context const* ctx_;

public:
EllpackPageImpl const* page;
EllpackPageImpl const* page{nullptr};
common::Span<FeatureType const> feature_types;
BatchParam batch_param;

Expand Down

0 comments on commit f383f76

Please sign in to comment.