Skip to content

Commit

Permalink
Add multi-node-multi-gpu Logistic Regression in C++ (#5477)
Browse files Browse the repository at this point in the history
This PR enables multi-node-multi-gpu Logistic Regression and it mostly reuses existing codes (i.e. GLMWithData and min_lbfgs) of single-GPU Logistic Regression. No change to any existing codes. 

Added Pytest code for Spark cluster and the tests run successfully with 2 GPUs on a random dataset. The coef_ and intercept_ are the same as single-GPU cuml.LogisticRegression.fit. Pytest code can be found here: https://github.com/lijinf2/spark-rapids-ml/blob/lr/python/tests/test_logistic_regression.py

Authors:
  - Jinfeng Li (https://github.com/lijinf2)

Approvers:
  - Dante Gama Dessavre (https://github.com/dantegd)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #5477
  • Loading branch information
lijinf2 authored Jul 24, 2023
1 parent 1b3ada9 commit e23167c
Show file tree
Hide file tree
Showing 11 changed files with 732 additions and 6 deletions.
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,7 @@ if(BUILD_CUML_CPP_LIBRARY)
src/glm/ols_mg.cu
src/glm/preprocess_mg.cu
src/glm/ridge_mg.cu
src/glm/qn_mg.cu
src/kmeans/kmeans_mg.cu
src/knn/knn_mg.cu
src/knn/knn_classify_mg.cu
Expand Down
56 changes: 56 additions & 0 deletions cpp/include/cuml/linear_model/qn_mg.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cuda_runtime.h>
#include <cuml/common/logger.hpp>
#include <cuml/linear_model/qn.h>
#include <raft/core/comms.hpp>

#include <cumlprims/opg/matrix/data.hpp>
#include <cumlprims/opg/matrix/part_descriptor.hpp>
using namespace MLCommon;

namespace ML {
namespace GLM {
namespace opg {

/**
* @brief performs MNMG fit operation for the logistic regression using quasi newton methods
* @param[in] handle: the internal cuml handle object
* @param[in] input_data: vector holding all partitions for that rank
* @param[in] input_desc: PartDescriptor object for the input
* @param[in] labels: labels data
* @param[out] coef: learned coefficients
* @param[in] pams: model parameters
* @param[in] X_col_major: true if X is stored column-major
* @param[in] n_classes: number of outputs (number of classes or `1` for regression)
* @param[out] f: host pointer holding the final objective value
* @param[out] num_iters: host pointer holding the actual number of iterations taken
*/
void qnFit(raft::handle_t& handle,
std::vector<Matrix::Data<float>*>& input_data,
Matrix::PartDescriptor& input_desc,
std::vector<Matrix::Data<float>*>& labels,
float* coef,
const qn_params& pams,
bool X_col_major,
int n_classes,
float* f,
int* num_iters);

}; // namespace opg
}; // namespace GLM
}; // namespace ML
166 changes: 166 additions & 0 deletions cpp/src/glm/qn/glm_base_mg.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <raft/core/comms.hpp>
#include <raft/core/handle.hpp>
#include <raft/linalg/multiply.cuh>
#include <raft/util/cudart_utils.hpp>

#include <glm/qn/glm_base.cuh>
#include <glm/qn/glm_logistic.cuh>
#include <glm/qn/glm_regularizer.cuh>
#include <glm/qn/qn_solvers.cuh>
#include <glm/qn/qn_util.cuh>

namespace ML {
namespace GLM {
namespace opg {
template <typename T>
// multi-gpu version of linearBwd
inline void linearBwdMG(const raft::handle_t& handle,
SimpleDenseMat<T>& G,
const SimpleMat<T>& X,
const SimpleDenseMat<T>& dZ,
bool setZero,
const int64_t n_samples,
const int n_ranks)
{
cudaStream_t stream = handle.get_stream();
// Backward pass:
// - compute G <- dZ * X.T
// - for bias: Gb = mean(dZ, 1)

const bool has_bias = X.n != G.n;
const int D = X.n;
const T beta = setZero ? T(0) : T(1);

if (has_bias) {
SimpleVec<T> Gbias;
SimpleDenseMat<T> Gweights;

col_ref(G, Gbias, D);

col_slice(G, Gweights, 0, D);

// TODO can this be fused somehow?
Gweights.assign_gemm(handle, 1.0 / n_samples, dZ, false, X, false, beta / n_ranks, stream);

raft::stats::mean(Gbias.data, dZ.data, dZ.m, dZ.n, false, true, stream);
T bias_factor = 1.0 * dZ.n / n_samples;
raft::linalg::multiplyScalar(Gbias.data, Gbias.data, bias_factor, dZ.m, stream);

} else {
CUML_LOG_DEBUG("has bias not enabled");
G.assign_gemm(handle, 1.0 / n_samples, dZ, false, X, false, beta / n_ranks, stream);
}
}

/**
* @brief Aggregates local gradient vectors and loss values from local training data. This
* class is the multi-node-multi-gpu version of GLMWithData.
*
* The implementation overrides existing GLMWithData::() function. The purpose is to
* aggregate local gradient vectors and loss values from distributed X, y, where X represents the
* input vectors and y represents labels.
*
* GLMWithData::() currently invokes three functions: linearFwd, getLossAndDz and linearBwd.
* linearFwd multiplies local input vectors with the coefficient vector (i.e. coef_), so does not
* require communication. getLossAndDz calculates local loss so requires allreduce to obtain a
* global loss. linearBwd calculates local gradient vector so requires allreduce to obtain a
* global gradient vector. The global loss and the global gradient vector will be used in
* min_lbfgs to update coefficient. The update runs individually on every GPU and when finished,
* all GPUs have the same value of coefficient.
*/
template <typename T, class GLMObjective>
struct GLMWithDataMG : ML::GLM::detail::GLMWithData<T, GLMObjective> {
const raft::handle_t* handle_p;
int rank;
int64_t n_samples;
int n_ranks;

GLMWithDataMG(raft::handle_t const& handle,
int rank,
int n_ranks,
int64_t n_samples,
GLMObjective* obj,
const SimpleMat<T>& X,
const SimpleVec<T>& y,
SimpleDenseMat<T>& Z)
: ML::GLM::detail::GLMWithData<T, GLMObjective>(obj, X, y, Z)
{
this->handle_p = &handle;
this->rank = rank;
this->n_ranks = n_ranks;
this->n_samples = n_samples;
}

inline T operator()(const SimpleVec<T>& wFlat,
SimpleVec<T>& gradFlat,
T* dev_scalar,
cudaStream_t stream)
{
SimpleDenseMat<T> W(wFlat.data, this->C, this->dims);
SimpleDenseMat<T> G(gradFlat.data, this->C, this->dims);
SimpleVec<T> lossVal(dev_scalar, 1);

// apply regularization
auto regularizer_obj = this->objective;
auto lossFunc = regularizer_obj->loss;
auto reg = regularizer_obj->reg;
G.fill(0, stream);
reg->reg_grad(dev_scalar, G, W, lossFunc->fit_intercept, stream);
float reg_host;
raft::update_host(&reg_host, dev_scalar, 1, stream);
// note: avoid syncing here because there's a sync before reg_host is used.

// apply linearFwd, getLossAndDz, linearBwd
ML::GLM::detail::linearFwd(
lossFunc->handle, *(this->Z), *(this->X), W); // linear part: forward pass

raft::comms::comms_t const& communicator = raft::resource::get_comms(*(this->handle_p));

lossFunc->getLossAndDZ(dev_scalar, *(this->Z), *(this->y), stream); // loss specific part

// normalize local loss before allreduce sum
T factor = 1.0 * (*this->y).len / this->n_samples;
raft::linalg::multiplyScalar(dev_scalar, dev_scalar, factor, 1, stream);

communicator.allreduce(dev_scalar, dev_scalar, 1, raft::comms::op_t::SUM, stream);
communicator.sync_stream(stream);

linearBwdMG(lossFunc->handle,
G,
*(this->X),
*(this->Z),
false,
n_samples,
n_ranks); // linear part: backward pass

communicator.allreduce(G.data, G.data, this->C * this->dims, raft::comms::op_t::SUM, stream);
communicator.sync_stream(stream);

float loss_host;
raft::update_host(&loss_host, dev_scalar, 1, stream);
raft::resource::sync_stream(*(this->handle_p));
loss_host += reg_host;
lossVal.fill(loss_host + reg_host, stream);

return loss_host;
}
};
}; // namespace opg
}; // namespace GLM
}; // namespace ML
157 changes: 157 additions & 0 deletions cpp/src/glm/qn_mg.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "qn/glm_logistic.cuh"
#include "qn/glm_regularizer.cuh"
#include "qn/qn_util.cuh"
#include "qn/simple_mat/dense.hpp"
#include <cuml/common/logger.hpp>
#include <cuml/linear_model/qn.h>
#include <cuml/linear_model/qn_mg.hpp>
#include <raft/core/comms.hpp>
#include <raft/core/error.hpp>
#include <raft/core/handle.hpp>
#include <raft/util/cudart_utils.hpp>
using namespace MLCommon;

#include "qn/glm_base_mg.cuh"

#include <cuda_runtime.h>

namespace ML {
namespace GLM {
namespace opg {

template <typename T>
void qnFit_impl(const raft::handle_t& handle,
const qn_params& pams,
T* X,
bool X_col_major,
T* y,
size_t N,
size_t D,
size_t C,
T* w0,
T* f,
int* num_iters,
size_t n_samples,
int rank,
int n_ranks)
{
switch (pams.loss) {
case QN_LOSS_LOGISTIC: {
RAFT_EXPECTS(
C == 2,
"qn_mg.cu: only the LOGISTIC loss is supported currently. The number of classes must be 2");
} break;
default: {
RAFT_EXPECTS(false, "qn_mg.cu: unknown loss function type (id = %d).", pams.loss);
}
}

cudaStream_t stream = raft::resource::get_cuda_stream(handle);
auto X_simple = SimpleDenseMat<T>(X, N, D, X_col_major ? COL_MAJOR : ROW_MAJOR);
auto y_simple = SimpleVec<T>(y, N);
SimpleVec<T> coef_simple(w0, D + pams.fit_intercept);

ML::GLM::detail::LBFGSParam<T> opt_param(pams);

// prepare regularizer regularizer_obj
ML::GLM::detail::LogisticLoss<T> loss_func(handle, D, pams.fit_intercept);
T l2 = pams.penalty_l2;
if (pams.penalty_normalized) {
l2 /= n_samples; // l2 /= 1/X.m
}
ML::GLM::detail::Tikhonov<T> reg(l2);
ML::GLM::detail::RegularizedGLM<T, ML::GLM::detail::LogisticLoss<T>, decltype(reg)>
regularizer_obj(&loss_func, &reg);

// prepare GLMWithDataMG
int n_targets = C == 2 ? 1 : C;
rmm::device_uvector<T> tmp(n_targets * N, stream);
SimpleDenseMat<T> Z(tmp.data(), n_targets, N);
auto obj_function =
GLMWithDataMG(handle, rank, n_ranks, n_samples, &regularizer_obj, X_simple, y_simple, Z);

// prepare temporary variables fx, k, workspace
float fx = -1;
int k = -1;
rmm::device_uvector<float> tmp_workspace(lbfgs_workspace_size(opt_param, coef_simple.len),
stream);
SimpleVec<float> workspace(tmp_workspace.data(), tmp_workspace.size());

// call min_lbfgs
min_lbfgs(opt_param, obj_function, coef_simple, fx, &k, workspace, stream, 5);
}

template <typename T>
void qnFit_impl(raft::handle_t& handle,
std::vector<Matrix::Data<T>*>& input_data,
Matrix::PartDescriptor& input_desc,
std::vector<Matrix::Data<T>*>& labels,
T* coef,
const qn_params& pams,
bool X_col_major,
int n_classes,
T* f,
int* num_iters)
{
RAFT_EXPECTS(input_data.size() == 1,
"qn_mg.cu currently does not accept more than one input matrix");
RAFT_EXPECTS(labels.size() == input_data.size(), "labels size does not equal to input_data size");

auto data_X = input_data[0];
auto data_y = labels[0];

size_t n_samples = 0;
for (auto p : input_desc.partsToRanks) {
n_samples += p->size;
}

qnFit_impl<T>(handle,
pams,
data_X->ptr,
X_col_major,
data_y->ptr,
input_desc.totalElementsOwnedBy(input_desc.rank),
input_desc.N,
n_classes,
coef,
f,
num_iters,
input_desc.M,
input_desc.rank,
input_desc.uniqueRanks().size());
}

void qnFit(raft::handle_t& handle,
std::vector<Matrix::Data<float>*>& input_data,
Matrix::PartDescriptor& input_desc,
std::vector<Matrix::Data<float>*>& labels,
float* coef,
const qn_params& pams,
bool X_col_major,
int n_classes,
float* f,
int* num_iters)
{
qnFit_impl<float>(
handle, input_data, input_desc, labels, coef, pams, X_col_major, n_classes, f, num_iters);
}

}; // namespace opg
}; // namespace GLM
}; // namespace ML
Loading

0 comments on commit e23167c

Please sign in to comment.