Skip to content

Commit

Permalink
Merge pull request #50 from 0xPolygonHermez/feature/fix-modexp-sha256
Browse files Browse the repository at this point in the history
Fix counters modexp & add sha256
  • Loading branch information
laisolizq authored Nov 21, 2023
2 parents 894b2eb + 6c3652d commit 0b12c56
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 54 deletions.
65 changes: 41 additions & 24 deletions main/modexp/modexp_utils.zkasm
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,25 @@ modexp_getBase:
B :MSTORE(tmpVarBmodexp)
C :MSTORE(tmpVarCmodexp)
D :MSTORE(tmpVarDmodexp)
; offset init
E :MSTORE(offsetInitModexp)
E + C => E ;E = offset final
;E = offset final
E + C => E
0 :MSTORE(modExpArrayIndex)
0 :MSTORE(modexp_Blen)
32 :MSTORE(readXFromCalldataLength)

modexp_getBaseLoop:

%MAX_CNT_BINARY - CNT_BINARY - 6 :JMPN(outOfCountersBinary)
%MAX_CNT_STEPS - STEP - 100 :JMPN(outOfCountersStep)

; C length to read
C => A
0 => B
; if C (length) == 0 --> modexp_saveBaseLen
$ :EQ,JMPC(modexp_saveBaseLen)
32 => B
; if C (length) < 32 --> modexp_getBaseMloadX
$ :LT,JMPC(modexp_getBaseMloadX)
E - 32 => E
E :MSTORE(readXFromCalldataOffset), CALL(readFromCalldataOffset); in: [readXFromCalldataOffset: offset value, readXFromCalldataLength: length value], out: [readXFromCalldataResult: result value]
Expand All @@ -44,30 +49,32 @@ modexp_getBaseMloadX:
0 => C

modexp_getBaseMstore:
; mstore base at index E
E :MSTORE(tmpVarEmodexp)
A => E
$ => B :MLOAD(modExpArrayIndex)
$ => A :ADD
0 => B
$ :EQ,JMPC(modexp_getBaseFinal)
E => A
$ => E :MLOAD(modExpArrayIndex)
A :MSTORE(modexp_B+E)
; update modExpArrayIndex + 1
E + 1 => B :MSTORE(modExpArrayIndex)

modexp_getBaseFinal:
$ => E :MLOAD(tmpVarEmodexp),JMP(modexp_getBaseLoop)

modexp_saveBaseLen:

%MAX_CNT_BINARY - CNT_BINARY - 2 :JMPN(outOfCountersBinary)
%MAX_CNT_STEPS - STEP - 20 :JMPN(outOfCountersStep)

; if modExpArrayIndex == 0 --> modexp_getReturn
$ => A :MLOAD(modExpArrayIndex)
0 => B
$ :EQ,JMPC(modexp_getReturn)
; update modExpArrayIndex = modExpArrayIndex - 1
A - 1 => E :MSTORE(modExpArrayIndex)
; get value of the last index
$ => A :MLOAD(modexp_B + E)
; if last value == 0 --> modexp_saveBaseLen
$ :EQ,JMPC(modexp_saveBaseLen)
; else Blen == modExpArrayIndex + 1
E + 1 :MSTORE(modexp_Blen),JMP(modexp_getReturn)

modexp_getExp:
Expand All @@ -79,8 +86,10 @@ modexp_getExp:
B :MSTORE(tmpVarBmodexp)
C :MSTORE(tmpVarCmodexp)
D :MSTORE(tmpVarDmodexp)
; offset init
E :MSTORE(offsetInitModexp)
E + C => E ;E = offset final
;E = offset final
E + C => E
0 :MSTORE(modExpArrayIndex)
0 :MSTORE(modexp_Elen)
32 :MSTORE(readXFromCalldataLength)
Expand All @@ -89,11 +98,13 @@ modexp_getExpLoop:

%MAX_CNT_BINARY - CNT_BINARY - 6 :JMPN(outOfCountersBinary)
%MAX_CNT_STEPS - STEP - 100 :JMPN(outOfCountersStep)

; C length to read
C => A
0 => B
; if C (length) == 0 --> modexp_saveExpLen
$ :EQ,JMPC(modexp_saveExpLen)
32 => B
; if C (length) < 32 --> modexp_getExpMloadX
$ :LT,JMPC(modexp_getExpMloadX)
E - 32 => E
E :MSTORE(readXFromCalldataOffset), CALL(readFromCalldataOffset); in: [readXFromCalldataOffset: offset value, readXFromCalldataLength: length value], out: [readXFromCalldataResult: result value]
Expand All @@ -109,15 +120,11 @@ modexp_getExpMloadX:
0 => C

modexp_getExpMstore:
; mstore exp at index E
E :MSTORE(tmpVarEmodexp)
A => E
$ => B :MLOAD(modExpArrayIndex)
$ => A :ADD
0 => B
$ :EQ,JMPC(modexp_getExpFinal)
E => A
$ => E :MLOAD(modExpArrayIndex)
A :MSTORE(modexp_E+E)
; update modExpArrayIndex + 1
E + 1 => B :MSTORE(modExpArrayIndex)

modexp_getExpFinal:
Expand All @@ -128,12 +135,17 @@ modexp_saveExpLen:
%MAX_CNT_BINARY - CNT_BINARY - 2 :JMPN(outOfCountersBinary)
%MAX_CNT_STEPS - STEP - 20 :JMPN(outOfCountersStep)

; if modExpArrayIndex == 0 --> modexp_getReturn
$ => A :MLOAD(modExpArrayIndex)
0 => B
$ :EQ,JMPC(modexp_getReturn)
; update modExpArrayIndex = modExpArrayIndex - 1
A - 1 => E :MSTORE(modExpArrayIndex)
; get value of the last index
$ => A :MLOAD(modexp_E + E)
; if last value == 0 --> modexp_saveExpLen
$ :EQ,JMPC(modexp_saveExpLen)
; else Elen == modExpArrayIndex + 1
E + 1 :MSTORE(modexp_Elen),JMP(modexp_getReturn)

modexp_getMod:
Expand All @@ -145,8 +157,10 @@ modexp_getMod:
B :MSTORE(tmpVarBmodexp)
C :MSTORE(tmpVarCmodexp)
D :MSTORE(tmpVarDmodexp)
; offset init
E :MSTORE(offsetInitModexp)
E + C => E ;E = offset final
;E = offset final
E + C => E
0 :MSTORE(modExpArrayIndex)
0 :MSTORE(modexp_Mlen)
32 :MSTORE(readXFromCalldataLength)
Expand All @@ -155,11 +169,13 @@ modexp_getModLoop:

%MAX_CNT_BINARY - CNT_BINARY - 6 :JMPN(outOfCountersBinary)
%MAX_CNT_STEPS - STEP - 100 :JMPN(outOfCountersStep)

; C length to read
C => A
0 => B
; if C (length) == 0 --> modexp_saveModLen
$ :EQ,JMPC(modexp_saveModLen)
32 => B
; if C (length) < 32 --> modexp_getModMloadX
$ :LT,JMPC(modexp_getModMloadX)
E - 32 => E
E :MSTORE(readXFromCalldataOffset), CALL(readFromCalldataOffset); in: [readXFromCalldataOffset: offset value, readXFromCalldataLength: length value], out: [readXFromCalldataResult: result value]
Expand All @@ -175,15 +191,11 @@ modexp_getModMloadX:
0 => C

modexp_getModMstore:
; mstore mod at index E
E :MSTORE(tmpVarEmodexp)
A => E
$ => B :MLOAD(modExpArrayIndex)
$ => A :ADD
0 => B
$ :EQ,JMPC(modexp_getModFinal)
E => A
$ => E :MLOAD(modExpArrayIndex)
A :MSTORE(modexp_M+E)
; update modExpArrayIndex + 1
E + 1 => B :MSTORE(modExpArrayIndex)

modexp_getModFinal:
Expand All @@ -194,12 +206,17 @@ modexp_saveModLen:
%MAX_CNT_BINARY - CNT_BINARY - 2 :JMPN(outOfCountersBinary)
%MAX_CNT_STEPS - STEP - 20 :JMPN(outOfCountersStep)

; if modExpArrayIndex == 0 --> modexp_getReturn
$ => A :MLOAD(modExpArrayIndex)
0 => B
$ :EQ,JMPC(modexp_getReturn)
; update modExpArrayIndex = modExpArrayIndex - 1
A - 1 => E :MSTORE(modExpArrayIndex)
; get value of the last index
$ => A :MLOAD(modexp_M + E)
; if last value == 0 --> modexp_saveModLen
$ :EQ,JMPC(modexp_saveModLen)
; else Mlen == modExpArrayIndex + 1
E + 1 :MSTORE(modexp_Mlen),JMP(modexp_getReturn)

modexp_getReturn:
Expand Down
5 changes: 2 additions & 3 deletions main/precompiled/pre-ecAdd.zkasm
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
* - stack output: [x, y]
*/
funcEcAdd:

%MAX_CNT_BINARY - CNT_BINARY - 10 :JMPN(outOfCountersBinary)
%MAX_CNT_ARITH - CNT_ARITH - 50 :JMPN(outOfCountersArith)
%MAX_CNT_STEPS - STEP - 500 :JMPN(outOfCountersStep)
%MAX_CNT_BINARY - CNT_BINARY - 50 :JMPN(outOfCountersBinary)
%MAX_CNT_STEPS - STEP - 800 :JMPN(outOfCountersStep)

; Move balances if value > 0 just before executing the contract CALL
$ => B :MLOAD(txValue)
Expand Down
69 changes: 45 additions & 24 deletions main/precompiled/pre-modexp.zkasm
Original file line number Diff line number Diff line change
Expand Up @@ -64,32 +64,41 @@ funcModexp:
; store offset modexp num values
E + 32 => E
E :MSTORE(modexp_offset)
; get exp offset = 96 bytes (Bsize | Esize | Msize) + Bsize
$ => A :MLOAD(modexp_Bsize)
A :MSTORE(arithA)
96 :MSTORE(arithB),CALL(addARITH)
$ => A :MLOAD(arithRes1)
$ => B :MLOAD(txCalldataLen)
; expLenBits = bit length of first 32 bytes of exp
0 :MSTORE(expLenBits)
; if 96 + Bsize (exp offset) < txCalldataLen --> setExpBits, else --> expLenBits = 0
$ :LT,JMPC(setExpBits,setMaxLen)

setExpBits:
; E exp offset
A => E
$ => A,C :MLOAD(modexp_Esize)
33 => B
; A, C = Esize
; if Esize <= 32 --> setExpBitsContinue
$ => B :LT,JMPC(setExpBitsContinue)
32 => C

setExpBitsContinue:
; read a length of bytes (C) from exp offset (E)
C :MSTORE(readXFromCalldataLength)
E :MSTORE(readXFromCalldataOffset), CALL(readFromCalldataOffset); in: [readXFromCalldataOffset: offset value, readXFromCalldataLength: length value], out: [readXFromCalldataResult: result value]
$ => A :MLOAD(readXFromCalldataResult)
; A = first 32 bytes of exp
32 - C => D :CALL(SHRarith)
A => B :CALL(getLenBits); A bits length first 32 bytes
; A = bit length of first 32 bytes of exp
A :JMPZ(setMaxLen)
A - 1 :MSTORE(expLenBits)

setMaxLen:
; set B with max length (max(Blen, Mlen))
$ => B :MLOAD(modexp_Msize)
$ => A :MLOAD(modexp_Bsize)
$ :LT, JMPC(calculateGas)
Expand All @@ -107,7 +116,7 @@ calculateGas:
$ => B :MLOAD(arithRes1)
B :MSTORE(arithA)
8 :MSTORE(arithB),CALL(divARITH)
; C: words = (max_length + 7) / 8
; B: words = (max_length + 7) / 8
$ => B :MLOAD(arithRes1)
%MAX_GAS_WORD_MODEXP => A
$ :LT,JMPC(outOfGas)
Expand Down Expand Up @@ -164,17 +173,17 @@ dinamicGas:
%TX_GAS_LIMIT => B
$ :LT,JMPNC(outOfGas)
200 => B
$ :LT,JMPC(callMODEXP)
$ :LT,JMPC(lastChecks)
A => B

callMODEXP:
lastChecks:
; B = max(200, multiplication_complexity * iteration_count / 3)
GAS - B => GAS :JMPN(outOfGas)
0 => A
$ => B :MLOAD(modexp_Msize)
$ :EQ,JMPC(save0outMod0) ; if Msize = 0 --> save0outMod0
$ => B :MLOAD(modexp_Bsize)
$ :EQ,JMPC(save0out) ; if Bsize = 0 --> save0outMod0
$ :EQ,JMPC(save0out) ; if Bsize = 0 --> save0out
%MAX_SIZE_MODEXP => A
$ => B :MLOAD(modexp_Bsize)
$ :LT,JMPC(endMODEXPFail) ; if Bsize > MAX_SIZE_MODEXP --> endMODEXPFail
Expand All @@ -184,35 +193,46 @@ callMODEXP:
$ :LT,JMPC(endMODEXPFail) ; if Msize > MAX_SIZE_MODEXP --> endMODEXPFail
$ => E :MLOAD(modexp_offset)
$ => C :MLOAD(modexp_Bsize)
; get base value
:CALL(modexp_getBase)
$ => C :MLOAD(modexp_Esize)
; get exp value
:CALL(modexp_getExp)
$ => C :MLOAD(modexp_Msize)
; if Msize+offset > MAX_SAFE_INTEGER_MODEXP --> endMODEXPFail
E :MSTORE(arithA)
C :MSTORE(arithB),CALL(addARITH)
$ => B :MLOAD(arithRes1)
%MAX_SAFE_INTEGER_MODEXP => A
$ :LT,JMPC(endMODEXPFail) ; if Msize+offset > MAX_SAFE_INTEGER_MODEXP --> endMODEXPFail
$ :LT,JMPC(endMODEXPFail)
; get mod value
:CALL(modexp_getMod)
$ => A :MLOAD(modexp_Elen),JMPZ(modexpExp0)
$ => A :MLOAD(modexp_Mlen),JMPZ(modexpMod0)
$ => A :MLOAD(modexp_Blen),JMPZ(modexpBase0)
1 => B
; if mod == 0 --> return 0
$ => A :MLOAD(modexp_Mlen),JMPZ(save0out)
; if Mlen != 1 --> checkBaseLen
$ :EQ,JMPNC(checkExpLen)
; if Mlen == 1 && mod == 1 --> return 0
$ => A :MLOAD(modexp_M)
$ :EQ,JMPC(save0out)

checkExpLen:
; if exp == 0 --> return 1
$ => A :MLOAD(modexp_Elen),JMPZ(save1out)

checkBaseLen:
; if base == 0 --> return 0
$ => A :MLOAD(modexp_Blen),JMPZ(save0out)
; if Blen != 1 --> checkExpLen
$ :EQ,JMPNC(callMODEXP)
; if Blen == 1 && base == 1 --> return 1
$ => A :MLOAD(modexp_B)
$ :EQ,JMPC(save1out)

callMODEXP:
:CALL(modexp)
:JMP(finalMODEXP)

modexpBase0:
$ => A :MLOAD(modexp_Elen)
A :JMPZ(save1out)
0 :MSTORE(modexp_out),JMP(finalMODEXP)

modexpExp0:
modexpMod0:
$ => A :MLOAD(modexp_Mlen)
0 :MSTORE(modexp_out)
A :JMPZ(finalMODEXP)
A - 1 :JMPNZ(save1out)
$ => A :MLOAD(modexp_M)
A - 1 :JMPZ(finalMODEXP)
save1out:
1 :MSTORE(modexp_out),JMP(finalMODEXP)

Expand All @@ -223,7 +243,7 @@ save0out:
0 :MSTORE(modexp_out),JMP(finalMODEXP)

save0outMod0:
0 :MSTORE(modexp_out),JMP(endMODEXP)
0 :MSTORE(modexp_out),JMP(preEndMODEXP)

finalMODEXP:

Expand Down Expand Up @@ -283,7 +303,8 @@ endMODEXPFail:
CTX :MSTORE(currentCTX), JMP(preEndFail)

preEndMODEXP:
$ => CTX :MLOAD(originCTX)
$ => A :MLOAD(originCTX), JMPZ(handleGas)
A => CTX

endMODEXP:
CTX :MSTORE(currentCTX), JMP(preEnd)
CTX :MSTORE(currentCTX),JMP(preEnd)
Loading

0 comments on commit 0b12c56

Please sign in to comment.