Skip to content

Commit

Permalink
Support max_shard_size as string in `split_state_dict_into_shards_f…
Browse files Browse the repository at this point in the history
…actory` (#2286)

* fix max-shard-size

* for torch also

* add tests

* Fix styling + do not support KiB

---------

Co-authored-by: Lucain Pouget <lucainp@gmail.com>
  • Loading branch information
SunMarc and Wauplin authored May 24, 2024
1 parent a6fbf16 commit 3900efc
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 9 deletions.
48 changes: 46 additions & 2 deletions src/huggingface_hub/serialization/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""Contains helpers to split tensors into shards."""

from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, TypeVar
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union

from .. import logging

Expand Down Expand Up @@ -46,7 +46,7 @@ def split_state_dict_into_shards_factory(
get_tensor_size: TensorSizeFn_T,
get_storage_id: StorageIDFn_T = lambda tensor: None,
filename_pattern: str = FILENAME_PATTERN,
max_shard_size: int = MAX_SHARD_SIZE,
max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
) -> StateDictSplit:
"""
Split a model state dictionary in shards so that each shard is smaller than a given size.
Expand Down Expand Up @@ -89,6 +89,9 @@ def split_state_dict_into_shards_factory(
current_shard_size = 0
total_size = 0

if isinstance(max_shard_size, str):
max_shard_size = parse_size_to_int(max_shard_size)

for key, tensor in state_dict.items():
# when bnb serialization is used the weights in the state dict can be strings
# check: https://github.com/huggingface/transformers/pull/24416 for more details
Expand Down Expand Up @@ -167,3 +170,44 @@ def split_state_dict_into_shards_factory(
filename_to_tensors=filename_to_tensors,
tensor_to_filename=tensor_name_to_filename,
)


SIZE_UNITS = {
"TB": 10**12,
"GB": 10**9,
"MB": 10**6,
"KB": 10**3,
}


def parse_size_to_int(size_as_str: str) -> int:
"""
Parse a size expressed as a string with digits and unit (like `"5MB"`) to an integer (in bytes).
Supported units are "TB", "GB", "MB", "KB".
Args:
size_as_str (`str`): The size to convert. Will be directly returned if an `int`.
Example:
```py
>>> parse_size_to_int("5MB")
5000000
```
"""
size_as_str = size_as_str.strip()

# Parse unit
unit = size_as_str[-2:].upper()
if unit not in SIZE_UNITS:
raise ValueError(f"Unit '{unit}' not supported. Supported units are TB, GB, MB, KB. Got '{size_as_str}'.")
multiplier = SIZE_UNITS[unit]

# Parse value
try:
value = float(size_as_str[:-2].strip())
except ValueError as e:
raise ValueError(f"Could not parse the size value from '{size_as_str}': {e}") from e

return int(value * multiplier)
4 changes: 2 additions & 2 deletions src/huggingface_hub/serialization/_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
"""Contains numpy-specific helpers."""

from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING, Dict, Union

from ._base import FILENAME_PATTERN, MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory

Expand All @@ -26,7 +26,7 @@ def split_numpy_state_dict_into_shards(
state_dict: Dict[str, "np.ndarray"],
*,
filename_pattern: str = FILENAME_PATTERN,
max_shard_size: int = MAX_SHARD_SIZE,
max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
) -> StateDictSplit:
"""
Split a model state dictionary in shards so that each shard is smaller than a given size.
Expand Down
4 changes: 2 additions & 2 deletions src/huggingface_hub/serialization/_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import math
import re
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING, Dict, Union

from ._base import MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory

Expand All @@ -28,7 +28,7 @@ def split_tf_state_dict_into_shards(
state_dict: Dict[str, "tf.Tensor"],
*,
filename_pattern: str = "tf_model{suffix}.h5",
max_shard_size: int = MAX_SHARD_SIZE,
max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
) -> StateDictSplit:
"""
Split a model state dictionary in shards so that each shard is smaller than a given size.
Expand Down
6 changes: 3 additions & 3 deletions src/huggingface_hub/serialization/_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import importlib
from functools import lru_cache
from typing import TYPE_CHECKING, Dict, Tuple
from typing import TYPE_CHECKING, Dict, Tuple, Union

from ._base import FILENAME_PATTERN, MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory

Expand All @@ -28,7 +28,7 @@ def split_torch_state_dict_into_shards(
state_dict: Dict[str, "torch.Tensor"],
*,
filename_pattern: str = FILENAME_PATTERN,
max_shard_size: int = MAX_SHARD_SIZE,
max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
) -> StateDictSplit:
"""
Split a model state dictionary in shards so that each shard is smaller than a given size.
Expand Down Expand Up @@ -67,7 +67,7 @@ def split_torch_state_dict_into_shards(
>>> def save_state_dict(state_dict: Dict[str, torch.Tensor], save_directory: str):
... state_dict_split = split_torch_state_dict_into_shards(state_dict)
... for filename, tensors in state_dict_split.filename_to_tensors.values():
... for filename, tensors in state_dict_split.filename_to_tensors.items():
... shard = {tensor: state_dict[tensor] for tensor in tensors}
... safe_save_file(
... shard,
Expand Down
17 changes: 17 additions & 0 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import pytest

from huggingface_hub.serialization import split_state_dict_into_shards_factory
from huggingface_hub.serialization._base import parse_size_to_int
from huggingface_hub.serialization._numpy import get_tensor_size as get_tensor_size_numpy
from huggingface_hub.serialization._tensorflow import get_tensor_size as get_tensor_size_tensorflow
from huggingface_hub.serialization._torch import get_tensor_size as get_tensor_size_torch
Expand Down Expand Up @@ -123,3 +126,17 @@ def test_get_tensor_size_torch():

assert get_tensor_size_torch(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float64)) == 5 * 8
assert get_tensor_size_torch(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float16)) == 5 * 2


def test_parse_size_to_int():
assert parse_size_to_int("1KB") == 1 * 10**3
assert parse_size_to_int("2MB") == 2 * 10**6
assert parse_size_to_int("3GB") == 3 * 10**9
assert parse_size_to_int(" 10 KB ") == 10 * 10**3 # ok with whitespace
assert parse_size_to_int("20mb") == 20 * 10**6 # ok with lowercase

with pytest.raises(ValueError, match="Unit 'IB' not supported"):
parse_size_to_int("1KiB") # not a valid unit

with pytest.raises(ValueError, match="Could not parse the size value"):
parse_size_to_int("1ooKB") # not a float

0 comments on commit 3900efc

Please sign in to comment.