Skip to content

Commit

Permalink
Matrix transpose bitorder (#27)
Browse files Browse the repository at this point in the history
* add naive test

* Change algorithm to work for LSB0 encoding instead of MSB0

* add test

---------

Co-authored-by: th4s <th4s@metavoid.xyz>
Co-authored-by: themighty1 <you@example.com>
  • Loading branch information
3 people authored Aug 11, 2023
1 parent 57a6ccf commit 6c4af5a
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 10 deletions.
1 change: 1 addition & 0 deletions matrix-transpose/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ thiserror.workspace = true
[dev-dependencies]
rand.workspace = true
criterion.workspace = true
itybity.workspace = true

[features]
simd-transpose = []
Expand Down
63 changes: 56 additions & 7 deletions matrix-transpose/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use thiserror::Error;

/// This function transposes a matrix on the bit-level.
///
/// Assumes an LSB0 bit encoding of the matrix.
/// This implementation requires that the number of rows is a power of 2
/// and that the number of columns is a multiple of 8
pub fn transpose_bits(matrix: &mut [u8], rows: usize) -> Result<(), TransposeError> {
Expand Down Expand Up @@ -87,17 +88,65 @@ mod tests {
(0..elements).map(|_| rng.gen::<T>()).collect()
}

fn transpose_naive(data: &[u8], row_width: usize) -> Vec<u8> {
use itybity::*;

let bits: Vec<Vec<bool>> = data.chunks(row_width).map(|x| x.to_lsb0_vec()).collect();
let col_count = bits[0].len();
let row_count = bits.len();

let mut bits_: Vec<Vec<bool>> = vec![vec![false; row_count]; col_count];
for j in 0..row_count {
for i in 0..col_count {
bits_[i][j] = bits[j][i];
}
}

bits_.into_iter().flat_map(Vec::<u8>::from_lsb0).collect()
}

#[test]
fn test_transpose_bits() {
let mut rows = 512;
let rows = 512;
let columns = 256;

let mut matrix: Vec<u8> = random_vec::<u8>(columns * rows);
let original = matrix.clone();
let naive = transpose_naive(&matrix, columns);

transpose_bits(&mut matrix, rows).unwrap();
rows = columns;
transpose_bits(&mut matrix, 8 * rows).unwrap();
assert_eq!(original, matrix);

assert_eq!(naive, matrix);
}

#[test]
fn test_transpose_naive() {
let matrix = [
// ------- bits in lsb0
3u8, // 1 1 0 0 0 0 0 0
76u8, // 0 0 1 1 0 0 1 0
120u8, // 0 0 0 1 1 1 1 0
9u8, // 1 0 0 1 0 0 0 0
17u8, // 1 0 0 0 1 0 0 0
102u8, // 0 1 1 0 0 1 1 0
53u8, // 1 0 1 0 1 1 0 0
125u8, // 1 0 1 1 1 1 1 0
];

let expected = [
// ------- bits in lsb0
217u8, // 1 0 0 1 1 0 1 1
33u8, // 1 0 0 0 0 1 0 0
226u8, // 0 1 0 0 0 1 1 1
142u8, // 0 1 1 1 0 0 0 1
212u8, // 0 0 1 0 1 0 1 1
228u8, // 0 0 1 0 0 1 1 1
166u8, // 0 1 1 0 0 1 0 1
0u8, // 0 0 0 0 0 0 0 0
];

let naive = transpose_naive(&matrix, 1);

assert_eq!(naive, expected);
}

#[test]
Expand Down Expand Up @@ -141,12 +190,12 @@ mod tests {
for k in 0..8 {
for (l, chunk) in row.chunks(8).enumerate() {
let expected: u8 = chunk.iter().enumerate().fold(0, |acc, (m, element)| {
acc + (element >> 7) * 2_u8.pow(7_u32 - m as u32)
acc + (element & 1) * 2_u8.pow(m as u32)
});
let actual = matrix[row_index * columns + columns / 8 * k + l];
assert_eq!(expected, actual);
}
let shifted_row = row.iter_mut().map(|el| *el << 1).collect::<Vec<u8>>();
let shifted_row = row.iter_mut().map(|el| *el >> 1).collect::<Vec<u8>>();
row.copy_from_slice(&shifted_row);
}
}
Expand Down
5 changes: 3 additions & 2 deletions matrix-transpose/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ where

/// Single-row bit-mask shift
///
/// Assumes an LSB0 bit encoding of the matrix.
/// This function is an implementation of the bit-level transpose in
/// https://docs.rs/oblivious-transfer/latest/oblivious_transfer/extension/fn.transpose128.html
/// Caller has to make sure that columns is a multiple of 8
Expand All @@ -48,8 +49,8 @@ pub fn bitmask_shift(matrix: &mut [u8], columns: usize) {
for bytes in row.chunks_mut(8) {
let mut high_bits: u8 = 0b00000000;
bytes.iter_mut().enumerate().for_each(|(k, b)| {
high_bits |= (0b10000000 & *b) >> k;
*b <<= 1;
high_bits |= (0b00000001 & *b) << k;
*b >>= 1;
});
shifted_row.push(high_bits);
}
Expand Down
10 changes: 9 additions & 1 deletion matrix-transpose/src/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::{

/// SIMD version for bit-level transposition
///
/// Assumes an LSB0 bit encoding of the matrix.
/// This SIMD implementation additionally requires that the matrix has at least
/// 16 (WASM) or 32 (x86_64) columns and rows
#[cfg(any(target_arch = "x86_64", target_arch = "wasm32"))]
Expand Down Expand Up @@ -73,6 +74,7 @@ where

/// Unsafe single-row bit-mask shift
///
/// Assumes an LSB0 bit encoding of the matrix.
/// This function is an implementation of the bit-level transpose in
/// https://docs.rs/oblivious-transfer/latest/oblivious_transfer/extension/fn.transpose128.html
/// Caller has to make sure that columns is a multiple of 16 or 32
Expand All @@ -84,6 +86,7 @@ pub unsafe fn bitmask_shift_unchecked(matrix: &mut [u8], columns: usize) {
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::_mm256_movemask_epi8;

matrix.iter_mut().for_each(|b| *b = b.reverse_bits());
let simd_one = Simd::<u8, LANE_COUNT>::splat(1);
let mut s: Simd<u8, LANE_COUNT>;
for row in matrix.chunks_mut(columns) {
Expand All @@ -95,7 +98,12 @@ pub unsafe fn bitmask_shift_unchecked(matrix: &mut [u8], columns: usize) {
let high_bits = _mm256_movemask_epi8(s.reverse().into());
#[cfg(target_arch = "wasm32")]
let high_bits = u8x16_bitmask(s.reverse().into());
shifted_row.extend_from_slice(&high_bits.to_be_bytes());
let high_bits: Vec<u8> = high_bits
.to_be_bytes()
.into_iter()
.map(|b| b.reverse_bits())
.collect();
shifted_row.extend_from_slice(&high_bits);
s.shl_assign(simd_one);
*chunk = s.to_array();
}
Expand Down

0 comments on commit 6c4af5a

Please sign in to comment.