diff --git a/README.md b/README.md index 9d3ae00..d3833d0 100644 --- a/README.md +++ b/README.md @@ -109,6 +109,15 @@ Import the voting contract types like this: import voting.ballot as ballot ``` +### Decimals + +To use decimals on Vyper 0.4, use the following config: + +```yaml +vyper: + enable_decimals: true +``` + ### Pragmas Ape-Vyper supports Vyper 0.3.10's [new pragma formats](https://github.com/vyperlang/vyper/pull/3493) diff --git a/tests/ape-config.yaml b/ape-config.yaml similarity index 75% rename from tests/ape-config.yaml rename to ape-config.yaml index cccbac6..4bac66d 100644 --- a/tests/ape-config.yaml +++ b/ape-config.yaml @@ -1,10 +1,10 @@ # Allows compiling to work from the project-level. -contracts_folder: contracts/passing_contracts +contracts_folder: tests/contracts/passing_contracts # Specify a dependency to use in Vyper imports. dependencies: - name: exampledependency - local: ./ExampleDependency + local: ./tests/ExampleDependency # NOTE: Snekmate does not need to be listed here since # it is installed in site-packages. However, we include it @@ -12,3 +12,6 @@ dependencies: - python: snekmate config_override: contracts_folder: . + +vyper: + enable_decimals: true diff --git a/ape_vyper/compiler/_versions/base.py b/ape_vyper/compiler/_versions/base.py index 96ef244..1c98581 100644 --- a/ape_vyper/compiler/_versions/base.py +++ b/ape_vyper/compiler/_versions/base.py @@ -190,8 +190,7 @@ def get_settings( optimization = False selection_dict = self._get_selection_dictionary(selection, project=pm) - search_paths = [*getsitepackages()] - search_paths.append(".") + search_paths = [*getsitepackages(), "."] version_settings[settings_key] = { "optimize": optimization, diff --git a/ape_vyper/compiler/_versions/vyper_04.py b/ape_vyper/compiler/_versions/vyper_04.py index 5e8cefd..9d6632a 100644 --- a/ape_vyper/compiler/_versions/vyper_04.py +++ b/ape_vyper/compiler/_versions/vyper_04.py @@ -23,6 +23,25 @@ def get_import_remapping(self, project: Optional[ProjectManager] = None) -> dict # You always import via module or package name. return {} + def get_settings( + self, + version: Version, + source_paths: Iterable[Path], + compiler_data: dict, + project: Optional[ProjectManager] = None, + ) -> dict: + pm = project or self.local_project + + enable_decimals = self.api.get_config(project=pm).enable_decimals + if enable_decimals is None: + enable_decimals = False + + settings = super().get_settings(version, source_paths, compiler_data, project=pm) + for settings_set in settings.values(): + settings_set["enable_decimals"] = enable_decimals + + return settings + def _get_sources_dictionary( self, source_ids: Iterable[str], project: Optional[ProjectManager] = None, **kwargs ) -> dict[str, dict]: diff --git a/ape_vyper/compiler/api.py b/ape_vyper/compiler/api.py index ff8cce9..4748347 100644 --- a/ape_vyper/compiler/api.py +++ b/ape_vyper/compiler/api.py @@ -247,8 +247,17 @@ def compile( settings: Optional[dict] = None, ) -> Iterator[ContractType]: pm = project or self.local_project - + original_settings = self.compiler_settings self.compiler_settings = {**self.compiler_settings, **(settings or {})} + try: + yield from self._compile(contract_filepaths, project=pm) + finally: + self.compiler_settings = original_settings + + def _compile( + self, contract_filepaths: Iterable[Path], project: Optional[ProjectManager] = None + ): + pm = project or self.local_project contract_types: list[ContractType] = [] import_map = self._import_resolver.get_imports(pm, contract_filepaths) config = self.get_config(pm) @@ -514,12 +523,11 @@ def init_coverage_profile( def enrich_error(self, err: ContractLogicError) -> ContractLogicError: return enrich_error(err) + # TODO: In 0.9, make sure project is a kwarg here. def trace_source( self, contract_source: ContractSource, trace: TraceAPI, calldata: HexBytes ) -> SourceTraceback: - frames = trace.get_raw_frames() - tracer = SourceTracer(contract_source, frames, calldata) - return tracer.trace() + return SourceTracer.trace(trace.get_raw_frames(), contract_source, calldata) def _get_compiler_arguments( self, diff --git a/ape_vyper/config.py b/ape_vyper/config.py index 8f715a5..65080e1 100644 --- a/ape_vyper/config.py +++ b/ape_vyper/config.py @@ -32,6 +32,13 @@ class VyperConfig(PluginConfig): """ + enable_decimals: Optional[bool] = None + """ + On Vyper 0.4, to use decimal types, you must enable it. + Defaults to ``None`` to avoid misleading that ``False`` + means you cannot use decimals on a lower version. + """ + @field_validator("version", mode="before") def validate_version(cls, value): return pragma_str_to_specifier_set(value) if isinstance(value, str) else value diff --git a/ape_vyper/flattener.py b/ape_vyper/flattener.py index 5a47400..7b25c6c 100644 --- a/ape_vyper/flattener.py +++ b/ape_vyper/flattener.py @@ -4,7 +4,7 @@ from ape.logging import logger from ape.managers import ProjectManager -from ape.utils import ManagerAccessMixin +from ape.utils import ManagerAccessMixin, get_relative_path from ethpm_types.source import Content from ape_vyper._utils import get_version_pragma_spec @@ -65,7 +65,10 @@ def _flatten_source( flattened_modules = "" modules_prefixes: set[str] = set() - for import_path in sorted(imports): + # Source by source ID for greater consistency.. + for import_path in sorted( + imports, key=lambda p: f"{get_relative_path(p.absolute(), pm.path)}" + ): import_info = imports[import_path] # Vyper imported interface names come from their file names diff --git a/ape_vyper/imports.py b/ape_vyper/imports.py index c18f2bc..167deb8 100644 --- a/ape_vyper/imports.py +++ b/ape_vyper/imports.py @@ -264,7 +264,7 @@ def __init__(self, project: ProjectManager, paths: list[Path]): # Even though we build up mappings of all sources, as may be referenced # later on and that prevents re-calculating over again, we only # "show" the items requested. - self._request_view: list[Path] = paths + self.paths: list[Path] = paths def __getitem__(self, item: Union[str, Path], *args, **kwargs) -> list[Import]: if isinstance(item, str) or not item.is_absolute(): @@ -294,7 +294,7 @@ def keys(self) -> list[Path]: # type: ignore result = [] keys = sorted(list(super().keys())) for path in keys: - if path not in self._request_view: + if path not in self.paths: continue result.append(path) @@ -311,7 +311,7 @@ def values(self) -> list[list[Import]]: # type: ignore def items(self) -> list[tuple[Path, list[Import]]]: # type: ignore result = [] for path in self.keys(): # sorted - if path not in self._request_view: + if path not in self.paths: continue result.append((path, self[path])) @@ -328,30 +328,16 @@ class ImportResolver(ManagerAccessMixin): _projects: dict[str, ImportMap] = {} _dependency_attempted_compile: set[str] = set() - def get_imports( - self, - project: ProjectManager, - contract_filepaths: Iterable[Path], - ) -> ImportMap: + def get_imports(self, project: ProjectManager, contract_filepaths: Iterable[Path]) -> ImportMap: paths = list(contract_filepaths) - reset_view = None if project.project_id not in self._projects: self._projects[project.project_id] = ImportMap(project, paths) - else: - # Change the items we "view". Some (or all) may need to be added as well. - reset_view = self._projects[project.project_id]._request_view - self._projects[project.project_id]._request_view = paths - try: - import_map = self._get_imports(paths, project) - finally: - if reset_view is not None: - self._projects[project.project_id]._request_view = reset_view - - return import_map + return self._get_imports(paths, project) def _get_imports(self, paths: list[Path], project: ProjectManager) -> ImportMap: import_map = self._projects[project.project_id] + import_map.paths = list({*import_map.paths, *paths}) for path in paths: if path in import_map: # Already handled. diff --git a/ape_vyper/traceback.py b/ape_vyper/traceback.py index 1a4acf4..88ad14c 100644 --- a/ape_vyper/traceback.py +++ b/ape_vyper/traceback.py @@ -2,6 +2,7 @@ from pathlib import Path from typing import Optional, cast +from ape.managers import ProjectManager from ape.types import SourceTraceback from ape.utils import ManagerAccessMixin, get_full_extension from eth_pydantic_types import HexBytes @@ -20,50 +21,49 @@ class SourceTracer(ManagerAccessMixin): Use EVM data to create a trace of Vyper source lines. """ - def __init__(self, contract_source: ContractSource, frames: Iterator[dict], calldata: HexBytes): - self.contract_source = contract_source - self.frames = frames - self.calldata = calldata - + @classmethod def trace( - self, - contract: Optional[ContractSource] = None, - calldata: Optional[HexBytes] = None, + cls, + frames: Iterator[dict], + contract: ContractSource, + calldata: HexBytes, previous_depth: Optional[int] = None, + project: Optional[ProjectManager] = None, ) -> SourceTraceback: - contract_source = self.contract_source if contract is None else contract - calldata = self.calldata if calldata is None else calldata + pm = project or cls.local_project method_id = HexBytes(calldata[:4]) traceback = SourceTraceback.model_validate([]) completed = False pcmap = PCMap.model_validate({}) - for frame in self.frames: + for frame in frames: if frame["op"] in [c.value for c in CALL_OPCODES]: start_depth = frame["depth"] - called_contract, sub_calldata = self._create_contract_from_call(frame) + called_contract, sub_calldata = cls._create_contract_from_call(frame, project=pm) if called_contract: ext = get_full_extension(Path(called_contract.source_id)) if ext in [x for x in FileType]: # Called another Vyper contract. - sub_trace = self.trace( - contract=called_contract, - calldata=sub_calldata, + sub_trace = cls.trace( + frames, + called_contract, + sub_calldata, previous_depth=frame["depth"], + project=pm, ) traceback.extend(sub_trace) else: # Not a Vyper contract! - compiler = self.compiler_manager.registered_compilers[ext] + compiler = cls.compiler_manager.registered_compilers[ext] try: sub_trace = compiler.trace_source( - called_contract.contract_type, self.frames, sub_calldata + called_contract.contract_type, frames, sub_calldata ) traceback.extend(sub_trace) except NotImplementedError: # Compiler not supported. Fast forward out of this call. - for fr in self.frames: + for fr in frames: if fr["depth"] <= start_depth: break @@ -71,7 +71,7 @@ def trace( else: # Contract not found. Fast forward out of this call. - for fr in self.frames: + for fr in frames: if fr["depth"] <= start_depth: break @@ -83,14 +83,14 @@ def trace( completed = previous_depth is not None pcs_to_try_adding = set() - if "PUSH" in frame["op"] and frame["pc"] in contract_source.pcmap: + if "PUSH" in frame["op"] and frame["pc"] in contract.pcmap: # Check if next op is SSTORE to properly use AST from push op. next_frame: Optional[dict] = frame - loc = contract_source.pcmap[frame["pc"]] + loc = contract.pcmap[frame["pc"]] pcs_to_try_adding.add(frame["pc"]) while next_frame and "PUSH" in next_frame["op"]: - next_frame = next(self.frames, None) + next_frame = next(frames, None) if next_frame and "PUSH" in next_frame["op"]: pcs_to_try_adding.add(next_frame["pc"]) @@ -103,7 +103,7 @@ def trace( completed = True else: - pcmap = contract_source.pcmap + pcmap = contract.pcmap dev_val = str((loc.get("dev") or "")).replace("dev: ", "") is_non_payable_hit = dev_val == RuntimeErrorType.NONPAYABLE_CHECK.value @@ -111,7 +111,7 @@ def trace( frame = next_frame else: - pcmap = contract_source.pcmap + pcmap = contract.pcmap pcs_to_try_adding.add(frame["pc"]) pcs_to_try_adding = {pc for pc in pcs_to_try_adding if pc in pcmap} @@ -147,7 +147,7 @@ def trace( # New group. pc_groups.append([location, {pc}, dev]) - dev_messages = contract_source.contract_type.dev_messages or {} + dev_messages = contract.contract_type.dev_messages or {} for location, pcs, dev in pc_groups: if dev in [m.value for m in RuntimeErrorType if m != RuntimeErrorType.USER_ASSERT]: error_type = RuntimeErrorType(dev) @@ -160,9 +160,9 @@ def trace( name = traceback.last.closure.name full_name = traceback.last.closure.full_name - elif method_id in contract_source.contract_type.methods: + elif method_id in contract.contract_type.methods: # For non-payable checks, they should hit here. - method_checked = contract_source.contract_type.methods[method_id] + method_checked = contract.contract_type.methods[method_id] name = method_checked.name full_name = method_checked.selector @@ -186,7 +186,7 @@ def trace( f"dev: {dev}", full_name=full_name, pcs=pcs, - source_path=contract_source.source_path, + source_path=contract.source_path, ) continue @@ -194,7 +194,7 @@ def trace( # Unknown. continue - if not (function := contract_source.lookup_function(location, method_id=method_id)): + if not (function := contract.lookup_function(location, method_id=method_id)): continue if ( @@ -213,7 +213,7 @@ def trace( function, depth, pcs=pcs, - source_path=contract_source.source_path, + source_path=contract.source_path, ) else: traceback.extend_last(location, pcs=pcs) @@ -235,7 +235,11 @@ def trace( return traceback - def _create_contract_from_call(self, frame: dict) -> tuple[Optional[ContractSource], HexBytes]: + @classmethod + def _create_contract_from_call( + cls, frame: dict, project: Optional[ProjectManager] = None + ) -> tuple[Optional[ContractSource], HexBytes]: + pm = project or cls.local_project evm_frame = TraceFrame(**frame) data = create_call_node_data(evm_frame) calldata = data.get("calldata", HexBytes("")) @@ -243,12 +247,12 @@ def _create_contract_from_call(self, frame: dict) -> tuple[Optional[ContractSour return None, calldata try: - address = self.provider.network.ecosystem.decode_address(address) + address = cls.provider.network.ecosystem.decode_address(address) except Exception: return None, calldata - if address not in self.chain_manager.contracts: + if address not in cls.chain_manager.contracts: return None, calldata - called_contract = self.chain_manager.contracts[address] - return self.local_project._create_contract_source(called_contract), calldata + called_contract = cls.chain_manager.contracts[address] + return pm._create_contract_source(called_contract), calldata diff --git a/tests/conftest.py b/tests/conftest.py index e2ce182..803f615 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,12 +1,11 @@ import os -import shutil from contextlib import contextmanager from pathlib import Path import ape import pytest import vvm # type: ignore -from ape.contracts import ContractContainer +from ape.contracts import ContractContainer, ContractInstance from ape.utils import create_tempdir from click.testing import CliRunner @@ -41,20 +40,6 @@ } -@pytest.fixture(scope="session", autouse=True) -def from_tests_dir(): - # Makes default project correct. - here = Path(__file__).parent - orig = Path.cwd() - if orig != here: - os.chdir(f"{here}") - - yield - - if Path.cwd() != orig: - os.chdir(f"{orig}") - - @pytest.fixture(scope="session", autouse=True) def config(): with ape.config.isolate_data_folder(): @@ -150,15 +135,7 @@ def compiler(compiler_manager): @pytest.fixture(scope="session", autouse=True) def project(config): - project_source_dir = Path(__file__).parent - - # Delete build / .cache that may exist pre-copy - cache = project_source_dir / ".build" - shutil.rmtree(cache, ignore_errors=True) - - root_project = ape.Project(project_source_dir) - with root_project.isolate_in_tempdir() as tmp_project: - yield tmp_project + return config.local_project @pytest.fixture @@ -207,7 +184,7 @@ def cli_runner(): return CliRunner() -def _get_tb_contract(version: str, project, account): +def _get_tb_contract(version: str, project, account) -> ContractInstance: project.load_contracts() registry_type = project.get_contract(f"registry_{version}") diff --git a/tests/contracts/passing_contracts/subdir/zero_four_in_subdir.vy b/tests/contracts/passing_contracts/subdir/zero_four_in_subdir.vy index b3091c9..980bcaf 100644 --- a/tests/contracts/passing_contracts/subdir/zero_four_in_subdir.vy +++ b/tests/contracts/passing_contracts/subdir/zero_four_in_subdir.vy @@ -1,5 +1,5 @@ # Show we can import from the root of the project w/o needing relative imports -from contracts.passing_contracts import zero_four_module as zero_four_module +from tests.contracts.passing_contracts import zero_four_module as zero_four_module @external def callModuleFunctionFromSubdir(role: bytes32) -> bool: diff --git a/tests/contracts/passing_contracts/use_iface.vy b/tests/contracts/passing_contracts/use_iface.vy index b29c6ff..c2903df 100644 --- a/tests/contracts/passing_contracts/use_iface.vy +++ b/tests/contracts/passing_contracts/use_iface.vy @@ -10,7 +10,7 @@ import exampledependency.Dependency as Dep from .interfaces import IFace2 as IFace2 # Also use IFaceNested to show we can use nested interfaces. -from contracts.passing_contracts.interfaces.nested import IFaceNested as IFaceNested +from tests.contracts.passing_contracts.interfaces.nested import IFaceNested as IFaceNested @external diff --git a/tests/functional/test_compiler.py b/tests/functional/test_compiler.py index 07fd984..04c26b4 100644 --- a/tests/functional/test_compiler.py +++ b/tests/functional/test_compiler.py @@ -12,6 +12,7 @@ from packaging.version import Version from vvm.exceptions import VyperError # type: ignore +from ape_vyper._utils import EVM_VERSION_DEFAULT from ape_vyper.exceptions import ( FallbackNotDefinedError, IntegerOverflowError, @@ -390,7 +391,7 @@ def test_get_imports(compiler, project): ] actual = compiler.get_imports(vyper_files, project=project) - prefix = "contracts/passing_contracts" + prefix = "tests/contracts/passing_contracts" builtin_import = "vyper/interfaces/ERC20.json" local_import = "IFace.vy" local_from_import = "IFace2.vy" @@ -585,8 +586,26 @@ def revert_type(self) -> Optional[str]: assert isinstance(new_error, NonPayableError) +def test_trace_source(geth_provider, project, traceback_contract, account, compiler): + receipt = traceback_contract.addBalance(123, sender=account) + contract = project._create_contract_source(traceback_contract.contract_type) + trace = receipt.trace + actual = compiler.trace_source(contract, trace, receipt.data) + base_folder = Path(__file__).parent.parent / "contracts" / "passing_contracts" + contract_name = traceback_contract.contract_type.name + expected = rf""" +Traceback (most recent call last) + File {base_folder}/{contract_name}.vy, in addBalance + 32 if i != num: + 33 continue + 34 + --> 35 return self._balance + """.strip() + assert str(actual) == expected + + @pytest.mark.parametrize("arguments", [(), (123,), (123, 321)]) -def test_trace_source(account, geth_provider, project, traceback_contract, arguments): +def test_trace_source_from_receipt(account, geth_provider, project, traceback_contract, arguments): receipt = traceback_contract.addBalance(*arguments, sender=account) actual = receipt.source_traceback base_folder = Path(__file__).parent.parent / "contracts" / "passing_contracts" @@ -628,7 +647,7 @@ def check(name: str, tb): check("addBalance(uint256,uint256)", both_args_tb) -def test_trace_err_source(account, geth_provider, project, traceback_contract): +def test_trace_source_when_err(account, geth_provider, project, traceback_contract): txn = traceback_contract.addBalance_f.as_transaction(123) try: account.call(txn) @@ -769,3 +788,48 @@ def test_get_import_remapping(project, compiler): dependency.load_contracts() actual = compiler.get_import_remapping(project=project) assert "exampledependency/Dependency.json" in actual + + +def test_get_compiler_settings(project, compiler): + vyper2_path = project.contracts_folder / "older_version.vy" + vyper3_path = project.contracts_folder / "non_payable_default.vy" + vyper4_path = project.contracts_folder / "zero_four.vy" + vyper2_settings = compiler.get_compiler_settings((vyper2_path,), project=project) + vyper3_settings = compiler.get_compiler_settings((vyper3_path,), project=project) + vyper4_settings = compiler.get_compiler_settings((vyper4_path,), project=project) + + v2_version_used = next(iter(vyper2_settings.keys())) + assert v2_version_used >= Version("0.2.16"), f"version={v2_version_used}" + assert vyper2_settings[v2_version_used]["true%berlin"]["optimize"] is True + assert vyper2_settings[v2_version_used]["true%berlin"]["evmVersion"] == "berlin" + assert vyper2_settings[v2_version_used]["true%berlin"]["outputSelection"] == { + "tests/contracts/passing_contracts/older_version.vy": ["*"] + } + assert "enable_decimals" not in vyper2_settings[v2_version_used]["true%berlin"] + + v3_version_used = next(iter(vyper3_settings.keys())) + settings_key = next(iter(vyper3_settings[v3_version_used].keys())) + valid_evm_versions = [ + v for k, v in EVM_VERSION_DEFAULT.items() if Version(k) >= Version("0.3.0") + ] + pattern = rf"true%({'|'.join(valid_evm_versions)})" + assert re.match(pattern, settings_key) + assert v3_version_used >= Version("0.3.0"), f"version={v3_version_used}" + assert vyper3_settings[v3_version_used][settings_key]["optimize"] is True + assert vyper3_settings[v3_version_used][settings_key]["evmVersion"] in valid_evm_versions + assert vyper3_settings[v3_version_used][settings_key]["outputSelection"] == { + "tests/contracts/passing_contracts/non_payable_default.vy": ["*"] + } + assert "enable_decimals" not in vyper3_settings[v3_version_used][settings_key] + + assert len(vyper4_settings) == 1, f"extra keys={''.join([f'{x}' for x in vyper4_settings])}" + v4_version_used = next(iter(vyper4_settings.keys())) + assert v4_version_used >= Version( + "0.4.0" + ), f"version={v4_version_used} full_data={vyper4_settings}" + assert vyper4_settings[v4_version_used]["gas%shanghai"]["enable_decimals"] is True + assert vyper4_settings[v4_version_used]["gas%shanghai"]["optimize"] == "gas" + assert vyper4_settings[v4_version_used]["gas%shanghai"]["outputSelection"] == { + "tests/contracts/passing_contracts/zero_four.vy": ["*"] + } + assert vyper4_settings[v4_version_used]["gas%shanghai"]["evmVersion"] == "shanghai"