Skip to content

Commit

Permalink
[Core] Change 8-bit serialization weight format format (#1164)
Browse files Browse the repository at this point in the history
* change 8-bit serialization weight format format

* precimmit

* pre-commit

* fix

* Update bitsandbytes/nn/modules.py

Co-authored-by: Aarni Koskela <akx@iki.fi>

* Update bitsandbytes/nn/modules.py

Co-authored-by: Aarni Koskela <akx@iki.fi>

* Update bitsandbytes/utils.py

Co-authored-by: Aarni Koskela <akx@iki.fi>

* address feedback

* lint

---------

Co-authored-by: Aarni Koskela <akx@iki.fi>
  • Loading branch information
younesbelkada and akx authored Apr 10, 2024
1 parent c54053d commit 7449d71
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 4 deletions.
29 changes: 25 additions & 4 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
from bitsandbytes.autograd._functions import get_tile_inds, undo_layout
from bitsandbytes.functional import QuantState
from bitsandbytes.optim import GlobalOptimManager
from bitsandbytes.utils import OutlierTracer
from bitsandbytes.utils import (
INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING,
LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING,
OutlierTracer,
)

T = TypeVar("T", bound="torch.nn.Module")

Expand Down Expand Up @@ -619,6 +623,16 @@ def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_k
return
weight_format = state_dict.pop(f"{prefix}weight_format", "row")

if isinstance(weight_format, torch.Tensor):
weight_format = weight_format.item()

# For new weights format storage type, we explicitly check
# if weights_format is on the mapping
if isinstance(weight_format, int) and weight_format not in INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING:
raise ValueError(f"Expected supported weight format - got {weight_format}")
elif isinstance(weight_format, int) and weight_format in INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING:
weight_format = INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING[weight_format]

if weight_format != "row":
tile_indices = get_tile_inds(weight_format, weight.device)
state_dict[f"{prefix}weight"] = undo_layout(weight, tile_indices)
Expand Down Expand Up @@ -711,13 +725,20 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
if not self.state.has_fp16_weights:
if param_from_weight is not None:
destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach()
destination[format_name] = "row"
destination[format_name] = torch.tensor(0, dtype=torch.uint8)
elif param_from_state is not None and not layout_reordered:
destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
destination[format_name] = "row"
destination[format_name] = torch.tensor(0, dtype=torch.uint8)
elif param_from_state is not None:
destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
destination[format_name] = self.state.formatB
weights_format = self.state.formatB
# At this point `weights_format` is an str
if weights_format not in LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING:
raise ValueError(f"Unrecognized weights format {weights_format}")

weights_format = LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING[weights_format]

destination[format_name] = torch.tensor(weights_format, dtype=torch.uint8)

def _load_from_state_dict(
self,
Expand Down
4 changes: 4 additions & 0 deletions bitsandbytes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,3 +198,7 @@ def unpack_tensor_to_dict(tensor_data):
unpacked_dict = json.loads(json_str)

return unpacked_dict


LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING = {"row": 0, "col32": 1, "col_turing": 2, "col_ampere": 3}
INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING = {val: name for (name, val) in LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING.items()}

0 comments on commit 7449d71

Please sign in to comment.