Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Matrix transpose bitorder #27

Merged
merged 3 commits into from
Aug 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions matrix-transpose/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ thiserror.workspace = true
[dev-dependencies]
rand.workspace = true
criterion.workspace = true
itybity.workspace = true

[features]
simd-transpose = []
Expand Down
63 changes: 56 additions & 7 deletions matrix-transpose/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use thiserror::Error;

/// This function transposes a matrix on the bit-level.
///
/// Assumes an LSB0 bit encoding of the matrix.
/// This implementation requires that the number of rows is a power of 2
/// and that the number of columns is a multiple of 8
pub fn transpose_bits(matrix: &mut [u8], rows: usize) -> Result<(), TransposeError> {
Expand Down Expand Up @@ -87,17 +88,65 @@ mod tests {
(0..elements).map(|_| rng.gen::<T>()).collect()
}

fn transpose_naive(data: &[u8], row_width: usize) -> Vec<u8> {
use itybity::*;

let bits: Vec<Vec<bool>> = data.chunks(row_width).map(|x| x.to_lsb0_vec()).collect();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't row_width the width in bits ? here it is treated as if it was in bytes

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is correct here to use the row width in bytes, because we iterate in chunks of byte rows and turn this chunk into bit rows. So in the beginning it is bytes and after this operation it is bits.

let col_count = bits[0].len();
let row_count = bits.len();

let mut bits_: Vec<Vec<bool>> = vec![vec![false; row_count]; col_count];
for j in 0..row_count {
for i in 0..col_count {
bits_[i][j] = bits[j][i];
}
}

bits_.into_iter().flat_map(Vec::<u8>::from_lsb0).collect()
}

#[test]
fn test_transpose_bits() {
let mut rows = 512;
let rows = 512;
let columns = 256;

let mut matrix: Vec<u8> = random_vec::<u8>(columns * rows);
let original = matrix.clone();
let naive = transpose_naive(&matrix, columns);

transpose_bits(&mut matrix, rows).unwrap();
rows = columns;
transpose_bits(&mut matrix, 8 * rows).unwrap();
assert_eq!(original, matrix);

assert_eq!(naive, matrix);
}

#[test]
fn test_transpose_naive() {
let matrix = [
// ------- bits in lsb0
3u8, // 1 1 0 0 0 0 0 0
76u8, // 0 0 1 1 0 0 1 0
120u8, // 0 0 0 1 1 1 1 0
9u8, // 1 0 0 1 0 0 0 0
17u8, // 1 0 0 0 1 0 0 0
102u8, // 0 1 1 0 0 1 1 0
53u8, // 1 0 1 0 1 1 0 0
125u8, // 1 0 1 1 1 1 1 0
];

let expected = [
// ------- bits in lsb0
217u8, // 1 0 0 1 1 0 1 1
33u8, // 1 0 0 0 0 1 0 0
226u8, // 0 1 0 0 0 1 1 1
142u8, // 0 1 1 1 0 0 0 1
212u8, // 0 0 1 0 1 0 1 1
228u8, // 0 0 1 0 0 1 1 1
166u8, // 0 1 1 0 0 1 0 1
0u8, // 0 0 0 0 0 0 0 0
];

let naive = transpose_naive(&matrix, 1);

assert_eq!(naive, expected);
}

#[test]
Expand Down Expand Up @@ -141,12 +190,12 @@ mod tests {
for k in 0..8 {
for (l, chunk) in row.chunks(8).enumerate() {
let expected: u8 = chunk.iter().enumerate().fold(0, |acc, (m, element)| {
acc + (element >> 7) * 2_u8.pow(7_u32 - m as u32)
acc + (element & 1) * 2_u8.pow(m as u32)
});
let actual = matrix[row_index * columns + columns / 8 * k + l];
assert_eq!(expected, actual);
}
let shifted_row = row.iter_mut().map(|el| *el << 1).collect::<Vec<u8>>();
let shifted_row = row.iter_mut().map(|el| *el >> 1).collect::<Vec<u8>>();
row.copy_from_slice(&shifted_row);
}
}
Expand Down
5 changes: 3 additions & 2 deletions matrix-transpose/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ where

/// Single-row bit-mask shift
///
/// Assumes an LSB0 bit encoding of the matrix.
/// This function is an implementation of the bit-level transpose in
/// https://docs.rs/oblivious-transfer/latest/oblivious_transfer/extension/fn.transpose128.html
/// Caller has to make sure that columns is a multiple of 8
Expand All @@ -48,8 +49,8 @@ pub fn bitmask_shift(matrix: &mut [u8], columns: usize) {
for bytes in row.chunks_mut(8) {
let mut high_bits: u8 = 0b00000000;
bytes.iter_mut().enumerate().for_each(|(k, b)| {
high_bits |= (0b10000000 & *b) >> k;
*b <<= 1;
high_bits |= (0b00000001 & *b) << k;
*b >>= 1;
});
shifted_row.push(high_bits);
}
Expand Down
10 changes: 9 additions & 1 deletion matrix-transpose/src/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::{

/// SIMD version for bit-level transposition
///
/// Assumes an LSB0 bit encoding of the matrix.
/// This SIMD implementation additionally requires that the matrix has at least
/// 16 (WASM) or 32 (x86_64) columns and rows
#[cfg(any(target_arch = "x86_64", target_arch = "wasm32"))]
Expand Down Expand Up @@ -73,6 +74,7 @@ where

/// Unsafe single-row bit-mask shift
///
/// Assumes an LSB0 bit encoding of the matrix.
/// This function is an implementation of the bit-level transpose in
/// https://docs.rs/oblivious-transfer/latest/oblivious_transfer/extension/fn.transpose128.html
/// Caller has to make sure that columns is a multiple of 16 or 32
Expand All @@ -84,6 +86,7 @@ pub unsafe fn bitmask_shift_unchecked(matrix: &mut [u8], columns: usize) {
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::_mm256_movemask_epi8;

matrix.iter_mut().for_each(|b| *b = b.reverse_bits());
let simd_one = Simd::<u8, LANE_COUNT>::splat(1);
let mut s: Simd<u8, LANE_COUNT>;
for row in matrix.chunks_mut(columns) {
Expand All @@ -95,7 +98,12 @@ pub unsafe fn bitmask_shift_unchecked(matrix: &mut [u8], columns: usize) {
let high_bits = _mm256_movemask_epi8(s.reverse().into());
#[cfg(target_arch = "wasm32")]
let high_bits = u8x16_bitmask(s.reverse().into());
shifted_row.extend_from_slice(&high_bits.to_be_bytes());
let high_bits: Vec<u8> = high_bits
.to_be_bytes()
.into_iter()
.map(|b| b.reverse_bits())
.collect();
shifted_row.extend_from_slice(&high_bits);
s.shl_assign(simd_one);
*chunk = s.to_array();
}
Expand Down