Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Updated thrust shuffle to use improved bijective function #1566

Merged
merged 4 commits into from
Jan 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions internal/benchmark/bench.cu
Original file line number Diff line number Diff line change
Expand Up @@ -992,15 +992,13 @@ void run_core_primitives_experiments_for_type()
, RegularTrials
>::run_experiment();

#if THRUST_CPP_DIALECT >= 2011
experiment_driver<
shuffle_tester
, ElementMetaType
, Elements / sizeof(typename ElementMetaType::type)
, BaselineTrials
, RegularTrials
>::run_experiment();
#endif
}

///////////////////////////////////////////////////////////////////////////////
Expand Down
20 changes: 9 additions & 11 deletions testing/shuffle.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include <thrust/detail/config.h>

#if THRUST_CPP_DIALECT >= 2011
#include <map>
#include <limits>
#include <thrust/random.h>
Expand Down Expand Up @@ -383,7 +382,7 @@ void TestFunctionIsBijection(size_t m) {
thrust::system::detail::generic::feistel_bijection host_f(m, host_g);
thrust::system::detail::generic::feistel_bijection device_f(m, device_g);

if (host_f.nearest_power_of_two() >= std::numeric_limits<T>::max() || m == 0) {
if (static_cast<double>(host_f.nearest_power_of_two()) >= static_cast<double>(std::numeric_limits<T>::max()) || m == 0) {
return;
}

Expand All @@ -410,17 +409,17 @@ DECLARE_VARIABLE_UNITTEST(TestFunctionIsBijection);
void TestBijectionLength() {
thrust::default_random_engine g(0xD5);

uint64_t m = 3;
uint64_t m = 31;
thrust::system::detail::generic::feistel_bijection f(m, g);
ASSERT_EQUAL(f.nearest_power_of_two(), uint64_t(4));
ASSERT_EQUAL(f.nearest_power_of_two(), uint64_t(32));

m = 2;
m = 32;
f = thrust::system::detail::generic::feistel_bijection(m, g);
ASSERT_EQUAL(f.nearest_power_of_two(), uint64_t(2));
ASSERT_EQUAL(f.nearest_power_of_two(), uint64_t(32));

m = 0;
m = 1;
f = thrust::system::detail::generic::feistel_bijection(m, g);
ASSERT_EQUAL(f.nearest_power_of_two(), uint64_t(1));
ASSERT_EQUAL(f.nearest_power_of_two(), uint64_t(16));
}
DECLARE_UNITTEST(TestBijectionLength);

Expand Down Expand Up @@ -515,7 +514,7 @@ void TestShuffleEvenSpacingBetweenOccurances() {
thrust::host_vector<T> h_results;
Vector sequence(shuffle_size);
thrust::sequence(sequence.begin(), sequence.end(), 0);
thrust::default_random_engine g(0xD5);
thrust::default_random_engine g(0xD6);
for (auto i = 0ull; i < num_samples; i++) {
thrust::shuffle(sequence.begin(), sequence.end(), g);
thrust::host_vector<T> tmp(sequence.begin(), sequence.end());
Expand Down Expand Up @@ -561,7 +560,7 @@ void TestShuffleEvenDistribution() {
const uint64_t shuffle_sizes[] = {10, 100, 500};
thrust::default_random_engine g(0xD5);
for (auto shuffle_size : shuffle_sizes) {
if(shuffle_size > std::numeric_limits<T>::max())
if(shuffle_size > (uint64_t)std::numeric_limits<T>::max())
continue;
const uint64_t num_samples = shuffle_size == 500 ? 1000 : 200;

Expand Down Expand Up @@ -601,4 +600,3 @@ void TestShuffleEvenDistribution() {
}
}
DECLARE_INTEGRAL_VECTOR_UNITTEST(TestShuffleEvenDistribution);
#endif
75 changes: 24 additions & 51 deletions thrust/system/detail/generic/shuffle.inl
Original file line number Diff line number Diff line change
Expand Up @@ -48,36 +48,42 @@ class feistel_bijection {
right_side_bits = total_bits - left_side_bits;
right_side_mask = (1ull << right_side_bits) - 1;

for (std::uint64_t i = 0; i < num_rounds; i++) {
for (std::uint32_t i = 0; i < num_rounds; i++) {
key[i] = g();
}
}

__host__ __device__ std::uint64_t nearest_power_of_two() const {
return 1ull << (left_side_bits + right_side_bits);
}
__host__ __device__ std::uint64_t operator()(const std::uint64_t val) const {
// Extract the right and left sides of the input
auto left = static_cast<std::uint32_t>(val >> right_side_bits);
auto right = static_cast<std::uint32_t>(val & right_side_mask);
round_state state = {left, right};

for (std::uint64_t i = 0; i < num_rounds; i++) {
state = do_round(state, i);
__host__ __device__ std::uint64_t operator()(const std::uint64_t val) const {
std::uint32_t state[2] = { static_cast<std::uint32_t>( val >> right_side_bits ), static_cast<std::uint32_t>( val & right_side_mask ) };
for( std::uint32_t i = 0; i < num_rounds; i++ )
{
std::uint32_t hi, lo;
constexpr std::uint64_t M0 = UINT64_C( 0xD2B74407B1CE6E93 );
mulhilo( M0, state[0], hi, lo );
lo = ( lo << ( right_side_bits - left_side_bits ) ) | state[1] >> left_side_bits;
state[0] = ( ( hi ^ key[i] ) ^ state[1] ) & left_side_mask;
state[1] = lo & right_side_mask;
}

// Check we have the correct number of bits on each side
assert((state.left >> left_side_bits) == 0);
assert((state.right >> right_side_bits) == 0);

// Combine the left and right sides together to get result
return state.left << right_side_bits | state.right;
return static_cast<std::uint64_t>(state[0] << right_side_bits) | static_cast<std::uint64_t>(state[1]);
}

private:
// Perform 64 bit multiplication and save result in two 32 bit int
static __host__ __device__ void mulhilo( std::uint64_t a, std::uint64_t b, std::uint32_t& hi, std::uint32_t& lo )
{
std::uint64_t product = a * b;
hi = static_cast<std::uint32_t>( product >> 32 );
lo = static_cast<std::uint32_t>( product );
}

// Find the nearest power of two
__host__ __device__ std::uint64_t get_cipher_bits(std::uint64_t m) {
if (m == 0) return 0;
static __host__ __device__ std::uint64_t get_cipher_bits(std::uint64_t m) {
if (m <= 16) return 4;
std::uint64_t i = 0;
m--;
while (m != 0) {
Expand All @@ -87,45 +93,12 @@ class feistel_bijection {
return i;
}

// Equivalent to boost::hash_combine
__host__ __device__
std::size_t hash_combine(std::uint64_t lhs, std::uint64_t rhs) const {
lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
return lhs;
}

// Round function, a 'pseudorandom function' who's output is indistinguishable
// from random for each key value input. This is not cryptographically secure
// but sufficient for generating permutations.
__host__ __device__ std::uint32_t round_function(std::uint64_t value,
const std::uint64_t key_) const {
std::uint64_t hash0 = thrust::random::taus88(static_cast<std::uint32_t>(value))();
std::uint64_t hash1 = thrust::random::ranlux48(value)();
return static_cast<std::uint32_t>(
hash_combine(hash_combine(hash0, key_), hash1) & left_side_mask);
}

__host__ __device__ round_state do_round(const round_state state,
const std::uint64_t round) const {
const std::uint32_t new_left = state.right & left_side_mask;
const std::uint32_t round_function_res =
state.left ^ round_function(state.right, key[round]);
if (right_side_bits != left_side_bits) {
// Upper bit of the old right becomes lower bit of new right if we have
// odd length feistel
const std::uint32_t new_right =
(round_function_res << 1ull) | state.right >> left_side_bits;
return {new_left, new_right};
}
return {new_left, round_function_res};
}

static constexpr std::uint64_t num_rounds = 16;
static constexpr std::uint32_t num_rounds = 24;
std::uint64_t right_side_bits;
std::uint64_t left_side_bits;
std::uint64_t right_side_mask;
std::uint64_t left_side_mask;
std::uint64_t key[num_rounds];
std::uint32_t key[num_rounds];
};

struct key_flag_tuple {
Expand Down