diff --git a/arith/Cargo.toml b/arith/Cargo.toml index 04ac761e..b26092ae 100644 --- a/arith/Cargo.toml +++ b/arith/Cargo.toml @@ -19,4 +19,8 @@ criterion.workspace = true name = "field" harness = false +[[bench]] +name = "ext_field" +harness = false + [features] diff --git a/arith/benches/ext_field.rs b/arith/benches/ext_field.rs new file mode 100644 index 00000000..88ab9c93 --- /dev/null +++ b/arith/benches/ext_field.rs @@ -0,0 +1,160 @@ +use arith::{ExtensionField, Field, GF2_128x4, M31Ext3, M31Ext3x16, GF2_128}; +use ark_std::test_rng; +use criterion::{criterion_group, criterion_main, BatchSize, Criterion}; +use tynm::type_name; + +fn random_element() -> F { + let mut rng = test_rng(); + F::random_unsafe(&mut rng) +} + +pub(crate) fn bench_field(c: &mut Criterion) { + c.bench_function( + &format!( + "mul-by-base-throughput<{}> 100x times {}x ", + type_name::(), + F::SIZE * 8 / F::FIELD_SIZE + ), + |b| { + b.iter_batched( + || { + ( + random_element::(), + random_element::(), + random_element::(), + random_element::(), + random_element::(), + random_element::(), + random_element::(), + random_element::(), + ) + }, + |(mut x, mut y, mut z, mut w, xx, yy, zz, ww)| { + for _ in 0..25 { + (x, y, z, w) = ( + x.mul_by_base_field(&xx), + y.mul_by_base_field(&yy), + z.mul_by_base_field(&zz), + w.mul_by_base_field(&ww), + ); + } + (x, y, z, w) + }, + BatchSize::SmallInput, + ) + }, + ); + + c.bench_function( + &format!( + "mul-by-x-throughput<{}> 100x times {}x ", + type_name::(), + F::SIZE * 8 / F::FIELD_SIZE + ), + |b| { + b.iter_batched( + || { + ( + random_element::(), + random_element::(), + random_element::(), + random_element::(), + ) + }, + |(mut x, mut y, mut z, mut w)| { + for _ in 0..25 { + (x, y, z, w) = (x.mul_by_x(), y.mul_by_x(), z.mul_by_x(), w.mul_by_x()); + } + (x, y, z, w) + }, + BatchSize::SmallInput, + ) + }, + ); + + c.bench_function( + &format!( + "mul-by-base-latency<{}> 100x times {}x ", + type_name::(), + F::SIZE * 8 / F::FIELD_SIZE + ), + |b| { + b.iter_batched( + || (random_element::(), random_element::()), + |(mut x, xx)| { + for _ in 0..100 { + x = x.mul_by_base_field(&xx); + } + x + }, + BatchSize::SmallInput, + ) + }, + ); + + c.bench_function( + &format!( + "add-by-base-throughput<{}> 100x times {}x ", + type_name::(), + F::SIZE * 8 / F::FIELD_SIZE + ), + |b| { + b.iter_batched( + || { + ( + random_element::(), + random_element::(), + random_element::(), + random_element::(), + random_element::(), + random_element::(), + random_element::(), + random_element::(), + ) + }, + |(mut x, mut y, mut z, mut w, xx, yy, zz, ww)| { + for _ in 0..25 { + (x, y, z, w) = ( + x.add_by_base_field(&xx), + y.add_by_base_field(&yy), + z.add_by_base_field(&zz), + w.add_by_base_field(&ww), + ); + } + (x, y, z, w) + }, + BatchSize::SmallInput, + ) + }, + ); + + c.bench_function( + &format!( + "add-by-base-latency<{}> 100x times {}x ", + type_name::(), + F::SIZE * 8 / F::FIELD_SIZE + ), + |b| { + b.iter_batched( + || (random_element::(), random_element::()), + |(mut x, xx)| { + for _ in 0..100 { + x = x.add_by_base_field(&xx); + } + x + }, + BatchSize::SmallInput, + ) + }, + ); +} + +fn ext_by_base_benchmark(c: &mut Criterion) { + bench_field::(c); + bench_field::(c); + bench_field::(c); + bench_field::(c); +} + +criterion_group!(ext_by_base_benches, ext_by_base_benchmark); +criterion_main!(ext_by_base_benches); diff --git a/arith/src/extension_field.rs b/arith/src/extension_field.rs index f7dbba52..3399c712 100644 --- a/arith/src/extension_field.rs +++ b/arith/src/extension_field.rs @@ -11,17 +11,19 @@ pub use m31_ext::M31Ext3; pub use m31_ext3x16::M31Ext3x16; /// Configurations for Extension Field over -/// the Binomial polynomial x^DEGREE - W +/// - either the Binomial polynomial x^DEGREE - W +/// - or the AES polynomial x^128 + x^7 + x^2 + x + 1 // -// FIXME: Our binary extension field is no longer a binomial extension field -// will fix later -pub trait BinomialExtensionField: From + Field + FieldSerde { +pub trait ExtensionField: From + Field + FieldSerde { /// Degree of the Extension const DEGREE: usize; - /// Extension Field + /// constant term if the extension field is represented as a binomial polynomial const W: u32; + /// x, i.e, 0 + x + 0 x^2 + 0 x^3 + ... + const X: Self; + /// Base field for the extension type BaseField: Field + FieldSerde + Send; @@ -30,4 +32,7 @@ pub trait BinomialExtensionField: From + Field + FieldSerde { /// Add the extension field with the base field fn add_by_base_field(&self, base: &Self::BaseField) -> Self; + + /// Multiply the extension field element by x, i.e, 0 + x + 0 x^2 + 0 x^3 + ... + fn mul_by_x(&self) -> Self; } diff --git a/arith/src/extension_field/fr_ext.rs b/arith/src/extension_field/fr_ext.rs index 85679370..79ab4872 100644 --- a/arith/src/extension_field/fr_ext.rs +++ b/arith/src/extension_field/fr_ext.rs @@ -1,13 +1,16 @@ use halo2curves::bn256::Fr; -use super::BinomialExtensionField; +use super::ExtensionField; -impl BinomialExtensionField for Fr { +impl ExtensionField for Fr { const DEGREE: usize = 1; /// Extension Field over X-1 which is self const W: u32 = 1; + // placeholder, doesn't make sense for Fr + const X: Self = Fr::zero(); + /// Base field for the extension type BaseField = Self; @@ -20,4 +23,9 @@ impl BinomialExtensionField for Fr { fn add_by_base_field(&self, base: &Self::BaseField) -> Self { self + base } + + /// Multiply the extension field by x, i.e, 0 + x + 0 x^2 + 0 x^3 + ... + fn mul_by_x(&self) -> Self { + unimplemented!("mul_by_x for Fr doesn't make sense") + } } diff --git a/arith/src/extension_field/gf2_128/avx.rs b/arith/src/extension_field/gf2_128/avx.rs index 50f4beb0..0dbe218f 100644 --- a/arith/src/extension_field/gf2_128/avx.rs +++ b/arith/src/extension_field/gf2_128/avx.rs @@ -5,7 +5,7 @@ use std::{ ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, }; -use crate::{field_common, BinomialExtensionField, Field, FieldSerde, FieldSerdeResult, GF2}; +use crate::{field_common, ExtensionField, Field, FieldSerde, FieldSerdeResult, GF2}; #[derive(Debug, Clone, Copy)] pub struct AVX512GF2_128 { @@ -48,13 +48,19 @@ impl FieldSerde for AVX512GF2_128 { impl Field for AVX512GF2_128 { const NAME: &'static str = "Galios Field 2^128"; + const SIZE: usize = 128 / 8; + const FIELD_SIZE: usize = 128; // in bits const ZERO: Self = AVX512GF2_128 { v: unsafe { std::mem::zeroed() }, }; + const ONE: Self = AVX512GF2_128 { + v: unsafe { std::mem::transmute::<[i32; 4], __m128i>([1, 0, 0, 0]) }, + }; + const INV_2: Self = AVX512GF2_128 { v: unsafe { std::mem::zeroed() }, }; // should not be used @@ -141,10 +147,15 @@ impl Field for AVX512GF2_128 { } } -impl BinomialExtensionField for AVX512GF2_128 { +impl ExtensionField for AVX512GF2_128 { const DEGREE: usize = 128; + const W: u32 = 0x87; + const X: Self = AVX512GF2_128 { + v: unsafe { std::mem::transmute::<[i32; 4], __m128i>([2, 0, 0, 0]) }, + }; + type BaseField = GF2; #[inline(always)] @@ -162,6 +173,37 @@ impl BinomialExtensionField for AVX512GF2_128 { res.v = unsafe { _mm_xor_si128(res.v, _mm_set_epi64x(0, base.v as i64)) }; res } + + #[inline] + fn mul_by_x(&self) -> Self { + unsafe { + // Shift left by 1 bit + let shifted = _mm_slli_epi64(self.v, 1); + + // Get the most significant bit and move it + let msb = _mm_srli_epi64(self.v, 63); + let msb_moved = _mm_slli_si128(msb, 8); + + // Combine the shifted value with the moved msb + let shifted_consolidated = _mm_or_si128(shifted, msb_moved); + + // Create the reduction value (0x87) and the comparison value (1) + let reduction = { + let multiplier = _mm_set_epi64x(0, 0x87); + let one = _mm_set_epi64x(0, 1); + + // Check if the MSB was 1 and create a mask + let mask = _mm_cmpeq_epi64(_mm_srli_si128(msb, 8), one); + + _mm_and_si128(mask, multiplier) + }; + + // Apply the reduction conditionally + let res = _mm_xor_si128(shifted_consolidated, reduction); + + Self { v: res } + } + } } impl From for AVX512GF2_128 { diff --git a/arith/src/extension_field/gf2_128/neon.rs b/arith/src/extension_field/gf2_128/neon.rs index 6ee7ad9b..526b7e3e 100644 --- a/arith/src/extension_field/gf2_128/neon.rs +++ b/arith/src/extension_field/gf2_128/neon.rs @@ -2,7 +2,7 @@ use std::iter::{Product, Sum}; use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; use std::{arch::aarch64::*, mem::transmute}; -use crate::{field_common, BinomialExtensionField, Field, FieldSerde, FieldSerdeResult, GF2}; +use crate::{field_common, ExtensionField, Field, FieldSerde, FieldSerdeResult, GF2}; #[derive(Clone, Copy, Debug)] pub struct NeonGF2_128 { @@ -67,13 +67,19 @@ impl FieldSerde for NeonGF2_128 { impl Field for NeonGF2_128 { const NAME: &'static str = "Galios Field 2^128"; + const SIZE: usize = 128 / 8; + const FIELD_SIZE: usize = 128; // in bits const ZERO: Self = NeonGF2_128 { v: unsafe { std::mem::zeroed() }, }; + const ONE: Self = NeonGF2_128 { + v: unsafe { transmute::<[u32; 4], uint32x4_t>([1, 0, 0, 0]) }, + }; + const INV_2: Self = NeonGF2_128 { v: unsafe { std::mem::zeroed() }, }; // should not be used @@ -160,10 +166,15 @@ impl Field for NeonGF2_128 { } } -impl BinomialExtensionField for NeonGF2_128 { +impl ExtensionField for NeonGF2_128 { const DEGREE: usize = 128; + const W: u32 = 0x87; + const X: Self = NeonGF2_128 { + v: unsafe { std::mem::transmute::<[i32; 4], uint32x4_t>([2, 0, 0, 0]) }, + }; + type BaseField = GF2; #[inline(always)] @@ -182,6 +193,13 @@ impl BinomialExtensionField for NeonGF2_128 { } add_internal(&Self::one(), self) } + + #[inline(always)] + fn mul_by_x(&self) -> Self { + Self { + v: mul_by_x_internal(&self.v), + } + } } impl From for NeonGF2_128 { @@ -343,3 +361,35 @@ pub(crate) unsafe fn gfmul(a: uint32x4_t, b: uint32x4_t) -> uint32x4_t { vreinterpretq_u32_u64(veorq_u64(tmp3, tmp6)) } + +#[inline] +pub(crate) fn mul_by_x_internal(a: &uint32x4_t) -> uint32x4_t { + unsafe { + let (high_bit, shifted_consolidated) = { + // Reinterpret uint32x4_t as uint64x2_t + let a_u64 = vreinterpretq_u64_u32(*a); + + // Extract the highest bit of both channels + let high_bit_first = vgetq_lane_u64(a_u64, 0) >> 63; + let high_bit_second = vgetq_lane_u64(a_u64, 1) >> 63; + + // shift to the left by 1 + let shifted = vshlq_n_u64(a_u64, 1); + + // Create a mask with the high bit in the lowest position of the second channel + let mask = vsetq_lane_u64(high_bit_first, vdupq_n_u64(0), 1); + + // OR the shifted value with the mask to set the low bit of the second channel + let shifted_consolidated = vorrq_u64(shifted, mask); + + (high_bit_second, shifted_consolidated) + }; + + let reduction = vcombine_u64(vdup_n_u64(0x87 * high_bit), vdup_n_u64(0)); + + let res = veorq_u64(shifted_consolidated, reduction); + + // Reinterpret uint64x2_t back to uint32x4_t + vreinterpretq_u32_u64(res) + } +} diff --git a/arith/src/extension_field/gf2_128x4/avx.rs b/arith/src/extension_field/gf2_128x4/avx.rs index 51d815d9..f608e4de 100644 --- a/arith/src/extension_field/gf2_128x4/avx.rs +++ b/arith/src/extension_field/gf2_128x4/avx.rs @@ -1,4 +1,4 @@ -use crate::field_common; +use crate::{field_common, ExtensionField}; use crate::{Field, FieldSerde, FieldSerdeResult, SimdField, GF2_128}; use std::fmt::Debug; @@ -81,6 +81,10 @@ impl Field for AVX512GF2_128x4 { const ZERO: Self = Self { data: PACKED_0 }; + const ONE: Self = Self { + data: unsafe { transmute::<[u64; 8], __m512i>([1, 0, 1, 0, 1, 0, 1, 0]) }, + }; + const INV_2: Self = Self { data: PACKED_INV_2 }; const FIELD_SIZE: usize = 128; @@ -394,6 +398,65 @@ impl SimdField for AVX512GF2_128x4 { } } +impl ExtensionField for AVX512GF2_128x4 { + const DEGREE: usize = GF2_128::DEGREE; + + const W: u32 = GF2_128::W; + + const X: Self = Self { + data: unsafe { transmute::<[u64; 8], __m512i>([2u64, 0, 2u64, 0, 2u64, 0, 2u64, 0]) }, + }; + + type BaseField = GF2_128; + + #[inline(always)] + fn mul_by_base_field(&self, base: &Self::BaseField) -> Self { + let simd_base = AVX512GF2_128x4::from(*base); + *self * simd_base + } + + #[inline(always)] + fn add_by_base_field(&self, base: &Self::BaseField) -> Self { + unsafe { + let base_vec = transmute::(*base); + let mut res = transmute::(*self); + res[0] ^= base_vec; + Self { + data: transmute::<[u128; 4], __m512i>(res), + } + } + } + + #[inline(always)] + fn mul_by_x(&self) -> Self { + unsafe { + // Shift left by 1 bit + let shifted = _mm512_slli_epi64(self.data, 1); + + // Get the most significant bit of each 64-bit part + let msb = _mm512_srli_epi64(self.data, 63); + + // Move the MSB from the high 64 bits to the LSB of the low 64 bits for each 128-bit element + let msb_moved = _mm512_bslli_epi128(msb, 8); + + // Combine the shifted value with the moved msb + let shifted_consolidated = _mm512_or_si512(shifted, msb_moved); + + // compute the reduced polynomial + let reduction = { + let odd_elements = _mm512_maskz_compress_epi64(0b10101010, msb); + let mask = _mm512_maskz_expand_epi64(0b01010101, odd_elements); + let multiplier = _mm512_set1_epi64(0x87); + _mm512_mul_epu32(multiplier, mask) + }; + + // Apply the reduction conditionally + let res = _mm512_xor_si512(shifted_consolidated, reduction); + AVX512GF2_128x4 { data: res } + } + } +} + #[inline(always)] fn add_internal(a: &AVX512GF2_128x4, b: &AVX512GF2_128x4) -> AVX512GF2_128x4 { unsafe { diff --git a/arith/src/extension_field/gf2_128x4/neon.rs b/arith/src/extension_field/gf2_128x4/neon.rs index 16e05529..6b84fe45 100644 --- a/arith/src/extension_field/gf2_128x4/neon.rs +++ b/arith/src/extension_field/gf2_128x4/neon.rs @@ -3,12 +3,13 @@ use std::iter::{Product, Sum}; use std::mem::transmute; use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; -use crate::SimdField; +use crate::neon::mul_by_x_internal; use crate::{ field_common, neon::{gfadd, gfmul, NeonGF2_128}, Field, FieldSerde, FieldSerdeResult, }; +use crate::{ExtensionField, SimdField}; #[derive(Clone, Copy, Debug)] pub struct NeonGF2_128x4 { @@ -78,6 +79,10 @@ impl Field for NeonGF2_128x4 { v: [unsafe { transmute::<[u32; 4], uint32x4_t>([0, 0, 0, 0]) }; 4], }; + const ONE: Self = NeonGF2_128x4 { + v: [unsafe { transmute::<[u32; 4], uint32x4_t>([1, 0, 0, 0]) }; 4], + }; + const INV_2: Self = NeonGF2_128x4 { v: [unsafe { transmute::<[u32; 4], uint32x4_t>([0, 0, 0, 0]) }; 4], }; // should not be used @@ -220,6 +225,46 @@ impl From for NeonGF2_128x4 { } } +impl ExtensionField for NeonGF2_128x4 { + const DEGREE: usize = NeonGF2_128::DEGREE; + + const W: u32 = NeonGF2_128::W; + + const X: Self = Self { + v: unsafe { transmute::<[u64; 8], [uint32x4_t; 4]>([2u64, 0, 2u64, 0, 2u64, 0, 2u64, 0]) }, + }; + + type BaseField = NeonGF2_128; + + #[inline(always)] + fn mul_by_base_field(&self, base: &Self::BaseField) -> Self { + let simd_base = Self::from(*base); + *self * simd_base + } + + #[inline(always)] + fn add_by_base_field(&self, base: &Self::BaseField) -> Self { + unsafe { + let base_vec = transmute::(*base); + let mut res = transmute::(*self); + res[0] ^= base_vec; + Self { + v: transmute::<[u128; 4], [uint32x4_t; 4]>(res), + } + } + } + + #[inline(always)] + fn mul_by_x(&self) -> Self { + let mut res = Self::default(); + res.v[0] = mul_by_x_internal(&self.v[0]); + res.v[1] = mul_by_x_internal(&self.v[1]); + res.v[2] = mul_by_x_internal(&self.v[2]); + res.v[3] = mul_by_x_internal(&self.v[3]); + res + } +} + #[inline(always)] fn add_internal(a: &NeonGF2_128x4, b: &NeonGF2_128x4) -> NeonGF2_128x4 { NeonGF2_128x4 { diff --git a/arith/src/extension_field/m31_ext.rs b/arith/src/extension_field/m31_ext.rs index bb5a05a1..28eb76db 100644 --- a/arith/src/extension_field/m31_ext.rs +++ b/arith/src/extension_field/m31_ext.rs @@ -8,7 +8,7 @@ use std::{ use crate::{field_common, mod_reduce_u32, Field, FieldSerde, FieldSerdeResult, M31}; -use super::BinomialExtensionField; +use super::ExtensionField; #[derive(Debug, Clone, Copy, Default, PartialEq)] pub struct M31Ext3 { @@ -56,12 +56,17 @@ impl Field for M31Ext3 { const NAME: &'static str = "Mersenne 31 Extension 3"; const SIZE: usize = 32 / 8 * 3; + const FIELD_SIZE: usize = 32 * 3; const ZERO: Self = M31Ext3 { v: [M31::ZERO, M31::ZERO, M31::ZERO], }; + const ONE: Self = M31Ext3 { + v: [M31::ONE, M31::ZERO, M31::ZERO], + }; + const INV_2: M31Ext3 = M31Ext3 { v: [M31::INV_2, M31 { v: 0 }, M31 { v: 0 }], }; @@ -157,12 +162,16 @@ impl Field for M31Ext3 { } } -impl BinomialExtensionField for M31Ext3 { +impl ExtensionField for M31Ext3 { const DEGREE: usize = 3; /// Extension Field const W: u32 = 5; + const X: Self = M31Ext3 { + v: [M31::ZERO, M31::ONE, M31::ZERO], + }; + /// Base field for the extension type BaseField = M31; @@ -183,6 +192,14 @@ impl BinomialExtensionField for M31Ext3 { res[0] += base; Self { v: res } } + + /// Multiply the extension field by x, i.e, 0 + x + 0 x^2 + 0 x^3 + ... + #[inline(always)] + fn mul_by_x(&self) -> Self { + Self { + v: [self.v[2].mul_by_5(), self.v[0], self.v[1]], + } + } } impl Add for M31Ext3 { diff --git a/arith/src/extension_field/m31_ext3x16.rs b/arith/src/extension_field/m31_ext3x16.rs index b8183cb4..cf60542d 100644 --- a/arith/src/extension_field/m31_ext3x16.rs +++ b/arith/src/extension_field/m31_ext3x16.rs @@ -5,8 +5,8 @@ use std::{ }; use crate::{ - field_common, BinomialExtensionField, Field, FieldSerde, FieldSerdeResult, M31Ext3, M31x16, - SimdField, M31, + field_common, ExtensionField, Field, FieldSerde, FieldSerdeResult, M31Ext3, M31x16, SimdField, + M31, }; #[derive(Debug, Clone, Copy, Default, PartialEq)] @@ -73,11 +73,15 @@ impl From for M31Ext3x16 { } } -impl BinomialExtensionField for M31Ext3x16 { +impl ExtensionField for M31Ext3x16 { const DEGREE: usize = 3; const W: u32 = 5; + const X: Self = M31Ext3x16 { + v: [M31x16::ZERO, M31x16::ONE, M31x16::ZERO], + }; + type BaseField = M31x16; #[inline(always)] @@ -93,6 +97,14 @@ impl BinomialExtensionField for M31Ext3x16 { v: [self.v[0] + base, self.v[1], self.v[2]], } } + + /// Multiply the extension field by x, i.e, 0 + x + 0 x^2 + 0 x^3 + ... + #[inline(always)] + fn mul_by_x(&self) -> Self { + Self { + v: [self.v[2].mul_by_5(), self.v[0], self.v[1]], + } + } } impl From for M31Ext3x16 { @@ -123,6 +135,10 @@ impl Field for M31Ext3x16 { v: [M31x16::ZERO; 3], }; + const ONE: Self = Self { + v: [M31x16::ONE, M31x16::ZERO, M31x16::ZERO], + }; + const INV_2: Self = Self { v: [M31x16::INV_2, M31x16::ZERO, M31x16::ZERO], }; diff --git a/arith/src/field.rs b/arith/src/field.rs index 4752a7dd..260ba90c 100644 --- a/arith/src/field.rs +++ b/arith/src/field.rs @@ -50,6 +50,9 @@ pub trait Field: /// zero const ZERO: Self; + /// One + const ONE: Self; + /// Inverse of 2 const INV_2: Self; diff --git a/arith/src/field/bn254.rs b/arith/src/field/bn254.rs index 0406f820..5bc1cffd 100644 --- a/arith/src/field/bn254.rs +++ b/arith/src/field/bn254.rs @@ -19,6 +19,9 @@ impl Field for Fr { /// zero const ZERO: Self = Fr::zero(); + /// One + const ONE: Self = Fr::one(); + /// Inverse of 2 const INV_2: Self = Fr::TWO_INV; diff --git a/arith/src/field/gf2.rs b/arith/src/field/gf2.rs index 8985edef..284667a1 100644 --- a/arith/src/field/gf2.rs +++ b/arith/src/field/gf2.rs @@ -46,11 +46,19 @@ impl FieldSerde for GF2 { impl Field for GF2 { // still will pack 8 bits into a u8 + const NAME: &'static str = "Galios Field 2"; + const SIZE: usize = 1; + const FIELD_SIZE: usize = 1; // in bits + const ZERO: Self = GF2 { v: 0 }; + + const ONE: Self = GF2 { v: 1 }; + const INV_2: Self = GF2 { v: 0 }; // should not be used + #[inline(always)] fn zero() -> Self { GF2 { v: 0 } diff --git a/arith/src/field/gf2/gf2x8.rs b/arith/src/field/gf2/gf2x8.rs index 001b3d8a..51fed85f 100644 --- a/arith/src/field/gf2/gf2x8.rs +++ b/arith/src/field/gf2/gf2x8.rs @@ -36,11 +36,19 @@ impl FieldSerde for GF2x8 { impl Field for GF2x8 { // still will pack 8 bits into a u8 + const NAME: &'static str = "Galios Field 2 SIMD"; + const SIZE: usize = 1; + const FIELD_SIZE: usize = 1; // in bits + const ZERO: Self = GF2x8 { v: 0 }; + + const ONE: Self = GF2x8 { v: 255 }; + const INV_2: Self = GF2x8 { v: 0 }; // should not be used + #[inline(always)] fn zero() -> Self { GF2x8 { v: 0 } diff --git a/arith/src/field/m31.rs b/arith/src/field/m31.rs index e1b69c66..c0b3dd78 100644 --- a/arith/src/field/m31.rs +++ b/arith/src/field/m31.rs @@ -90,6 +90,8 @@ impl Field for M31 { const ZERO: Self = M31 { v: 0 }; + const ONE: Self = M31 { v: 1 }; + const INV_2: M31 = M31 { v: 1 << 30 }; const FIELD_SIZE: usize = 32; diff --git a/arith/src/field/m31/m31_avx.rs b/arith/src/field/m31/m31_avx.rs index a3182bec..ea8e6886 100644 --- a/arith/src/field/m31/m31_avx.rs +++ b/arith/src/field/m31/m31_avx.rs @@ -82,6 +82,10 @@ impl Field for AVXM31 { const ZERO: Self = Self { v: PACKED_0 }; + const ONE: Self = Self { + v: unsafe { transmute::<[u32; 16], __m512i>([1; M31_PACK_SIZE]) }, + }; + const INV_2: Self = Self { v: PACKED_INV_2 }; const FIELD_SIZE: usize = 32; diff --git a/arith/src/field/m31/m31_neon.rs b/arith/src/field/m31/m31_neon.rs index 1c66010a..7676c554 100644 --- a/arith/src/field/m31/m31_neon.rs +++ b/arith/src/field/m31/m31_neon.rs @@ -95,6 +95,10 @@ impl Field for NeonM31 { const ZERO: Self = Self { v: [PACKED_0; 4] }; + const ONE: Self = Self { + v: [unsafe { transmute::<[u32; 4], uint32x4_t>([1; 4]) }; 4], + }; + const INV_2: Self = Self { v: [PACKED_INV_2; 4], }; diff --git a/arith/src/tests/extension_field.rs b/arith/src/tests/extension_field.rs index 5d91547c..4067229f 100644 --- a/arith/src/tests/extension_field.rs +++ b/arith/src/tests/extension_field.rs @@ -1,44 +1,53 @@ use ark_std::test_rng; use crate::field::Field; -use crate::BinomialExtensionField; +use crate::ExtensionField; -pub(crate) fn random_extension_field_tests(_name: String) { +pub(crate) fn random_extension_field_tests(_name: String) { let mut rng = test_rng(); - - { - let a = F::random_unsafe(&mut rng); - let s1 = F::BaseField::random_unsafe(&mut rng); - let s2 = F::BaseField::random_unsafe(&mut rng); - - assert_eq!( - a.mul_by_base_field(&s1).mul_by_base_field(&s2), - a.mul_by_base_field(&s2).mul_by_base_field(&s1), - ); - assert_eq!( - a.mul_by_base_field(&s1).mul_by_base_field(&s2), - a.mul_by_base_field(&(s1 * s2)), - ); - - assert_eq!( - a.add_by_base_field(&s1).add_by_base_field(&s2), - a.add_by_base_field(&s2).add_by_base_field(&s1), - ); - assert_eq!( - a.add_by_base_field(&s1).add_by_base_field(&s2), - a.add_by_base_field(&(s1 + s2)), - ); - } - - { - let a = F::random_unsafe(&mut rng); - let b = F::random_unsafe(&mut rng); - let s = F::BaseField::random_unsafe(&mut rng); - - assert_eq!(a.mul_by_base_field(&s) * b, (a * b).mul_by_base_field(&s),); - assert_eq!(b.mul_by_base_field(&s) * a, (a * b).mul_by_base_field(&s),); - - assert_eq!(a.add_by_base_field(&s) + b, (a + b).add_by_base_field(&s),); - assert_eq!(b.add_by_base_field(&s) + a, (a + b).add_by_base_field(&s),); + for _ in 0..1000 { + { + let a = F::random_unsafe(&mut rng); + let s1 = F::BaseField::random_unsafe(&mut rng); + let s2 = F::BaseField::random_unsafe(&mut rng); + + assert_eq!( + a.mul_by_base_field(&s1).mul_by_base_field(&s2), + a.mul_by_base_field(&s2).mul_by_base_field(&s1), + ); + assert_eq!( + a.mul_by_base_field(&s1).mul_by_base_field(&s2), + a.mul_by_base_field(&(s1 * s2)), + ); + + assert_eq!( + a.add_by_base_field(&s1).add_by_base_field(&s2), + a.add_by_base_field(&s2).add_by_base_field(&s1), + ); + assert_eq!( + a.add_by_base_field(&s1).add_by_base_field(&s2), + a.add_by_base_field(&(s1 + s2)), + ); + } + + { + let a = F::random_unsafe(&mut rng); + let b = F::random_unsafe(&mut rng); + let s = F::BaseField::random_unsafe(&mut rng); + + assert_eq!(a.mul_by_base_field(&s) * b, (a * b).mul_by_base_field(&s),); + assert_eq!(b.mul_by_base_field(&s) * a, (a * b).mul_by_base_field(&s),); + + assert_eq!(a.add_by_base_field(&s) + b, (a + b).add_by_base_field(&s),); + assert_eq!(b.add_by_base_field(&s) + a, (a + b).add_by_base_field(&s),); + } + + { + let a = F::random_unsafe(&mut rng); + let b = F::X; + let ax = a.mul_by_x(); + let ab = a * b; + assert_eq!(ax, ab); + } } } diff --git a/arith/src/tests/gf2_128.rs b/arith/src/tests/gf2_128.rs index e8b4f64f..5b1503ce 100644 --- a/arith/src/tests/gf2_128.rs +++ b/arith/src/tests/gf2_128.rs @@ -4,6 +4,7 @@ use std::io::Cursor; use crate::{FieldSerde, GF2_128x4, GF2_128}; use super::{ + extension_field::random_extension_field_tests, field::{random_field_tests, random_inversion_tests}, simd_field::random_simd_field_tests, }; @@ -11,7 +12,9 @@ use super::{ #[test] fn test_field() { random_field_tests::("GF2_128".to_string()); + random_extension_field_tests::("GF2_128".to_string()); random_field_tests::("Vectorized GF2_128".to_string()); + random_extension_field_tests::("Vectorized GF2_128".to_string()); let mut rng = test_rng(); random_inversion_tests::(&mut rng, "GF2_128".to_string()); diff --git a/src/config.rs b/src/config.rs index 6087adf3..d1637bdd 100644 --- a/src/config.rs +++ b/src/config.rs @@ -8,7 +8,7 @@ pub use bn254_sha2::BN254ConfigSha2; pub use m31_ext_keccak::M31ExtConfigKeccak; pub use m31_ext_sha2::M31ExtConfigSha2; -use arith::{BinomialExtensionField, Field, FieldSerde, SimdField}; +use arith::{ExtensionField, Field, FieldSerde, SimdField}; use crate::FiatShamirHash; @@ -91,10 +91,10 @@ pub trait GKRConfig: Default + Clone + Send + Sync + 'static { type CircuitField: Field + FieldSerde + Send; /// Field type for the challenge, e.g., M31Ext3 - type ChallengeField: BinomialExtensionField + Send; + type ChallengeField: ExtensionField + Send; /// Main field type for the scheme, e.g., M31Ext3x16 - type Field: BinomialExtensionField + SimdField + Send; + type Field: ExtensionField + SimdField + Send; /// Simd field for circuit type SimdCircuitField: SimdField + FieldSerde + Send; diff --git a/src/config/m31_ext_keccak.rs b/src/config/m31_ext_keccak.rs index 2a0d75d0..427d742b 100644 --- a/src/config/m31_ext_keccak.rs +++ b/src/config/m31_ext_keccak.rs @@ -1,4 +1,4 @@ -use arith::{BinomialExtensionField, M31Ext3, M31Ext3x16, M31x16, M31}; +use arith::{ExtensionField, M31Ext3, M31Ext3x16, M31x16, M31}; use crate::Keccak256hasher; diff --git a/src/config/m31_ext_sha2.rs b/src/config/m31_ext_sha2.rs index 336700ad..0eb6d8ba 100644 --- a/src/config/m31_ext_sha2.rs +++ b/src/config/m31_ext_sha2.rs @@ -1,4 +1,4 @@ -use arith::{BinomialExtensionField, M31Ext3, M31Ext3x16, M31x16, M31}; +use arith::{ExtensionField, M31Ext3, M31Ext3x16, M31x16, M31}; use crate::SHA256hasher;