Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better struct handling in code generation util #2068

Merged
merged 5 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 100 additions & 12 deletions slither/utils/code_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,18 @@
MappingType,
ArrayType,
ElementaryType,
TypeAlias,
)
from slither.core.declarations import Structure, Enum, Contract
from slither.core.declarations import Structure, StructureContract, Enum, Contract

if TYPE_CHECKING:
from slither.core.declarations import FunctionContract, CustomErrorContract
from slither.core.variables.state_variable import StateVariable
from slither.core.variables.local_variable import LocalVariable
from slither.core.variables.structure_variable import StructureVariable


# pylint: disable=too-many-arguments
# pylint: disable=too-many-arguments,too-many-locals,too-many-branches
def generate_interface(
contract: "Contract",
unroll_structs: bool = True,
Expand Down Expand Up @@ -56,12 +58,47 @@ def generate_interface(
for enum in contract.enums:
interface += f" enum {enum.name} {{ {', '.join(enum.values)} }}\n"
if include_structs:
for struct in contract.structures:
# Include structures defined in this contract and at the top level
structs = contract.structures + contract.compilation_unit.structures_top_level
# Function signatures may reference other structures as well
# Include structures defined in libraries used for them
for _for in contract.using_for.keys():
if (
isinstance(_for, UserDefinedType)
and isinstance(_for.type, StructureContract)
and _for.type not in structs
):
structs.append(_for.type)
# Include any other structures used as function arguments/returns
for func in contract.functions_entry_points:
for arg in func.parameters + func.returns:
_type = arg.type
if isinstance(_type, ArrayType):
_type = _type.type
while isinstance(_type, MappingType):
_type = _type.type_to
if isinstance(_type, UserDefinedType):
_type = _type.type
if isinstance(_type, Structure) and _type not in structs:
structs.append(_type)
for struct in structs:
interface += generate_struct_interface_str(struct, indent=4)
for elem in struct.elems_ordered:
if (
isinstance(elem.type, UserDefinedType)
and isinstance(elem.type.type, StructureContract)
and elem.type.type not in structs
):
structs.append(elem.type.type)
for var in contract.state_variables_entry_points:
interface += f" function {generate_interface_variable_signature(var, unroll_structs)};\n"
# if any(func.name == var.name for func in contract.functions_entry_points):
# # ignore public variables that override a public function
# continue
var_sig = generate_interface_variable_signature(var, unroll_structs)
if var_sig is not None and var_sig != "":
interface += f" function {var_sig};\n"
for func in contract.functions_entry_points:
if func.is_constructor or func.is_fallback or func.is_receive:
if func.is_constructor or func.is_fallback or func.is_receive or not func.is_implemented:
continue
interface += (
f" function {generate_interface_function_signature(func, unroll_structs)};\n"
Expand All @@ -75,6 +112,10 @@ def generate_interface_variable_signature(
) -> Optional[str]:
if var.visibility in ["private", "internal"]:
return None
if isinstance(var.type, UserDefinedType) and isinstance(var.type.type, Structure):
for elem in var.type.type.elems_ordered:
if isinstance(elem.type, MappingType):
return ""
if unroll_structs:
params = [
convert_type_for_solidity_signature_to_string(x).replace("(", "").replace(")", "")
Expand All @@ -93,6 +134,11 @@ def generate_interface_variable_signature(
_type = _type.type_to
while isinstance(_type, (ArrayType, UserDefinedType)):
_type = _type.type
if isinstance(_type, TypeAlias):
_type = _type.type
if isinstance(_type, Structure):
if any(isinstance(elem.type, MappingType) for elem in _type.elems_ordered):
return ""
ret = str(_type)
if isinstance(_type, Structure) or (isinstance(_type, Type) and _type.is_dynamic):
ret += " memory"
Expand Down Expand Up @@ -125,6 +171,8 @@ def format_var(var: "LocalVariable", unroll: bool) -> str:
.replace("(", "")
.replace(")", "")
)
if var.type.is_dynamic:
return f"{_handle_dynamic_struct_elem(var.type)} {var.location}"
if isinstance(var.type, ArrayType) and isinstance(
var.type.type, (UserDefinedType, ElementaryType)
):
Expand All @@ -135,12 +183,14 @@ def format_var(var: "LocalVariable", unroll: bool) -> str:
+ f" {var.location}"
)
if isinstance(var.type, UserDefinedType):
if isinstance(var.type.type, (Structure, Enum)):
if isinstance(var.type.type, Structure):
return f"{str(var.type.type)} memory"
if isinstance(var.type.type, Enum):
return str(var.type.type)
if isinstance(var.type.type, Contract):
return "address"
if var.type.is_dynamic:
return f"{var.type} {var.location}"
if isinstance(var.type, TypeAlias):
return str(var.type.type)
return str(var.type)

name, _, _ = func.signature
Expand All @@ -154,6 +204,12 @@ def format_var(var: "LocalVariable", unroll: bool) -> str:
view = " view" if func.view and not func.pure else ""
pure = " pure" if func.pure else ""
payable = " payable" if func.payable else ""
# Make sure the function doesn't return a struct with nested mappings
for ret in func.returns:
if isinstance(ret.type, UserDefinedType) and isinstance(ret.type.type, Structure):
for elem in ret.type.type.elems_ordered:
if isinstance(elem.type, MappingType):
return ""
returns = [format_var(ret, unroll_structs) for ret in func.returns]
parameters = [format_var(param, unroll_structs) for param in func.parameters]
_interface_signature_str = (
Expand Down Expand Up @@ -184,17 +240,49 @@ def generate_struct_interface_str(struct: "Structure", indent: int = 0) -> str:
spaces += " "
definition = f"{spaces}struct {struct.name} {{\n"
for elem in struct.elems_ordered:
if isinstance(elem.type, UserDefinedType):
if isinstance(elem.type.type, (Structure, Enum)):
if elem.type.is_dynamic:
definition += f"{spaces} {_handle_dynamic_struct_elem(elem.type)} {elem.name};\n"
elif isinstance(elem.type, UserDefinedType):
if isinstance(elem.type.type, Structure):
definition += f"{spaces} {elem.type.type} {elem.name};\n"
elif isinstance(elem.type.type, Contract):
definition += f"{spaces} address {elem.name};\n"
else:
definition += f"{spaces} {convert_type_for_solidity_signature_to_string(elem.type)} {elem.name};\n"
elif isinstance(elem.type, TypeAlias):
definition += f"{spaces} {elem.type.type} {elem.name};\n"
else:
definition += f"{spaces} {elem.type} {elem.name};\n"
definition += f"{spaces}}}\n"
return definition


def _handle_dynamic_struct_elem(elem_type: Type) -> str:
assert elem_type.is_dynamic
if isinstance(elem_type, ElementaryType):
return f"{elem_type}"
if isinstance(elem_type, ArrayType):
base_type = elem_type.type
if isinstance(base_type, UserDefinedType):
if isinstance(base_type.type, Contract):
return "address[]"
if isinstance(base_type.type, Enum):
return convert_type_for_solidity_signature_to_string(elem_type)
return f"{base_type.type.name}[]"
return f"{base_type}[]"
if isinstance(elem_type, MappingType):
type_to = elem_type.type_to
type_from = elem_type.type_from
if isinstance(type_from, UserDefinedType) and isinstance(type_from.type, Contract):
type_from = ElementaryType("address")
if isinstance(type_to, MappingType):
return f"mapping({type_from} => {_handle_dynamic_struct_elem(type_to)})"
if isinstance(type_to, UserDefinedType):
if isinstance(type_to.type, Contract):
return f"mapping({type_from} => address)"
return f"mapping({type_from} => {type_to.type.name})"
return f"{elem_type}"
return ""


def generate_custom_error_interface(
error: "CustomErrorContract", unroll_structs: bool = True
) -> str:
Expand Down
13 changes: 11 additions & 2 deletions tests/unit/utils/test_data/code_generation/CodeGeneration.sol
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
pragma solidity ^0.8.4;

import "./IFee.sol";

interface I {
enum SomeEnum { ONE, TWO, THREE }
error ErrorWithEnum(SomeEnum e);
enum SomeEnum { ONE, TWO, THREE }
error ErrorWithEnum(SomeEnum e);
}

contract TestContract is I {
Expand Down Expand Up @@ -62,4 +65,10 @@ contract TestContract is I {
function setOtherI(I _i) public {
otherI = _i;
}

function newFee(uint128 fee) public returns (IFee.Fee memory) {
IFee.Fee memory _fee;
_fee.fee = fee;
return _fee;
}
}
5 changes: 5 additions & 0 deletions tests/unit/utils/test_data/code_generation/IFee.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
interface IFee {
struct Fee {
uint128 fee;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ interface ITestContract {
struct Nested {
St st;
}
struct Fee {
uint128 fee;
}
function stateA() external returns (uint256);
function owner() external returns (address);
function structsMap(address,uint256) external returns (uint256);
Expand All @@ -26,5 +29,6 @@ interface ITestContract {
function getSt(uint256) external view returns (uint256);
function removeSt(uint256) external;
function setOtherI(address) external;
function newFee(uint128) external returns (uint128);
}

Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ interface ITestContract {
struct Nested {
St st;
}
struct Fee {
uint128 fee;
}
function stateA() external returns (uint256);
function owner() external returns (address);
function structsMap(address,uint256) external returns (St memory);
Expand All @@ -26,5 +29,6 @@ interface ITestContract {
function getSt(uint256) external view returns (St memory);
function removeSt(St memory) external;
function setOtherI(address) external;
function newFee(uint128) external returns (Fee memory);
}