Skip to content

Commit

Permalink
Add ff/pow.hpp and deduplicate operator^() methods.
Browse files Browse the repository at this point in the history
  • Loading branch information
dot-asm committed Sep 18, 2024
1 parent 41439f3 commit 5e740bb
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 174 deletions.
40 changes: 4 additions & 36 deletions ff/baby_bear.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#ifndef __SPPARK_FF_BABY_BEAR_HPP__
#define __SPPARK_FF_BABY_BEAR_HPP__

#include "pow.hpp"

#ifdef __CUDACC__ // CUDA device-side field types
# include <cassert>
# include "mont32_t.cuh"
Expand Down Expand Up @@ -638,49 +640,15 @@ class __align__(16) bb31_4_t {
// raise to a variable power, variable in respect to threadIdx,
// but mind the ^ operator's precedence!
inline bb31_4_t& operator^=(uint32_t p)
{
bb31_4_t sqr = *this;

if (!(p&1)) {
c[0] = bb31_t{1};
c[1] = c[2] = c[3] = 0;
}

#pragma unroll 1
while (p >>= 1) {
sqr.sqr();
if (p&1)
mul(sqr);
}

return *this;
}
{ return pow_byref(*this, p); }
friend inline bb31_4_t operator^(bb31_4_t a, uint32_t p)
{ return a ^= p; }
inline bb31_4_t operator()(uint32_t p)
{ return *this^p; }

// raise to a constant power, e.g. x^7, to be unrolled at compile time
inline bb31_4_t& operator^=(int p)
{
assert(p >= 2);

bb31_4_t sqr = *this;
if ((p&1) == 0) {
do {
sqr.sqr();
p >>= 1;
} while ((p&1) == 0);
*this = sqr;
}
for (p >>= 1; p; p >>= 1) {
sqr.sqr();
if (p&1)
mul(sqr);
}

return *this;
}
{ return pow_byref(*this, p); }
friend inline bb31_4_t operator^(bb31_4_t a, int p)
{ return a ^= p; }
inline bb31_4_t operator()(int p)
Expand Down
40 changes: 5 additions & 35 deletions ff/gl64_t.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
#if defined(__CUDACC__) && !defined(__SPPARK_FF_GL64_T_CUH__)
#define __SPPARK_FF_GL64_T_CUH__

# include <cstddef>
# include <cstdint>
# include "pow.hpp"

# define inline __device__ __forceinline__
# ifdef __GNUC__
# define asm __asm__ __volatile__
Expand Down Expand Up @@ -308,48 +311,15 @@ public:
// raise to a variable power, variable in respect to threadIdx,
// but mind the ^ operator's precedence!
inline gl64_t& operator^=(uint32_t p)
{
gl64_t sqr = *this;
*this = csel(*this, one(), p&1);

#pragma unroll 1
while (p >>= 1) {
sqr.mul(sqr);
if (p&1)
mul(sqr);
}
to();

return *this;
}
{ pow_byref(*this, p); to(); return *this; }
friend inline gl64_t operator^(gl64_t a, uint32_t p)
{ return a ^= p; }
inline gl64_t operator()(uint32_t p)
{ return *this^p; }

// raise to a constant power, e.g. x^7, to be unrolled at compile time
inline gl64_t& operator^=(int p)
{
if (p < 2)
asm("trap;");

gl64_t sqr = *this;
if ((p&1) == 0) {
do {
sqr.mul(sqr);
p >>= 1;
} while ((p&1) == 0);
*this = sqr;
}
for (p >>= 1; p; p >>= 1) {
sqr.mul(sqr);
if (p&1)
mul(sqr);
}
to();

return *this;
}
{ pow_byref(*this, p); to(); return *this; }
friend inline gl64_t operator^(gl64_t a, int p)
{ return a ^= p; }
inline gl64_t operator()(int p)
Expand Down
37 changes: 5 additions & 32 deletions ff/mersenne31.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#ifndef __SPPARK_FF_MERSENNE31_HPP__
#define __SPPARK_FF_MERSENNE31_HPP__

#include "pow.hpp"

#ifdef __CUDACC__ // CUDA device-side field types
# include "mont32_t.cuh"
# define inline __device__ __forceinline__
Expand Down Expand Up @@ -61,6 +63,7 @@ struct mrs31_t : public mrs31_base {
};
# undef inline
#else
# include <cstddef>
# include <cstdint>
# include <cassert>
# if defined(__CUDACC__) || defined(__HIPCC__)
Expand Down Expand Up @@ -181,45 +184,15 @@ class mrs31_t {
// raise to a variable power, variable in respect to threadIdx,
// but mind the ^ operator's precedence!
inline mrs31_t& operator^=(uint32_t p)
{
mrs31_t sqr = *this;
*this = csel(val, 1, p&1);

#pragma unroll 1
while (p >>= 1) {
sqr.mul(sqr);
if (p&1)
mul(sqr);
}

return *this;
}
{ return pow_byref(*this, p); }
friend inline mrs31_t operator^(mrs31_t a, uint32_t p)
{ return a ^= p; }
inline mrs31_t operator()(uint32_t p)
{ return *this^p; }

// raise to a constant power, e.g. x^7, to be unrolled at compile time
inline mrs31_t& operator^=(int p)
{
assert(p >= 2);

mrs31_t sqr = *this;
if ((p&1) == 0) {
do {
sqr.mul(sqr);
p >>= 1;
} while ((p&1) == 0);
*this = sqr;
}
for (p >>= 1; p; p >>= 1) {
sqr.mul(sqr);
if (p&1)
mul(sqr);
}

return *this;
}
{ return pow_byref(*this, p); }
friend inline mrs31_t operator^(mrs31_t a, int p)
{ return a ^= p; }
inline mrs31_t operator()(int p)
Expand Down
44 changes: 5 additions & 39 deletions ff/mont32_t.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
#if defined(__CUDACC__) && !defined(__SPPARK_FF_MONT32_T_CUH__)
#define __SPPARK_FF_MONT32_T_CUH__

# include <cstddef>
# include <cstdint>
# include "pow.hpp"

# define inline __device__ __forceinline__
# ifdef __GNUC__
# define asm __asm__ __volatile__
Expand Down Expand Up @@ -228,52 +231,15 @@ public:
// raise to a variable power, variable in respect to threadIdx,
// but mind the ^ operator's precedence!
inline mont32_t& operator^=(uint32_t p)
{
mont32_t sqr = *this;
*this = csel(val, ONE, p&1);

#pragma unroll 1
while (p >>= 1) {
sqr.mul(sqr);
if (p&1)
mul(sqr);
}

return *this;
}
{ return pow_byref(*this, p); }
friend inline mont32_t operator^(mont32_t a, uint32_t p)
{ return a ^= p; }
inline mont32_t operator()(uint32_t p)
{ return *this^p; }

// raise to a constant power, e.g. x^7, to be unrolled at compile time
inline mont32_t& operator^=(int p)
{
if (p < 2)
asm("trap;");

if (p == 7) {
mont32_t temp = sqr_n_mul(*this, 1, *this);
*this = sqr_n_mul(temp, 1, *this);
return *this;
}

mont32_t sqr = *this;
if ((p&1) == 0) {
do {
sqr.mul(sqr);
p >>= 1;
} while ((p&1) == 0);
*this = sqr;
}
for (p >>= 1; p; p >>= 1) {
sqr.mul(sqr);
if (p&1)
mul(sqr);
}

return *this;
}
{ return pow_byref(*this, p); }
friend inline mont32_t operator^(mont32_t a, int p)
{ return a ^= p; }
inline mont32_t operator()(int p)
Expand Down
35 changes: 3 additions & 32 deletions ff/mont_t.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

# include <cstddef>
# include <cstdint>
# include "pow.hpp"

# define inline __device__ __forceinline__
# ifdef __GNUC__
Expand Down Expand Up @@ -425,45 +426,15 @@ public:
// raise to a variable power, variable in respect to threadIdx,
// but mind the ^ operator's precedence!
inline mont_t& operator^=(uint32_t p)
{
mont_t sqr = *this;
*this = csel(*this, one(), p&1);

#pragma unroll 1
while (p >>= 1) {
sqr.sqr();
if (p&1)
*this *= sqr;
}

return *this;
}
{ return pow_byref(*this, p); }
friend inline mont_t operator^(mont_t a, uint32_t p)
{ return a ^= p; }
inline mont_t operator()(uint32_t p)
{ return *this^p; }

// raise to a constant power, e.g. x^7, to be unrolled at compile time
inline mont_t& operator^=(int p)
{
if (p < 2)
asm("trap;");

mont_t sqr = *this;
if ((p&1) == 0) {
do {
sqr.sqr();
p >>= 1;
} while ((p&1) == 0);
*this = sqr;
}
for (p >>= 1; p; p >>= 1) {
sqr.sqr();
if (p&1)
*this *= sqr;
}
return *this;
}
{ return pow_byref(*this, p); }
friend inline mont_t operator^(mont_t a, int p)
{ return p == 2 ? (mont_t)wide_t{a} : a ^= p; }
inline mont_t operator()(int p)
Expand Down
Loading

0 comments on commit 5e740bb

Please sign in to comment.