Skip to content

Commit

Permalink
Fix subroutine mutual recursion with different argument counts bug (#234
Browse files Browse the repository at this point in the history
)

* Fix mutual recursion bug

* Remove usage of set.pop
  • Loading branch information
jasonpaulos authored Mar 7, 2022
1 parent f2598da commit b01f86e
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 3 deletions.
143 changes: 142 additions & 1 deletion pyteal/compiler/compiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1153,7 +1153,7 @@ def multiplyByAdding(a, b):
assert actual == expected


def test_compile_subroutine_mutually_recursive():
def test_compile_subroutine_mutually_recursive_4():
@Subroutine(TealType.uint64)
def isEven(i: Expr) -> Expr:
return If(i == Int(0), Int(1), Not(isOdd(i - Int(1))))
Expand Down Expand Up @@ -1285,6 +1285,147 @@ def isOdd(i: Expr) -> Expr:
assert actual == expected


def test_compile_subroutine_mutually_recursive_different_arg_count_4():
@Subroutine(TealType.uint64)
def factorial(i: Expr) -> Expr:
return If(
i <= Int(1),
Int(1),
factorial_intermediate(i - Int(1), Bytes("inconsequential")) * i,
)

@Subroutine(TealType.uint64)
def factorial_intermediate(i: Expr, j: Expr) -> Expr:
return Seq(Pop(j), factorial(i))

program = Return(factorial(Int(4)) == Int(24))

expected = """#pragma version 4
int 4
callsub factorial_0
int 24
==
return
// factorial
factorial_0:
store 0
load 0
int 1
<=
bnz factorial_0_l2
load 0
int 1
-
byte "inconsequential"
load 0
dig 2
dig 2
callsub factorialintermediate_1
swap
store 0
swap
pop
swap
pop
load 0
*
b factorial_0_l3
factorial_0_l2:
int 1
factorial_0_l3:
retsub
// factorial_intermediate
factorialintermediate_1:
store 2
store 1
load 2
pop
load 1
load 1
load 2
dig 2
callsub factorial_0
store 1
store 2
load 1
swap
store 1
swap
pop
retsub
""".strip()
actual = compileTeal(program, Mode.Application, version=4, assembleConstants=False)
assert actual == expected


def test_compile_subroutine_mutually_recursive_different_arg_count_5():
@Subroutine(TealType.uint64)
def factorial(i: Expr) -> Expr:
return If(
i <= Int(1),
Int(1),
factorial_intermediate(i - Int(1), Bytes("inconsequential")) * i,
)

@Subroutine(TealType.uint64)
def factorial_intermediate(i: Expr, j: Expr) -> Expr:
return Seq(Log(j), factorial(i))

program = Return(factorial(Int(4)) == Int(24))

expected = """#pragma version 5
int 4
callsub factorial_0
int 24
==
return
// factorial
factorial_0:
store 0
load 0
int 1
<=
bnz factorial_0_l2
load 0
int 1
-
byte "inconsequential"
load 0
cover 2
callsub factorialintermediate_1
swap
store 0
load 0
*
b factorial_0_l3
factorial_0_l2:
int 1
factorial_0_l3:
retsub
// factorial_intermediate
factorialintermediate_1:
store 2
store 1
load 2
log
load 1
load 1
load 2
uncover 2
callsub factorial_0
cover 2
store 2
store 1
retsub
""".strip()
actual = compileTeal(program, Mode.Application, version=5, assembleConstants=False)
assert actual == expected


def test_compile_loop_in_subroutine():
@Subroutine(TealType.none)
def setState(value: Expr) -> Expr:
Expand Down
16 changes: 14 additions & 2 deletions pyteal/compiler/subroutines.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ def spillLocalSlotsDuringRecursion(

for subroutine, reentryPoints in recursivePoints.items():
slots = list(sorted(slot for slot in localSlots[subroutine]))
numArgs = subroutine.argumentCount()

if len(reentryPoints) == 0 or len(slots) == 0:
# no need to spill slots
Expand All @@ -107,13 +106,26 @@ def spillLocalSlotsDuringRecursion(
before: List[TealComponent] = []
after: List[TealComponent] = []

if len(reentryPoints.intersection(stmt.getSubroutines())) != 0:
calledSubroutines = stmt.getSubroutines()
# the only opcode that references subroutines is callsub, and it should only ever
# reference one subroutine at a time
assert (
len(calledSubroutines) <= 1
), "Multiple subroutines are called from the same TealComponent"

reentrySubroutineCalls = list(reentryPoints.intersection(calledSubroutines))
if len(reentrySubroutineCalls) != 0:
# A subroutine is being called which may reenter the current subroutine, so insert
# ops to spill local slots to the stack before calling the subroutine and also to
# restore the local slots after returning from the subroutine. This prevents a
# reentry into the current subroutine from modifying variables we are currently
# using.

# reentrySubroutineCalls should have a length of 1, since calledSubroutines has a
# maximum length of 1
reentrySubroutineCall = reentrySubroutineCalls[0]
numArgs = reentrySubroutineCall.argumentCount()

digArgs = True
coverSpilledSlots = False
uncoverArgs = False
Expand Down

0 comments on commit b01f86e

Please sign in to comment.