Skip to content

Commit

Permalink
BIG, removes Eigen-test dep. && review
Browse files Browse the repository at this point in the history
  • Loading branch information
aryamanjeendgar committed Nov 4, 2024
1 parent 6da1463 commit 6ade616
Show file tree
Hide file tree
Showing 2 changed files with 232 additions and 292 deletions.
50 changes: 21 additions & 29 deletions RandBLAS/trig_skops.hh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include "RandBLAS/base.hh"
#include "RandBLAS/exceptions.hh"
#include "RandBLAS/random_gen.hh"
#include <RandBLAS/sparse_skops.hh>
#include "RandBLAS/sparse_skops.hh"

#include <Random123/philox.h>
#include <blas.hh>
Expand All @@ -27,13 +27,13 @@ namespace RandBLAS {
///

// Generates a vector of Rademacher entries using the Random123 library
template<SignedInteger sint_t = int64_t, typename RNG = DefaultRNG>
RNGState<RNG> generate_rademacher_vector_r123(sint_t* buff, int64_t n, RNGState<RNG> seed_state) {
RNG rng;
template<SignedInteger sint_t = int64_t, typename state_t = RNGState<DefaultRNG>>
state_t generate_rademacher_vector_r123(sint_t* buff, int64_t n, state_t &seed_state) {
DefaultRNG rng;
auto [ctr, key] = seed_state;

for (int64_t i = 0; i < n; ++i) {
typename RNG::ctr_type r = rng(ctr, key);
typename DefaultRNG::ctr_type r = rng(ctr, key);

float rand_value = r123::u01fixedpt<float>(r.v[0]);

Expand All @@ -43,7 +43,7 @@ namespace RandBLAS {
}

// Return the updated RNGState (with the incremented counter)
return RNGState<RNG> {ctr, key};
return state_t {ctr, key};
}

// Catch-all method for applying the diagonal Rademacher
Expand Down Expand Up @@ -107,19 +107,15 @@ namespace RandBLAS {
// Use BLAS swap to swap the entire rows
// Swapping row 'selected' with row 'top'
blas::swap(cols, &A[top], rows, &A[selected_rows[i]], rows);
}
else
continue;
} // else, continue;
}
}
else {
// For `RowMajor` ordering
for (int64_t i=0; i < d; i++) {
if (selected_rows[i] != top) {
blas::swap(cols, &A[cols * selected_rows[i]], 1, &A[cols * top], 1);
}
else
continue;
} // else, continue;
}
}
}
Expand All @@ -142,8 +138,6 @@ namespace RandBLAS {
// Swapping col 'selected' with col 'top'
blas::swap(rows, &A[rows * selected_cols[i]], 1, &A[rows * left], 1);
}
else
continue;
}
}
else {
Expand All @@ -152,8 +146,6 @@ namespace RandBLAS {
if (selected_cols[i] != left) {
blas::swap(rows, &A[selected_cols[i]], cols, &A[left], cols);
}
else
continue;
}
}
}
Expand Down Expand Up @@ -208,8 +200,8 @@ namespace RandBLAS {
for (int64_t k = 0; k < s1; ++k) {
// For implicitly padding the input we just have to make sure
// we replace all out-of-bounds accesses with zeros
bool b1 = j + k < num_rows;
bool b2 = j + k + s1 < num_rows;
bool b1 = (j + k) * num_cols + col < num_rows * num_cols;
bool b2 = (j + k + s1) * num_cols + col < num_rows * num_cols;
T u = b1 ? buf[(j + k) * num_cols + col] : 0;
T v = b2 ? buf[(j + k + s1) * num_cols + col] : 0;
if(b1 && b2) {
Expand Down Expand Up @@ -277,8 +269,8 @@ namespace RandBLAS {
for (int64_t k = 0; k < s1; ++k) {
// For implicitly padding the input we just have to make sure
// we replace all out-of-bounds accesses with zeros
bool b1 = j + k < num_cols;
bool b2 = j + k + s1 < num_cols;
bool b1 = (j + k) * num_rows + row < num_cols * num_rows;
bool b2 = (j + k + s1) * num_rows + row < num_cols * num_rows;
T u = b1 ? buf[(j + k) * num_rows + row] : 0;
T v = b2 ? buf[(j + k + s1) * num_rows + row] : 0;
if(b1 && b2) {
Expand Down Expand Up @@ -346,8 +338,8 @@ struct HadamardMixingOp{

// Destructor
~HadamardMixingOp() {
free(this->diag_scale);
free(this->selected_idxs);
delete [] this->diag_scale;
delete [] this->selected_idxs;
}

private:
Expand All @@ -358,7 +350,7 @@ struct HadamardMixingOp{
* the inversion is also performed in-place
*/
template <typename T, SignedInteger sint_t = int64_t>
void invert(
void invert_hadamard(
HadamardMixingOp<sint_t> &hmo, // details about the transform
T* SA // sketched matrix
) {
Expand Down Expand Up @@ -406,17 +398,17 @@ void invert(
* A: (m x n), input dimensions of `A`
* d: The number of rows/columns that will be permuted by the action of $\Pi$
*/
template <typename T, typename RNG = DefaultRNG, SignedInteger sint_t = int64_t>
inline RNGState<RNG> miget(
template <typename T, typename state_t = RNGState<DefaultRNG>, SignedInteger sint_t = int64_t>
inline state_t miget(
HadamardMixingOp<sint_t> &hmo, // All information about `A` && the $\mathbb{\Pi\text{RHT}}$
const RNGState<RNG> &random_state,
const state_t &random_state,
T* A // The data-matrix
) {
auto [ctr, key] = random_state;

//Step 1: Scale with `D`
//Populating `diag`
RNGState<RNG> state_idxs = generate_rademacher_vector_r123(hmo.diag_scale, hmo.n, random_state);
auto next_state = generate_rademacher_vector_r123(hmo.diag_scale, hmo.n, random_state);
apply_diagonal_rademacher(hmo.left, hmo.layout, hmo.m, hmo.n, A, hmo.diag_scale);

//Step 2: Apply the Hadamard transform
Expand All @@ -433,12 +425,12 @@ inline RNGState<RNG> miget(

//Step 3: Permute the rows
// Uniformly samples `d` entries from the index set [0, ..., m - 1]
RNGState<RNG> next_state = repeated_fisher_yates<sint_t>(
next_state = repeated_fisher_yates<sint_t>(
hmo.d,
hmo.m,
1,
hmo.selected_idxs,
state_idxs
next_state
);

if(hmo.left)
Expand Down
Loading

0 comments on commit 6ade616

Please sign in to comment.