Skip to content

Commit

Permalink
Add API for sharded serialization to a digest (#198)
Browse files Browse the repository at this point in the history
* Add API for sharded serialization to a digest.

This is what used to be `serialize_v1`.

Additionally, in this change we rename `serializing` to `serialization`
to be gramatically correct. We expose `shard_size` and a new
`digest_size` method from all hashing engines. We also make imports be
more consistent.

Signed-off-by: Mihai Maruseac <mihaimaruseac@google.com>

* Change TODOs to link to issues

Signed-off-by: Mihai Maruseac <mihaimaruseac@google.com>

* Add test for fifo

Signed-off-by: Mihai Maruseac <mihaimaruseac@google.com>

* Test root as pipe too

Signed-off-by: Mihai Maruseac <mihaimaruseac@google.com>

* Merge `_get_sizes` and `_build_tasks`

Signed-off-by: Mihai Maruseac <mihaimaruseac@google.com>

---------

Signed-off-by: Mihai Maruseac <mihaimaruseac@google.com>
  • Loading branch information
mihaimaruseac authored Jun 6, 2024
1 parent 185a5fd commit 24d68a4
Show file tree
Hide file tree
Showing 14 changed files with 1,043 additions and 442 deletions.
14 changes: 10 additions & 4 deletions model_signing/hashing/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,12 @@ def compute(self) -> hashing.Digest:
digest = self._content_hasher.compute()
return hashing.Digest(self.digest_name, digest.digest_value)

@override
@property
def digest_size(self) -> int:
"""The size, in bytes, of the digests produced by the engine."""
return self._content_hasher.digest_size


class ShardedFileHasher(FileHasher):
"""File hash engine that can be invoked in parallel.
Expand Down Expand Up @@ -168,7 +174,7 @@ def __init__(
raise ValueError(
f"Shard size must be strictly positive, got {shard_size}."
)
self._shard_size = shard_size
self.shard_size = shard_size

self.set_shard(start=start, end=end)

Expand All @@ -184,9 +190,9 @@ def set_shard(self, *, start: int, end: int) -> None:
f" got {start=}, {end=}."
)
read_length = end - start
if read_length > self._shard_size:
if read_length > self.shard_size:
raise ValueError(
f"Must not read more than shard_size={self._shard_size}, got"
f"Must not read more than shard_size={self.shard_size}, got"
f" {read_length}."
)

Expand Down Expand Up @@ -219,4 +225,4 @@ def compute(self) -> hashing.Digest:
def digest_name(self) -> str:
if self._digest_name_override is not None:
return self._digest_name_override
return f"file-{self._content_hasher.digest_name}-{self._shard_size}"
return f"file-{self._content_hasher.digest_name}-{self.shard_size}"
66 changes: 65 additions & 1 deletion model_signing/hashing/file_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ def test_hash_of_known_file_small_chunk(self, sample_file, expected_digest):
digest = hasher.compute()
assert digest.digest_hex == expected_digest

def test_hash_of_known_file_large_chunk(self, sample_file, expected_digest):
size = 2 * len(_FULL_CONTENT)
hasher = file.FileHasher(sample_file, memory.SHA256(), chunk_size=size)
digest = hasher.compute()
assert digest.digest_hex == expected_digest

def test_hash_file_twice(self, sample_file):
hasher1 = file.FileHasher(sample_file, memory.SHA256())
digest1 = hasher1.compute()
Expand Down Expand Up @@ -113,7 +119,7 @@ def test_set_file(self, sample_file, sample_file_content_only):
assert digest1.digest_value == digest2.digest_value

def test_default_digest_name(self):
hasher = file.FileHasher("unused", memory.SHA256(), chunk_size=10)
hasher = file.FileHasher("unused", memory.SHA256())
assert hasher.digest_name == "file-sha256"

def test_override_digest_name(self):
Expand All @@ -130,6 +136,11 @@ def test_digest_algorithm_is_digest_name(self, sample_file):
digest = hasher.compute()
assert digest.algorithm == hasher.digest_name

def test_digest_size(self):
memory_hasher = memory.SHA256()
hasher = file.FileHasher(sample_file, memory_hasher)
assert hasher.digest_size == memory_hasher.digest_size


class TestShardedFileHasher:

Expand Down Expand Up @@ -304,6 +315,54 @@ def test_hash_of_known_file_small_chunk(
digest2 = hasher2.compute()
assert digest2.digest_hex == expected_content_digest

def test_hash_of_known_file_large_chunk(
self, sample_file, expected_header_digest, expected_content_digest
):
hasher1 = file.ShardedFileHasher(
sample_file,
memory.SHA256(),
start=0,
end=_SHARD_SIZE,
chunk_size=2 * len(_FULL_CONTENT),
)
hasher2 = file.ShardedFileHasher(
sample_file,
memory.SHA256(),
start=_SHARD_SIZE,
end=2 * _SHARD_SIZE,
chunk_size=2 * len(_FULL_CONTENT),
)

digest1 = hasher1.compute()
assert digest1.digest_hex == expected_header_digest

digest2 = hasher2.compute()
assert digest2.digest_hex == expected_content_digest

def test_hash_of_known_file_large_shard(
self, sample_file, expected_header_digest, expected_content_digest
):
hasher1 = file.ShardedFileHasher(
sample_file,
memory.SHA256(),
start=0,
end=_SHARD_SIZE,
shard_size=2 * len(_FULL_CONTENT),
)
hasher2 = file.ShardedFileHasher(
sample_file,
memory.SHA256(),
start=_SHARD_SIZE,
end=2 * _SHARD_SIZE,
shard_size=2 * len(_FULL_CONTENT),
)

digest1 = hasher1.compute()
assert digest1.digest_hex == expected_header_digest

digest2 = hasher2.compute()
assert digest2.digest_hex == expected_content_digest

def test_default_digest_name(self):
hasher = file.ShardedFileHasher(
"unused", memory.SHA256(), start=0, end=2, shard_size=10
Expand Down Expand Up @@ -332,3 +391,8 @@ def test_digest_algorithm_is_digest_name(self, sample_file):
)
digest = hasher.compute()
assert digest.algorithm == hasher.digest_name

def test_digest_size(self):
memory_hasher = memory.SHA256()
hasher = file.FileHasher(sample_file, memory_hasher)
assert hasher.digest_size == memory_hasher.digest_size
27 changes: 19 additions & 8 deletions model_signing/hashing/hashing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@
specify the algorithm and the digest value.
"""

from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
import abc
import dataclasses
from typing import Protocol


@dataclass(frozen=True)
@dataclasses.dataclass(frozen=True)
class Digest:
"""A digest computed by a `HashEngine`."""

Expand All @@ -38,17 +38,22 @@ def digest_hex(self) -> str:
"""Hexadecimal, human readable, equivalent of `digest`."""
return self.digest_value.hex()

@property
def digest_size(self) -> int:
"""The size, in bytes, of the digest."""
return len(self.digest_value)


class HashEngine(metaclass=ABCMeta):
class HashEngine(metaclass=abc.ABCMeta):
"""Generic hash engine."""

@abstractmethod
@abc.abstractmethod
def compute(self) -> Digest:
"""Computes the digest of data passed to the engine."""
pass

@property
@abstractmethod
@abc.abstractmethod
def digest_name(self) -> str:
"""The canonical name of the algorithm used to compute the hash.
Expand All @@ -60,16 +65,22 @@ def digest_name(self) -> str:
"""
pass

@property
@abc.abstractmethod
def digest_size(self) -> int:
"""The size, in bytes, of the digests produced by the engine."""
pass


class Streaming(Protocol):
"""A protocol to support streaming data to `HashEngine` objects."""

@abstractmethod
@abc.abstractmethod
def update(self, data: bytes) -> None:
"""Appends additional bytes to the data to be hashed."""
pass

@abstractmethod
@abc.abstractmethod
def reset(self, data: bytes = b"") -> None:
"""Resets the data to be hashed to the passed argument."""
pass
Expand Down
6 changes: 6 additions & 0 deletions model_signing/hashing/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,9 @@ def compute(self) -> hashing.Digest:
@property
def digest_name(self) -> str:
return "sha256"

@override
@property
def digest_size(self) -> int:
"""The size, in bytes, of the digests produced by the engine."""
return 32
7 changes: 7 additions & 0 deletions model_signing/hashing/memory_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,10 @@ def test_update_after_reset(self):

assert digest1.digest_hex == digest2.digest_hex
assert digest1.digest_value == digest2.digest_value

def test_digest_size(self):
hasher = memory.SHA256(b"Test string")
assert hasher.digest_size == 32

digest = hasher.compute()
assert digest.digest_size == 32
10 changes: 8 additions & 2 deletions model_signing/hashing/precomputed.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@
```
"""

from dataclasses import dataclass
import dataclasses
from typing_extensions import override

from model_signing.hashing import hashing


@dataclass(frozen=True)
@dataclasses.dataclass(frozen=True)
class PrecomputedDigest(hashing.HashEngine):
"""A wrapper around digests computed by external tooling."""

Expand All @@ -49,3 +49,9 @@ def compute(self) -> hashing.Digest:
@property
def digest_name(self) -> str:
return self._digest_type

@override
@property
def digest_size(self) -> int:
"""The size, in bytes, of the digests produced by the engine."""
return len(self._digest_value)
5 changes: 5 additions & 0 deletions model_signing/hashing/precomputed_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,8 @@ def test_expected_hash_type(self):
assert hasher.digest_name == "test"
digest = hasher.compute()
assert digest.algorithm == "test"

def test_digest_size(self):
digest = b"abcd"
hasher = precomputed.PrecomputedDigest("test", digest)
assert hasher.digest_size == len(digest)
8 changes: 4 additions & 4 deletions model_signing/manifest/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,19 @@
soon.
"""

from abc import ABCMeta
from dataclasses import dataclass
import abc
import dataclasses

from model_signing.hashing import hashing


class Manifest(metaclass=ABCMeta):
class Manifest(metaclass=abc.ABCMeta):
"""Generic manifest file to represent a model."""

pass


@dataclass
@dataclasses.dataclass
class DigestManifest(Manifest):
"""A manifest that is just a hash."""

Expand Down
File renamed without changes.
Loading

0 comments on commit 24d68a4

Please sign in to comment.