Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

♻️ refactor(gentest,base_types): Improved architecture, implement EthereumTestBaseModel, EthereumTestRootModel #901

Merged
merged 6 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 5 additions & 26 deletions src/cli/gentest/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,11 @@
from typing import TextIO

import click
import jinja2

from ethereum_test_base_types import Hash

from .request_manager import RPCRequest
from .test_providers import BlockchainTestProvider

template_loader = jinja2.PackageLoader("cli.gentest")
template_env = jinja2.Environment(loader=template_loader, keep_trailing_newline=True)
from .source_code_generator import get_test_source
from .test_context_providers import BlockchainTestContextProvider


@click.command()
danceratopz marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -31,26 +27,9 @@ def generate(transaction_hash: str, output_file: TextIO):

OUTPUT_FILE is the path to the output python script.
"""
request = RPCRequest()

print(
"Perform tx request: eth_get_transaction_by_hash(" + f"{transaction_hash}" + ")",
file=stderr,
)
transaction = request.eth_get_transaction_by_hash(Hash(transaction_hash))

print("Perform debug_trace_call", file=stderr)
state = request.debug_trace_call(transaction)

print("Perform eth_get_block_by_number", file=stderr)
block = request.eth_get_block_by_number(transaction.block_number)

print("Generate py test", file=stderr)
context = BlockchainTestProvider(
block=block, transaction=transaction, state=state
).get_context()
provider = BlockchainTestContextProvider(transaction_hash=Hash(transaction_hash))

template = template_env.get_template("blockchain_test/transaction.py.j2")
output_file.write(template.render(context))
source = get_test_source(provider=provider, template_path="blockchain_test/transaction.py.j2")
output_file.write(source)

print("Finished", file=stderr)
4 changes: 2 additions & 2 deletions src/cli/gentest/request_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from pydantic import BaseModel

from config import EnvConfig
from ethereum_test_base_types import Account, Address, Hash, HexNumber
from ethereum_test_base_types import Hash, HexNumber
from ethereum_test_rpc import BlockNumberType, DebugRPC, EthRPC
from ethereum_test_types import Transaction

Expand Down Expand Up @@ -100,7 +100,7 @@ def eth_get_block_by_number(self, block_number: BlockNumberType) -> RemoteBlock:
timestamp=res["timestamp"],
)

def debug_trace_call(self, transaction: RemoteTransaction) -> Dict[Address, Account]:
def debug_trace_call(self, transaction: RemoteTransaction) -> Dict[str, dict]:
"""
Get pre-state required for transaction
"""
Expand Down
83 changes: 83 additions & 0 deletions src/cli/gentest/source_code_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""
Pytest source code generator.

This module maps a test provider instance to pytest source code.
"""

import subprocess
import sys
import tempfile
from pathlib import Path

import jinja2

from .test_context_providers import Provider

template_loader = jinja2.PackageLoader("cli.gentest")
template_env = jinja2.Environment(loader=template_loader, keep_trailing_newline=True)

# This filter maps python objects to string
template_env.filters["stringify"] = lambda input: repr(input)
danceratopz marked this conversation as resolved.
Show resolved Hide resolved


# generates a formatted pytest source code by writing provided data on a given template.
def get_test_source(provider: Provider, template_path: str) -> str:
"""
Generates formatted pytest source code by rendering a template with provided data.

This function uses the given template path to create a pytest-compatible source
code string. It retrieves context data from the specified provider and applies
it to the template.

Args:
provider: An object that provides the necessary context for rendering the template.
template_path (str): The path to the Jinja2 template file used to generate tests.

Returns:
str: The formatted pytest source code.
"""
template = template_env.get_template(template_path)
rendered_template = template.render(provider.get_context())
# return rendered_template
return format_code(rendered_template)


def format_code(code: str) -> str:
danceratopz marked this conversation as resolved.
Show resolved Hide resolved
"""
Formats the provided Python code using the Black code formatter.

This function writes the given code to a temporary Python file, formats it using
the Black formatter, and returns the formatted code as a string.

Args:
code (str): The Python code to be formatted.

Returns:
str: The formatted Python code.
"""
# Create a temporary python file
with tempfile.NamedTemporaryFile(suffix=".py") as temp_file:
# Write the code to the temporary file
temp_file.write(code.encode("utf-8"))
# Ensure the file is written
temp_file.flush()

# Create a Path object for the input file
input_file_path = Path(temp_file.name)

# Get the path to the black executable in the virtual environment
if sys.platform.startswith("win"):
black_path = Path(sys.prefix) / "Scripts" / "black.exe"
else:
black_path = Path(sys.prefix) / "bin" / "black"

# Call black to format the file
config_path = Path(sys.prefix).parent / "pyproject.toml"

subprocess.run(
[str(black_path), str(input_file_path), "--quiet", "--config", str(config_path)],
check=True,
)

# Return the formatted source code
return input_file_path.read_text()
31 changes: 15 additions & 16 deletions src/cli/gentest/templates/blockchain_test/transaction.py.j2
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
gentest autogenerated test with debug_traceCall of tx.hash
Gentest autogenerated test from `tx.hash`:
{{ tx_hash }}
https://etherscan.io/tx/{{tx_hash}}
"""
Expand All @@ -8,37 +8,36 @@ from typing import Dict

import pytest

from ethereum_test_tools import Account, Block, BlockchainTestFiller, Environment, Transaction
from ethereum_test_tools import (
Account,
Block,
BlockchainTestFiller,
Environment,
Storage,
Transaction,
)

REFERENCE_SPEC_GIT_PATH = "N/A"
REFERENCE_SPEC_VERSION = "N/A"


@pytest.fixture
def env(): # noqa: D103
return Environment(
{{ environment_kwargs }}
)

return {{ environment | stringify }}

@pytest.mark.valid_from("Paris")
marioevz marked this conversation as resolved.
Show resolved Hide resolved
def test_transaction_{{ tx_hash }}( # noqa: SC200, E501
env: Environment,
blockchain_test: BlockchainTestFiller,
):
"""
gentest autogenerated test for tx.hash
Gentest autogenerated test for tx.hash:
{{ tx_hash }}
"""
pre = {
{{ pre_state_items }}
}
pre = {{ pre_state | stringify }}

post: Dict = {
}
post: Dict = {}

tx = Transaction(
{{ transaction_items }}
)
tx = {{ transaction | stringify }}

blockchain_test(genesis_environment=env, pre=pre, post=post, blocks=[Block(txs=[tx])])
blockchain_test(genesis_environment=env, pre=pre, post=post, blocks=[Block(txs=[tx])])
104 changes: 104 additions & 0 deletions src/cli/gentest/test_context_providers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""
This module contains various providers which generates context required to create test scripts.

Classes:
- Provider: An provider generates required context for creating a test.
- BlockchainTestProvider: The BlockchainTestProvider takes a transaction hash and creates
required context to create a test.

Example:
provider = BlockchainTestContextProvider(transaction=transaction)
context = provider.get_context()
"""

from abc import ABC, abstractmethod
from sys import stderr
from typing import Any, Dict, Optional

from pydantic import BaseModel

from ethereum_test_base_types import Account, Hash
from ethereum_test_tools import Environment, Transaction

from .request_manager import RPCRequest


class Provider(ABC, BaseModel):
"""
An provider generates required context for creating a test.
"""

@abstractmethod
def get_context(self) -> Dict:
"""
Get the context for generating a test.
"""

pass


class BlockchainTestContextProvider(Provider):
"""
Provides context required to generate a `blockchain_test` using pytest.
"""

transaction_hash: Hash
block: Optional[RPCRequest.RemoteBlock] = None
transaction: Optional[RPCRequest.RemoteTransaction] = None
state: Optional[Dict[str, Dict]] = None

def _make_rpc_calls(self):
request = RPCRequest()
print(
f"Perform tx request: eth_get_transaction_by_hash({self.transaction_hash})",
file=stderr,
)
self.transaction = request.eth_get_transaction_by_hash(self.transaction_hash)

print("Perform debug_trace_call", file=stderr)
self.state = request.debug_trace_call(self.transaction)

print("Perform eth_get_block_by_number", file=stderr)
self.block = request.eth_get_block_by_number(self.transaction.block_number)

print("Generate py test", file=stderr)

def _get_environment(self) -> Environment:
assert self.block is not None
return Environment(**self.block.model_dump())

def _get_pre_state(self) -> Dict[str, Account]:
assert self.state is not None
assert self.transaction is not None

pre_state: Dict[str, Account] = {}
for address, account_data in self.state.items():

# TODO: Check if this is required. Ideally,
# the pre-state tracer should have the correct
# values without requiring any additional modifications.
if address == self.transaction.sender:
account_data["nonce"] = self.transaction.nonce

pre_state[address] = Account(**account_data)
return pre_state

def _get_transaction(self) -> Transaction:
assert self.transaction is not None
return Transaction(**self.transaction.model_dump())

def get_context(self) -> Dict[str, Any]:
"""
Get the context for generating a blockchain test.

Returns:
Dict[str, Any]: A dictionary containing environment,
pre-state, a transaction and its hash.
"""
self._make_rpc_calls()
return {
"environment": self._get_environment(),
"pre_state": self._get_pre_state(),
"transaction": self._get_transaction(),
"tx_hash": self.transaction_hash,
}
Loading