diff --git a/pyteal/compiler/compiler_test.py b/pyteal/compiler/compiler_test.py index 7a36fcf6c..74b5799b5 100644 --- a/pyteal/compiler/compiler_test.py +++ b/pyteal/compiler/compiler_test.py @@ -1153,7 +1153,7 @@ def multiplyByAdding(a, b): assert actual == expected -def test_compile_subroutine_mutually_recursive(): +def test_compile_subroutine_mutually_recursive_4(): @Subroutine(TealType.uint64) def isEven(i: Expr) -> Expr: return If(i == Int(0), Int(1), Not(isOdd(i - Int(1)))) @@ -1285,6 +1285,147 @@ def isOdd(i: Expr) -> Expr: assert actual == expected +def test_compile_subroutine_mutually_recursive_different_arg_count_4(): + @Subroutine(TealType.uint64) + def factorial(i: Expr) -> Expr: + return If( + i <= Int(1), + Int(1), + factorial_intermediate(i - Int(1), Bytes("inconsequential")) * i, + ) + + @Subroutine(TealType.uint64) + def factorial_intermediate(i: Expr, j: Expr) -> Expr: + return Seq(Pop(j), factorial(i)) + + program = Return(factorial(Int(4)) == Int(24)) + + expected = """#pragma version 4 +int 4 +callsub factorial_0 +int 24 +== +return + +// factorial +factorial_0: +store 0 +load 0 +int 1 +<= +bnz factorial_0_l2 +load 0 +int 1 +- +byte "inconsequential" +load 0 +dig 2 +dig 2 +callsub factorialintermediate_1 +swap +store 0 +swap +pop +swap +pop +load 0 +* +b factorial_0_l3 +factorial_0_l2: +int 1 +factorial_0_l3: +retsub + +// factorial_intermediate +factorialintermediate_1: +store 2 +store 1 +load 2 +pop +load 1 +load 1 +load 2 +dig 2 +callsub factorial_0 +store 1 +store 2 +load 1 +swap +store 1 +swap +pop +retsub + """.strip() + actual = compileTeal(program, Mode.Application, version=4, assembleConstants=False) + assert actual == expected + + +def test_compile_subroutine_mutually_recursive_different_arg_count_5(): + @Subroutine(TealType.uint64) + def factorial(i: Expr) -> Expr: + return If( + i <= Int(1), + Int(1), + factorial_intermediate(i - Int(1), Bytes("inconsequential")) * i, + ) + + @Subroutine(TealType.uint64) + def factorial_intermediate(i: Expr, j: Expr) -> Expr: + return Seq(Log(j), factorial(i)) + + program = Return(factorial(Int(4)) == Int(24)) + + expected = """#pragma version 5 +int 4 +callsub factorial_0 +int 24 +== +return + +// factorial +factorial_0: +store 0 +load 0 +int 1 +<= +bnz factorial_0_l2 +load 0 +int 1 +- +byte "inconsequential" +load 0 +cover 2 +callsub factorialintermediate_1 +swap +store 0 +load 0 +* +b factorial_0_l3 +factorial_0_l2: +int 1 +factorial_0_l3: +retsub + +// factorial_intermediate +factorialintermediate_1: +store 2 +store 1 +load 2 +log +load 1 +load 1 +load 2 +uncover 2 +callsub factorial_0 +cover 2 +store 2 +store 1 +retsub + """.strip() + actual = compileTeal(program, Mode.Application, version=5, assembleConstants=False) + assert actual == expected + + def test_compile_loop_in_subroutine(): @Subroutine(TealType.none) def setState(value: Expr) -> Expr: diff --git a/pyteal/compiler/subroutines.py b/pyteal/compiler/subroutines.py index 1b5e63e62..78bfa7ada 100644 --- a/pyteal/compiler/subroutines.py +++ b/pyteal/compiler/subroutines.py @@ -94,7 +94,6 @@ def spillLocalSlotsDuringRecursion( for subroutine, reentryPoints in recursivePoints.items(): slots = list(sorted(slot for slot in localSlots[subroutine])) - numArgs = subroutine.argumentCount() if len(reentryPoints) == 0 or len(slots) == 0: # no need to spill slots @@ -107,13 +106,26 @@ def spillLocalSlotsDuringRecursion( before: List[TealComponent] = [] after: List[TealComponent] = [] - if len(reentryPoints.intersection(stmt.getSubroutines())) != 0: + calledSubroutines = stmt.getSubroutines() + # the only opcode that references subroutines is callsub, and it should only ever + # reference one subroutine at a time + assert ( + len(calledSubroutines) <= 1 + ), "Multiple subroutines are called from the same TealComponent" + + reentrySubroutineCalls = list(reentryPoints.intersection(calledSubroutines)) + if len(reentrySubroutineCalls) != 0: # A subroutine is being called which may reenter the current subroutine, so insert # ops to spill local slots to the stack before calling the subroutine and also to # restore the local slots after returning from the subroutine. This prevents a # reentry into the current subroutine from modifying variables we are currently # using. + # reentrySubroutineCalls should have a length of 1, since calledSubroutines has a + # maximum length of 1 + reentrySubroutineCall = reentrySubroutineCalls[0] + numArgs = reentrySubroutineCall.argumentCount() + digArgs = True coverSpilledSlots = False uncoverArgs = False