diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 0f82277..a8d7fd2 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -22,8 +22,5 @@ jobs: run: | go get -v -t -d ./... - - name: TestBlsCurve - run: go test -v ./bls12_381/... --bench=. -cover - - name: TestBenchmark - run: go test -v --bench=. -cover \ No newline at end of file + run: go test -v --bench=. -cover diff --git a/README.md b/README.md index 0ad842c..b6681c0 100644 --- a/README.md +++ b/README.md @@ -27,14 +27,14 @@ func main() { // generate round constants for poseidon hash. // width=len(input)+1. - cons, _ := GenPoseidonConstants(4) + cons, _ := GenPoseidonConstants[*fr.Element](4) // use OptimizedStatic hash mode. - h1, _ := Hash(input, cons, OptimizedStatic) + h1, _ := Hash[*fr.Element](input, cons, OptimizedStatic) // use OptimizedDynamic hash mode. - h2, _ := Hash(input, cons, OptimizedDynamic) + h2, _ := Hash[*fr.Element](input, cons, OptimizedDynamic) // use Correct hash mode. - h3, _ := Hash(input, cons, Correct) + h3, _ := Hash[*fr.Element](input, cons, Correct) } ``` # Benchmark diff --git a/bls12_381/arith.go b/bls12_381/arith.go deleted file mode 100644 index c4d332f..0000000 --- a/bls12_381/arith.go +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package bls12_381 - -import ( - "math/bits" -) - -// madd0 hi = a*b + c (discards lo bits) -func madd0(a, b, c uint64) (hi uint64) { - var carry, lo uint64 - hi, lo = bits.Mul64(a, b) - _, carry = bits.Add64(lo, c, 0) - hi, _ = bits.Add64(hi, 0, carry) - return -} - -// madd1 hi, lo = a*b + c -func madd1(a, b, c uint64) (hi uint64, lo uint64) { - var carry uint64 - hi, lo = bits.Mul64(a, b) - lo, carry = bits.Add64(lo, c, 0) - hi, _ = bits.Add64(hi, 0, carry) - return -} - -// madd2 hi, lo = a*b + c + d -func madd2(a, b, c, d uint64) (hi uint64, lo uint64) { - var carry uint64 - hi, lo = bits.Mul64(a, b) - c, carry = bits.Add64(c, d, 0) - hi, _ = bits.Add64(hi, 0, carry) - lo, carry = bits.Add64(lo, c, 0) - hi, _ = bits.Add64(hi, 0, carry) - return -} - -func madd3(a, b, c, d, e uint64) (hi uint64, lo uint64) { - var carry uint64 - hi, lo = bits.Mul64(a, b) - c, carry = bits.Add64(c, d, 0) - hi, _ = bits.Add64(hi, 0, carry) - lo, carry = bits.Add64(lo, c, 0) - hi, _ = bits.Add64(hi, e, carry) - return -} diff --git a/bls12_381/asm.go b/bls12_381/asm.go deleted file mode 100644 index d35aa2e..0000000 --- a/bls12_381/asm.go +++ /dev/null @@ -1,23 +0,0 @@ -// +build !noadx - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package bls12_381 - -import "golang.org/x/sys/cpu" - -var supportAdx = cpu.X86.HasADX && cpu.X86.HasBMI2 diff --git a/bls12_381/asm_noadx.go b/bls12_381/asm_noadx.go deleted file mode 100644 index 4a418cc..0000000 --- a/bls12_381/asm_noadx.go +++ /dev/null @@ -1,24 +0,0 @@ -// +build noadx - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package bls12_381 - -// note: this is needed for test purposes, as dynamically changing supportAdx doesn't flag -// certain errors (like fatal error: missing stackmap) -// this ensures we test all asm path. -var supportAdx = false diff --git a/bls12_381/doc.go b/bls12_381/doc.go deleted file mode 100644 index ca7e12e..0000000 --- a/bls12_381/doc.go +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -// Package bls12_381 contains field arithmetic operations for modulus = 0x73eda7...000001. -// -// The API is similar to math/big (big.Int), but the operations are significantly faster (up to 20x for the modular multiplication on amd64, see also https://hackmd.io/@zkteam/modular_multiplication) -// -// The modulus is hardcoded in all the operations. -// -// Field elements are represented as an array, and assumed to be in Montgomery form in all methods: -// type Element [4]uint64 -// -// Example API signature -// // Mul z = x * y mod q -// func (z *Element) Mul(x, y *Element) *Element -// -// and can be used like so: -// var a, b Element -// a.SetUint64(2) -// b.SetString("984896738") -// a.Mul(a, b) -// a.Sub(a, a) -// .Add(a, b) -// .Inv(a) -// b.Exp(b, new(big.Int).SetUint64(42)) -// -// Modulus -// 0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001 // base 16 -// 52435875175126190479447740508185965837690552500527637822603658699938581184513 // base 10 -package bls12_381 diff --git a/bls12_381/element.go b/bls12_381/element.go deleted file mode 100644 index 983a7c8..0000000 --- a/bls12_381/element.go +++ /dev/null @@ -1,952 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package bls12_381 - -// /!\ WARNING /!\ -// this code has not been audited and is provided as-is. In particular, -// there is no security guarantees such as constant time implementation -// or side-channel attack resistance -// /!\ WARNING /!\ - -import ( - "crypto/rand" - "encoding/binary" - "io" - "math/big" - "math/bits" - "strconv" - "sync" -) - -// Element represents a field element stored on 4 words (uint64) -// Element are assumed to be in Montgomery form in all methods -// field modulus q = -// -// 52435875175126190479447740508185965837690552500527637822603658699938581184513 -type Element [4]uint64 - -// Limbs number of 64 bits words needed to represent Element -const Limbs = 4 - -// Bits number bits needed to represent Element -const Bits = 255 - -// Bytes number bytes needed to represent Element -const Bytes = Limbs * 8 - -// field modulus stored as big.Int -var _modulus big.Int - -// Modulus returns q as a big.Int -// q = -// -// 52435875175126190479447740508185965837690552500527637822603658699938581184513 -func Modulus() *big.Int { - return new(big.Int).Set(&_modulus) -} - -// Returns true if the big integer represents a valid field element, i.e. it's smaller -// than the modulus. -func IsValid(z *big.Int) bool { - return z.Cmp(Modulus()) == -1 -} - -// q (modulus) -var qElement = Element{ - 18446744069414584321, - 6034159408538082302, - 3691218898639771653, - 8353516859464449352, -} - -// rSquare -var rSquare = Element{ - 14526898881837571181, - 3129137299524312099, - 419701826671360399, - 524908885293268753, -} - -var bigIntPool = sync.Pool{ - New: func() interface{} { - return new(big.Int) - }, -} - -func init() { - _modulus.SetString("52435875175126190479447740508185965837690552500527637822603658699938581184513", 10) -} - -// SetUint64 z = v, sets z LSB to v (non-Montgomery form) and convert z to Montgomery form -func (z *Element) SetUint64(v uint64) *Element { - *z = Element{v} - return z.Mul(z, &rSquare) // z.ToMont() -} - -// Set z = x -func (z *Element) Set(x *Element) *Element { - z[0] = x[0] - z[1] = x[1] - z[2] = x[2] - z[3] = x[3] - return z -} - -// SetInterface converts i1 from uint64, int, string, or Element, big.Int into Element -// panic if provided type is not supported -func (z *Element) SetInterface(i1 interface{}) *Element { - switch c1 := i1.(type) { - case Element: - return z.Set(&c1) - case *Element: - return z.Set(c1) - case uint64: - return z.SetUint64(c1) - case int: - return z.SetString(strconv.Itoa(c1)) - case string: - return z.SetString(c1) - case *big.Int: - return z.SetBigInt(c1) - case big.Int: - return z.SetBigInt(&c1) - case []byte: - return z.SetBytes(c1) - default: - panic("invalid type") - } -} - -// SetZero z = 0 -func (z *Element) SetZero() *Element { - z[0] = 0 - z[1] = 0 - z[2] = 0 - z[3] = 0 - return z -} - -// SetOne z = 1 (in Montgomery form) -func (z *Element) SetOne() *Element { - z[0] = 8589934590 - z[1] = 6378425256633387010 - z[2] = 11064306276430008309 - z[3] = 1739710354780652911 - return z -} - -// Div z = x*y^-1 mod q -func (z *Element) Div(x, y *Element) *Element { - var yInv Element - yInv.Inverse(y) - z.Mul(x, &yInv) - return z -} - -// Equal returns z == x -func (z *Element) Equal(x *Element) bool { - return (z[3] == x[3]) && (z[2] == x[2]) && (z[1] == x[1]) && (z[0] == x[0]) -} - -// IsZero returns z == 0 -func (z *Element) IsZero() bool { - return (z[3] | z[2] | z[1] | z[0]) == 0 -} - -// Cmp compares (lexicographic order) z and x and returns: -// -// -1 if z < x -// 0 if z == x -// +1 if z > x -// -func (z *Element) Cmp(x *Element) int { - _z := *z - _x := *x - _z.FromMont() - _x.FromMont() - if _z[3] > _x[3] { - return 1 - } else if _z[3] < _x[3] { - return -1 - } - if _z[2] > _x[2] { - return 1 - } else if _z[2] < _x[2] { - return -1 - } - if _z[1] > _x[1] { - return 1 - } else if _z[1] < _x[1] { - return -1 - } - if _z[0] > _x[0] { - return 1 - } else if _z[0] < _x[0] { - return -1 - } - return 0 -} - -// LexicographicallyLargest returns true if this element is strictly lexicographically -// larger than its negation, false otherwise -func (z *Element) LexicographicallyLargest() bool { - // adapted from github.com/zkcrypto/bls12_381 - // we check if the element is larger than (q-1) / 2 - // if z - (((q -1) / 2) + 1) have no underflow, then z > (q-1) / 2 - - _z := *z - _z.FromMont() - - var b uint64 - _, b = bits.Sub64(_z[0], 9223372034707292161, 0) - _, b = bits.Sub64(_z[1], 12240451741123816959, b) - _, b = bits.Sub64(_z[2], 1845609449319885826, b) - _, b = bits.Sub64(_z[3], 4176758429732224676, b) - - return b == 0 -} - -// SetRandom sets z to a random element < q -func (z *Element) SetRandom() (*Element, error) { - var bytes [32]byte - if _, err := io.ReadFull(rand.Reader, bytes[:]); err != nil { - return nil, err - } - z[0] = binary.BigEndian.Uint64(bytes[0:8]) - z[1] = binary.BigEndian.Uint64(bytes[8:16]) - z[2] = binary.BigEndian.Uint64(bytes[16:24]) - z[3] = binary.BigEndian.Uint64(bytes[24:32]) - z[3] %= 8353516859464449352 - - // if z > q --> z -= q - // note: this is NOT constant time - if !(z[3] < 8353516859464449352 || (z[3] == 8353516859464449352 && (z[2] < 3691218898639771653 || (z[2] == 3691218898639771653 && (z[1] < 6034159408538082302 || (z[1] == 6034159408538082302 && (z[0] < 18446744069414584321))))))) { - var b uint64 - z[0], b = bits.Sub64(z[0], 18446744069414584321, 0) - z[1], b = bits.Sub64(z[1], 6034159408538082302, b) - z[2], b = bits.Sub64(z[2], 3691218898639771653, b) - z[3], _ = bits.Sub64(z[3], 8353516859464449352, b) - } - - return z, nil -} - -// One returns 1 (in montgommery form) -func One() Element { - var one Element - one.SetOne() - return one -} - -// API with assembly impl - -// Mul z = x * y mod q -// see https://hackmd.io/@zkteam/modular_multiplication -func (z *Element) Mul(x, y *Element) *Element { - mul(z, x, y) - return z -} - -// Square z = x * x mod q -// see https://hackmd.io/@zkteam/modular_multiplication -func (z *Element) Square(x *Element) *Element { - mul(z, x, x) - return z -} - -// FromMont converts z in place (i.e. mutates) from Montgomery to regular representation -// sets and returns z = z * 1 -func (z *Element) FromMont() *Element { - fromMont(z) - return z -} - -// Add z = x + y mod q -func (z *Element) Add(x, y *Element) *Element { - add(z, x, y) - return z -} - -// Double z = x + x mod q, aka Lsh 1 -func (z *Element) Double(x *Element) *Element { - double(z, x) - return z -} - -// Sub z = x - y mod q -func (z *Element) Sub(x, y *Element) *Element { - sub(z, x, y) - return z -} - -// Neg z = q - x -func (z *Element) Neg(x *Element) *Element { - neg(z, x) - return z -} - -// Generic (no ADX instructions, no AMD64) versions of multiplication and squaring algorithms - -func _mulGeneric(z, x, y *Element) { - - var t [4]uint64 - var c [3]uint64 - { - // round 0 - v := x[0] - c[1], c[0] = bits.Mul64(v, y[0]) - m := c[0] * 18446744069414584319 - c[2] = madd0(m, 18446744069414584321, c[0]) - c[1], c[0] = madd1(v, y[1], c[1]) - c[2], t[0] = madd2(m, 6034159408538082302, c[2], c[0]) - c[1], c[0] = madd1(v, y[2], c[1]) - c[2], t[1] = madd2(m, 3691218898639771653, c[2], c[0]) - c[1], c[0] = madd1(v, y[3], c[1]) - t[3], t[2] = madd3(m, 8353516859464449352, c[0], c[2], c[1]) - } - { - // round 1 - v := x[1] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * 18446744069414584319 - c[2] = madd0(m, 18446744069414584321, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, 6034159408538082302, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, 3691218898639771653, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - t[3], t[2] = madd3(m, 8353516859464449352, c[0], c[2], c[1]) - } - { - // round 2 - v := x[2] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * 18446744069414584319 - c[2] = madd0(m, 18446744069414584321, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, 6034159408538082302, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, 3691218898639771653, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - t[3], t[2] = madd3(m, 8353516859464449352, c[0], c[2], c[1]) - } - { - // round 3 - v := x[3] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * 18446744069414584319 - c[2] = madd0(m, 18446744069414584321, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], z[0] = madd2(m, 6034159408538082302, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], z[1] = madd2(m, 3691218898639771653, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - z[3], z[2] = madd3(m, 8353516859464449352, c[0], c[2], c[1]) - } - - // if z > q --> z -= q - // note: this is NOT constant time - if !(z[3] < 8353516859464449352 || (z[3] == 8353516859464449352 && (z[2] < 3691218898639771653 || (z[2] == 3691218898639771653 && (z[1] < 6034159408538082302 || (z[1] == 6034159408538082302 && (z[0] < 18446744069414584321))))))) { - var b uint64 - z[0], b = bits.Sub64(z[0], 18446744069414584321, 0) - z[1], b = bits.Sub64(z[1], 6034159408538082302, b) - z[2], b = bits.Sub64(z[2], 3691218898639771653, b) - z[3], _ = bits.Sub64(z[3], 8353516859464449352, b) - } -} - -func _fromMontGeneric(z *Element) { - // the following lines implement z = z * 1 - // with a modified CIOS montgomery multiplication - { - // m = z[0]n'[0] mod W - m := z[0] * 18446744069414584319 - C := madd0(m, 18446744069414584321, z[0]) - C, z[0] = madd2(m, 6034159408538082302, z[1], C) - C, z[1] = madd2(m, 3691218898639771653, z[2], C) - C, z[2] = madd2(m, 8353516859464449352, z[3], C) - z[3] = C - } - { - // m = z[0]n'[0] mod W - m := z[0] * 18446744069414584319 - C := madd0(m, 18446744069414584321, z[0]) - C, z[0] = madd2(m, 6034159408538082302, z[1], C) - C, z[1] = madd2(m, 3691218898639771653, z[2], C) - C, z[2] = madd2(m, 8353516859464449352, z[3], C) - z[3] = C - } - { - // m = z[0]n'[0] mod W - m := z[0] * 18446744069414584319 - C := madd0(m, 18446744069414584321, z[0]) - C, z[0] = madd2(m, 6034159408538082302, z[1], C) - C, z[1] = madd2(m, 3691218898639771653, z[2], C) - C, z[2] = madd2(m, 8353516859464449352, z[3], C) - z[3] = C - } - { - // m = z[0]n'[0] mod W - m := z[0] * 18446744069414584319 - C := madd0(m, 18446744069414584321, z[0]) - C, z[0] = madd2(m, 6034159408538082302, z[1], C) - C, z[1] = madd2(m, 3691218898639771653, z[2], C) - C, z[2] = madd2(m, 8353516859464449352, z[3], C) - z[3] = C - } - - // if z > q --> z -= q - // note: this is NOT constant time - if !(z[3] < 8353516859464449352 || (z[3] == 8353516859464449352 && (z[2] < 3691218898639771653 || (z[2] == 3691218898639771653 && (z[1] < 6034159408538082302 || (z[1] == 6034159408538082302 && (z[0] < 18446744069414584321))))))) { - var b uint64 - z[0], b = bits.Sub64(z[0], 18446744069414584321, 0) - z[1], b = bits.Sub64(z[1], 6034159408538082302, b) - z[2], b = bits.Sub64(z[2], 3691218898639771653, b) - z[3], _ = bits.Sub64(z[3], 8353516859464449352, b) - } -} - -func _addGeneric(z, x, y *Element) { - var carry uint64 - - z[0], carry = bits.Add64(x[0], y[0], 0) - z[1], carry = bits.Add64(x[1], y[1], carry) - z[2], carry = bits.Add64(x[2], y[2], carry) - z[3], _ = bits.Add64(x[3], y[3], carry) - - // if z > q --> z -= q - // note: this is NOT constant time - if !(z[3] < 8353516859464449352 || (z[3] == 8353516859464449352 && (z[2] < 3691218898639771653 || (z[2] == 3691218898639771653 && (z[1] < 6034159408538082302 || (z[1] == 6034159408538082302 && (z[0] < 18446744069414584321))))))) { - var b uint64 - z[0], b = bits.Sub64(z[0], 18446744069414584321, 0) - z[1], b = bits.Sub64(z[1], 6034159408538082302, b) - z[2], b = bits.Sub64(z[2], 3691218898639771653, b) - z[3], _ = bits.Sub64(z[3], 8353516859464449352, b) - } -} - -func _doubleGeneric(z, x *Element) { - var carry uint64 - - z[0], carry = bits.Add64(x[0], x[0], 0) - z[1], carry = bits.Add64(x[1], x[1], carry) - z[2], carry = bits.Add64(x[2], x[2], carry) - z[3], _ = bits.Add64(x[3], x[3], carry) - - // if z > q --> z -= q - // note: this is NOT constant time - if !(z[3] < 8353516859464449352 || (z[3] == 8353516859464449352 && (z[2] < 3691218898639771653 || (z[2] == 3691218898639771653 && (z[1] < 6034159408538082302 || (z[1] == 6034159408538082302 && (z[0] < 18446744069414584321))))))) { - var b uint64 - z[0], b = bits.Sub64(z[0], 18446744069414584321, 0) - z[1], b = bits.Sub64(z[1], 6034159408538082302, b) - z[2], b = bits.Sub64(z[2], 3691218898639771653, b) - z[3], _ = bits.Sub64(z[3], 8353516859464449352, b) - } -} - -func _subGeneric(z, x, y *Element) { - var b uint64 - z[0], b = bits.Sub64(x[0], y[0], 0) - z[1], b = bits.Sub64(x[1], y[1], b) - z[2], b = bits.Sub64(x[2], y[2], b) - z[3], b = bits.Sub64(x[3], y[3], b) - if b != 0 { - var c uint64 - z[0], c = bits.Add64(z[0], 18446744069414584321, 0) - z[1], c = bits.Add64(z[1], 6034159408538082302, c) - z[2], c = bits.Add64(z[2], 3691218898639771653, c) - z[3], _ = bits.Add64(z[3], 8353516859464449352, c) - } -} - -func _negGeneric(z, x *Element) { - if x.IsZero() { - z.SetZero() - return - } - var borrow uint64 - z[0], borrow = bits.Sub64(18446744069414584321, x[0], 0) - z[1], borrow = bits.Sub64(6034159408538082302, x[1], borrow) - z[2], borrow = bits.Sub64(3691218898639771653, x[2], borrow) - z[3], _ = bits.Sub64(8353516859464449352, x[3], borrow) -} - -func _reduceGeneric(z *Element) { - - // if z > q --> z -= q - // note: this is NOT constant time - if !(z[3] < 8353516859464449352 || (z[3] == 8353516859464449352 && (z[2] < 3691218898639771653 || (z[2] == 3691218898639771653 && (z[1] < 6034159408538082302 || (z[1] == 6034159408538082302 && (z[0] < 18446744069414584321))))))) { - var b uint64 - z[0], b = bits.Sub64(z[0], 18446744069414584321, 0) - z[1], b = bits.Sub64(z[1], 6034159408538082302, b) - z[2], b = bits.Sub64(z[2], 3691218898639771653, b) - z[3], _ = bits.Sub64(z[3], 8353516859464449352, b) - } -} - -func mulByConstant(z *Element, c uint8) { - switch c { - case 0: - z.SetZero() - return - case 1: - return - case 2: - z.Double(z) - return - case 3: - _z := *z - z.Double(z).Add(z, &_z) - case 5: - _z := *z - z.Double(z).Double(z).Add(z, &_z) - default: - var y Element - y.SetUint64(uint64(c)) - z.Mul(z, &y) - } -} - -// Exp z = x^exponent mod q -func (z *Element) Exp(x Element, exponent *big.Int) *Element { - var bZero big.Int - if exponent.Cmp(&bZero) == 0 { - return z.SetOne() - } - - z.Set(&x) - - for i := exponent.BitLen() - 2; i >= 0; i-- { - z.Square(z) - if exponent.Bit(i) == 1 { - z.Mul(z, &x) - } - } - - return z -} - -// ToMont converts z to Montgomery form -// sets and returns z = z * r^2 -func (z *Element) ToMont() *Element { - return z.Mul(z, &rSquare) -} - -// ToRegular returns z in regular form (doesn't mutate z) -func (z Element) ToRegular() Element { - return *z.FromMont() -} - -// String returns the string form of an Element in Montgomery form -func (z *Element) String() string { - vv := bigIntPool.Get().(*big.Int) - defer bigIntPool.Put(vv) - return z.ToBigIntRegular(vv).String() -} - -// ToBigInt returns z as a big.Int in Montgomery form -func (z *Element) ToBigInt(res *big.Int) *big.Int { - var b [Limbs * 8]byte - binary.BigEndian.PutUint64(b[24:32], z[0]) - binary.BigEndian.PutUint64(b[16:24], z[1]) - binary.BigEndian.PutUint64(b[8:16], z[2]) - binary.BigEndian.PutUint64(b[0:8], z[3]) - - return res.SetBytes(b[:]) -} - -// ToBigIntRegular returns z as a big.Int in regular form -func (z Element) ToBigIntRegular(res *big.Int) *big.Int { - z.FromMont() - return z.ToBigInt(res) -} - -// Bytes returns the regular (non montgomery) value -// of z as a big-endian byte array. -func (z *Element) Bytes() (res [Limbs * 8]byte) { - _z := z.ToRegular() - binary.BigEndian.PutUint64(res[24:32], _z[0]) - binary.BigEndian.PutUint64(res[16:24], _z[1]) - binary.BigEndian.PutUint64(res[8:16], _z[2]) - binary.BigEndian.PutUint64(res[0:8], _z[3]) - - return -} - -// SetBytes interprets e as the bytes of a big-endian unsigned integer, -// sets z to that value (in Montgomery form), and returns z. -func (z *Element) SetBytes(e []byte) *Element { - // get a big int from our pool - vv := bigIntPool.Get().(*big.Int) - vv.SetBytes(e) - - // set big int - z.SetBigInt(vv) - - // put temporary object back in pool - bigIntPool.Put(vv) - - return z -} - -// SetBigInt sets z to v (regular form) and returns z in Montgomery form -func (z *Element) SetBigInt(v *big.Int) *Element { - z.SetZero() - - var zero big.Int - - // fast path - c := v.Cmp(&_modulus) - if c == 0 { - // v == 0 - return z - } else if c != 1 && v.Cmp(&zero) != -1 { - // 0 < v < q - return z.setBigInt(v) - } - - // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) - - // copy input + modular reduction - vv.Set(v) - vv.Mod(v, &_modulus) - - // set big int byte value - z.setBigInt(vv) - - // release object into pool - bigIntPool.Put(vv) - return z -} - -// setBigInt assumes 0 <= v < q -func (z *Element) setBigInt(v *big.Int) *Element { - vBits := v.Bits() - - if bits.UintSize == 64 { - for i := 0; i < len(vBits); i++ { - z[i] = uint64(vBits[i]) - } - } else { - for i := 0; i < len(vBits); i++ { - if i%2 == 0 { - z[i/2] = uint64(vBits[i]) - } else { - z[i/2] |= uint64(vBits[i]) << 32 - } - } - } - - return z.ToMont() -} - -// SetString creates a big.Int with s (in base 10) and calls SetBigInt on z -func (z *Element) SetString(s string) *Element { - // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) - - if _, ok := vv.SetString(s, 10); !ok { - panic("Element.SetString failed -> can't parse number in base10 into a big.Int") - } - z.SetBigInt(vv) - - // release object into pool - bigIntPool.Put(vv) - - return z -} - -// SetHexString creates a big.Int with s (in base 16) and calls SetBigInt on z -func (z *Element) SetHexString(s string) *Element { - // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) - - if _, ok := vv.SetString(s, 16); !ok { - panic("Element.SetString failed -> can't parse number in base10 into a big.Int") - } - z.SetBigInt(vv) - - // release object into pool - bigIntPool.Put(vv) - - return z -} - -var ( - _bLegendreExponentElement *big.Int - _bSqrtExponentElement *big.Int -) - -func init() { - _bLegendreExponentElement, _ = new(big.Int).SetString("39f6d3a994cebea4199cec0404d0ec02a9ded2017fff2dff7fffffff80000000", 16) - const sqrtExponentElement = "39f6d3a994cebea4199cec0404d0ec02a9ded2017fff2dff7fffffff" - _bSqrtExponentElement, _ = new(big.Int).SetString(sqrtExponentElement, 16) -} - -// Legendre returns the Legendre symbol of z (either +1, -1, or 0.) -func (z *Element) Legendre() int { - var l Element - // z^((q-1)/2) - l.Exp(*z, _bLegendreExponentElement) - - if l.IsZero() { - return 0 - } - - // if l == 1 - if (l[3] == 1739710354780652911) && (l[2] == 11064306276430008309) && (l[1] == 6378425256633387010) && (l[0] == 8589934590) { - return 1 - } - return -1 -} - -// Sqrt z = √x mod q -// if the square root doesn't exist (x is not a square mod q) -// Sqrt leaves z unchanged and returns nil -func (z *Element) Sqrt(x *Element) *Element { - // q ≡ 1 (mod 4) - // see modSqrtTonelliShanks in math/big/int.go - // using https://www.maa.org/sites/default/files/pdf/upload_library/22/Polya/07468342.di020786.02p0470a.pdf - - var y, b, t, w Element - // w = x^((s-1)/2)) - w.Exp(*x, _bSqrtExponentElement) - - // y = x^((s+1)/2)) = w * x - y.Mul(x, &w) - - // b = x^s = w * w * x = y * x - b.Mul(&w, &y) - - // g = nonResidue ^ s - var g = Element{ - 11289237133041595516, - 2081200955273736677, - 967625415375836421, - 4543825880697944938, - } - r := uint64(32) - - // compute legendre symbol - // t = x^((q-1)/2) = r-1 squaring of x^s - t = b - for i := uint64(0); i < r-1; i++ { - t.Square(&t) - } - if t.IsZero() { - return z.SetZero() - } - if !((t[3] == 1739710354780652911) && (t[2] == 11064306276430008309) && (t[1] == 6378425256633387010) && (t[0] == 8589934590)) { - // t != 1, we don't have a square root - return nil - } - for { - var m uint64 - t = b - - // for t != 1 - for !((t[3] == 1739710354780652911) && (t[2] == 11064306276430008309) && (t[1] == 6378425256633387010) && (t[0] == 8589934590)) { - t.Square(&t) - m++ - } - - if m == 0 { - return z.Set(&y) - } - // t = g^(2^(r-m-1)) mod q - ge := int(r - m - 1) - t = g - for ge > 0 { - t.Square(&t) - ge-- - } - - g.Square(&t) - y.Mul(&y, &t) - b.Mul(&b, &g) - r = m - } -} - -// Inverse z = x^-1 mod q -// Algorithm 16 in "Efficient Software-Implementation of Finite Fields with Applications to Cryptography" -// if x == 0, sets and returns z = x -func (z *Element) Inverse(x *Element) *Element { - if x.IsZero() { - return z.Set(x) - } - - // initialize u = q - var u = Element{ - 18446744069414584321, - 6034159408538082302, - 3691218898639771653, - 8353516859464449352, - } - - // initialize s = r^2 - var s = Element{ - 14526898881837571181, - 3129137299524312099, - 419701826671360399, - 524908885293268753, - } - - // r = 0 - r := Element{} - - v := *x - - var carry, borrow, t, t2 uint64 - var bigger bool - - for { - for v[0]&1 == 0 { - - // v = v >> 1 - t2 = v[3] << 63 - v[3] >>= 1 - t = t2 - t2 = v[2] << 63 - v[2] = (v[2] >> 1) | t - t = t2 - t2 = v[1] << 63 - v[1] = (v[1] >> 1) | t - t = t2 - v[0] = (v[0] >> 1) | t - - if s[0]&1 == 1 { - - // s = s + q - s[0], carry = bits.Add64(s[0], 18446744069414584321, 0) - s[1], carry = bits.Add64(s[1], 6034159408538082302, carry) - s[2], carry = bits.Add64(s[2], 3691218898639771653, carry) - s[3], _ = bits.Add64(s[3], 8353516859464449352, carry) - - } - - // s = s >> 1 - t2 = s[3] << 63 - s[3] >>= 1 - t = t2 - t2 = s[2] << 63 - s[2] = (s[2] >> 1) | t - t = t2 - t2 = s[1] << 63 - s[1] = (s[1] >> 1) | t - t = t2 - s[0] = (s[0] >> 1) | t - - } - for u[0]&1 == 0 { - - // u = u >> 1 - t2 = u[3] << 63 - u[3] >>= 1 - t = t2 - t2 = u[2] << 63 - u[2] = (u[2] >> 1) | t - t = t2 - t2 = u[1] << 63 - u[1] = (u[1] >> 1) | t - t = t2 - u[0] = (u[0] >> 1) | t - - if r[0]&1 == 1 { - - // r = r + q - r[0], carry = bits.Add64(r[0], 18446744069414584321, 0) - r[1], carry = bits.Add64(r[1], 6034159408538082302, carry) - r[2], carry = bits.Add64(r[2], 3691218898639771653, carry) - r[3], _ = bits.Add64(r[3], 8353516859464449352, carry) - - } - - // r = r >> 1 - t2 = r[3] << 63 - r[3] >>= 1 - t = t2 - t2 = r[2] << 63 - r[2] = (r[2] >> 1) | t - t = t2 - t2 = r[1] << 63 - r[1] = (r[1] >> 1) | t - t = t2 - r[0] = (r[0] >> 1) | t - - } - - // v >= u - bigger = !(v[3] < u[3] || (v[3] == u[3] && (v[2] < u[2] || (v[2] == u[2] && (v[1] < u[1] || (v[1] == u[1] && (v[0] < u[0]))))))) - - if bigger { - - // v = v - u - v[0], borrow = bits.Sub64(v[0], u[0], 0) - v[1], borrow = bits.Sub64(v[1], u[1], borrow) - v[2], borrow = bits.Sub64(v[2], u[2], borrow) - v[3], _ = bits.Sub64(v[3], u[3], borrow) - - // s = s - r - s[0], borrow = bits.Sub64(s[0], r[0], 0) - s[1], borrow = bits.Sub64(s[1], r[1], borrow) - s[2], borrow = bits.Sub64(s[2], r[2], borrow) - s[3], borrow = bits.Sub64(s[3], r[3], borrow) - - if borrow == 1 { - - // s = s + q - s[0], carry = bits.Add64(s[0], 18446744069414584321, 0) - s[1], carry = bits.Add64(s[1], 6034159408538082302, carry) - s[2], carry = bits.Add64(s[2], 3691218898639771653, carry) - s[3], _ = bits.Add64(s[3], 8353516859464449352, carry) - - } - } else { - - // u = u - v - u[0], borrow = bits.Sub64(u[0], v[0], 0) - u[1], borrow = bits.Sub64(u[1], v[1], borrow) - u[2], borrow = bits.Sub64(u[2], v[2], borrow) - u[3], _ = bits.Sub64(u[3], v[3], borrow) - - // r = r - s - r[0], borrow = bits.Sub64(r[0], s[0], 0) - r[1], borrow = bits.Sub64(r[1], s[1], borrow) - r[2], borrow = bits.Sub64(r[2], s[2], borrow) - r[3], borrow = bits.Sub64(r[3], s[3], borrow) - - if borrow == 1 { - - // r = r + q - r[0], carry = bits.Add64(r[0], 18446744069414584321, 0) - r[1], carry = bits.Add64(r[1], 6034159408538082302, carry) - r[2], carry = bits.Add64(r[2], 3691218898639771653, carry) - r[3], _ = bits.Add64(r[3], 8353516859464449352, carry) - - } - } - if (u[0] == 1) && (u[3]|u[2]|u[1]) == 0 { - return z.Set(&r) - } - if (v[0] == 1) && (v[3]|v[2]|v[1]) == 0 { - return z.Set(&s) - } - } - -} diff --git a/bls12_381/element_mul_adx_amd64.s b/bls12_381/element_mul_adx_amd64.s deleted file mode 100644 index 369a6b7..0000000 --- a/bls12_381/element_mul_adx_amd64.s +++ /dev/null @@ -1,466 +0,0 @@ -// +build amd64_adx - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0xffffffff00000001 -DATA q<>+8(SB)/8, $0x53bda402fffe5bfe -DATA q<>+16(SB)/8, $0x3339d80809a1d805 -DATA q<>+24(SB)/8, $0x73eda753299d7d48 -GLOBL q<>(SB), (RODATA+NOPTR), $32 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0xfffffffeffffffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), NOSPLIT, $0-24 - - // the algorithm is described here - // https://hackmd.io/@zkteam/modular_multiplication - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - MOVQ x+8(FP), SI - - // x[0] -> DI - // x[1] -> R8 - // x[2] -> R9 - // x[3] -> R10 - MOVQ 0(SI), DI - MOVQ 8(SI), R8 - MOVQ 16(SI), R9 - MOVQ 24(SI), R10 - MOVQ y+16(FP), R11 - - // A -> BP - // t[0] -> R14 - // t[1] -> R15 - // t[2] -> CX - // t[3] -> BX - // clear the flags - XORQ AX, AX - MOVQ 0(R11), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ DI, R14, R15 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R8, AX, CX - ADOXQ AX, R15 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R9, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 8(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R15 - MULXQ R8, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 16(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R15 - MULXQ R8, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 24(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R15 - MULXQ R8, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // reduce element(R14,R15,CX,BX) using temp registers (R13,SI,R12,R11) - REDUCE(R14,R15,CX,BX,R13,SI,R12,R11) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R15, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET - -TEXT ·fromMont(SB), NOSPLIT, $0-8 - - // the algorithm is described here - // https://hackmd.io/@zkteam/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R15 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - - // reduce element(R14,R15,CX,BX) using temp registers (SI,DI,R8,R9) - REDUCE(R14,R15,CX,BX,SI,DI,R8,R9) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R15, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET diff --git a/bls12_381/element_mul_amd64.s b/bls12_381/element_mul_amd64.s deleted file mode 100644 index 28570b1..0000000 --- a/bls12_381/element_mul_amd64.s +++ /dev/null @@ -1,488 +0,0 @@ -// +build !amd64_adx - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0xffffffff00000001 -DATA q<>+8(SB)/8, $0x53bda402fffe5bfe -DATA q<>+16(SB)/8, $0x3339d80809a1d805 -DATA q<>+24(SB)/8, $0x73eda753299d7d48 -GLOBL q<>(SB), (RODATA+NOPTR), $32 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0xfffffffeffffffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), $24-24 - - // the algorithm is described here - // https://hackmd.io/@zkteam/modular_multiplication - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - NO_LOCAL_POINTERS - CMPB ·supportAdx(SB), $1 - JNE l1 - MOVQ x+8(FP), SI - - // x[0] -> DI - // x[1] -> R8 - // x[2] -> R9 - // x[3] -> R10 - MOVQ 0(SI), DI - MOVQ 8(SI), R8 - MOVQ 16(SI), R9 - MOVQ 24(SI), R10 - MOVQ y+16(FP), R11 - - // A -> BP - // t[0] -> R14 - // t[1] -> R15 - // t[2] -> CX - // t[3] -> BX - // clear the flags - XORQ AX, AX - MOVQ 0(R11), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ DI, R14, R15 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R8, AX, CX - ADOXQ AX, R15 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R9, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 8(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R15 - MULXQ R8, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 16(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R15 - MULXQ R8, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 24(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R15 - MULXQ R8, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // reduce element(R14,R15,CX,BX) using temp registers (R13,SI,R12,R11) - REDUCE(R14,R15,CX,BX,R13,SI,R12,R11) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R15, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET - -l1: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - MOVQ x+8(FP), AX - MOVQ AX, 8(SP) - MOVQ y+16(FP), AX - MOVQ AX, 16(SP) - CALL ·_mulGeneric(SB) - RET - -TEXT ·fromMont(SB), $8-8 - NO_LOCAL_POINTERS - - // the algorithm is described here - // https://hackmd.io/@zkteam/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - CMPB ·supportAdx(SB), $1 - JNE l2 - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R15 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - - // reduce element(R14,R15,CX,BX) using temp registers (SI,DI,R8,R9) - REDUCE(R14,R15,CX,BX,SI,DI,R8,R9) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R15, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET - -l2: - MOVQ res+0(FP), AX - MOVQ AX, (SP) - CALL ·_fromMontGeneric(SB) - RET diff --git a/bls12_381/element_ops_amd64.go b/bls12_381/element_ops_amd64.go deleted file mode 100644 index 0da4db2..0000000 --- a/bls12_381/element_ops_amd64.go +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package bls12_381 - -//go:noescape -func MulBy3(x *Element) - -//go:noescape -func MulBy5(x *Element) - -//go:noescape -func MulBy13(x *Element) - -//go:noescape -func add(res, x, y *Element) - -//go:noescape -func sub(res, x, y *Element) - -//go:noescape -func neg(res, x *Element) - -//go:noescape -func double(res, x *Element) - -//go:noescape -func mul(res, x, y *Element) - -//go:noescape -func fromMont(res *Element) - -//go:noescape -func reduce(res *Element) diff --git a/bls12_381/element_ops_amd64.s b/bls12_381/element_ops_amd64.s deleted file mode 100644 index 9b2606e..0000000 --- a/bls12_381/element_ops_amd64.s +++ /dev/null @@ -1,292 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0xffffffff00000001 -DATA q<>+8(SB)/8, $0x53bda402fffe5bfe -DATA q<>+16(SB)/8, $0x3339d80809a1d805 -DATA q<>+24(SB)/8, $0x73eda753299d7d48 -GLOBL q<>(SB), (RODATA+NOPTR), $32 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0xfffffffeffffffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - -// add(res, x, y *Element) -TEXT ·add(SB), NOSPLIT, $0-24 - MOVQ x+8(FP), AX - MOVQ 0(AX), CX - MOVQ 8(AX), BX - MOVQ 16(AX), SI - MOVQ 24(AX), DI - MOVQ y+16(FP), DX - ADDQ 0(DX), CX - ADCQ 8(DX), BX - ADCQ 16(DX), SI - ADCQ 24(DX), DI - - // reduce element(CX,BX,SI,DI) using temp registers (R8,R9,R10,R11) - REDUCE(CX,BX,SI,DI,R8,R9,R10,R11) - - MOVQ res+0(FP), R12 - MOVQ CX, 0(R12) - MOVQ BX, 8(R12) - MOVQ SI, 16(R12) - MOVQ DI, 24(R12) - RET - -// sub(res, x, y *Element) -TEXT ·sub(SB), NOSPLIT, $0-24 - XORQ DI, DI - MOVQ x+8(FP), SI - MOVQ 0(SI), AX - MOVQ 8(SI), DX - MOVQ 16(SI), CX - MOVQ 24(SI), BX - MOVQ y+16(FP), SI - SUBQ 0(SI), AX - SBBQ 8(SI), DX - SBBQ 16(SI), CX - SBBQ 24(SI), BX - MOVQ $0xffffffff00000001, R8 - MOVQ $0x53bda402fffe5bfe, R9 - MOVQ $0x3339d80809a1d805, R10 - MOVQ $0x73eda753299d7d48, R11 - CMOVQCC DI, R8 - CMOVQCC DI, R9 - CMOVQCC DI, R10 - CMOVQCC DI, R11 - ADDQ R8, AX - ADCQ R9, DX - ADCQ R10, CX - ADCQ R11, BX - MOVQ res+0(FP), R12 - MOVQ AX, 0(R12) - MOVQ DX, 8(R12) - MOVQ CX, 16(R12) - MOVQ BX, 24(R12) - RET - -// double(res, x *Element) -TEXT ·double(SB), NOSPLIT, $0-16 - MOVQ x+8(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - MOVQ res+0(FP), R11 - MOVQ DX, 0(R11) - MOVQ CX, 8(R11) - MOVQ BX, 16(R11) - MOVQ SI, 24(R11) - RET - -// neg(res, x *Element) -TEXT ·neg(SB), NOSPLIT, $0-16 - MOVQ res+0(FP), DI - MOVQ x+8(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - MOVQ DX, AX - ORQ CX, AX - ORQ BX, AX - ORQ SI, AX - TESTQ AX, AX - JEQ l1 - MOVQ $0xffffffff00000001, R8 - SUBQ DX, R8 - MOVQ R8, 0(DI) - MOVQ $0x53bda402fffe5bfe, R8 - SBBQ CX, R8 - MOVQ R8, 8(DI) - MOVQ $0x3339d80809a1d805, R8 - SBBQ BX, R8 - MOVQ R8, 16(DI) - MOVQ $0x73eda753299d7d48, R8 - SBBQ SI, R8 - MOVQ R8, 24(DI) - RET - -l1: - MOVQ AX, 0(DI) - MOVQ AX, 8(DI) - MOVQ AX, 16(DI) - MOVQ AX, 24(DI) - RET - -TEXT ·reduce(SB), NOSPLIT, $0-8 - MOVQ res+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy3(x *Element) -TEXT ·MulBy3(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy5(x *Element) -TEXT ·MulBy5(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (R15,DI,R8,R9) - REDUCE(DX,CX,BX,SI,R15,DI,R8,R9) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET - -// MulBy13(x *Element) -TEXT ·MulBy13(SB), NOSPLIT, $0-8 - MOVQ x+0(FP), AX - MOVQ 0(AX), DX - MOVQ 8(AX), CX - MOVQ 16(AX), BX - MOVQ 24(AX), SI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (R11,R12,R13,R14) - REDUCE(DX,CX,BX,SI,R11,R12,R13,R14) - - MOVQ DX, R11 - MOVQ CX, R12 - MOVQ BX, R13 - MOVQ SI, R14 - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ R11, DX - ADCQ R12, CX - ADCQ R13, BX - ADCQ R14, SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - ADDQ 0(AX), DX - ADCQ 8(AX), CX - ADCQ 16(AX), BX - ADCQ 24(AX), SI - - // reduce element(DX,CX,BX,SI) using temp registers (DI,R8,R9,R10) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10) - - MOVQ DX, 0(AX) - MOVQ CX, 8(AX) - MOVQ BX, 16(AX) - MOVQ SI, 24(AX) - RET diff --git a/bls12_381/element_ops_noasm.go b/bls12_381/element_ops_noasm.go deleted file mode 100644 index 6134d1c..0000000 --- a/bls12_381/element_ops_noasm.go +++ /dev/null @@ -1,70 +0,0 @@ -// +build !amd64 - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package bls12_381 - -// /!\ WARNING /!\ -// this code has not been audited and is provided as-is. In particular, -// there is no security guarantees such as constant time implementation -// or side-channel attack resistance -// /!\ WARNING /!\ - -// MulBy3 x *= 3 -func MulBy3(x *Element) { - mulByConstant(x, 3) -} - -// MulBy5 x *= 5 -func MulBy5(x *Element) { - mulByConstant(x, 5) -} - -// MulBy13 x *= 13 -func MulBy13(x *Element) { - mulByConstant(x, 13) -} - -func mul(z, x, y *Element) { - _mulGeneric(z, x, y) -} - -// FromMont converts z in place (i.e. mutates) from Montgomery to regular representation -// sets and returns z = z * 1 -func fromMont(z *Element) { - _fromMontGeneric(z) -} - -func add(z, x, y *Element) { - _addGeneric(z, x, y) -} - -func double(z, x *Element) { - _doubleGeneric(z, x) -} - -func sub(z, x, y *Element) { - _subGeneric(z, x, y) -} - -func neg(z, x *Element) { - _negGeneric(z, x) -} - -func reduce(z *Element) { - _reduceGeneric(z) -} diff --git a/bls12_381/element_test.go b/bls12_381/element_test.go deleted file mode 100644 index fa1ec20..0000000 --- a/bls12_381/element_test.go +++ /dev/null @@ -1,1796 +0,0 @@ -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package bls12_381 - -import ( - "crypto/rand" - "math/big" - "math/bits" - "testing" - - "github.com/leanovate/gopter" - "github.com/leanovate/gopter/prop" -) - -// ------------------------------------------------------------------------------------------------- -// benchmarks -// most benchmarks are rudimentary and should sample a large number of random inputs -// or be run multiple times to ensure it didn't measure the fastest path of the function - -var benchResElement Element - -func BenchmarkElementSetBytes(b *testing.B) { - var x Element - x.SetRandom() - bb := x.Bytes() - b.ResetTimer() - - for i := 0; i < b.N; i++ { - benchResElement.SetBytes(bb[:]) - } - -} - -func BenchmarkElementMulByConstants(b *testing.B) { - b.Run("mulBy3", func(b *testing.B) { - benchResElement.SetRandom() - b.ResetTimer() - for i := 0; i < b.N; i++ { - MulBy3(&benchResElement) - } - }) - b.Run("mulBy5", func(b *testing.B) { - benchResElement.SetRandom() - b.ResetTimer() - for i := 0; i < b.N; i++ { - MulBy5(&benchResElement) - } - }) - b.Run("mulBy13", func(b *testing.B) { - benchResElement.SetRandom() - b.ResetTimer() - for i := 0; i < b.N; i++ { - MulBy13(&benchResElement) - } - }) -} - -func BenchmarkElementInverse(b *testing.B) { - var x Element - x.SetRandom() - benchResElement.SetRandom() - b.ResetTimer() - - for i := 0; i < b.N; i++ { - benchResElement.Inverse(&x) - } - -} - -func BenchmarkElementExp(b *testing.B) { - var x Element - x.SetRandom() - benchResElement.SetRandom() - b1, _ := rand.Int(rand.Reader, Modulus()) - b.ResetTimer() - for i := 0; i < b.N; i++ { - benchResElement.Exp(x, b1) - } -} - -func BenchmarkElementDouble(b *testing.B) { - benchResElement.SetRandom() - b.ResetTimer() - for i := 0; i < b.N; i++ { - benchResElement.Double(&benchResElement) - } -} - -func BenchmarkElementAdd(b *testing.B) { - var x Element - x.SetRandom() - benchResElement.SetRandom() - b.ResetTimer() - for i := 0; i < b.N; i++ { - benchResElement.Add(&x, &benchResElement) - } -} - -func BenchmarkElementSub(b *testing.B) { - var x Element - x.SetRandom() - benchResElement.SetRandom() - b.ResetTimer() - for i := 0; i < b.N; i++ { - benchResElement.Sub(&x, &benchResElement) - } -} - -func BenchmarkElementNeg(b *testing.B) { - benchResElement.SetRandom() - b.ResetTimer() - for i := 0; i < b.N; i++ { - benchResElement.Neg(&benchResElement) - } -} - -func BenchmarkElementDiv(b *testing.B) { - var x Element - x.SetRandom() - benchResElement.SetRandom() - b.ResetTimer() - for i := 0; i < b.N; i++ { - benchResElement.Div(&x, &benchResElement) - } -} - -func BenchmarkElementFromMont(b *testing.B) { - benchResElement.SetRandom() - b.ResetTimer() - for i := 0; i < b.N; i++ { - benchResElement.FromMont() - } -} - -func BenchmarkElementToMont(b *testing.B) { - benchResElement.SetRandom() - b.ResetTimer() - for i := 0; i < b.N; i++ { - benchResElement.ToMont() - } -} -func BenchmarkElementSquare(b *testing.B) { - benchResElement.SetRandom() - b.ResetTimer() - for i := 0; i < b.N; i++ { - benchResElement.Square(&benchResElement) - } -} - -func BenchmarkElementSqrt(b *testing.B) { - var a Element - a.SetRandom() - b.ResetTimer() - for i := 0; i < b.N; i++ { - benchResElement.Sqrt(&a) - } -} - -func BenchmarkElementMul(b *testing.B) { - x := Element{ - 14526898881837571181, - 3129137299524312099, - 419701826671360399, - 524908885293268753, - } - benchResElement.SetOne() - b.ResetTimer() - for i := 0; i < b.N; i++ { - benchResElement.Mul(&benchResElement, &x) - } -} - -func BenchmarkElementCmp(b *testing.B) { - x := Element{ - 14526898881837571181, - 3129137299524312099, - 419701826671360399, - 524908885293268753, - } - benchResElement = x - benchResElement[0] = 0 - b.ResetTimer() - for i := 0; i < b.N; i++ { - benchResElement.Cmp(&x) - } -} - -func TestElementCmp(t *testing.T) { - var x, y Element - - if x.Cmp(&y) != 0 { - t.Fatal("x == y") - } - - one := One() - y.Sub(&y, &one) - - if x.Cmp(&y) != -1 { - t.Fatal("x < y") - } - if y.Cmp(&x) != 1 { - t.Fatal("x < y") - } - - x = y - if x.Cmp(&y) != 0 { - t.Fatal("x == y") - } - - x.Sub(&x, &one) - if x.Cmp(&y) != -1 { - t.Fatal("x < y") - } - if y.Cmp(&x) != 1 { - t.Fatal("x < y") - } -} - -func TestElementIsRandom(t *testing.T) { - for i := 0; i < 50; i++ { - var x, y Element - x.SetRandom() - y.SetRandom() - if x.Equal(&y) { - t.Fatal("2 random numbers are unlikely to be equal") - } - } -} - -// ------------------------------------------------------------------------------------------------- -// Gopter tests -// most of them are generated with a template - -const ( - nbFuzzShort = 200 - nbFuzz = 1000 -) - -// special values to be used in tests -var staticTestValues []Element - -func init() { - staticTestValues = append(staticTestValues, Element{}) // zero - staticTestValues = append(staticTestValues, One()) // one - staticTestValues = append(staticTestValues, rSquare) // r^2 - var e, one Element - one.SetOne() - e.Sub(&qElement, &one) - staticTestValues = append(staticTestValues, e) // q - 1 - e.Double(&one) - staticTestValues = append(staticTestValues, e) // 2 - - { - a := qElement - a[3]-- - staticTestValues = append(staticTestValues, a) - } - { - a := qElement - a[0]-- - staticTestValues = append(staticTestValues, a) - } - - for i := 0; i <= 3; i++ { - staticTestValues = append(staticTestValues, Element{uint64(i)}) - staticTestValues = append(staticTestValues, Element{0, uint64(i)}) - } - - { - a := qElement - a[3]-- - a[0]++ - staticTestValues = append(staticTestValues, a) - } - -} - -func TestElementNegZero(t *testing.T) { - var a, b Element - b.SetZero() - for a.IsZero() { - a.SetRandom() - } - a.Neg(&b) - if !a.IsZero() { - t.Fatal("neg(0) != 0") - } -} - -func TestElementReduce(t *testing.T) { - testValues := make([]Element, len(staticTestValues)) - copy(testValues, staticTestValues) - - for _, s := range testValues { - expected := s - reduce(&s) - _reduceGeneric(&expected) - if !s.Equal(&expected) { - t.Fatal("reduce failed: asm and generic impl don't match") - } - } - - parameters := gopter.DefaultTestParameters() - if testing.Short() { - parameters.MinSuccessfulTests = nbFuzzShort - } else { - parameters.MinSuccessfulTests = nbFuzz - } - - properties := gopter.NewProperties(parameters) - - genA := genFull() - - properties.Property("reduce should output a result smaller than modulus", prop.ForAll( - func(a Element) bool { - b := a - reduce(&a) - _reduceGeneric(&b) - return !a.biggerOrEqualModulus() && a.Equal(&b) - }, - genA, - )) - - properties.TestingRun(t, gopter.ConsoleReporter(false)) - // if we have ADX instruction enabled, test both path in assembly - if supportAdx { - t.Log("disabling ADX") - supportAdx = false - properties.TestingRun(t, gopter.ConsoleReporter(false)) - supportAdx = true - } - -} - -func TestElementBytes(t *testing.T) { - parameters := gopter.DefaultTestParameters() - if testing.Short() { - parameters.MinSuccessfulTests = nbFuzzShort - } else { - parameters.MinSuccessfulTests = nbFuzz - } - - properties := gopter.NewProperties(parameters) - - genA := gen() - - properties.Property("SetBytes(Bytes()) should stayt constant", prop.ForAll( - func(a testPairElement) bool { - var b Element - bytes := a.element.Bytes() - b.SetBytes(bytes[:]) - return a.element.Equal(&b) - }, - genA, - )) - - properties.TestingRun(t, gopter.ConsoleReporter(false)) -} - -func TestElementInverseExp(t *testing.T) { - // inverse must be equal to exp^-2 - exp := Modulus() - exp.Sub(exp, new(big.Int).SetUint64(2)) - - parameters := gopter.DefaultTestParameters() - if testing.Short() { - parameters.MinSuccessfulTests = nbFuzzShort - } else { - parameters.MinSuccessfulTests = nbFuzz - } - - properties := gopter.NewProperties(parameters) - - genA := gen() - - properties.Property("inv == exp^-2", prop.ForAll( - func(a testPairElement) bool { - var b Element - b.Set(&a.element) - a.element.Inverse(&a.element) - b.Exp(b, exp) - - return a.element.Equal(&b) - }, - genA, - )) - - properties.TestingRun(t, gopter.ConsoleReporter(false)) - // if we have ADX instruction enabled, test both path in assembly - if supportAdx { - t.Log("disabling ADX") - supportAdx = false - properties.TestingRun(t, gopter.ConsoleReporter(false)) - supportAdx = true - } -} - -func TestElementMulByConstants(t *testing.T) { - - parameters := gopter.DefaultTestParameters() - if testing.Short() { - parameters.MinSuccessfulTests = nbFuzzShort - } else { - parameters.MinSuccessfulTests = nbFuzz - } - - properties := gopter.NewProperties(parameters) - - genA := gen() - - implemented := []uint8{0, 1, 2, 3, 5, 13} - properties.Property("mulByConstant", prop.ForAll( - func(a testPairElement) bool { - for _, c := range implemented { - var constant Element - constant.SetUint64(uint64(c)) - - b := a.element - b.Mul(&b, &constant) - - aa := a.element - mulByConstant(&aa, c) - - if !aa.Equal(&b) { - return false - } - } - - return true - }, - genA, - )) - - properties.Property("MulBy3(x) == Mul(x, 3)", prop.ForAll( - func(a testPairElement) bool { - var constant Element - constant.SetUint64(3) - - b := a.element - b.Mul(&b, &constant) - - MulBy3(&a.element) - - return a.element.Equal(&b) - }, - genA, - )) - - properties.Property("MulBy5(x) == Mul(x, 5)", prop.ForAll( - func(a testPairElement) bool { - var constant Element - constant.SetUint64(5) - - b := a.element - b.Mul(&b, &constant) - - MulBy5(&a.element) - - return a.element.Equal(&b) - }, - genA, - )) - - properties.Property("MulBy13(x) == Mul(x, 13)", prop.ForAll( - func(a testPairElement) bool { - var constant Element - constant.SetUint64(13) - - b := a.element - b.Mul(&b, &constant) - - MulBy13(&a.element) - - return a.element.Equal(&b) - }, - genA, - )) - - properties.TestingRun(t, gopter.ConsoleReporter(false)) - // if we have ADX instruction enabled, test both path in assembly - if supportAdx { - t.Log("disabling ADX") - supportAdx = false - properties.TestingRun(t, gopter.ConsoleReporter(false)) - supportAdx = true - } - -} - -func TestElementLegendre(t *testing.T) { - parameters := gopter.DefaultTestParameters() - if testing.Short() { - parameters.MinSuccessfulTests = nbFuzzShort - } else { - parameters.MinSuccessfulTests = nbFuzz - } - - properties := gopter.NewProperties(parameters) - - genA := gen() - - properties.Property("legendre should output same result than big.Int.Jacobi", prop.ForAll( - func(a testPairElement) bool { - return a.element.Legendre() == big.Jacobi(&a.bigint, Modulus()) - }, - genA, - )) - - properties.TestingRun(t, gopter.ConsoleReporter(false)) - // if we have ADX instruction enabled, test both path in assembly - if supportAdx { - t.Log("disabling ADX") - supportAdx = false - properties.TestingRun(t, gopter.ConsoleReporter(false)) - supportAdx = true - } - -} - -func TestElementLexicographicallyLargest(t *testing.T) { - parameters := gopter.DefaultTestParameters() - if testing.Short() { - parameters.MinSuccessfulTests = nbFuzzShort - } else { - parameters.MinSuccessfulTests = nbFuzz - } - - properties := gopter.NewProperties(parameters) - - genA := gen() - - properties.Property("element.Cmp should match LexicographicallyLargest output", prop.ForAll( - func(a testPairElement) bool { - var negA Element - negA.Neg(&a.element) - - cmpResult := a.element.Cmp(&negA) - lResult := a.element.LexicographicallyLargest() - - if lResult && cmpResult == 1 { - return true - } - if !lResult && cmpResult != 1 { - return true - } - return false - }, - genA, - )) - - properties.TestingRun(t, gopter.ConsoleReporter(false)) - // if we have ADX instruction enabled, test both path in assembly - if supportAdx { - t.Log("disabling ADX") - supportAdx = false - properties.TestingRun(t, gopter.ConsoleReporter(false)) - supportAdx = true - } - -} - -func TestElementAdd(t *testing.T) { - parameters := gopter.DefaultTestParameters() - if testing.Short() { - parameters.MinSuccessfulTests = nbFuzzShort - } else { - parameters.MinSuccessfulTests = nbFuzz - } - - properties := gopter.NewProperties(parameters) - - genA := gen() - genB := gen() - - properties.Property("Add: having the receiver as operand should output the same result", prop.ForAll( - func(a, b testPairElement) bool { - var c, d Element - d.Set(&a.element) - - c.Add(&a.element, &b.element) - a.element.Add(&a.element, &b.element) - b.element.Add(&d, &b.element) - - return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) - }, - genA, - genB, - )) - - properties.Property("Add: operation result must match big.Int result", prop.ForAll( - func(a, b testPairElement) bool { - { - var c Element - - c.Add(&a.element, &b.element) - - var d, e big.Int - d.Add(&a.bigint, &b.bigint).Mod(&d, Modulus()) - - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { - return false - } - } - - // fixed elements - // a is random - // r takes special values - testValues := make([]Element, len(staticTestValues)) - copy(testValues, staticTestValues) - - for _, r := range testValues { - var d, e, rb big.Int - r.ToBigIntRegular(&rb) - - var c Element - c.Add(&a.element, &r) - d.Add(&a.bigint, &rb).Mod(&d, Modulus()) - - // checking generic impl against asm path - var cGeneric Element - _addGeneric(&cGeneric, &a.element, &r) - if !cGeneric.Equal(&c) { - // need to give context to failing error. - return false - } - - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { - return false - } - } - return true - }, - genA, - genB, - )) - - properties.Property("Add: operation result must be smaller than modulus", prop.ForAll( - func(a, b testPairElement) bool { - var c Element - - c.Add(&a.element, &b.element) - - return !c.biggerOrEqualModulus() - }, - genA, - genB, - )) - - properties.Property("Add: assembly implementation must be consistent with generic one", prop.ForAll( - func(a, b testPairElement) bool { - var c, d Element - c.Add(&a.element, &b.element) - _addGeneric(&d, &a.element, &b.element) - return c.Equal(&d) - }, - genA, - genB, - )) - - specialValueTest := func() { - // test special values against special values - testValues := make([]Element, len(staticTestValues)) - copy(testValues, staticTestValues) - - for _, a := range testValues { - var aBig big.Int - a.ToBigIntRegular(&aBig) - for _, b := range testValues { - - var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) - - var c Element - c.Add(&a, &b) - d.Add(&aBig, &bBig).Mod(&d, Modulus()) - - // checking asm against generic impl - var cGeneric Element - _addGeneric(&cGeneric, &a, &b) - if !cGeneric.Equal(&c) { - t.Fatal("Add failed special test values: asm and generic impl don't match") - } - - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { - t.Fatal("Add failed special test values") - } - } - } - } - - properties.TestingRun(t, gopter.ConsoleReporter(false)) - specialValueTest() - // if we have ADX instruction enabled, test both path in assembly - if supportAdx { - t.Log("disabling ADX") - supportAdx = false - properties.TestingRun(t, gopter.ConsoleReporter(false)) - specialValueTest() - supportAdx = true - } -} - -func TestElementSub(t *testing.T) { - parameters := gopter.DefaultTestParameters() - if testing.Short() { - parameters.MinSuccessfulTests = nbFuzzShort - } else { - parameters.MinSuccessfulTests = nbFuzz - } - - properties := gopter.NewProperties(parameters) - - genA := gen() - genB := gen() - - properties.Property("Sub: having the receiver as operand should output the same result", prop.ForAll( - func(a, b testPairElement) bool { - var c, d Element - d.Set(&a.element) - - c.Sub(&a.element, &b.element) - a.element.Sub(&a.element, &b.element) - b.element.Sub(&d, &b.element) - - return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) - }, - genA, - genB, - )) - - properties.Property("Sub: operation result must match big.Int result", prop.ForAll( - func(a, b testPairElement) bool { - { - var c Element - - c.Sub(&a.element, &b.element) - - var d, e big.Int - d.Sub(&a.bigint, &b.bigint).Mod(&d, Modulus()) - - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { - return false - } - } - - // fixed elements - // a is random - // r takes special values - testValues := make([]Element, len(staticTestValues)) - copy(testValues, staticTestValues) - - for _, r := range testValues { - var d, e, rb big.Int - r.ToBigIntRegular(&rb) - - var c Element - c.Sub(&a.element, &r) - d.Sub(&a.bigint, &rb).Mod(&d, Modulus()) - - // checking generic impl against asm path - var cGeneric Element - _subGeneric(&cGeneric, &a.element, &r) - if !cGeneric.Equal(&c) { - // need to give context to failing error. - return false - } - - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { - return false - } - } - return true - }, - genA, - genB, - )) - - properties.Property("Sub: operation result must be smaller than modulus", prop.ForAll( - func(a, b testPairElement) bool { - var c Element - - c.Sub(&a.element, &b.element) - - return !c.biggerOrEqualModulus() - }, - genA, - genB, - )) - - properties.Property("Sub: assembly implementation must be consistent with generic one", prop.ForAll( - func(a, b testPairElement) bool { - var c, d Element - c.Sub(&a.element, &b.element) - _subGeneric(&d, &a.element, &b.element) - return c.Equal(&d) - }, - genA, - genB, - )) - - specialValueTest := func() { - // test special values against special values - testValues := make([]Element, len(staticTestValues)) - copy(testValues, staticTestValues) - - for _, a := range testValues { - var aBig big.Int - a.ToBigIntRegular(&aBig) - for _, b := range testValues { - - var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) - - var c Element - c.Sub(&a, &b) - d.Sub(&aBig, &bBig).Mod(&d, Modulus()) - - // checking asm against generic impl - var cGeneric Element - _subGeneric(&cGeneric, &a, &b) - if !cGeneric.Equal(&c) { - t.Fatal("Sub failed special test values: asm and generic impl don't match") - } - - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { - t.Fatal("Sub failed special test values") - } - } - } - } - - properties.TestingRun(t, gopter.ConsoleReporter(false)) - specialValueTest() - // if we have ADX instruction enabled, test both path in assembly - if supportAdx { - t.Log("disabling ADX") - supportAdx = false - properties.TestingRun(t, gopter.ConsoleReporter(false)) - specialValueTest() - supportAdx = true - } -} - -func TestElementMul(t *testing.T) { - parameters := gopter.DefaultTestParameters() - if testing.Short() { - parameters.MinSuccessfulTests = nbFuzzShort - } else { - parameters.MinSuccessfulTests = nbFuzz - } - - properties := gopter.NewProperties(parameters) - - genA := gen() - genB := gen() - - properties.Property("Mul: having the receiver as operand should output the same result", prop.ForAll( - func(a, b testPairElement) bool { - var c, d Element - d.Set(&a.element) - - c.Mul(&a.element, &b.element) - a.element.Mul(&a.element, &b.element) - b.element.Mul(&d, &b.element) - - return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) - }, - genA, - genB, - )) - - properties.Property("Mul: operation result must match big.Int result", prop.ForAll( - func(a, b testPairElement) bool { - { - var c Element - - c.Mul(&a.element, &b.element) - - var d, e big.Int - d.Mul(&a.bigint, &b.bigint).Mod(&d, Modulus()) - - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { - return false - } - } - - // fixed elements - // a is random - // r takes special values - testValues := make([]Element, len(staticTestValues)) - copy(testValues, staticTestValues) - - for _, r := range testValues { - var d, e, rb big.Int - r.ToBigIntRegular(&rb) - - var c Element - c.Mul(&a.element, &r) - d.Mul(&a.bigint, &rb).Mod(&d, Modulus()) - - // checking generic impl against asm path - var cGeneric Element - _mulGeneric(&cGeneric, &a.element, &r) - if !cGeneric.Equal(&c) { - // need to give context to failing error. - return false - } - - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { - return false - } - } - return true - }, - genA, - genB, - )) - - properties.Property("Mul: operation result must be smaller than modulus", prop.ForAll( - func(a, b testPairElement) bool { - var c Element - - c.Mul(&a.element, &b.element) - - return !c.biggerOrEqualModulus() - }, - genA, - genB, - )) - - properties.Property("Mul: assembly implementation must be consistent with generic one", prop.ForAll( - func(a, b testPairElement) bool { - var c, d Element - c.Mul(&a.element, &b.element) - _mulGeneric(&d, &a.element, &b.element) - return c.Equal(&d) - }, - genA, - genB, - )) - - specialValueTest := func() { - // test special values against special values - testValues := make([]Element, len(staticTestValues)) - copy(testValues, staticTestValues) - - for _, a := range testValues { - var aBig big.Int - a.ToBigIntRegular(&aBig) - for _, b := range testValues { - - var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) - - var c Element - c.Mul(&a, &b) - d.Mul(&aBig, &bBig).Mod(&d, Modulus()) - - // checking asm against generic impl - var cGeneric Element - _mulGeneric(&cGeneric, &a, &b) - if !cGeneric.Equal(&c) { - t.Fatal("Mul failed special test values: asm and generic impl don't match") - } - - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { - t.Fatal("Mul failed special test values") - } - } - } - } - - properties.TestingRun(t, gopter.ConsoleReporter(false)) - specialValueTest() - // if we have ADX instruction enabled, test both path in assembly - if supportAdx { - t.Log("disabling ADX") - supportAdx = false - properties.TestingRun(t, gopter.ConsoleReporter(false)) - specialValueTest() - supportAdx = true - } -} - -func TestElementDiv(t *testing.T) { - parameters := gopter.DefaultTestParameters() - if testing.Short() { - parameters.MinSuccessfulTests = nbFuzzShort - } else { - parameters.MinSuccessfulTests = nbFuzz - } - - properties := gopter.NewProperties(parameters) - - genA := gen() - genB := gen() - - properties.Property("Div: having the receiver as operand should output the same result", prop.ForAll( - func(a, b testPairElement) bool { - var c, d Element - d.Set(&a.element) - - c.Div(&a.element, &b.element) - a.element.Div(&a.element, &b.element) - b.element.Div(&d, &b.element) - - return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) - }, - genA, - genB, - )) - - properties.Property("Div: operation result must match big.Int result", prop.ForAll( - func(a, b testPairElement) bool { - { - var c Element - - c.Div(&a.element, &b.element) - - var d, e big.Int - d.ModInverse(&b.bigint, Modulus()) - d.Mul(&d, &a.bigint).Mod(&d, Modulus()) - - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { - return false - } - } - - // fixed elements - // a is random - // r takes special values - testValues := make([]Element, len(staticTestValues)) - copy(testValues, staticTestValues) - - for _, r := range testValues { - var d, e, rb big.Int - r.ToBigIntRegular(&rb) - - var c Element - c.Div(&a.element, &r) - d.ModInverse(&rb, Modulus()) - d.Mul(&d, &a.bigint).Mod(&d, Modulus()) - - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { - return false - } - } - return true - }, - genA, - genB, - )) - - properties.Property("Div: operation result must be smaller than modulus", prop.ForAll( - func(a, b testPairElement) bool { - var c Element - - c.Div(&a.element, &b.element) - - return !c.biggerOrEqualModulus() - }, - genA, - genB, - )) - - specialValueTest := func() { - // test special values against special values - testValues := make([]Element, len(staticTestValues)) - copy(testValues, staticTestValues) - - for _, a := range testValues { - var aBig big.Int - a.ToBigIntRegular(&aBig) - for _, b := range testValues { - - var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) - - var c Element - c.Div(&a, &b) - d.ModInverse(&bBig, Modulus()) - d.Mul(&d, &aBig).Mod(&d, Modulus()) - - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { - t.Fatal("Div failed special test values") - } - } - } - } - - properties.TestingRun(t, gopter.ConsoleReporter(false)) - specialValueTest() - // if we have ADX instruction enabled, test both path in assembly - if supportAdx { - t.Log("disabling ADX") - supportAdx = false - properties.TestingRun(t, gopter.ConsoleReporter(false)) - specialValueTest() - supportAdx = true - } -} - -func TestElementExp(t *testing.T) { - parameters := gopter.DefaultTestParameters() - if testing.Short() { - parameters.MinSuccessfulTests = nbFuzzShort - } else { - parameters.MinSuccessfulTests = nbFuzz - } - - properties := gopter.NewProperties(parameters) - - genA := gen() - genB := gen() - - properties.Property("Exp: having the receiver as operand should output the same result", prop.ForAll( - func(a, b testPairElement) bool { - var c, d Element - d.Set(&a.element) - - c.Exp(a.element, &b.bigint) - a.element.Exp(a.element, &b.bigint) - b.element.Exp(d, &b.bigint) - - return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) - }, - genA, - genB, - )) - - properties.Property("Exp: operation result must match big.Int result", prop.ForAll( - func(a, b testPairElement) bool { - { - var c Element - - c.Exp(a.element, &b.bigint) - - var d, e big.Int - d.Exp(&a.bigint, &b.bigint, Modulus()) - - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { - return false - } - } - - // fixed elements - // a is random - // r takes special values - testValues := make([]Element, len(staticTestValues)) - copy(testValues, staticTestValues) - - for _, r := range testValues { - var d, e, rb big.Int - r.ToBigIntRegular(&rb) - - var c Element - c.Exp(a.element, &rb) - d.Exp(&a.bigint, &rb, Modulus()) - - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { - return false - } - } - return true - }, - genA, - genB, - )) - - properties.Property("Exp: operation result must be smaller than modulus", prop.ForAll( - func(a, b testPairElement) bool { - var c Element - - c.Exp(a.element, &b.bigint) - - return !c.biggerOrEqualModulus() - }, - genA, - genB, - )) - - specialValueTest := func() { - // test special values against special values - testValues := make([]Element, len(staticTestValues)) - copy(testValues, staticTestValues) - - for _, a := range testValues { - var aBig big.Int - a.ToBigIntRegular(&aBig) - for _, b := range testValues { - - var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) - - var c Element - c.Exp(a, &bBig) - d.Exp(&aBig, &bBig, Modulus()) - - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { - t.Fatal("Exp failed special test values") - } - } - } - } - - properties.TestingRun(t, gopter.ConsoleReporter(false)) - specialValueTest() - // if we have ADX instruction enabled, test both path in assembly - if supportAdx { - t.Log("disabling ADX") - supportAdx = false - properties.TestingRun(t, gopter.ConsoleReporter(false)) - specialValueTest() - supportAdx = true - } -} - -func TestElementSquare(t *testing.T) { - parameters := gopter.DefaultTestParameters() - if testing.Short() { - parameters.MinSuccessfulTests = nbFuzzShort - } else { - parameters.MinSuccessfulTests = nbFuzz - } - - properties := gopter.NewProperties(parameters) - - genA := gen() - - properties.Property("Square: having the receiver as operand should output the same result", prop.ForAll( - func(a testPairElement) bool { - - var b Element - - b.Square(&a.element) - a.element.Square(&a.element) - return a.element.Equal(&b) - }, - genA, - )) - - properties.Property("Square: operation result must match big.Int result", prop.ForAll( - func(a testPairElement) bool { - var c Element - c.Square(&a.element) - - var d, e big.Int - d.Mul(&a.bigint, &a.bigint).Mod(&d, Modulus()) - - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 - }, - genA, - )) - - properties.Property("Square: operation result must be smaller than modulus", prop.ForAll( - func(a testPairElement) bool { - var c Element - c.Square(&a.element) - return !c.biggerOrEqualModulus() - }, - genA, - )) - - specialValueTest := func() { - // test special values - testValues := make([]Element, len(staticTestValues)) - copy(testValues, staticTestValues) - - for _, a := range testValues { - var aBig big.Int - a.ToBigIntRegular(&aBig) - var c Element - c.Square(&a) - - var d, e big.Int - d.Mul(&aBig, &aBig).Mod(&d, Modulus()) - - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { - t.Fatal("Square failed special test values") - } - } - } - - properties.TestingRun(t, gopter.ConsoleReporter(false)) - specialValueTest() - // if we have ADX instruction enabled, test both path in assembly - if supportAdx { - supportAdx = false - t.Log("disabling ADX") - properties.TestingRun(t, gopter.ConsoleReporter(false)) - specialValueTest() - supportAdx = true - } -} - -func TestElementInverse(t *testing.T) { - parameters := gopter.DefaultTestParameters() - if testing.Short() { - parameters.MinSuccessfulTests = nbFuzzShort - } else { - parameters.MinSuccessfulTests = nbFuzz - } - - properties := gopter.NewProperties(parameters) - - genA := gen() - - properties.Property("Inverse: having the receiver as operand should output the same result", prop.ForAll( - func(a testPairElement) bool { - - var b Element - - b.Inverse(&a.element) - a.element.Inverse(&a.element) - return a.element.Equal(&b) - }, - genA, - )) - - properties.Property("Inverse: operation result must match big.Int result", prop.ForAll( - func(a testPairElement) bool { - var c Element - c.Inverse(&a.element) - - var d, e big.Int - d.ModInverse(&a.bigint, Modulus()) - - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 - }, - genA, - )) - - properties.Property("Inverse: operation result must be smaller than modulus", prop.ForAll( - func(a testPairElement) bool { - var c Element - c.Inverse(&a.element) - return !c.biggerOrEqualModulus() - }, - genA, - )) - - specialValueTest := func() { - // test special values - testValues := make([]Element, len(staticTestValues)) - copy(testValues, staticTestValues) - - for _, a := range testValues { - var aBig big.Int - a.ToBigIntRegular(&aBig) - var c Element - c.Inverse(&a) - - var d, e big.Int - d.ModInverse(&aBig, Modulus()) - - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { - t.Fatal("Inverse failed special test values") - } - } - } - - properties.TestingRun(t, gopter.ConsoleReporter(false)) - specialValueTest() - // if we have ADX instruction enabled, test both path in assembly - if supportAdx { - supportAdx = false - t.Log("disabling ADX") - properties.TestingRun(t, gopter.ConsoleReporter(false)) - specialValueTest() - supportAdx = true - } -} - -func TestElementSqrt(t *testing.T) { - parameters := gopter.DefaultTestParameters() - if testing.Short() { - parameters.MinSuccessfulTests = nbFuzzShort - } else { - parameters.MinSuccessfulTests = nbFuzz - } - - properties := gopter.NewProperties(parameters) - - genA := gen() - - properties.Property("Sqrt: having the receiver as operand should output the same result", prop.ForAll( - func(a testPairElement) bool { - - b := a.element - - b.Sqrt(&a.element) - a.element.Sqrt(&a.element) - return a.element.Equal(&b) - }, - genA, - )) - - properties.Property("Sqrt: operation result must match big.Int result", prop.ForAll( - func(a testPairElement) bool { - var c Element - c.Sqrt(&a.element) - - var d, e big.Int - d.ModSqrt(&a.bigint, Modulus()) - - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 - }, - genA, - )) - - properties.Property("Sqrt: operation result must be smaller than modulus", prop.ForAll( - func(a testPairElement) bool { - var c Element - c.Sqrt(&a.element) - return !c.biggerOrEqualModulus() - }, - genA, - )) - - specialValueTest := func() { - // test special values - testValues := make([]Element, len(staticTestValues)) - copy(testValues, staticTestValues) - - for _, a := range testValues { - var aBig big.Int - a.ToBigIntRegular(&aBig) - var c Element - c.Sqrt(&a) - - var d, e big.Int - d.ModSqrt(&aBig, Modulus()) - - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { - t.Fatal("Sqrt failed special test values") - } - } - } - - properties.TestingRun(t, gopter.ConsoleReporter(false)) - specialValueTest() - // if we have ADX instruction enabled, test both path in assembly - if supportAdx { - supportAdx = false - t.Log("disabling ADX") - properties.TestingRun(t, gopter.ConsoleReporter(false)) - specialValueTest() - supportAdx = true - } -} - -func TestElementDouble(t *testing.T) { - parameters := gopter.DefaultTestParameters() - if testing.Short() { - parameters.MinSuccessfulTests = nbFuzzShort - } else { - parameters.MinSuccessfulTests = nbFuzz - } - - properties := gopter.NewProperties(parameters) - - genA := gen() - - properties.Property("Double: having the receiver as operand should output the same result", prop.ForAll( - func(a testPairElement) bool { - - var b Element - - b.Double(&a.element) - a.element.Double(&a.element) - return a.element.Equal(&b) - }, - genA, - )) - - properties.Property("Double: operation result must match big.Int result", prop.ForAll( - func(a testPairElement) bool { - var c Element - c.Double(&a.element) - - var d, e big.Int - d.Lsh(&a.bigint, 1).Mod(&d, Modulus()) - - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 - }, - genA, - )) - - properties.Property("Double: operation result must be smaller than modulus", prop.ForAll( - func(a testPairElement) bool { - var c Element - c.Double(&a.element) - return !c.biggerOrEqualModulus() - }, - genA, - )) - - properties.Property("Double: assembly implementation must be consistent with generic one", prop.ForAll( - func(a testPairElement) bool { - var c, d Element - c.Double(&a.element) - _doubleGeneric(&d, &a.element) - return c.Equal(&d) - }, - genA, - )) - - specialValueTest := func() { - // test special values - testValues := make([]Element, len(staticTestValues)) - copy(testValues, staticTestValues) - - for _, a := range testValues { - var aBig big.Int - a.ToBigIntRegular(&aBig) - var c Element - c.Double(&a) - - var d, e big.Int - d.Lsh(&aBig, 1).Mod(&d, Modulus()) - - // checking asm against generic impl - var cGeneric Element - _doubleGeneric(&cGeneric, &a) - if !cGeneric.Equal(&c) { - t.Fatal("Double failed special test values: asm and generic impl don't match") - } - - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { - t.Fatal("Double failed special test values") - } - } - } - - properties.TestingRun(t, gopter.ConsoleReporter(false)) - specialValueTest() - // if we have ADX instruction enabled, test both path in assembly - if supportAdx { - supportAdx = false - t.Log("disabling ADX") - properties.TestingRun(t, gopter.ConsoleReporter(false)) - specialValueTest() - supportAdx = true - } -} - -func TestElementNeg(t *testing.T) { - parameters := gopter.DefaultTestParameters() - if testing.Short() { - parameters.MinSuccessfulTests = nbFuzzShort - } else { - parameters.MinSuccessfulTests = nbFuzz - } - - properties := gopter.NewProperties(parameters) - - genA := gen() - - properties.Property("Neg: having the receiver as operand should output the same result", prop.ForAll( - func(a testPairElement) bool { - - var b Element - - b.Neg(&a.element) - a.element.Neg(&a.element) - return a.element.Equal(&b) - }, - genA, - )) - - properties.Property("Neg: operation result must match big.Int result", prop.ForAll( - func(a testPairElement) bool { - var c Element - c.Neg(&a.element) - - var d, e big.Int - d.Neg(&a.bigint).Mod(&d, Modulus()) - - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 - }, - genA, - )) - - properties.Property("Neg: operation result must be smaller than modulus", prop.ForAll( - func(a testPairElement) bool { - var c Element - c.Neg(&a.element) - return !c.biggerOrEqualModulus() - }, - genA, - )) - - properties.Property("Neg: assembly implementation must be consistent with generic one", prop.ForAll( - func(a testPairElement) bool { - var c, d Element - c.Neg(&a.element) - _negGeneric(&d, &a.element) - return c.Equal(&d) - }, - genA, - )) - - specialValueTest := func() { - // test special values - testValues := make([]Element, len(staticTestValues)) - copy(testValues, staticTestValues) - - for _, a := range testValues { - var aBig big.Int - a.ToBigIntRegular(&aBig) - var c Element - c.Neg(&a) - - var d, e big.Int - d.Neg(&aBig).Mod(&d, Modulus()) - - // checking asm against generic impl - var cGeneric Element - _negGeneric(&cGeneric, &a) - if !cGeneric.Equal(&c) { - t.Fatal("Neg failed special test values: asm and generic impl don't match") - } - - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { - t.Fatal("Neg failed special test values") - } - } - } - - properties.TestingRun(t, gopter.ConsoleReporter(false)) - specialValueTest() - // if we have ADX instruction enabled, test both path in assembly - if supportAdx { - supportAdx = false - t.Log("disabling ADX") - properties.TestingRun(t, gopter.ConsoleReporter(false)) - specialValueTest() - supportAdx = true - } -} - -func TestElementFromMont(t *testing.T) { - - parameters := gopter.DefaultTestParameters() - if testing.Short() { - parameters.MinSuccessfulTests = nbFuzzShort - } else { - parameters.MinSuccessfulTests = nbFuzz - } - - properties := gopter.NewProperties(parameters) - - genA := gen() - - properties.Property("Assembly implementation must be consistent with generic one", prop.ForAll( - func(a testPairElement) bool { - c := a.element - d := a.element - c.FromMont() - _fromMontGeneric(&d) - return c.Equal(&d) - }, - genA, - )) - - properties.Property("x.FromMont().ToMont() == x", prop.ForAll( - func(a testPairElement) bool { - c := a.element - c.FromMont().ToMont() - return c.Equal(&a.element) - }, - genA, - )) - - properties.TestingRun(t, gopter.ConsoleReporter(false)) -} - -type testPairElement struct { - element Element - bigint big.Int -} - -func (z *Element) biggerOrEqualModulus() bool { - if z[3] > qElement[3] { - return true - } - if z[3] < qElement[3] { - return false - } - - if z[2] > qElement[2] { - return true - } - if z[2] < qElement[2] { - return false - } - - if z[1] > qElement[1] { - return true - } - if z[1] < qElement[1] { - return false - } - - return z[0] >= qElement[0] -} - -func gen() gopter.Gen { - return func(genParams *gopter.GenParameters) *gopter.GenResult { - var g testPairElement - - g.element = Element{ - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - } - if qElement[3] != ^uint64(0) { - g.element[3] %= (qElement[3] + 1) - } - - for g.element.biggerOrEqualModulus() { - g.element = Element{ - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - } - if qElement[3] != ^uint64(0) { - g.element[3] %= (qElement[3] + 1) - } - } - - g.element.ToBigIntRegular(&g.bigint) - genResult := gopter.NewGenResult(g, gopter.NoShrinker) - return genResult - } -} - -func genFull() gopter.Gen { - return func(genParams *gopter.GenParameters) *gopter.GenResult { - - genRandomFq := func() Element { - var g Element - - g = Element{ - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - } - - if qElement[3] != ^uint64(0) { - g[3] %= (qElement[3] + 1) - } - - for g.biggerOrEqualModulus() { - g = Element{ - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - genParams.NextUint64(), - } - if qElement[3] != ^uint64(0) { - g[3] %= (qElement[3] + 1) - } - } - - return g - } - a := genRandomFq() - - var carry uint64 - a[0], carry = bits.Add64(a[0], qElement[0], carry) - a[1], carry = bits.Add64(a[1], qElement[1], carry) - a[2], carry = bits.Add64(a[2], qElement[2], carry) - a[3], _ = bits.Add64(a[3], qElement[3], carry) - - genResult := gopter.NewGenResult(a, gopter.NoShrinker) - return genResult - } -} \ No newline at end of file diff --git a/element.go b/element.go new file mode 100644 index 0000000..7cabce0 --- /dev/null +++ b/element.go @@ -0,0 +1,88 @@ +package poseidon + +import ( + "math/big" + "reflect" + + "github.com/consensys/gnark-crypto/field/pool" +) + +type Element[E any] interface { + SetUint64(uint64) E + SetBigInt(*big.Int) E + SetBytes([]byte) E + SetString(string) (E, error) + BigInt(*big.Int) *big.Int + SetOne() E + SetZero() E + Inverse(E) E + Set(E) E + Square(E) E + Mul(E, E) E + Add(E, E) E + Sub(E, E) E + Cmp(x E) int +} + +func NewElement[E Element[E]]() E { + typ := reflect.TypeOf((*E)(nil)).Elem() + val := reflect.New(typ.Elem()) + return val.Interface().(E) +} + +func isNil[E Element[E]](t E) bool { + v := reflect.ValueOf(t) + return v.IsNil() +} + +func zero[E Element[E]]() E { + return NewElement[E]().SetZero() +} + +func one[E Element[E]]() E { + return NewElement[E]().SetOne() +} + +func Modulus[E Element[E]]() *big.Int { + e := NewElement[E]().SetZero() + e.Sub(e, NewElement[E]().SetOne()) + b := e.BigInt(new(big.Int)) + return b.Add(b, big.NewInt(1)) +} + +func Bits[E Element[E]]() int { + return Modulus[E]().BitLen() +} + +func Bytes[E Element[E]]() int { + return (Bits[E]() + 7) / 8 +} + +// Exp is a copy of gnark-crypto's implementation, but takes a pointer argument +func Exp[E Element[E]](z, x E, k *big.Int) { + if k.IsUint64() && k.Uint64() == 0 { + z.SetOne() + } + + e := k + if k.Sign() == -1 { + // negative k, we invert + // if k < 0: xᵏ (mod q) == (x⁻¹)ᵏ (mod q) + x.Inverse(x) + + // we negate k in a temp big.Int since + // Int.Bit(_) of k and -k is different + e = pool.BigInt.Get() + defer pool.BigInt.Put(e) + e.Neg(k) + } + + z.Set(x) + + for i := e.BitLen() - 2; i >= 0; i-- { + z.Square(z) + if e.Bit(i) == 1 { + z.Mul(z, x) + } + } +} diff --git a/go.mod b/go.mod index 7cd464d..ab31db8 100644 --- a/go.mod +++ b/go.mod @@ -1,10 +1,18 @@ module github.com/triplewz/poseidon -go 1.14 +go 1.21 require ( - github.com/leanovate/gopter v0.2.9 // indirect - github.com/pkg/errors v0.9.1 - github.com/stretchr/testify v1.7.0 // indirect - golang.org/x/sys v0.0.0-20210616094352-59db8d763f22 + github.com/consensys/gnark-crypto v0.12.1 + github.com/stretchr/testify v1.8.2 +) + +require ( + github.com/bits-and-blooms/bitset v1.7.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/rogpeppe/go-internal v1.12.0 // indirect + golang.org/x/sys v0.9.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index ecbbf54..59bfa32 100644 --- a/go.sum +++ b/go.sum @@ -1,16 +1,33 @@ -github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/bits-and-blooms/bitset v1.7.0 h1:YjAGVd3XmtK9ktAbX8Zg2g2PwLIMjGREZJHlV4j7NEo= +github.com/bits-and-blooms/bitset v1.7.0/go.mod h1:gIdJ4wp64HaoK2YrL1Q5/N7Y16edYb8uY+O0FJTyyDA= +github.com/consensys/gnark-crypto v0.12.1 h1:lHH39WuuFgVHONRl3J0LRBtuYdQTumFSDtJF7HpyG8M= +github.com/consensys/gnark-crypto v0.12.1/go.mod h1:v2Gy7L/4ZRosZ7Ivs+9SfUDr0f5UlG+EM5t7MPHiLuY= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leanovate/gopter v0.2.9 h1:fQjYxZaynp97ozCzfOyOuAGOU4aU/z37zf/tOujFk7c= github.com/leanovate/gopter v0.2.9/go.mod h1:U2L/78B+KVFIx2VmW6onHJQzXtFb+p5y3y2Sh+Jxxv8= -github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= -github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= +github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= -github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -golang.org/x/sys v0.0.0-20210616094352-59db8d763f22 h1:RqytpXGR1iVNX7psjB3ff8y7sNFinVFvkx1c8SjBkio= -golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= +github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s= +golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/matrix.go b/matrix.go index fbee84e..1bd07f8 100644 --- a/matrix.go +++ b/matrix.go @@ -1,19 +1,16 @@ package poseidon import ( - "github.com/pkg/errors" - ff "github.com/triplewz/poseidon/bls12_381" + "errors" + "fmt" ) -type Matrix [][]*ff.Element +type Matrix[E Element[E]] [][]E -type Vector []*ff.Element - -var one = new(ff.Element).SetOne() -var zero = new(ff.Element).SetZero() +type Vector[E Element[E]] []E // return the column numbers of the matrix. -func column(m Matrix) int { +func column[E Element[E]](m Matrix[E]) int { if len(m) > 0 { length := len(m[0]) for i := 1; i < len(m); i++ { @@ -28,17 +25,17 @@ func column(m Matrix) int { } // return the row numbers of the matrix. -func row(m Matrix) int { +func row[E Element[E]](m Matrix[E]) int { return len(m) } // for 0 <= i < row, 0 <= j < column, compute M_ij*scalar. -func ScalarMul(scalar *ff.Element, m Matrix) Matrix { - res := make([][]*ff.Element, len(m)) +func ScalarMul[E Element[E]](scalar E, m Matrix[E]) Matrix[E] { + res := make([][]E, len(m)) for i := 0; i < len(m); i++ { - res[i] = make([]*ff.Element, len(m[i])) + res[i] = make([]E, len(m[i])) for j := 0; j < len(m[i]); j++ { - res[i][j] = new(ff.Element).Mul(scalar, m[i][j]) + res[i][j] = NewElement[E]().Mul(scalar, m[i][j]) } } @@ -46,58 +43,58 @@ func ScalarMul(scalar *ff.Element, m Matrix) Matrix { } // for 0 <= i < length, compute v_i*scalar. -func ScalarVecMul(scalar *ff.Element, v Vector) Vector { - res := make([]*ff.Element, len(v)) +func ScalarVecMul[E Element[E]](scalar E, v Vector[E]) Vector[E] { + res := make([]E, len(v)) for i := 0; i < len(v); i++ { - res[i] = new(ff.Element).Mul(scalar, v[i]) + res[i] = NewElement[E]().Mul(scalar, v[i]) } return res } -func VecAdd(a, b Vector) (Vector, error) { +func VecAdd[E Element[E]](a, b Vector[E]) (Vector[E], error) { if len(a) != len(b) { - return nil, errors.New("length err: cannot compute vector add!") + return nil, errors.New("length err: cannot compute vector add") } - res := make([]*ff.Element, len(a)) + res := make([]E, len(a)) for i := 0; i < len(a); i++ { - res[i] = new(ff.Element).Add(a[i], b[i]) + res[i] = NewElement[E]().Add(a[i], b[i]) } return res, nil } -func VecSub(a, b Vector) (Vector, error) { +func VecSub[E Element[E]](a, b Vector[E]) (Vector[E], error) { if len(a) != len(b) { - return nil, errors.New("length err: cannot compute vector sub!") + return nil, errors.New("length err: cannot compute vector sub") } - res := make([]*ff.Element, len(a)) + res := make([]E, len(a)) for i := 0; i < len(a); i++ { - res[i] = new(ff.Element).Sub(a[i], b[i]) + res[i] = NewElement[E]().Sub(a[i], b[i]) } return res, nil } // compute the product between two vectors. -func VecMul(a, b Vector) (*ff.Element, error) { +func VecMul[E Element[E]](a, b Vector[E]) (E, error) { + res := NewElement[E]() if len(a) != len(b) { - return nil, errors.New("length err: cannot compute vector mul!") + return res, errors.New("length err: cannot compute vector mul!") } - res := new(ff.Element) for i := 0; i < len(a); i++ { - tmp := new(ff.Element).Mul(a[i], b[i]) + tmp := NewElement[E]().Mul(a[i], b[i]) res.Add(res, tmp) } return res, nil } -func IsVecEqual(a, b Vector) bool { +func IsVecEqual[E Element[E]](a, b Vector[E]) bool { if len(a) != len(b) { return false } @@ -121,21 +118,21 @@ func IsVecEqual(a, b Vector) bool { // if delta(m)≠0, m is invertible. // so we can transform m to the upper triangular matrix, // and if all upper diagonal elements are not zero, then m is invertible. -func IsInvertible(m Matrix) bool { +func IsInvertible[E Element[E]](m Matrix[E]) bool { // need to copy m. tmp := copyMatrixRows(m, 0, row(m)) if !IsSquareMatrix(tmp) { return false } - shadow := MakeIdentity(row(tmp)) + shadow := MakeIdentity[E](row(tmp)) upper, _, err := upperTriangular(tmp, shadow) if err != nil { panic(err) } for i := 0; i < row(tmp); i++ { - if upper[i][i].Cmp(zero) == 0 { + if upper[i][i].Cmp(zero[E]()) == 0 { return false } } @@ -144,21 +141,21 @@ func IsInvertible(m Matrix) bool { } // compute the product between two matrices. -func MatMul(a, b Matrix) (Matrix, error) { +func MatMul[E Element[E]](a, b Matrix[E]) (Matrix[E], error) { if row(a) != column(b) { - return nil, errors.New("cannot compute the result!") + return nil, errors.New("cannot compute the result") } transb := transpose(b) var err error - res := make([][]*ff.Element, row(a)) + res := make([][]E, row(a)) for i := 0; i < row(a); i++ { - res[i] = make([]*ff.Element, column(b)) + res[i] = make([]E, column(b)) for j := 0; j < column(b); j++ { res[i][j], err = VecMul(a[i], transb[j]) if err != nil { - return nil, errors.Errorf("vec mul err: %s", err) + return nil, fmt.Errorf("vec mul err: %w", err) } } } @@ -167,21 +164,21 @@ func MatMul(a, b Matrix) (Matrix, error) { } // left Matrix multiplication, denote by M*V, where M is the matrix, and V is the vector. -func LeftMatMul(m Matrix, v Vector) (Vector, error) { +func LeftMatMul[E Element[E]](m Matrix[E], v Vector[E]) (Vector[E], error) { if !IsSquareMatrix(m) { panic("matrix is not square!") } if row(m) != len(v) { - return nil, errors.New("length err: cannot compute matrix multiplication with the vector!") + return nil, errors.New("length err: cannot compute matrix multiplication with the vector") } - res := make([]*ff.Element, len(v)) + res := make([]E, len(v)) var err error for i := 0; i < len(v); i++ { - res[i], err = VecMul(m[i], v) + res[i], err = VecMul[E](m[i], v) if err != nil { - return nil, errors.Errorf("vector mul err:%s", err) + return nil, fmt.Errorf("vector mul err: %w", err) } } @@ -189,22 +186,22 @@ func LeftMatMul(m Matrix, v Vector) (Vector, error) { } // right Matrix multiplication, denote by V*M, where V is the vector, and M is the matrix. -func RightMatMul(v Vector, m Matrix) (Vector, error) { +func RightMatMul[E Element[E]](v Vector[E], m Matrix[E]) (Vector[E], error) { if !IsSquareMatrix(m) { - return nil, errors.New("matrix is not square!") + return nil, errors.New("matrix is not square") } if row(m) != len(v) { - return nil, errors.New("length err: cannot compute matrix multiplication with the vector!") + return nil, errors.New("length err: cannot compute matrix multiplication with the vector") } transm := transpose(m) - res := make([]*ff.Element, len(v)) + res := make([]E, len(v)) var err error for i := 0; i < len(v); i++ { res[i], err = VecMul(transm[i], v) if err != nil { - return nil, errors.Errorf("vector mul err:%s", err) + return nil, fmt.Errorf("vector mul err: %w", err) } } @@ -212,11 +209,11 @@ func RightMatMul(v Vector, m Matrix) (Vector, error) { } // swap rows and columns of the matrix. -func transpose(m Matrix) Matrix { - res := make([][]*ff.Element, column(m)) +func transpose[E Element[E]](m Matrix[E]) Matrix[E] { + res := make([][]E, column(m)) for j := 0; j < column(m); j++ { - res[j] = make([]*ff.Element, len(m)) + res[j] = make([]E, len(m)) for i := 0; i < len(m); i++ { res[j][i] = m[i][j] } @@ -226,21 +223,21 @@ func transpose(m Matrix) Matrix { } // the square matrix is a t*t matrix. -func IsSquareMatrix(m Matrix) bool { +func IsSquareMatrix[E Element[E]](m Matrix[E]) bool { return row(m) == column(m) } // make t*t identity matrix. -func MakeIdentity(t int) Matrix { - res := make([][]*ff.Element, t) +func MakeIdentity[E Element[E]](t int) Matrix[E] { + res := make([][]E, t) for i := 0; i < t; i++ { - res[i] = make([]*ff.Element, t) + res[i] = make([]E, t) for j := 0; j < t; j++ { if i == j { - res[i][j] = one + res[i][j] = one[E]() } else { - res[i][j] = zero + res[i][j] = zero[E]() } } } @@ -249,10 +246,10 @@ func MakeIdentity(t int) Matrix { } // determine if a matrix is identity. -func IsIdentity(m Matrix) bool { +func IsIdentity[E Element[E]](m Matrix[E]) bool { for i := 0; i < row(m); i++ { for j := 0; j < column(m); j++ { - if ((i == j) && m[i][j].Cmp(one) != 0) || ((i != j) && (m[i][j].Cmp(zero) != 0)) { + if ((i == j) && m[i][j].Cmp(one[E]()) != 0) || ((i != j) && (m[i][j].Cmp(zero[E]()) != 0)) { return false } } @@ -261,7 +258,7 @@ func IsIdentity(m Matrix) bool { return true } -func IsEqual(a, b Matrix) bool { +func IsEqual[E Element[E]](a, b Matrix[E]) bool { if row(a) != row(b) || column(a) != column(b) { return false } @@ -287,12 +284,12 @@ func IsEqual(a, b Matrix) bool { } // remove i-th row and j-th column of the matrix. -func minor(m Matrix, rowIndex, columnIndex int) (Matrix, error) { +func minor[E Element[E]](m Matrix[E], rowIndex, columnIndex int) (Matrix[E], error) { if !IsSquareMatrix(m) { return nil, errors.New("matrix is not square!") } - res := make([][]*ff.Element, row(m)-1) + res := make([][]E, row(m)-1) for i := 0; i < row(m); i++ { if i < rowIndex { @@ -314,13 +311,13 @@ func minor(m Matrix, rowIndex, columnIndex int) (Matrix, error) { } // determine if the first k elements are zero. -func isFirstKZero(v Vector, k int) bool { - if k == 0 && v[0].Cmp(zero) == 0 { +func isFirstKZero[E Element[E]](v Vector[E], k int) bool { + if k == 0 && v[0].Cmp(zero[E]()) == 0 { return false } for i := 0; i < k; i++ { - if v[i].Cmp(zero) != 0 { + if v[i].Cmp(zero[E]()) != 0 { return false } } @@ -328,15 +325,15 @@ func isFirstKZero(v Vector, k int) bool { } // find the first non-zero element in the given column. -func findNonZero(m Matrix, index int) (pivot *ff.Element, pivotIndex int, err error) { +func findNonZero[E Element[E]](m Matrix[E], index int) (pivot E, pivotIndex int, err error) { pivotIndex = -1 if index > column(m) { - return nil, -1, errors.New("index out of range!") + return NewElement[E](), -1, errors.New("index out of range!") } for i := 0; i < row(m); i++ { - if m[i][index].Cmp(zero) != 0 { + if m[i][index].Cmp(zero[E]()) != 0 { pivot = m[i][index] pivotIndex = i break @@ -347,27 +344,27 @@ func findNonZero(m Matrix, index int) (pivot *ff.Element, pivotIndex int, err er } // assume matrix is partially reduced to upper triangular. -func eliminate(m, shadow Matrix, columnIndex int) (Matrix, Matrix, error) { +func eliminate[E Element[E]](m, shadow Matrix[E], columnIndex int) (Matrix[E], Matrix[E], error) { pivot, pivotIndex, err := findNonZero(m, columnIndex) if err != nil || pivotIndex == -1 { - return nil, nil, errors.Errorf("cannot find non-zero element: %s", err) + return nil, nil, fmt.Errorf("cannot find non-zero element: %w", err) } - pivotInv := new(ff.Element).Inverse(pivot) + pivotInv := NewElement[E]().Inverse(pivot) for i := 0; i < row(m); i++ { if i == pivotIndex { continue } - if m[i][columnIndex].Cmp(zero) != 0 { - factor := new(ff.Element).Mul(m[i][columnIndex], pivotInv) + if m[i][columnIndex].Cmp(zero[E]()) != 0 { + factor := NewElement[E]().Mul(m[i][columnIndex], pivotInv) scalarPivot := ScalarVecMul(factor, m[pivotIndex]) m[i], err = VecSub(m[i], scalarPivot) if err != nil { - return nil, nil, errors.Errorf("matrix m eliminate failed, vec sub err: %s", err) + return nil, nil, fmt.Errorf("matrix m eliminate failed, vec sub err: %w", err) } shadowPivot := shadow[pivotIndex] @@ -376,7 +373,7 @@ func eliminate(m, shadow Matrix, columnIndex int) (Matrix, Matrix, error) { shadow[i], err = VecSub(shadow[i], scalarShadowPivot) if err != nil { - return nil, nil, errors.Errorf("matrix shadow eliminate failed, vec sub err: %s", err) + return nil, nil, fmt.Errorf("matrix shadow eliminate failed, vec sub err: %w", err) } } } @@ -385,15 +382,15 @@ func eliminate(m, shadow Matrix, columnIndex int) (Matrix, Matrix, error) { } // copy rows between start index and end index. -func copyMatrixRows(m Matrix, startIndex, endIndex int) Matrix { +func copyMatrixRows[E Element[E]](m Matrix[E], startIndex, endIndex int) Matrix[E] { if startIndex >= endIndex { panic("start index should be less than end index!") } - res := make([][]*ff.Element, endIndex-startIndex) + res := make([][]E, endIndex-startIndex) for i := 0; i < endIndex-startIndex; i++ { - res[i] = make([]*ff.Element, column(m)) + res[i] = make([]E, column(m)) copy(res[i], m[i+startIndex]) } @@ -401,11 +398,11 @@ func copyMatrixRows(m Matrix, startIndex, endIndex int) Matrix { } // reverse rows of the matrix. -func reverseRows(m Matrix) Matrix { - res := make([][]*ff.Element, row(m)) +func reverseRows[E Element[E]](m Matrix[E]) Matrix[E] { + res := make([][]E, row(m)) for i := 0; i < row(m); i++ { - res[i] = make([]*ff.Element, column(m)) + res[i] = make([]E, column(m)) copy(res[i], m[row(m)-i-1]) } @@ -413,10 +410,10 @@ func reverseRows(m Matrix) Matrix { } // determine if numbers of zero elements equals to n. -func zeroNums(v Vector, n int) bool { +func zeroNums[E Element[E]](v Vector[E], n int) bool { count := 0 for i := 0; i < len(v); i++ { - if v[i].Cmp(zero) != 0 { + if v[i].Cmp(zero[E]()) != 0 { break } count++ @@ -430,7 +427,7 @@ func zeroNums(v Vector, n int) bool { } // determine if a matrix is upper triangular. -func isUpperTriangular(m Matrix) bool { +func isUpperTriangular[E Element[E]](m Matrix[E]) bool { for i := 0; i < row(m); i++ { if !zeroNums(m[i], i) { return false @@ -441,23 +438,23 @@ func isUpperTriangular(m Matrix) bool { } // transform a square matrix to upper triangular matrix. -func upperTriangular(m, shadow Matrix) (Matrix, Matrix, error) { +func upperTriangular[E Element[E]](m, shadow Matrix[E]) (Matrix[E], Matrix[E], error) { if !IsSquareMatrix(m) { return nil, nil, errors.New("matrix is not square!") } curr := copyMatrixRows(m, 0, row(m)) currShadow := copyMatrixRows(shadow, 0, row(shadow)) - result := make([][]*ff.Element, row(m)) - shadowResult := make([][]*ff.Element, row(shadow)) + result := make([][]E, row(m)) + shadowResult := make([][]E, row(shadow)) c := 0 var err error for row(curr) > 1 { - result[c] = make([]*ff.Element, column(m)) - shadowResult[c] = make([]*ff.Element, column(shadow)) + result[c] = make([]E, column(m)) + shadowResult[c] = make([]E, column(shadow)) curr, currShadow, err = eliminate(curr, currShadow, c) if err != nil { - return nil, nil, errors.Errorf("matrix eliminate err: %s", err) + return nil, nil, fmt.Errorf("matrix eliminate err: %w", err) } copy(result[c], curr[0]) @@ -468,8 +465,8 @@ func upperTriangular(m, shadow Matrix) (Matrix, Matrix, error) { curr = copyMatrixRows(curr, 1, row(curr)) currShadow = copyMatrixRows(currShadow, 1, row(currShadow)) } - result[c] = make([]*ff.Element, column(m)) - shadowResult[c] = make([]*ff.Element, column(shadow)) + result[c] = make([]E, column(m)) + shadowResult[c] = make([]E, column(shadow)) copy(result[c], curr[0]) copy(shadowResult[c], currShadow[0]) @@ -477,22 +474,22 @@ func upperTriangular(m, shadow Matrix) (Matrix, Matrix, error) { } // reduce a upper triangular matrix to identity matrix. -func reduceToIdentity(m, shadow Matrix) (Matrix, Matrix, error) { +func reduceToIdentity[E Element[E]](m, shadow Matrix[E]) (Matrix[E], Matrix[E], error) { var err error - result := make([][]*ff.Element, row(m)) - shadowResult := make([][]*ff.Element, row(shadow)) + result := make([][]E, row(m)) + shadowResult := make([][]E, row(shadow)) for i := 0; i < row(m); i++ { - result[i] = make([]*ff.Element, column(m)) - shadowResult[i] = make([]*ff.Element, column(shadow)) + result[i] = make([]E, column(m)) + shadowResult[i] = make([]E, column(shadow)) indexi := row(m) - i - 1 factor := m[indexi][indexi] - if factor.Cmp(zero) == 0 { + if factor.Cmp(zero[E]()) == 0 { return nil, nil, errors.New("cannot compute the result!") } - factorInv := new(ff.Element).Inverse(factor) + factorInv := NewElement[E]().Inverse(factor) norm := ScalarVecMul(factorInv, m[indexi]) @@ -507,12 +504,12 @@ func reduceToIdentity(m, shadow Matrix) (Matrix, Matrix, error) { norm, err = VecSub(norm, scalarVal) if err != nil { - return nil, nil, errors.Errorf("reduces to identity matrix failed, err: %s", err) + return nil, nil, fmt.Errorf("reduces to identity matrix failed, err: %w", err) } shadowNorm, err = VecSub(shadowNorm, scalarShadow) if err != nil { - return nil, nil, errors.Errorf("reduces to identity matrix failed, err: %s", err) + return nil, nil, fmt.Errorf("reduces to identity matrix failed, err: %w", err) } } copy(result[i], norm) @@ -527,30 +524,30 @@ func reduceToIdentity(m, shadow Matrix) (Matrix, Matrix, error) { // use Gaussian elimination to invert a matrix. // A|I -> I|A^-1. -func Invert(m Matrix) (Matrix, error) { +func Invert[E Element[E]](m Matrix[E]) (Matrix[E], error) { if !IsInvertible(m) { - return nil, errors.Errorf("the matrix is not invertible!") + return nil, fmt.Errorf("the matrix is not invertible") } - shadow := MakeIdentity(row(m)) + shadow := MakeIdentity[E](row(m)) up, upShadow, err := upperTriangular(m, shadow) if err != nil { - return nil, errors.Errorf("transform to upper triangular matrix failed, err: %s", err) + return nil, fmt.Errorf("transform to upper triangular matrix failed, err: %w", err) } if !isUpperTriangular(up) { - return nil, errors.Errorf("the matrix should be upper triangular before reducing!") + return nil, fmt.Errorf("the matrix should be upper triangular before reducing") } // reduce m to identity, so shadow matrix transforms to the inverse of m. reduce, reducedShadow, err := reduceToIdentity(up, upShadow) if err != nil { - return nil, errors.Errorf("reduce to identity failed, err: %s", err) + return nil, fmt.Errorf("reduce to identity failed, err: %w", err) } if !IsIdentity(reduce) { - return nil, errors.New("reduces failed, the result is not the identity matrix!") + return nil, errors.New("reduces failed, the result is not the identity matrix") } return reducedShadow, nil diff --git a/matrix_test.go b/matrix_test.go index 464cc32..3272bb5 100644 --- a/matrix_test.go +++ b/matrix_test.go @@ -1,31 +1,34 @@ package poseidon import ( - "github.com/stretchr/testify/assert" - ff "github.com/triplewz/poseidon/bls12_381" "testing" + + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/stretchr/testify/assert" ) -var two = new(ff.Element).SetUint64(2) -var three = new(ff.Element).SetUint64(3) -var four = new(ff.Element).SetUint64(4) -var five = new(ff.Element).SetUint64(5) -var six = new(ff.Element).SetUint64(6) -var seven = new(ff.Element).SetUint64(7) -var eight = new(ff.Element).SetUint64(8) -var nine = new(ff.Element).SetUint64(9) +var zeroE = new(fr.Element).SetUint64(0) +var oneE = new(fr.Element).SetUint64(1) +var two = new(fr.Element).SetUint64(2) +var three = new(fr.Element).SetUint64(3) +var four = new(fr.Element).SetUint64(4) +var five = new(fr.Element).SetUint64(5) +var six = new(fr.Element).SetUint64(6) +var seven = new(fr.Element).SetUint64(7) +var eight = new(fr.Element).SetUint64(8) +var nine = new(fr.Element).SetUint64(9) func TestVector(t *testing.T) { - negTwo := new(ff.Element).Neg(two) + negTwo := new(fr.Element).Neg(two) sub := []struct { - v1, v2 Vector - want Vector + v1, v2 Vector[*fr.Element] + want Vector[*fr.Element] }{ - {Vector{one, two}, Vector{one, two}, Vector{zero, zero}}, - {Vector{one, two}, Vector{zero, zero}, Vector{one, two}}, - {Vector{three, four}, Vector{one, two}, Vector{two, two}}, - {Vector{one, two}, Vector{three, four}, Vector{negTwo, negTwo}}, + {Vector[*fr.Element]{oneE, two}, Vector[*fr.Element]{oneE, two}, Vector[*fr.Element]{zeroE, zeroE}}, + {Vector[*fr.Element]{oneE, two}, Vector[*fr.Element]{zeroE, zeroE}, Vector[*fr.Element]{oneE, two}}, + {Vector[*fr.Element]{three, four}, Vector[*fr.Element]{oneE, two}, Vector[*fr.Element]{two, two}}, + {Vector[*fr.Element]{oneE, two}, Vector[*fr.Element]{three, four}, Vector[*fr.Element]{negTwo, negTwo}}, } for _, cases := range sub { @@ -35,12 +38,12 @@ func TestVector(t *testing.T) { } add := []struct { - v1, v2 Vector - want Vector + v1, v2 Vector[*fr.Element] + want Vector[*fr.Element] }{ - {Vector{one, two}, Vector{one, two}, Vector{two, four}}, - {Vector{one, two}, Vector{zero, zero}, Vector{one, two}}, - {Vector{one, two}, Vector{one, negTwo}, Vector{two, zero}}, + {Vector[*fr.Element]{oneE, two}, Vector[*fr.Element]{oneE, two}, Vector[*fr.Element]{two, four}}, + {Vector[*fr.Element]{oneE, two}, Vector[*fr.Element]{zeroE, zeroE}, Vector[*fr.Element]{oneE, two}}, + {Vector[*fr.Element]{oneE, two}, Vector[*fr.Element]{oneE, negTwo}, Vector[*fr.Element]{two, zeroE}}, } for _, cases := range add { @@ -50,13 +53,13 @@ func TestVector(t *testing.T) { } scalarmul := []struct { - scalar *ff.Element - v Vector - want Vector + scalar *fr.Element + v Vector[*fr.Element] + want Vector[*fr.Element] }{ - {zero, Vector{one, two}, Vector{zero, zero}}, - {one, Vector{one, two}, Vector{one, two}}, - {two, Vector{one, two}, Vector{two, four}}, + {zeroE, Vector[*fr.Element]{oneE, two}, Vector[*fr.Element]{zeroE, zeroE}}, + {oneE, Vector[*fr.Element]{oneE, two}, Vector[*fr.Element]{oneE, two}}, + {two, Vector[*fr.Element]{oneE, two}, Vector[*fr.Element]{two, four}}, } for _, cases := range scalarmul { @@ -65,12 +68,12 @@ func TestVector(t *testing.T) { } vecmul := []struct { - v1, v2 Vector - want *ff.Element + v1, v2 Vector[*fr.Element] + want *fr.Element }{ - {Vector{one, two}, Vector{one, two}, five}, - {Vector{one, two}, Vector{zero, zero}, zero}, - {Vector{one, two}, Vector{negTwo, one}, zero}, + {Vector[*fr.Element]{oneE, two}, Vector[*fr.Element]{oneE, two}, five}, + {Vector[*fr.Element]{oneE, two}, Vector[*fr.Element]{zeroE, zeroE}, zeroE}, + {Vector[*fr.Element]{oneE, two}, Vector[*fr.Element]{negTwo, oneE}, zeroE}, } for _, cases := range vecmul { @@ -82,13 +85,13 @@ func TestVector(t *testing.T) { func TestMatrixScalarMul(t *testing.T) { scalarmul := []struct { - scalar *ff.Element - m Matrix - want Matrix + scalar *fr.Element + m Matrix[*fr.Element] + want Matrix[*fr.Element] }{ - {zero, Matrix{{one, two}, {one, two}}, Matrix{{zero, zero}, {zero, zero}}}, - {one, Matrix{{one, two}, {one, two}}, Matrix{{one, two}, {one, two}}}, - {two, Matrix{{one, two}, {three, four}}, Matrix{{two, four}, {six, eight}}}, + {zeroE, Matrix[*fr.Element]{{oneE, two}, {oneE, two}}, Matrix[*fr.Element]{{zeroE, zeroE}, {zeroE, zeroE}}}, + {oneE, Matrix[*fr.Element]{{oneE, two}, {oneE, two}}, Matrix[*fr.Element]{{oneE, two}, {oneE, two}}}, + {two, Matrix[*fr.Element]{{oneE, two}, {three, four}}, Matrix[*fr.Element]{{two, four}, {six, eight}}}, } for _, cases := range scalarmul { @@ -98,27 +101,27 @@ func TestMatrixScalarMul(t *testing.T) { } func TestIdentity(t *testing.T) { - get := MakeIdentity(3) - want := Matrix{{one, zero, zero}, {zero, one, zero}, {zero, zero, one}} + get := MakeIdentity[*fr.Element](3) + want := Matrix[*fr.Element]{{oneE, zeroE, zeroE}, {zeroE, oneE, zeroE}, {zeroE, zeroE, oneE}} assert.Equal(t, get, want) } func TestMinor(t *testing.T) { - m := Matrix{{one, two, three}, {four, five, six}, {seven, eight, nine}} + m := Matrix[*fr.Element]{{oneE, two, three}, {four, five, six}, {seven, eight, nine}} testMatrix := []struct { i, j int - want Matrix + want Matrix[*fr.Element] }{ - {0, 0, Matrix{{five, six}, {eight, nine}}}, - {0, 1, Matrix{{four, six}, {seven, nine}}}, - {0, 2, Matrix{{four, five}, {seven, eight}}}, - {1, 0, Matrix{{two, three}, {eight, nine}}}, - {1, 1, Matrix{{one, three}, {seven, nine}}}, - {1, 2, Matrix{{one, two}, {seven, eight}}}, - {2, 0, Matrix{{two, three}, {five, six}}}, - {2, 1, Matrix{{one, three}, {four, six}}}, - {2, 2, Matrix{{one, two}, {four, five}}}, + {0, 0, Matrix[*fr.Element]{{five, six}, {eight, nine}}}, + {0, 1, Matrix[*fr.Element]{{four, six}, {seven, nine}}}, + {0, 2, Matrix[*fr.Element]{{four, five}, {seven, eight}}}, + {1, 0, Matrix[*fr.Element]{{two, three}, {eight, nine}}}, + {1, 1, Matrix[*fr.Element]{{oneE, three}, {seven, nine}}}, + {1, 2, Matrix[*fr.Element]{{oneE, two}, {seven, eight}}}, + {2, 0, Matrix[*fr.Element]{{two, three}, {five, six}}}, + {2, 1, Matrix[*fr.Element]{{oneE, three}, {four, six}}}, + {2, 2, Matrix[*fr.Element]{{oneE, two}, {four, five}}}, } for _, cases := range testMatrix { @@ -128,19 +131,19 @@ func TestMinor(t *testing.T) { } } -func TestCopyMatrix(t *testing.T) { - m := Matrix{{one, two, three}, {four, five, six}, {seven, eight, nine}} +func TestcopyMatrix(t *testing.T) { + m := Matrix[*fr.Element]{{oneE, two, three}, {four, five, six}, {seven, eight, nine}} testMatrix := []struct { start, end int - want Matrix + want Matrix[*fr.Element] }{ - {0, 1, Matrix{{one, two, three}}}, - {0, 2, Matrix{{one, two, three}, {four, five, six}}}, - {0, 3, Matrix{{one, two, three}, {four, five, six}, {seven, eight, nine}}}, - {1, 2, Matrix{{four, five, six}}}, - {1, 3, Matrix{{four, five, six}, {seven, eight, nine}}}, - {2, 3, Matrix{{seven, eight, nine}}}, + {0, 1, Matrix[*fr.Element]{{oneE, two, three}}}, + {0, 2, Matrix[*fr.Element]{{oneE, two, three}, {four, five, six}}}, + {0, 3, Matrix[*fr.Element]{{oneE, two, three}, {four, five, six}, {seven, eight, nine}}}, + {1, 2, Matrix[*fr.Element]{{four, five, six}}}, + {1, 3, Matrix[*fr.Element]{{four, five, six}, {seven, eight, nine}}}, + {2, 3, Matrix[*fr.Element]{{seven, eight, nine}}}, } for _, cases := range testMatrix { @@ -151,10 +154,10 @@ func TestCopyMatrix(t *testing.T) { func TestTranspose(t *testing.T) { testMatrix := []struct { - input, want Matrix + input, want Matrix[*fr.Element] }{ - {Matrix{{one, two}, {three, four}}, Matrix{{one, three}, {two, four}}}, - {Matrix{{one, two, three}, {four, five, six}, {seven, eight, nine}}, Matrix{{one, four, seven}, {two, five, eight}, {three, six, nine}}}, + {Matrix[*fr.Element]{{oneE, two}, {three, four}}, Matrix[*fr.Element]{{oneE, three}, {two, four}}}, + {Matrix[*fr.Element]{{oneE, two, three}, {four, five, six}, {seven, eight, nine}}, Matrix[*fr.Element]{{oneE, four, seven}, {two, five, eight}, {three, six, nine}}}, } for _, cases := range testMatrix { @@ -164,15 +167,15 @@ func TestTranspose(t *testing.T) { } func TestUpperTriangular(t *testing.T) { - shadow := MakeIdentity(3) + shadow := MakeIdentity[*fr.Element](3) testMatrix := []struct { - m, s Matrix + m, s Matrix[*fr.Element] want bool }{ - {Matrix{{two, three, four}, {four, five, six}, {seven, eight, eight}}, shadow, true}, - {Matrix{{one, two, three}, {four, five, six}, {seven, eight, nine}}, shadow, false}, - {Matrix{{one, two, three}, {zero, three, four}, {zero, zero, three}}, shadow, true}, - {Matrix{{two, three, four}, {zero, two, four}, {zero, zero, one}}, shadow, true}, + {Matrix[*fr.Element]{{two, three, four}, {four, five, six}, {seven, eight, eight}}, shadow, true}, + {Matrix[*fr.Element]{{oneE, two, three}, {four, five, six}, {seven, eight, nine}}, shadow, false}, + {Matrix[*fr.Element]{{oneE, two, three}, {zeroE, three, four}, {zeroE, zeroE, three}}, shadow, true}, + {Matrix[*fr.Element]{{two, three, four}, {zeroE, two, four}, {zeroE, zeroE, oneE}}, shadow, true}, } for _, cases := range testMatrix { @@ -183,19 +186,19 @@ func TestUpperTriangular(t *testing.T) { } } -func TestFindNonZero(t *testing.T) { +func TestFindNonzeroE(t *testing.T) { vectorSet := []struct { k int - v Vector + v Vector[*fr.Element] want bool }{ - {0, Vector{zero, one, two, three}, false}, - {1, Vector{zero, one, two, three}, true}, - {2, Vector{zero, one, two, three}, false}, - {2, Vector{zero, zero, zero, one}, true}, - {3, Vector{zero, zero, zero, one}, true}, - {3, Vector{zero, one, two, three}, false}, - {4, Vector{zero, one, two, three}, false}, + {0, Vector[*fr.Element]{zeroE, oneE, two, three}, false}, + {1, Vector[*fr.Element]{zeroE, oneE, two, three}, true}, + {2, Vector[*fr.Element]{zeroE, oneE, two, three}, false}, + {2, Vector[*fr.Element]{zeroE, zeroE, zeroE, oneE}, true}, + {3, Vector[*fr.Element]{zeroE, zeroE, zeroE, oneE}, true}, + {3, Vector[*fr.Element]{zeroE, oneE, two, three}, false}, + {4, Vector[*fr.Element]{zeroE, oneE, two, three}, false}, } for _, cases := range vectorSet { @@ -203,55 +206,55 @@ func TestFindNonZero(t *testing.T) { assert.Equal(t, get, cases.want) } - nonzeroSet := []struct { - m Matrix + nonzeroESet := []struct { + m Matrix[*fr.Element] c int want struct { - e *ff.Element + e *fr.Element index int } }{ - {Matrix{{two, three, four}, {four, five, six}, {seven, eight, eight}}, 0, struct { - e *ff.Element + {Matrix[*fr.Element]{{two, three, four}, {four, five, six}, {seven, eight, eight}}, 0, struct { + e *fr.Element index int }{two, 0}}, - {Matrix{{two, three, four}, {four, five, six}, {seven, eight, eight}}, 1, struct { - e *ff.Element + {Matrix[*fr.Element]{{two, three, four}, {four, five, six}, {seven, eight, eight}}, 1, struct { + e *fr.Element index int }{three, 0}}, - {Matrix{{two, three, four}, {four, five, six}, {seven, eight, eight}}, 2, struct { - e *ff.Element + {Matrix[*fr.Element]{{two, three, four}, {four, five, six}, {seven, eight, eight}}, 2, struct { + e *fr.Element index int }{four, 0}}, - {Matrix{{one, zero, zero}, {two, three, zero}, {four, five, zero}}, 0, struct { - e *ff.Element + {Matrix[*fr.Element]{{oneE, zeroE, zeroE}, {two, three, zeroE}, {four, five, zeroE}}, 0, struct { + e *fr.Element index int - }{one, 0}}, - {Matrix{{one, zero, zero}, {two, three, zero}, {four, five, zero}}, 1, struct { - e *ff.Element + }{oneE, 0}}, + {Matrix[*fr.Element]{{oneE, zeroE, zeroE}, {two, three, zeroE}, {four, five, zeroE}}, 1, struct { + e *fr.Element index int }{three, 1}}, - {Matrix{{one, zero, zero}, {two, three, zero}, {four, five, zero}}, 2, struct { - e *ff.Element + {Matrix[*fr.Element]{{oneE, zeroE, zeroE}, {two, three, zeroE}, {four, five, zeroE}}, 2, struct { + e *fr.Element index int }{nil, -1}}, } - for _, cases := range nonzeroSet { + for _, cases := range nonzeroESet { gete, geti, err := findNonZero(cases.m, cases.c) assert.NoError(t, err) if gete != nil && cases.want.e != nil { if gete.Cmp(cases.want.e) != 0 || geti != cases.want.index { - t.Errorf("find non zero failed, get element: %v, want element: %v, get index: %d, want index: %d", gete, cases.want.e, geti, cases.want.index) + t.Errorf("find non zeroE failed, get element: %v, want element: %v, get index: %d, want index: %d", gete, cases.want.e, geti, cases.want.index) return } } else if gete == nil && cases.want.e == nil { if geti != cases.want.index || geti != -1 { - t.Errorf("find non zero failed, get element: %v, want element: %v, get index: %d, want index: %d", gete, cases.want.e, geti, cases.want.index) + t.Errorf("find non zeroE failed, get element: %v, want element: %v, get index: %d, want index: %d", gete, cases.want.e, geti, cases.want.index) return } } else { - t.Errorf("find non zero failed, get element: %v, want element: %v, get index: %d, want index: %d", gete, cases.want.e, geti, cases.want.index) + t.Errorf("find non zeroE failed, get element: %v, want element: %v, get index: %d, want index: %d", gete, cases.want.e, geti, cases.want.index) return } } @@ -260,30 +263,30 @@ func TestFindNonZero(t *testing.T) { func TestMatMul(t *testing.T) { // [[1,2,3],[4,5,6],[7,8,9]]*[[2,3,4],[4,5,6],[7,8,8]] // =[[31,37,40],[70,85,95],[109,133,148]] - m00 := new(ff.Element).SetUint64(31) - m01 := new(ff.Element).SetUint64(37) - m02 := new(ff.Element).SetUint64(40) - m10 := new(ff.Element).SetUint64(70) - m11 := new(ff.Element).SetUint64(85) - m12 := new(ff.Element).SetUint64(94) - m20 := new(ff.Element).SetUint64(109) - m21 := new(ff.Element).SetUint64(133) - m22 := new(ff.Element).SetUint64(148) - - thirteen := new(ff.Element).SetUint64(13) - sixteen := new(ff.Element).SetUint64(16) - eighteen := new(ff.Element).SetUint64(18) + m00 := new(fr.Element).SetUint64(31) + m01 := new(fr.Element).SetUint64(37) + m02 := new(fr.Element).SetUint64(40) + m10 := new(fr.Element).SetUint64(70) + m11 := new(fr.Element).SetUint64(85) + m12 := new(fr.Element).SetUint64(94) + m20 := new(fr.Element).SetUint64(109) + m21 := new(fr.Element).SetUint64(133) + m22 := new(fr.Element).SetUint64(148) + + thirteen := new(fr.Element).SetUint64(13) + sixteen := new(fr.Element).SetUint64(16) + eighteen := new(fr.Element).SetUint64(18) testMatrix := []struct { - m1, m2 Matrix - want Matrix + m1, m2 Matrix[*fr.Element] + want Matrix[*fr.Element] }{ - {Matrix{{zero, zero}, {zero, zero}}, Matrix{{one, two}, {one, two}}, Matrix{{zero, zero}, {zero, zero}}}, - {Matrix{{one, two}, {two, three}}, Matrix{{one, two}, {one, zero}}, Matrix{{three, two}, {five, four}}}, - {Matrix{{one, two, three}, {four, five, six}, {seven, eight, nine}}, Matrix{{two, three, four}, {four, five, six}, {seven, eight, eight}}, Matrix{{m00, m01, m02}, {m10, m11, m12}, {m20, m21, m22}}}, - {Matrix{{one, one, one}, {one, one, one}, {one, one, one}}, Matrix{{two, three, four}, {four, five, six}, {seven, eight, eight}}, Matrix{{thirteen, sixteen, eighteen}, {thirteen, sixteen, eighteen}, {thirteen, sixteen, eighteen}}}, - {Matrix{{zero, zero, zero}, {zero, zero, zero}, {zero, zero, zero}}, Matrix{{two, three, four}, {four, five, six}, {seven, eight, eight}}, Matrix{{zero, zero, zero}, {zero, zero, zero}, {zero, zero, zero}}}, - {Matrix{{one, zero, zero}, {zero, one, zero}, {zero, zero, one}}, Matrix{{two, three, four}, {four, five, six}, {seven, eight, eight}}, Matrix{{two, three, four}, {four, five, six}, {seven, eight, eight}}}, + {Matrix[*fr.Element]{{zeroE, zeroE}, {zeroE, zeroE}}, Matrix[*fr.Element]{{oneE, two}, {oneE, two}}, Matrix[*fr.Element]{{zeroE, zeroE}, {zeroE, zeroE}}}, + {Matrix[*fr.Element]{{oneE, two}, {two, three}}, Matrix[*fr.Element]{{oneE, two}, {oneE, zeroE}}, Matrix[*fr.Element]{{three, two}, {five, four}}}, + {Matrix[*fr.Element]{{oneE, two, three}, {four, five, six}, {seven, eight, nine}}, Matrix[*fr.Element]{{two, three, four}, {four, five, six}, {seven, eight, eight}}, Matrix[*fr.Element]{{m00, m01, m02}, {m10, m11, m12}, {m20, m21, m22}}}, + {Matrix[*fr.Element]{{oneE, oneE, oneE}, {oneE, oneE, oneE}, {oneE, oneE, oneE}}, Matrix[*fr.Element]{{two, three, four}, {four, five, six}, {seven, eight, eight}}, Matrix[*fr.Element]{{thirteen, sixteen, eighteen}, {thirteen, sixteen, eighteen}, {thirteen, sixteen, eighteen}}}, + {Matrix[*fr.Element]{{zeroE, zeroE, zeroE}, {zeroE, zeroE, zeroE}, {zeroE, zeroE, zeroE}}, Matrix[*fr.Element]{{two, three, four}, {four, five, six}, {seven, eight, eight}}, Matrix[*fr.Element]{{zeroE, zeroE, zeroE}, {zeroE, zeroE, zeroE}, {zeroE, zeroE, zeroE}}}, + {Matrix[*fr.Element]{{oneE, zeroE, zeroE}, {zeroE, oneE, zeroE}, {zeroE, zeroE, oneE}}, Matrix[*fr.Element]{{two, three, four}, {four, five, six}, {seven, eight, eight}}, Matrix[*fr.Element]{{two, three, four}, {four, five, six}, {seven, eight, eight}}}, } for _, cases := range testMatrix { @@ -294,17 +297,17 @@ func TestMatMul(t *testing.T) { // [[1,2,3],[4,5,6],[7,8,9]]*[1,1,1] // =[6,15,24] - fifteen := new(ff.Element).SetUint64(15) - twentyfour := new(ff.Element).SetUint64(24) + fifteen := new(fr.Element).SetUint64(15) + twentyfour := new(fr.Element).SetUint64(24) testLeftMul := []struct { - m Matrix - v Vector - want Vector + m Matrix[*fr.Element] + v Vector[*fr.Element] + want Vector[*fr.Element] }{ - {Matrix{{one, two, three}, {four, five, six}, {seven, eight, nine}}, Vector{zero, zero, zero}, Vector{zero, zero, zero}}, - {Matrix{{one, two, three}, {four, five, six}, {seven, eight, nine}}, Vector{one, zero, zero}, Vector{one, four, seven}}, - {Matrix{{one, two, three}, {four, five, six}, {seven, eight, nine}}, Vector{one, one, one}, Vector{six, fifteen, twentyfour}}, + {Matrix[*fr.Element]{{oneE, two, three}, {four, five, six}, {seven, eight, nine}}, Vector[*fr.Element]{zeroE, zeroE, zeroE}, Vector[*fr.Element]{zeroE, zeroE, zeroE}}, + {Matrix[*fr.Element]{{oneE, two, three}, {four, five, six}, {seven, eight, nine}}, Vector[*fr.Element]{oneE, zeroE, zeroE}, Vector[*fr.Element]{oneE, four, seven}}, + {Matrix[*fr.Element]{{oneE, two, three}, {four, five, six}, {seven, eight, nine}}, Vector[*fr.Element]{oneE, oneE, oneE}, Vector[*fr.Element]{six, fifteen, twentyfour}}, } for _, cases := range testLeftMul { @@ -315,16 +318,16 @@ func TestMatMul(t *testing.T) { // [1,1,1]*[[1,2,3],[4,5,6],[7,8,9]] // =[12,15,18] - twelve := new(ff.Element).SetUint64(12) + twelve := new(fr.Element).SetUint64(12) testRightMul := []struct { - v Vector - m Matrix - want Vector + v Vector[*fr.Element] + m Matrix[*fr.Element] + want Vector[*fr.Element] }{ - {Vector{zero, zero, zero}, Matrix{{one, two, three}, {four, five, six}, {seven, eight, nine}}, Vector{zero, zero, zero}}, - {Vector{one, zero, zero}, Matrix{{one, two, three}, {four, five, six}, {seven, eight, nine}}, Vector{one, two, three}}, - {Vector{one, one, one}, Matrix{{one, two, three}, {four, five, six}, {seven, eight, nine}}, Vector{twelve, fifteen, eighteen}}, + {Vector[*fr.Element]{zeroE, zeroE, zeroE}, Matrix[*fr.Element]{{oneE, two, three}, {four, five, six}, {seven, eight, nine}}, Vector[*fr.Element]{zeroE, zeroE, zeroE}}, + {Vector[*fr.Element]{oneE, zeroE, zeroE}, Matrix[*fr.Element]{{oneE, two, three}, {four, five, six}, {seven, eight, nine}}, Vector[*fr.Element]{oneE, two, three}}, + {Vector[*fr.Element]{oneE, oneE, oneE}, Matrix[*fr.Element]{{oneE, two, three}, {four, five, six}, {seven, eight, nine}}, Vector[*fr.Element]{twelve, fifteen, eighteen}}, } for _, cases := range testRightMul { @@ -335,36 +338,36 @@ func TestMatMul(t *testing.T) { } func TestEliminate(t *testing.T) { - m := Matrix{{two, three, four}, {four, five, six}, {seven, eight, eight}} - shadow := MakeIdentity(3) + m := Matrix[*fr.Element]{{two, three, four}, {four, five, six}, {seven, eight, eight}} + shadow := MakeIdentity[*fr.Element](3) // result of eliminating the first column. // [[2,3,4],[0,-1,-2],[0,-5/2,-6]] - negone := new(ff.Element).Neg(one) - negtwo := new(ff.Element).Neg(two) - negFiveDivTwo := new(ff.Element).Neg(five) + negoneE := new(fr.Element).Neg(oneE) + negtwo := new(fr.Element).Neg(two) + negFiveDivTwo := new(fr.Element).Neg(five) negFiveDivTwo.Div(negFiveDivTwo, two) - negsix := new(ff.Element).Neg(six) + negsix := new(fr.Element).Neg(six) // result of eliminating the second column. // [[2,3,4],[2/3,0,-2/3],[5/3,0,-8/3]] - twoDivThree := new(ff.Element).Div(two, three) - negTwoDivThree := new(ff.Element).Neg(twoDivThree) - fiveDivThree := new(ff.Element).Div(five, three) - negEightDivThree := new(ff.Element).Div(eight, three) + twoDivThree := new(fr.Element).Div(two, three) + negTwoDivThree := new(fr.Element).Neg(twoDivThree) + fiveDivThree := new(fr.Element).Div(five, three) + negEightDivThree := new(fr.Element).Div(eight, three) negEightDivThree.Neg(negEightDivThree) // result of eliminating the third column. // [[2,3,4],[1,1/2,0],[3,2,0]] - oneDivTwo := new(ff.Element).Div(one, two) + oneEDivTwo := new(fr.Element).Div(oneE, two) testMatrix := []struct { c int - want Matrix + want Matrix[*fr.Element] }{ - {0, Matrix{{two, three, four}, {zero, negone, negtwo}, {zero, negFiveDivTwo, negsix}}}, - {1, Matrix{{two, three, four}, {twoDivThree, zero, negTwoDivThree}, {fiveDivThree, zero, negEightDivThree}}}, - {2, Matrix{{two, three, four}, {one, oneDivTwo, zero}, {three, two, zero}}}, + {0, Matrix[*fr.Element]{{two, three, four}, {zeroE, negoneE, negtwo}, {zeroE, negFiveDivTwo, negsix}}}, + {1, Matrix[*fr.Element]{{two, three, four}, {twoDivThree, zeroE, negTwoDivThree}, {fiveDivThree, zeroE, negEightDivThree}}}, + {2, Matrix[*fr.Element]{{two, three, four}, {oneE, oneEDivTwo, zeroE}, {three, two, zeroE}}}, } for _, cases := range testMatrix { @@ -377,29 +380,29 @@ func TestEliminate(t *testing.T) { func TestReduceToIdentity(t *testing.T) { // m=[[1,2,3],[0,3,4],[0,0,3]] // m^-1=[[1,-2/3,-1/9],[0,1/3,-4/9],[0,0,1/3]] - negTwoDivThree := new(ff.Element).Div(two, three) + negTwoDivThree := new(fr.Element).Div(two, three) negTwoDivThree.Neg(negTwoDivThree) - negOneDivNine := new(ff.Element).Div(one, nine) - negOneDivNine.Neg(negOneDivNine) - oneDivThree := new(ff.Element).Div(one, three) - negFourDivNine := new(ff.Element).Div(four, nine) + negoneEDivNine := new(fr.Element).Div(oneE, nine) + negoneEDivNine.Neg(negoneEDivNine) + oneEDivThree := new(fr.Element).Div(oneE, three) + negFourDivNine := new(fr.Element).Div(four, nine) negFourDivNine.Neg(negFourDivNine) // m=[[2,3,4],[0,2,4],[0,0,1]] // m^-1=[[1/2,-3/4,1],[0,1/2,-2],[0,0,1]] - oneDivTwo := new(ff.Element).Div(one, two) - negThreeDivFour := new(ff.Element).Div(three, four) + oneEDivTwo := new(fr.Element).Div(oneE, two) + negThreeDivFour := new(fr.Element).Div(three, four) negThreeDivFour.Neg(negThreeDivFour) - negtwo := new(ff.Element).Neg(two) + negtwo := new(fr.Element).Neg(two) - shadow := MakeIdentity(3) + shadow := MakeIdentity[*fr.Element](3) testMatrix := []struct { - m Matrix - want Matrix + m Matrix[*fr.Element] + want Matrix[*fr.Element] }{ - {Matrix{{one, two, three}, {zero, three, four}, {zero, zero, three}}, Matrix{{one, negTwoDivThree, negOneDivNine}, {zero, oneDivThree, negFourDivNine}, {zero, zero, oneDivThree}}}, - {Matrix{{two, three, four}, {zero, two, four}, {zero, zero, one}}, Matrix{{oneDivTwo, negThreeDivFour, one}, {zero, oneDivTwo, negtwo}, {zero, zero, one}}}, + {Matrix[*fr.Element]{{oneE, two, three}, {zeroE, three, four}, {zeroE, zeroE, three}}, Matrix[*fr.Element]{{oneE, negTwoDivThree, negoneEDivNine}, {zeroE, oneEDivThree, negFourDivNine}, {zeroE, zeroE, oneEDivThree}}}, + {Matrix[*fr.Element]{{two, three, four}, {zeroE, two, four}, {zeroE, zeroE, oneE}}, Matrix[*fr.Element]{{oneEDivTwo, negThreeDivFour, oneE}, {zeroE, oneEDivTwo, negtwo}, {zeroE, zeroE, oneE}}}, } for _, cases := range testMatrix { @@ -411,13 +414,13 @@ func TestReduceToIdentity(t *testing.T) { func TestIsInvertible(t *testing.T) { testMatrix := []struct { - m Matrix + m Matrix[*fr.Element] want bool }{ - {Matrix{{two, three, four}, {four, five, six}, {seven, eight, eight}}, true}, - {Matrix{{one, two, three}, {zero, three, four}, {zero, zero, three}}, true}, - {Matrix{{two, three, four}, {zero, two, four}, {zero, zero, one}}, true}, - {Matrix{{one, two, three}, {four, five, six}, {seven, eight, nine}}, false}, + {Matrix[*fr.Element]{{two, three, four}, {four, five, six}, {seven, eight, eight}}, true}, + {Matrix[*fr.Element]{{oneE, two, three}, {zeroE, three, four}, {zeroE, zeroE, three}}, true}, + {Matrix[*fr.Element]{{two, three, four}, {zeroE, two, four}, {zeroE, zeroE, oneE}}, true}, + {Matrix[*fr.Element]{{oneE, two, three}, {four, five, six}, {seven, eight, nine}}, false}, } for _, cases := range testMatrix { @@ -433,8 +436,8 @@ func TestInvert(t *testing.T) { // m^-1: // [7 -3] // [-2 1] - negtwo := new(ff.Element).Neg(two) - negthree := new(ff.Element).Neg(three) + negtwo := new(fr.Element).Neg(two) + negthree := new(fr.Element).Neg(three) // 3*3 m: // [1 2 3] @@ -444,12 +447,12 @@ func TestInvert(t *testing.T) { // [1 -2/3 -1/9] // [0 1/3 -4/9] // [0 0 1/3] - negTwoDivThree := new(ff.Element).Div(two, three) + negTwoDivThree := new(fr.Element).Div(two, three) negTwoDivThree.Neg(negTwoDivThree) - negOneDivNine := new(ff.Element).Div(one, nine) - negOneDivNine.Neg(negOneDivNine) - oneDivThree := new(ff.Element).Div(one, three) - negFourDivNine := new(ff.Element).Div(four, nine) + negoneEDivNine := new(fr.Element).Div(oneE, nine) + negoneEDivNine.Neg(negoneEDivNine) + oneEDivThree := new(fr.Element).Div(oneE, three) + negFourDivNine := new(fr.Element).Div(four, nine) negFourDivNine.Neg(negFourDivNine) // 3*3 m: @@ -460,20 +463,20 @@ func TestInvert(t *testing.T) { // [-4 4 -1] // [5 -6 2] // [-3/2 5/2 -1] - negone := new(ff.Element).Neg(one) - negfour := new(ff.Element).Neg(four) - negsix := new(ff.Element).Neg(six) - negThreeDivTwo := new(ff.Element).Div(three, two) + negoneE := new(fr.Element).Neg(oneE) + negfour := new(fr.Element).Neg(four) + negsix := new(fr.Element).Neg(six) + negThreeDivTwo := new(fr.Element).Div(three, two) negThreeDivTwo.Neg(negThreeDivTwo) - fiveDivTwo := new(ff.Element).Div(five, two) + fiveDivTwo := new(fr.Element).Div(five, two) testMatrix := []struct { - m Matrix - want Matrix + m Matrix[*fr.Element] + want Matrix[*fr.Element] }{ - {Matrix{{one, three}, {two, seven}}, Matrix{{seven, negthree}, {negtwo, one}}}, - {Matrix{{one, two, three}, {zero, three, four}, {zero, zero, three}}, Matrix{{one, negTwoDivThree, negOneDivNine}, {zero, oneDivThree, negFourDivNine}, {zero, zero, oneDivThree}}}, - {Matrix{{two, three, four}, {four, five, six}, {seven, eight, eight}}, Matrix{{negfour, four, negone}, {five, negsix, two}, {negThreeDivTwo, fiveDivTwo, negone}}}, + {Matrix[*fr.Element]{{oneE, three}, {two, seven}}, Matrix[*fr.Element]{{seven, negthree}, {negtwo, oneE}}}, + {Matrix[*fr.Element]{{oneE, two, three}, {zeroE, three, four}, {zeroE, zeroE, three}}, Matrix[*fr.Element]{{oneE, negTwoDivThree, negoneEDivNine}, {zeroE, oneEDivThree, negFourDivNine}, {zeroE, zeroE, oneEDivThree}}}, + {Matrix[*fr.Element]{{two, three, four}, {four, five, six}, {seven, eight, eight}}, Matrix[*fr.Element]{{negfour, four, negoneE}, {five, negsix, two}, {negThreeDivTwo, fiveDivTwo, negoneE}}}, } for _, cases := range testMatrix { @@ -486,4 +489,4 @@ func TestInvert(t *testing.T) { assert.NoError(t, err) assert.Equal(t, get, cases.want) } -} \ No newline at end of file +} diff --git a/mds.go b/mds.go index 06f5a60..6f16892 100644 --- a/mds.go +++ b/mds.go @@ -1,26 +1,25 @@ package poseidon import ( - "github.com/pkg/errors" - ff "github.com/triplewz/poseidon/bls12_381" + "fmt" ) // mdsMatrices is matrices for improving the efficiency of Poseidon hash. // see more details in the paper https://eprint.iacr.org/2019/458.pdf page 20. -type mdsMatrices struct { +type mdsMatrices[E Element[E]] struct { // the input mds matrix. - m Matrix + m Matrix[E] // mInv is the inverse of the mds matrix. - mInv Matrix + mInv Matrix[E] // mHat is the matrix by eliminating the first row and column of the matrix. - mHat Matrix + mHat Matrix[E] // mHatInv is the inverse of the mHat matrix. - mHatInv Matrix + mHatInv Matrix[E] // mPrime is the matrix m' in the paper, and it holds m = m'*m''. // mPrime consists of: // 1 | 0 // 0 | mHat - mPrime Matrix + mPrime Matrix[E] // mDoublePrime is the matrix m'' in the paper, and it holds m = m'*m''. // mDoublePrime consists of: // m_00 | v @@ -28,48 +27,41 @@ type mdsMatrices struct { // where M_00 is the first element of the mds matrix, // w_hat and v are t-1 length vectors, // I is the (t-1)*(t-1) identity matrix. - mDoublePrime Matrix + mDoublePrime Matrix[E] } // SparseMatrix is specifically one of the form of m''. // This means its first row and column are each dense, and the interior matrix // (minor to the element in both the row and column) is the identity. // For simplicity, we omit the identity matrix in m''. -type SparseMatrix struct { - // wHat is the first column of the M'' matrix, this is a little different with the wHat in the paper because - // we add M_00 to the beginning of the wHat. - wHat Vector - // v contains all but the first element, because it is already included in wHat. - v Vector -} - -// create the mds matrices. -func createMDSMatrix(t int) (*mdsMatrices, error) { - m := genMDS(t) - - return deriveMatrices(m) +type SparseMatrix[E Element[E]] struct { + // WHat is the first column of the M'' matrix, this is a little different with the WHat in the paper because + // we add M_00 to the beginning of the WHat. + WHat Vector[E] + // V contains all but the first element, because it is already included in WHat. + V Vector[E] } // generate the mds (cauchy) matrix, which is invertible, and // its sub-matrices are invertible as well. -func genMDS(t int) Matrix { - xVec := make([]*ff.Element, t) - yVec := make([]*ff.Element, t) +func genMDS[E Element[E]](t int) Matrix[E] { + xVec := make([]E, t) + yVec := make([]E, t) regen: // generate x and y value where x[i] != y[i] to allow the values to be inverted, and // there are no duplicates in the x vector or y vector, so that // the determinant is always non-zero. for i := 0; i < t; i++ { - xVec[i] = new(ff.Element).SetUint64(uint64(i)) - yVec[i] = new(ff.Element).SetUint64(uint64(i + t)) + xVec[i] = NewElement[E]().SetUint64(uint64(i)) + yVec[i] = NewElement[E]().SetUint64(uint64(i + t)) } - m := make([][]*ff.Element, t) + m := make([][]E, t) for i := 0; i < t; i++ { - m[i] = make([]*ff.Element, t) + m[i] = make([]E, t) for j := 0; j < t; j++ { - m[i][j] = new(ff.Element).Add(xVec[i], yVec[j]) + m[i][j] = NewElement[E]().Add(xVec[i], yVec[j]) m[i][j].Inverse(m[i][j]) } } @@ -90,43 +82,43 @@ regen: } // derive the mds matrices from m. -func deriveMatrices(m Matrix) (*mdsMatrices, error) { +func deriveMatrices[E Element[E]](m Matrix[E]) (*mdsMatrices[E], error) { mInv, err := Invert(m) if err != nil { - return nil, errors.Errorf("gen mInv failed, err: %s", err) + return nil, fmt.Errorf("gen mInv failed, err: %w", err) } mHat, err := minor(m, 0, 0) if err != nil { - return nil, errors.Errorf("gen mHat failed, err: %s", err) + return nil, fmt.Errorf("gen mHat failed, err: %w", err) } mHatInv, err := Invert(mHat) if err != nil { - return nil, errors.Errorf("gen mHatInv failed, err: %s", err) + return nil, fmt.Errorf("gen mHatInv failed, err: %w", err) } mPrime := genPrime(m) mDoublePrime, err := genDoublePrime(m, mHatInv) if err != nil { - return nil, errors.Errorf("gen double prime m failed, err: %s", err) + return nil, fmt.Errorf("gen double prime m failed, err: %w", err) } - return &mdsMatrices{m, mInv, mHat, mHatInv, mPrime, mDoublePrime}, nil + return &mdsMatrices[E]{m, mInv, mHat, mHatInv, mPrime, mDoublePrime}, nil } // generate the matrix m', where m = m'*m''. -func genPrime(m Matrix) Matrix { - prime := make([][]*ff.Element, row(m)) - prime[0] = append(prime[0], one) +func genPrime[E Element[E]](m Matrix[E]) Matrix[E] { + prime := make([][]E, row(m)) + prime[0] = append(prime[0], one[E]()) for i := 1; i < column(m); i++ { - prime[0] = append(prime[0], zero) + prime[0] = append(prime[0], zero[E]()) } for i := 1; i < row(m); i++ { - prime[i] = make([]*ff.Element, column(m)) - prime[i][0] = zero + prime[i] = make([]E, column(m)) + prime[i][0] = zero[E]() for j := 1; j < column(m); j++ { prime[i][j] = m[i][j] } @@ -135,24 +127,24 @@ func genPrime(m Matrix) Matrix { } // generate the matrix m'', where m = m'*m''. -func genDoublePrime(m, mHatInv Matrix) (Matrix, error) { +func genDoublePrime[E Element[E]](m, mHatInv Matrix[E]) (Matrix[E], error) { w, v := genPreVectors(m) wHat, err := LeftMatMul(mHatInv, w) if err != nil { - return nil, errors.Errorf("compute wHat failed, err: %s", err) + return nil, fmt.Errorf("compute WHat failed, err: %w", err) } - doublePrime := make([][]*ff.Element, row(m)) - doublePrime[0] = append([]*ff.Element{m[0][0]}, v...) + doublePrime := make([][]E, row(m)) + doublePrime[0] = append([]E{m[0][0]}, v...) for i := 1; i < row(m); i++ { - doublePrime[i] = make([]*ff.Element, column(m)) + doublePrime[i] = make([]E, column(m)) doublePrime[i][0] = wHat[i-1] for j := 1; j < column(m); j++ { if j == i { - doublePrime[i][j] = one + doublePrime[i][j] = one[E]() } else { - doublePrime[i][j] = zero + doublePrime[i][j] = zero[E]() } } } @@ -161,11 +153,11 @@ func genDoublePrime(m, mHatInv Matrix) (Matrix, error) { } // generate pre-computed vectors used in the sparse matrix. -func genPreVectors(m Matrix) (Vector, Vector) { - v := make([]*ff.Element, column(m)-1) +func genPreVectors[E Element[E]](m Matrix[E]) (Vector[E], Vector[E]) { + v := make([]E, column(m)-1) copy(v, m[0][1:]) - w := make([]*ff.Element, row(m)-1) + w := make([]E, row(m)-1) for i := 1; i < row(m); i++ { w[i-1] = m[i][0] } @@ -174,27 +166,27 @@ func genPreVectors(m Matrix) (Vector, Vector) { } // parseSparseMatrix parses the sparse matrix. -func parseSparseMatrix(m Matrix) (*SparseMatrix, error) { +func parseSparseMatrix[E Element[E]](m Matrix[E]) (*SparseMatrix[E], error) { sub, err := minor(m, 0, 0) if err != nil { - return nil, errors.Errorf("get the sub matrix err: %s", err) + return nil, fmt.Errorf("get the sub matrix err: %w", err) } // m should be the sparse matrix, which has a (t-1)*(t-1) sub identity matrix. if !IsSquareMatrix(m) || !IsIdentity(sub) { - return nil, errors.Errorf("cannot parse the sparse matrix!") + return nil, fmt.Errorf("cannot parse the sparse matrix") } - // wHat is the first column of the sparse matrix. - sparse := new(SparseMatrix) - sparse.wHat = make([]*ff.Element, row(m)) + // WHat is the first column of the sparse matrix. + sparse := new(SparseMatrix[E]) + sparse.WHat = make([]E, row(m)) for i := 0; i < column(m); i++ { - sparse.wHat[i] = m[i][0] + sparse.WHat[i] = m[i][0] } - // v contains all but the first element. - sparse.v = make([]*ff.Element, column(m)-1) - copy(sparse.v, m[0][1:]) + // V contains all but the first element. + sparse.V = make([]E, column(m)-1) + copy(sparse.V, m[0][1:]) return sparse, nil } @@ -207,26 +199,26 @@ func parseSparseMatrix(m Matrix) (*SparseMatrix, error) { // use the sparse matrix m'' as the mds matrix, // then the previous layer's m is replaced by m x m' = m*. // from the last partial round, do the same work to the first partial round. -func genSparseMatrix(m Matrix, rp int) ([]*SparseMatrix, Matrix, error) { - sparses := make([]*SparseMatrix, rp) +func genSparseMatrix[E Element[E]](m Matrix[E], rp int) ([]*SparseMatrix[E], Matrix[E], error) { + sparses := make([]*SparseMatrix[E], rp) preSparse := copyMatrixRows(m, 0, row(m)) for i := 0; i < rp; i++ { mds, err := deriveMatrices(preSparse) if err != nil { - return nil, nil, errors.Errorf("derive mds matrices err: %s", err) + return nil, nil, fmt.Errorf("derive mds matrices err: %w", err) } // m* = m x m' mat, err := MatMul(m, mds.mPrime) if err != nil { - return nil, nil, errors.Errorf("get the previous layer's matrix err: %s", err) + return nil, nil, fmt.Errorf("get the previous layer's matrix err: %w", err) } // parse the sparse matrix by reverse order. sparses[rp-i-1], err = parseSparseMatrix(mds.mDoublePrime) if err != nil { - return nil, nil, errors.Errorf("parse sparse matrix err: %s", err) + return nil, nil, fmt.Errorf("parse sparse matrix err: %w", err) } preSparse = copyMatrixRows(mat, 0, row(mat)) diff --git a/mds_test.go b/mds_test.go index 182d434..eca8cd5 100644 --- a/mds_test.go +++ b/mds_test.go @@ -1,13 +1,16 @@ package poseidon import ( - "github.com/stretchr/testify/assert" "testing" + + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/stretchr/testify/assert" ) func TestMDS(t *testing.T) { for i := 2; i < 50; i++ { - mds, err := createMDSMatrix(i) + m := genMDS[*fr.Element](i) + mds, err := deriveMatrices(m) assert.NoError(t, err) mul0, err := MatMul(mds.m, mds.mInv) diff --git a/param.go b/param.go index 538caca..f9a15e1 100644 --- a/param.go +++ b/param.go @@ -1,8 +1,7 @@ package poseidon import ( - "github.com/pkg/errors" - ff "github.com/triplewz/poseidon/bls12_381" + "fmt" "math" "math/big" ) @@ -10,20 +9,16 @@ import ( // security level (in bits) const SecurityLevel int = 128 -// for bls12_381 modular p, since p ≠ 1 mod 5, we set Alpha = 5. -// see https://eprint.iacr.org/2019/458.pdf page 6. -const Alpha int = 5 - // we refer the rust implement and supplementary material shown in the paper to generate the round numbers. // see https://extgit.iaik.tugraz.at/krypto/hadeshash. -func calcRoundNumbers(t int, securityMargin bool) (rf, rp int) { +func calcRoundNumbers[E Element[E]](t int, securityMargin bool) (rf, rp int) { rf, rp = 0, 0 min := math.MaxInt64 // Brute-force approach for rft := 2; rft <= 1000; rft += 2 { for rpt := 4; rpt < 200; rpt++ { - if isRoundNumberSecure(t, rft, rpt) { + if isRoundNumberSecure[E](t, rft, rpt) { // https://eprint.iacr.org/2019/458.pdf page 9. if securityMargin { rft += 2 @@ -43,9 +38,9 @@ func calcRoundNumbers(t int, securityMargin bool) (rf, rp int) { } // isRoundNumberSecure determines if the round numbers are secure. -func isRoundNumberSecure(t, rf, rp int) bool { +func isRoundNumberSecure[E Element[E]](t, rf, rp int) bool { // n is the number of bits of p. - n := ff.Bits + n := Bits[E]() // Statistical Attacks // https://eprint.iacr.org/2019/458.pdf page 10. @@ -87,7 +82,7 @@ func appendBits(bits []byte, n, size int) []byte { // genNewBits generates new 80-bits slice and returns the newly generated bit. func genNewBits(bits []byte) byte { - newBit := byte(bits[0] ^ bits[13] ^ bits[23] ^ bits[38] ^ bits[51] ^ bits[62]) + newBit := bits[0] ^ bits[13] ^ bits[23] ^ bits[38] ^ bits[51] ^ bits[62] newBits := append(bits, newBit) copy(bits, newBits[1:]) return newBit @@ -113,15 +108,16 @@ func nextByte(bits []byte, bitCount int) byte { } // getBytes generates a random byte slice. -func getBytes(bits []byte, fieldsize int) []byte { +func getBytes[E Element[E]](bits []byte, fieldsize int) []byte { // Only prime fields are supported, and they always have reminder bits. remainderBits := fieldsize % 8 - buf := make([]byte, ff.Bytes) + bytes := Bytes[E]() + buf := make([]byte, bytes) buf[0] = nextByte(bits, remainderBits) // The first byte is already set. - for i := 1; i < ff.Bytes; i++ { + for i := 1; i < bytes; i++ { buf[i] = nextByte(bits, 8) } @@ -148,7 +144,7 @@ func getBytes(bits []byte, fieldsize int) []byte { // parameters (e.g., n and t) are the same. // Note that cryptographically strong randomness is not needed for the // round constants, and other methods can also be used. -func genRoundConstants(field, sbox int, fieldsize, t, rf, rp int) []*ff.Element { +func genRoundConstants[E Element[E]](field, sbox int, fieldsize, t, rf, rp int) []E { numCons := (rf + rp) * t var bits []byte @@ -164,14 +160,14 @@ func genRoundConstants(field, sbox int, fieldsize, t, rf, rp int) []*ff.Element genNewBits(bits) } - roundConsts := make([]*ff.Element, numCons) + roundConsts := make([]E, numCons) for i := 0; i < numCons; i++ { for { - buf := getBytes(bits, fieldsize) + buf := getBytes[E](bits, fieldsize) bufBigint := new(big.Int).SetBytes(buf) // Skip all buffers that would result in invalid field elements. - if ff.IsValid(bufBigint) { - roundConsts[i] = new(ff.Element).SetBytes(buf) + if IsValid[E](bufBigint) { + roundConsts[i] = NewElement[E]().SetBytes(buf) break } } @@ -180,12 +176,16 @@ func genRoundConstants(field, sbox int, fieldsize, t, rf, rp int) []*ff.Element return roundConsts } +func IsValid[E Element[E]](z *big.Int) bool { + return z.Cmp(Modulus[E]()) == -1 +} + // compress constants by pushing them back through linear layers and through the identity components of partial layers. // as a result, constants need only be added after each S-box. // see https://eprint.iacr.org/2019/458.pdf page 20. // in our implementation, we compress all constants in partial rounds. -func genCompressedRoundConstants(width, rf, rp int, roundConstants []*ff.Element, mds *mdsMatrices) ([]*ff.Element, error) { - comRoundConstants := make([]*ff.Element, rf*width+rp) +func genCompressedRoundConstants[E Element[E]](width, rf, rp int, roundConstants []E, mds *mdsMatrices[E]) ([]E, error) { + comRoundConstants := make([]E, rf*width+rp) mInv := mds.mInv // first round constants @@ -197,7 +197,7 @@ func genCompressedRoundConstants(width, rf, rp int, roundConstants []*ff.Element nextRound := roundConstants[(i+1)*width : (i+2)*width] inv, err := RightMatMul(nextRound, mInv) if err != nil { - return nil, errors.Errorf("full round constants mul err: %s", err) + return nil, fmt.Errorf("full round constants mul err: %w", err) } copy(comRoundConstants[(i+1)*width:(i+2)*width], inv) } @@ -206,28 +206,28 @@ func genCompressedRoundConstants(width, rf, rp int, roundConstants []*ff.Element lastPartialRound := rf/2 + rp lastPartialRoundKey := roundConstants[lastPartialRound*width : (lastPartialRound+1)*width] - partialKeys := make([]*ff.Element, rp) - roundAcc := make([]*ff.Element, width) - preRoundKeys := make([]*ff.Element, width) + partialKeys := make([]E, rp) + roundAcc := make([]E, width) + preRoundKeys := make([]E, width) copy(roundAcc, lastPartialRoundKey) for i := 0; i < rp; i++ { inv, err := RightMatMul(roundAcc, mInv) if err != nil { - return nil, errors.Errorf("partial key err: %s", err) + return nil, fmt.Errorf("partial key err: %w", err) } partialKeys[i] = inv[0] - inv[0] = zero + inv[0] = zero[E]() copy(preRoundKeys, roundConstants[(lastPartialRound-i-1)*width:(lastPartialRound-i)*width]) roundAcc, err = VecAdd(preRoundKeys, inv) if err != nil { - return nil, errors.Errorf("round accumulated err: %s", err) + return nil, fmt.Errorf("round accumulated err: %w", err) } } // the accumulated result. acc, err := RightMatMul(roundAcc, mInv) if err != nil { - return nil, errors.Errorf("last round key err: %s", err) + return nil, fmt.Errorf("last round key err: %w", err) } copy(comRoundConstants[(rf/2)*width:(rf/2+1)*width], acc) @@ -241,7 +241,7 @@ func genCompressedRoundConstants(width, rf, rp int, roundConstants []*ff.Element constants := roundConstants[(rf/2+rp+i)*width : (rf/2+rp+i+1)*width] inv, err := RightMatMul(constants, mInv) if err != nil { - return nil, errors.Errorf("final full round key err: %s", err) + return nil, fmt.Errorf("final full round key err: %w", err) } copy(comRoundConstants[(rf/2+i)*width+rp:(rf/2+i+1)*width+rp], inv) } diff --git a/param_test.go b/param_test.go index 0798446..dfe6af6 100644 --- a/param_test.go +++ b/param_test.go @@ -1,8 +1,10 @@ package poseidon import ( - "github.com/stretchr/testify/assert" "testing" + + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/stretchr/testify/assert" ) func TestCalcRoundNum(t *testing.T) { @@ -52,7 +54,7 @@ func TestCalcRoundNum(t *testing.T) { } for _, cases := range tests { - getRf, getRp := calcRoundNumbers(cases.t, cases.s) + getRf, getRp := calcRoundNumbers[*fr.Element](cases.t, cases.s) assert.Equal(t, getRf, cases.want.rf) assert.Equal(t, getRp, cases.want.rp) } @@ -71,7 +73,7 @@ func TestGenRoundConstants(t *testing.T) { } for _, cases := range tests { - get := genRoundConstants(1, 1, 255, cases.t, cases.rf, cases.rp) + get := genRoundConstants[*fr.Element](1, 1, 255, cases.t, cases.rf, cases.rp) assert.Equal(t, len(get), cases.want) } } @@ -89,8 +91,9 @@ func TestGenCompressedRoundConstants(t *testing.T) { } for _, cases := range tests { - roundContants := genRoundConstants(1, 1, 255, cases.t, cases.rf, cases.rp) - mds, _ := createMDSMatrix(cases.t) + roundContants := genRoundConstants[*fr.Element](1, 1, 255, cases.t, cases.rf, cases.rp) + m := genMDS[*fr.Element](cases.t) + mds, _ := deriveMatrices(m) comRoundContantsm, err := genCompressedRoundConstants(cases.t, cases.rf, cases.rp, roundContants, mds) assert.NoError(t, err) diff --git a/poseidon.go b/poseidon.go index a27b4bf..2bb343f 100644 --- a/poseidon.go +++ b/poseidon.go @@ -1,20 +1,19 @@ package poseidon import ( - "github.com/pkg/errors" - ff "github.com/triplewz/poseidon/bls12_381" + "fmt" "math/big" ) -type PoseidonConst struct { - Mds *mdsMatrices - RoundConsts []*ff.Element - ComRoundConts []*ff.Element - PreSparse Matrix - Sparse []*SparseMatrix - FullRounds int - HalfFullRounds int - PartialRounds int +type PoseidonConst[E Element[E]] struct { + Mds *mdsMatrices[E] + RoundConsts []E + CompRoundConsts []E + PreSparse Matrix[E] + Sparse []*SparseMatrix[E] + FullRounds int + HalfFullRounds int + PartialRounds int } // provide three hash modes. @@ -36,12 +35,15 @@ var PoseidonExp = new(big.Int).SetUint64(5) // we refer the rust implement (OptimizedStatic mode), see https://github.com/filecoin-project/neptune. // the input length is a slice of big integers. // the output of poseidon hash is a big integer. -func Hash(input []*big.Int, pdsContants *PoseidonConst, hash HashMode) (*big.Int, error) { - state := bigToElement(input) +func Hash[E Element[E]](input []*big.Int, pdsContants *PoseidonConst[E], hash HashMode) (*big.Int, error) { + state := bigToElement[E](input) // Neptune (a Rust implementation of Poseidon) is using domain tag 0x3 by default. - domain_tag := new(ff.Element).SetString("3") - state = append([]*ff.Element{domain_tag}, state...) + domain_tag, err := NewElement[E]().SetString("3") + if err != nil { + return nil, err + } + state = append([]E{domain_tag}, state...) //pdsContants, err := genPoseidonConstants(t) //if err != nil { @@ -61,52 +63,59 @@ func Hash(input []*big.Int, pdsContants *PoseidonConst, hash HashMode) (*big.Int } // generate poseidon constants used in the poseidon hash. -func GenPoseidonConstants(width int) (*PoseidonConst, error) { +func GenPoseidonConstants[E Element[E]](width int) (*PoseidonConst[E], error) { // round numbers. - rf, rp := calcRoundNumbers(width, true) + rf, rp := calcRoundNumbers[E](width, true) if rf%2 != 0 { - return nil, errors.Errorf("full rounds should be even!") + return nil, fmt.Errorf("full rounds should be even") } + + // generate mds matrix + mds := genMDS[E](width) + + return GenCustomPoseidonConstants[E](width, 1, 1, rf, rp, mds) +} + +func GenCustomPoseidonConstants[E Element[E]](width, field, sbox, rf, rp int, mds Matrix[E]) (*PoseidonConst[E], error) { half := rf / 2 - // round constants. - constants := genRoundConstants(1, 1, ff.Bits, width, rf, rp) + constants := genRoundConstants[E](field, sbox, Bits[E](), width, rf, rp) // mds matrices. - mds, err := createMDSMatrix(width) + mdsm, err := deriveMatrices(mds) if err != nil { - return nil, errors.Errorf("create mds matrix err: %s", err) + return nil, fmt.Errorf("create mds matrix err: %w", err) } // compressed round constants. - compress, err := genCompressedRoundConstants(width, rf, rp, constants, mds) + compress, err := genCompressedRoundConstants(width, rf, rp, constants, mdsm) if err != nil { - return nil, errors.Errorf("generate compressed round constants err: %s", err) + return nil, fmt.Errorf("generate compressed round constants err: %w", err) } // sparse and pre-sparse matrices. - sparse, preSparse, err := genSparseMatrix(mds.m, rp) + sparse, preSparse, err := genSparseMatrix(mdsm.m, rp) if err != nil { - return nil, errors.Errorf("generate sparse matrix err: %s", err) + return nil, fmt.Errorf("generate sparse matrix err: %w", err) } - return &PoseidonConst{ - Mds: mds, - RoundConsts: constants, - ComRoundConts: compress, - PreSparse: preSparse, - Sparse: sparse, - FullRounds: rf, - PartialRounds: rp, - HalfFullRounds: half, + return &PoseidonConst[E]{ + Mds: mdsm, + RoundConsts: constants, + CompRoundConsts: compress, + PreSparse: preSparse, + Sparse: sparse, + FullRounds: rf, + PartialRounds: rp, + HalfFullRounds: half, }, nil } -func optimizedStaticHash(state []*ff.Element, pdsConsts *PoseidonConst) (*big.Int, error) { +func optimizedStaticHash[E Element[E]](state []E, pdsConsts *PoseidonConst[E]) (*big.Int, error) { t := len(state) // The first full round should use the initial constants. for i := 0; i < t; i++ { - state[i].Add(state[i], pdsConsts.ComRoundConts[i]) + state[i].Add(state[i], pdsConsts.CompRoundConsts[i]) } // do the first half full rounds @@ -129,12 +138,12 @@ func optimizedStaticHash(state []*ff.Element, pdsConsts *PoseidonConst) (*big.In // output state[1] h := new(big.Int) - state[1].ToBigIntRegular(h) + state[1].BigInt(h) return h, nil } -func optimizedDynamicHash(state []*ff.Element, pdsConsts *PoseidonConst) (*big.Int, error) { +func optimizedDynamicHash[E Element[E]](state []E, pdsConsts *PoseidonConst[E]) (*big.Int, error) { t := len(state) // The first full round should use the initial constants. state = dynamicFullRounds(state, true, true, 0, pdsConsts) @@ -154,12 +163,12 @@ func optimizedDynamicHash(state []*ff.Element, pdsConsts *PoseidonConst) (*big.I // output state[1] h := new(big.Int) - state[1].ToBigIntRegular(h) + state[1].BigInt(h) return h, nil } -func correctHash(state []*ff.Element, pdsConsts *PoseidonConst) (*big.Int, error) { +func correctHash[E Element[E]](state []E, pdsConsts *PoseidonConst[E]) (*big.Int, error) { t := len(state) // do the first half full rounds. @@ -179,13 +188,13 @@ func correctHash(state []*ff.Element, pdsConsts *PoseidonConst) (*big.Int, error // output state[1] h := new(big.Int) - state[1].ToBigIntRegular(h) + state[1].BigInt(h) return h, nil } // addRoundConsts adds round constants to the input. -func addRoundConsts(state []*ff.Element, RoundConsts []*ff.Element) []*ff.Element { +func addRoundConsts[E Element[E]](state []E, RoundConsts []E) []E { for i := 0; i < len(state); i++ { state[i].Add(state[i], RoundConsts[i]) } @@ -194,17 +203,18 @@ func addRoundConsts(state []*ff.Element, RoundConsts []*ff.Element) []*ff.Elemen } // sbox computes x^5 mod p -func sbox(e *ff.Element, pre, post *ff.Element) { +func sbox[E Element[E]](e E, pre, post *E) { //if pre is not nil, add round constants before computing the sbox. - if pre != nil { - e.Add(e, pre) + if pre != nil && !isNil(*pre) { + e.Add(e, *pre) } - e.Exp(*e, PoseidonExp) + x := NewElement[E]().Set(e) + Exp(e, x, PoseidonExp) // if post is not nil, add round constants after computing the sbox. - if post != nil { - e.Add(e, post) + if post != nil && !isNil(*post) { + e.Add(e, *post) } } @@ -212,10 +222,10 @@ func sbox(e *ff.Element, pre, post *ff.Element) { // see https://eprint.iacr.org/2019/458.pdf page 6. // The partial round is the same as the full round, with the difference // that we apply the S-Box only to the first element. -func staticPartialRounds(state []*ff.Element, offset int, pdsConsts *PoseidonConst) []*ff.Element { +func staticPartialRounds[E Element[E]](state []E, offset int, pdsConsts *PoseidonConst[E]) []E { // swap the order of the linear layer and the round constant addition, // see https://eprint.iacr.org/2019/458.pdf page 20. - sbox(state[0], nil, pdsConsts.ComRoundConts[offset]) + sbox(state[0], nil, &pdsConsts.CompRoundConsts[offset]) state = productSparseMatrix(state, offset-len(state)*(pdsConsts.HalfFullRounds+1), pdsConsts.Sparse) return state @@ -223,7 +233,7 @@ func staticPartialRounds(state []*ff.Element, offset int, pdsConsts *PoseidonCon // staticFullRounds computes arc->sbox->M, which has full sbox layers, // see https://eprint.iacr.org/2019/458.pdf page 6. -func staticFullRounds(state []*ff.Element, lastRound bool, offset int, pdsConsts *PoseidonConst) []*ff.Element { +func staticFullRounds[E Element[E]](state []E, lastRound bool, offset int, pdsConsts *PoseidonConst[E]) []E { // in the last round, there is no need to add round constants because // we have swapped the order of the linear layer and the round constant addition. // see https://eprint.iacr.org/2019/458.pdf page 20. @@ -233,8 +243,8 @@ func staticFullRounds(state []*ff.Element, lastRound bool, offset int, pdsConsts } } else { for i := 0; i < len(state); i++ { - postKey := pdsConsts.ComRoundConts[offset+i] - sbox(state[i], nil, postKey) + postKey := pdsConsts.CompRoundConsts[offset+i] + sbox(state[i], nil, &postKey) } } @@ -250,7 +260,7 @@ func staticFullRounds(state []*ff.Element, lastRound bool, offset int, pdsConsts } // dynamic partial rounds used in the dynamic hash mode. -func dynamicPartialRounds(state []*ff.Element, pdsContants *PoseidonConst) []*ff.Element { +func dynamicPartialRounds[E Element[E]](state []E, pdsContants *PoseidonConst[E]) []E { // sbox layer. sbox(state[0], nil, nil) @@ -261,10 +271,10 @@ func dynamicPartialRounds(state []*ff.Element, pdsContants *PoseidonConst) []*ff } // dynamic full rounds used in the dynamic hash mode. -func dynamicFullRounds(state []*ff.Element, current, next bool, offset int, pdsContants *PoseidonConst) []*ff.Element { +func dynamicFullRounds[E Element[E]](state []E, current, next bool, offset int, pdsContants *PoseidonConst[E]) []E { t := len(state) - preRoundKeys := make([]*ff.Element, t) - postVec := make([]*ff.Element, t) + preRoundKeys := make([]E, t) + postVec := make([]E, t) // if `current` is true, we need to add the round constants before the sbox layer. if current { @@ -285,17 +295,17 @@ func dynamicFullRounds(state []*ff.Element, current, next bool, offset int, pdsC panic(err) } - postRoundKeys := make([]*ff.Element, t) + postRoundKeys := make([]E, t) copy(postRoundKeys, inv) // sbox layer. for i := 0; i < t; i++ { - sbox(state[i], preRoundKeys[i], postRoundKeys[i]) + sbox(state[i], &preRoundKeys[i], &postRoundKeys[i]) } } else { // sbox layer. for i := 0; i < t; i++ { - sbox(state[i], preRoundKeys[i], nil) + sbox(state[i], &preRoundKeys[i], nil) } } @@ -305,7 +315,7 @@ func dynamicFullRounds(state []*ff.Element, current, next bool, offset int, pdsC } // partial rounds used in the correct hash mode. -func partialRounds(state []*ff.Element, offset int, pdsConsts *PoseidonConst) []*ff.Element { +func partialRounds[E Element[E]](state []E, offset int, pdsConsts *PoseidonConst[E]) []E { // ark. state = addRoundConsts(state, pdsConsts.RoundConsts[offset:offset+len(state)]) @@ -319,10 +329,10 @@ func partialRounds(state []*ff.Element, offset int, pdsConsts *PoseidonConst) [] } // full rounds used in the correct hash mode. -func fullRounds(state []*ff.Element, offset int, pdsConsts *PoseidonConst) []*ff.Element { +func fullRounds[E Element[E]](state []E, offset int, pdsConsts *PoseidonConst[E]) []E { // sbox layer. for i := 0; i < len(state); i++ { - sbox(state[i], pdsConsts.RoundConsts[offset+i], nil) + sbox(state[i], &pdsConsts.RoundConsts[offset+i], nil) } // mixed layer, multiply the elements by the constant MDS matrix. @@ -332,16 +342,16 @@ func fullRounds(state []*ff.Element, offset int, pdsConsts *PoseidonConst) []*ff } // productMdsMatrix computes the product between the elements and the mds matrix. -func productMdsMatrix(state []*ff.Element, mds Matrix) []*ff.Element { +func productMdsMatrix[E Element[E]](state []E, mds Matrix[E]) []E { if len(state) != len(mds) { panic("cannot compute the product !") } - var res []*ff.Element + var res []E for j := 0; j < len(state); j++ { - tmp1 := new(ff.Element) + tmp1 := NewElement[E]() for i := 0; i < len(state); i++ { - tmp2 := new(ff.Element).Mul(state[i], mds[i][j]) + tmp2 := NewElement[E]().Mul(state[i], mds[i][j]) tmp1.Add(tmp1, tmp2) } res = append(res, tmp1) @@ -355,16 +365,16 @@ func productMdsMatrix(state []*ff.Element, mds Matrix) []*ff.Element { } // productPreSparseMatrix computes the product between the elements and the pre-sparse matrix. -func productPreSparseMatrix(state []*ff.Element, preSparseMatrix Matrix) []*ff.Element { +func productPreSparseMatrix[E Element[E]](state []E, preSparseMatrix Matrix[E]) []E { if len(state) != len(preSparseMatrix) { panic("cannot compute the product !") } - var res []*ff.Element + var res []E for j := 0; j < len(state); j++ { - tmp1 := new(ff.Element) + tmp1 := NewElement[E]() for i := 0; i < len(state); i++ { - tmp2 := new(ff.Element).Mul(state[i], preSparseMatrix[i][j]) + tmp2 := NewElement[E]().Mul(state[i], preSparseMatrix[i][j]) tmp1.Add(tmp1, tmp2) } res = append(res, tmp1) @@ -374,7 +384,7 @@ func productPreSparseMatrix(state []*ff.Element, preSparseMatrix Matrix) []*ff.E } // productSparseMatrix computes the product between the elements and the sparse matrix. -func productSparseMatrix(state []*ff.Element, offset int, sparse []*SparseMatrix) []*ff.Element { +func productSparseMatrix[E Element[E]](state []E, offset int, sparse []*SparseMatrix[E]) []E { // this part is described in https://eprint.iacr.org/2019/458.pdf page 20. // the sparse matrix M'' consists of: // @@ -388,16 +398,16 @@ func productSparseMatrix(state []*ff.Element, offset int, sparse []*SparseMatrix // we can first compute ret[0] = state * [M_00, w_hat], // then for 1 <= i < t, // compute ret[i] = state[0] * v[i-1] + state[i]. - res := make([]*ff.Element, len(state)) - res[0] = new(ff.Element) + res := make([]E, len(state)) + res[0] = NewElement[E]() for i := 0; i < len(state); i++ { - tmp := new(ff.Element).Mul(state[i], sparse[offset].wHat[i]) + tmp := NewElement[E]().Mul(state[i], sparse[offset].WHat[i]) res[0].Add(res[0], tmp) } for i := 1; i < len(state); i++ { - tmp := new(ff.Element).Mul(state[0], sparse[offset].v[i-1]) - res[i] = new(ff.Element).Add(state[i], tmp) + tmp := NewElement[E]().Mul(state[0], sparse[offset].V[i-1]) + res[i] = NewElement[E]().Add(state[i], tmp) } return res diff --git a/poseidon_test.go b/poseidon_test.go index bad95b9..248adee 100644 --- a/poseidon_test.go +++ b/poseidon_test.go @@ -2,11 +2,12 @@ package poseidon import ( "encoding/json" - "github.com/stretchr/testify/assert" - ff "github.com/triplewz/poseidon/bls12_381" "math/big" "os" "testing" + + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/stretchr/testify/assert" ) func TestPoseidonConstans(t *testing.T) { @@ -38,29 +39,29 @@ func TestPoseidonConstans(t *testing.T) { assert.NoError(t, err) // compressed round constants - comRoundConstants := hexToElement(strs.CompressedRoundConstants) + comRoundConstants := hexToElement[*fr.Element](strs.CompressedRoundConstants) // round constants - roundConstants := hexToElement(strs.RoundConstants) + roundConstants := hexToElement[*fr.Element](strs.RoundConstants) // mds matrix - mdsMatrix := make([][]*ff.Element, len(strs.Mds)) + mdsMatrix := make([][]*fr.Element, len(strs.Mds)) for i := 0; i < len(strs.Mds); i++ { - mdsMatrix[i] = hexToElement(strs.Mds[i]) + mdsMatrix[i] = hexToElement[*fr.Element](strs.Mds[i]) } // pre-sparse matrix - preSparseMatrix := make([][]*ff.Element, len(strs.PreSparse)) + preSparseMatrix := make([][]*fr.Element, len(strs.PreSparse)) for i := 0; i < len(strs.PreSparse); i++ { - preSparseMatrix[i] = hexToElement(strs.PreSparse[i]) + preSparseMatrix[i] = hexToElement[*fr.Element](strs.PreSparse[i]) } // sparse matrix - sparseMatrix := make([][][]*ff.Element, len(strs.Sparse)) + sparseMatrix := make([][][]*fr.Element, len(strs.Sparse)) for i := 0; i < len(strs.Sparse); i++ { - sparseMatrix[i] = make([][]*ff.Element, len(strs.Sparse[i])) + sparseMatrix[i] = make([][]*fr.Element, len(strs.Sparse[i])) for j := 0; j < len(strs.Sparse[i]); j++ { - sparseMatrix[i][j] = hexToElement(strs.Sparse[i][j]) + sparseMatrix[i][j] = hexToElement[*fr.Element](strs.Sparse[i][j]) } } @@ -81,7 +82,7 @@ func TestPoseidonConstans(t *testing.T) { } for i := 0; i < 57; i++ { - if !IsVecEqual(sparseMatrix[i][0], sparse[i].wHat) || !IsVecEqual(sparseMatrix[i][1], sparse[i].v) { + if !IsVecEqual(sparseMatrix[i][0], sparse[i].WHat) || !IsVecEqual(sparseMatrix[i][1], sparse[i].V) { t.Error("got wrong sparse matrix!") return } @@ -106,7 +107,7 @@ var strs = [][]string{ func TestPoseidonHash(t *testing.T) { for i := 0; i < len(strs); i++ { - cons, _ := GenPoseidonConstants(len(strs[i]) + 1) + cons, _ := GenPoseidonConstants[*fr.Element](len(strs[i]) + 1) input := hexToBig(strs[i]) h1, _ := Hash(input, cons, OptimizedStatic) h2, _ := Hash(input, cons, OptimizedDynamic) @@ -117,15 +118,15 @@ func TestPoseidonHash(t *testing.T) { } func TestPoseidonHashFixed(t *testing.T) { - cons, _ := GenPoseidonConstants(3) + cons, _ := GenPoseidonConstants[*fr.Element](3) input := []*big.Int{big.NewInt(0), big.NewInt(0)} hash, _ := Hash(input, cons, OptimizedStatic) expected, _ := new(big.Int).SetString("48fe0b1331196f6cdb33a7c6e5af61b76fd388e1ef1d3d418be5147f0e4613d4", 16) - assert.Equal(t, hash, expected) + assert.Equal(t, expected, hash) } func benchmarkStatic(b *testing.B, str []string) { - cons, _ := GenPoseidonConstants(len(str) + 1) + cons, _ := GenPoseidonConstants[*fr.Element](len(str) + 1) input := hexToBig(str) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -145,7 +146,7 @@ func BenchmarkOptimizedStaticWith9Inputs(b *testing.B) { benchmarkStatic(b, str func BenchmarkOptimizedStaticWith10Inputs(b *testing.B) { benchmarkStatic(b, strs[9]) } func benchmarkDynamic(b *testing.B, str []string) { - cons, _ := GenPoseidonConstants(len(str) + 1) + cons, _ := GenPoseidonConstants[*fr.Element](len(str) + 1) input := hexToBig(str) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -165,7 +166,7 @@ func BenchmarkOptimizedDynamicWith9Inputs(b *testing.B) { benchmarkDynamic(b, s func BenchmarkOptimizedDynamicWith10Inputs(b *testing.B) { benchmarkDynamic(b, strs[9]) } func benchmarkCorrect(b *testing.B, str []string) { - cons, _ := GenPoseidonConstants(len(str) + 1) + cons, _ := GenPoseidonConstants[*fr.Element](len(str) + 1) input := hexToBig(str) b.ResetTimer() for i := 0; i < b.N; i++ { diff --git a/util.go b/util.go index 6cce86c..44ff10f 100644 --- a/util.go +++ b/util.go @@ -1,28 +1,31 @@ package poseidon import ( - ff "github.com/triplewz/poseidon/bls12_381" "math/big" ) // hexToElement converts hex-strings to finite field elements -func hexToElement(hex []string) []*ff.Element { - elementArray := make([]*ff.Element, len(hex)) +func hexToElement[E Element[E]](hex []string) []E { + elementArray := make([]E, len(hex)) for i := 0; i < len(hex); i++ { - elementArray[i] = new(ff.Element) - elementArray[i].SetHexString(hex[i]) + elementArray[i] = NewElement[E]() + b, ok := new(big.Int).SetString(hex[i], 16) + if !ok { + panic("Element.SetString failed -> can't parse number in base16 into a big.Int") + } + elementArray[i].SetBigInt(b) } return elementArray } // bigToElement converts big integers to finite field elements -func bigToElement(big []*big.Int) []*ff.Element { - elementArray := make([]*ff.Element, len(big)) +func bigToElement[E Element[E]](big []*big.Int) []E { + elementArray := make([]E, len(big)) for i := 0; i < len(big); i++ { - elementArray[i] = new(ff.Element) + elementArray[i] = NewElement[E]() elementArray[i].SetBigInt(big[i]) }