Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CP Block Coordinate Descent #157

Open
wants to merge 26 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
f859521
Start making a block coordinate diagonal decomposition
kmp5VT May 3, 2023
f3c84a2
Merge branch 'kmp5/debug/fix_cp_test' into kmp5/feature/cp-bcd
kmp5VT May 3, 2023
bc235f0
Use target instead of assuming tensor_ref
kmp5VT May 3, 2023
4079f1a
Start fixing BCD
kmp5VT May 3, 2023
f4a5ab8
Fix cp_bcd
kmp5VT May 4, 2023
1f2f05a
Add function to change norm of converge_test
kmp5VT May 4, 2023
62da071
Update cp_bcd.h
kmp5VT May 8, 2023
df4938b
BCD now completely working!
kmp5VT May 8, 2023
5bb3bc8
BCD wasn't calculating gradient correctly
kmp5VT May 18, 2023
d66b529
Add BCD unit tests
kmp5VT May 18, 2023
d56085e
Allow more than 1 sweep of block coordinates, this allows each block …
kmp5VT May 25, 2023
c79a7cc
Fix for complex tensors
kmp5VT May 25, 2023
132eecb
one is const, one is not
kmp5VT May 26, 2023
622e258
Merge branch 'master' into kmp5/feature/cp-bcd
kmp5VT May 26, 2023
2b3a4a0
create `compute_full` fit function
kmp5VT Jul 18, 2023
0bd5473
Merge branch 'master' into kmp5/feature/cp-bcd
kmp5VT Jul 24, 2023
dbd76aa
Add ability to use different block sizes
kmp5VT Aug 17, 2023
4008a5e
Merge branch 'master' into kmp5/feature/cp-bcd
kmp5VT Aug 17, 2023
4a1304c
update BCD printing
kmp5VT Aug 22, 2023
562f2c2
Merge branch 'master' into kmp5/feature/cp-bcd
kmp5VT Dec 22, 2023
8bcee70
Goal here is to try make new CP
kmp5VT Jan 10, 2024
d74188f
Merge branch 'kmp5/debug/fix_flatten' into kmp5/feature/cp-bcd
kmp5VT Jan 10, 2024
fdb1b8d
Fix testing area
kmp5VT Jan 24, 2024
b9490ec
Merge remote-tracking branch 'origin/master' into kmp5/feature/cp-bcd
kmp5VT Jan 24, 2024
0614c06
Merge branch 'master' into kmp5/feature/cp-bcd
kmp5VT Jan 30, 2024
8007360
Bump VG tag
kmp5VT Jan 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions btas/btas.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@

#include <btas/generic/cp_als.h>
#include <btas/generic/cp_rals.h>
//#include <btas/generic/cp_id.h>
#include <btas/generic/cp_df_als.h>
#include <btas/generic/tuck_cp_als.h>
#include <btas/generic/coupled_cp_als.h>
#include <btas/generic/cp_bcd.h>
#include <btas/generic/dot_impl.h>
#include <btas/generic/scal_impl.h>
#include <btas/generic/axpy_impl.h>
Expand Down
2 changes: 1 addition & 1 deletion btas/generic/cp.h
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ namespace btas {
/// \param[in] Mat Calculates the 2-norm of the matrix mat
/// \return the 2-norm.

auto norm(const Tensor &Mat) { return sqrt(abs(dot(Mat, Mat))); }
auto norm(const Tensor &Mat) { return sqrt(abs(dot(Mat, btas::impl::conj(Mat)))); }

/// SVD referencing code from
/// http://www.netlib.org/lapack/explore-html/de/ddd/lapacke_8h_af31b3cb47f7cc3b9f6541303a2968c9f.html
Expand Down
78 changes: 77 additions & 1 deletion btas/generic/cp_als.h
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ namespace btas {
Tensor &tensor_ref; // Tensor to be decomposed
ord_t size; // Total number of elements
bool factors_set = false; // Are the factors preset (not implemented yet).
std::vector<std::vector<int64_t>> pivs;

/// Creates an initial guess by computing the SVD of each mode
/// If the rank of the mode is smaller than the CP rank requested
Expand Down Expand Up @@ -658,7 +659,82 @@ namespace btas {
double lambda = 0) {
Tensor temp(A[n].extent(0), rank);
Tensor an(A[n].range());

// Testing the code to see if pivoted QR can help
if (false) {
// First create a Pivot matrix from the flattened tensor_ref
auto f = flatten(tensor_ref, n);
auto square_dim = f.extent(0), full = f.extent(1);
auto scale_factor = double(full) / double(square_dim);
auto extended = (scale_factor > 2.0 ? square_dim * (scale_factor - 1.0) : full);
std::vector<int64_t> piv;
if (pivs.size() < (n + 1)) {
piv = std::vector<int64_t>(f.extent(1));
std::vector<T> tau(f.extent(1));
btas::geqp3_pivot(blas::Layout::RowMajor, f.extent(0), f.extent(1), f.data(), f.extent(1), piv.data(),
tau.data());
Tensor R(full, square_dim);
R.fill(0.0);
f = flatten(tensor_ref, n);
pivs.emplace_back(piv);
} else {
piv = pivs[n];
}

auto K = this->generate_KRP(n, rank, true);
// For reference I compute the Matricized tensor times khatri rao product
gemm(blas::Op::NoTrans, blas::Op::NoTrans, this->one, f, K, this->zero, temp);
detail::set_MtKRP(converge_test, temp);

Tensor Fp(square_dim, square_dim);
{
Tensor t(square_dim, full);
for (auto i = 0; i < square_dim; ++i) {
for (auto j = 0; j < full; ++j) {
int v = pivs[n][j];
t(i, j) = f(i, v);
}
}

for (auto i = 0; i < square_dim; ++i) {
for (auto j = 0; j < square_dim; ++j) {
Fp(i, j) = t(i, j);
}
}
}

Tensor Kp(square_dim, rank);
{
Tensor t(full, rank);
for (auto j = 0; j < full; ++j) {
int v = pivs[n][j];
for (auto r = 0; r < rank; ++r) {
t(j, r) = K(v, r);
}
}

for (auto j = 0; j < square_dim; ++j) {
for (auto r = 0; r < rank; ++r) {
Kp(j, r) = t(j, r);
}
}
}

// contract the product from above with the pseudoinverse of the Hadamard
// produce an optimize factor matrix
fast_pI = false;
auto pInv = pseudoInverse(Kp, fast_pI);

Tensor t;
gemm(blas::Op::NoTrans, blas::Op::NoTrans, this->one, Fp, pInv, this->zero, temp);

// compute the difference between this new factor matrix and the previous
// iteration
this->normCol(temp);

// Replace the old factor matrix with the new optimized result
A[n] = temp;
return;
}
#ifdef BTAS_HAS_INTEL_MKL

// Computes the Khatri-Rao product intermediate
Expand Down
Loading
Loading