From 5e740bbaecfc811748d74377d7e1acc6d30ff536 Mon Sep 17 00:00:00 2001 From: Andy Polyakov Date: Wed, 18 Sep 2024 13:50:24 +0200 Subject: [PATCH] Add ff/pow.hpp and deduplicate operator^() methods. --- ff/baby_bear.hpp | 40 ++------------------ ff/gl64_t.cuh | 40 +++----------------- ff/mersenne31.hpp | 37 +++--------------- ff/mont32_t.cuh | 44 +++------------------- ff/mont_t.cuh | 35 ++--------------- ff/pow.hpp | 95 +++++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 117 insertions(+), 174 deletions(-) create mode 100644 ff/pow.hpp diff --git a/ff/baby_bear.hpp b/ff/baby_bear.hpp index 6ff3a5f..6ac55e7 100644 --- a/ff/baby_bear.hpp +++ b/ff/baby_bear.hpp @@ -5,6 +5,8 @@ #ifndef __SPPARK_FF_BABY_BEAR_HPP__ #define __SPPARK_FF_BABY_BEAR_HPP__ +#include "pow.hpp" + #ifdef __CUDACC__ // CUDA device-side field types # include # include "mont32_t.cuh" @@ -638,23 +640,7 @@ class __align__(16) bb31_4_t { // raise to a variable power, variable in respect to threadIdx, // but mind the ^ operator's precedence! inline bb31_4_t& operator^=(uint32_t p) - { - bb31_4_t sqr = *this; - - if (!(p&1)) { - c[0] = bb31_t{1}; - c[1] = c[2] = c[3] = 0; - } - - #pragma unroll 1 - while (p >>= 1) { - sqr.sqr(); - if (p&1) - mul(sqr); - } - - return *this; - } + { return pow_byref(*this, p); } friend inline bb31_4_t operator^(bb31_4_t a, uint32_t p) { return a ^= p; } inline bb31_4_t operator()(uint32_t p) @@ -662,25 +648,7 @@ class __align__(16) bb31_4_t { // raise to a constant power, e.g. x^7, to be unrolled at compile time inline bb31_4_t& operator^=(int p) - { - assert(p >= 2); - - bb31_4_t sqr = *this; - if ((p&1) == 0) { - do { - sqr.sqr(); - p >>= 1; - } while ((p&1) == 0); - *this = sqr; - } - for (p >>= 1; p; p >>= 1) { - sqr.sqr(); - if (p&1) - mul(sqr); - } - - return *this; - } + { return pow_byref(*this, p); } friend inline bb31_4_t operator^(bb31_4_t a, int p) { return a ^= p; } inline bb31_4_t operator()(int p) diff --git a/ff/gl64_t.cuh b/ff/gl64_t.cuh index 03ac4de..4e6f6c7 100644 --- a/ff/gl64_t.cuh +++ b/ff/gl64_t.cuh @@ -5,7 +5,10 @@ #if defined(__CUDACC__) && !defined(__SPPARK_FF_GL64_T_CUH__) #define __SPPARK_FF_GL64_T_CUH__ +# include # include +# include "pow.hpp" + # define inline __device__ __forceinline__ # ifdef __GNUC__ # define asm __asm__ __volatile__ @@ -308,20 +311,7 @@ public: // raise to a variable power, variable in respect to threadIdx, // but mind the ^ operator's precedence! inline gl64_t& operator^=(uint32_t p) - { - gl64_t sqr = *this; - *this = csel(*this, one(), p&1); - - #pragma unroll 1 - while (p >>= 1) { - sqr.mul(sqr); - if (p&1) - mul(sqr); - } - to(); - - return *this; - } + { pow_byref(*this, p); to(); return *this; } friend inline gl64_t operator^(gl64_t a, uint32_t p) { return a ^= p; } inline gl64_t operator()(uint32_t p) @@ -329,27 +319,7 @@ public: // raise to a constant power, e.g. x^7, to be unrolled at compile time inline gl64_t& operator^=(int p) - { - if (p < 2) - asm("trap;"); - - gl64_t sqr = *this; - if ((p&1) == 0) { - do { - sqr.mul(sqr); - p >>= 1; - } while ((p&1) == 0); - *this = sqr; - } - for (p >>= 1; p; p >>= 1) { - sqr.mul(sqr); - if (p&1) - mul(sqr); - } - to(); - - return *this; - } + { pow_byref(*this, p); to(); return *this; } friend inline gl64_t operator^(gl64_t a, int p) { return a ^= p; } inline gl64_t operator()(int p) diff --git a/ff/mersenne31.hpp b/ff/mersenne31.hpp index 9fe400b..13d1865 100644 --- a/ff/mersenne31.hpp +++ b/ff/mersenne31.hpp @@ -5,6 +5,8 @@ #ifndef __SPPARK_FF_MERSENNE31_HPP__ #define __SPPARK_FF_MERSENNE31_HPP__ +#include "pow.hpp" + #ifdef __CUDACC__ // CUDA device-side field types # include "mont32_t.cuh" # define inline __device__ __forceinline__ @@ -61,6 +63,7 @@ struct mrs31_t : public mrs31_base { }; # undef inline #else +# include # include # include # if defined(__CUDACC__) || defined(__HIPCC__) @@ -181,19 +184,7 @@ class mrs31_t { // raise to a variable power, variable in respect to threadIdx, // but mind the ^ operator's precedence! inline mrs31_t& operator^=(uint32_t p) - { - mrs31_t sqr = *this; - *this = csel(val, 1, p&1); - - #pragma unroll 1 - while (p >>= 1) { - sqr.mul(sqr); - if (p&1) - mul(sqr); - } - - return *this; - } + { return pow_byref(*this, p); } friend inline mrs31_t operator^(mrs31_t a, uint32_t p) { return a ^= p; } inline mrs31_t operator()(uint32_t p) @@ -201,25 +192,7 @@ class mrs31_t { // raise to a constant power, e.g. x^7, to be unrolled at compile time inline mrs31_t& operator^=(int p) - { - assert(p >= 2); - - mrs31_t sqr = *this; - if ((p&1) == 0) { - do { - sqr.mul(sqr); - p >>= 1; - } while ((p&1) == 0); - *this = sqr; - } - for (p >>= 1; p; p >>= 1) { - sqr.mul(sqr); - if (p&1) - mul(sqr); - } - - return *this; - } + { return pow_byref(*this, p); } friend inline mrs31_t operator^(mrs31_t a, int p) { return a ^= p; } inline mrs31_t operator()(int p) diff --git a/ff/mont32_t.cuh b/ff/mont32_t.cuh index 7daa430..2d7b96b 100644 --- a/ff/mont32_t.cuh +++ b/ff/mont32_t.cuh @@ -5,7 +5,10 @@ #if defined(__CUDACC__) && !defined(__SPPARK_FF_MONT32_T_CUH__) #define __SPPARK_FF_MONT32_T_CUH__ +# include # include +# include "pow.hpp" + # define inline __device__ __forceinline__ # ifdef __GNUC__ # define asm __asm__ __volatile__ @@ -228,19 +231,7 @@ public: // raise to a variable power, variable in respect to threadIdx, // but mind the ^ operator's precedence! inline mont32_t& operator^=(uint32_t p) - { - mont32_t sqr = *this; - *this = csel(val, ONE, p&1); - - #pragma unroll 1 - while (p >>= 1) { - sqr.mul(sqr); - if (p&1) - mul(sqr); - } - - return *this; - } + { return pow_byref(*this, p); } friend inline mont32_t operator^(mont32_t a, uint32_t p) { return a ^= p; } inline mont32_t operator()(uint32_t p) @@ -248,32 +239,7 @@ public: // raise to a constant power, e.g. x^7, to be unrolled at compile time inline mont32_t& operator^=(int p) - { - if (p < 2) - asm("trap;"); - - if (p == 7) { - mont32_t temp = sqr_n_mul(*this, 1, *this); - *this = sqr_n_mul(temp, 1, *this); - return *this; - } - - mont32_t sqr = *this; - if ((p&1) == 0) { - do { - sqr.mul(sqr); - p >>= 1; - } while ((p&1) == 0); - *this = sqr; - } - for (p >>= 1; p; p >>= 1) { - sqr.mul(sqr); - if (p&1) - mul(sqr); - } - - return *this; - } + { return pow_byref(*this, p); } friend inline mont32_t operator^(mont32_t a, int p) { return a ^= p; } inline mont32_t operator()(int p) diff --git a/ff/mont_t.cuh b/ff/mont_t.cuh index aa5e075..02864f7 100644 --- a/ff/mont_t.cuh +++ b/ff/mont_t.cuh @@ -7,6 +7,7 @@ # include # include +# include "pow.hpp" # define inline __device__ __forceinline__ # ifdef __GNUC__ @@ -425,19 +426,7 @@ public: // raise to a variable power, variable in respect to threadIdx, // but mind the ^ operator's precedence! inline mont_t& operator^=(uint32_t p) - { - mont_t sqr = *this; - *this = csel(*this, one(), p&1); - - #pragma unroll 1 - while (p >>= 1) { - sqr.sqr(); - if (p&1) - *this *= sqr; - } - - return *this; - } + { return pow_byref(*this, p); } friend inline mont_t operator^(mont_t a, uint32_t p) { return a ^= p; } inline mont_t operator()(uint32_t p) @@ -445,25 +434,7 @@ public: // raise to a constant power, e.g. x^7, to be unrolled at compile time inline mont_t& operator^=(int p) - { - if (p < 2) - asm("trap;"); - - mont_t sqr = *this; - if ((p&1) == 0) { - do { - sqr.sqr(); - p >>= 1; - } while ((p&1) == 0); - *this = sqr; - } - for (p >>= 1; p; p >>= 1) { - sqr.sqr(); - if (p&1) - *this *= sqr; - } - return *this; - } + { return pow_byref(*this, p); } friend inline mont_t operator^(mont_t a, int p) { return p == 2 ? (mont_t)wide_t{a} : a ^= p; } inline mont_t operator()(int p) diff --git a/ff/pow.hpp b/ff/pow.hpp new file mode 100644 index 0000000..a25cb23 --- /dev/null +++ b/ff/pow.hpp @@ -0,0 +1,95 @@ +// Copyright Supranational LLC +// Licensed under the Apache License, Version 2.0, see LICENSE for details. +// SPDX-License-Identifier: Apache-2.0 + +#ifndef __SPPARK_FF_POW_HPP__ +#define __SPPARK_FF_POW_HPP__ + +#if defined(__CUDACC__) || defined(__HIPCC__) +# define inline __host__ __device__ __forceinline__ +#endif + +#if defined(_MSC_VER) +# pragma warning(push) +# pragma warning(disable: 4068) +#elif defined(__GNUC__) && !defined(__clang__) +# pragma GCC diagnostic push +# pragma GCC diagnostic ignored "-Wunknown-pragmas" +#endif + +/* + * Raise to a variable power, e.g. variable in respect to threadIdx. + */ +template +inline T& pow_byref(T& val, U p) +{ + T sqr = val; + val = T::csel(val, T::one(), p&1); + + #pragma unroll 1 + while (p >>= 1) { + sqr.sqr(); + if (p&1) + val *= sqr; + } + + return val; +} + +#if defined(__CUDACC__) || defined(__HIPCC__) +/* + * This is meant to be used for code size optimization by deduplicating + * otherwise inlined pow_byref. + */ +template __device__ __noinline__ +T pow_byval(T val, unsigned p) +{ return pow_byref(val, p); } +#endif + +#include + +/* + * Raise to a constant power, e.g. x^7. The idea is to let compiler + * "decide" how to unroll with expectation that for small constants + * it will be fully inrolled. + */ +template +inline T& pow_byref(T& val, int p) +{ + assert(p >= 2); + + T sqr = val; + if ((p&1) == 0) { + do { + sqr.sqr(); + p >>= 1; + } while ((p&1) == 0); + val = sqr; + } + for (p >>= 1; p; p >>= 1) { + sqr.sqr(); + if (p&1) + val *= sqr; + } + return val; +} + +#if defined(__CUDACC__) || defined(__HIPCC__) +/* + * This is meant to be used for code size optimization by deduplicating + * otherwise inlined pow_byref. + */ +template __device__ __noinline__ +T pow_byval(T val, int p) +{ return pow_byref(val, p); } + +# undef inline +#endif + +#if defined(_MSC_VER) +# pragma warning(pop) +#elif defined(__GNUC__) && !defined(__clang__) +# pragma GCC diagnostic pop +#endif + +#endif