Skip to content

Commit

Permalink
Support standardization for sparse vectors in logistic regression MG (r…
Browse files Browse the repository at this point in the history
…apidsai#5806)

* Align scratch space to resolve cudamisalignedaddress error in cusparsespmm

* revise int n_samples to size_t, set neg variance to add back mean square term, not tested yet

* revise per comments

* support mean var calculation in chunks to avoid precision loss of adding one to a large number

* double the training data to 2e5 when n_classe is larger than 5, to increase the stability of the tests
  • Loading branch information
lijinf2 authored Apr 3, 2024
1 parent 2c876b5 commit 669fad2
Show file tree
Hide file tree
Showing 5 changed files with 284 additions and 27 deletions.
3 changes: 1 addition & 2 deletions cpp/src/glm/qn/mg/glm_base_mg.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,7 @@ struct GLMWithDataMG : ML::GLM::detail::GLMWithData<T, GLMObjective> {
communicator.sync_stream(stream);

if (stder_p != NULL) {
stder_p->adapt_gradient_for_linearBwd(
*handle_p, G, *(this->Z), (this->X)->n != G.n, n_samples);
stder_p->adapt_gradient_for_linearBwd(*handle_p, G, *(this->Z), (this->X)->n != G.n);
raft::copy(wFlat.data, wFlatOrigin.data(), this->C * this->dims, stream);
}

Expand Down
162 changes: 157 additions & 5 deletions cpp/src/glm/qn/mg/standardization.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <raft/linalg/divide.cuh>
#include <raft/linalg/multiply.cuh>
#include <raft/linalg/sqrt.cuh>
#include <raft/linalg/subtract.cuh>
#include <raft/matrix/math.hpp>
#include <raft/sparse/op/row_op.cuh>
#include <raft/stats/stddev.cuh>
Expand All @@ -42,7 +43,7 @@ namespace opg {
template <typename T>
void mean_stddev(const raft::handle_t& handle,
const SimpleDenseMat<T>& X,
int n_samples,
size_t n_samples,
T* mean_vector,
T* stddev_vector)
{
Expand All @@ -65,9 +66,129 @@ void mean_stddev(const raft::handle_t& handle,
comm.allreduce(stddev_vector, stddev_vector, D, raft::comms::op_t::SUM, stream);
comm.sync_stream(stream);

// avoid negative variance that is due to precision loss of floating point arithmetic
weight = n_samples < 1 ? T(0) : T(1) / T(n_samples - 1);
weight = n_samples * weight;
auto no_neg_op = [weight] __device__(const T a, const T b) -> T {
if (a >= 0) return a;

return a + weight * b * b;
};

raft::linalg::binaryOp(stddev_vector, stddev_vector, mean_vector, D, no_neg_op, stream);

raft::linalg::sqrt(stddev_vector, stddev_vector, D, handle.get_stream());
}

template <typename T>
SimpleSparseMat<T> get_sub_mat(const raft::handle_t& handle,
SimpleSparseMat<T> mat,
int start,
int end,
rmm::device_uvector<int>& buff_row_ids)
{
end = end <= mat.m ? end : mat.m;
int n_rows = end - start;
int n_cols = mat.n;
auto stream = handle.get_stream();

RAFT_EXPECTS(start < end, "start index must be smaller than end index");
RAFT_EXPECTS(buff_row_ids.size() >= n_rows + 1,
"the size of buff_row_ids should be at least end - start + 1");
raft::copy(buff_row_ids.data(), mat.row_ids + start, n_rows + 1, stream);

int idx;
raft::copy(&idx, buff_row_ids.data(), 1, stream);
raft::resource::sync_stream(handle);

auto subtract_op = [idx] __device__(const int a) { return a - idx; };
raft::linalg::unaryOp(buff_row_ids.data(), buff_row_ids.data(), n_rows + 1, subtract_op, stream);

int nnz;
raft::copy(&nnz, buff_row_ids.data() + n_rows, 1, stream);
raft::resource::sync_stream(handle);

SimpleSparseMat<T> res(
mat.values + idx, mat.cols + idx, buff_row_ids.data(), nnz, n_rows, n_cols);
return res;
}

template <typename T>
void mean(const raft::handle_t& handle,
const SimpleSparseMat<T>& X,
size_t n_samples,
T* mean_vector)
{
int D = X.n;
int num_rows = X.m;
auto stream = handle.get_stream();
auto& comm = handle.get_comms();

int chunk_size = 500000; // split matrix by rows for better numeric precision
rmm::device_uvector<int> buff_row_ids(chunk_size + 1, stream);

rmm::device_uvector<T> ones(chunk_size, stream);
SimpleVec<T> ones_vec(ones.data(), chunk_size);
ones_vec.fill(1.0, stream);

rmm::device_uvector<T> buff_D(D, stream);
SimpleDenseMat<T> buff_D_mat(buff_D.data(), 1, D);

// calculate mean
SimpleDenseMat<T> mean_mat(mean_vector, 1, D);
mean_mat.fill(0., stream);

for (int i = 0; i < X.m; i += chunk_size) {
// get X[i:i + chunk_size]
SimpleSparseMat<T> X_sub = get_sub_mat(handle, X, i, i + chunk_size, buff_row_ids);
SimpleDenseMat<T> ones_mat(ones.data(), 1, X_sub.m);

X_sub.gemmb(handle, 1., ones_mat, false, false, 0., buff_D_mat, stream);
raft::linalg::binaryOp(mean_vector, mean_vector, buff_D_mat.data, D, raft::add_op(), stream);
}

T weight = T(1) / T(n_samples);
raft::linalg::multiplyScalar(mean_vector, mean_vector, weight, D, stream);
comm.allreduce(mean_vector, mean_vector, D, raft::comms::op_t::SUM, stream);
comm.sync_stream(stream);
}

template <typename T>
void mean_stddev(const raft::handle_t& handle,
const SimpleSparseMat<T>& X,
size_t n_samples,
T* mean_vector,
T* stddev_vector)
{
auto stream = handle.get_stream();
int D = X.n;
mean(handle, X, n_samples, mean_vector);

// calculate stdev.S
rmm::device_uvector<T> X_values_squared(X.nnz, stream);
raft::copy(X_values_squared.data(), X.values, X.nnz, stream);
auto square_op = [] __device__(const T a) { return a * a; };
raft::linalg::unaryOp(X_values_squared.data(), X_values_squared.data(), X.nnz, square_op, stream);

auto X_squared = SimpleSparseMat<T>(X_values_squared.data(), X.cols, X.row_ids, X.nnz, X.m, X.n);

mean(handle, X_squared, n_samples, stddev_vector);

T weight = n_samples / T(n_samples - 1);
auto submean_no_neg_op = [weight] __device__(const T a, const T b) -> T {
T res = weight * (a - b * b);
if (res < 0) {
// return sum(x^2) / (n - 1) if negative variance (due to precision loss of floating point
// arithmetic)
res = weight * a;
}
return res;
};
raft::linalg::binaryOp(stddev_vector, stddev_vector, mean_vector, X.n, submean_no_neg_op, stream);

raft::linalg::sqrt(stddev_vector, stddev_vector, X.n, handle.get_stream());
}

struct inverse_op {
template <typename T>
constexpr RAFT_INLINE_FUNCTION auto operator()(const T& a) const
Expand All @@ -85,11 +206,11 @@ struct Standardizer {

Standardizer(const raft::handle_t& handle,
const SimpleDenseMat<T>& X,
int n_samples,
size_t n_samples,
rmm::device_uvector<T>& mean_std_buff)
{
int D = X.n;
ASSERT(mean_std_buff.size() == 4 * D, "buff size must be four times the dimension");
ASSERT(mean_std_buff.size() == 4 * D, "mean_std_buff size must be four times the dimension");

auto stream = handle.get_stream();

Expand All @@ -106,6 +227,38 @@ struct Standardizer {
raft::linalg::binaryOp(scaled_mean.data, std_inv.data, mean.data, D, raft::mul_op(), stream);
}

Standardizer(const raft::handle_t& handle,
const SimpleSparseMat<T>& X,
size_t n_samples,
rmm::device_uvector<T>& mean_std_buff,
size_t vec_size)
{
int D = X.n;
ASSERT(mean_std_buff.size() == 4 * vec_size,
"mean_std_buff size must be four times the aligned size");

auto stream = handle.get_stream();

T* p_ws = mean_std_buff.data();

mean.reset(p_ws, D);
p_ws += vec_size;

std.reset(p_ws, D);
p_ws += vec_size;

std_inv.reset(p_ws, D);
p_ws += vec_size;

scaled_mean.reset(p_ws, D);

mean_stddev(handle, X, n_samples, mean.data, std.data);
raft::linalg::unaryOp(std_inv.data, std.data, D, inverse_op(), stream);

// scale mean by the standard deviation
raft::linalg::binaryOp(scaled_mean.data, std_inv.data, mean.data, D, raft::mul_op(), stream);
}

void adapt_model_for_linearFwd(
const raft::handle_t& handle, T* coef, int n_targets, int D, bool has_bias) const
{
Expand Down Expand Up @@ -139,8 +292,7 @@ struct Standardizer {
void adapt_gradient_for_linearBwd(const raft::handle_t& handle,
SimpleDenseMat<T>& G,
const SimpleDenseMat<T>& dZ,
bool has_bias,
int n_samples) const
bool has_bias) const
{
auto stream = handle.get_stream();
int D = mean.len;
Expand Down
26 changes: 14 additions & 12 deletions cpp/src/glm/qn_mg.cu
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ void qnFit_impl(const raft::handle_t& handle,
auto X_simple = SimpleDenseMat<T>(X, N, D, X_col_major ? COL_MAJOR : ROW_MAJOR);

rmm::device_uvector<T> mean_std_buff(4 * D, handle.get_stream());
Standardizer<T>* stder = NULL;
if (standardization) stder = new Standardizer(handle, X_simple, n_samples, mean_std_buff);
Standardizer<T>* std_obj = NULL;
if (standardization) std_obj = new Standardizer(handle, X_simple, n_samples, mean_std_buff);

ML::GLM::opg::qn_fit_x_mg(handle,
pams,
Expand All @@ -128,12 +128,12 @@ void qnFit_impl(const raft::handle_t& handle,
n_samples,
rank,
n_ranks,
stder); // ignore sample_weight, svr_eps
std_obj); // ignore sample_weight, svr_eps

if (standardization) {
int n_targets = ML::GLM::detail::qn_is_classification(pams.loss) && C == 2 ? 1 : C;
stder->adapt_model_for_linearFwd(handle, w0, n_targets, D, pams.fit_intercept);
delete stder;
std_obj->adapt_model_for_linearFwd(handle, w0, n_targets, D, pams.fit_intercept);
delete std_obj;
}

return;
Expand Down Expand Up @@ -238,12 +238,14 @@ void qnFitSparse_impl(const raft::handle_t& handle,
int rank,
int n_ranks)
{
RAFT_EXPECTS(standardization == false, "standardization for sparse vectors is not supported yet");

auto X_simple = SimpleSparseMat<T>(X_values, X_cols, X_row_ids, X_nnz, N, D);

rmm::device_uvector<T> mean_std_buff(4 * D, handle.get_stream());
Standardizer<T>* stder = NULL;
size_t vec_size = raft::alignTo<size_t>(sizeof(T) * D, ML::GLM::detail::qn_align);
rmm::device_uvector<T> mean_std_buff(4 * vec_size, handle.get_stream());
Standardizer<T>* std_obj = NULL;

if (standardization)
std_obj = new Standardizer(handle, X_simple, n_samples, mean_std_buff, vec_size);

ML::GLM::opg::qn_fit_x_mg(handle,
pams,
Expand All @@ -256,12 +258,12 @@ void qnFitSparse_impl(const raft::handle_t& handle,
n_samples,
rank,
n_ranks,
stder); // ignore sample_weight, svr_eps
std_obj); // ignore sample_weight, svr_eps

if (standardization) {
int n_targets = ML::GLM::detail::qn_is_classification(pams.loss) && C == 2 ? 1 : C;
stder->adapt_model_for_linearFwd(handle, w0, n_targets, D, pams.fit_intercept);
delete stder;
std_obj->adapt_model_for_linearFwd(handle, w0, n_targets, D, pams.fit_intercept);
delete std_obj;
}

return;
Expand Down
3 changes: 0 additions & 3 deletions python/cuml/linear_model/logistic_regression_mg.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,6 @@ class LogisticRegressionMG(MGFitMixin, LogisticRegression):

sparse_input = isinstance(X, list)

if self.standardization:
assert not sparse_input, "standardization for sparse vectors is not supported yet"

if self.dtype == np.float32:
if sparse_input is False:
qnFit(
Expand Down
Loading

0 comments on commit 669fad2

Please sign in to comment.