Skip to content

Commit

Permalink
feat: more arithmetic optimizations (vyperlang#2647)
Browse files Browse the repository at this point in the history
this is a small rewrite of the IR optimizer. it changes the structure of
the binop optimizations so that it is easier to add more optimizations.
it also refactors the `clamp` optimizations to be in terms of an
`assert` statement, so that the clamp conditions can be optimized using
the binop optimizer code.

Co-authored-by: El De-dog-lo <3859395+fubuloubu@users.noreply.github.com>
  • Loading branch information
tserg and fubuloubu committed May 13, 2022
1 parent ed13745 commit 7219a8a
Show file tree
Hide file tree
Showing 25 changed files with 592 additions and 471 deletions.
4 changes: 1 addition & 3 deletions examples/auctions/blind_auction.vy
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,7 @@ def reveal(_numBids: int128, _values: uint256[128], _fakes: bool[128], _secrets:

# Bid was not actually revealed
# Do not refund deposit
if (blindedBid != bidToCheck.blindedBid):
assert 1 == 0
continue
assert blindedBid == bidToCheck.blindedBid

# Add deposit to refund if bid was indeed revealed
refund += bidToCheck.deposit
Expand Down
21 changes: 0 additions & 21 deletions tests/compiler/LLL/test_optimize_lll.py

This file was deleted.

File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ def test_ir_compile_fail(bad_ir, get_contract_from_ir, assert_compile_failed):

valid_list = [
["pass"],
["clamplt", ["mload", 0], 300],
["clampgt", ["mload", 0], -1],
["uclampgt", 1, ["mload", 0]],
["uclampge", ["mload", 0], 0],
["assert", ["slt", ["mload", 0], 300]],
["assert", ["sgt", ["mload", 0], -1]],
["assert", ["gt", 1, ["mload", 0]]],
["assert", ["ge", ["mload", 0], 0]],
]


Expand Down
147 changes: 147 additions & 0 deletions tests/compiler/ir/test_optimize_ir.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import pytest

from vyper.codegen.ir_node import IRnode
from vyper.exceptions import StaticAssertionException
from vyper.ir import optimizer

optimize_list = [
(["eq", 1, 2], [0]),
(["lt", 1, 2], [1]),
(["eq", "x", 0], ["iszero", "x"]),
# branch pruner
(["if", ["eq", 1, 2], "pass"], ["seq"]),
(["if", ["eq", 1, 1], 3, 4], [3]),
(["if", ["eq", 1, 2], 3, 4], [4]),
(["seq", ["assert", ["lt", 1, 2]]], ["seq"]),
(["seq", ["assert", ["lt", 1, 2]], 2], [2]),
# condition rewriter
(["if", ["eq", "x", "y"], "pass"], ["if", ["iszero", ["sub", "x", "y"]], "pass"]),
(["if", "cond", 1, 0], ["if", ["iszero", "cond"], 0, 1]),
(["assert", ["eq", "x", "y"]], ["assert", ["iszero", ["sub", "x", "y"]]]),
# nesting
(["mstore", 0, ["eq", 1, 2]], ["mstore", 0, 0]),
# conditions
(["ge", "x", 0], [1]), # x >= 0 == True
(["iszero", ["gt", "x", 2 ** 256 - 1]], [1]), # x >= MAX_UINT256 == False
(["iszero", ["sgt", "x", 2 ** 255 - 1]], [1]), # signed x >= MAX_INT256 == False
(["le", "x", 0], ["iszero", "x"]),
(["le", 0, "x"], [1]),
(["lt", "x", 0], [0]),
(["lt", 0, "x"], ["iszero", ["iszero", "x"]]),
(["gt", 5, "x"], ["lt", "x", 5]),
(["ge", 5, "x"], ["le", "x", 5]),
(["lt", 5, "x"], ["gt", "x", 5]),
(["le", 5, "x"], ["ge", "x", 5]),
(["sgt", 5, "x"], ["slt", "x", 5]),
(["sge", 5, "x"], ["sle", "x", 5]),
(["slt", 5, "x"], ["sgt", "x", 5]),
(["sle", 5, "x"], ["sge", "x", 5]),
(["slt", "x", -(2 ** 255)], ["slt", "x", -(2 ** 255)]), # unimplemented
# tricky conditions
(["sgt", 2 ** 256 - 1, 0], [0]), # -1 > 0
(["gt", 2 ** 256 - 1, 0], [1]), # -1 > 0
(["gt", 2 ** 255, 0], [1]), # 0x80 > 0
(["sgt", 2 ** 255, 0], [0]), # 0x80 > 0
(["sgt", 2 ** 255, 2 ** 255 - 1], [0]), # 0x80 > 0x81
(["gt", -(2 ** 255), 2 ** 255 - 1], [1]), # 0x80 > 0x81
(["slt", 2 ** 255, 2 ** 255 - 1], [1]), # 0x80 < 0x7f
(["lt", -(2 ** 255), 2 ** 255 - 1], [0]), # 0x80 < 0x7f
(["sle", -1, 2 ** 256 - 1], [1]), # -1 <= -1
(["sge", -(2 ** 255), 2 ** 255], [1]), # 0x80 >= 0x80
(["sgt", -(2 ** 255), 2 ** 255], [0]), # 0x80 > 0x80
(["slt", 2 ** 255, -(2 ** 255)], [0]), # 0x80 < 0x80
# arithmetic
(["add", "x", 0], ["x"]),
(["add", 0, "x"], ["x"]),
(["sub", "x", 0], ["x"]),
(["sub", "x", "x"], [0]),
(["sub", ["sload", 0], ["sload", 0]], ["sub", ["sload", 0], ["sload", 0]]), # no-op
(["sub", ["callvalue"], ["callvalue"]], ["sub", ["callvalue"], ["callvalue"]]), # no-op
(["mul", "x", 1], ["x"]),
(["div", "x", 1], ["x"]),
(["sdiv", "x", 1], ["x"]),
(["mod", "x", 1], [0]),
(["smod", "x", 1], [0]),
(["mul", "x", -1], ["sub", 0, "x"]),
(["sdiv", "x", -1], ["sub", 0, "x"]),
(["mul", "x", 0], [0]),
(["div", "x", 0], [0]),
(["sdiv", "x", 0], [0]),
(["mod", "x", 0], [0]),
(["smod", "x", 0], [0]),
(["mul", "x", 32], ["shl", 5, "x"]),
(["div", "x", 64], ["shr", 6, "x"]),
(["mod", "x", 128], ["and", "x", 127]),
(["sdiv", "x", 64], ["sdiv", "x", 64]), # no-op
(["smod", "x", 64], ["smod", "x", 64]), # no-op
# bitwise ops
(["shr", 0, "x"], ["x"]),
(["sar", 0, "x"], ["x"]),
(["shl", 0, "x"], ["x"]),
(["and", 1, 2], [0]),
(["or", 1, 2], [3]),
(["xor", 1, 2], [3]),
(["xor", 3, 2], [1]),
(["and", 0, "x"], [0]),
(["and", "x", 0], [0]),
(["or", "x", 0], ["x"]),
(["or", 0, "x"], ["x"]),
(["xor", "x", 0], ["x"]),
(["xor", "x", 1], ["xor", "x", 1]), # no-op
(["and", "x", 1], ["and", "x", 1]), # no-op
(["or", "x", 1], ["or", "x", 1]), # no-op
(["xor", 0, "x"], ["x"]),
(["iszero", ["or", "x", 1]], [0]),
(["iszero", ["or", 2, "x"]], [0]),
# nested optimizations
(["eq", 0, ["sub", 1, 1]], [1]),
(["eq", 0, ["add", 2 ** 255, 2 ** 255]], [1]), # test compile-time wrapping
(["eq", 0, ["add", 2 ** 255, -(2 ** 255)]], [1]), # test compile-time wrapping
(["eq", -1, ["add", 0, -1]], [1]), # test compile-time wrapping
(["eq", -1, ["add", 2 ** 255, 2 ** 255 - 1]], [1]), # test compile-time wrapping
(["eq", -1, ["add", -(2 ** 255), 2 ** 255 - 1]], [1]), # test compile-time wrapping
(["eq", -2, ["add", 2 ** 256 - 1, 2 ** 256 - 1]], [1]), # test compile-time wrapping
]


@pytest.mark.parametrize("ir", optimize_list)
def test_ir_optimizer(ir):
optimized = optimizer.optimize(IRnode.from_list(ir[0]))
optimized.repr_show_gas = True
hand_optimized = IRnode.from_list(ir[1])
hand_optimized.repr_show_gas = True
assert optimized == hand_optimized


static_assertions_list = [
["assert", ["eq", 2, 1]],
["assert", ["ne", 1, 1]],
["assert", ["sub", 1, 1]],
["assert", ["lt", 2, 1]],
["assert", ["lt", 1, 1]],
["assert", ["lt", "x", 0]], # +x < 0
["assert", ["le", 1, 0]],
["assert", ["le", 2 ** 256 - 1, 0]],
["assert", ["gt", 1, 2]],
["assert", ["gt", 1, 1]],
["assert", ["gt", 0, 2 ** 256 - 1]],
["assert", ["gt", "x", 2 ** 256 - 1]],
["assert", ["ge", 1, 2]],
["assert", ["ge", 1, 2]],
["assert", ["slt", 2, 1]],
["assert", ["slt", 1, 1]],
["assert", ["slt", 0, 2 ** 256 - 1]], # 0 < -1
["assert", ["slt", -(2 ** 255), 2 ** 255]], # 0x80 < 0x80
["assert", ["sle", 0, 2 ** 255]], # 0 < 0x80
["assert", ["sgt", 1, 2]],
["assert", ["sgt", 1, 1]],
["assert", ["sgt", 2 ** 256 - 1, 0]], # -1 > 0
["assert", ["sgt", 2 ** 255, -(2 ** 255)]], # 0x80 > 0x80
["assert", ["sge", 2 ** 255, 0]], # 0x80 > 0
]


@pytest.mark.parametrize("ir", static_assertions_list)
def test_static_assertions(ir, assert_compile_failed):
ir = IRnode.from_list(ir)
assert_compile_failed(lambda: optimizer.optimize(ir), StaticAssertionException)
File renamed without changes.
File renamed without changes.
94 changes: 0 additions & 94 deletions tests/compiler/test_clamps.py

This file was deleted.

2 changes: 1 addition & 1 deletion tests/parser/features/test_assert.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def test_assert_refund(w3, get_contract_with_gas_estimation, assert_tx_failed):
code = """
@external
def foo():
assert 1 == 2
raise
"""
c = get_contract_with_gas_estimation(code)
a0 = w3.eth.accounts[0]
Expand Down
2 changes: 1 addition & 1 deletion tests/parser/features/test_assert_unreachable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ def test_assure_refund(w3, get_contract):
code = """
@external
def foo():
assert 1 == 2, UNREACHABLE
assert msg.sender != msg.sender, UNREACHABLE
"""

c = get_contract(code)
Expand Down
3 changes: 1 addition & 2 deletions tests/parser/functions/test_raw_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,7 @@ def test_multiple_levels2(assert_tx_failed, get_contract_with_gas_estimation):
inner_code = """
@external
def returnten() -> int128:
assert False
return 10
raise
"""

c = get_contract_with_gas_estimation(inner_code)
Expand Down
4 changes: 2 additions & 2 deletions tests/parser/syntax/test_unbalanced_return.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test() -> int128:
if True:
return 0
else:
assert False
assert msg.sender != msg.sender
""",
FunctionDeclarationException,
),
Expand Down Expand Up @@ -108,7 +108,7 @@ def test() -> int128:
if 1 == 1 :
return 1
else:
assert False
assert msg.sender != msg.sender
return 0
""",
"""
Expand Down
10 changes: 5 additions & 5 deletions vyper/builtin_functions/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
IRnode,
bytes_clamp,
bytes_data_ptr,
clamp,
clamp_basetype,
get_bytearray_length,
int_clamp,
Expand Down Expand Up @@ -119,15 +120,14 @@ def _clamp_numeric_convert(arg, arg_bounds, out_bounds, arg_is_signed):
if arg_lo < out_lo:
# if not arg_is_signed, arg_lo is 0, so this branch cannot be hit
assert arg_is_signed, "bad assumption in numeric convert"
CLAMPGE = "clampge"
arg = [CLAMPGE, arg, out_lo]
arg = clamp("sge", arg, out_lo)

if arg_hi > out_hi:
# out_hi must be smaller than MAX_UINT256, so clample makes sense.
# add an assertion, just in case this assumption ever changes.
assert out_hi < 2 ** 256 - 1, "bad assumption in numeric convert"
CLAMPLE = "clample" if arg_is_signed else "uclample"
arg = [CLAMPLE, arg, out_hi]
CLAMP_OP = "sle" if arg_is_signed else "le"
arg = clamp(CLAMP_OP, arg, out_hi)

return arg

Expand Down Expand Up @@ -194,7 +194,7 @@ def _int_to_int(arg, out_typ):

else:
# note: this also works for out_bits == 256.
arg = IRnode.from_list(["clampge", arg, 0])
arg = clamp("sge", arg, 0)

elif not arg_info.is_signed and out_info.is_signed:
# e.g. (uclample (uclampge arg 0) (2**127 - 1))
Expand Down
Loading

0 comments on commit 7219a8a

Please sign in to comment.