diff --git a/pyteal/__init__.pyi b/pyteal/__init__.pyi index e88e0f6de..4f1477a90 100644 --- a/pyteal/__init__.pyi +++ b/pyteal/__init__.pyi @@ -23,6 +23,7 @@ from pyteal.errors import ( from pyteal.config import MAX_GROUP_SIZE, NUM_SLOTS __all__ = [ + "ABIReturnSubroutine", "AccountParam", "Add", "Addr", diff --git a/pyteal/ast/__init__.py b/pyteal/ast/__init__.py index 0a05795a2..79f3beee5 100644 --- a/pyteal/ast/__init__.py +++ b/pyteal/ast/__init__.py @@ -102,7 +102,7 @@ from pyteal.ast.substring import Substring, Extract, Suffix # more ops -from pyteal.ast.naryexpr import NaryExpr, And, Add, Mul, Or, Concat +from pyteal.ast.naryexpr import NaryExpr, Add, And, Mul, Or, Concat from pyteal.ast.widemath import WideRatio # control flow @@ -118,6 +118,7 @@ SubroutineDeclaration, SubroutineCall, SubroutineFnWrapper, + ABIReturnSubroutine, ) from pyteal.ast.while_ import While from pyteal.ast.for_ import For @@ -242,6 +243,7 @@ "SubroutineDeclaration", "SubroutineCall", "SubroutineFnWrapper", + "ABIReturnSubroutine", "ScratchIndex", "ScratchLoad", "ScratchSlot", diff --git a/pyteal/ast/abi/__init__.py b/pyteal/ast/abi/__init__.py index 47e61cbab..b9cb974ce 100644 --- a/pyteal/ast/abi/__init__.py +++ b/pyteal/ast/abi/__init__.py @@ -4,7 +4,7 @@ Address, AddressLength, ) -from pyteal.ast.abi.type import TypeSpec, BaseType, ComputedValue +from pyteal.ast.abi.type import TypeSpec, BaseType, ComputedValue, ReturnedValue from pyteal.ast.abi.bool import BoolTypeSpec, Bool from pyteal.ast.abi.uint import ( UintTypeSpec, @@ -47,6 +47,7 @@ "TypeSpec", "BaseType", "ComputedValue", + "ReturnedValue", "BoolTypeSpec", "Bool", "UintTypeSpec", diff --git a/pyteal/ast/abi/type.py b/pyteal/ast/abi/type.py index 97304062f..52a57fec8 100644 --- a/pyteal/ast/abi/type.py +++ b/pyteal/ast/abi/type.py @@ -147,7 +147,7 @@ class ComputedValue(ABC, Generic[T]): """Represents an ABI Type whose value must be computed by an expression.""" @abstractmethod - def produced_type_spec(cls) -> TypeSpec: + def produced_type_spec(self) -> TypeSpec: """Get the ABI TypeSpec that this object produces.""" pass @@ -182,3 +182,40 @@ def use(self, action: Callable[[T], Expr]) -> Expr: ComputedValue.__module__ = "pyteal" + + +class ReturnedValue(ComputedValue): + def __init__(self, type_spec: TypeSpec, computation_expr: Expr): + from pyteal.ast.subroutine import SubroutineCall + + self.type_spec = type_spec + if not isinstance(computation_expr, SubroutineCall): + raise TealInputError( + f"Expecting computation_expr to be SubroutineCall but get {type(computation_expr)}" + ) + self.computation = computation_expr + + def produced_type_spec(self) -> TypeSpec: + return self.type_spec + + def store_into(self, output: BaseType) -> Expr: + if output.type_spec() != self.produced_type_spec(): + raise TealInputError( + f"expected type_spec {self.produced_type_spec()} but get {output.type_spec()}" + ) + + declaration = self.computation.subroutine.get_declaration() + + if declaration.deferred_expr is None: + raise TealInputError( + "ABI return subroutine must have deferred_expr to be not-None." + ) + if declaration.deferred_expr.type_of() != output.type_spec().storage_type(): + raise TealInputError( + f"ABI return subroutine deferred_expr is expected to be typed {output.type_spec().storage_type()}, " + f"but has type {declaration.deferred_expr.type_of()}." + ) + return output.stored_value.slot.store(self.computation) + + +ReturnedValue.__module__ = "pyteal" diff --git a/pyteal/ast/abi/util.py b/pyteal/ast/abi/util.py index d3393f3e7..6cd3aa7c9 100644 --- a/pyteal/ast/abi/util.py +++ b/pyteal/ast/abi/util.py @@ -12,7 +12,7 @@ def substringForDecoding( *, startIndex: Expr = None, endIndex: Expr = None, - length: Expr = None + length: Expr = None, ) -> Expr: """A helper function for getting the substring to decode according to the rules of BaseType.decode.""" if length is not None and endIndex is not None: diff --git a/pyteal/ast/subroutine.py b/pyteal/ast/subroutine.py index 41aafe7e2..bd6af5918 100644 --- a/pyteal/ast/subroutine.py +++ b/pyteal/ast/subroutine.py @@ -1,19 +1,9 @@ -from inspect import get_annotations, Parameter, signature -from typing import ( - Callable, - Dict, - List, - Optional, - Union, - TYPE_CHECKING, - Tuple, - cast, - Any, -) +from dataclasses import dataclass +from inspect import isclass, Parameter, signature, get_annotations from types import MappingProxyType +from typing import Callable, Optional, TYPE_CHECKING, cast, Any, Final from pyteal.ast import abi -from pyteal.ast.abi.type import TypeSpec from pyteal.ast.expr import Expr from pyteal.ast.seq import Seq from pyteal.ast.scratchvar import DynamicScratchVar, ScratchVar @@ -37,6 +27,7 @@ def __init__( implementation: Callable[..., Expr], return_type: TealType, name_str: Optional[str] = None, + has_abi_output: bool = False, ) -> None: """ Args: @@ -44,6 +35,7 @@ def __init__( return_type: the TealType to be returned by the subroutine name_str (optional): the name that is used to identify the subroutine. If omitted, the name defaults to the implementation's __name__ attribute + has_abi_output (optional): the boolean that tells if ABI output kwarg for subroutine is used. """ super().__init__() self.id = SubroutineDefinition.nextSubroutineId @@ -53,38 +45,35 @@ def __init__( self.declaration: Optional["SubroutineDeclaration"] = None self.implementation: Callable = implementation + self.has_abi_output: bool = has_abi_output + + self.implementation_params: MappingProxyType[str, Parameter] + self.annotations: dict[str, type] + self.expected_arg_types: list[type[Expr] | type[ScratchVar] | abi.TypeSpec] + self.by_ref_args: set[str] + self.abi_args: dict[str, abi.TypeSpec] + self.output_kwarg: dict[str, abi.TypeSpec] ( - impl_params, - annotations, - expected_arg_types, - by_ref_args, - abi_args, + self.implementation_params, + self.annotations, + self.expected_arg_types, + self.by_ref_args, + self.abi_args, + self.output_kwarg, ) = self._validate() - self.implementation_params: MappingProxyType[str, Parameter] = impl_params - self.annotations: dict[str, Expr | ScratchVar] = annotations - self.expected_arg_types: list[type | TypeSpec] = expected_arg_types - self.by_ref_args: set[str] = by_ref_args - self.abi_args: Dict[str, abi.TypeSpec] = abi_args self.__name: str = name_str if name_str else self.implementation.__name__ - @staticmethod - def is_abi_annotation(obj: Any) -> bool: - try: - abi.type_spec_from_annotation(obj) - return True - except TypeError: - return False - def _validate( self, input_types: list[TealType] = None ) -> tuple[ MappingProxyType[str, Parameter], - dict[str, Expr | ScratchVar], - list[type | TypeSpec], + dict[str, type], + list[type[Expr] | type[ScratchVar] | abi.TypeSpec], set[str], - Dict[str, abi.TypeSpec], + dict[str, abi.TypeSpec], + dict[str, abi.TypeSpec], ]: """Validate the full function signature and annotations for subroutine definition. @@ -108,35 +97,58 @@ def _validate( We load the ABI scratch space stored value to stack, and store them later in subroutine's local ABI values. Args: - input_types: optional, containing the TealType of input expression. + input_types (optional): for testing purposes - expected `TealType`s of each parameter + Returns: + impl_params: a map from python function implementation's argument name, to argument's parameter. + annotations: a dict whose keys are names of type-annotated arguments, + and values are appearing type-annotations. + arg_types: a list of argument type inferred from python function implementation, + containing [type[Expr]| type[ScratchVar] | abi.TypeSpec]. + by_ref_args: a list of argument names that are passed in Subroutine with by-reference mechanism. + abi_args: a dict whose keys are names of ABI arguments, and values are their ABI type-specs. + abi_output_kwarg (might be empty): a dict whose key is the name of ABI output keyword argument, + and the value is the corresponding ABI type-spec. + NOTE: this dict might be empty, when we are defining a normal subroutine, + but it has at most one element when we define an ABI-returning subroutine. """ - implementation = self.implementation - if not callable(implementation): + + if not callable(self.implementation): raise TealInputError("Input to SubroutineDefinition is not callable") - implementation_params: MappingProxyType[str, Parameter] = signature( - implementation + impl_params: MappingProxyType[str, Parameter] = signature( + self.implementation ).parameters - annotations: dict[str, Expr | ScratchVar] = get_annotations(implementation) - expected_arg_types: list[type | TypeSpec] = [] + annotations: dict[str, type] = get_annotations(self.implementation) + arg_types: list[type[Expr] | type[ScratchVar] | abi.TypeSpec] = [] by_ref_args: set[str] = set() abi_args: dict[str, abi.TypeSpec] = {} + abi_output_kwarg: dict[str, abi.TypeSpec] = {} - if input_types is not None and len(input_types) != len(implementation_params): - raise TealInputError( - f"Provided number of input_types ({len(input_types)}) " - f"does not match detected number of parameters ({len(implementation_params)})" - ) + if input_types is not None: + if len(input_types) != len(impl_params): + raise TealInputError( + f"Provided number of input_types ({len(input_types)}) " + f"does not match detected number of parameters ({len(impl_params)})" + ) + for in_type, name in zip(input_types, impl_params): + if not isinstance(in_type, TealType): + raise TealInputError( + f"Function has input type {in_type} for parameter {name} which is not a TealType" + ) if "return" in annotations and annotations["return"] is not Expr: raise TealInputError( f"Function has return of disallowed type {annotations['return']}. Only Expr is allowed" ) - for i, (name, param) in enumerate(implementation_params.items()): + for name, param in impl_params.items(): if param.kind not in ( Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD, + ) and not ( + param.kind is Parameter.KEYWORD_ONLY + and self.has_abi_output + and name == ABIReturnSubroutine.OUTPUT_ARG_NAME ): raise TealInputError( f"Function has a parameter type that is not allowed in a subroutine: parameter {name} with type {param.kind}" @@ -147,35 +159,47 @@ def _validate( f"Function has a parameter with a default value, which is not allowed in a subroutine: {name}" ) - if input_types: - intype = input_types[i] - if not isinstance(intype, TealType): + expected_arg_type = self._validate_annotation(annotations, name) + + if param.kind is Parameter.KEYWORD_ONLY: + # this case is only entered when + # - `self.has_abi_output is True` + # - `name == ABIReturnSubroutine.OUTPUT_ARG_NAME` + if not isinstance(expected_arg_type, abi.TypeSpec): raise TealInputError( - f"Function has input type {intype} for parameter {name} which is not a TealType" + f"Function keyword parameter {name} has type {expected_arg_type}" ) + abi_output_kwarg[name] = expected_arg_type + continue - expected_arg_type = self._validate_parameter_type(annotations, name) - - expected_arg_types.append(expected_arg_type) + arg_types.append(expected_arg_type) if expected_arg_type is ScratchVar: by_ref_args.add(name) if isinstance(expected_arg_type, abi.TypeSpec): abi_args[name] = expected_arg_type return ( - implementation_params, + impl_params, annotations, - expected_arg_types, + arg_types, by_ref_args, abi_args, + abi_output_kwarg, ) @staticmethod - def _validate_parameter_type( - user_defined_annotations: dict, parameter_name: str - ) -> type | TypeSpec: - ptype = user_defined_annotations.get(parameter_name, None) + def _is_abi_annotation(obj: Any) -> bool: + try: + abi.type_spec_from_annotation(obj) + return True + except TypeError: + return False + @staticmethod + def _validate_annotation( + user_defined_annotations: dict[str, Any], parameter_name: str + ) -> type[Expr] | type[ScratchVar] | abi.TypeSpec: + ptype = user_defined_annotations.get(parameter_name, None) if ptype is None: # Without a type annotation, `SubroutineDefinition` presumes an implicit `Expr` declaration # rather than these alternatives: @@ -188,40 +212,46 @@ def _validate_parameter_type( # when `Expr` is the only supported annotation type. # * `invoke` type checks provided arguments against parameter types to catch mismatches. return Expr - else: - if ptype in (Expr, ScratchVar): - return ptype - - if SubroutineDefinition.is_abi_annotation(ptype): - return abi.type_spec_from_annotation(ptype) - + if ptype in (Expr, ScratchVar): + return ptype + if SubroutineDefinition._is_abi_annotation(ptype): + return abi.type_spec_from_annotation(ptype) + if not isclass(ptype): raise TealInputError( - f"Function has parameter {parameter_name} of disallowed type {ptype}. " - f"Only the types {(Expr, ScratchVar, 'ABI')} are allowed" + f"Function has parameter {parameter_name} of declared type {ptype} which is not a class" ) + raise TealInputError( + f"Function has parameter {parameter_name} of disallowed type {ptype}. " + f"Only the types {(Expr, ScratchVar, 'ABI')} are allowed" + ) - def getDeclaration(self) -> "SubroutineDeclaration": + def get_declaration(self) -> "SubroutineDeclaration": if self.declaration is None: # lazy evaluate subroutine - self.declaration = evaluateSubroutine(self) + self.declaration = evaluate_subroutine(self) return self.declaration def name(self) -> str: return self.__name - def argumentCount(self) -> int: - return len(self.implementation_params) + def argument_count(self) -> int: + return len(self.arguments()) - def arguments(self) -> List[str]: - return list(self.implementation_params.keys()) + def arguments(self) -> list[str]: + syntax_args = list(self.implementation_params.keys()) + syntax_args = [ + arg_name for arg_name in syntax_args if arg_name not in self.output_kwarg + ] + return syntax_args def invoke( - self, args: List[Union[Expr, ScratchVar, abi.BaseType]] + self, + args: list[Expr | ScratchVar | abi.BaseType], ) -> "SubroutineCall": - if len(args) != self.argumentCount(): + if len(args) != self.argument_count(): raise TealInputError( f"Incorrect number of arguments for subroutine call. " - f"Expected {self.argumentCount()} arguments, got {len(args)}" + f"Expected {self.arguments()} arguments, got {len(args)} arguments" ) for i, arg in enumerate(args): @@ -245,7 +275,9 @@ def invoke( f"should have ABI typespec {arg_type} but got {arg.type_spec()}" ) - return SubroutineCall(self, args) + return SubroutineCall( + self, args, output_kwarg=OutputKwArgInfo.from_dict(self.output_kwarg) + ) def __str__(self): return f"subroutine#{self.id}" @@ -263,10 +295,16 @@ def __hash__(self): class SubroutineDeclaration(Expr): - def __init__(self, subroutine: SubroutineDefinition, body: Expr) -> None: + def __init__( + self, + subroutine: SubroutineDefinition, + body: Expr, + deferred_expr: Optional[Expr] = None, + ) -> None: super().__init__() self.subroutine = subroutine self.body = body + self.deferred_expr = deferred_expr def __teal__(self, options: "CompileOptions"): return self.body.__teal__(options) @@ -284,15 +322,39 @@ def has_return(self): SubroutineDeclaration.__module__ = "pyteal" +@dataclass +class OutputKwArgInfo: + name: str + abi_type: abi.TypeSpec + + @staticmethod + def from_dict(kwarg_info: dict[str, abi.TypeSpec]) -> Optional["OutputKwArgInfo"]: + match list(kwarg_info.keys()): + case []: + return None + case [k]: + return OutputKwArgInfo(k, kwarg_info[k]) + case _: + raise TealInputError( + f"illegal conversion kwarg_info length {len(kwarg_info)}." + ) + + +OutputKwArgInfo.__module__ = "pyteal" + + class SubroutineCall(Expr): def __init__( self, subroutine: SubroutineDefinition, - args: List[Union[Expr, ScratchVar, abi.BaseType]], + args: list[Expr | ScratchVar | abi.BaseType], + *, + output_kwarg: OutputKwArgInfo = None, ) -> None: super().__init__() self.subroutine = subroutine self.args = args + self.output_kwarg = output_kwarg for i, arg in enumerate(args): if isinstance(arg, Expr): @@ -315,7 +377,7 @@ def __teal__(self, options: "CompileOptions"): """ Generate the subroutine's start and end teal blocks. The subroutine's arguments are pushed on the stack to be picked up into local scratch variables. - There are 2 cases to consider for the pushed arg expression: + There are 4 cases to consider for the pushed arg expression: 1. (by-value) In the case of typical arguments of type Expr, the expression ITSELF is evaluated for the stack and will be stored in a local ScratchVar for subroutine evaluation @@ -325,6 +387,10 @@ def __teal__(self, options: "CompileOptions"): 3. (ABI, or a special case in by-value) In this case, the storage of an ABI value are loaded to the stack and will be stored in a local ABI value for subroutine evaluation + + 4. (ABI output keyword argument) In this case, we do not place ABI values (encoding) on the stack. + This is an *output-only* argument: in `evaluate_subroutine` an ABI typed instance for subroutine evaluation + will be generated, and gets in to construct the subroutine implementation. """ verifyTealVersion( Op.callsub.min_version, @@ -332,7 +398,7 @@ def __teal__(self, options: "CompileOptions"): "TEAL version too low to use SubroutineCall expression", ) - def handle_arg(arg: Union[Expr, ScratchVar, abi.BaseType]) -> Expr: + def handle_arg(arg: Expr | ScratchVar | abi.BaseType) -> Expr: if isinstance(arg, ScratchVar): return arg.index() elif isinstance(arg, Expr): @@ -345,14 +411,15 @@ def handle_arg(arg: Union[Expr, ScratchVar, abi.BaseType]) -> Expr: ) op = TealOp(self, Op.callsub, self.subroutine) - return TealBlock.FromOp(options, op, *(handle_arg(x) for x in self.args)) + return TealBlock.FromOp(options, op, *[handle_arg(x) for x in self.args]) def __str__(self): - ret_str = '(SubroutineCall "' + self.subroutine.name() + '" (' - for a in self.args: - ret_str += " " + a.__str__() - ret_str += "))" - return ret_str + arg_str_list = list(map(str, self.args)) + if self.output_kwarg: + arg_str_list.append( + f"{self.output_kwarg.name}={str(self.output_kwarg.abi_type)}" + ) + return f'(SubroutineCall {self.subroutine.name()} ({" ".join(arg_str_list)}))' def type_of(self): return self.subroutine.return_type @@ -377,7 +444,7 @@ def __init__( name_str=name, ) - def __call__(self, *args: Expr | ScratchVar | abi.BaseType, **kwargs) -> Expr: + def __call__(self, *args: Expr | ScratchVar | abi.BaseType, **kwargs: Any) -> Expr: if len(kwargs) != 0: raise TealInputError( f"Subroutine cannot be called with keyword arguments. " @@ -389,15 +456,137 @@ def name(self) -> str: return self.subroutine.name() def type_of(self): - return self.subroutine.getDeclaration().type_of() + return self.subroutine.get_declaration().type_of() def has_return(self): - return self.subroutine.getDeclaration().has_return() + return self.subroutine.get_declaration().has_return() SubroutineFnWrapper.__module__ = "pyteal" +class ABIReturnSubroutine: + """Used to create a PyTeal Subroutine (returning an ABI value) from a python function. + + This class is meant to be used as a function decorator. For example: + + .. code-block:: python + + @ABIReturnSubroutine + def abi_sum(toSum: abi.DynamicArray[abi.Uint64], *, output: abi.Uint64) -> Expr: + i = ScratchVar(TealType.uint64) + valueAtIndex = abi.Uint64() + return Seq( + output.set(0), + For(i.store(Int(0)), i.load() < toSum.length(), i.store(i.load() + Int(1))).Do( + Seq( + toSum[i.load()].store_into(valueAtIndex), + output.set(output.get() + valueAtIndex.get()), + ) + ), + ) + + program = Seq( + (to_sum_arr := abi.make(abi.DynamicArray[abi.Uint64])).decode( + Txn.application_args[1] + ), + (res := abi.Uint64()).set(abi_sum(to_sum_arr)), + abi.MethodReturn(res), + Int(1), + ) + """ + + OUTPUT_ARG_NAME: Final[str] = "output" + + def __init__( + self, + fn_implementation: Callable[..., Expr], + ) -> None: + self.output_kwarg_info: Optional[OutputKwArgInfo] = self._get_output_kwarg_info( + fn_implementation + ) + self.subroutine = SubroutineDefinition( + fn_implementation, + return_type=TealType.none, + has_abi_output=self.output_kwarg_info is not None, + ) + + @classmethod + def _get_output_kwarg_info( + cls, fn_implementation: Callable[..., Expr] + ) -> Optional[OutputKwArgInfo]: + if not callable(fn_implementation): + raise TealInputError("Input to ABIReturnSubroutine is not callable") + sig = signature(fn_implementation) + fn_annotations = get_annotations(fn_implementation) + + potential_abi_arg_names = [ + k for k, v in sig.parameters.items() if v.kind == Parameter.KEYWORD_ONLY + ] + + match potential_abi_arg_names: + case []: + return None + case [name]: + if name != cls.OUTPUT_ARG_NAME: + raise TealInputError( + f"ABI return subroutine output-kwarg name must be `output` at this moment, " + f"while {name} is the keyword." + ) + annotation = fn_annotations.get(name, None) + if annotation is None: + raise TealInputError( + f"ABI return subroutine output-kwarg {name} must specify ABI type" + ) + type_spec = abi.type_spec_from_annotation(annotation) + return OutputKwArgInfo(name, type_spec) + case _: + raise TealInputError( + f"multiple output arguments ({len(potential_abi_arg_names)}) " + f"with type annotations {potential_abi_arg_names}" + ) + + def __call__( + self, *args: Expr | ScratchVar | abi.BaseType, **kwargs + ) -> abi.ReturnedValue | Expr: + if len(kwargs) != 0: + raise TealInputError( + f"Subroutine cannot be called with keyword arguments. " + f"Received keyword arguments: {', '.join(kwargs.keys())}" + ) + + invoked = self.subroutine.invoke(list(args)) + if self.output_kwarg_info is None: + if invoked.type_of() != TealType.none: + raise TealInputError( + "ABI subroutine with void type should be evaluated to TealType.none" + ) + return invoked + + self.subroutine.get_declaration() + + return abi.ReturnedValue( + self.output_kwarg_info.abi_type, + invoked, + ) + + def name(self) -> str: + return self.subroutine.name() + + def type_of(self) -> str | abi.TypeSpec: + return ( + "void" + if self.output_kwarg_info is None + else self.output_kwarg_info.abi_type + ) + + def is_registrable(self) -> bool: + return len(self.subroutine.abi_args) == self.subroutine.argument_count() + + +ABIReturnSubroutine.__module__ = "pyteal" + + class Subroutine: """Used to create a PyTeal subroutine from a Python function. @@ -436,7 +625,7 @@ def __call__(self, fn_implementation: Callable[..., Expr]) -> SubroutineFnWrappe Subroutine.__module__ = "pyteal" -def evaluateSubroutine(subroutine: SubroutineDefinition) -> SubroutineDeclaration: +def evaluate_subroutine(subroutine: SubroutineDefinition) -> SubroutineDeclaration: """ Puts together the data necessary to define the code for a subroutine. "evaluate" is used here to connote evaluating the PyTEAL AST into a SubroutineDeclaration, @@ -446,7 +635,7 @@ def evaluateSubroutine(subroutine: SubroutineDefinition) -> SubroutineDeclaratio 2 Argument Usages / Code-Paths - -------- ------ ---------- - Usage (A) for run-time: "argumentVars" --reverse--> "bodyOps" + Usage (A) for run-time: "argumentVars" --reverse--> "body_ops" These are "store" expressions that pick up parameters that have been pre-placed on the stack prior to subroutine invocation. The argumentVars are stored into local scratch space to be used by the TEAL subroutine. @@ -461,53 +650,89 @@ def evaluateSubroutine(subroutine: SubroutineDefinition) -> SubroutineDeclaratio Type 1 (by-value): these have python type Expr Type 2 (by-reference): these have python type ScratchVar Type 3 (ABI): these are ABI typed variables with scratch space storage, and still pass by value + Type 4 (ABI-output-arg): ABI typed variables with scratch space, a new ABI instance is generated inside function body, + not one of the cases in the previous three options Usage (A) "argumentVars" - Storing pre-placed stack variables into local scratch space: Type 1. (by-value) use ScratchVar.store() to pick the actual value into a local scratch space Type 2. (by-reference) ALSO use ScratchVar.store() to pick up from the stack NOTE: SubroutineCall.__teal__() has placed the _SLOT INDEX_ on the stack so this is stored into the local scratch space Type 3. (ABI) abi_value.stored_value.store() to pick from the stack + Type 4. (ABI-output-arg) it is not really used here, since it is only generated internal of the subroutine Usage (B) "loadedArgs" - Passing through to an invoked PyTEAL subroutine AST: Type 1. (by-value) use ScratchVar.load() to have an Expr that can be compiled in python by the PyTEAL subroutine Type 2. (by-reference) use a DynamicScratchVar as the user will have written the PyTEAL in a way that satisfies the ScratchVar API. I.e., the user will write `x.load()` and `x.store(val)` as opposed to just `x`. Type 3. (ABI) use abi_value itself after storing stack value into scratch space. + Type 4. (ABI-output-arg) generates a new instance of the ABI value, + and appends a return expression of stored value of the ABI keyword value. """ def var_n_loaded( param: str, - ) -> Tuple[ScratchVar, Union[ScratchVar, abi.BaseType, Expr]]: - loaded: Union[ScratchVar, abi.BaseType, Expr] - argVar: ScratchVar + ) -> tuple[ScratchVar, ScratchVar | abi.BaseType | Expr]: + loaded_var: ScratchVar | abi.BaseType | Expr + argument_var: ScratchVar if param in subroutine.by_ref_args: - argVar = DynamicScratchVar(TealType.anytype) - loaded = argVar + argument_var = DynamicScratchVar(TealType.anytype) + loaded_var = argument_var elif param in subroutine.abi_args: internal_abi_var = subroutine.abi_args[param].new_instance() - argVar = internal_abi_var.stored_value - loaded = internal_abi_var + argument_var = internal_abi_var.stored_value + loaded_var = internal_abi_var else: - argVar = ScratchVar(TealType.anytype) - loaded = argVar.load() + argument_var = ScratchVar(TealType.anytype) + loaded_var = argument_var.load() - return argVar, loaded + return argument_var, loaded_var + + if len(subroutine.output_kwarg) > 1: + raise TealInputError( + f"ABI keyword argument num: {len(subroutine.output_kwarg)}. " + f"Exceeding abi output keyword argument max number 1." + ) args = subroutine.arguments() - argumentVars, loadedArgs = zip(*map(var_n_loaded, args)) if args else ([], []) + + arg_vars: list[ScratchVar] = [] + loaded_args: list[ScratchVar | Expr | abi.BaseType] = [] + for arg in args: + arg_var, loaded_arg = var_n_loaded(arg) + arg_vars.append(arg_var) + loaded_args.append(loaded_arg) + + abi_output_kwargs: dict[str, abi.BaseType] = {} + output_kwarg_info = OutputKwArgInfo.from_dict(subroutine.output_kwarg) + output_carrying_abi: Optional[abi.BaseType] = None + + if output_kwarg_info: + output_carrying_abi = output_kwarg_info.abi_type.new_instance() + abi_output_kwargs[output_kwarg_info.name] = output_carrying_abi # Arg usage "B" supplied to build an AST from the user-defined PyTEAL function: - subroutineBody = subroutine.implementation(*loadedArgs) + subroutine_body = subroutine.implementation(*loaded_args, **abi_output_kwargs) - if not isinstance(subroutineBody, Expr): + if not isinstance(subroutine_body, Expr): raise TealInputError( - f"Subroutine function does not return a PyTeal expression. Got type {type(subroutineBody)}" + f"Subroutine function does not return a PyTeal expression. Got type {type(subroutine_body)}." ) + deferred_expr: Optional[Expr] = None + + # if there is an output keyword argument for ABI, place the storing on the stack + if output_carrying_abi: + if subroutine_body.type_of() != TealType.none: + raise TealInputError( + f"ABI returning subroutine definition should evaluate to TealType.none, " + f"while evaluate to {subroutine_body.type_of()}." + ) + deferred_expr = output_carrying_abi.stored_value.load() + # Arg usage "A" to be pick up and store in scratch parameters that have been placed on the stack # need to reverse order of argumentVars because the last argument will be on top of the stack - bodyOps = [var.slot.store() for var in argumentVars[::-1]] - bodyOps.append(subroutineBody) + body_ops = [var.slot.store() for var in arg_vars[::-1]] + body_ops.append(subroutine_body) - return SubroutineDeclaration(subroutine, Seq(bodyOps)) + return SubroutineDeclaration(subroutine, Seq(body_ops), deferred_expr) diff --git a/pyteal/ast/subroutine_test.py b/pyteal/ast/subroutine_test.py index 77ba08922..0a90f33ee 100644 --- a/pyteal/ast/subroutine_test.py +++ b/pyteal/ast/subroutine_test.py @@ -2,11 +2,12 @@ from typing import List, Literal import pytest +from dataclasses import dataclass import pyteal as pt -from pyteal.ast.subroutine import evaluateSubroutine +from pyteal.ast.subroutine import evaluate_subroutine -options = pt.CompileOptions(version=4) +options = pt.CompileOptions(version=5) def test_subroutine_definition(): @@ -58,7 +59,7 @@ def fnWithPartialExprAnnotations(a, b: pt.Expr) -> pt.Expr: for (fn, numArgs, name) in cases: definition = pt.SubroutineDefinition(fn, pt.TealType.none) - assert definition.argumentCount() == numArgs + assert definition.argument_count() == numArgs assert definition.name() == name if numArgs > 0: @@ -79,15 +80,130 @@ def fnWithPartialExprAnnotations(a, b: pt.Expr) -> pt.Expr: assert invocation.args == args +@dataclass +class ABISubroutineTC: + definition: pt.ABIReturnSubroutine + arg_instances: list[pt.Expr | pt.abi.BaseType] + name: str + ret_type: str | pt.abi.TypeSpec + + +def test_abi_subroutine_definition(): + @pt.ABIReturnSubroutine + def fn_0arg_0ret() -> pt.Expr: + return pt.Return() + + @pt.ABIReturnSubroutine + def fn_0arg_uint64_ret(*, output: pt.abi.Uint64) -> pt.Expr: + return output.set(1) + + @pt.ABIReturnSubroutine + def fn_1arg_0ret(a: pt.abi.Uint64) -> pt.Expr: + return pt.Return() + + @pt.ABIReturnSubroutine + def fn_1arg_1ret(a: pt.abi.Uint64, *, output: pt.abi.Uint64) -> pt.Expr: + return output.set(a) + + @pt.ABIReturnSubroutine + def fn_2arg_0ret( + a: pt.abi.Uint64, b: pt.abi.StaticArray[pt.abi.Byte, Literal[10]] + ) -> pt.Expr: + return pt.Return() + + @pt.ABIReturnSubroutine + def fn_2arg_1ret( + a: pt.abi.Uint64, + b: pt.abi.StaticArray[pt.abi.Byte, Literal[10]], + *, + output: pt.abi.Byte, + ) -> pt.Expr: + return output.set(b[a.get() % pt.Int(10)]) + + @pt.ABIReturnSubroutine + def fn_2arg_1ret_with_expr( + a: pt.Expr, + b: pt.abi.StaticArray[pt.abi.Byte, Literal[10]], + *, + output: pt.abi.Byte, + ) -> pt.Expr: + return output.set(b[a % pt.Int(10)]) + + cases = ( + ABISubroutineTC(fn_0arg_0ret, [], "fn_0arg_0ret", "void"), + ABISubroutineTC( + fn_0arg_uint64_ret, [], "fn_0arg_uint64_ret", pt.abi.Uint64TypeSpec() + ), + ABISubroutineTC(fn_1arg_0ret, [pt.abi.Uint64()], "fn_1arg_0ret", "void"), + ABISubroutineTC( + fn_1arg_1ret, [pt.abi.Uint64()], "fn_1arg_1ret", pt.abi.Uint64TypeSpec() + ), + ABISubroutineTC( + fn_2arg_0ret, + [ + pt.abi.Uint64(), + pt.abi.StaticArray( + pt.abi.StaticArrayTypeSpec(pt.abi.ByteTypeSpec(), 10) + ), + ], + "fn_2arg_0ret", + "void", + ), + ABISubroutineTC( + fn_2arg_1ret, + [ + pt.abi.Uint64(), + pt.abi.StaticArray( + pt.abi.StaticArrayTypeSpec(pt.abi.ByteTypeSpec(), 10) + ), + ], + "fn_2arg_1ret", + pt.abi.ByteTypeSpec(), + ), + ABISubroutineTC( + fn_2arg_1ret_with_expr, + [ + pt.Int(5), + pt.abi.StaticArray( + pt.abi.StaticArrayTypeSpec(pt.abi.ByteTypeSpec(), 10) + ), + ], + "fn_2arg_1ret_with_expr", + pt.abi.ByteTypeSpec(), + ), + ) + + for case in cases: + assert case.definition.subroutine.argument_count() == len(case.arg_instances) + assert case.definition.name() == case.name + + if len(case.arg_instances) > 0: + with pytest.raises(pt.TealInputError): + case.definition(*case.arg_instances[:-1]) + + with pytest.raises(pt.TealInputError): + case.definition(*(case.arg_instances + [pt.abi.Uint64()])) + + assert case.definition.type_of() == case.ret_type + invoked = case.definition(*case.arg_instances) + assert isinstance( + invoked, (pt.Expr if case.ret_type == "void" else pt.abi.ReturnedValue) + ) + assert case.definition.is_registrable() == all( + map(lambda x: isinstance(x, pt.abi.BaseType), case.arg_instances) + ) + + def test_subroutine_definition_validate(): """ DFS through SubroutineDefinition.validate()'s logic """ - def mock_subroutine_definition(implementation): + def mock_subroutine_definition(implementation, has_abi_output=False): mock = pt.SubroutineDefinition(lambda: pt.Return(pt.Int(1)), pt.TealType.uint64) mock._validate() # haven't failed with dummy implementation mock.implementation = implementation + mock.has_abi_output = has_abi_output return mock not_callable = mock_subroutine_definition("I'm not callable") @@ -98,6 +214,8 @@ def mock_subroutine_definition(implementation): "Input to SubroutineDefinition is not callable" ) + # input_types: + three_params = mock_subroutine_definition(lambda x, y, z: pt.Return(pt.Int(1))) two_inputs = [pt.TealType.uint64, pt.TealType.bytes] with pytest.raises(pt.TealInputError) as tie: @@ -107,14 +225,22 @@ def mock_subroutine_definition(implementation): "Provided number of input_types (2) does not match detected number of parameters (3)" ) - params, anns, arg_types, byrefs, abis = three_params._validate() + three_inputs_with_a_wrong_type = [pt.TealType.uint64, pt.Expr, pt.TealType.bytes] + + with pytest.raises(pt.TealInputError) as tie: + three_params._validate(input_types=three_inputs_with_a_wrong_type) + + assert tie.value == pt.TealInputError( + "Function has input type for parameter y which is not a TealType" + ) + + params, anns, arg_types, byrefs, abi_args, output_kwarg = three_params._validate() assert len(params) == 3 assert anns == {} assert all(at is pt.Expr for at in arg_types) assert byrefs == set() - assert abis == {} - - # return validation: + assert abi_args == {} + assert output_kwarg == {} def bad_return_impl() -> str: return pt.Return(pt.Int(1)) # type: ignore @@ -127,25 +253,31 @@ def bad_return_impl() -> str: "Function has return of disallowed type . Only Expr is allowed" ) - # param validation: + # now we iterate through the implementation params validating each as we go - var_positional_or_kw = three_params - params, anns, arg_types, byrefs, abis = var_positional_or_kw._validate() - assert len(params) == 3 - assert anns == {} - assert all(at is pt.Expr for at in arg_types) - assert byrefs == set() - assert abis == {} + def var_abi_output_impl(*, output: pt.abi.Uint16): + pt.Return(pt.Int(1)) # this is wrong but ignored - var_positional_only = mock_subroutine_definition( - lambda x, y, /, z: pt.Return(pt.Int(1)) + # raises without abi_output_arg_name: + var_abi_output_noname = mock_subroutine_definition(var_abi_output_impl) + with pytest.raises(pt.TealInputError) as tie: + var_abi_output_noname._validate() + + assert tie.value == pt.TealInputError( + "Function has a parameter type that is not allowed in a subroutine: parameter output with type KEYWORD_ONLY" ) - params, anns, arg_types, byrefs, abis = var_positional_only._validate() - assert len(params) == 3 - assert anns == {} + + # copacetic abi output: + var_abi_output = mock_subroutine_definition( + var_abi_output_impl, has_abi_output=True + ) + params, anns, arg_types, byrefs, abi_args, output_kwarg = var_abi_output._validate() + assert len(params) == 1 + assert anns == {"output": pt.abi.Uint16} assert all(at is pt.Expr for at in arg_types) assert byrefs == set() - assert abis == {} + assert abi_args == {} + assert output_kwarg == {"output": pt.abi.Uint16TypeSpec()} var_positional = mock_subroutine_definition(lambda *args: pt.Return(pt.Int(1))) with pytest.raises(pt.TealInputError) as tie: @@ -188,38 +320,44 @@ def bad_return_impl() -> str: "Function has input type for parameter y which is not a TealType" ) - # Now we get to _validate_parameter_type(): + # Now we get to _validate_annotation(): one_vanilla = mock_subroutine_definition(lambda x: pt.Return(pt.Int(1))) - params, anns, arg_types, byrefs, abis = one_vanilla._validate() + params, anns, arg_types, byrefs, abi_args, output_kwarg = one_vanilla._validate() assert len(params) == 1 assert anns == {} assert all(at is pt.Expr for at in arg_types) assert byrefs == set() - assert abis == {} + assert abi_args == {} + assert output_kwarg == {} def one_expr_impl(x: pt.Expr): return pt.Return(pt.Int(1)) one_expr = mock_subroutine_definition(one_expr_impl) - params, anns, arg_types, byrefs, abis = one_expr._validate() + params, anns, arg_types, byrefs, abi_args, output_kwarg = one_expr._validate() assert len(params) == 1 assert anns == {"x": pt.Expr} assert all(at is pt.Expr for at in arg_types) assert byrefs == set() - assert abis == {} + assert abi_args == {} + assert output_kwarg == {} def one_scratchvar_impl(x: pt.ScratchVar): return pt.Return(pt.Int(1)) one_scratchvar = mock_subroutine_definition(one_scratchvar_impl) - params, anns, arg_types, byrefs, abis = one_scratchvar._validate() + params, anns, arg_types, byrefs, abi_args, output_kwarg = one_scratchvar._validate() assert len(params) == 1 assert anns == {"x": pt.ScratchVar} assert all(at is pt.ScratchVar for at in arg_types) assert byrefs == {"x"} - assert abis == {} + assert abi_args == {} + assert output_kwarg == {} + + # for _is_abi_annotation() cf. copacetic x,y,z product below + # not is_class() def one_nontype_impl(x: "blahBlah"): # type: ignore # noqa: F821 return pt.Return(pt.Int(1)) @@ -228,7 +366,7 @@ def one_nontype_impl(x: "blahBlah"): # type: ignore # noqa: F821 one_nontype._validate() assert tie.value == pt.TealInputError( - "Function has parameter x of disallowed type blahBlah. Only the types (, , 'ABI') are allowed" + "Function has parameter x of declared type blahBlah which is not a class" ) def one_dynscratchvar_impl(x: pt.DynamicScratchVar): @@ -243,17 +381,21 @@ def one_dynscratchvar_impl(x: pt.DynamicScratchVar): ) # Now we're back to validate() and everything should be copacetic - - # input type handling: for x, y, z in product(pt.TealType, pt.TealType, pt.TealType): - params, anns, arg_types, byrefs, abis = three_params._validate( - input_types=[x, y, z] - ) + ( + params, + anns, + arg_types, + byrefs, + abi_args, + output_kwarg, + ) = three_params._validate(input_types=[x, y, z]) assert len(params) == 3 assert anns == {} assert all(at is pt.Expr for at in arg_types) assert byrefs == set() - assert abis == {} + assert abi_args == {} + assert output_kwarg == {} # annotation / abi type handling: abi_annotation_examples = { @@ -283,9 +425,9 @@ def mocker_impl(x: x_ann, y, z: z_ann): return pt.Return(pt.Int(1)) mocker = mock_subroutine_definition(mocker_impl) - params, anns, arg_types, byrefs, abis = mocker._validate() + params, anns, arg_types, byrefs, abis, output_kwarg = mocker._validate() print( - f"{x_ann=}, {z_ann=}, {params=}, {anns=}, {arg_types=}, {byrefs=}, {abis=}" + f"{x_ann=}, {z_ann=}, {params=}, {anns=}, {arg_types=}, {byrefs=}, {abis=}, {output_kwarg=}" ) assert len(params) == 3 @@ -451,7 +593,7 @@ def fnWithMixedAnns4(a: pt.ScratchVar, b, c: pt.abi.Uint16) -> pt.Expr: ] for case_name, fn, args, err in cases: definition = pt.SubroutineDefinition(fn, pt.TealType.none) - assert definition.argumentCount() == len(args), case_name + assert definition.argument_count() == len(args), case_name assert definition.name() == fn.__name__, case_name if err is None: @@ -475,11 +617,220 @@ def fnWithMixedAnns4(a: pt.ScratchVar, b, c: pt.abi.Uint16) -> pt.Expr: ), f"EXPECTED ERROR of type {err}. encountered unexpected error during invocation case <{case_name}>: {e}" +def test_abi_subroutine_calling_param_types(): + @pt.ABIReturnSubroutine + def fn_log_add(a: pt.abi.Uint64, b: pt.abi.Uint32) -> pt.Expr: + return pt.Seq(pt.Log(pt.Itob(a.get() + b.get())), pt.Return()) + + @pt.ABIReturnSubroutine + def fn_ret_add( + a: pt.abi.Uint64, b: pt.abi.Uint32, *, output: pt.abi.Uint64 + ) -> pt.Expr: + return output.set(a.get() + b.get() + pt.Int(0xA190)) + + @pt.ABIReturnSubroutine + def fn_abi_annotations_0( + a: pt.abi.Byte, + b: pt.abi.StaticArray[pt.abi.Uint32, Literal[10]], + c: pt.abi.DynamicArray[pt.abi.Bool], + ) -> pt.Expr: + return pt.Return() + + @pt.ABIReturnSubroutine + def fn_abi_annotations_0_with_ret( + a: pt.abi.Byte, + b: pt.abi.StaticArray[pt.abi.Uint32, Literal[10]], + c: pt.abi.DynamicArray[pt.abi.Bool], + *, + output: pt.abi.Byte, + ): + return output.set(a) + + @pt.ABIReturnSubroutine + def fn_mixed_annotations_0(a: pt.ScratchVar, b: pt.Expr, c: pt.abi.Byte) -> pt.Expr: + return pt.Seq( + a.store(c.get() * pt.Int(0x0FF1CE) * b), + pt.Return(), + ) + + @pt.ABIReturnSubroutine + def fn_mixed_annotations_0_with_ret( + a: pt.ScratchVar, b: pt.Expr, c: pt.abi.Byte, *, output: pt.abi.Uint64 + ) -> pt.Expr: + return pt.Seq( + a.store(c.get() * pt.Int(0x0FF1CE) * b), + output.set(a.load()), + ) + + @pt.ABIReturnSubroutine + def fn_mixed_annotation_1( + a: pt.ScratchVar, b: pt.abi.StaticArray[pt.abi.Uint32, Literal[10]] + ) -> pt.Expr: + return pt.Seq( + (intermediate := pt.abi.Uint32()).set(b[a.load() % pt.Int(10)]), + a.store(intermediate.get()), + pt.Return(), + ) + + @pt.ABIReturnSubroutine + def fn_mixed_annotation_1_with_ret( + a: pt.ScratchVar, b: pt.abi.Uint64, *, output: pt.abi.Bool + ) -> pt.Expr: + return output.set((a.load() + b.get()) % pt.Int(2)) + + abi_u64 = pt.abi.Uint64() + abi_u32 = pt.abi.Uint32() + abi_byte = pt.abi.Byte() + abi_static_u32_10 = pt.abi.StaticArray( + pt.abi.StaticArrayTypeSpec(pt.abi.Uint32TypeSpec(), 10) + ) + abi_dynamic_bool = pt.abi.DynamicArray( + pt.abi.DynamicArrayTypeSpec(pt.abi.BoolTypeSpec()) + ) + sv = pt.ScratchVar() + expr_int = pt.Int(1) + + cases = [ + ("vanilla 1", fn_log_add, [abi_u64, abi_u32], "void", None), + ( + "vanilla 1 with wrong ABI type", + fn_log_add, + [abi_u64, abi_u64], + None, + pt.TealInputError, + ), + ( + "vanilla 1 with ABI return", + fn_ret_add, + [abi_u64, abi_u32], + pt.abi.Uint64TypeSpec(), + None, + ), + ( + "vanilla 1 with ABI return wrong typed", + fn_ret_add, + [abi_u32, abi_u64], + None, + pt.TealInputError, + ), + ( + "full ABI annotations no return", + fn_abi_annotations_0, + [abi_byte, abi_static_u32_10, abi_dynamic_bool], + "void", + None, + ), + ( + "full ABI annotations wrong input 0", + fn_abi_annotations_0, + [abi_u64, abi_static_u32_10, abi_dynamic_bool], + None, + pt.TealInputError, + ), + ( + "full ABI annotations with ABI return", + fn_abi_annotations_0_with_ret, + [abi_byte, abi_static_u32_10, abi_dynamic_bool], + pt.abi.ByteTypeSpec(), + None, + ), + ( + "full ABI annotations with ABI return wrong inputs", + fn_abi_annotations_0_with_ret, + [abi_byte, abi_dynamic_bool, abi_static_u32_10], + None, + pt.TealInputError, + ), + ( + "mixed with ABI annotations 0", + fn_mixed_annotations_0, + [sv, expr_int, abi_byte], + "void", + None, + ), + ( + "mixed with ABI annotations 0 wrong inputs", + fn_mixed_annotations_0, + [abi_u64, expr_int, abi_byte], + None, + pt.TealInputError, + ), + ( + "mixed with ABI annotations 0 with ABI return", + fn_mixed_annotations_0_with_ret, + [sv, expr_int, abi_byte], + pt.abi.Uint64TypeSpec(), + None, + ), + ( + "mixed with ABI annotations 0 with ABI return wrong inputs", + fn_mixed_annotations_0_with_ret, + [sv, expr_int, sv], + None, + pt.TealInputError, + ), + ( + "mixed with ABI annotations 1", + fn_mixed_annotation_1, + [sv, abi_static_u32_10], + "void", + None, + ), + ( + "mixed with ABI annotations 1 with ABI return", + fn_mixed_annotation_1_with_ret, + [sv, abi_u64], + pt.abi.BoolTypeSpec(), + None, + ), + ( + "mixed with ABI annotations 1 with ABI return wrong inputs", + fn_mixed_annotation_1_with_ret, + [expr_int, abi_static_u32_10], + None, + pt.TealInputError, + ), + ] + + for case_name, definition, args, ret_type, err in cases: + assert definition.subroutine.argument_count() == len(args), case_name + assert ( + definition.name() == definition.subroutine.implementation.__name__ + ), case_name + + if err is None: + invocation = definition(*args) + if ret_type == "void": + assert isinstance(invocation, pt.SubroutineCall), case_name + assert not invocation.has_return(), case_name + assert invocation.args == args, case_name + else: + assert isinstance(invocation, pt.abi.ReturnedValue), case_name + assert invocation.type_spec == ret_type + assert isinstance(invocation.computation, pt.SubroutineCall), case_name + assert not invocation.computation.has_return(), case_name + assert invocation.computation.args == args, case_name + else: + try: + with pytest.raises(err): + definition(*args) + except Exception as e: + assert ( + not e + ), f"EXPECTED ERROR of type {err}. encountered unexpected error during invocation case <{case_name}>: {e}" + + def test_subroutine_definition_invalid(): def fnWithDefaults(a, b=None): return pt.Return() - def fnWithKeywordArgs(a, *, b): + def fnWithKeywordArgs(a, *, output): + return pt.Return() + + def fnWithKeywordArgsWrongKWName(a, *, b: pt.abi.Uint64): + return pt.Return() + + def fnWithMultipleABIKeywordArgs(a, *, b: pt.abi.Byte, c: pt.abi.Bool): return pt.Return() def fnWithVariableArgs(a, *b): @@ -513,57 +864,92 @@ def fnWithMixedAnnsABIRet2( return pt.abi.Uint64() cases = ( - (1, "TealInputError('Input to SubroutineDefinition is not callable'"), - (None, "TealInputError('Input to SubroutineDefinition is not callable'"), + ( + 1, + "TealInputError('Input to SubroutineDefinition is not callable'", + "TealInputError('Input to ABIReturnSubroutine is not callable'", + ), + ( + None, + "TealInputError('Input to SubroutineDefinition is not callable'", + "TealInputError('Input to ABIReturnSubroutine is not callable'", + ), ( fnWithDefaults, "TealInputError('Function has a parameter with a default value, which is not allowed in a subroutine: b'", + "TealInputError('Function has a parameter with a default value, which is not allowed in a subroutine: b'", ), ( fnWithKeywordArgs, + "TealInputError('Function has a parameter type that is not allowed in a subroutine: parameter output with type", + "TealInputError('ABI return subroutine output-kwarg output must specify ABI type')", + ), + ( + fnWithKeywordArgsWrongKWName, + "TealInputError('Function has a parameter type that is not allowed in a subroutine: parameter b with type", + "TealInputError('ABI return subroutine output-kwarg name must be `output` at this moment", + ), + ( + fnWithMultipleABIKeywordArgs, "TealInputError('Function has a parameter type that is not allowed in a subroutine: parameter b with type", + "multiple output arguments (2) with type annotations", ), ( fnWithVariableArgs, "TealInputError('Function has a parameter type that is not allowed in a subroutine: parameter b with type", + "Function has a parameter type that is not allowed in a subroutine: parameter b with type VAR_POSITIONAL", ), ( fnWithNonExprReturnAnnotation, "Function has return of disallowed type TealType.uint64. Only Expr is allowed", + "Function has return of disallowed type TealType.uint64. Only Expr is allowed", ), ( fnWithNonExprParamAnnotation, - "Function has parameter b of disallowed type TealType.uint64. Only the types", + "Function has parameter b of declared type TealType.uint64 which is not a class", + "Function has parameter b of declared type TealType.uint64 which is not a class", ), ( fnWithScratchVarSubclass, "Function has parameter b of disallowed type ", + "Function has parameter b of disallowed type ", ), ( fnReturningExprSubclass, "Function has return of disallowed type ", + "Function has return of disallowed type . Only Expr is allowed", ), ( fnWithMixedAnns4AndBytesReturn, "Function has return of disallowed type ", + "Function has return of disallowed type . Only Expr is allowed", ), ( fnWithMixedAnnsABIRet1, "Function has return of disallowed type pyteal.StaticArray[pyteal.Uint32, typing.Literal[10]]. " "Only Expr is allowed", + "Function has return of disallowed type pyteal.StaticArray[pyteal.Uint32, typing.Literal[10]]. " + "Only Expr is allowed", ), ( fnWithMixedAnnsABIRet2, "Function has return of disallowed type . Only Expr is allowed", + "Function has return of disallowed type . Only Expr is allowed", ), ) - for fn, msg in cases: + for fn, sub_def_msg, abi_sub_def_msg in cases: with pytest.raises(pt.TealInputError) as e: - print(f"case=[{msg}]") + print(f"case=[{sub_def_msg}]") pt.SubroutineDefinition(fn, pt.TealType.none) - assert msg in str(e), "failed for case [{}]".format(fn.__name__) + assert sub_def_msg in str(e), f"failed for case [{fn.__name__}]" + + with pytest.raises(pt.TealInputError) as e: + print(f"case=[{abi_sub_def_msg}]") + pt.ABIReturnSubroutine(fn) + + assert abi_sub_def_msg in str(e), f"failed for case[{fn.__name__}]" def test_subroutine_declaration(): @@ -670,8 +1056,8 @@ def mySubroutine(): return returnValue definition = pt.SubroutineDefinition(mySubroutine, returnType) + declaration = evaluate_subroutine(definition) - declaration = evaluateSubroutine(definition) assert isinstance(declaration, pt.SubroutineDeclaration) assert declaration.subroutine is definition @@ -704,8 +1090,8 @@ def mySubroutine(a1): return returnValue definition = pt.SubroutineDefinition(mySubroutine, returnType) + declaration = evaluate_subroutine(definition) - declaration = evaluateSubroutine(definition) assert isinstance(declaration, pt.SubroutineDeclaration) assert declaration.subroutine is definition @@ -748,7 +1134,8 @@ def mySubroutine(a1, a2): definition = pt.SubroutineDefinition(mySubroutine, returnType) - declaration = evaluateSubroutine(definition) + declaration = evaluate_subroutine(definition) + assert isinstance(declaration, pt.SubroutineDeclaration) assert declaration.subroutine is definition @@ -793,8 +1180,8 @@ def mySubroutine(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10): return returnValue definition = pt.SubroutineDefinition(mySubroutine, returnType) + declaration = evaluate_subroutine(definition) - declaration = evaluateSubroutine(definition) assert isinstance(declaration, pt.SubroutineDeclaration) assert declaration.subroutine is definition diff --git a/pyteal/compiler/compiler.py b/pyteal/compiler/compiler.py index d5a0174b2..f09eb2c95 100644 --- a/pyteal/compiler/compiler.py +++ b/pyteal/compiler/compiler.py @@ -10,7 +10,7 @@ SubroutineDefinition, SubroutineDeclaration, ) -from pyteal.ir import Mode, TealComponent, TealOp, TealBlock, TealSimpleBlock +from pyteal.ir import Mode, Op, TealComponent, TealOp, TealBlock, TealSimpleBlock from pyteal.errors import TealInputError, TealInternalError from pyteal.compiler.sort import sortBlocks @@ -133,26 +133,60 @@ def compileSubroutine( ast = Return(ast) options.setSubroutine(currentSubroutine) + start, end = ast.__teal__(options) start.addIncoming() start.validateTree() - start = TealBlock.NormalizeBlocks(start) - start.validateTree() + if ( + currentSubroutine is not None + and currentSubroutine.get_declaration().deferred_expr is not None + ): + # this represents code that should be inserted before each retsub op + deferred_expr = cast(Expr, currentSubroutine.get_declaration().deferred_expr) + + for block in TealBlock.Iterate(start): + if not any(op.getOp() == Op.retsub for op in block.ops): + continue + + if len(block.ops) != 1: + # we expect all retsub ops to be in their own block at this point since + # TealBlock.NormalizeBlocks has not yet been used + raise TealInternalError( + f"Expected retsub to be the only op in the block, but there are {len(block.ops)} ops" + ) - order = sortBlocks(start, end) - teal = flattenBlocks(order) + # we invoke __teal__ here and not outside of this loop because the same block cannot be + # added in multiple places to the control flow graph + deferred_start, deferred_end = deferred_expr.__teal__(options) + deferred_start.addIncoming() + deferred_start.validateTree() - verifyOpsForVersion(teal, options.version) - verifyOpsForMode(teal, options.mode) + # insert deferred blocks between the previous block(s) and this one + deferred_start.incoming = block.incoming + block.incoming = [deferred_end] + deferred_end.nextBlock = block + + for prev in deferred_start.incoming: + prev.replaceOutgoing(block, deferred_start) + + if block is start: + # this is the start block, replace start + start = deferred_start + + start.validateTree() + + start = TealBlock.NormalizeBlocks(start) + start.validateTree() subroutine_start_blocks[currentSubroutine] = start subroutine_end_blocks[currentSubroutine] = end referencedSubroutines: Set[SubroutineDefinition] = set() - for stmt in teal: - for subroutine in stmt.getSubroutines(): - referencedSubroutines.add(subroutine) + for block in TealBlock.Iterate(start): + for stmt in block.ops: + for subroutine in stmt.getSubroutines(): + referencedSubroutines.add(subroutine) if currentSubroutine is not None: subroutineGraph[currentSubroutine] = referencedSubroutines @@ -160,7 +194,7 @@ def compileSubroutine( newSubroutines = referencedSubroutines - subroutine_start_blocks.keys() for subroutine in sorted(newSubroutines, key=lambda subroutine: subroutine.id): compileSubroutine( - subroutine.getDeclaration(), + subroutine.get_declaration(), options, subroutineGraph, subroutine_start_blocks, @@ -256,6 +290,9 @@ def compileTeal( subroutineLabels = resolveSubroutines(subroutineMapping) teal = flattenSubroutines(subroutineMapping, subroutineLabels) + verifyOpsForVersion(teal, options.version) + verifyOpsForMode(teal, options.mode) + if assembleConstants: if version < 3: raise TealInternalError( diff --git a/pyteal/compiler/compiler_test.py b/pyteal/compiler/compiler_test.py index 0d6fdbe09..a89f075fc 100644 --- a/pyteal/compiler/compiler_test.py +++ b/pyteal/compiler/compiler_test.py @@ -1659,6 +1659,194 @@ def storeValue(key: pt.Expr, t1: pt.Expr, t2: pt.Expr, t3: pt.Expr) -> pt.Expr: assert actual == expected +def test_compile_subroutine_deferred_expr(): + @pt.Subroutine(pt.TealType.none) + def deferredExample(value: pt.Expr) -> pt.Expr: + return pt.Seq( + pt.If(value == pt.Int(0)).Then(pt.Return()), + pt.If(value == pt.Int(1)).Then(pt.Approve()), + pt.If(value == pt.Int(2)).Then(pt.Reject()), + pt.If(value == pt.Int(3)).Then(pt.Err()), + ) + + program = pt.Seq(deferredExample(pt.Int(10)), pt.Approve()) + + expected_no_deferred = """#pragma version 6 +int 10 +callsub deferredExample_0 +int 1 +return + +// deferredExample +deferredExample_0: +store 0 +load 0 +int 0 +== +bnz deferredExample_0_l7 +load 0 +int 1 +== +bnz deferredExample_0_l6 +load 0 +int 2 +== +bnz deferredExample_0_l5 +load 0 +int 3 +== +bz deferredExample_0_l8 +err +deferredExample_0_l5: +int 0 +return +deferredExample_0_l6: +int 1 +return +deferredExample_0_l7: +retsub +deferredExample_0_l8: +retsub + """.strip() + actual_no_deferred = pt.compileTeal( + program, pt.Mode.Application, version=6, assembleConstants=False + ) + assert actual_no_deferred == expected_no_deferred + + # manually add deferred expression to SubroutineDefinition + declaration = deferredExample.subroutine.get_declaration() + declaration.deferred_expr = pt.Pop(pt.Bytes("deferred")) + + expected_deferred = """#pragma version 6 +int 10 +callsub deferredExample_0 +int 1 +return + +// deferredExample +deferredExample_0: +store 0 +load 0 +int 0 +== +bnz deferredExample_0_l7 +load 0 +int 1 +== +bnz deferredExample_0_l6 +load 0 +int 2 +== +bnz deferredExample_0_l5 +load 0 +int 3 +== +bz deferredExample_0_l8 +err +deferredExample_0_l5: +int 0 +return +deferredExample_0_l6: +int 1 +return +deferredExample_0_l7: +byte "deferred" +pop +retsub +deferredExample_0_l8: +byte "deferred" +pop +retsub + """.strip() + actual_deferred = pt.compileTeal( + program, pt.Mode.Application, version=6, assembleConstants=False + ) + assert actual_deferred == expected_deferred + + +def test_compile_subroutine_deferred_expr_empty(): + @pt.Subroutine(pt.TealType.none) + def empty() -> pt.Expr: + return pt.Return() + + program = pt.Seq(empty(), pt.Approve()) + + expected_no_deferred = """#pragma version 6 +callsub empty_0 +int 1 +return + +// empty +empty_0: +retsub + """.strip() + actual_no_deferred = pt.compileTeal( + program, pt.Mode.Application, version=6, assembleConstants=False + ) + assert actual_no_deferred == expected_no_deferred + + # manually add deferred expression to SubroutineDefinition + declaration = empty.subroutine.get_declaration() + declaration.deferred_expr = pt.Pop(pt.Bytes("deferred")) + + expected_deferred = """#pragma version 6 +callsub empty_0 +int 1 +return + +// empty +empty_0: +byte "deferred" +pop +retsub + """.strip() + actual_deferred = pt.compileTeal( + program, pt.Mode.Application, version=6, assembleConstants=False + ) + assert actual_deferred == expected_deferred + + +def test_compileSubroutine_deferred_block_malformed(): + class BadRetsub(pt.Expr): + def type_of(self) -> pt.TealType: + return pt.TealType.none + + def has_return(self) -> bool: + return True + + def __str__(self) -> str: + return "(BadRetsub)" + + def __teal__( + self, options: pt.CompileOptions + ) -> tuple[pt.TealBlock, pt.TealSimpleBlock]: + block = pt.TealSimpleBlock( + [ + pt.TealOp(self, pt.Op.int, 1), + pt.TealOp(self, pt.Op.pop), + pt.TealOp(self, pt.Op.retsub), + ] + ) + + return block, block + + @pt.Subroutine(pt.TealType.none) + def bad() -> pt.Expr: + return BadRetsub() + + program = pt.Seq(bad(), pt.Approve()) + + # manually add deferred expression to SubroutineDefinition + declaration = bad.subroutine.get_declaration() + declaration.deferred_expr = pt.Pop(pt.Bytes("deferred")) + + with pytest.raises( + pt.TealInternalError, + match=r"^Expected retsub to be the only op in the block, but there are 3 ops$", + ): + pt.compileTeal(program, pt.Mode.Application, version=6, assembleConstants=False) + + def test_compile_wide_ratio(): cases = ( ( @@ -1816,3 +2004,192 @@ def test_compile_wide_ratio(): program, pt.Mode.Application, version=5, assembleConstants=False ) assert actual == expected.strip() + + +def test_compile_abi_subroutine_return(): + @pt.ABIReturnSubroutine + def abi_sum( + toSum: pt.abi.DynamicArray[pt.abi.Uint64], *, output: pt.abi.Uint64 + ) -> pt.Expr: + i = pt.ScratchVar(pt.TealType.uint64) + valueAtIndex = pt.abi.Uint64() + return pt.Seq( + output.set(0), + pt.For( + i.store(pt.Int(0)), + i.load() < toSum.length(), + i.store(i.load() + pt.Int(1)), + ).Do( + pt.Seq( + toSum[i.load()].store_into(valueAtIndex), + output.set(output.get() + valueAtIndex.get()), + ) + ), + ) + + program = pt.Seq( + (to_sum_arr := pt.abi.make(pt.abi.DynamicArray[pt.abi.Uint64])).decode( + pt.Txn.application_args[1] + ), + (res := pt.abi.Uint64()).set(abi_sum(to_sum_arr)), + pt.abi.MethodReturn(res), + pt.Approve(), + ) + + expected_sum = """#pragma version 6 +txna ApplicationArgs 1 +store 0 +load 0 +callsub abisum_0 +store 1 +byte 0x151F7C75 +load 1 +itob +concat +log +int 1 +return + +// abi_sum +abisum_0: +store 2 +int 0 +store 3 +int 0 +store 4 +abisum_0_l1: +load 4 +load 2 +int 0 +extract_uint16 +store 6 +load 6 +< +bz abisum_0_l3 +load 2 +int 8 +load 4 +* +int 2 ++ +extract_uint64 +store 5 +load 3 +load 5 ++ +store 3 +load 4 +int 1 ++ +store 4 +b abisum_0_l1 +abisum_0_l3: +load 3 +retsub + """.strip() + + actual_sum = pt.compileTeal(program, pt.Mode.Application, version=6) + assert expected_sum == actual_sum + + @pt.ABIReturnSubroutine + def conditional_factorial( + _factor: pt.abi.Uint64, *, output: pt.abi.Uint64 + ) -> pt.Expr: + i = pt.ScratchVar(pt.TealType.uint64) + + return pt.Seq( + output.set(1), + pt.If(_factor.get() <= pt.Int(1)) + .Then(pt.Return()) + .Else( + pt.For( + i.store(_factor.get()), + i.load() > pt.Int(1), + i.store(i.load() - pt.Int(1)), + ).Do(output.set(output.get() * i.load())), + ), + ) + + program_cond_factorial = pt.Seq( + (factor := pt.abi.Uint64()).decode(pt.Txn.application_args[1]), + (res := pt.abi.Uint64()).set(conditional_factorial(factor)), + pt.abi.MethodReturn(res), + pt.Approve(), + ) + + expected_conditional_factorial = """#pragma version 6 +txna ApplicationArgs 1 +btoi +store 0 +load 0 +callsub conditionalfactorial_0 +store 1 +byte 0x151F7C75 +load 1 +itob +concat +log +int 1 +return + +// conditional_factorial +conditionalfactorial_0: +store 2 +int 1 +store 3 +load 2 +int 1 +<= +bnz conditionalfactorial_0_l4 +load 2 +store 4 +conditionalfactorial_0_l2: +load 4 +int 1 +> +bz conditionalfactorial_0_l5 +load 3 +load 4 +* +store 3 +load 4 +int 1 +- +store 4 +b conditionalfactorial_0_l2 +conditionalfactorial_0_l4: +load 3 +retsub +conditionalfactorial_0_l5: +load 3 +retsub + """.strip() + + actual_conditional_factorial = pt.compileTeal( + program_cond_factorial, pt.Mode.Application, version=6 + ) + assert actual_conditional_factorial == expected_conditional_factorial + + @pt.ABIReturnSubroutine + def load_b4_set(*, output: pt.abi.Bool): + return pt.Return() + + program_load_b4_set_broken = pt.Seq( + (_ := pt.abi.Bool()).set(load_b4_set()), pt.Approve() + ) + + with pytest.raises(pt.TealInternalError): + pt.compileTeal(program_load_b4_set_broken, pt.Mode.Application, version=6) + + @pt.ABIReturnSubroutine + def access_b4_store(magic_num: pt.abi.Uint64, *, output: pt.abi.Uint64): + return pt.Seq(output.set(output.get() ^ magic_num.get())) + + program_access_b4_store_broken = pt.Seq( + (other_party_magic := pt.abi.Uint64()).decode(pt.Txn.application_args[1]), + (_ := pt.abi.Uint64()).set(access_b4_store(other_party_magic)), + pt.Approve(), + ) + + with pytest.raises(pt.TealInternalError): + pt.compileTeal(program_access_b4_store_broken, pt.Mode.Application, version=6) diff --git a/pyteal/compiler/subroutines.py b/pyteal/compiler/subroutines.py index bc3ee74e7..8ea3f4ddb 100644 --- a/pyteal/compiler/subroutines.py +++ b/pyteal/compiler/subroutines.py @@ -167,7 +167,7 @@ def spillLocalSlotsDuringRecursion( # reentrySubroutineCalls should have a length of 1, since calledSubroutines has a # maximum length of 1 reentrySubroutineCall = reentrySubroutineCalls[0] - numArgs = reentrySubroutineCall.argumentCount() + numArgs = reentrySubroutineCall.argument_count() digArgs = True coverSpilledSlots = False