Skip to content

Commit

Permalink
improve name resolution of type aliases (#2061)
Browse files Browse the repository at this point in the history
* BREAKING CHANGE: Renamed user_defined_types to type_aliases so it's less confusing with what we call UserDefinedType.
* Added type aliased at the Contract level so now at the file scope there are only top level aliasing and fully qualified contract aliasing like C.myAlias.
* Fix #1809
  • Loading branch information
smonicas authored Sep 8, 2023
1 parent 53c97f9 commit 8b07fe5
Show file tree
Hide file tree
Showing 14 changed files with 101 additions and 39 deletions.
6 changes: 3 additions & 3 deletions slither/core/compilation_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(self, core: "SlitherCore", crytic_compilation_unit: CompilationUnit
self._pragma_directives: List[Pragma] = []
self._import_directives: List[Import] = []
self._custom_errors: List[CustomErrorTopLevel] = []
self._user_defined_value_types: Dict[str, TypeAliasTopLevel] = {}
self._type_aliases: Dict[str, TypeAliasTopLevel] = {}

self._all_functions: Set[Function] = set()
self._all_modifiers: Set[Modifier] = set()
Expand Down Expand Up @@ -220,8 +220,8 @@ def custom_errors(self) -> List[CustomErrorTopLevel]:
return self._custom_errors

@property
def user_defined_value_types(self) -> Dict[str, TypeAliasTopLevel]:
return self._user_defined_value_types
def type_aliases(self) -> Dict[str, TypeAliasTopLevel]:
return self._type_aliases

# endregion
###################################################################################
Expand Down
34 changes: 34 additions & 0 deletions slither/core/declarations/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from slither.core.compilation_unit import SlitherCompilationUnit
from slither.core.scope.scope import FileScope
from slither.core.cfg.node import Node
from slither.core.solidity_types import TypeAliasContract


LOGGER = logging.getLogger("Contract")
Expand Down Expand Up @@ -81,6 +82,7 @@ def __init__(self, compilation_unit: "SlitherCompilationUnit", scope: "FileScope
self._functions: Dict[str, "FunctionContract"] = {}
self._linearizedBaseContracts: List[int] = []
self._custom_errors: Dict[str, "CustomErrorContract"] = {}
self._type_aliases: Dict[str, "TypeAliasContract"] = {}

# The only str is "*"
self._using_for: Dict[USING_FOR_KEY, USING_FOR_ITEM] = {}
Expand Down Expand Up @@ -364,6 +366,38 @@ def custom_errors_declared(self) -> List["CustomErrorContract"]:
def custom_errors_as_dict(self) -> Dict[str, "CustomErrorContract"]:
return self._custom_errors

# endregion
###################################################################################
###################################################################################
# region Custom Errors
###################################################################################
###################################################################################

@property
def type_aliases(self) -> List["TypeAliasContract"]:
"""
list(TypeAliasContract): List of the contract's custom errors
"""
return list(self._type_aliases.values())

@property
def type_aliases_inherited(self) -> List["TypeAliasContract"]:
"""
list(TypeAliasContract): List of the inherited custom errors
"""
return [s for s in self.type_aliases if s.contract != self]

@property
def type_aliases_declared(self) -> List["TypeAliasContract"]:
"""
list(TypeAliasContract): List of the custom errors declared within the contract (not inherited)
"""
return [s for s in self.type_aliases if s.contract == self]

@property
def type_aliases_as_dict(self) -> Dict[str, "TypeAliasContract"]:
return self._type_aliases

# endregion
###################################################################################
###################################################################################
Expand Down
6 changes: 3 additions & 3 deletions slither/core/scope/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(self, filename: Filename) -> None:

# User defined types
# Name -> type alias
self.user_defined_types: Dict[str, TypeAlias] = {}
self.type_aliases: Dict[str, TypeAlias] = {}

def add_accesible_scopes(self) -> bool:
"""
Expand Down Expand Up @@ -95,8 +95,8 @@ def add_accesible_scopes(self) -> bool:
if not _dict_contain(new_scope.renaming, self.renaming):
self.renaming.update(new_scope.renaming)
learn_something = True
if not _dict_contain(new_scope.user_defined_types, self.user_defined_types):
self.user_defined_types.update(new_scope.user_defined_types)
if not _dict_contain(new_scope.type_aliases, self.type_aliases):
self.type_aliases.update(new_scope.type_aliases)
learn_something = True

return learn_something
Expand Down
8 changes: 4 additions & 4 deletions slither/solc_parsing/declarations/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,10 +291,10 @@ def _parse_type_alias(self, item: Dict) -> None:
alias = item["name"]
alias_canonical = self._contract.name + "." + item["name"]

user_defined_type = TypeAliasContract(original_type, alias, self.underlying_contract)
user_defined_type.set_offset(item["src"], self.compilation_unit)
self._contract.file_scope.user_defined_types[alias] = user_defined_type
self._contract.file_scope.user_defined_types[alias_canonical] = user_defined_type
type_alias = TypeAliasContract(original_type, alias, self.underlying_contract)
type_alias.set_offset(item["src"], self.compilation_unit)
self._contract.type_aliases_as_dict[alias] = type_alias
self._contract.file_scope.type_aliases[alias_canonical] = type_alias

def _parse_struct(self, struct: Dict) -> None:

Expand Down
2 changes: 1 addition & 1 deletion slither/solc_parsing/declarations/using_for_top_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def _propagate_global(self, type_name: Union[TypeAliasTopLevel, UserDefinedType]
if self._global:
for scope in self.compilation_unit.scopes.values():
if isinstance(type_name, TypeAliasTopLevel):
for alias in scope.user_defined_types.values():
for alias in scope.type_aliases.values():
if alias == type_name:
scope.using_for_directives.add(self._using_for)
elif isinstance(type_name, UserDefinedType):
Expand Down
9 changes: 6 additions & 3 deletions slither/solc_parsing/expressions/find_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ def find_top_level(
:return:
:rtype:
"""
if var_name in scope.type_aliases:
return scope.type_aliases[var_name], False

if var_name in scope.structures:
return scope.structures[var_name], False
Expand Down Expand Up @@ -205,6 +207,10 @@ def _find_in_contract(
if sig == var_name:
return modifier

type_aliases = contract.type_aliases_as_dict
if var_name in type_aliases:
return type_aliases[var_name]

# structures are looked on the contract declarer
structures = contract.structures_as_dict
if var_name in structures:
Expand Down Expand Up @@ -362,9 +368,6 @@ def find_variable(
if var_name in current_scope.renaming:
var_name = current_scope.renaming[var_name]

if var_name in current_scope.user_defined_types:
return current_scope.user_defined_types[var_name], False

# Use ret0/ret1 to help mypy
ret0 = _find_variable_from_ref_declaration(
referenced_declaration, direct_contracts, direct_functions
Expand Down
8 changes: 4 additions & 4 deletions slither/solc_parsing/slither_compilation_unit_solc.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,10 +344,10 @@ def parse_top_level_from_loaded_json(self, data_loaded: Dict, filename: str) ->

original_type = ElementaryType(underlying_type["name"])

user_defined_type = TypeAliasTopLevel(original_type, alias, scope)
user_defined_type.set_offset(top_level_data["src"], self._compilation_unit)
self._compilation_unit.user_defined_value_types[alias] = user_defined_type
scope.user_defined_types[alias] = user_defined_type
type_alias = TypeAliasTopLevel(original_type, alias, scope)
type_alias.set_offset(top_level_data["src"], self._compilation_unit)
self._compilation_unit.type_aliases[alias] = type_alias
scope.type_aliases[alias] = type_alias

else:
raise SlitherException(f"Top level {top_level_data[self.get_key()]} not supported")
Expand Down
30 changes: 15 additions & 15 deletions slither/solc_parsing/solidity_types/type_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def parse_type(

sl: "SlitherCompilationUnit"
renaming: Dict[str, str]
user_defined_types: Dict[str, TypeAlias]
type_aliases: Dict[str, TypeAlias]
enums_direct_access: List["Enum"] = []
# Note: for convenicence top level functions use the same parser than function in contract
# but contract_parser is set to None
Expand All @@ -247,13 +247,13 @@ def parse_type(
sl = caller_context.compilation_unit
next_context = caller_context
renaming = {}
user_defined_types = sl.user_defined_value_types
type_aliases = sl.type_aliases
else:
assert isinstance(caller_context, FunctionSolc)
sl = caller_context.underlying_function.compilation_unit
next_context = caller_context.slither_parser
renaming = caller_context.underlying_function.file_scope.renaming
user_defined_types = caller_context.underlying_function.file_scope.user_defined_types
type_aliases = caller_context.underlying_function.file_scope.type_aliases
structures_direct_access = sl.structures_top_level
all_structuress = [c.structures for c in sl.contracts]
all_structures = [item for sublist in all_structuress for item in sublist]
Expand Down Expand Up @@ -299,7 +299,7 @@ def parse_type(
functions = list(scope.functions)

renaming = scope.renaming
user_defined_types = scope.user_defined_types
type_aliases = scope.type_aliases
elif isinstance(caller_context, (ContractSolc, FunctionSolc)):
sl = caller_context.compilation_unit
if isinstance(caller_context, FunctionSolc):
Expand Down Expand Up @@ -329,7 +329,7 @@ def parse_type(
functions = contract.functions + contract.modifiers

renaming = scope.renaming
user_defined_types = scope.user_defined_types
type_aliases = scope.type_aliases
else:
raise ParsingError(f"Incorrect caller context: {type(caller_context)}")

Expand All @@ -343,8 +343,8 @@ def parse_type(
name = t.name
if name in renaming:
name = renaming[name]
if name in user_defined_types:
return user_defined_types[name]
if name in type_aliases:
return type_aliases[name]
return _find_from_type_name(
name,
functions,
Expand All @@ -365,9 +365,9 @@ def parse_type(
name = t["typeDescriptions"]["typeString"]
if name in renaming:
name = renaming[name]
if name in user_defined_types:
_add_type_references(user_defined_types[name], t["src"], sl)
return user_defined_types[name]
if name in type_aliases:
_add_type_references(type_aliases[name], t["src"], sl)
return type_aliases[name]
type_found = _find_from_type_name(
name,
functions,
Expand All @@ -386,9 +386,9 @@ def parse_type(
name = t["attributes"][type_name_key]
if name in renaming:
name = renaming[name]
if name in user_defined_types:
_add_type_references(user_defined_types[name], t["src"], sl)
return user_defined_types[name]
if name in type_aliases:
_add_type_references(type_aliases[name], t["src"], sl)
return type_aliases[name]
type_found = _find_from_type_name(
name,
functions,
Expand All @@ -407,8 +407,8 @@ def parse_type(
name = t["name"]
if name in renaming:
name = renaming[name]
if name in user_defined_types:
return user_defined_types[name]
if name in type_aliases:
return type_aliases[name]
type_found = _find_from_type_name(
name,
functions,
Expand Down
4 changes: 2 additions & 2 deletions slither/visitors/slithir/expression_to_slithir.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,8 +516,8 @@ def _post_member_access(self, expression: MemberAccess) -> None:
# contract A { type MyInt is int}
# contract B { function f() public{ A.MyInt test = A.MyInt.wrap(1);}}
# The logic is handled by _post_call_expression
if expression.member_name in expr.file_scope.user_defined_types:
set_val(expression, expr.file_scope.user_defined_types[expression.member_name])
if expression.member_name in expr.file_scope.type_aliases:
set_val(expression, expr.file_scope.type_aliases[expression.member_name])
return
# Lookup errors referred to as member of contract e.g. Test.myError.selector
if expression.member_name in expr.custom_errors_as_dict:
Expand Down
1 change: 1 addition & 0 deletions tests/e2e/solc_parsing/test_ast_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,7 @@ def make_version(minor: int, patch_min: int, patch_max: int) -> List[str]:
["0.6.9", "0.7.6", "0.8.16"],
),
Test("user_defined_operators-0.8.19.sol", ["0.8.19"]),
Test("type-aliases.sol", ["0.8.19"]),
]
# create the output folder if needed
try:
Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"OtherTest": {
"myfunc()": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: NEW VARIABLE 1\n\"];\n}\n"
},
"DeleteTest": {}
}
20 changes: 20 additions & 0 deletions tests/e2e/solc_parsing/test_data/type-aliases.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@

struct Z {
int x;
int y;
}

contract OtherTest {
struct Z {
int x;
int y;
}

function myfunc() external {
Z memory z = Z(2,3);
}
}

contract DeleteTest {
type Z is int;
}
6 changes: 2 additions & 4 deletions tests/unit/core/test_source_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,13 @@ def test_references_user_defined_aliases(solc_binary_path):
file = Path(SRC_MAPPING_TEST_ROOT, "ReferencesUserDefinedAliases.sol").as_posix()
slither = Slither(file, solc=solc_path)

alias_top_level = slither.compilation_units[0].user_defined_value_types["aliasTopLevel"]
alias_top_level = slither.compilation_units[0].type_aliases["aliasTopLevel"]
assert len(alias_top_level.references) == 2
lines = _sort_references_lines(alias_top_level.references)
assert lines == [12, 16]

alias_contract_level = (
slither.compilation_units[0]
.contracts[0]
.file_scope.user_defined_types["C.aliasContractLevel"]
slither.compilation_units[0].contracts[0].file_scope.type_aliases["C.aliasContractLevel"]
)
assert len(alias_contract_level.references) == 2
lines = _sort_references_lines(alias_contract_level.references)
Expand Down

0 comments on commit 8b07fe5

Please sign in to comment.