Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

streamable for merkle blob serialization #19154

Draft
wants to merge 2 commits into
base: long_lived/datalayer_merkle_blob
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 80 additions & 88 deletions chia/_tests/core/data_layer/test_merkle_blob.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

import hashlib
import struct
from dataclasses import astuple, dataclass
import itertools
from dataclasses import dataclass
from random import Random
from typing import Generic, Protocol, TypeVar, final

Expand All @@ -26,14 +26,14 @@
TreeIndex,
data_size,
metadata_size,
null_parent,
pack_raw_node,
raw_node_classes,
raw_node_type_to_class,
spacing,
unpack_raw_node,
)
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.ints import int64, uint32

pytestmark = pytest.mark.data_layer

Expand Down Expand Up @@ -62,23 +62,6 @@ def raw_node_class_fixture(request: SubRequest) -> RawMerkleNodeProtocol:
return request.param # type: ignore[no-any-return]


class_to_structs: dict[type[object], struct.Struct] = {
NodeMetadata: NodeMetadata.struct,
**{cls: cls.struct for cls in raw_node_classes},
}


@pytest.fixture(
name="class_struct",
scope="session",
params=class_to_structs.values(),
ids=[cls.__name__ for cls in class_to_structs.keys()],
)
def class_struct_fixture(request: SubRequest) -> RawMerkleNodeProtocol:
# https://github.com/pytest-dev/pytest/issues/8763
return request.param # type: ignore[no-any-return]


def test_raw_node_class_types_are_unique() -> None:
assert len(raw_node_type_to_class) == len(raw_node_classes)

Expand All @@ -88,32 +71,46 @@ def test_metadata_size_not_changed() -> None:


def test_data_size_not_changed() -> None:
assert data_size == 52


def test_raw_node_struct_sizes(raw_node_class: RawMerkleNodeProtocol) -> None:
assert raw_node_class.struct.size == data_size


def test_all_big_endian(class_struct: struct.Struct) -> None:
assert class_struct.format.startswith(">")
assert data_size == 53


# TODO: check all struct types against attribute types

RawMerkleNodeT = TypeVar("RawMerkleNodeT", bound=RawMerkleNodeProtocol)


reference_blob = bytes(range(data_size))
counter = itertools.count()
# hash
internal_reference_blob = bytes([next(counter) for _ in range(32)])
# optional parent
internal_reference_blob += bytes([1])
internal_reference_blob += bytes([next(counter) for _ in range(4)])
# left
internal_reference_blob += bytes([next(counter) for _ in range(4)])
# right
internal_reference_blob += bytes([next(counter) for _ in range(4)])
internal_reference_blob += bytes(0 for _ in range(data_size - len(internal_reference_blob)))
assert len(internal_reference_blob) == data_size

counter = itertools.count()
# hash
leaf_reference_blob = bytes([next(counter) for _ in range(32)])
# optional parent
leaf_reference_blob += bytes([1])
leaf_reference_blob += bytes([next(counter) for _ in range(4)])
# key
leaf_reference_blob += bytes([next(counter) for _ in range(8)])
# value
leaf_reference_blob += bytes([next(counter) for _ in range(8)])
leaf_reference_blob += bytes(0 for _ in range(data_size - len(leaf_reference_blob)))
assert len(leaf_reference_blob) == data_size


@final
@dataclass
class RawNodeFromBlobCase(Generic[RawMerkleNodeT]):
raw: RawMerkleNodeT
blob_to_unpack: bytes = reference_blob
packed_blob_reference_leaf: bytes = reference_blob
packed_blob_reference_internal: bytes = reference_blob[:44] + bytes([0] * 8)
packed: bytes

marks: Marks = ()

Expand All @@ -125,101 +122,104 @@ def id(self) -> str:
reference_raw_nodes: list[DataCase] = [
RawNodeFromBlobCase(
raw=RawInternalMerkleNode(
hash=bytes(range(32)),
parent=TreeIndex(0x20212223),
left=TreeIndex(0x24252627),
right=TreeIndex(0x28292A2B),
hash=bytes32(range(32)),
parent=TreeIndex(uint32(0x20212223)),
left=TreeIndex(uint32(0x24252627)),
right=TreeIndex(uint32(0x28292A2B)),
),
packed=internal_reference_blob,
),
RawNodeFromBlobCase(
raw=RawLeafMerkleNode(
hash=bytes(range(32)),
parent=TreeIndex(0x20212223),
key=KVId(0x2425262728292A2B),
value=KVId(0x2C2D2E2F30313233),
hash=bytes32(range(32)),
parent=TreeIndex(uint32(0x20212223)),
key=KVId(int64(0x2425262728292A2B)),
value=KVId(int64(0x2C2D2E2F30313233)),
),
packed=leaf_reference_blob,
),
]


@datacases(*reference_raw_nodes)
def test_raw_node_from_blob(case: RawNodeFromBlobCase[RawMerkleNodeProtocol]) -> None:
node = unpack_raw_node(
index=TreeIndex(0),
index=TreeIndex(uint32(0)),
metadata=NodeMetadata(type=case.raw.type, dirty=False),
data=case.blob_to_unpack,
data=case.packed,
)
assert node == case.raw


@datacases(*reference_raw_nodes)
def test_raw_node_to_blob(case: RawNodeFromBlobCase[RawMerkleNodeProtocol]) -> None:
blob = pack_raw_node(case.raw)
expected_blob = (
case.packed_blob_reference_leaf
if isinstance(case.raw, RawLeafMerkleNode)
else case.packed_blob_reference_internal
)

assert blob == expected_blob
assert blob == case.packed


def test_merkle_blob_one_leaf_loads() -> None:
# TODO: need to persist reference data
leaf = RawLeafMerkleNode(
hash=bytes(range(32)),
parent=null_parent,
key=KVId(0x0405060708090A0B),
value=KVId(0x0405060708090A1B),
hash=bytes32(range(32)),
parent=None,
key=KVId(int64(0x0405060708090A0B)),
value=KVId(int64(0x0405060708090A1B)),
)
blob = bytearray(NodeMetadata(type=NodeType.leaf, dirty=False).pack() + pack_raw_node(leaf))

merkle_blob = MerkleBlob(blob=blob)
assert merkle_blob.get_raw_node(TreeIndex(0)) == leaf
assert merkle_blob.get_raw_node(TreeIndex(uint32(0))) == leaf


def test_merkle_blob_two_leafs_loads() -> None:
# TODO: break this test down into some reusable data and multiple tests
# TODO: need to persist reference data
root = RawInternalMerkleNode(
hash=bytes(range(32)),
parent=null_parent,
left=TreeIndex(1),
right=TreeIndex(2),
hash=bytes32(range(32)),
parent=None,
left=TreeIndex(uint32(1)),
right=TreeIndex(uint32(2)),
)
left_leaf = RawLeafMerkleNode(
hash=bytes(range(32)),
parent=TreeIndex(0),
key=KVId(0x0405060708090A0B),
value=KVId(0x0405060708090A1B),
hash=bytes32(range(32)),
parent=TreeIndex(uint32(0)),
key=KVId(int64(0x0405060708090A0B)),
value=KVId(int64(0x0405060708090A1B)),
)
right_leaf = RawLeafMerkleNode(
hash=bytes(range(32)),
parent=TreeIndex(0),
key=KVId(0x1415161718191A1B),
value=KVId(0x1415161718191A2B),
hash=bytes32(range(32)),
parent=TreeIndex(uint32(0)),
key=KVId(int64(0x1415161718191A1B)),
value=KVId(int64(0x1415161718191A2B)),
)
blob = bytearray()
blob.extend(NodeMetadata(type=NodeType.internal, dirty=True).pack() + pack_raw_node(root))
blob.extend(NodeMetadata(type=NodeType.leaf, dirty=False).pack() + pack_raw_node(left_leaf))
blob.extend(NodeMetadata(type=NodeType.leaf, dirty=False).pack() + pack_raw_node(right_leaf))

merkle_blob = MerkleBlob(blob=blob)
assert merkle_blob.get_raw_node(TreeIndex(0)) == root
assert merkle_blob.get_raw_node(TreeIndex(uint32(0))) == root
assert merkle_blob.get_raw_node(root.left) == left_leaf
assert merkle_blob.get_raw_node(root.right) == right_leaf
assert left_leaf.parent is not None
assert merkle_blob.get_raw_node(left_leaf.parent) == root
assert right_leaf.parent is not None
assert merkle_blob.get_raw_node(right_leaf.parent) == root

assert merkle_blob.get_lineage_with_indexes(TreeIndex(0)) == [(0, root)]
assert merkle_blob.get_lineage_with_indexes(root.left) == [(1, left_leaf), (0, root)]
assert merkle_blob.get_lineage_with_indexes(TreeIndex(uint32(0))) == [(0, root)]
expected: list[tuple[TreeIndex, RawMerkleNodeProtocol]] = [
(TreeIndex(uint32(1)), left_leaf),
(TreeIndex(uint32(0)), root),
]
assert merkle_blob.get_lineage_with_indexes(root.left) == expected

merkle_blob.calculate_lazy_hashes()
son_hash = bytes32(range(32))
root_hash = internal_hash(son_hash, son_hash)
expected_node = InternalNode(root_hash, son_hash, son_hash)
assert merkle_blob.get_lineage_by_key_id(KVId(0x0405060708090A0B)) == [expected_node]
assert merkle_blob.get_lineage_by_key_id(KVId(0x1415161718191A1B)) == [expected_node]
assert merkle_blob.get_lineage_by_key_id(KVId(int64(0x0405060708090A0B))) == [expected_node]
assert merkle_blob.get_lineage_by_key_id(KVId(int64(0x1415161718191A1B))) == [expected_node]


def generate_kvid(seed: int) -> tuple[KVId, KVId]:
Expand All @@ -228,16 +228,16 @@ def generate_kvid(seed: int) -> tuple[KVId, KVId]:
for offset in range(2):
seed_bytes = (2 * seed + offset).to_bytes(8, byteorder="big", signed=True)
hash_obj = hashlib.sha256(seed_bytes)
hash_int = int.from_bytes(hash_obj.digest()[:8], byteorder="big", signed=True)
hash_int = int64.from_bytes(hash_obj.digest()[:8])
kv_ids.append(KVId(hash_int))

return kv_ids[0], kv_ids[1]


def generate_hash(seed: int) -> bytes:
def generate_hash(seed: int) -> bytes32:
seed_bytes = seed.to_bytes(8, byteorder="big", signed=True)
hash_obj = hashlib.sha256(seed_bytes)
return hash_obj.digest()
return bytes32(hash_obj.digest())


def test_insert_delete_loads_all_keys() -> None:
Expand Down Expand Up @@ -337,7 +337,7 @@ def test_proof_of_inclusion_merkle_blob() -> None:
num_deletes = 1 + repeats * 10

kv_ids: list[tuple[KVId, KVId]] = []
hashes: list[bytes] = []
hashes: list[bytes32] = []
for _ in range(num_inserts):
seed += 1
key, value = generate_kvid(seed)
Expand Down Expand Up @@ -379,10 +379,10 @@ def test_proof_of_inclusion_merkle_blob() -> None:
assert proof_of_inclusion.valid()


@pytest.mark.parametrize(argnames="index", argvalues=[TreeIndex(-1), TreeIndex(1), TreeIndex(null_parent)])
@pytest.mark.parametrize(argnames="index", argvalues=[-1, 1, None])
def test_get_raw_node_raises_for_invalid_indexes(index: TreeIndex) -> None:
merkle_blob = MerkleBlob(blob=bytearray())
merkle_blob.insert(KVId(0x1415161718191A1B), KVId(0x1415161718191A1B), bytes(range(12, data_size)))
merkle_blob.insert(KVId(int64(0x1415161718191A1B)), KVId(int64(0x1415161718191A1B)), bytes32(range(12, 12 + 32)))

with pytest.raises(InvalidIndexError):
merkle_blob.get_raw_node(index)
Expand All @@ -391,14 +391,6 @@ def test_get_raw_node_raises_for_invalid_indexes(index: TreeIndex) -> None:
merkle_blob.get_metadata(index)


@pytest.mark.parametrize(argnames="cls", argvalues=raw_node_classes)
def test_as_tuple_matches_dataclasses_astuple(cls: type[RawMerkleNodeProtocol], seeded_random: Random) -> None:
raw_bytes = bytes(seeded_random.getrandbits(8) for _ in range(cls.struct.size))
raw_node = cls(*cls.struct.unpack(raw_bytes))
# TODO: try again to indicate that the RawMerkleNodeProtocol requires the dataclass interface
assert raw_node.as_tuple() == astuple(raw_node) # type: ignore[call-overload]


def test_helper_methods(merkle_blob_type: MerkleBlobCallable) -> None:
merkle_blob = merkle_blob_type(blob=bytearray())
assert merkle_blob.empty()
Expand All @@ -409,7 +401,7 @@ def test_helper_methods(merkle_blob_type: MerkleBlobCallable) -> None:
merkle_blob.insert(key, value, hash)
assert not merkle_blob.empty()
assert merkle_blob.get_root_hash() is not None
assert merkle_blob.get_root_hash() == merkle_blob.get_hash_at_index(TreeIndex(0))
assert merkle_blob.get_root_hash() == merkle_blob.get_hash_at_index(TreeIndex(uint32(0)))

merkle_blob.delete(key)
assert merkle_blob.empty()
Expand Down Expand Up @@ -483,7 +475,7 @@ def test_get_nodes(merkle_blob_type: MerkleBlobCallable) -> None:


def test_just_insert_a_bunch(merkle_blob_type: MerkleBlobCallable) -> None:
HASH = bytes(range(12, 44))
HASH = bytes32(range(12, 12 + 32))

import pathlib

Expand All @@ -497,6 +489,6 @@ def test_just_insert_a_bunch(merkle_blob_type: MerkleBlobCallable) -> None:
total_time = 0.0
for i in range(100000):
start = time.monotonic()
merkle_blob.insert(KVId(i), KVId(i), HASH)
merkle_blob.insert(KVId(int64(i)), KVId(int64(i)), HASH)
end = time.monotonic()
total_time += end - start
2 changes: 1 addition & 1 deletion chia/data_layer/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1200,7 +1200,7 @@ async def insert_batch(
last_action[hash] = change["action"]

batch_keys_values: list[tuple[KVId, KVId]] = []
batch_hashes: list[bytes] = []
batch_hashes: list[bytes32] = []

for change in changelist:
if change["action"] == "insert":
Expand Down
Loading
Loading