Skip to content

Commit

Permalink
fix: correctly handle name collisions with structs in ARC-56
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-makerx authored and tristanmenzel committed Nov 27, 2024
1 parent b48a7d4 commit c6dae28
Show file tree
Hide file tree
Showing 19 changed files with 857 additions and 290 deletions.
4 changes: 2 additions & 2 deletions examples/sizes.txt
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@
state_totals 65 32 - | 32 16 -
stress_tests/BruteForceRotationSearch 228 163 - | 152 106 -
string_ops 156 154 - | 58 55 -
struct_by_name/Demo 205 161 - | 115 83 -
struct_by_name/Demo 271 217 - | 155 113 -
struct_in_box/Example 242 206 - | 127 99 -
stubs/BigUInt 192 121 - | 126 73 -
stubs/Bytes 944 279 - | 606 153 -
Expand All @@ -135,4 +135,4 @@
unssa/UnSSA 432 368 - | 241 204 -
voting/VotingRoundApp 1580 1475 - | 727 644 -
with_reentrancy/WithReentrancy 245 234 - | 126 117 -
Total 70436 54802 54743 | 33380 22388 22344
Total 70502 54858 54799 | 33420 22418 22374
10 changes: 6 additions & 4 deletions src/puya/arc56.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,13 +181,15 @@ def _get_source_info(debug_info: DebugInfo) -> Sequence[models.SourceInfo]:

class _StructAliases:
def __init__(self, structs: Iterable[ARC4Struct]) -> None:
self.aliases = dict[str, str]()
alias_to_fullname = dict[str, str]()
for struct in structs:
self.aliases[struct.fullname] = (
alias = (
struct.fullname
if struct.name in self.aliases or struct.name in models.AVMType
if struct.name in alias_to_fullname or struct.name in models.AVMType
else struct.name
)
alias_to_fullname[alias] = struct.fullname
self.aliases = {v: k for k, v in alias_to_fullname.items()}

@typing.overload
def resolve(self, struct: str) -> str: ...
Expand All @@ -203,7 +205,7 @@ def resolve(self, struct: str | None) -> str | None:

def _struct_to_event(structs: _StructAliases, struct: ARC4Struct) -> models.Event:
return models.Event(
name=struct.name,
name=structs.resolve(struct.name),
desc=struct.desc,
args=[
models.EventArg(
Expand Down
15 changes: 11 additions & 4 deletions src/puyapy/awst_build/arc4_client_gen.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
import textwrap
from collections.abc import Iterable, Sequence
from pathlib import Path
Expand All @@ -18,6 +19,7 @@

_AUTO_GENERATED_COMMENT = "# This file is auto-generated, do not modify"
_INDENT = " " * 4
_NON_ALPHA_NUMERIC = re.compile(r"\W+")


def write_arc4_client(contract: arc56.Contract, out_dir: Path) -> None:
Expand All @@ -42,7 +44,7 @@ def __init__(self, contract: arc56.Contract):
self.contract = contract
self.python_methods = set[str]()
self.struct_to_class = dict[str, str]()
self.reserved_class_names = {contract.name}
self.reserved_class_names = set[str]()
self.reserved_method_names = set[str]()
self.class_decls = list[str]()

Expand All @@ -53,6 +55,7 @@ def generate(cls, contract: arc56.Contract) -> str:
def _gen(self) -> str:
# generate class definitions for any referenced structs in methods
# don't generate from self.contract.structs as it may contain other struct definitions
client_class = self._unique_class(self.contract.name)
for method in self.contract.methods:
for struct in filter(None, (method.returns.struct, *(a.struct for a in method.args))):
if struct not in self.struct_to_class and (
Expand All @@ -70,7 +73,7 @@ def _gen(self) -> str:
"",
*self.class_decls,
"",
f"class {self.contract.name}(algopy.arc4.ARC4Client, typing.Protocol):",
f"class {client_class}(algopy.arc4.ARC4Client, typing.Protocol):",
*_docstring(self.contract.desc),
*self._gen_methods(),
)
Expand Down Expand Up @@ -110,7 +113,7 @@ def _get_client_type(self, typ: str) -> str:
return str(arc4_to_pytype(typ, None))

def _unique_class(self, name: str) -> str:
base_name = name
base_name = name = _get_python_safe_name(name)
seq = 1
while name in self.reserved_class_names:
seq += 1
Expand All @@ -120,7 +123,7 @@ def _unique_class(self, name: str) -> str:
return name

def _unique_method(self, name: str) -> str:
base_name = name
base_name = name = _get_python_safe_name(name)
seq = 1
while name in self.reserved_method_names:
seq += 1
Expand Down Expand Up @@ -218,3 +221,7 @@ def _indent(lines: Iterable[str] | str) -> str:
if not isinstance(lines, str):
lines = "\n".join(lines)
return textwrap.indent(lines, _INDENT)


def _get_python_safe_name(name: str) -> str:
return _NON_ALPHA_NUMERIC.sub("_", name)
11 changes: 10 additions & 1 deletion test_cases/struct_by_name/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from algopy import ARC4Contract, arc4

from test_cases.struct_by_name.mod import StructTwo as StructThree


class StructOne(typing.NamedTuple):
x: arc4.UInt8
Expand Down Expand Up @@ -34,6 +36,13 @@ def get_two(self) -> StructTwo:
y=arc4.UInt8(1),
)

@arc4.abimethod()
def get_three(self) -> StructThree:
return StructThree(
x=arc4.UInt8(1),
y=arc4.UInt8(1),
)

@arc4.abimethod()
def compare(self) -> bool:
return self.get_one() == self.get_two()
return self.get_one() == self.get_two() and self.get_two() == self.get_three()
8 changes: 8 additions & 0 deletions test_cases/struct_by_name/mod.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import typing

from algopy import arc4


class StructTwo(typing.NamedTuple):
x: arc4.UInt8
y: arc4.UInt8
104 changes: 80 additions & 24 deletions test_cases/struct_by_name/out/DemoContract.approval.teal
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,19 @@ test_cases.struct_by_name.contract.DemoContract.approval_program:

// test_cases.struct_by_name.contract.DemoContract.__puya_arc4_router__() -> uint64:
__puya_arc4_router__:
// struct_by_name/contract.py:16
// struct_by_name/contract.py:18
// class DemoContract(ARC4Contract):
proto 0 1
txn NumAppArgs
bz __puya_arc4_router___bare_routing@7
pushbytess 0x3d694b70 0x7fb34e8a 0x46dadea3 // method "get_one()(uint8,uint8)", method "get_two()(uint8,uint8)", method "compare()bool"
bz __puya_arc4_router___bare_routing@8
pushbytess 0x3d694b70 0x7fb34e8a 0x8ba7c4c2 0x46dadea3 // method "get_one()(uint8,uint8)", method "get_two()(uint8,uint8)", method "get_three()(uint8,uint8)", method "compare()bool"
txna ApplicationArgs 0
match __puya_arc4_router___get_one_route@2 __puya_arc4_router___get_two_route@3 __puya_arc4_router___compare_route@4
match __puya_arc4_router___get_one_route@2 __puya_arc4_router___get_two_route@3 __puya_arc4_router___get_three_route@4 __puya_arc4_router___compare_route@5
intc_1 // 0
retsub

__puya_arc4_router___get_one_route@2:
// struct_by_name/contract.py:23
// struct_by_name/contract.py:25
// @arc4.abimethod()
txn OnCompletion
!
Expand All @@ -38,7 +38,7 @@ __puya_arc4_router___get_one_route@2:
retsub

__puya_arc4_router___get_two_route@3:
// struct_by_name/contract.py:30
// struct_by_name/contract.py:32
// @arc4.abimethod()
txn OnCompletion
!
Expand All @@ -54,8 +54,25 @@ __puya_arc4_router___get_two_route@3:
intc_0 // 1
retsub

__puya_arc4_router___compare_route@4:
// struct_by_name/contract.py:37
__puya_arc4_router___get_three_route@4:
// struct_by_name/contract.py:39
// @arc4.abimethod()
txn OnCompletion
!
assert // OnCompletion is not NoOp
txn ApplicationID
assert // can only call when not creating
callsub get_three
concat
bytec_0 // 0x151f7c75
swap
concat
log
intc_0 // 1
retsub

__puya_arc4_router___compare_route@5:
// struct_by_name/contract.py:46
// @arc4.abimethod()
txn OnCompletion
!
Expand All @@ -74,37 +91,37 @@ __puya_arc4_router___compare_route@4:
intc_0 // 1
retsub

__puya_arc4_router___bare_routing@7:
// struct_by_name/contract.py:16
__puya_arc4_router___bare_routing@8:
// struct_by_name/contract.py:18
// class DemoContract(ARC4Contract):
txn OnCompletion
bnz __puya_arc4_router___after_if_else@11
bnz __puya_arc4_router___after_if_else@12
txn ApplicationID
!
assert // can only call when creating
intc_0 // 1
retsub

__puya_arc4_router___after_if_else@11:
// struct_by_name/contract.py:16
__puya_arc4_router___after_if_else@12:
// struct_by_name/contract.py:18
// class DemoContract(ARC4Contract):
intc_1 // 0
retsub


// test_cases.struct_by_name.contract.DemoContract.get_one() -> bytes, bytes:
get_one:
// struct_by_name/contract.py:23-24
// struct_by_name/contract.py:25-26
// @arc4.abimethod()
// def get_one(self) -> StructOne:
proto 0 2
// struct_by_name/contract.py:26
// struct_by_name/contract.py:28
// x=arc4.UInt8(1),
bytec_1 // 0x01
// struct_by_name/contract.py:27
// struct_by_name/contract.py:29
// y=arc4.UInt8(1),
dup
// struct_by_name/contract.py:25-28
// struct_by_name/contract.py:27-30
// return StructOne(
// x=arc4.UInt8(1),
// y=arc4.UInt8(1),
Expand All @@ -114,32 +131,52 @@ get_one:

// test_cases.struct_by_name.contract.DemoContract.get_two() -> bytes, bytes:
get_two:
// struct_by_name/contract.py:30-31
// struct_by_name/contract.py:32-33
// @arc4.abimethod()
// def get_two(self) -> StructTwo:
proto 0 2
// struct_by_name/contract.py:33
// struct_by_name/contract.py:35
// x=arc4.UInt8(1),
bytec_1 // 0x01
// struct_by_name/contract.py:34
// struct_by_name/contract.py:36
// y=arc4.UInt8(1),
dup
// struct_by_name/contract.py:32-35
// struct_by_name/contract.py:34-37
// return StructTwo(
// x=arc4.UInt8(1),
// y=arc4.UInt8(1),
// )
retsub


// test_cases.struct_by_name.contract.DemoContract.get_three() -> bytes, bytes:
get_three:
// struct_by_name/contract.py:39-40
// @arc4.abimethod()
// def get_three(self) -> StructThree:
proto 0 2
// struct_by_name/contract.py:42
// x=arc4.UInt8(1),
bytec_1 // 0x01
// struct_by_name/contract.py:43
// y=arc4.UInt8(1),
dup
// struct_by_name/contract.py:41-44
// return StructThree(
// x=arc4.UInt8(1),
// y=arc4.UInt8(1),
// )
retsub


// test_cases.struct_by_name.contract.DemoContract.compare() -> uint64:
compare:
// struct_by_name/contract.py:37-38
// struct_by_name/contract.py:46-47
// @arc4.abimethod()
// def compare(self) -> bool:
proto 0 1
// struct_by_name/contract.py:39
// return self.get_one() == self.get_two()
// struct_by_name/contract.py:48
// return self.get_one() == self.get_two() and self.get_two() == self.get_three()
callsub get_one
callsub get_two
uncover 3
Expand All @@ -148,4 +185,23 @@ compare:
cover 2
b==
&&
bz compare_bool_false@3
callsub get_two
callsub get_three
uncover 3
uncover 2
b==
cover 2
b==
&&
bz compare_bool_false@3
intc_0 // 1
b compare_bool_merge@4

compare_bool_false@3:
intc_1 // 0

compare_bool_merge@4:
// struct_by_name/contract.py:48
// return self.get_one() == self.get_two() and self.get_two() == self.get_three()
retsub
Loading

0 comments on commit c6dae28

Please sign in to comment.