Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce non-prime finite fields in fp #147

Closed
wants to merge 13 commits into from
4 changes: 2 additions & 2 deletions ext/crates/algebra/src/algebra/combinatorics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub const MAX_XI_TAU: usize = MAX_MULTINOMIAL_LEN;
/// If p is the nth prime, then `XI_DEGREES[n][i - 1]` is the degree of $ξ_i$ at the prime p divided by
/// q, where q = 2p - 2 if p != 2 and 1 if p = 2.
const XI_DEGREES: [[i32; MAX_XI_TAU]; NUM_PRIMES] = {
let mut res = [[0; 10]; 8];
let mut res = [[0; MAX_XI_TAU]; NUM_PRIMES];
JoeyBF marked this conversation as resolved.
Show resolved Hide resolved
const_for! { p_idx in 0 .. NUM_PRIMES {
let p = PRIMES[p_idx];
let mut p_to_the_i = p;
Expand All @@ -28,7 +28,7 @@ const XI_DEGREES: [[i32; MAX_XI_TAU]; NUM_PRIMES] = {
/// If p is the nth prime, then `TAU_DEGREES[n][i]` is the degree of $τ_i$ at the prime p. Its value is
/// nonsense at the prime 2
const TAU_DEGREES: [[i32; MAX_XI_TAU]; NUM_PRIMES] = {
let mut res = [[0; 10]; 8];
let mut res = [[0; MAX_XI_TAU]; NUM_PRIMES];
const_for! { p_idx in 0 .. NUM_PRIMES {
let p = PRIMES[p_idx];
let mut p_to_the_i: u32 = 1;
Expand Down
2 changes: 2 additions & 0 deletions ext/crates/fp/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ edition = "2021"
build_const = "0.2.2"
byteorder = "1.4.3"
cfg-if = "1.0.0"
dashmap = "5"
itertools = { version = "0.10.0", default-features = false }
once_cell = "1.19.0"
serde = "1.0.0"
serde_json = "1.0.0"

Expand Down
20 changes: 8 additions & 12 deletions ext/crates/fp/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@ use build_const::ConstWriter;
type Limb = u64;

fn main() -> Result<(), Error> {
let num_primes = 8;
let primes = first_n_primes(num_primes);
// We want primes up to 2^8 - 1, because those will be the characteristics of the fields that
// have degree at least 2 and order at most 2^16 - 1. We will use PRIME_TO_INDEX_MAP when
// computing Zech logarithms.
let prime_bound = u8::MAX;
let primes = primes_up_to(prime_bound);
let num_primes = primes.len();
let max_prime = *primes.last().unwrap();
let not_a_prime: usize = u32::MAX as usize; // Hack for 32-bit architectures
let max_multinomial_len = 10;
Expand Down Expand Up @@ -55,16 +59,8 @@ fn main() -> Result<(), Error> {
Ok(())
}

fn first_n_primes(n: usize) -> Vec<u32> {
let mut acc = vec![];
let mut i = 2;
while acc.len() < n {
if is_prime(i) {
acc.push(i);
}
i += 1;
}
acc
fn primes_up_to(n: impl Into<u32>) -> Vec<u32> {
(2..=n.into()).filter(|&i| is_prime(i)).collect()
}

fn is_prime(i: u32) -> bool {
Expand Down
135 changes: 135 additions & 0 deletions ext/crates/fp/src/field/fp.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
use super::{limb::LimbMethods, Field, FieldElement};
use crate::{constants::BITS_PER_LIMB, limb::Limb, prime::Prime};

/// A prime field. This is just a wrapper around a prime.
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub struct Fp<P>(pub(crate) P);

impl<P: Prime> Field for Fp<P> {
#[cfg(feature = "odd-primes")]
type Characteristic = P;

#[cfg(feature = "odd-primes")]
fn characteristic(self) -> Self::Characteristic {
self.0
}

fn degree(self) -> u32 {
1
}

fn zero(self) -> Self::Element {
0
}

fn one(self) -> Self::Element {
1
}

fn add(self, a: Self::Element, b: Self::Element) -> Self::Element {
self.0.sum(a, b)
}

fn mul(self, a: Self::Element, b: Self::Element) -> Self::Element {
self.0.product(a, b)
}

fn inv(self, a: Self::Element) -> Option<Self::Element> {
if a == 0 {
None
} else {
// By Fermat's little theorem, `a^(p-1) = 1`. Therefore, `a^(-1) = a^(p-2)`.
Some(self.0.pow_mod(a, self.0.as_u32() - 2))
}
}

fn neg(self, a: Self::Element) -> Self::Element {
if a > 0 {
self.0.as_u32() - a
} else {
0
}
}

fn frobenius(self, a: Self::Element) -> Self::Element {
a
}
}

impl<P: Prime> LimbMethods for Fp<P> {
type Element = u32;

fn encode(self, element: Self::Element) -> Limb {
element as Limb
}

fn decode(self, element: Limb) -> Self::Element {
(element % self.0.as_u32() as Limb) as u32
}

fn bit_length(self) -> usize {
let p = self.characteristic().as_u32() as u64;
match p {
2 => 1,
_ => (BITS_PER_LIMB as u32 - (p * (p - 1)).leading_zeros()) as usize,
}
}

fn fma_limb(self, limb_a: Limb, limb_b: Limb, coeff: Self::Element) -> Limb {
if self.characteristic() == 2 {
limb_a ^ (coeff as Limb * limb_b)
} else {
limb_a + (coeff as Limb) * limb_b
}
}

/// Contributed by Robert Burklund.
fn reduce(self, limb: Limb) -> Limb {
match self.characteristic().as_u32() {
2 => limb,
3 => {
// Set top bit to 1 in every limb
const TOP_BIT: Limb = (!0 / 7) << (2 - BITS_PER_LIMB % 3);
let mut limb_2 = ((limb & TOP_BIT) >> 2) + (limb & (!TOP_BIT));
let mut limb_3s = limb_2 & (limb_2 >> 1);
limb_3s |= limb_3s << 1;
limb_2 ^= limb_3s;
limb_2
}
5 => {
// Set bottom bit to 1 in every limb
const BOTTOM_BIT: Limb = (!0 / 31) >> (BITS_PER_LIMB % 5);
const BOTTOM_TWO_BITS: Limb = BOTTOM_BIT | (BOTTOM_BIT << 1);
const BOTTOM_THREE_BITS: Limb = BOTTOM_BIT | (BOTTOM_TWO_BITS << 1);
let a = (limb >> 2) & BOTTOM_THREE_BITS;
let b = limb & BOTTOM_TWO_BITS;
let m = (BOTTOM_BIT << 3) - a + b;
let mut c = (m >> 3) & BOTTOM_BIT;
c |= c << 1;
let d = m & BOTTOM_THREE_BITS;
d + c - BOTTOM_TWO_BITS
}
_ => self.pack(self.unpack(limb)),
}
}
}

impl<P> std::ops::Deref for Fp<P> {
type Target = P;

fn deref(&self) -> &Self::Target {
&self.0
}
}

impl<P: Prime> From<P> for Fp<P> {
fn from(p: P) -> Self {
Self(p)
}
}

impl FieldElement for u32 {
fn is_zero(&self) -> bool {
*self == 0
}
}
149 changes: 149 additions & 0 deletions ext/crates/fp/src/field/limb.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
// According to
// https://doc.rust-lang.org/stable/rustc/lints/listing/warn-by-default.html#private-interfaces:
//
// "Having something private in primary interface guarantees that the item will be unusable from
// outer modules due to type privacy."
//
// In our case, this is a feature. We want to be able to use the `LimbMethods` trait in this crate
// and we also want it to be inaccessible from outside the crate.
#![allow(private_interfaces)]

use std::ops::Range;

use super::FieldElement;
use crate::{
constants::BITS_PER_LIMB,
limb::{Limb, LimbBitIndexPair},
};

/// Methods that lets us interact with the underlying `Limb` type.
///
/// In practice this is an extension trait of a `Field`, so we treat it as such. We can't make it a
/// supertrait of `Field` because `Field` is already a supertrait of `LimbMethods`.
pub trait LimbMethods: Clone + Copy + Sized {
type Element: FieldElement;

/// Encode a field element into a `Limb`. The limbs of an `FpVectorP<Self>` will consist of the
/// coordinates of the vector, packed together using this method. It is assumed that the output
/// value occupies at most `self.bit_length()` bits with the rest padded with zeros, and that
/// the limb is reduced.
///
/// It is required that `self.encode(self.zero()) == 0` (whenever `Self` implements `Field`).
fn encode(self, element: Self::Element) -> Limb;

/// Decode a `Limb` into a field element. The argument will always contain a single encoded
/// field element, padded with zeros. This is the inverse of [`encode`].
fn decode(self, element: Limb) -> Self::Element;

/// Return the number of bits a `Self::Element` occupies in a limb.
fn bit_length(self) -> usize;

/// Fused multiply-add. Return the `Limb` whose `i`th entry is `limb_a[i] + coeff * limb_b[i]`.
/// Both `limb_a` and `limb_b` are assumed to be reduced, and the result does not have to be
/// reduced.
fn fma_limb(self, limb_a: Limb, limb_b: Limb, coeff: Self::Element) -> Limb;

/// Return the `Limb` whose entries are the entries of `limb` reduced modulo `P`.
fn reduce(self, limb: Limb) -> Limb;

/// If `l` is a limb of `Self::Element`s, then `l & F.bitmask()` is the value of the
/// first entry of `l`.
fn bitmask(self) -> Limb {
(1 << self.bit_length()) - 1
}

/// The number of `Self::Element`s that fit in a single limb.
fn entries_per_limb(self) -> usize {
BITS_PER_LIMB / self.bit_length()
}

fn limb_bit_index_pair(self, idx: usize) -> LimbBitIndexPair {
LimbBitIndexPair {
limb: idx / self.entries_per_limb(),
bit_index: (idx % self.entries_per_limb() * self.bit_length()),
}
}

/// Check whether or not a limb is reduced, i.e. whether every entry is a value in the range `0..P`.
/// This is currently **not** faster than calling [`reduce`] directly.
fn is_reduced(self, limb: Limb) -> bool {
limb == self.reduce(limb)
}

/// Given an interator of `Self::Element`s, pack all of them into a single limb in order.
/// It is assumed that
/// - The values of the iterator are less than P
/// - The values of the iterator fit into a single limb
///
/// If these assumptions are violated, the result will be nonsense.
fn pack<T: Iterator<Item = Self::Element>>(self, entries: T) -> Limb {
let bit_length = self.bit_length();
let mut result: Limb = 0;
let mut shift = 0;
for entry in entries {
result += self.encode(entry) << shift;
shift += bit_length;
}
result
}

/// Give an iterator over the entries of `limb`.
fn unpack(self, limb: Limb) -> LimbIterator<Self> {
LimbIterator {
fq: self,
limb,
bit_length: self.bit_length(),
bit_mask: self.bitmask(),
}
}

/// Return the number of limbs required to hold `dim` entries.
fn number(self, dim: usize) -> usize {
if dim == 0 {
0
} else {
self.limb_bit_index_pair(dim - 1).limb + 1
}
}

/// Return the `Range<usize>` starting at the index of the limb containing the `start`th entry, and
/// ending at the index of the limb containing the `end`th entry (including the latter).
fn range(self, start: usize, end: usize) -> Range<usize> {
let min = self.limb_bit_index_pair(start).limb;
let max = if end > 0 {
self.limb_bit_index_pair(end - 1).limb + 1
} else {
0
};
min..max
}

/// Return either `Some(sum)` if no carries happen in the limb, or `None` if some carry does happen.
fn truncate(self, sum: Limb) -> Option<Limb> {
if self.is_reduced(sum) {
Some(sum)
} else {
None
}
}
}

pub(crate) struct LimbIterator<F> {
fq: F,
limb: Limb,
bit_length: usize,
bit_mask: Limb,
}

impl<F: LimbMethods> Iterator for LimbIterator<F> {
type Item = F::Element;

fn next(&mut self) -> Option<Self::Item> {
if self.limb == 0 {
return None;
}
let result = self.limb & self.bit_mask;
self.limb >>= self.bit_length;
Some(self.fq.decode(result))
}
}
Loading
Loading