From 8b03492d88aba19f3c643694701109ee30d95bd1 Mon Sep 17 00:00:00 2001 From: Martin Holst Swende Date: Thu, 11 Aug 2022 14:54:49 +0200 Subject: [PATCH 1/4] core/vm, tests: optimized modexp + fuzzer --- common/math/modexp.go | 72 +++++++++++++++++++++++ core/vm/contracts.go | 8 ++- oss-fuzz.sh | 2 + tests/fuzzers/modexp/modexp-fuzzer.go | 84 +++++++++++++++++++++++++++ 4 files changed, 165 insertions(+), 1 deletion(-) create mode 100644 common/math/modexp.go create mode 100644 tests/fuzzers/modexp/modexp-fuzzer.go diff --git a/common/math/modexp.go b/common/math/modexp.go new file mode 100644 index 000000000000..9c5086ea48c2 --- /dev/null +++ b/common/math/modexp.go @@ -0,0 +1,72 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package math + +import ( + "math/big" + "math/bits" + + "github.com/ethereum/go-ethereum/common" +) + +// FastExp is semantically equivalent to x.Exp(x,y, m), but is faster in +// when the mod is even. +func FastExp(x, y, m *big.Int) *big.Int { + // Split m = m1 × m2 where m1 = 2ⁿ + n := m.TrailingZeroBits() + m1 := new(big.Int).Lsh(common.Big1, n) + mask := new(big.Int).Sub(m1, common.Big1) + m2 := new(big.Int).Rsh(m, n) + + // We want z = x**y mod m. + // z1 = x**y mod m1 = (x**y mod m) mod m1 = z mod m1 + // z2 = x**y mod m2 = (x**y mod m) mod m2 = z mod m2 + z1 := fastExpPow2(x, y, mask) + z2 := new(big.Int).Exp(x, y, m2) + + // Reconstruct z from z1, z2 using CRT, using algorithm from paper, + // which uses only a single modInverse. + // p = (z1 - z2) * m2⁻¹ (mod m1) + // z = z2 + p * m2 + z := new(big.Int).Set(z2) + + // Compute (z1 - z2) mod m1 [m1 == 2**n] into z1. + z1 = z1.And(z1, mask) + z2 = z2.And(z2, mask) + z1 = z1.Sub(z1, z2) + if z1.Sign() < 0 { + z1 = z1.Add(z1, m1) + } + + // Reuse z2 for p = z1 * m2inv. + m2inv := new(big.Int).ModInverse(m2, m1) + z2 = z2.Mul(z1, m2inv) + z2 = z2.And(z2, mask) + + // Reuse z1 for m2 * p. + z = z.Add(z, z1.Mul(z2, m2)) + z = z.Rem(z, m) + + return z +} + +func fastExpPow2(x, y *big.Int, mask *big.Int) *big.Int { + z := big.NewInt(1) + p := new(big.Int).Set(x) + t := new(big.Int) + + for _, b := range y.Bits() { + for i := 0; i < bits.UintSize; i++ { + if b&1 != 0 { + z, t = t.Mul(z, p), z + z = z.And(z, mask) + } + p, t = t.Mul(p, p), p + p = p.And(p, mask) + b >>= 1 + } + } + return z +} diff --git a/core/vm/contracts.go b/core/vm/contracts.go index 1b832b638695..a14f23f1c148 100644 --- a/core/vm/contracts.go +++ b/core/vm/contracts.go @@ -385,7 +385,13 @@ func (c *bigModExp) Run(input []byte) ([]byte, error) { // Modulo 0 is undefined, return zero return common.LeftPadBytes([]byte{}, int(modLen)), nil } - return common.LeftPadBytes(base.Exp(base, exp, mod).Bytes(), int(modLen)), nil + var v []byte + if mod.Bit(0) == 0 { // modulo is even + v = math.FastExp(base, exp, mod).Bytes() + } else { + v = base.Exp(base, exp, mod).Bytes() + } + return common.LeftPadBytes(v, int(modLen)), nil } // newCurvePoint unmarshals a binary blob into a bn256 elliptic curve point, diff --git a/oss-fuzz.sh b/oss-fuzz.sh index 745a5ba7c7c0..7f454ff307b4 100644 --- a/oss-fuzz.sh +++ b/oss-fuzz.sh @@ -125,5 +125,7 @@ compile_fuzzer tests/fuzzers/snap FuzzSRange fuzz_storage_range compile_fuzzer tests/fuzzers/snap FuzzByteCodes fuzz_byte_codes compile_fuzzer tests/fuzzers/snap FuzzTrieNodes fuzz_trie_nodes +compile_fuzzer tests/fuzzers/modexp Fuzz fuzzModexp + #TODO: move this to tests/fuzzers, if possible compile_fuzzer crypto/blake2b Fuzz fuzzBlake2b diff --git a/tests/fuzzers/modexp/modexp-fuzzer.go b/tests/fuzzers/modexp/modexp-fuzzer.go new file mode 100644 index 000000000000..0068c5030259 --- /dev/null +++ b/tests/fuzzers/modexp/modexp-fuzzer.go @@ -0,0 +1,84 @@ +// Copyright 2022 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package modexp + +import ( + "fmt" + "math/big" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/math" + "github.com/ethereum/go-ethereum/core/vm" +) + +// The function must return +// 1 if the fuzzer should increase priority of the +// given input during subsequent fuzzing (for example, the input is lexically +// correct and was parsed successfully); +// -1 if the input must not be added to corpus even if gives new coverage; and +// 0 otherwise +// other values are reserved for future use. +func Fuzz(input []byte) int { + if len(input) <= 96 { + return -1 + } + // Abort on too expensive inputs + precomp := vm.PrecompiledContractsBerlin[common.BytesToAddress([]byte{5})] + if gas := precomp.RequiredGas(input); gas > 40_000_000 { + return 0 + } + var ( + baseLen = new(big.Int).SetBytes(getData(input, 0, 32)).Uint64() + expLen = new(big.Int).SetBytes(getData(input, 32, 32)).Uint64() + modLen = new(big.Int).SetBytes(getData(input, 64, 32)).Uint64() + ) + // Handle a special case when both the base and mod length is zero + if baseLen == 0 && modLen == 0 { + return -1 + } + input = input[96:] + // Retrieve the operands and execute the exponentiation + var ( + base = new(big.Int).SetBytes(getData(input, 0, baseLen)) + exp = new(big.Int).SetBytes(getData(input, baseLen, expLen)) + mod = new(big.Int).SetBytes(getData(input, baseLen+expLen, modLen)) + ) + if mod.BitLen() == 0 { + // Modulo 0 is undefined, return zero + return -1 + } + var a = math.FastExp(new(big.Int).Set(base), new(big.Int).Set(exp), new(big.Int).Set(mod)) + var b = base.Exp(base, exp, mod) + if a.Cmp(b) != 0 { + panic(fmt.Sprintf("Inequality %x != %x", a, b)) + } + return 1 +} + +// getData returns a slice from the data based on the start and size and pads +// up to size with zero's. This function is overflow safe. +func getData(data []byte, start uint64, size uint64) []byte { + length := uint64(len(data)) + if start > length { + start = length + } + end := start + size + if end > length { + end = length + } + return common.RightPadBytes(data[start:end], int(size)) +} From e1bd851ee84afa40165f2e4da95892c331d36941 Mon Sep 17 00:00:00 2001 From: Martin Holst Swende Date: Wed, 17 Aug 2022 19:05:13 +0200 Subject: [PATCH 2/4] common/math: modexp optimizations --- common/math/modexp.go | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/common/math/modexp.go b/common/math/modexp.go index 9c5086ea48c2..b0a32e8c2739 100644 --- a/common/math/modexp.go +++ b/common/math/modexp.go @@ -11,8 +11,8 @@ import ( "github.com/ethereum/go-ethereum/common" ) -// FastExp is semantically equivalent to x.Exp(x,y, m), but is faster in -// when the mod is even. +// FastExp is semantically equivalent to x.Exp(x,y, m), but is faster for even +// modulus. func FastExp(x, y, m *big.Int) *big.Int { // Split m = m1 × m2 where m1 = 2ⁿ n := m.TrailingZeroBits() @@ -54,7 +54,17 @@ func FastExp(x, y, m *big.Int) *big.Int { func fastExpPow2(x, y *big.Int, mask *big.Int) *big.Int { z := big.NewInt(1) + if y.Sign() == 0 { + return z + } p := new(big.Int).Set(x) + p = p.And(p, mask) + if p.Cmp(z) <= 0 { // p <= 1 + return p + } + if y.Cmp(mask) > 0 { + y = new(big.Int).And(y, mask) + } t := new(big.Int) for _, b := range y.Bits() { From 39a88553d12e416f9e1809caf63e3601058da25c Mon Sep 17 00:00:00 2001 From: Martin Holst Swende Date: Fri, 26 Aug 2022 15:06:14 +0200 Subject: [PATCH 3/4] core/vm: special case base 1 in big modexp --- core/vm/contracts.go | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/core/vm/contracts.go b/core/vm/contracts.go index a14f23f1c148..24f64b68d761 100644 --- a/core/vm/contracts.go +++ b/core/vm/contracts.go @@ -380,15 +380,20 @@ func (c *bigModExp) Run(input []byte) ([]byte, error) { base = new(big.Int).SetBytes(getData(input, 0, baseLen)) exp = new(big.Int).SetBytes(getData(input, baseLen, expLen)) mod = new(big.Int).SetBytes(getData(input, baseLen+expLen, modLen)) + v []byte ) - if mod.BitLen() == 0 { + switch { + case mod.BitLen() == 0: // Modulo 0 is undefined, return zero return common.LeftPadBytes([]byte{}, int(modLen)), nil - } - var v []byte - if mod.Bit(0) == 0 { // modulo is even + case base.Cmp(common.Big1) == 0: + //If base == 1, then we can just return base % mod (if mod >= 1, which it is) + v = base.Mod(base, mod).Bytes() + case mod.Bit(0) == 0: + // Modulo is even v = math.FastExp(base, exp, mod).Bytes() - } else { + default: + // Modulo is odd v = base.Exp(base, exp, mod).Bytes() } return common.LeftPadBytes(v, int(modLen)), nil From 56c3ac0602b4a6b7985ee1d6e94b254dea089533 Mon Sep 17 00:00:00 2001 From: Martin Holst Swende Date: Tue, 4 Oct 2022 15:09:09 +0200 Subject: [PATCH 4/4] core/vm: disable fastexp --- core/vm/contracts.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/vm/contracts.go b/core/vm/contracts.go index 24f64b68d761..2e6753b16cb5 100644 --- a/core/vm/contracts.go +++ b/core/vm/contracts.go @@ -389,9 +389,9 @@ func (c *bigModExp) Run(input []byte) ([]byte, error) { case base.Cmp(common.Big1) == 0: //If base == 1, then we can just return base % mod (if mod >= 1, which it is) v = base.Mod(base, mod).Bytes() - case mod.Bit(0) == 0: - // Modulo is even - v = math.FastExp(base, exp, mod).Bytes() + //case mod.Bit(0) == 0: + // // Modulo is even + // v = math.FastExp(base, exp, mod).Bytes() default: // Modulo is odd v = base.Exp(base, exp, mod).Bytes()