From 3900efc5442348360f01c6b4b2a424307590808d Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Fri, 24 May 2024 14:46:37 +0200 Subject: [PATCH] Support `max_shard_size` as string in `split_state_dict_into_shards_factory` (#2286) * fix max-shard-size * for torch also * add tests * Fix styling + do not support KiB --------- Co-authored-by: Lucain Pouget --- src/huggingface_hub/serialization/_base.py | 48 ++++++++++++++++++- src/huggingface_hub/serialization/_numpy.py | 4 +- .../serialization/_tensorflow.py | 4 +- src/huggingface_hub/serialization/_torch.py | 6 +-- tests/test_serialization.py | 17 +++++++ 5 files changed, 70 insertions(+), 9 deletions(-) diff --git a/src/huggingface_hub/serialization/_base.py b/src/huggingface_hub/serialization/_base.py index a7f7bba892..e16e4a8137 100644 --- a/src/huggingface_hub/serialization/_base.py +++ b/src/huggingface_hub/serialization/_base.py @@ -14,7 +14,7 @@ """Contains helpers to split tensors into shards.""" from dataclasses import dataclass, field -from typing import Any, Callable, Dict, List, Optional, TypeVar +from typing import Any, Callable, Dict, List, Optional, TypeVar, Union from .. import logging @@ -46,7 +46,7 @@ def split_state_dict_into_shards_factory( get_tensor_size: TensorSizeFn_T, get_storage_id: StorageIDFn_T = lambda tensor: None, filename_pattern: str = FILENAME_PATTERN, - max_shard_size: int = MAX_SHARD_SIZE, + max_shard_size: Union[int, str] = MAX_SHARD_SIZE, ) -> StateDictSplit: """ Split a model state dictionary in shards so that each shard is smaller than a given size. @@ -89,6 +89,9 @@ def split_state_dict_into_shards_factory( current_shard_size = 0 total_size = 0 + if isinstance(max_shard_size, str): + max_shard_size = parse_size_to_int(max_shard_size) + for key, tensor in state_dict.items(): # when bnb serialization is used the weights in the state dict can be strings # check: https://github.com/huggingface/transformers/pull/24416 for more details @@ -167,3 +170,44 @@ def split_state_dict_into_shards_factory( filename_to_tensors=filename_to_tensors, tensor_to_filename=tensor_name_to_filename, ) + + +SIZE_UNITS = { + "TB": 10**12, + "GB": 10**9, + "MB": 10**6, + "KB": 10**3, +} + + +def parse_size_to_int(size_as_str: str) -> int: + """ + Parse a size expressed as a string with digits and unit (like `"5MB"`) to an integer (in bytes). + + Supported units are "TB", "GB", "MB", "KB". + + Args: + size_as_str (`str`): The size to convert. Will be directly returned if an `int`. + + Example: + + ```py + >>> parse_size_to_int("5MB") + 5000000 + ``` + """ + size_as_str = size_as_str.strip() + + # Parse unit + unit = size_as_str[-2:].upper() + if unit not in SIZE_UNITS: + raise ValueError(f"Unit '{unit}' not supported. Supported units are TB, GB, MB, KB. Got '{size_as_str}'.") + multiplier = SIZE_UNITS[unit] + + # Parse value + try: + value = float(size_as_str[:-2].strip()) + except ValueError as e: + raise ValueError(f"Could not parse the size value from '{size_as_str}': {e}") from e + + return int(value * multiplier) diff --git a/src/huggingface_hub/serialization/_numpy.py b/src/huggingface_hub/serialization/_numpy.py index 214c77d9ac..19b5a26aef 100644 --- a/src/huggingface_hub/serialization/_numpy.py +++ b/src/huggingface_hub/serialization/_numpy.py @@ -13,7 +13,7 @@ # limitations under the License. """Contains numpy-specific helpers.""" -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING, Dict, Union from ._base import FILENAME_PATTERN, MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory @@ -26,7 +26,7 @@ def split_numpy_state_dict_into_shards( state_dict: Dict[str, "np.ndarray"], *, filename_pattern: str = FILENAME_PATTERN, - max_shard_size: int = MAX_SHARD_SIZE, + max_shard_size: Union[int, str] = MAX_SHARD_SIZE, ) -> StateDictSplit: """ Split a model state dictionary in shards so that each shard is smaller than a given size. diff --git a/src/huggingface_hub/serialization/_tensorflow.py b/src/huggingface_hub/serialization/_tensorflow.py index f8d752c083..f3818b0ae3 100644 --- a/src/huggingface_hub/serialization/_tensorflow.py +++ b/src/huggingface_hub/serialization/_tensorflow.py @@ -15,7 +15,7 @@ import math import re -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING, Dict, Union from ._base import MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory @@ -28,7 +28,7 @@ def split_tf_state_dict_into_shards( state_dict: Dict[str, "tf.Tensor"], *, filename_pattern: str = "tf_model{suffix}.h5", - max_shard_size: int = MAX_SHARD_SIZE, + max_shard_size: Union[int, str] = MAX_SHARD_SIZE, ) -> StateDictSplit: """ Split a model state dictionary in shards so that each shard is smaller than a given size. diff --git a/src/huggingface_hub/serialization/_torch.py b/src/huggingface_hub/serialization/_torch.py index 00ab7e2c80..7ccce3c281 100644 --- a/src/huggingface_hub/serialization/_torch.py +++ b/src/huggingface_hub/serialization/_torch.py @@ -15,7 +15,7 @@ import importlib from functools import lru_cache -from typing import TYPE_CHECKING, Dict, Tuple +from typing import TYPE_CHECKING, Dict, Tuple, Union from ._base import FILENAME_PATTERN, MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory @@ -28,7 +28,7 @@ def split_torch_state_dict_into_shards( state_dict: Dict[str, "torch.Tensor"], *, filename_pattern: str = FILENAME_PATTERN, - max_shard_size: int = MAX_SHARD_SIZE, + max_shard_size: Union[int, str] = MAX_SHARD_SIZE, ) -> StateDictSplit: """ Split a model state dictionary in shards so that each shard is smaller than a given size. @@ -67,7 +67,7 @@ def split_torch_state_dict_into_shards( >>> def save_state_dict(state_dict: Dict[str, torch.Tensor], save_directory: str): ... state_dict_split = split_torch_state_dict_into_shards(state_dict) - ... for filename, tensors in state_dict_split.filename_to_tensors.values(): + ... for filename, tensors in state_dict_split.filename_to_tensors.items(): ... shard = {tensor: state_dict[tensor] for tensor in tensors} ... safe_save_file( ... shard, diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 004f3d3e8a..47a78d5e2e 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -1,4 +1,7 @@ +import pytest + from huggingface_hub.serialization import split_state_dict_into_shards_factory +from huggingface_hub.serialization._base import parse_size_to_int from huggingface_hub.serialization._numpy import get_tensor_size as get_tensor_size_numpy from huggingface_hub.serialization._tensorflow import get_tensor_size as get_tensor_size_tensorflow from huggingface_hub.serialization._torch import get_tensor_size as get_tensor_size_torch @@ -123,3 +126,17 @@ def test_get_tensor_size_torch(): assert get_tensor_size_torch(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float64)) == 5 * 8 assert get_tensor_size_torch(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float16)) == 5 * 2 + + +def test_parse_size_to_int(): + assert parse_size_to_int("1KB") == 1 * 10**3 + assert parse_size_to_int("2MB") == 2 * 10**6 + assert parse_size_to_int("3GB") == 3 * 10**9 + assert parse_size_to_int(" 10 KB ") == 10 * 10**3 # ok with whitespace + assert parse_size_to_int("20mb") == 20 * 10**6 # ok with lowercase + + with pytest.raises(ValueError, match="Unit 'IB' not supported"): + parse_size_to_int("1KiB") # not a valid unit + + with pytest.raises(ValueError, match="Could not parse the size value"): + parse_size_to_int("1ooKB") # not a float