Skip to content

Commit

Permalink
Merge pull request #197 from zkemail/circuit-fixes
Browse files Browse the repository at this point in the history
Circuit fixes
  • Loading branch information
saleel authored May 27, 2024
2 parents 1b11429 + 9917437 commit 95cd901
Show file tree
Hide file tree
Showing 13 changed files with 134 additions and 835 deletions.
10 changes: 6 additions & 4 deletions packages/circuits/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ Base64Decode: Decodes a base64 encoded string into binary data.
</details>

## Utils
This section provides an overview of utility circom templates available in the `@zk-email/circuits/utils` directory. These templates assist in the construction of zk circuits for various applications beyond the core ZK Email functionalities.
This section provides an overview of utility circom templates available in the `@zk-email/circuits/utils` directory. These templates assist in the construction of ZK circuits for various applications beyond the core ZK Email functionalities.

> Important: When using these templates outside of zk-email, please ensure you read the assumptions on the input signals that are documented above each template source code. You would need to constrain the inputs accordingly before you pass them to these utility circuits.
### `utils/array.circom`

Expand Down Expand Up @@ -276,14 +278,14 @@ Constants: Defines a set of constants used across various circom circuits for st

<details>
<summary>
log2Ceil: Calculates the ceiling of the base 2 logarithm of a given number.
log2Ceil: Calculate log2 of a number and round it up
</summary>

- **[Source](utils/functions.circom#L2-L10)**
- **Inputs**:
- `a`: The input number for which the base 2 logarithm ceiling is to be calculated.
- `a`: The input number for which the `ceil(log2())` needs to be calculated.
- **Outputs**:
- Returns the smallest integer greater than or equal to the base 2 logarithm of the input number.
- Returns `ceil(log2())` of the input number.
</details>


Expand Down
23 changes: 17 additions & 6 deletions packages/circuits/email-verifier.circom
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,13 @@ template EmailVerifier(maxHeadersLength, maxBodyLength, n, k, ignoreBodyHashChec
signal output pubkeyHash;


// Assert emailHeader data after emailHeaderLength are zeros
AssertZeroPadding(maxHeadersLength)(emailHeader, emailHeaderLength + 1);
// Assert `emailHeaderLength` fits in `ceil(log2(maxHeadersLength))`
component n2bHeaderLength = Num2Bits(log2Ceil(maxHeadersLength));
n2bHeaderLength.in <== emailHeaderLength;


// Assert `emailHeader` data after `emailHeaderLength` are zeros
AssertZeroPadding(maxHeadersLength)(emailHeader, emailHeaderLength);


// Calculate SHA256 hash of the `emailHeader` - 506,670 constraints
Expand Down Expand Up @@ -84,8 +89,15 @@ template EmailVerifier(maxHeadersLength, maxBodyLength, n, k, ignoreBodyHashChec
signal input emailBody[maxBodyLength];
signal input emailBodyLength;

// Assert data after the body (maxBodyLength - emailBody.length) is all zeroes
AssertZeroPadding(maxBodyLength)(emailBody, emailBodyLength + 1);

// Assert `emailBodyLength` fits in `ceil(log2(maxBodyLength))`
component n2bBodyLength = Num2Bits(log2Ceil(maxBodyLength));
n2bBodyLength.in <== emailBodyLength;


// Assert data after the body (`maxBodyLength - emailBody.length`) is all zeroes
AssertZeroPadding(maxBodyLength)(emailBody, emailBodyLength);


// Body hash regex - 617,597 constraints
// Extract the body hash from the header (i.e. the part after bh= within the DKIM-signature section)
Expand All @@ -96,8 +108,7 @@ template EmailVerifier(maxHeadersLength, maxBodyLength, n, k, ignoreBodyHashChec
signal bhBase64[shaB64Length] <== SelectRegexReveal(maxHeadersLength, shaB64Length)(bhReveal, bodyHashIndex);
signal headerBodyHash[32] <== Base64Decode(32)(bhBase64);


// Compute SHA256 of email body : 760,142 constraints
// Compute SHA256 of email body : 760,142 constraints (for maxBodyLength = 1536)
// We are using a technique to save constraints by precomputing the SHA hash of the body till the area we want to extract
// It doesn't have an impact on security since a user must have known the pre-image of a signed message to be able to fake it
signal computedBodyHash[256] <== Sha256BytesPartial(maxBodyLength)(emailBody, emailBodyLength, precomputedSHA);
Expand Down
7 changes: 5 additions & 2 deletions packages/circuits/lib/base64.circom
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@ include "circomlib/circuits/comparators.circom";

/// @title Base64Decode
/// @notice Decodes a Base64 encoded string to array of bytes.
/// @notice Only support inputs with length = `byteLength` (no 0 padding).
/// @notice It is known that padding char '=' can be replaed with `A` to produce the same output
/// as Base64Lookup returns `0` for both, but a pracical attack from this is unlikely.
/// @param byteLength Byte length of the encoded value - length of the output array.
/// @input in Base64 encoded string.
/// @input in Base64 encoded string; assumes elements to be valid Base64 characters.
/// @output out Decoded array of bytes.
template Base64Decode(byteLength) {
var charLength = 4 * ((byteLength + 2) \ 3); // 4 chars encode 3 bytes
Expand Down Expand Up @@ -63,7 +66,7 @@ template Base64Decode(byteLength) {

/// @title Base64Lookup
/// @notice http://0x80.pl/notesen/2016-01-17-sse-base64-decoding.html#vector-lookup-base
/// @input in input character.
/// @input in input character; assumes input to be valid Base64 character (though constrained implicitly).
/// @output out output bit value.
template Base64Lookup() {
signal input in;
Expand Down
184 changes: 3 additions & 181 deletions packages/circuits/lib/bigint-func.circom
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
pragma circom 2.1.6;

function isNegative(x) {
// half babyjubjub field size
return x > 10944121435919637611123202872628637544274182200208017171849102093287904247808 ? 1 : 0;
}

function div_ceil(m, n) {
var ret = 0;
Expand All @@ -26,14 +22,6 @@ function log_ceil(n) {
return 254;
}

function SplitFn(in, n, m) {
return [in % (1 << n), (in \ (1 << n)) % (1 << m)];
}

function SplitThreeFn(in, n, m, k) {
return [in % (1 << n), (in \ (1 << n)) % (1 << m), (in \ (1 << n + m)) % (1 << k)];
}

// m bits per overflowed register (values are potentially negative)
// n bits per properly-sized register
// in has k registers
Expand Down Expand Up @@ -188,7 +176,7 @@ function long_div(n, k, m, a, b){
}
m -= k;

var remainder[200];
var remainder[100];
for (var i = 0; i < m + k; i++) {
remainder[i] = a[i];
}
Expand Down Expand Up @@ -262,9 +250,9 @@ function short_div(n, k, a, b) {
var scale = (1 << n) \ (1 + b[k - 1]);

// k + 2 registers now
var norm_a[200] = long_scalar_mult(n, k + 1, scale, a);
var norm_a[100] = long_scalar_mult(n, k + 1, scale, a);
// k + 1 registers now
var norm_b[200] = long_scalar_mult(n, k, scale, b);
var norm_b[100] = long_scalar_mult(n, k, scale, b);

var ret;
if (norm_b[k] != 0) {
Expand All @@ -274,169 +262,3 @@ function short_div(n, k, a, b) {
}
return ret;
}

// n bits per register
// a and b both have k registers
// out[0] has length 2 * k
// adapted from BigMulShortLong and LongToShortNoEndCarry2 witness computation
function prod(n, k, a, b) {
// first compute the intermediate values. taken from BigMulShortLong
var prod_val[100]; // length is 2 * k - 1
for (var i = 0; i < 2 * k - 1; i++) {
prod_val[i] = 0;
if (i < k) {
for (var a_idx = 0; a_idx <= i; a_idx++) {
prod_val[i] = prod_val[i] + a[a_idx] * b[i - a_idx];
}
} else {
for (var a_idx = i - k + 1; a_idx < k; a_idx++) {
prod_val[i] = prod_val[i] + a[a_idx] * b[i - a_idx];
}
}
}

// now do a bunch of carrying to make sure registers not overflowed. taken from LongToShortNoEndCarry2
var out[100]; // length is 2 * k

var split[100][3]; // first dimension has length 2 * k - 1
for (var i = 0; i < 2 * k - 1; i++) {
split[i] = SplitThreeFn(prod_val[i], n, n, n);
}

var carry[100]; // length is 2 * k - 1
carry[0] = 0;
out[0] = split[0][0];
if (2 * k - 1 > 1) {
var sumAndCarry[2] = SplitFn(split[0][1] + split[1][0], n, n);
out[1] = sumAndCarry[0];
carry[1] = sumAndCarry[1];
}
if (2 * k - 1 > 2) {
for (var i = 2; i < 2 * k - 1; i++) {
var sumAndCarry[2] = SplitFn(split[i][0] + split[i-1][1] + split[i-2][2] + carry[i-1], n, n);
out[i] = sumAndCarry[0];
carry[i] = sumAndCarry[1];
}
out[2 * k - 1] = split[2*k-2][1] + split[2*k-3][2] + carry[2*k-2];
}
return out;
}

// n bits per register
// a has k registers
// p has k registers
// e has k registers
// k * n <= 500
// p is a prime
// computes a^e mod p
function mod_exp(n, k, a, p, e) {
var eBits[500]; // length is k * n
for (var i = 0; i < k; i++) {
for (var j = 0; j < n; j++) {
eBits[j + n * i] = (e[i] >> j) & 1;
}
}

var out[100]; // length is k
for (var i = 0; i < 100; i++) {
out[i] = 0;
}
out[0] = 1;

// repeated squaring
for (var i = k * n - 1; i >= 0; i--) {
// multiply by a if bit is 0
if (eBits[i] == 1) {
var temp[200]; // length 2 * k
temp = prod(n, k, out, a);
var temp2[2][100];
temp2 = long_div(n, k, k, temp, p);
out = temp2[1];
}

// square, unless we're at the end
if (i > 0) {
var temp[200]; // length 2 * k
temp = prod(n, k, out, out);
var temp2[2][100];
temp2 = long_div(n, k, k, temp, p);
out = temp2[1];
}

}
return out;
}

// n bits per register
// a has k registers
// p has k registers
// k * n <= 500
// p is a prime
// if a == 0 mod p, returns 0
// else computes inv = a^(p-2) mod p
function mod_inv(n, k, a, p) {
var isZero = 1;
for (var i = 0; i < k; i++) {
if (a[i] != 0) {
isZero = 0;
}
}
if (isZero == 1) {
var ret[100];
for (var i = 0; i < k; i++) {
ret[i] = 0;
}
return ret;
}

var pCopy[100];
for (var i = 0; i < 100; i++) {
if (i < k) {
pCopy[i] = p[i];
} else {
pCopy[i] = 0;
}
}

var two[100];
for (var i = 0; i < 100; i++) {
two[i] = 0;
}
two[0] = 2;

var pMinusTwo[100];
pMinusTwo = long_sub(n, k, pCopy, two); // length k
var out[100];
out = mod_exp(n, k, a, pCopy, pMinusTwo);
return out;
}

// a, b and out are all n bits k registers
function long_sub_mod_p(n, k, a, b, p){
var gt = long_gt(n, k, a, b);
var tmp[100];
if(gt){
tmp = long_sub(n, k, a, b);
}
else{
tmp = long_sub(n, k, b, a);
}
var out[2][100];
for(var i = k;i < 2 * k; i++){
tmp[i] = 0;
}
out = long_div(n, k, k, tmp, p);
if(gt==0){
tmp = long_sub(n, k, p, out[1]);
}
return tmp;
}

// a, b, p and out are all n bits k registers
function prod_mod_p(n, k, a, b, p){
var tmp[100];
var result[2][100];
tmp = prod(n, k, a, b);
result = long_div(n, k, k, tmp, p);
return result[1];
}
Loading

0 comments on commit 95cd901

Please sign in to comment.