Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sample_weight to Coordinate Descent solver (Lasso and ElasticNet) #4867

Merged
merged 12 commits into from
Aug 31, 2022
56 changes: 53 additions & 3 deletions cpp/include/cuml/solvers/solver.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2021, NVIDIA CORPORATION.
* Copyright (c) 2018-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -101,6 +101,54 @@ void sgdPredictBinaryClass(raft::handle_t& handle,
double* preds,
int loss);

/**
* Fits a linear, lasso, and elastic-net regression model using Coordinate Descent solver.
*
* i.e. finds coefficients that minimize the following loss function:
*
* f(coef) = 1/2 * || labels - input * coef ||^2
* + 1/2 * alpha * (1 - l1_ratio) * ||coef||^2
* + alpha * l1_ratio * ||coef||_1
*
*
* @param handle
* Reference of raft::handle_t
* @param input
* pointer to an array in column-major format (size of n_rows, n_cols)
* @param n_rows
* n_samples or rows in input
* @param n_cols
* n_features or columns in X
* @param labels
* pointer to an array for labels (size of n_rows)
* @param coef
* pointer to an array for coefficients (size of n_cols). This will be filled with
* coefficients once the function is executed.
* @param intercept
* pointer to a scalar for intercept. This will be filled
* once the function is executed
* @param fit_intercept
* boolean parameter to control if the intercept will be fitted or not
* @param normalize
* boolean parameter to control if the data will be normalized or not;
* NB: the input is scaled by the column-wise biased sample standard deviation estimator.
* @param epochs
* Maximum number of iterations that solver will run
* @param loss
* enum to use different loss functions. Only linear regression loss functions is supported
* right now
* @param alpha
* L1 parameter
* @param l1_ratio
* ratio of alpha will be used for L1. (1 - l1_ratio) * alpha will be used for L2
* @param shuffle
* boolean parameter to control whether coordinates will be picked randomly or not
* @param tol
* tolerance to stop the solver
* @param sample_weight
* device pointer to sample weight vector of length n_rows (nullptr or uniform weights)
* This vector is modified during the computation
*/
void cdFit(raft::handle_t& handle,
float* input,
int n_rows,
Expand All @@ -115,7 +163,8 @@ void cdFit(raft::handle_t& handle,
float alpha,
float l1_ratio,
bool shuffle,
float tol);
float tol,
float* sample_weight = nullptr);

void cdFit(raft::handle_t& handle,
double* input,
Expand All @@ -131,7 +180,8 @@ void cdFit(raft::handle_t& handle,
double alpha,
double l1_ratio,
bool shuffle,
double tol);
double tol,
double* sample_weight = nullptr);

void cdPredict(raft::handle_t& handle,
const float* input,
Expand Down
77 changes: 58 additions & 19 deletions cpp/src/solver/cd.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,20 @@
#include <raft/common/nvtx.hpp>
#include <raft/cuda_utils.cuh>
#include <raft/cudart_utils.h>
#include <raft/linalg/add.hpp>
#include <raft/linalg/axpy.hpp>
#include <raft/linalg/eltwise.hpp>
#include <raft/linalg/gemm.hpp>
#include <raft/linalg/gemv.hpp>
#include <raft/linalg/multiply.hpp>
#include <raft/linalg/subtract.hpp>
#include <raft/linalg/unary_op.hpp>
#include <raft/matrix/math.hpp>
#include <raft/matrix/matrix.hpp>
#include <raft/linalg/add.cuh>
#include <raft/linalg/axpy.cuh>
#include <raft/linalg/eltwise.cuh>
#include <raft/linalg/gemm.cuh>
#include <raft/linalg/gemv.cuh>
#include <raft/linalg/map.cuh>
#include <raft/linalg/multiply.cuh>
#include <raft/linalg/power.cuh>
#include <raft/linalg/sqrt.cuh>
#include <raft/linalg/subtract.cuh>
#include <raft/linalg/unary_op.cuh>
#include <raft/matrix/math.cuh>
#include <raft/matrix/matrix.cuh>
#include <raft/stats/sum.cuh>

namespace ML {
namespace Solver {
Expand Down Expand Up @@ -123,8 +127,9 @@ __global__ void __launch_bounds__(1, 1) cdUpdateCoefKernel(math_t* coefLoc,
* boolean parameter to control whether coordinates will be picked randomly or not
* @param tol
* tolerance to stop the solver
* @param stream
* cuda stream
* @param sample_weight
* device pointer to sample weight vector of length n_rows (nullptr or uniform weights)
* This vector is modified during the computation
*/
template <typename math_t>
void cdFit(const raft::handle_t& handle,
Expand All @@ -142,20 +147,30 @@ void cdFit(const raft::handle_t& handle,
math_t l1_ratio,
bool shuffle,
math_t tol,
cudaStream_t stream)
math_t* sample_weight = nullptr)
{
raft::common::nvtx::range fun_scope("ML::Solver::cdFit-%d-%d", n_rows, n_cols);
ASSERT(n_cols > 0, "Parameter n_cols: number of columns cannot be less than one");
ASSERT(n_rows > 1, "Parameter n_rows: number of rows cannot be less than two");
ASSERT(loss == ML::loss_funct::SQRD_LOSS,
"Parameter loss: Only SQRT_LOSS function is supported for now");

cudaStream_t stream = handle.get_stream();
rmm::device_uvector<math_t> residual(n_rows, stream);
rmm::device_uvector<math_t> squared(n_cols, stream);
rmm::device_uvector<math_t> mu_input(0, stream);
rmm::device_uvector<math_t> mu_labels(0, stream);
rmm::device_uvector<math_t> norm2_input(0, stream);
math_t h_sum_sw = 0;

if (sample_weight != nullptr) {
rmm::device_scalar<math_t> sum_sw(stream);
raft::stats::sum(sum_sw.data(), sample_weight, 1, n_rows, true, stream);
raft::update_host(&h_sum_sw, sum_sw.data(), 1, stream);

raft::linalg::multiplyScalar(
sample_weight, sample_weight, (math_t)n_rows / h_sum_sw, n_rows, stream);
}
if (fit_intercept) {
mu_input.resize(n_cols, stream);
mu_labels.resize(1, stream);
Expand All @@ -171,7 +186,20 @@ void cdFit(const raft::handle_t& handle,
mu_labels.data(),
norm2_input.data(),
fit_intercept,
normalize);
normalize,
sample_weight);
}
if (sample_weight != nullptr) {
raft::linalg::sqrt(sample_weight, sample_weight, n_rows, stream);
raft::matrix::matrixVectorBinaryMult(
input, sample_weight, n_rows, n_cols, false, false, stream);
raft::linalg::map(
labels,
n_rows,
[] __device__(math_t a, math_t b) { return a * b; },
stream,
labels,
sample_weight);
}

std::vector<int> ri(n_cols);
Expand Down Expand Up @@ -254,6 +282,20 @@ void cdFit(const raft::handle_t& handle,
if (h_convState.coefMax < tol || (h_convState.diffMax / h_convState.coefMax) < tol) break;
}

if (sample_weight != nullptr) {
raft::matrix::matrixVectorBinaryDivSkipZero(
input, sample_weight, n_rows, n_cols, false, false, stream);
raft::linalg::map(
labels,
n_rows,
[] __device__(math_t a, math_t b) { return a / b; },
stream,
labels,
sample_weight);
raft::linalg::powerScalar(sample_weight, sample_weight, (math_t)2, n_rows, stream);
raft::linalg::multiplyScalar(sample_weight, sample_weight, h_sum_sw / n_rows, n_rows, stream);
}

if (fit_intercept) {
GLM::postProcessData(handle,
input,
Expand Down Expand Up @@ -293,8 +335,6 @@ void cdFit(const raft::handle_t& handle,
* @param loss
* enum to use different loss functions. Only linear regression loss functions is supported
* right now.
* @param stream
* cuda stream
*/
template <typename math_t>
void cdPredict(const raft::handle_t& handle,
Expand All @@ -304,15 +344,14 @@ void cdPredict(const raft::handle_t& handle,
const math_t* coef,
math_t intercept,
math_t* preds,
ML::loss_funct loss,
cudaStream_t stream)
ML::loss_funct loss)
{
ASSERT(n_cols > 0, "Parameter n_cols: number of columns cannot be less than one");
ASSERT(n_rows > 1, "Parameter n_rows: number of rows cannot be less than two");
ASSERT(loss == ML::loss_funct::SQRD_LOSS,
"Parameter loss: Only SQRT_LOSS function is supported for now");

Functions::linearRegH(handle, input, n_rows, n_cols, coef, preds, intercept, stream);
Functions::linearRegH(handle, input, n_rows, n_cols, coef, preds, intercept, handle.get_stream());
}

}; // namespace Solver
Expand Down
14 changes: 8 additions & 6 deletions cpp/src/solver/solver.cu
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,8 @@ void cdFit(raft::handle_t& handle,
float alpha,
float l1_ratio,
bool shuffle,
float tol)
float tol,
float* sample_weight = nullptr)
{
ASSERT(loss == 0, "Parameter loss: Only SQRT_LOSS function is supported for now");

Expand All @@ -318,7 +319,7 @@ void cdFit(raft::handle_t& handle,
l1_ratio,
shuffle,
tol,
handle.get_stream());
sample_weight);
}

void cdFit(raft::handle_t& handle,
Expand All @@ -335,7 +336,8 @@ void cdFit(raft::handle_t& handle,
double alpha,
double l1_ratio,
bool shuffle,
double tol)
double tol,
double* sample_weight = nullptr)
{
ASSERT(loss == 0, "Parameter loss: Only SQRT_LOSS function is supported for now");

Expand All @@ -356,7 +358,7 @@ void cdFit(raft::handle_t& handle,
l1_ratio,
shuffle,
tol,
handle.get_stream());
sample_weight);
}

void cdPredict(raft::handle_t& handle,
Expand All @@ -375,7 +377,7 @@ void cdPredict(raft::handle_t& handle,
ASSERT(false, "glm.cu: other functions are not supported yet.");
}

cdPredict(handle, input, n_rows, n_cols, coef, intercept, preds, loss_funct, handle.get_stream());
cdPredict(handle, input, n_rows, n_cols, coef, intercept, preds, loss_funct);
}

void cdPredict(raft::handle_t& handle,
Expand All @@ -394,7 +396,7 @@ void cdPredict(raft::handle_t& handle,
ASSERT(false, "glm.cu: other functions are not supported yet.");
}

cdPredict(handle, input, n_rows, n_cols, coef, intercept, preds, loss_funct, handle.get_stream());
cdPredict(handle, input, n_rows, n_cols, coef, intercept, preds, loss_funct);
}

} // namespace Solver
Expand Down
Loading