Skip to content

Commit

Permalink
Fix EltwiseReduceMod (#90)
Browse files Browse the repository at this point in the history
* Fix AVVX512DQ EltwiseReduceMod
  • Loading branch information
fboemer authored Nov 1, 2021
1 parent 8a976dd commit accf7a5
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 7 deletions.
19 changes: 12 additions & 7 deletions hexl/eltwise/eltwise-reduce-mod.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,18 +97,23 @@ void EltwiseReduceMod(uint64_t* result, const uint64_t* operand, uint64_t n,
}
return;
}

#ifdef HEXL_HAS_AVX512IFMA
if (has_avx512ifma && modulus < (1ULL << 52)) {
EltwiseReduceModAVX512<52>(result, operand, n, modulus, input_mod_factor,
output_mod_factor);
return;
}
#endif

#ifdef HEXL_HAS_AVX512DQ
if (has_avx512dq) {
if (modulus < (1ULL << 52)) {
EltwiseReduceModAVX512<52>(result, operand, n, modulus, input_mod_factor,
output_mod_factor);
} else {
EltwiseReduceModAVX512<64>(result, operand, n, modulus, input_mod_factor,
output_mod_factor);
}
EltwiseReduceModAVX512<64>(result, operand, n, modulus, input_mod_factor,
output_mod_factor);
return;
}
#endif

HEXL_VLOG(3, "Calling EltwiseReduceModNative");
EltwiseReduceModNative(result, operand, n, modulus, input_mod_factor,
output_mod_factor);
Expand Down
44 changes: 44 additions & 0 deletions test/test-eltwise-reduce-mod.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "hexl/logging/logging.hpp"
#include "hexl/number-theory/number-theory.hpp"
#include "test-util.hpp"
#include "util/util-internal.hpp"

namespace intel {
namespace hexl {
Expand Down Expand Up @@ -79,5 +80,48 @@ TEST(EltwiseReduceMod, 4_2) {
CheckEqual(result, exp_out);
}

// First parameter is the number of bits in the modulus
// Second parameter is whether or not to prefer small moduli
class EltwiseReduceModTest
: public ::testing::TestWithParam<std::tuple<uint64_t, bool>> {
protected:
void SetUp() override {
m_modulus_bits = std::get<0>(GetParam());
m_prefer_small_primes = std::get<1>(GetParam());
m_modulus = GeneratePrimes(1, m_modulus_bits, m_prefer_small_primes)[0];
}

void TearDown() override {}

public:
uint64_t m_N{1024 + 7}; // m_N % 8 = 7 to test AVX512 boundary case
uint64_t m_modulus_bits;
bool m_prefer_small_primes;
uint64_t m_modulus;
};

// Test public API matches Native implementation on random values
TEST_P(EltwiseReduceModTest, Random) {
uint64_t upper_bound =
m_modulus < (1ULL << 32) ? m_modulus * m_modulus : 1ULL << 63;

auto input = GenerateInsecureUniformRandomValues(m_N, 0, upper_bound);
std::vector<uint64_t> result_native(m_N, 0);
std::vector<uint64_t> result_public_api(m_N, 0);

EltwiseReduceModNative(result_native.data(), input.data(), m_N, m_modulus,
m_modulus, 1);
EltwiseReduceMod(result_public_api.data(), input.data(), m_N, m_modulus,
m_modulus, 1);
AssertEqual(result_native, result_public_api);
}

INSTANTIATE_TEST_SUITE_P(
EltwiseReduceMod, EltwiseReduceModTest,
::testing::Combine(::testing::ValuesIn(AlignedVector64<uint64_t>{
20, 25, 30, 31, 32, 33, 35, 40, 48, 49, 50, 51, 52,
55, 58, 59, 60}),
::testing::ValuesIn(std::vector<bool>{false, true})));

} // namespace hexl
} // namespace intel

0 comments on commit accf7a5

Please sign in to comment.