From 0b1996482bfeaf0d8d5b1e642f33a9dbd6aa8a33 Mon Sep 17 00:00:00 2001 From: Ben Guidarelli Date: Wed, 5 Oct 2022 10:18:58 -0400 Subject: [PATCH 1/8] Add overriding name to method spec so it comes through in contract json --- pyteal/ast/router.py | 2 +- pyteal/ast/router_test.py | 1214 +++++++++++++++++---------------- pyteal/ast/subroutine.py | 5 +- pyteal/ast/subroutine_test.py | 11 + 4 files changed, 632 insertions(+), 600 deletions(-) diff --git a/pyteal/ast/router.py b/pyteal/ast/router.py index ff7ed1f0f..d213d4532 100644 --- a/pyteal/ast/router.py +++ b/pyteal/ast/router.py @@ -587,7 +587,7 @@ def add_method_handler( f"with {self.method_selector_to_sig[method_selector]}" ) - meth = method_call.method_spec() + meth = method_call.method_spec(overriding_name) if description is not None: meth.desc = description self.methods.append(meth) diff --git a/pyteal/ast/router_test.py b/pyteal/ast/router_test.py index c7abd54fd..769edb5fc 100644 --- a/pyteal/ast/router_test.py +++ b/pyteal/ast/router_test.py @@ -199,605 +199,625 @@ def multiple_txn( ] -def power_set(no_dup_list: list, length_override: int = None): - """ - This function serves as a generator for all possible elements in power_set - over `non_dup_list`, which is a list of non-duplicated elements (matches property of a set). - - The cardinality of a powerset is 2^|non_dup_list|, so we can iterate from 0 to 2^|non_dup_list| - 1 - to index each element in such power_set. - By binary representation of each index, we can see it as an allowance over each element in `no_dup_list`, - and generate a unique subset of `non_dup_list`, which yields as an element of power_set of `no_dup_list`. - - Args: - no_dup_list: a list of elements with no duplication - length_override: a number indicating the largest size of super_set element, - must be in range [1, len(no_dup_list)]. - """ - if length_override is None: - length_override = len(no_dup_list) - assert 1 <= length_override <= len(no_dup_list) - masks = [1 << i for i in range(length_override)] - for i in range(1 << len(no_dup_list)): - yield [elem for mask, elem in zip(masks, no_dup_list) if i & mask] - - -def full_ordered_combination_gen(non_dup_list: list, perm_length: int): - """ - This function serves as a generator for all possible vectors of length `perm_length`, - each of whose entries are one of the elements in `non_dup_list`, - which is a list of non-duplicated elements. - - Args: - non_dup_list: must be a list of elements with no duplication - perm_length: must be a non-negative number indicating resulting length of the vector - """ - if perm_length < 0: - raise pt.TealInputError("input permutation length must be non-negative") - elif len(set(non_dup_list)) != len(non_dup_list): - raise pt.TealInputError(f"input non_dup_list {non_dup_list} has duplications") - elif perm_length == 0: - yield [] - return - # we can index all possible cases of vectors with an index in range - # [0, |non_dup_list| ^ perm_length - 1] - # by converting an index into |non_dup_list|-based number, - # we can get the vector mapped by the index. - for index in range(len(non_dup_list) ** perm_length): - index_list_basis = [] - temp = index - for _ in range(perm_length): - index_list_basis.append(non_dup_list[temp % len(non_dup_list)]) - temp //= len(non_dup_list) - yield index_list_basis - - -def oncomplete_is_in_oc_list(sth: pt.EnumInt, oc_list: list[pt.EnumInt]): - return any(map(lambda x: str(x) == str(sth), oc_list)) - - -def assemble_helper(what: pt.Expr) -> pt.TealBlock: - assembled, _ = what.__teal__(options) - assembled.addIncoming() - assembled = pt.TealBlock.NormalizeBlocks(assembled) - return assembled - - -def camel_to_snake(name: str) -> str: - return "".join(["_" + c.lower() if c.isupper() else c for c in name]).lstrip("_") - - -def test_call_config(): - for cc in pt.CallConfig: - approval_cond_on_cc: pt.Expr | int = cc.approval_condition_under_config() - match approval_cond_on_cc: - case pt.Expr(): - expected_cc = ( - (pt.Txn.application_id() == pt.Int(0)) - if cc == pt.CallConfig.CREATE - else (pt.Txn.application_id() != pt.Int(0)) - ) - with pt.TealComponent.Context.ignoreExprEquality(): - assert assemble_helper(approval_cond_on_cc) == assemble_helper( - expected_cc - ) - case int(): - assert approval_cond_on_cc == int(cc) & 1 - case _: - raise pt.TealInternalError( - f"unexpected approval_cond_on_cc {approval_cond_on_cc}" - ) - - if cc in (pt.CallConfig.CREATE, pt.CallConfig.ALL): - with pytest.raises( - pt.TealInputError, - match=r"Only CallConfig.CALL or CallConfig.NEVER are valid for a clear state CallConfig, since clear state can never be invoked during creation$", - ): - cc.clear_state_condition_under_config() - continue - - clear_state_cond_on_cc: int = cc.clear_state_condition_under_config() - match clear_state_cond_on_cc: - case 0: - assert cc == pt.CallConfig.NEVER - case 1: - assert cc == pt.CallConfig.CALL - case _: - raise pt.TealInternalError( - f"unexpected clear_state_cond_on_cc {clear_state_cond_on_cc}" - ) - - -def test_method_config(): - never_mc = pt.MethodConfig(no_op=pt.CallConfig.NEVER) - assert never_mc.is_never() - assert never_mc.approval_cond() == 0 - assert never_mc.clear_state_cond() == 0 - - on_complete_pow_set = power_set(ON_COMPLETE_CASES) - approval_check_names_n_ocs = [ - (camel_to_snake(oc.name), oc) - for oc in ON_COMPLETE_CASES - if str(oc) != str(pt.OnComplete.ClearState) - ] - for on_complete_set in on_complete_pow_set: - oc_names = [camel_to_snake(oc.name) for oc in on_complete_set] - ordered_call_configs = full_ordered_combination_gen( - list(pt.CallConfig), len(on_complete_set) - ) - for call_configs in ordered_call_configs: - mc = pt.MethodConfig(**dict(zip(oc_names, call_configs))) - match mc.clear_state: - case pt.CallConfig.NEVER: - assert mc.clear_state_cond() == 0 - case pt.CallConfig.CALL: - assert mc.clear_state_cond() == 1 - case pt.CallConfig.CREATE | pt.CallConfig.ALL: - with pytest.raises( - pt.TealInputError, - match=r"Only CallConfig.CALL or CallConfig.NEVER are valid for a clear state CallConfig, since clear state can never be invoked during creation$", - ): - mc.clear_state_cond() - if mc.is_never() or all( - getattr(mc, i) == pt.CallConfig.NEVER - for i, _ in approval_check_names_n_ocs - ): - assert mc.approval_cond() == 0 - continue - elif all( - getattr(mc, i) == pt.CallConfig.ALL - for i, _ in approval_check_names_n_ocs - ): - assert mc.approval_cond() == 1 - continue - list_of_cc = [ - ( - typing.cast( - pt.CallConfig, getattr(mc, i) - ).approval_condition_under_config(), - oc, - ) - for i, oc in approval_check_names_n_ocs - ] - list_of_expressions = [] - for expr_or_int, oc in list_of_cc: - match expr_or_int: - case pt.Expr(): - list_of_expressions.append( - pt.And(pt.Txn.on_completion() == oc, expr_or_int) - ) - case 0: - continue - case 1: - list_of_expressions.append(pt.Txn.on_completion() == oc) - with pt.TealComponent.Context.ignoreExprEquality(): - assert assemble_helper(mc.approval_cond()) == assemble_helper( - pt.Or(*list_of_expressions) - ) - - -def test_on_complete_action(): - with pytest.raises(pt.TealInputError) as contradict_err: - pt.OnCompleteAction(action=pt.Seq(), call_config=pt.CallConfig.NEVER) - assert "contradicts" in str(contradict_err) - assert pt.OnCompleteAction.never().is_empty() - assert pt.OnCompleteAction.call_only(pt.Seq()).call_config == pt.CallConfig.CALL - assert pt.OnCompleteAction.create_only(pt.Seq()).call_config == pt.CallConfig.CREATE - assert pt.OnCompleteAction.always(pt.Seq()).call_config == pt.CallConfig.ALL - - -def test_wrap_handler_bare_call(): - BARE_CALL_CASES = [ - dummy_doing_nothing, - safe_clear_state_delete, - pt.Approve(), - pt.Log(pt.Bytes("message")), - ] - for bare_call in BARE_CALL_CASES: - wrapped: pt.Expr = ASTBuilder.wrap_handler(False, bare_call) - expected: pt.Expr - match bare_call: - case pt.Expr(): - if bare_call.has_return(): - expected = bare_call - else: - expected = pt.Seq(bare_call, pt.Approve()) - case pt.SubroutineFnWrapper() | pt.ABIReturnSubroutine(): - expected = pt.Seq(bare_call(), pt.Approve()) - case _: - raise pt.TealInputError("how you got here?") - wrapped_assemble = assemble_helper(wrapped) - wrapped_helper = assemble_helper(expected) - with pt.TealComponent.Context.ignoreExprEquality(): - assert wrapped_assemble == wrapped_helper - - ERROR_CASES = [ - ( - pt.Int(1), - f"bare appcall handler should be TealType.none not {pt.TealType.uint64}.", - ), - ( - returning_u64, - f"subroutine call should be returning TealType.none not {pt.TealType.uint64}.", - ), - ( - mult_over_u64_and_log, - "subroutine call should take 0 arg for bare-app call. this subroutine takes 2.", - ), - ( - eine_constant, - f"abi-returning subroutine call should be returning void not {pt.abi.Uint64TypeSpec()}.", - ), - ( - take_abi_and_log, - "abi-returning subroutine call should take 0 arg for bare-app call. this abi-returning subroutine takes 1.", - ), - ( - 1, - "bare appcall can only accept: none type Expr, or Subroutine/ABIReturnSubroutine with none return and no arg", - ), - ] - for error_case, error_msg in ERROR_CASES: - with pytest.raises(pt.TealInputError) as bug: - ASTBuilder.wrap_handler(False, error_case) - assert error_msg in str(bug) - - -def test_wrap_handler_method_call(): - with pytest.raises(pt.TealInputError) as bug: - ASTBuilder.wrap_handler(True, not_registrable) - assert "method call ABIReturnSubroutine is not routable" in str(bug) - - with pytest.raises(pt.TealInputError) as bug: - ASTBuilder.wrap_handler(True, safe_clear_state_delete) - assert "method call should be only registering ABIReturnSubroutine" in str(bug) - - ONLY_ABI_SUBROUTINE_CASES = list( - filter(lambda x: isinstance(x, pt.ABIReturnSubroutine), GOOD_SUBROUTINE_CASES) - ) - - for abi_subroutine in ONLY_ABI_SUBROUTINE_CASES: - wrapped: pt.Expr = ASTBuilder.wrap_handler(True, abi_subroutine) - actual: pt.TealBlock = assemble_helper(wrapped) - - args: list[pt.abi.BaseType] = [ - spec.new_instance() - for spec in typing.cast( - list[pt.abi.TypeSpec], abi_subroutine.subroutine.expected_arg_types - ) - ] - - app_args = [ - arg for arg in args if arg.type_spec() not in pt.abi.TransactionTypeSpecs - ] - - app_arg_cnt = len(app_args) - - txn_args: list[pt.abi.Transaction] = [ - arg for arg in args if arg.type_spec() in pt.abi.TransactionTypeSpecs - ] - - loading: list[pt.Expr] = [] - - if app_arg_cnt > pt.METHOD_ARG_NUM_CUTOFF: - sdk_last_arg = pt.abi.TupleTypeSpec( - *[arg.type_spec() for arg in app_args[pt.METHOD_ARG_NUM_CUTOFF - 1 :]] - ).new_instance() - - loading = [ - arg.decode(pt.Txn.application_args[index + 1]) - for index, arg in enumerate(app_args[: pt.METHOD_ARG_NUM_CUTOFF - 1]) - ] - - loading.append( - sdk_last_arg.decode(pt.Txn.application_args[pt.METHOD_ARG_NUM_CUTOFF]) - ) - else: - loading = [ - arg.decode(pt.Txn.application_args[index + 1]) - for index, arg in enumerate(app_args) - ] - - if len(txn_args) > 0: - for idx, txn_arg in enumerate(txn_args): - loading.append( - txn_arg._set_index( - pt.Txn.group_index() - pt.Int(len(txn_args) - idx) - ) - ) - if str(txn_arg.type_spec()) != "txn": - loading.append( - pt.Assert( - txn_arg.get().type_enum() - == txn_arg.type_spec().txn_type_enum() - ) - ) - - if app_arg_cnt > pt.METHOD_ARG_NUM_CUTOFF: - loading.extend( - [ - sdk_last_arg[idx].store_into(val) - for idx, val in enumerate(app_args[pt.METHOD_ARG_NUM_CUTOFF - 1 :]) - ] - ) - - evaluate: pt.Expr - if abi_subroutine.type_of() != "void": - output_temp = abi_subroutine.output_kwarg_info.abi_type.new_instance() - evaluate = pt.Seq( - abi_subroutine(*args).store_into(output_temp), - pt.abi.MethodReturn(output_temp), - ) - else: - evaluate = abi_subroutine(*args) - - expected = assemble_helper(pt.Seq(*loading, evaluate, pt.Approve())) - with pt.TealComponent.Context.ignoreScratchSlotEquality(), pt.TealComponent.Context.ignoreExprEquality(): - assert actual == expected - - assert pt.TealBlock.MatchScratchSlotReferences( - pt.TealBlock.GetReferencedScratchSlots(actual), - pt.TealBlock.GetReferencedScratchSlots(expected), - ) - - -def test_wrap_handler_method_txn_types(): - wrapped: pt.Expr = ASTBuilder.wrap_handler(True, multiple_txn) - actual: pt.TealBlock = assemble_helper(wrapped) - - args: list[pt.abi.Transaction] = [ - pt.abi.ApplicationCallTransaction(), - pt.abi.AssetTransferTransaction(), - pt.abi.PaymentTransaction(), - pt.abi.Transaction(), - ] - output_temp = pt.abi.Uint64() - expected_ast = pt.Seq( - args[0]._set_index(pt.Txn.group_index() - pt.Int(4)), - pt.Assert(args[0].get().type_enum() == pt.TxnType.ApplicationCall), - args[1]._set_index(pt.Txn.group_index() - pt.Int(3)), - pt.Assert(args[1].get().type_enum() == pt.TxnType.AssetTransfer), - args[2]._set_index(pt.Txn.group_index() - pt.Int(2)), - pt.Assert(args[2].get().type_enum() == pt.TxnType.Payment), - args[3]._set_index(pt.Txn.group_index() - pt.Int(1)), - multiple_txn(*args).store_into(output_temp), - pt.abi.MethodReturn(output_temp), - pt.Approve(), - ) - - expected = assemble_helper(expected_ast) - with pt.TealComponent.Context.ignoreScratchSlotEquality(), pt.TealComponent.Context.ignoreExprEquality(): - assert actual == expected - - assert pt.TealBlock.MatchScratchSlotReferences( - pt.TealBlock.GetReferencedScratchSlots(actual), - pt.TealBlock.GetReferencedScratchSlots(expected), - ) - - -def test_wrap_handler_method_call_many_args(): - wrapped: pt.Expr = ASTBuilder.wrap_handler(True, many_args) - actual: pt.TealBlock = assemble_helper(wrapped) - - args = [pt.abi.Uint64() for _ in range(20)] - last_arg = pt.abi.TupleTypeSpec( - *[pt.abi.Uint64TypeSpec() for _ in range(6)] - ).new_instance() - - output_temp = pt.abi.Uint64() - expected_ast = pt.Seq( - args[0].decode(pt.Txn.application_args[1]), - args[1].decode(pt.Txn.application_args[2]), - args[2].decode(pt.Txn.application_args[3]), - args[3].decode(pt.Txn.application_args[4]), - args[4].decode(pt.Txn.application_args[5]), - args[5].decode(pt.Txn.application_args[6]), - args[6].decode(pt.Txn.application_args[7]), - args[7].decode(pt.Txn.application_args[8]), - args[8].decode(pt.Txn.application_args[9]), - args[9].decode(pt.Txn.application_args[10]), - args[10].decode(pt.Txn.application_args[11]), - args[11].decode(pt.Txn.application_args[12]), - args[12].decode(pt.Txn.application_args[13]), - args[13].decode(pt.Txn.application_args[14]), - last_arg.decode(pt.Txn.application_args[15]), - last_arg[0].store_into(args[14]), - last_arg[1].store_into(args[15]), - last_arg[2].store_into(args[16]), - last_arg[3].store_into(args[17]), - last_arg[4].store_into(args[18]), - last_arg[5].store_into(args[19]), - many_args(*args).store_into(output_temp), - pt.abi.MethodReturn(output_temp), - pt.Approve(), - ) - expected = assemble_helper(expected_ast) - with pt.TealComponent.Context.ignoreScratchSlotEquality(), pt.TealComponent.Context.ignoreExprEquality(): - assert actual == expected - - assert pt.TealBlock.MatchScratchSlotReferences( - pt.TealBlock.GetReferencedScratchSlots(actual), - pt.TealBlock.GetReferencedScratchSlots(expected), - ) - - -def test_contract_json_obj(): - abi_subroutines = list( - filter(lambda x: isinstance(x, pt.ABIReturnSubroutine), GOOD_SUBROUTINE_CASES) - ) - contract_name = "contract_name" - on_complete_actions = pt.BareCallActions( - clear_state=pt.OnCompleteAction.call_only(safe_clear_state_delete) - ) - router = pt.Router(contract_name, on_complete_actions) - method_list: list[sdk_abi.Method] = [] - for subroutine in abi_subroutines: - - doc = subroutine.subroutine.implementation.__doc__ - desc = None - if doc is not None and doc.strip() == "replace me": - desc = "dope description" - - router.add_method_handler(subroutine, description=desc) - - ms = subroutine.method_spec() - - # Manually replace it since the override is applied in the method handler - # not attached to the ABIReturnSubroutine itself - ms.desc = desc if desc is not None else ms.desc - - sig_method = sdk_abi.Method.from_signature(subroutine.method_signature()) - - assert ms.name == sig_method.name - - for idx, arg in enumerate(ms.args): - assert arg.type == sig_method.args[idx].type - - method_list.append(ms) +# def power_set(no_dup_list: list, length_override: int = None): +# """ +# This function serves as a generator for all possible elements in power_set +# over `non_dup_list`, which is a list of non-duplicated elements (matches property of a set). +# +# The cardinality of a powerset is 2^|non_dup_list|, so we can iterate from 0 to 2^|non_dup_list| - 1 +# to index each element in such power_set. +# By binary representation of each index, we can see it as an allowance over each element in `no_dup_list`, +# and generate a unique subset of `non_dup_list`, which yields as an element of power_set of `no_dup_list`. +# +# Args: +# no_dup_list: a list of elements with no duplication +# length_override: a number indicating the largest size of super_set element, +# must be in range [1, len(no_dup_list)]. +# """ +# if length_override is None: +# length_override = len(no_dup_list) +# assert 1 <= length_override <= len(no_dup_list) +# masks = [1 << i for i in range(length_override)] +# for i in range(1 << len(no_dup_list)): +# yield [elem for mask, elem in zip(masks, no_dup_list) if i & mask] +# +# +# def full_ordered_combination_gen(non_dup_list: list, perm_length: int): +# """ +# This function serves as a generator for all possible vectors of length `perm_length`, +# each of whose entries are one of the elements in `non_dup_list`, +# which is a list of non-duplicated elements. +# +# Args: +# non_dup_list: must be a list of elements with no duplication +# perm_length: must be a non-negative number indicating resulting length of the vector +# """ +# if perm_length < 0: +# raise pt.TealInputError("input permutation length must be non-negative") +# elif len(set(non_dup_list)) != len(non_dup_list): +# raise pt.TealInputError(f"input non_dup_list {non_dup_list} has duplications") +# elif perm_length == 0: +# yield [] +# return +# # we can index all possible cases of vectors with an index in range +# # [0, |non_dup_list| ^ perm_length - 1] +# # by converting an index into |non_dup_list|-based number, +# # we can get the vector mapped by the index. +# for index in range(len(non_dup_list) ** perm_length): +# index_list_basis = [] +# temp = index +# for _ in range(perm_length): +# index_list_basis.append(non_dup_list[temp % len(non_dup_list)]) +# temp //= len(non_dup_list) +# yield index_list_basis +# +# +# def oncomplete_is_in_oc_list(sth: pt.EnumInt, oc_list: list[pt.EnumInt]): +# return any(map(lambda x: str(x) == str(sth), oc_list)) +# +# +# def assemble_helper(what: pt.Expr) -> pt.TealBlock: +# assembled, _ = what.__teal__(options) +# assembled.addIncoming() +# assembled = pt.TealBlock.NormalizeBlocks(assembled) +# return assembled +# +# +# def camel_to_snake(name: str) -> str: +# return "".join(["_" + c.lower() if c.isupper() else c for c in name]).lstrip("_") +# +# +# def test_call_config(): +# for cc in pt.CallConfig: +# approval_cond_on_cc: pt.Expr | int = cc.approval_condition_under_config() +# match approval_cond_on_cc: +# case pt.Expr(): +# expected_cc = ( +# (pt.Txn.application_id() == pt.Int(0)) +# if cc == pt.CallConfig.CREATE +# else (pt.Txn.application_id() != pt.Int(0)) +# ) +# with pt.TealComponent.Context.ignoreExprEquality(): +# assert assemble_helper(approval_cond_on_cc) == assemble_helper( +# expected_cc +# ) +# case int(): +# assert approval_cond_on_cc == int(cc) & 1 +# case _: +# raise pt.TealInternalError( +# f"unexpected approval_cond_on_cc {approval_cond_on_cc}" +# ) +# +# if cc in (pt.CallConfig.CREATE, pt.CallConfig.ALL): +# with pytest.raises( +# pt.TealInputError, +# match=r"Only CallConfig.CALL or CallConfig.NEVER are valid for a clear state CallConfig, since clear state can never be invoked during creation$", +# ): +# cc.clear_state_condition_under_config() +# continue +# +# clear_state_cond_on_cc: int = cc.clear_state_condition_under_config() +# match clear_state_cond_on_cc: +# case 0: +# assert cc == pt.CallConfig.NEVER +# case 1: +# assert cc == pt.CallConfig.CALL +# case _: +# raise pt.TealInternalError( +# f"unexpected clear_state_cond_on_cc {clear_state_cond_on_cc}" +# ) +# +# +# def test_method_config(): +# never_mc = pt.MethodConfig(no_op=pt.CallConfig.NEVER) +# assert never_mc.is_never() +# assert never_mc.approval_cond() == 0 +# assert never_mc.clear_state_cond() == 0 +# +# on_complete_pow_set = power_set(ON_COMPLETE_CASES) +# approval_check_names_n_ocs = [ +# (camel_to_snake(oc.name), oc) +# for oc in ON_COMPLETE_CASES +# if str(oc) != str(pt.OnComplete.ClearState) +# ] +# for on_complete_set in on_complete_pow_set: +# oc_names = [camel_to_snake(oc.name) for oc in on_complete_set] +# ordered_call_configs = full_ordered_combination_gen( +# list(pt.CallConfig), len(on_complete_set) +# ) +# for call_configs in ordered_call_configs: +# mc = pt.MethodConfig(**dict(zip(oc_names, call_configs))) +# match mc.clear_state: +# case pt.CallConfig.NEVER: +# assert mc.clear_state_cond() == 0 +# case pt.CallConfig.CALL: +# assert mc.clear_state_cond() == 1 +# case pt.CallConfig.CREATE | pt.CallConfig.ALL: +# with pytest.raises( +# pt.TealInputError, +# match=r"Only CallConfig.CALL or CallConfig.NEVER are valid for a clear state CallConfig, since clear state can never be invoked during creation$", +# ): +# mc.clear_state_cond() +# if mc.is_never() or all( +# getattr(mc, i) == pt.CallConfig.NEVER +# for i, _ in approval_check_names_n_ocs +# ): +# assert mc.approval_cond() == 0 +# continue +# elif all( +# getattr(mc, i) == pt.CallConfig.ALL +# for i, _ in approval_check_names_n_ocs +# ): +# assert mc.approval_cond() == 1 +# continue +# list_of_cc = [ +# ( +# typing.cast( +# pt.CallConfig, getattr(mc, i) +# ).approval_condition_under_config(), +# oc, +# ) +# for i, oc in approval_check_names_n_ocs +# ] +# list_of_expressions = [] +# for expr_or_int, oc in list_of_cc: +# match expr_or_int: +# case pt.Expr(): +# list_of_expressions.append( +# pt.And(pt.Txn.on_completion() == oc, expr_or_int) +# ) +# case 0: +# continue +# case 1: +# list_of_expressions.append(pt.Txn.on_completion() == oc) +# with pt.TealComponent.Context.ignoreExprEquality(): +# assert assemble_helper(mc.approval_cond()) == assemble_helper( +# pt.Or(*list_of_expressions) +# ) +# +# +# def test_on_complete_action(): +# with pytest.raises(pt.TealInputError) as contradict_err: +# pt.OnCompleteAction(action=pt.Seq(), call_config=pt.CallConfig.NEVER) +# assert "contradicts" in str(contradict_err) +# assert pt.OnCompleteAction.never().is_empty() +# assert pt.OnCompleteAction.call_only(pt.Seq()).call_config == pt.CallConfig.CALL +# assert pt.OnCompleteAction.create_only(pt.Seq()).call_config == pt.CallConfig.CREATE +# assert pt.OnCompleteAction.always(pt.Seq()).call_config == pt.CallConfig.ALL +# +# +# def test_wrap_handler_bare_call(): +# BARE_CALL_CASES = [ +# dummy_doing_nothing, +# safe_clear_state_delete, +# pt.Approve(), +# pt.Log(pt.Bytes("message")), +# ] +# for bare_call in BARE_CALL_CASES: +# wrapped: pt.Expr = ASTBuilder.wrap_handler(False, bare_call) +# expected: pt.Expr +# match bare_call: +# case pt.Expr(): +# if bare_call.has_return(): +# expected = bare_call +# else: +# expected = pt.Seq(bare_call, pt.Approve()) +# case pt.SubroutineFnWrapper() | pt.ABIReturnSubroutine(): +# expected = pt.Seq(bare_call(), pt.Approve()) +# case _: +# raise pt.TealInputError("how you got here?") +# wrapped_assemble = assemble_helper(wrapped) +# wrapped_helper = assemble_helper(expected) +# with pt.TealComponent.Context.ignoreExprEquality(): +# assert wrapped_assemble == wrapped_helper +# +# ERROR_CASES = [ +# ( +# pt.Int(1), +# f"bare appcall handler should be TealType.none not {pt.TealType.uint64}.", +# ), +# ( +# returning_u64, +# f"subroutine call should be returning TealType.none not {pt.TealType.uint64}.", +# ), +# ( +# mult_over_u64_and_log, +# "subroutine call should take 0 arg for bare-app call. this subroutine takes 2.", +# ), +# ( +# eine_constant, +# f"abi-returning subroutine call should be returning void not {pt.abi.Uint64TypeSpec()}.", +# ), +# ( +# take_abi_and_log, +# "abi-returning subroutine call should take 0 arg for bare-app call. this abi-returning subroutine takes 1.", +# ), +# ( +# 1, +# "bare appcall can only accept: none type Expr, or Subroutine/ABIReturnSubroutine with none return and no arg", +# ), +# ] +# for error_case, error_msg in ERROR_CASES: +# with pytest.raises(pt.TealInputError) as bug: +# ASTBuilder.wrap_handler(False, error_case) +# assert error_msg in str(bug) +# +# +# def test_wrap_handler_method_call(): +# with pytest.raises(pt.TealInputError) as bug: +# ASTBuilder.wrap_handler(True, not_registrable) +# assert "method call ABIReturnSubroutine is not routable" in str(bug) +# +# with pytest.raises(pt.TealInputError) as bug: +# ASTBuilder.wrap_handler(True, safe_clear_state_delete) +# assert "method call should be only registering ABIReturnSubroutine" in str(bug) +# +# ONLY_ABI_SUBROUTINE_CASES = list( +# filter(lambda x: isinstance(x, pt.ABIReturnSubroutine), GOOD_SUBROUTINE_CASES) +# ) +# +# for abi_subroutine in ONLY_ABI_SUBROUTINE_CASES: +# wrapped: pt.Expr = ASTBuilder.wrap_handler(True, abi_subroutine) +# actual: pt.TealBlock = assemble_helper(wrapped) +# +# args: list[pt.abi.BaseType] = [ +# spec.new_instance() +# for spec in typing.cast( +# list[pt.abi.TypeSpec], abi_subroutine.subroutine.expected_arg_types +# ) +# ] +# +# app_args = [ +# arg for arg in args if arg.type_spec() not in pt.abi.TransactionTypeSpecs +# ] +# +# app_arg_cnt = len(app_args) +# +# txn_args: list[pt.abi.Transaction] = [ +# arg for arg in args if arg.type_spec() in pt.abi.TransactionTypeSpecs +# ] +# +# loading: list[pt.Expr] = [] +# +# if app_arg_cnt > pt.METHOD_ARG_NUM_CUTOFF: +# sdk_last_arg = pt.abi.TupleTypeSpec( +# *[arg.type_spec() for arg in app_args[pt.METHOD_ARG_NUM_CUTOFF - 1 :]] +# ).new_instance() +# +# loading = [ +# arg.decode(pt.Txn.application_args[index + 1]) +# for index, arg in enumerate(app_args[: pt.METHOD_ARG_NUM_CUTOFF - 1]) +# ] +# +# loading.append( +# sdk_last_arg.decode(pt.Txn.application_args[pt.METHOD_ARG_NUM_CUTOFF]) +# ) +# else: +# loading = [ +# arg.decode(pt.Txn.application_args[index + 1]) +# for index, arg in enumerate(app_args) +# ] +# +# if len(txn_args) > 0: +# for idx, txn_arg in enumerate(txn_args): +# loading.append( +# txn_arg._set_index( +# pt.Txn.group_index() - pt.Int(len(txn_args) - idx) +# ) +# ) +# if str(txn_arg.type_spec()) != "txn": +# loading.append( +# pt.Assert( +# txn_arg.get().type_enum() +# == txn_arg.type_spec().txn_type_enum() +# ) +# ) +# +# if app_arg_cnt > pt.METHOD_ARG_NUM_CUTOFF: +# loading.extend( +# [ +# sdk_last_arg[idx].store_into(val) +# for idx, val in enumerate(app_args[pt.METHOD_ARG_NUM_CUTOFF - 1 :]) +# ] +# ) +# +# evaluate: pt.Expr +# if abi_subroutine.type_of() != "void": +# output_temp = abi_subroutine.output_kwarg_info.abi_type.new_instance() +# evaluate = pt.Seq( +# abi_subroutine(*args).store_into(output_temp), +# pt.abi.MethodReturn(output_temp), +# ) +# else: +# evaluate = abi_subroutine(*args) +# +# expected = assemble_helper(pt.Seq(*loading, evaluate, pt.Approve())) +# with pt.TealComponent.Context.ignoreScratchSlotEquality(), pt.TealComponent.Context.ignoreExprEquality(): +# assert actual == expected +# +# assert pt.TealBlock.MatchScratchSlotReferences( +# pt.TealBlock.GetReferencedScratchSlots(actual), +# pt.TealBlock.GetReferencedScratchSlots(expected), +# ) +# +# +# def test_wrap_handler_method_txn_types(): +# wrapped: pt.Expr = ASTBuilder.wrap_handler(True, multiple_txn) +# actual: pt.TealBlock = assemble_helper(wrapped) +# +# args: list[pt.abi.Transaction] = [ +# pt.abi.ApplicationCallTransaction(), +# pt.abi.AssetTransferTransaction(), +# pt.abi.PaymentTransaction(), +# pt.abi.Transaction(), +# ] +# output_temp = pt.abi.Uint64() +# expected_ast = pt.Seq( +# args[0]._set_index(pt.Txn.group_index() - pt.Int(4)), +# pt.Assert(args[0].get().type_enum() == pt.TxnType.ApplicationCall), +# args[1]._set_index(pt.Txn.group_index() - pt.Int(3)), +# pt.Assert(args[1].get().type_enum() == pt.TxnType.AssetTransfer), +# args[2]._set_index(pt.Txn.group_index() - pt.Int(2)), +# pt.Assert(args[2].get().type_enum() == pt.TxnType.Payment), +# args[3]._set_index(pt.Txn.group_index() - pt.Int(1)), +# multiple_txn(*args).store_into(output_temp), +# pt.abi.MethodReturn(output_temp), +# pt.Approve(), +# ) +# +# expected = assemble_helper(expected_ast) +# with pt.TealComponent.Context.ignoreScratchSlotEquality(), pt.TealComponent.Context.ignoreExprEquality(): +# assert actual == expected +# +# assert pt.TealBlock.MatchScratchSlotReferences( +# pt.TealBlock.GetReferencedScratchSlots(actual), +# pt.TealBlock.GetReferencedScratchSlots(expected), +# ) +# +# +# def test_wrap_handler_method_call_many_args(): +# wrapped: pt.Expr = ASTBuilder.wrap_handler(True, many_args) +# actual: pt.TealBlock = assemble_helper(wrapped) +# +# args = [pt.abi.Uint64() for _ in range(20)] +# last_arg = pt.abi.TupleTypeSpec( +# *[pt.abi.Uint64TypeSpec() for _ in range(6)] +# ).new_instance() +# +# output_temp = pt.abi.Uint64() +# expected_ast = pt.Seq( +# args[0].decode(pt.Txn.application_args[1]), +# args[1].decode(pt.Txn.application_args[2]), +# args[2].decode(pt.Txn.application_args[3]), +# args[3].decode(pt.Txn.application_args[4]), +# args[4].decode(pt.Txn.application_args[5]), +# args[5].decode(pt.Txn.application_args[6]), +# args[6].decode(pt.Txn.application_args[7]), +# args[7].decode(pt.Txn.application_args[8]), +# args[8].decode(pt.Txn.application_args[9]), +# args[9].decode(pt.Txn.application_args[10]), +# args[10].decode(pt.Txn.application_args[11]), +# args[11].decode(pt.Txn.application_args[12]), +# args[12].decode(pt.Txn.application_args[13]), +# args[13].decode(pt.Txn.application_args[14]), +# last_arg.decode(pt.Txn.application_args[15]), +# last_arg[0].store_into(args[14]), +# last_arg[1].store_into(args[15]), +# last_arg[2].store_into(args[16]), +# last_arg[3].store_into(args[17]), +# last_arg[4].store_into(args[18]), +# last_arg[5].store_into(args[19]), +# many_args(*args).store_into(output_temp), +# pt.abi.MethodReturn(output_temp), +# pt.Approve(), +# ) +# expected = assemble_helper(expected_ast) +# with pt.TealComponent.Context.ignoreScratchSlotEquality(), pt.TealComponent.Context.ignoreExprEquality(): +# assert actual == expected +# +# assert pt.TealBlock.MatchScratchSlotReferences( +# pt.TealBlock.GetReferencedScratchSlots(actual), +# pt.TealBlock.GetReferencedScratchSlots(expected), +# ) +# +# +# def test_contract_json_obj(): +# abi_subroutines = list( +# filter(lambda x: isinstance(x, pt.ABIReturnSubroutine), GOOD_SUBROUTINE_CASES) +# ) +# contract_name = "contract_name" +# on_complete_actions = pt.BareCallActions( +# clear_state=pt.OnCompleteAction.call_only(safe_clear_state_delete) +# ) +# router = pt.Router(contract_name, on_complete_actions) +# method_list: list[sdk_abi.Method] = [] +# for subroutine in abi_subroutines: +# +# doc = subroutine.subroutine.implementation.__doc__ +# desc = None +# if doc is not None and doc.strip() == "replace me": +# desc = "dope description" +# +# router.add_method_handler(subroutine, description=desc) +# +# ms = subroutine.method_spec() +# +# # Manually replace it since the override is applied in the method handler +# # not attached to the ABIReturnSubroutine itself +# ms.desc = desc if desc is not None else ms.desc +# +# sig_method = sdk_abi.Method.from_signature(subroutine.method_signature()) +# +# assert ms.name == sig_method.name +# +# for idx, arg in enumerate(ms.args): +# assert arg.type == sig_method.args[idx].type +# +# method_list.append(ms) +# +# sdk_contract = sdk_abi.Contract(contract_name, method_list) +# contract = router.contract_construct() +# assert contract == sdk_contract +# +# +# def test_build_program_all_empty(): +# router = pt.Router("test") +# +# approval, clear_state, contract = router.build_program() +# +# expected_empty_program = pt.TealSimpleBlock( +# [ +# pt.TealOp(None, pt.Op.int, 0), +# pt.TealOp(None, pt.Op.return_), +# ] +# ) +# +# with pt.TealComponent.Context.ignoreExprEquality(): +# assert assemble_helper(approval) == expected_empty_program +# assert assemble_helper(clear_state) == expected_empty_program +# +# expected_contract = sdk_abi.Contract("test", []) +# assert contract == expected_contract +# +# +# def test_build_program_approval_empty(): +# router = pt.Router( +# "test", +# pt.BareCallActions(clear_state=pt.OnCompleteAction.call_only(pt.Approve())), +# ) +# +# approval, clear_state, contract = router.build_program() +# +# expected_empty_program = pt.TealSimpleBlock( +# [ +# pt.TealOp(None, pt.Op.int, 0), +# pt.TealOp(None, pt.Op.return_), +# ] +# ) +# +# with pt.TealComponent.Context.ignoreExprEquality(): +# assert assemble_helper(approval) == expected_empty_program +# assert assemble_helper(clear_state) != expected_empty_program +# +# expected_contract = sdk_abi.Contract("test", []) +# assert contract == expected_contract +# +# +# def test_build_program_clear_state_empty(): +# router = pt.Router( +# "test", pt.BareCallActions(no_op=pt.OnCompleteAction.always(pt.Approve())) +# ) +# +# approval, clear_state, contract = router.build_program() +# +# expected_empty_program = pt.TealSimpleBlock( +# [ +# pt.TealOp(None, pt.Op.int, 0), +# pt.TealOp(None, pt.Op.return_), +# ] +# ) +# +# with pt.TealComponent.Context.ignoreExprEquality(): +# assert assemble_helper(approval) != expected_empty_program +# assert assemble_helper(clear_state) == expected_empty_program +# +# expected_contract = sdk_abi.Contract("test", []) +# assert contract == expected_contract +# +# +# def test_build_program_clear_state_invalid_config(): +# for config in (pt.CallConfig.CREATE, pt.CallConfig.ALL): +# bareCalls = pt.BareCallActions( +# clear_state=pt.OnCompleteAction(action=pt.Approve(), call_config=config) +# ) +# with pytest.raises( +# pt.TealInputError, +# match=r"Only CallConfig.CALL or CallConfig.NEVER are valid for a clear state CallConfig, since clear state can never be invoked during creation$", +# ): +# pt.Router("test", bareCalls) +# +# router = pt.Router("test") +# +# @pt.ABIReturnSubroutine +# def clear_state_method(): +# return pt.Approve() +# +# with pytest.raises( +# pt.TealInputError, +# match=r"Only CallConfig.CALL or CallConfig.NEVER are valid for a clear state CallConfig, since clear state can never be invoked during creation$", +# ): +# router.add_method_handler( +# clear_state_method, +# method_config=pt.MethodConfig(clear_state=config), +# ) +# +# +# def test_build_program_clear_state_valid_config(): +# action = pt.If(pt.Txn.fee() == pt.Int(4)).Then(pt.Approve()).Else(pt.Reject()) +# config = pt.CallConfig.CALL +# +# router_with_bare_call = pt.Router( +# "test", +# pt.BareCallActions( +# clear_state=pt.OnCompleteAction(action=action, call_config=config) +# ), +# ) +# _, actual_clear_state_with_bare_call, _ = router_with_bare_call.build_program() +# +# expected_clear_state_with_bare_call = assemble_helper( +# pt.Cond([pt.Txn.application_args.length() == pt.Int(0), action]) +# ) +# +# with pt.TealComponent.Context.ignoreExprEquality(): +# assert ( +# assemble_helper(actual_clear_state_with_bare_call) +# == expected_clear_state_with_bare_call +# ) +# +# router_with_method = pt.Router("test") +# +# @pt.ABIReturnSubroutine +# def clear_state_method(): +# return action +# +# router_with_method.add_method_handler( +# clear_state_method, method_config=pt.MethodConfig(clear_state=config) +# ) +# +# _, actual_clear_state_with_method, _ = router_with_method.build_program() +# +# expected_clear_state_with_method = assemble_helper( +# pt.Cond( +# [ +# pt.Txn.application_args[0] +# == pt.MethodSignature("clear_state_method()void"), +# pt.Seq(clear_state_method(), pt.Approve()), +# ] +# ) +# ) +# +# with pt.TealComponent.Context.ignoreExprEquality(): +# assert ( +# assemble_helper(actual_clear_state_with_method) +# == expected_clear_state_with_method +# ) +# +# +def test_override_names(): - sdk_contract = sdk_abi.Contract(contract_name, method_list) - contract = router.contract_construct() - assert contract == sdk_contract - - -def test_build_program_all_empty(): router = pt.Router("test") - approval, clear_state, contract = router.build_program() - - expected_empty_program = pt.TealSimpleBlock( - [ - pt.TealOp(None, pt.Op.int, 0), - pt.TealOp(None, pt.Op.return_), - ] - ) - - with pt.TealComponent.Context.ignoreExprEquality(): - assert assemble_helper(approval) == expected_empty_program - assert assemble_helper(clear_state) == expected_empty_program - - expected_contract = sdk_abi.Contract("test", []) - assert contract == expected_contract - - -def test_build_program_approval_empty(): - router = pt.Router( - "test", - pt.BareCallActions(clear_state=pt.OnCompleteAction.call_only(pt.Approve())), - ) - - approval, clear_state, contract = router.build_program() - - expected_empty_program = pt.TealSimpleBlock( - [ - pt.TealOp(None, pt.Op.int, 0), - pt.TealOp(None, pt.Op.return_), - ] - ) - - with pt.TealComponent.Context.ignoreExprEquality(): - assert assemble_helper(approval) == expected_empty_program - assert assemble_helper(clear_state) != expected_empty_program - - expected_contract = sdk_abi.Contract("test", []) - assert contract == expected_contract - - -def test_build_program_clear_state_empty(): - router = pt.Router( - "test", pt.BareCallActions(no_op=pt.OnCompleteAction.always(pt.Approve())) - ) - - approval, clear_state, contract = router.build_program() - - expected_empty_program = pt.TealSimpleBlock( - [ - pt.TealOp(None, pt.Op.int, 0), - pt.TealOp(None, pt.Op.return_), - ] - ) - - with pt.TealComponent.Context.ignoreExprEquality(): - assert assemble_helper(approval) != expected_empty_program - assert assemble_helper(clear_state) == expected_empty_program - - expected_contract = sdk_abi.Contract("test", []) - assert contract == expected_contract - - -def test_build_program_clear_state_invalid_config(): - for config in (pt.CallConfig.CREATE, pt.CallConfig.ALL): - bareCalls = pt.BareCallActions( - clear_state=pt.OnCompleteAction(action=pt.Approve(), call_config=config) - ) - with pytest.raises( - pt.TealInputError, - match=r"Only CallConfig.CALL or CallConfig.NEVER are valid for a clear state CallConfig, since clear state can never be invoked during creation$", - ): - pt.Router("test", bareCalls) - - router = pt.Router("test") - - @pt.ABIReturnSubroutine - def clear_state_method(): - return pt.Approve() - - with pytest.raises( - pt.TealInputError, - match=r"Only CallConfig.CALL or CallConfig.NEVER are valid for a clear state CallConfig, since clear state can never be invoked during creation$", - ): - router.add_method_handler( - clear_state_method, - method_config=pt.MethodConfig(clear_state=config), - ) - - -def test_build_program_clear_state_valid_config(): - action = pt.If(pt.Txn.fee() == pt.Int(4)).Then(pt.Approve()).Else(pt.Reject()) - config = pt.CallConfig.CALL - - router_with_bare_call = pt.Router( - "test", - pt.BareCallActions( - clear_state=pt.OnCompleteAction(action=action, call_config=config) - ), - ) - _, actual_clear_state_with_bare_call, _ = router_with_bare_call.build_program() - - expected_clear_state_with_bare_call = assemble_helper( - pt.Cond([pt.Txn.application_args.length() == pt.Int(0), action]) - ) - - with pt.TealComponent.Context.ignoreExprEquality(): - assert ( - assemble_helper(actual_clear_state_with_bare_call) - == expected_clear_state_with_bare_call - ) - - router_with_method = pt.Router("test") - - @pt.ABIReturnSubroutine - def clear_state_method(): - return action + @router.method(name="handle") + def handle_asa(deposit: pt.abi.AssetTransferTransaction): + """handles the deposit where the input is an asset transfer""" + return pt.Assert(deposit.get().asset_amount() > pt.Int(0)) - router_with_method.add_method_handler( - clear_state_method, method_config=pt.MethodConfig(clear_state=config) - ) - - _, actual_clear_state_with_method, _ = router_with_method.build_program() - - expected_clear_state_with_method = assemble_helper( - pt.Cond( - [ - pt.Txn.application_args[0] - == pt.MethodSignature("clear_state_method()void"), - pt.Seq(clear_state_method(), pt.Approve()), - ] - ) - ) + @router.method(name="handle") + def handle_algo(deposit: pt.abi.PaymentTransaction): + """handles the deposit where the input is a payment""" + return pt.Assert(deposit.get().amount() > pt.Int(0)) - with pt.TealComponent.Context.ignoreExprEquality(): - assert ( - assemble_helper(actual_clear_state_with_method) - == expected_clear_state_with_method - ) + approval, clear, contract = router.compile_program(version=7) + assert len(contract.methods) > 0 + for meth in contract.methods: + print(meth.dictify()) diff --git a/pyteal/ast/subroutine.py b/pyteal/ast/subroutine.py index fc550d442..feb9050ab 100644 --- a/pyteal/ast/subroutine.py +++ b/pyteal/ast/subroutine.py @@ -610,7 +610,7 @@ def method_signature(self, overriding_name: str = None) -> str: overriding_name = self.name() return f"{overriding_name}({','.join(args)}){self.type_of()}" - def method_spec(self) -> sdk_abi.Method: + def method_spec(self, overriding_name: str = None) -> sdk_abi.Method: desc: str = "" arg_descs: dict[str, str] = {} return_desc: str = "" @@ -674,7 +674,8 @@ def method_spec(self) -> sdk_abi.Method: return_obj["desc"] = return_desc # Create the method spec, adding description if set - spec = {"name": self.name(), "args": args, "returns": return_obj} + name = overriding_name if overriding_name is not None else self.name() + spec = {"name": name, "args": args, "returns": return_obj} if desc: spec["desc"] = desc diff --git a/pyteal/ast/subroutine_test.py b/pyteal/ast/subroutine_test.py index 9b5087efa..c3c2d01e0 100644 --- a/pyteal/ast/subroutine_test.py +++ b/pyteal/ast/subroutine_test.py @@ -1464,3 +1464,14 @@ def withdraw(amount: pt.abi.Uint64, recipient: pt.abi.Account): == "An account who will receive the withdrawn Algos. This may or may not be the same as the method call sender." ) assert "desc" not in mspec_dict["returns"] + + +def test_override_abi_method_name(): + def abi_meth(a: pt.abi.Uint64, b: pt.abi.Uint64, *, output: pt.abi.Uint64): + return output.set(a.get() + b.get()) + + mspec = ABIReturnSubroutine(abi_meth).method_spec().dictify() + assert mspec["name"] == "abi_meth" + + mspec = ABIReturnSubroutine(abi_meth).method_spec("add").dictify() + assert mspec["name"] == "add" From 38f7c81ee671024ed1f665cb1cfde550cfbee71e Mon Sep 17 00:00:00 2001 From: Ben Guidarelli Date: Wed, 5 Oct 2022 10:20:26 -0400 Subject: [PATCH 2/8] uncomment long tests --- pyteal/ast/router_test.py | 1208 ++++++++++++++++++------------------- 1 file changed, 604 insertions(+), 604 deletions(-) diff --git a/pyteal/ast/router_test.py b/pyteal/ast/router_test.py index 769edb5fc..764325dec 100644 --- a/pyteal/ast/router_test.py +++ b/pyteal/ast/router_test.py @@ -199,610 +199,610 @@ def multiple_txn( ] -# def power_set(no_dup_list: list, length_override: int = None): -# """ -# This function serves as a generator for all possible elements in power_set -# over `non_dup_list`, which is a list of non-duplicated elements (matches property of a set). -# -# The cardinality of a powerset is 2^|non_dup_list|, so we can iterate from 0 to 2^|non_dup_list| - 1 -# to index each element in such power_set. -# By binary representation of each index, we can see it as an allowance over each element in `no_dup_list`, -# and generate a unique subset of `non_dup_list`, which yields as an element of power_set of `no_dup_list`. -# -# Args: -# no_dup_list: a list of elements with no duplication -# length_override: a number indicating the largest size of super_set element, -# must be in range [1, len(no_dup_list)]. -# """ -# if length_override is None: -# length_override = len(no_dup_list) -# assert 1 <= length_override <= len(no_dup_list) -# masks = [1 << i for i in range(length_override)] -# for i in range(1 << len(no_dup_list)): -# yield [elem for mask, elem in zip(masks, no_dup_list) if i & mask] -# -# -# def full_ordered_combination_gen(non_dup_list: list, perm_length: int): -# """ -# This function serves as a generator for all possible vectors of length `perm_length`, -# each of whose entries are one of the elements in `non_dup_list`, -# which is a list of non-duplicated elements. -# -# Args: -# non_dup_list: must be a list of elements with no duplication -# perm_length: must be a non-negative number indicating resulting length of the vector -# """ -# if perm_length < 0: -# raise pt.TealInputError("input permutation length must be non-negative") -# elif len(set(non_dup_list)) != len(non_dup_list): -# raise pt.TealInputError(f"input non_dup_list {non_dup_list} has duplications") -# elif perm_length == 0: -# yield [] -# return -# # we can index all possible cases of vectors with an index in range -# # [0, |non_dup_list| ^ perm_length - 1] -# # by converting an index into |non_dup_list|-based number, -# # we can get the vector mapped by the index. -# for index in range(len(non_dup_list) ** perm_length): -# index_list_basis = [] -# temp = index -# for _ in range(perm_length): -# index_list_basis.append(non_dup_list[temp % len(non_dup_list)]) -# temp //= len(non_dup_list) -# yield index_list_basis -# -# -# def oncomplete_is_in_oc_list(sth: pt.EnumInt, oc_list: list[pt.EnumInt]): -# return any(map(lambda x: str(x) == str(sth), oc_list)) -# -# -# def assemble_helper(what: pt.Expr) -> pt.TealBlock: -# assembled, _ = what.__teal__(options) -# assembled.addIncoming() -# assembled = pt.TealBlock.NormalizeBlocks(assembled) -# return assembled -# -# -# def camel_to_snake(name: str) -> str: -# return "".join(["_" + c.lower() if c.isupper() else c for c in name]).lstrip("_") -# -# -# def test_call_config(): -# for cc in pt.CallConfig: -# approval_cond_on_cc: pt.Expr | int = cc.approval_condition_under_config() -# match approval_cond_on_cc: -# case pt.Expr(): -# expected_cc = ( -# (pt.Txn.application_id() == pt.Int(0)) -# if cc == pt.CallConfig.CREATE -# else (pt.Txn.application_id() != pt.Int(0)) -# ) -# with pt.TealComponent.Context.ignoreExprEquality(): -# assert assemble_helper(approval_cond_on_cc) == assemble_helper( -# expected_cc -# ) -# case int(): -# assert approval_cond_on_cc == int(cc) & 1 -# case _: -# raise pt.TealInternalError( -# f"unexpected approval_cond_on_cc {approval_cond_on_cc}" -# ) -# -# if cc in (pt.CallConfig.CREATE, pt.CallConfig.ALL): -# with pytest.raises( -# pt.TealInputError, -# match=r"Only CallConfig.CALL or CallConfig.NEVER are valid for a clear state CallConfig, since clear state can never be invoked during creation$", -# ): -# cc.clear_state_condition_under_config() -# continue -# -# clear_state_cond_on_cc: int = cc.clear_state_condition_under_config() -# match clear_state_cond_on_cc: -# case 0: -# assert cc == pt.CallConfig.NEVER -# case 1: -# assert cc == pt.CallConfig.CALL -# case _: -# raise pt.TealInternalError( -# f"unexpected clear_state_cond_on_cc {clear_state_cond_on_cc}" -# ) -# -# -# def test_method_config(): -# never_mc = pt.MethodConfig(no_op=pt.CallConfig.NEVER) -# assert never_mc.is_never() -# assert never_mc.approval_cond() == 0 -# assert never_mc.clear_state_cond() == 0 -# -# on_complete_pow_set = power_set(ON_COMPLETE_CASES) -# approval_check_names_n_ocs = [ -# (camel_to_snake(oc.name), oc) -# for oc in ON_COMPLETE_CASES -# if str(oc) != str(pt.OnComplete.ClearState) -# ] -# for on_complete_set in on_complete_pow_set: -# oc_names = [camel_to_snake(oc.name) for oc in on_complete_set] -# ordered_call_configs = full_ordered_combination_gen( -# list(pt.CallConfig), len(on_complete_set) -# ) -# for call_configs in ordered_call_configs: -# mc = pt.MethodConfig(**dict(zip(oc_names, call_configs))) -# match mc.clear_state: -# case pt.CallConfig.NEVER: -# assert mc.clear_state_cond() == 0 -# case pt.CallConfig.CALL: -# assert mc.clear_state_cond() == 1 -# case pt.CallConfig.CREATE | pt.CallConfig.ALL: -# with pytest.raises( -# pt.TealInputError, -# match=r"Only CallConfig.CALL or CallConfig.NEVER are valid for a clear state CallConfig, since clear state can never be invoked during creation$", -# ): -# mc.clear_state_cond() -# if mc.is_never() or all( -# getattr(mc, i) == pt.CallConfig.NEVER -# for i, _ in approval_check_names_n_ocs -# ): -# assert mc.approval_cond() == 0 -# continue -# elif all( -# getattr(mc, i) == pt.CallConfig.ALL -# for i, _ in approval_check_names_n_ocs -# ): -# assert mc.approval_cond() == 1 -# continue -# list_of_cc = [ -# ( -# typing.cast( -# pt.CallConfig, getattr(mc, i) -# ).approval_condition_under_config(), -# oc, -# ) -# for i, oc in approval_check_names_n_ocs -# ] -# list_of_expressions = [] -# for expr_or_int, oc in list_of_cc: -# match expr_or_int: -# case pt.Expr(): -# list_of_expressions.append( -# pt.And(pt.Txn.on_completion() == oc, expr_or_int) -# ) -# case 0: -# continue -# case 1: -# list_of_expressions.append(pt.Txn.on_completion() == oc) -# with pt.TealComponent.Context.ignoreExprEquality(): -# assert assemble_helper(mc.approval_cond()) == assemble_helper( -# pt.Or(*list_of_expressions) -# ) -# -# -# def test_on_complete_action(): -# with pytest.raises(pt.TealInputError) as contradict_err: -# pt.OnCompleteAction(action=pt.Seq(), call_config=pt.CallConfig.NEVER) -# assert "contradicts" in str(contradict_err) -# assert pt.OnCompleteAction.never().is_empty() -# assert pt.OnCompleteAction.call_only(pt.Seq()).call_config == pt.CallConfig.CALL -# assert pt.OnCompleteAction.create_only(pt.Seq()).call_config == pt.CallConfig.CREATE -# assert pt.OnCompleteAction.always(pt.Seq()).call_config == pt.CallConfig.ALL -# -# -# def test_wrap_handler_bare_call(): -# BARE_CALL_CASES = [ -# dummy_doing_nothing, -# safe_clear_state_delete, -# pt.Approve(), -# pt.Log(pt.Bytes("message")), -# ] -# for bare_call in BARE_CALL_CASES: -# wrapped: pt.Expr = ASTBuilder.wrap_handler(False, bare_call) -# expected: pt.Expr -# match bare_call: -# case pt.Expr(): -# if bare_call.has_return(): -# expected = bare_call -# else: -# expected = pt.Seq(bare_call, pt.Approve()) -# case pt.SubroutineFnWrapper() | pt.ABIReturnSubroutine(): -# expected = pt.Seq(bare_call(), pt.Approve()) -# case _: -# raise pt.TealInputError("how you got here?") -# wrapped_assemble = assemble_helper(wrapped) -# wrapped_helper = assemble_helper(expected) -# with pt.TealComponent.Context.ignoreExprEquality(): -# assert wrapped_assemble == wrapped_helper -# -# ERROR_CASES = [ -# ( -# pt.Int(1), -# f"bare appcall handler should be TealType.none not {pt.TealType.uint64}.", -# ), -# ( -# returning_u64, -# f"subroutine call should be returning TealType.none not {pt.TealType.uint64}.", -# ), -# ( -# mult_over_u64_and_log, -# "subroutine call should take 0 arg for bare-app call. this subroutine takes 2.", -# ), -# ( -# eine_constant, -# f"abi-returning subroutine call should be returning void not {pt.abi.Uint64TypeSpec()}.", -# ), -# ( -# take_abi_and_log, -# "abi-returning subroutine call should take 0 arg for bare-app call. this abi-returning subroutine takes 1.", -# ), -# ( -# 1, -# "bare appcall can only accept: none type Expr, or Subroutine/ABIReturnSubroutine with none return and no arg", -# ), -# ] -# for error_case, error_msg in ERROR_CASES: -# with pytest.raises(pt.TealInputError) as bug: -# ASTBuilder.wrap_handler(False, error_case) -# assert error_msg in str(bug) -# -# -# def test_wrap_handler_method_call(): -# with pytest.raises(pt.TealInputError) as bug: -# ASTBuilder.wrap_handler(True, not_registrable) -# assert "method call ABIReturnSubroutine is not routable" in str(bug) -# -# with pytest.raises(pt.TealInputError) as bug: -# ASTBuilder.wrap_handler(True, safe_clear_state_delete) -# assert "method call should be only registering ABIReturnSubroutine" in str(bug) -# -# ONLY_ABI_SUBROUTINE_CASES = list( -# filter(lambda x: isinstance(x, pt.ABIReturnSubroutine), GOOD_SUBROUTINE_CASES) -# ) -# -# for abi_subroutine in ONLY_ABI_SUBROUTINE_CASES: -# wrapped: pt.Expr = ASTBuilder.wrap_handler(True, abi_subroutine) -# actual: pt.TealBlock = assemble_helper(wrapped) -# -# args: list[pt.abi.BaseType] = [ -# spec.new_instance() -# for spec in typing.cast( -# list[pt.abi.TypeSpec], abi_subroutine.subroutine.expected_arg_types -# ) -# ] -# -# app_args = [ -# arg for arg in args if arg.type_spec() not in pt.abi.TransactionTypeSpecs -# ] -# -# app_arg_cnt = len(app_args) -# -# txn_args: list[pt.abi.Transaction] = [ -# arg for arg in args if arg.type_spec() in pt.abi.TransactionTypeSpecs -# ] -# -# loading: list[pt.Expr] = [] -# -# if app_arg_cnt > pt.METHOD_ARG_NUM_CUTOFF: -# sdk_last_arg = pt.abi.TupleTypeSpec( -# *[arg.type_spec() for arg in app_args[pt.METHOD_ARG_NUM_CUTOFF - 1 :]] -# ).new_instance() -# -# loading = [ -# arg.decode(pt.Txn.application_args[index + 1]) -# for index, arg in enumerate(app_args[: pt.METHOD_ARG_NUM_CUTOFF - 1]) -# ] -# -# loading.append( -# sdk_last_arg.decode(pt.Txn.application_args[pt.METHOD_ARG_NUM_CUTOFF]) -# ) -# else: -# loading = [ -# arg.decode(pt.Txn.application_args[index + 1]) -# for index, arg in enumerate(app_args) -# ] -# -# if len(txn_args) > 0: -# for idx, txn_arg in enumerate(txn_args): -# loading.append( -# txn_arg._set_index( -# pt.Txn.group_index() - pt.Int(len(txn_args) - idx) -# ) -# ) -# if str(txn_arg.type_spec()) != "txn": -# loading.append( -# pt.Assert( -# txn_arg.get().type_enum() -# == txn_arg.type_spec().txn_type_enum() -# ) -# ) -# -# if app_arg_cnt > pt.METHOD_ARG_NUM_CUTOFF: -# loading.extend( -# [ -# sdk_last_arg[idx].store_into(val) -# for idx, val in enumerate(app_args[pt.METHOD_ARG_NUM_CUTOFF - 1 :]) -# ] -# ) -# -# evaluate: pt.Expr -# if abi_subroutine.type_of() != "void": -# output_temp = abi_subroutine.output_kwarg_info.abi_type.new_instance() -# evaluate = pt.Seq( -# abi_subroutine(*args).store_into(output_temp), -# pt.abi.MethodReturn(output_temp), -# ) -# else: -# evaluate = abi_subroutine(*args) -# -# expected = assemble_helper(pt.Seq(*loading, evaluate, pt.Approve())) -# with pt.TealComponent.Context.ignoreScratchSlotEquality(), pt.TealComponent.Context.ignoreExprEquality(): -# assert actual == expected -# -# assert pt.TealBlock.MatchScratchSlotReferences( -# pt.TealBlock.GetReferencedScratchSlots(actual), -# pt.TealBlock.GetReferencedScratchSlots(expected), -# ) -# -# -# def test_wrap_handler_method_txn_types(): -# wrapped: pt.Expr = ASTBuilder.wrap_handler(True, multiple_txn) -# actual: pt.TealBlock = assemble_helper(wrapped) -# -# args: list[pt.abi.Transaction] = [ -# pt.abi.ApplicationCallTransaction(), -# pt.abi.AssetTransferTransaction(), -# pt.abi.PaymentTransaction(), -# pt.abi.Transaction(), -# ] -# output_temp = pt.abi.Uint64() -# expected_ast = pt.Seq( -# args[0]._set_index(pt.Txn.group_index() - pt.Int(4)), -# pt.Assert(args[0].get().type_enum() == pt.TxnType.ApplicationCall), -# args[1]._set_index(pt.Txn.group_index() - pt.Int(3)), -# pt.Assert(args[1].get().type_enum() == pt.TxnType.AssetTransfer), -# args[2]._set_index(pt.Txn.group_index() - pt.Int(2)), -# pt.Assert(args[2].get().type_enum() == pt.TxnType.Payment), -# args[3]._set_index(pt.Txn.group_index() - pt.Int(1)), -# multiple_txn(*args).store_into(output_temp), -# pt.abi.MethodReturn(output_temp), -# pt.Approve(), -# ) -# -# expected = assemble_helper(expected_ast) -# with pt.TealComponent.Context.ignoreScratchSlotEquality(), pt.TealComponent.Context.ignoreExprEquality(): -# assert actual == expected -# -# assert pt.TealBlock.MatchScratchSlotReferences( -# pt.TealBlock.GetReferencedScratchSlots(actual), -# pt.TealBlock.GetReferencedScratchSlots(expected), -# ) -# -# -# def test_wrap_handler_method_call_many_args(): -# wrapped: pt.Expr = ASTBuilder.wrap_handler(True, many_args) -# actual: pt.TealBlock = assemble_helper(wrapped) -# -# args = [pt.abi.Uint64() for _ in range(20)] -# last_arg = pt.abi.TupleTypeSpec( -# *[pt.abi.Uint64TypeSpec() for _ in range(6)] -# ).new_instance() -# -# output_temp = pt.abi.Uint64() -# expected_ast = pt.Seq( -# args[0].decode(pt.Txn.application_args[1]), -# args[1].decode(pt.Txn.application_args[2]), -# args[2].decode(pt.Txn.application_args[3]), -# args[3].decode(pt.Txn.application_args[4]), -# args[4].decode(pt.Txn.application_args[5]), -# args[5].decode(pt.Txn.application_args[6]), -# args[6].decode(pt.Txn.application_args[7]), -# args[7].decode(pt.Txn.application_args[8]), -# args[8].decode(pt.Txn.application_args[9]), -# args[9].decode(pt.Txn.application_args[10]), -# args[10].decode(pt.Txn.application_args[11]), -# args[11].decode(pt.Txn.application_args[12]), -# args[12].decode(pt.Txn.application_args[13]), -# args[13].decode(pt.Txn.application_args[14]), -# last_arg.decode(pt.Txn.application_args[15]), -# last_arg[0].store_into(args[14]), -# last_arg[1].store_into(args[15]), -# last_arg[2].store_into(args[16]), -# last_arg[3].store_into(args[17]), -# last_arg[4].store_into(args[18]), -# last_arg[5].store_into(args[19]), -# many_args(*args).store_into(output_temp), -# pt.abi.MethodReturn(output_temp), -# pt.Approve(), -# ) -# expected = assemble_helper(expected_ast) -# with pt.TealComponent.Context.ignoreScratchSlotEquality(), pt.TealComponent.Context.ignoreExprEquality(): -# assert actual == expected -# -# assert pt.TealBlock.MatchScratchSlotReferences( -# pt.TealBlock.GetReferencedScratchSlots(actual), -# pt.TealBlock.GetReferencedScratchSlots(expected), -# ) -# -# -# def test_contract_json_obj(): -# abi_subroutines = list( -# filter(lambda x: isinstance(x, pt.ABIReturnSubroutine), GOOD_SUBROUTINE_CASES) -# ) -# contract_name = "contract_name" -# on_complete_actions = pt.BareCallActions( -# clear_state=pt.OnCompleteAction.call_only(safe_clear_state_delete) -# ) -# router = pt.Router(contract_name, on_complete_actions) -# method_list: list[sdk_abi.Method] = [] -# for subroutine in abi_subroutines: -# -# doc = subroutine.subroutine.implementation.__doc__ -# desc = None -# if doc is not None and doc.strip() == "replace me": -# desc = "dope description" -# -# router.add_method_handler(subroutine, description=desc) -# -# ms = subroutine.method_spec() -# -# # Manually replace it since the override is applied in the method handler -# # not attached to the ABIReturnSubroutine itself -# ms.desc = desc if desc is not None else ms.desc -# -# sig_method = sdk_abi.Method.from_signature(subroutine.method_signature()) -# -# assert ms.name == sig_method.name -# -# for idx, arg in enumerate(ms.args): -# assert arg.type == sig_method.args[idx].type -# -# method_list.append(ms) -# -# sdk_contract = sdk_abi.Contract(contract_name, method_list) -# contract = router.contract_construct() -# assert contract == sdk_contract -# -# -# def test_build_program_all_empty(): -# router = pt.Router("test") -# -# approval, clear_state, contract = router.build_program() -# -# expected_empty_program = pt.TealSimpleBlock( -# [ -# pt.TealOp(None, pt.Op.int, 0), -# pt.TealOp(None, pt.Op.return_), -# ] -# ) -# -# with pt.TealComponent.Context.ignoreExprEquality(): -# assert assemble_helper(approval) == expected_empty_program -# assert assemble_helper(clear_state) == expected_empty_program -# -# expected_contract = sdk_abi.Contract("test", []) -# assert contract == expected_contract -# -# -# def test_build_program_approval_empty(): -# router = pt.Router( -# "test", -# pt.BareCallActions(clear_state=pt.OnCompleteAction.call_only(pt.Approve())), -# ) -# -# approval, clear_state, contract = router.build_program() -# -# expected_empty_program = pt.TealSimpleBlock( -# [ -# pt.TealOp(None, pt.Op.int, 0), -# pt.TealOp(None, pt.Op.return_), -# ] -# ) -# -# with pt.TealComponent.Context.ignoreExprEquality(): -# assert assemble_helper(approval) == expected_empty_program -# assert assemble_helper(clear_state) != expected_empty_program -# -# expected_contract = sdk_abi.Contract("test", []) -# assert contract == expected_contract -# -# -# def test_build_program_clear_state_empty(): -# router = pt.Router( -# "test", pt.BareCallActions(no_op=pt.OnCompleteAction.always(pt.Approve())) -# ) -# -# approval, clear_state, contract = router.build_program() -# -# expected_empty_program = pt.TealSimpleBlock( -# [ -# pt.TealOp(None, pt.Op.int, 0), -# pt.TealOp(None, pt.Op.return_), -# ] -# ) -# -# with pt.TealComponent.Context.ignoreExprEquality(): -# assert assemble_helper(approval) != expected_empty_program -# assert assemble_helper(clear_state) == expected_empty_program -# -# expected_contract = sdk_abi.Contract("test", []) -# assert contract == expected_contract -# -# -# def test_build_program_clear_state_invalid_config(): -# for config in (pt.CallConfig.CREATE, pt.CallConfig.ALL): -# bareCalls = pt.BareCallActions( -# clear_state=pt.OnCompleteAction(action=pt.Approve(), call_config=config) -# ) -# with pytest.raises( -# pt.TealInputError, -# match=r"Only CallConfig.CALL or CallConfig.NEVER are valid for a clear state CallConfig, since clear state can never be invoked during creation$", -# ): -# pt.Router("test", bareCalls) -# -# router = pt.Router("test") -# -# @pt.ABIReturnSubroutine -# def clear_state_method(): -# return pt.Approve() -# -# with pytest.raises( -# pt.TealInputError, -# match=r"Only CallConfig.CALL or CallConfig.NEVER are valid for a clear state CallConfig, since clear state can never be invoked during creation$", -# ): -# router.add_method_handler( -# clear_state_method, -# method_config=pt.MethodConfig(clear_state=config), -# ) -# -# -# def test_build_program_clear_state_valid_config(): -# action = pt.If(pt.Txn.fee() == pt.Int(4)).Then(pt.Approve()).Else(pt.Reject()) -# config = pt.CallConfig.CALL -# -# router_with_bare_call = pt.Router( -# "test", -# pt.BareCallActions( -# clear_state=pt.OnCompleteAction(action=action, call_config=config) -# ), -# ) -# _, actual_clear_state_with_bare_call, _ = router_with_bare_call.build_program() -# -# expected_clear_state_with_bare_call = assemble_helper( -# pt.Cond([pt.Txn.application_args.length() == pt.Int(0), action]) -# ) -# -# with pt.TealComponent.Context.ignoreExprEquality(): -# assert ( -# assemble_helper(actual_clear_state_with_bare_call) -# == expected_clear_state_with_bare_call -# ) -# -# router_with_method = pt.Router("test") -# -# @pt.ABIReturnSubroutine -# def clear_state_method(): -# return action -# -# router_with_method.add_method_handler( -# clear_state_method, method_config=pt.MethodConfig(clear_state=config) -# ) -# -# _, actual_clear_state_with_method, _ = router_with_method.build_program() -# -# expected_clear_state_with_method = assemble_helper( -# pt.Cond( -# [ -# pt.Txn.application_args[0] -# == pt.MethodSignature("clear_state_method()void"), -# pt.Seq(clear_state_method(), pt.Approve()), -# ] -# ) -# ) -# -# with pt.TealComponent.Context.ignoreExprEquality(): -# assert ( -# assemble_helper(actual_clear_state_with_method) -# == expected_clear_state_with_method -# ) -# -# +def power_set(no_dup_list: list, length_override: int = None): + """ + This function serves as a generator for all possible elements in power_set + over `non_dup_list`, which is a list of non-duplicated elements (matches property of a set). + + The cardinality of a powerset is 2^|non_dup_list|, so we can iterate from 0 to 2^|non_dup_list| - 1 + to index each element in such power_set. + By binary representation of each index, we can see it as an allowance over each element in `no_dup_list`, + and generate a unique subset of `non_dup_list`, which yields as an element of power_set of `no_dup_list`. + + Args: + no_dup_list: a list of elements with no duplication + length_override: a number indicating the largest size of super_set element, + must be in range [1, len(no_dup_list)]. + """ + if length_override is None: + length_override = len(no_dup_list) + assert 1 <= length_override <= len(no_dup_list) + masks = [1 << i for i in range(length_override)] + for i in range(1 << len(no_dup_list)): + yield [elem for mask, elem in zip(masks, no_dup_list) if i & mask] + + +def full_ordered_combination_gen(non_dup_list: list, perm_length: int): + """ + This function serves as a generator for all possible vectors of length `perm_length`, + each of whose entries are one of the elements in `non_dup_list`, + which is a list of non-duplicated elements. + + Args: + non_dup_list: must be a list of elements with no duplication + perm_length: must be a non-negative number indicating resulting length of the vector + """ + if perm_length < 0: + raise pt.TealInputError("input permutation length must be non-negative") + elif len(set(non_dup_list)) != len(non_dup_list): + raise pt.TealInputError(f"input non_dup_list {non_dup_list} has duplications") + elif perm_length == 0: + yield [] + return + # we can index all possible cases of vectors with an index in range + # [0, |non_dup_list| ^ perm_length - 1] + # by converting an index into |non_dup_list|-based number, + # we can get the vector mapped by the index. + for index in range(len(non_dup_list) ** perm_length): + index_list_basis = [] + temp = index + for _ in range(perm_length): + index_list_basis.append(non_dup_list[temp % len(non_dup_list)]) + temp //= len(non_dup_list) + yield index_list_basis + + +def oncomplete_is_in_oc_list(sth: pt.EnumInt, oc_list: list[pt.EnumInt]): + return any(map(lambda x: str(x) == str(sth), oc_list)) + + +def assemble_helper(what: pt.Expr) -> pt.TealBlock: + assembled, _ = what.__teal__(options) + assembled.addIncoming() + assembled = pt.TealBlock.NormalizeBlocks(assembled) + return assembled + + +def camel_to_snake(name: str) -> str: + return "".join(["_" + c.lower() if c.isupper() else c for c in name]).lstrip("_") + + +def test_call_config(): + for cc in pt.CallConfig: + approval_cond_on_cc: pt.Expr | int = cc.approval_condition_under_config() + match approval_cond_on_cc: + case pt.Expr(): + expected_cc = ( + (pt.Txn.application_id() == pt.Int(0)) + if cc == pt.CallConfig.CREATE + else (pt.Txn.application_id() != pt.Int(0)) + ) + with pt.TealComponent.Context.ignoreExprEquality(): + assert assemble_helper(approval_cond_on_cc) == assemble_helper( + expected_cc + ) + case int(): + assert approval_cond_on_cc == int(cc) & 1 + case _: + raise pt.TealInternalError( + f"unexpected approval_cond_on_cc {approval_cond_on_cc}" + ) + + if cc in (pt.CallConfig.CREATE, pt.CallConfig.ALL): + with pytest.raises( + pt.TealInputError, + match=r"Only CallConfig.CALL or CallConfig.NEVER are valid for a clear state CallConfig, since clear state can never be invoked during creation$", + ): + cc.clear_state_condition_under_config() + continue + + clear_state_cond_on_cc: int = cc.clear_state_condition_under_config() + match clear_state_cond_on_cc: + case 0: + assert cc == pt.CallConfig.NEVER + case 1: + assert cc == pt.CallConfig.CALL + case _: + raise pt.TealInternalError( + f"unexpected clear_state_cond_on_cc {clear_state_cond_on_cc}" + ) + + +def test_method_config(): + never_mc = pt.MethodConfig(no_op=pt.CallConfig.NEVER) + assert never_mc.is_never() + assert never_mc.approval_cond() == 0 + assert never_mc.clear_state_cond() == 0 + + on_complete_pow_set = power_set(ON_COMPLETE_CASES) + approval_check_names_n_ocs = [ + (camel_to_snake(oc.name), oc) + for oc in ON_COMPLETE_CASES + if str(oc) != str(pt.OnComplete.ClearState) + ] + for on_complete_set in on_complete_pow_set: + oc_names = [camel_to_snake(oc.name) for oc in on_complete_set] + ordered_call_configs = full_ordered_combination_gen( + list(pt.CallConfig), len(on_complete_set) + ) + for call_configs in ordered_call_configs: + mc = pt.MethodConfig(**dict(zip(oc_names, call_configs))) + match mc.clear_state: + case pt.CallConfig.NEVER: + assert mc.clear_state_cond() == 0 + case pt.CallConfig.CALL: + assert mc.clear_state_cond() == 1 + case pt.CallConfig.CREATE | pt.CallConfig.ALL: + with pytest.raises( + pt.TealInputError, + match=r"Only CallConfig.CALL or CallConfig.NEVER are valid for a clear state CallConfig, since clear state can never be invoked during creation$", + ): + mc.clear_state_cond() + if mc.is_never() or all( + getattr(mc, i) == pt.CallConfig.NEVER + for i, _ in approval_check_names_n_ocs + ): + assert mc.approval_cond() == 0 + continue + elif all( + getattr(mc, i) == pt.CallConfig.ALL + for i, _ in approval_check_names_n_ocs + ): + assert mc.approval_cond() == 1 + continue + list_of_cc = [ + ( + typing.cast( + pt.CallConfig, getattr(mc, i) + ).approval_condition_under_config(), + oc, + ) + for i, oc in approval_check_names_n_ocs + ] + list_of_expressions = [] + for expr_or_int, oc in list_of_cc: + match expr_or_int: + case pt.Expr(): + list_of_expressions.append( + pt.And(pt.Txn.on_completion() == oc, expr_or_int) + ) + case 0: + continue + case 1: + list_of_expressions.append(pt.Txn.on_completion() == oc) + with pt.TealComponent.Context.ignoreExprEquality(): + assert assemble_helper(mc.approval_cond()) == assemble_helper( + pt.Or(*list_of_expressions) + ) + + +def test_on_complete_action(): + with pytest.raises(pt.TealInputError) as contradict_err: + pt.OnCompleteAction(action=pt.Seq(), call_config=pt.CallConfig.NEVER) + assert "contradicts" in str(contradict_err) + assert pt.OnCompleteAction.never().is_empty() + assert pt.OnCompleteAction.call_only(pt.Seq()).call_config == pt.CallConfig.CALL + assert pt.OnCompleteAction.create_only(pt.Seq()).call_config == pt.CallConfig.CREATE + assert pt.OnCompleteAction.always(pt.Seq()).call_config == pt.CallConfig.ALL + + +def test_wrap_handler_bare_call(): + BARE_CALL_CASES = [ + dummy_doing_nothing, + safe_clear_state_delete, + pt.Approve(), + pt.Log(pt.Bytes("message")), + ] + for bare_call in BARE_CALL_CASES: + wrapped: pt.Expr = ASTBuilder.wrap_handler(False, bare_call) + expected: pt.Expr + match bare_call: + case pt.Expr(): + if bare_call.has_return(): + expected = bare_call + else: + expected = pt.Seq(bare_call, pt.Approve()) + case pt.SubroutineFnWrapper() | pt.ABIReturnSubroutine(): + expected = pt.Seq(bare_call(), pt.Approve()) + case _: + raise pt.TealInputError("how you got here?") + wrapped_assemble = assemble_helper(wrapped) + wrapped_helper = assemble_helper(expected) + with pt.TealComponent.Context.ignoreExprEquality(): + assert wrapped_assemble == wrapped_helper + + ERROR_CASES = [ + ( + pt.Int(1), + f"bare appcall handler should be TealType.none not {pt.TealType.uint64}.", + ), + ( + returning_u64, + f"subroutine call should be returning TealType.none not {pt.TealType.uint64}.", + ), + ( + mult_over_u64_and_log, + "subroutine call should take 0 arg for bare-app call. this subroutine takes 2.", + ), + ( + eine_constant, + f"abi-returning subroutine call should be returning void not {pt.abi.Uint64TypeSpec()}.", + ), + ( + take_abi_and_log, + "abi-returning subroutine call should take 0 arg for bare-app call. this abi-returning subroutine takes 1.", + ), + ( + 1, + "bare appcall can only accept: none type Expr, or Subroutine/ABIReturnSubroutine with none return and no arg", + ), + ] + for error_case, error_msg in ERROR_CASES: + with pytest.raises(pt.TealInputError) as bug: + ASTBuilder.wrap_handler(False, error_case) + assert error_msg in str(bug) + + +def test_wrap_handler_method_call(): + with pytest.raises(pt.TealInputError) as bug: + ASTBuilder.wrap_handler(True, not_registrable) + assert "method call ABIReturnSubroutine is not routable" in str(bug) + + with pytest.raises(pt.TealInputError) as bug: + ASTBuilder.wrap_handler(True, safe_clear_state_delete) + assert "method call should be only registering ABIReturnSubroutine" in str(bug) + + ONLY_ABI_SUBROUTINE_CASES = list( + filter(lambda x: isinstance(x, pt.ABIReturnSubroutine), GOOD_SUBROUTINE_CASES) + ) + + for abi_subroutine in ONLY_ABI_SUBROUTINE_CASES: + wrapped: pt.Expr = ASTBuilder.wrap_handler(True, abi_subroutine) + actual: pt.TealBlock = assemble_helper(wrapped) + + args: list[pt.abi.BaseType] = [ + spec.new_instance() + for spec in typing.cast( + list[pt.abi.TypeSpec], abi_subroutine.subroutine.expected_arg_types + ) + ] + + app_args = [ + arg for arg in args if arg.type_spec() not in pt.abi.TransactionTypeSpecs + ] + + app_arg_cnt = len(app_args) + + txn_args: list[pt.abi.Transaction] = [ + arg for arg in args if arg.type_spec() in pt.abi.TransactionTypeSpecs + ] + + loading: list[pt.Expr] = [] + + if app_arg_cnt > pt.METHOD_ARG_NUM_CUTOFF: + sdk_last_arg = pt.abi.TupleTypeSpec( + *[arg.type_spec() for arg in app_args[pt.METHOD_ARG_NUM_CUTOFF - 1 :]] + ).new_instance() + + loading = [ + arg.decode(pt.Txn.application_args[index + 1]) + for index, arg in enumerate(app_args[: pt.METHOD_ARG_NUM_CUTOFF - 1]) + ] + + loading.append( + sdk_last_arg.decode(pt.Txn.application_args[pt.METHOD_ARG_NUM_CUTOFF]) + ) + else: + loading = [ + arg.decode(pt.Txn.application_args[index + 1]) + for index, arg in enumerate(app_args) + ] + + if len(txn_args) > 0: + for idx, txn_arg in enumerate(txn_args): + loading.append( + txn_arg._set_index( + pt.Txn.group_index() - pt.Int(len(txn_args) - idx) + ) + ) + if str(txn_arg.type_spec()) != "txn": + loading.append( + pt.Assert( + txn_arg.get().type_enum() + == txn_arg.type_spec().txn_type_enum() + ) + ) + + if app_arg_cnt > pt.METHOD_ARG_NUM_CUTOFF: + loading.extend( + [ + sdk_last_arg[idx].store_into(val) + for idx, val in enumerate(app_args[pt.METHOD_ARG_NUM_CUTOFF - 1 :]) + ] + ) + + evaluate: pt.Expr + if abi_subroutine.type_of() != "void": + output_temp = abi_subroutine.output_kwarg_info.abi_type.new_instance() + evaluate = pt.Seq( + abi_subroutine(*args).store_into(output_temp), + pt.abi.MethodReturn(output_temp), + ) + else: + evaluate = abi_subroutine(*args) + + expected = assemble_helper(pt.Seq(*loading, evaluate, pt.Approve())) + with pt.TealComponent.Context.ignoreScratchSlotEquality(), pt.TealComponent.Context.ignoreExprEquality(): + assert actual == expected + + assert pt.TealBlock.MatchScratchSlotReferences( + pt.TealBlock.GetReferencedScratchSlots(actual), + pt.TealBlock.GetReferencedScratchSlots(expected), + ) + + +def test_wrap_handler_method_txn_types(): + wrapped: pt.Expr = ASTBuilder.wrap_handler(True, multiple_txn) + actual: pt.TealBlock = assemble_helper(wrapped) + + args: list[pt.abi.Transaction] = [ + pt.abi.ApplicationCallTransaction(), + pt.abi.AssetTransferTransaction(), + pt.abi.PaymentTransaction(), + pt.abi.Transaction(), + ] + output_temp = pt.abi.Uint64() + expected_ast = pt.Seq( + args[0]._set_index(pt.Txn.group_index() - pt.Int(4)), + pt.Assert(args[0].get().type_enum() == pt.TxnType.ApplicationCall), + args[1]._set_index(pt.Txn.group_index() - pt.Int(3)), + pt.Assert(args[1].get().type_enum() == pt.TxnType.AssetTransfer), + args[2]._set_index(pt.Txn.group_index() - pt.Int(2)), + pt.Assert(args[2].get().type_enum() == pt.TxnType.Payment), + args[3]._set_index(pt.Txn.group_index() - pt.Int(1)), + multiple_txn(*args).store_into(output_temp), + pt.abi.MethodReturn(output_temp), + pt.Approve(), + ) + + expected = assemble_helper(expected_ast) + with pt.TealComponent.Context.ignoreScratchSlotEquality(), pt.TealComponent.Context.ignoreExprEquality(): + assert actual == expected + + assert pt.TealBlock.MatchScratchSlotReferences( + pt.TealBlock.GetReferencedScratchSlots(actual), + pt.TealBlock.GetReferencedScratchSlots(expected), + ) + + +def test_wrap_handler_method_call_many_args(): + wrapped: pt.Expr = ASTBuilder.wrap_handler(True, many_args) + actual: pt.TealBlock = assemble_helper(wrapped) + + args = [pt.abi.Uint64() for _ in range(20)] + last_arg = pt.abi.TupleTypeSpec( + *[pt.abi.Uint64TypeSpec() for _ in range(6)] + ).new_instance() + + output_temp = pt.abi.Uint64() + expected_ast = pt.Seq( + args[0].decode(pt.Txn.application_args[1]), + args[1].decode(pt.Txn.application_args[2]), + args[2].decode(pt.Txn.application_args[3]), + args[3].decode(pt.Txn.application_args[4]), + args[4].decode(pt.Txn.application_args[5]), + args[5].decode(pt.Txn.application_args[6]), + args[6].decode(pt.Txn.application_args[7]), + args[7].decode(pt.Txn.application_args[8]), + args[8].decode(pt.Txn.application_args[9]), + args[9].decode(pt.Txn.application_args[10]), + args[10].decode(pt.Txn.application_args[11]), + args[11].decode(pt.Txn.application_args[12]), + args[12].decode(pt.Txn.application_args[13]), + args[13].decode(pt.Txn.application_args[14]), + last_arg.decode(pt.Txn.application_args[15]), + last_arg[0].store_into(args[14]), + last_arg[1].store_into(args[15]), + last_arg[2].store_into(args[16]), + last_arg[3].store_into(args[17]), + last_arg[4].store_into(args[18]), + last_arg[5].store_into(args[19]), + many_args(*args).store_into(output_temp), + pt.abi.MethodReturn(output_temp), + pt.Approve(), + ) + expected = assemble_helper(expected_ast) + with pt.TealComponent.Context.ignoreScratchSlotEquality(), pt.TealComponent.Context.ignoreExprEquality(): + assert actual == expected + + assert pt.TealBlock.MatchScratchSlotReferences( + pt.TealBlock.GetReferencedScratchSlots(actual), + pt.TealBlock.GetReferencedScratchSlots(expected), + ) + + +def test_contract_json_obj(): + abi_subroutines = list( + filter(lambda x: isinstance(x, pt.ABIReturnSubroutine), GOOD_SUBROUTINE_CASES) + ) + contract_name = "contract_name" + on_complete_actions = pt.BareCallActions( + clear_state=pt.OnCompleteAction.call_only(safe_clear_state_delete) + ) + router = pt.Router(contract_name, on_complete_actions) + method_list: list[sdk_abi.Method] = [] + for subroutine in abi_subroutines: + + doc = subroutine.subroutine.implementation.__doc__ + desc = None + if doc is not None and doc.strip() == "replace me": + desc = "dope description" + + router.add_method_handler(subroutine, description=desc) + + ms = subroutine.method_spec() + + # Manually replace it since the override is applied in the method handler + # not attached to the ABIReturnSubroutine itself + ms.desc = desc if desc is not None else ms.desc + + sig_method = sdk_abi.Method.from_signature(subroutine.method_signature()) + + assert ms.name == sig_method.name + + for idx, arg in enumerate(ms.args): + assert arg.type == sig_method.args[idx].type + + method_list.append(ms) + + sdk_contract = sdk_abi.Contract(contract_name, method_list) + contract = router.contract_construct() + assert contract == sdk_contract + + +def test_build_program_all_empty(): + router = pt.Router("test") + + approval, clear_state, contract = router.build_program() + + expected_empty_program = pt.TealSimpleBlock( + [ + pt.TealOp(None, pt.Op.int, 0), + pt.TealOp(None, pt.Op.return_), + ] + ) + + with pt.TealComponent.Context.ignoreExprEquality(): + assert assemble_helper(approval) == expected_empty_program + assert assemble_helper(clear_state) == expected_empty_program + + expected_contract = sdk_abi.Contract("test", []) + assert contract == expected_contract + + +def test_build_program_approval_empty(): + router = pt.Router( + "test", + pt.BareCallActions(clear_state=pt.OnCompleteAction.call_only(pt.Approve())), + ) + + approval, clear_state, contract = router.build_program() + + expected_empty_program = pt.TealSimpleBlock( + [ + pt.TealOp(None, pt.Op.int, 0), + pt.TealOp(None, pt.Op.return_), + ] + ) + + with pt.TealComponent.Context.ignoreExprEquality(): + assert assemble_helper(approval) == expected_empty_program + assert assemble_helper(clear_state) != expected_empty_program + + expected_contract = sdk_abi.Contract("test", []) + assert contract == expected_contract + + +def test_build_program_clear_state_empty(): + router = pt.Router( + "test", pt.BareCallActions(no_op=pt.OnCompleteAction.always(pt.Approve())) + ) + + approval, clear_state, contract = router.build_program() + + expected_empty_program = pt.TealSimpleBlock( + [ + pt.TealOp(None, pt.Op.int, 0), + pt.TealOp(None, pt.Op.return_), + ] + ) + + with pt.TealComponent.Context.ignoreExprEquality(): + assert assemble_helper(approval) != expected_empty_program + assert assemble_helper(clear_state) == expected_empty_program + + expected_contract = sdk_abi.Contract("test", []) + assert contract == expected_contract + + +def test_build_program_clear_state_invalid_config(): + for config in (pt.CallConfig.CREATE, pt.CallConfig.ALL): + bareCalls = pt.BareCallActions( + clear_state=pt.OnCompleteAction(action=pt.Approve(), call_config=config) + ) + with pytest.raises( + pt.TealInputError, + match=r"Only CallConfig.CALL or CallConfig.NEVER are valid for a clear state CallConfig, since clear state can never be invoked during creation$", + ): + pt.Router("test", bareCalls) + + router = pt.Router("test") + + @pt.ABIReturnSubroutine + def clear_state_method(): + return pt.Approve() + + with pytest.raises( + pt.TealInputError, + match=r"Only CallConfig.CALL or CallConfig.NEVER are valid for a clear state CallConfig, since clear state can never be invoked during creation$", + ): + router.add_method_handler( + clear_state_method, + method_config=pt.MethodConfig(clear_state=config), + ) + + +def test_build_program_clear_state_valid_config(): + action = pt.If(pt.Txn.fee() == pt.Int(4)).Then(pt.Approve()).Else(pt.Reject()) + config = pt.CallConfig.CALL + + router_with_bare_call = pt.Router( + "test", + pt.BareCallActions( + clear_state=pt.OnCompleteAction(action=action, call_config=config) + ), + ) + _, actual_clear_state_with_bare_call, _ = router_with_bare_call.build_program() + + expected_clear_state_with_bare_call = assemble_helper( + pt.Cond([pt.Txn.application_args.length() == pt.Int(0), action]) + ) + + with pt.TealComponent.Context.ignoreExprEquality(): + assert ( + assemble_helper(actual_clear_state_with_bare_call) + == expected_clear_state_with_bare_call + ) + + router_with_method = pt.Router("test") + + @pt.ABIReturnSubroutine + def clear_state_method(): + return action + + router_with_method.add_method_handler( + clear_state_method, method_config=pt.MethodConfig(clear_state=config) + ) + + _, actual_clear_state_with_method, _ = router_with_method.build_program() + + expected_clear_state_with_method = assemble_helper( + pt.Cond( + [ + pt.Txn.application_args[0] + == pt.MethodSignature("clear_state_method()void"), + pt.Seq(clear_state_method(), pt.Approve()), + ] + ) + ) + + with pt.TealComponent.Context.ignoreExprEquality(): + assert ( + assemble_helper(actual_clear_state_with_method) + == expected_clear_state_with_method + ) + + def test_override_names(): router = pt.Router("test") From 8639e32f05ac19eca9025ef59e93217377ee1334 Mon Sep 17 00:00:00 2001 From: Ben Guidarelli Date: Wed, 5 Oct 2022 10:21:26 -0400 Subject: [PATCH 3/8] actually assert --- pyteal/ast/router_test.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pyteal/ast/router_test.py b/pyteal/ast/router_test.py index 764325dec..1d0395356 100644 --- a/pyteal/ast/router_test.py +++ b/pyteal/ast/router_test.py @@ -817,7 +817,8 @@ def handle_algo(deposit: pt.abi.PaymentTransaction): """handles the deposit where the input is a payment""" return pt.Assert(deposit.get().amount() > pt.Int(0)) - approval, clear, contract = router.compile_program(version=7) + _, _, contract = router.compile_program(version=7) assert len(contract.methods) > 0 for meth in contract.methods: - print(meth.dictify()) + dmeth = meth.dictify() + assert dmeth["name"] == "handle" From dc8e6eedc174b6399c5637dcf3d7107b6f4696ec Mon Sep 17 00:00:00 2001 From: Hang Su <87964331+ahangsu@users.noreply.github.com> Date: Wed, 5 Oct 2022 13:00:09 -0400 Subject: [PATCH 4/8] Proposing API minor change on `ABIReturnSubroutine` for name overriding (#551) * Typo Fix for `abi.Uint64TypeSpec` (#549) * minor suggestion on ABIReturnSubroutine API change --- pyteal/ast/abi/uint.py | 2 +- pyteal/ast/subroutine.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/pyteal/ast/abi/uint.py b/pyteal/ast/abi/uint.py index 53d2cc797..9855c29b9 100644 --- a/pyteal/ast/abi/uint.py +++ b/pyteal/ast/abi/uint.py @@ -224,7 +224,7 @@ def annotation_type(self) -> "type[Uint64]": return Uint64 -Uint32TypeSpec.__module__ = "pyteal.abi" +Uint64TypeSpec.__module__ = "pyteal.abi" class Uint(BaseType): diff --git a/pyteal/ast/subroutine.py b/pyteal/ast/subroutine.py index feb9050ab..b156ddcbd 100644 --- a/pyteal/ast/subroutine.py +++ b/pyteal/ast/subroutine.py @@ -519,6 +519,9 @@ def abi_sum(toSum: abi.DynamicArray[abi.Uint64], *, output: abi.Uint64) -> Expr: def __init__( self, fn_implementation: Callable[..., Expr], + /, + *, + overriding_name: Optional[str] = None, ) -> None: self.output_kwarg_info: Optional[OutputKwArgInfo] = self._get_output_kwarg_info( fn_implementation @@ -526,9 +529,17 @@ def __init__( self.subroutine = SubroutineDefinition( fn_implementation, return_type=TealType.none, + name_str=overriding_name, has_abi_output=self.output_kwarg_info is not None, ) + @staticmethod + def name_override(name: str) -> Callable[..., "ABIReturnSubroutine"]: + def wrapper(fn_impl: Callable[..., Expr]) -> ABIReturnSubroutine: + return ABIReturnSubroutine(fn_impl, overriding_name=name) + + return wrapper + @classmethod def _get_output_kwarg_info( cls, fn_implementation: Callable[..., Expr] From 758f52002900ab58763d33fff7b8db217af6974c Mon Sep 17 00:00:00 2001 From: Ben Guidarelli Date: Wed, 5 Oct 2022 13:11:30 -0400 Subject: [PATCH 5/8] make tests pass --- pyteal/ast/router.py | 4 ++-- pyteal/ast/subroutine.py | 5 ++--- pyteal/ast/subroutine_test.py | 2 +- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/pyteal/ast/router.py b/pyteal/ast/router.py index d213d4532..846499e43 100644 --- a/pyteal/ast/router.py +++ b/pyteal/ast/router.py @@ -587,7 +587,7 @@ def add_method_handler( f"with {self.method_selector_to_sig[method_selector]}" ) - meth = method_call.method_spec(overriding_name) + meth = method_call.method_spec() if description is not None: meth.desc = description self.methods.append(meth) @@ -647,7 +647,7 @@ def method( # - CallConfig.Never # both cases evaluate to False in if statement. def wrap(_func) -> ABIReturnSubroutine: - wrapped_subroutine = ABIReturnSubroutine(_func) + wrapped_subroutine = ABIReturnSubroutine(_func, overriding_name=name) call_configs: MethodConfig if ( no_op is None diff --git a/pyteal/ast/subroutine.py b/pyteal/ast/subroutine.py index b156ddcbd..62b3c3f43 100644 --- a/pyteal/ast/subroutine.py +++ b/pyteal/ast/subroutine.py @@ -621,7 +621,7 @@ def method_signature(self, overriding_name: str = None) -> str: overriding_name = self.name() return f"{overriding_name}({','.join(args)}){self.type_of()}" - def method_spec(self, overriding_name: str = None) -> sdk_abi.Method: + def method_spec(self) -> sdk_abi.Method: desc: str = "" arg_descs: dict[str, str] = {} return_desc: str = "" @@ -685,8 +685,7 @@ def method_spec(self, overriding_name: str = None) -> sdk_abi.Method: return_obj["desc"] = return_desc # Create the method spec, adding description if set - name = overriding_name if overriding_name is not None else self.name() - spec = {"name": name, "args": args, "returns": return_obj} + spec = {"name": self.name(), "args": args, "returns": return_obj} if desc: spec["desc"] = desc diff --git a/pyteal/ast/subroutine_test.py b/pyteal/ast/subroutine_test.py index c3c2d01e0..70fd9ccd3 100644 --- a/pyteal/ast/subroutine_test.py +++ b/pyteal/ast/subroutine_test.py @@ -1473,5 +1473,5 @@ def abi_meth(a: pt.abi.Uint64, b: pt.abi.Uint64, *, output: pt.abi.Uint64): mspec = ABIReturnSubroutine(abi_meth).method_spec().dictify() assert mspec["name"] == "abi_meth" - mspec = ABIReturnSubroutine(abi_meth).method_spec("add").dictify() + mspec = ABIReturnSubroutine(abi_meth, overriding_name="add").method_spec().dictify() assert mspec["name"] == "add" From 2e2383309721635794a1af99a59d86390d528dfb Mon Sep 17 00:00:00 2001 From: Michael Diamant Date: Fri, 7 Oct 2022 09:44:19 -0400 Subject: [PATCH 6/8] Add sanity check to test_override_names (#556) --- pyteal/ast/router_test.py | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/pyteal/ast/router_test.py b/pyteal/ast/router_test.py index 1d0395356..148d48c1a 100644 --- a/pyteal/ast/router_test.py +++ b/pyteal/ast/router_test.py @@ -804,21 +804,37 @@ def clear_state_method(): def test_override_names(): + r1 = pt.Router("test") - router = pt.Router("test") - - @router.method(name="handle") + @r1.method(name="handle") def handle_asa(deposit: pt.abi.AssetTransferTransaction): """handles the deposit where the input is an asset transfer""" return pt.Assert(deposit.get().asset_amount() > pt.Int(0)) - @router.method(name="handle") + @r1.method(name="handle") def handle_algo(deposit: pt.abi.PaymentTransaction): """handles the deposit where the input is a payment""" return pt.Assert(deposit.get().amount() > pt.Int(0)) - _, _, contract = router.compile_program(version=7) - assert len(contract.methods) > 0 - for meth in contract.methods: + ap1, cs1, c1 = r1.compile_program(version=pt.compiler.MAX_PROGRAM_VERSION) + assert len(c1.methods) == 2 + for meth in c1.methods: dmeth = meth.dictify() assert dmeth["name"] == "handle" + + # Confirm an equivalent router definition _without_ `name` overrides produces the same output. + r2 = pt.Router("test") + + @r2.method() + def handle(deposit: pt.abi.AssetTransferTransaction): + """handles the deposit where the input is an asset transfer""" + return pt.Assert(deposit.get().asset_amount() > pt.Int(0)) + + @r2.method() + def handle(deposit: pt.abi.PaymentTransaction): # noqa: F811 + """handles the deposit where the input is a payment""" + return pt.Assert(deposit.get().amount() > pt.Int(0)) + + ap2, cs2, c2 = r2.compile_program(version=pt.compiler.MAX_PROGRAM_VERSION) + + assert (ap1, cs1, c1) == (ap2, cs2, c2) From f2b2ceab1cadbae54be4a62cb1731b57135a0ec4 Mon Sep 17 00:00:00 2001 From: Hang Su <87964331+ahangsu@users.noreply.github.com> Date: Fri, 7 Oct 2022 13:18:52 -0400 Subject: [PATCH 7/8] `ABIReturnSubroutine` doc/test-case on `name_override` decorator (#557) * abi overridden name subr doc testcase * minor typo fix Co-authored-by: Michael Diamant Co-authored-by: Michael Diamant --- docs/abi.rst | 15 +++++++++++++++ pyteal/ast/subroutine_test.py | 7 +++++++ 2 files changed, 22 insertions(+) diff --git a/docs/abi.rst b/docs/abi.rst index fed0509ff..63d65147c 100644 --- a/docs/abi.rst +++ b/docs/abi.rst @@ -682,6 +682,21 @@ Notice that even though the original :code:`get_account_status` function returns The only exception to this transformation is if the subroutine has no return value. Without a return value, a :code:`ComputedValue` is unnecessary and the subroutine will still return an :code:`Expr` to the caller. In this case, the :code:`@ABIReturnSubroutine` decorator acts identically the :code:`@Subroutine` decorator. +The name of the subroutine constructed by the :code:`@ABIReturnSubroutine` decorator is by default the function name. In order to override the default subroutine name, the decorator :any:`ABIReturnSubroutine.name_override ` is introduced to construct a subroutine with its name overridden. An example is below: + +.. code-block:: python + + from pyteal import * + + @ABIReturnSubroutine.name_override("increment") + def add_by_one(prev: abi.Uint32, *, output: abi.Uint32) -> Expr: + return output.set(prev.get() + Int(1)) + + # NOTE! In this case, the `ABIReturnSubroutine` is initialized with a name "increment" + # overriding its original name "add_by_one" + assert add_by_one.method_spec().dictify()["name"] == "increment" + + Creating an ARC-4 Program ---------------------------------------------------- diff --git a/pyteal/ast/subroutine_test.py b/pyteal/ast/subroutine_test.py index 70fd9ccd3..28b0f0aa9 100644 --- a/pyteal/ast/subroutine_test.py +++ b/pyteal/ast/subroutine_test.py @@ -1475,3 +1475,10 @@ def abi_meth(a: pt.abi.Uint64, b: pt.abi.Uint64, *, output: pt.abi.Uint64): mspec = ABIReturnSubroutine(abi_meth, overriding_name="add").method_spec().dictify() assert mspec["name"] == "add" + + @ABIReturnSubroutine.name_override("overriden_add") + def abi_meth_2(a: pt.abi.Uint64, b: pt.abi.Uint64, *, output: pt.abi.Uint64): + return output.set(a.get() + b.get()) + + mspec = abi_meth_2.method_spec().dictify() + assert mspec["name"] == "overriden_add" From 4318d65cb4c462c5b9aed036b645f5ecbdef5c9a Mon Sep 17 00:00:00 2001 From: Ben Guidarelli Date: Fri, 7 Oct 2022 14:07:21 -0400 Subject: [PATCH 8/8] adding changelog note --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f68ff7c92..526057338 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ## Fixed * Erroring on constructing an odd length hex string. ([#539](https://github.com/algorand/pyteal/pull/539)) +* Incorrect behavior when overriding a method name ([#550](https://github.com/algorand/pyteal/pull/550)) # 0.18.1