Skip to content

Commit

Permalink
implemented compute_next_state for DenseSkOp and SparseSkOp
Browse files Browse the repository at this point in the history
  • Loading branch information
rileyjmurray committed Jul 27, 2024
1 parent 421197c commit ae624fc
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 15 deletions.
25 changes: 19 additions & 6 deletions RandBLAS/dense_skops.hh
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,6 @@
namespace RandBLAS::dense {


template <typename RNG, typename DD>
static inline RNGState<RNG> compute_next_state(DD dist, RNGState<RNG> state) {
// Need logic that depends on DenseDistName.
return RNGState<RNG>(0);
}

template <typename T_IN, typename T_OUT>
inline void copy_promote(int n, const T_IN &a, T_OUT* b) {
for (int i = 0; i < n; ++i)
Expand Down Expand Up @@ -184,6 +178,25 @@ static RNGState<RNG> fill_dense_submat_impl(
return RNGState<RNG> {max_c, k};
}

template <typename RNG, typename DD>
RNGState<RNG> compute_next_state(DD dist, RNGState<RNG> state) {
if (dist.major_axis == MajorAxis::Undefined) {
// implies dist.family = DenseDistName::BlackBox
return state;
}
int64_t major_len = major_axis_length(dist);
int64_t minor_len = dist.n_rows + (dist.n_cols - major_len);
int64_t ctr_size = RNG::ctr_type::static_size;
int64_t pad = 0;
if (major_len % ctr_size != 0) {
pad = ctr_size - major_len % ctr_size;
}
int64_t ctr_major_axis_stride = (major_len + pad) / ctr_size;
int64_t full_incr = safe_signed_int_product(ctr_major_axis_stride, minor_len);
state.counter.incr(full_incr);
return state;
}

} // end namespace RandBLAS::dense


Expand Down
27 changes: 18 additions & 9 deletions RandBLAS/sparse_skops.hh
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,6 @@

namespace RandBLAS::sparse {

template <typename RNG, typename SD>
static RNGState<RNG> compute_next_state(SD dist, RNGState<RNG> seed_state) {
return RNGState<RNG>(0);
}

// =============================================================================
/// WARNING: this function is not part of the public API.
Expand All @@ -73,11 +69,11 @@ static RNGState<RNG> repeated_fisher_yates(
auto [ctr, key] = state;
for (sint_t i = 0; i < dim_minor; ++i) {
sint_t offset = i * vec_nnz;
auto ctri = ctr;
ctri.incr(offset);
auto ctr_work = ctr;
ctr_work.incr(offset);
for (sint_t j = 0; j < vec_nnz; ++j) {
// one step of Fisher-Yates shuffling
auto rv = gen(ctri, key);
auto rv = gen(ctr_work, key);
sint_t ell = j + rv[0] % (dim_major - j);
pivots[j] = ell;
sint_t swap = vec_work[ell];
Expand All @@ -88,7 +84,7 @@ static RNGState<RNG> repeated_fisher_yates(
vals[j + offset] = (rv[1] % 2 == 0) ? 1.0 : -1.0;
idxs_minor[j + offset] = (sint_t) i;
// increment counter
ctri.incr();
ctr_work.incr();
}
// Restore vec_work for next iteration of Fisher-Yates.
// This isn't necessary from a statistical perspective,
Expand All @@ -101,10 +97,23 @@ static RNGState<RNG> repeated_fisher_yates(
vec_work[jj] = vec_work[ell];
vec_work[ell] = swap;
}
ctr = ctri;
}
return RNGState<RNG> {ctr, key};
}

template <typename RNG, typename SD>
static RNGState<RNG> compute_next_state(SD dist, RNGState<RNG> state) {
int64_t minor_len;
if (dist.major_axis == MajorAxis::Short) {
minor_len = std::min(dist.n_rows, dist.n_cols);
} else {
minor_len = std::max(dist.n_rows, dist.n_cols);
}
int64_t full_incr = minor_len * dist.vec_nnz;
state.counter.incr(full_incr);
return state;
}

}

namespace RandBLAS {
Expand Down

0 comments on commit ae624fc

Please sign in to comment.