diff --git a/Makefile b/Makefile index ae2186a..823eece 100644 --- a/Makefile +++ b/Makefile @@ -6,7 +6,7 @@ VENV_DIR = ./.venvs .PHONY: clean all -all: $(TEST_DIR) $(TEST_DIR)/shuffling $(TEST_DIR)/bls +all: $(TEST_DIR) $(TEST_DIR)/shuffling $(TEST_DIR)/bls $(TEST_DIR)/ssz clean: @@ -42,6 +42,16 @@ $(TEST_DIR)/bls: python $(GENERATOR_DIR)/bls/tgen_bls.py $@/test_bls.yml +$(TEST_DIR)/ssz: + mkdir -p $@ + + python -m venv $(VENV_DIR)/ssz + . $(VENV_DIR)/ssz/bin/activate + pip install -r $(GENERATOR_DIR)/ssz/requirements.txt --user + + python $(GENERATOR_DIR)/ssz/test_generator.py -o $@ + + # Example: # # $(TEST_DIR)/test-test: diff --git a/ssz/__init__.py b/ssz/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ssz/renderers.py b/ssz/renderers.py new file mode 100644 index 0000000..394d783 --- /dev/null +++ b/ssz/renderers.py @@ -0,0 +1,102 @@ +from collections.abc import ( + Mapping, + Sequence, +) + +from eth_utils import ( + encode_hex, + to_dict, +) + +from ssz.sedes import ( + BaseSedes, + Boolean, + Bytes, + BytesN, + Container, + List, + UInt, +) + + +def render_value(value): + if isinstance(value, bool): + return value + elif isinstance(value, int): + return str(value) + elif isinstance(value, bytes): + return encode_hex(value) + elif isinstance(value, Sequence): + return tuple(render_value(element) for element in value) + elif isinstance(value, Mapping): + return render_dict_value(value) + else: + raise ValueError(f"Cannot render value {value}") + + +@to_dict +def render_dict_value(value): + for key, value in value.items(): + yield key, render_value(value) + + +def render_type_definition(sedes): + if isinstance(sedes, Boolean): + return "bool" + + elif isinstance(sedes, UInt): + return f"uint{sedes.length * 8}" + + elif isinstance(sedes, BytesN): + return f"bytes{sedes.length}" + + elif isinstance(sedes, Bytes): + return f"bytes" + + elif isinstance(sedes, List): + return [render_type_definition(sedes.element_sedes)] + + elif isinstance(sedes, Container): + return { + field_name: render_type_definition(field_sedes) + for field_name, field_sedes in sedes.fields + } + + elif isinstance(sedes, BaseSedes): + raise Exception("Unreachable: All sedes types have been checked") + + else: + raise TypeError("Expected BaseSedes") + + +@to_dict +def render_test_case(*, sedes, valid, value=None, serial=None, description=None, tags=None): + value_and_serial_given = value is not None and serial is not None + if valid: + if not value_and_serial_given: + raise ValueError("For valid test cases, both value and ssz must be present") + else: + if value_and_serial_given: + raise ValueError("For invalid test cases, either value or ssz must not be present") + + if tags is None: + tags = [] + + yield "type", render_type_definition(sedes) + yield "valid", valid + if value is not None: + yield "value", render_value(value) + if serial is not None: + yield "ssz", encode_hex(serial) + if description is not None: + yield description + yield "tags", tags + + +@to_dict +def render_test(*, title, summary, version, test_cases): + yield "title", title, + if summary is not None: + yield "summary", summary + yield "version", version + yield "test_cases", test_cases diff --git a/ssz/requirements.txt b/ssz/requirements.txt new file mode 100644 index 0000000..88193a0 --- /dev/null +++ b/ssz/requirements.txt @@ -0,0 +1,2 @@ +ruamel.yaml==0.15.87 +ssz==0.1.0a2 diff --git a/ssz/test_generator.py b/ssz/test_generator.py new file mode 100644 index 0000000..d19ec12 --- /dev/null +++ b/ssz/test_generator.py @@ -0,0 +1,84 @@ +import argparse +import pathlib +import sys + +from ruamel.yaml import ( + YAML, +) + +from uint_test_generators import ( + generate_uint_bounds_test, + generate_uint_random_test, + generate_uint_wrong_length_test, +) + +test_generators = [ + generate_uint_random_test, + generate_uint_wrong_length_test, + generate_uint_bounds_test, +] + + +def make_filename_for_test(test): + title = test["title"] + filename = title.lower().replace(" ", "_") + ".yaml" + return pathlib.Path(filename) + + +def validate_output_dir(path_str): + path = pathlib.Path(path_str) + + if not path.exists(): + raise argparse.ArgumentTypeError("Output directory must exist") + + if not path.is_dir(): + raise argparse.ArgumentTypeError("Output path must lead to a directory") + + return path + + +parser = argparse.ArgumentParser( + prog="gen-ssz-tests", + description="Generate YAML test files for SSZ and tree hashing", +) +parser.add_argument( + "-o", + "--output-dir", + dest="output_dir", + required=True, + type=validate_output_dir, + help="directory into which the generated YAML files will be dumped" +) +parser.add_argument( + "-f", + "--force", + action="store_true", + default=False, + help="if set overwrite test files if they exist", +) + + +if __name__ == "__main__": + args = parser.parse_args() + output_dir = args.output_dir + if not args.force: + file_mode = "x" + else: + file_mode = "w" + + yaml = YAML(pure=True) + + print(f"generating {len(test_generators)} test files...") + for test_generator in test_generators: + test = test_generator() + + filename = make_filename_for_test(test) + path = output_dir / filename + + try: + with path.open(file_mode) as f: + yaml.dump(test, f) + except IOError as e: + sys.exit(f'Error when dumping test "{test["title"]}" ({e})') + + print("done.") diff --git a/ssz/uint_test_generators.py b/ssz/uint_test_generators.py new file mode 100644 index 0000000..a353785 --- /dev/null +++ b/ssz/uint_test_generators.py @@ -0,0 +1,132 @@ +import random + +from eth_utils import ( + to_tuple, +) + +import ssz +from ssz.sedes import ( + UInt, +) +from renderers import ( + render_test, + render_test_case, +) + +random.seed(0) + + +BIT_SIZES = [i for i in range(8, 512 + 1, 8)] +RANDOM_TEST_CASES_PER_BIT_SIZE = 10 +RANDOM_TEST_CASES_PER_LENGTH = 3 + + +def get_random_bytes(length): + return bytes(random.randint(0, 255) for _ in range(length)) + + +def generate_uint_bounds_test(): + test_cases = generate_uint_bounds_test_cases() + generate_uint_out_of_bounds_test_cases() + + return render_test( + title="UInt Bounds", + summary="Integers right at or beyond the bounds of the allowed value range", + version="0.1", + test_cases=test_cases, + ) + + +def generate_uint_random_test(): + test_cases = generate_random_uint_test_cases() + + return render_test( + title="UInt Random", + summary="Random integers chosen uniformly over the allowed value range", + version="0.1", + test_cases=test_cases, + ) + + +def generate_uint_wrong_length_test(): + test_cases = generate_uint_wrong_length_test_cases() + + return render_test( + title="UInt Wrong Length", + summary="Serialized integers that are too short or too long", + version="0.1", + test_cases=test_cases, + ) + + +@to_tuple +def generate_random_uint_test_cases(): + for bit_size in BIT_SIZES: + sedes = UInt(bit_size) + + for _ in range(RANDOM_TEST_CASES_PER_BIT_SIZE): + value = random.randrange(0, 2 ** bit_size) + serial = ssz.encode(value, sedes) + # note that we need to create the tags in each loop cycle, otherwise ruamel will use + # YAML references which makes the resulting file harder to read + tags = tuple(["atomic", "uint", "random"]) + yield render_test_case( + sedes=sedes, + valid=True, + value=value, + serial=serial, + tags=tags, + ) + + +@to_tuple +def generate_uint_wrong_length_test_cases(): + for bit_size in BIT_SIZES: + sedes = UInt(bit_size) + lengths = sorted({ + 0, + sedes.length // 2, + sedes.length - 1, + sedes.length + 1, + sedes.length * 2, + }) + for length in lengths: + for _ in range(RANDOM_TEST_CASES_PER_LENGTH): + tags = tuple(["atomic", "uint", "wrong_length"]) + yield render_test_case( + sedes=sedes, + valid=False, + serial=get_random_bytes(length), + tags=tags, + ) + + +@to_tuple +def generate_uint_bounds_test_cases(): + common_tags = ("atomic", "uint") + for bit_size in BIT_SIZES: + sedes = UInt(bit_size) + + for value, tag in ((0, "uint_lower_bound"), (2 ** bit_size - 1, "uint_upper_bound")): + serial = ssz.encode(value, sedes) + yield render_test_case( + sedes=sedes, + valid=True, + value=value, + serial=serial, + tags=common_tags + (tag,), + ) + + +@to_tuple +def generate_uint_out_of_bounds_test_cases(): + common_tags = ("atomic", "uint") + for bit_size in BIT_SIZES: + sedes = UInt(bit_size) + + for value, tag in ((-1, "uint_underflow"), (2 ** bit_size, "uint_overflow")): + yield render_test_case( + sedes=sedes, + valid=False, + value=value, + tags=common_tags + (tag,), + )