Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML-214] [GPU] use distributed covariance as the first step for PCA #234

Merged
merged 24 commits into from
Apr 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 22 additions & 11 deletions mllib-dal/src/main/native/PCAImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#endif
#include "Communicator.hpp"
#include "OutputHelpers.hpp"
#include "oneapi/dal/algo/covariance.hpp"
#include "oneapi/dal/algo/pca.hpp"
#include "oneapi/dal/table/homogen.hpp"
#endif
Expand Down Expand Up @@ -192,30 +193,36 @@ static void doPCAOneAPICompute(JNIEnv *env, jint rankId, jlong pNumTabData,
const bool isRoot = (rankId == ccl_root);
homogen_table htable =
*reinterpret_cast<const homogen_table *>(pNumTabData);
const auto pca_desc = pca::descriptor{};

const auto cov_desc = covariance::descriptor{}.set_result_options(
covariance::result_options::cov_matrix);
auto queue = getQueue(device);
auto comm = preview::spmd::make_communicator<preview::spmd::backend::ccl>(
queue, executorNum, rankId, ipPort);
pca::train_input local_input{htable};

auto t1 = std::chrono::high_resolution_clock::now();
const auto result_train = preview::train(comm, pca_desc, local_input);
const auto result = preview::compute(comm, cov_desc, htable);
auto t2 = std::chrono::high_resolution_clock::now();
auto duration =
std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1).count();
std::cout << "PCA (native): training step took " << duration / 1000
std::cout << "PCA (native): Covariance step took " << duration / 1000
<< " secs" << std::endl;
if (isRoot) {
std::cout << "Eigenvectors:\n"
<< result_train.get_eigenvectors() << std::endl;
std::cout << "Eigenvalues:\n"
<< result_train.get_eigenvalues() << std::endl;
using float_t = double;
using method_t = pca::method::precomputed;
using task_t = pca::task::dim_reduction;
using descriptor_t = pca::descriptor<float_t, method_t, task_t>;
const auto pca_desc = descriptor_t().set_deterministic(true);

auto t1 = std::chrono::high_resolution_clock::now();
const auto result_train =
train(queue, pca_desc, result.get_cov_matrix());
t2 = std::chrono::high_resolution_clock::now();
duration =
std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1)
.count();
std::cout << "PCA (native): rankid " << rankId
<< "; training step took " << duration / 1000
<< " secs in end. " << std::endl;
std::cout << "PCA (native): rankid " << rankId << "; Eigen step took "
<< duration / 1000 << " secs in end. " << std::endl;
// Return all eigenvalues & eigenvectors
// Get the class of the input object
jclass clazz = env->GetObjectClass(resultObj);
Expand All @@ -224,6 +231,10 @@ static void doPCAOneAPICompute(JNIEnv *env, jint rankId, jlong pNumTabData,
env->GetFieldID(clazz, "pcNumericTable", "J");
jfieldID explainedVarianceNumericTableField =
env->GetFieldID(clazz, "explainedVarianceNumericTable", "J");
std::cout << "Eigenvectors:\n"
<< result_train.get_eigenvectors() << std::endl;
std::cout << "Eigenvalues:\n"
<< result_train.get_eigenvalues() << std::endl;

HomogenTablePtr eigenvectors =
std::make_shared<homogen_table>(result_train.get_eigenvectors());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ class PCADALImpl(val k: Int,
}

private[mllib] def getPrincipleComponentsFromOneAPI(table: HomogenTable,
k: Int,
device: Common.ComputeDevice): DenseMatrix = {
k: Int,
device: Common.ComputeDevice): DenseMatrix = {
val numRows = table.getRowCount.toInt
val numCols = table.getColumnCount.toInt
require(k <= numRows, "k should be less or equal to row number")
Expand Down