Skip to content

Commit

Permalink
Fix float8_e4m3fn in modeling_utils (#32193)
Browse files Browse the repository at this point in the history
* Fix float8_e4m3fn in modeling_utils

* style

* fix

* comment
  • Loading branch information
SunMarc authored Jul 24, 2024
1 parent 1392a68 commit af0e4b7
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,8 @@ def _load_state_dict_into_meta_model(
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)

is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")

for param_name, param in state_dict.items():
# First part of the test is always true as load_state_dict_keys always contains state_dict keys.
if param_name not in loaded_state_dict_keys or param_name not in expected_keys:
Expand All @@ -866,9 +868,10 @@ def _load_state_dict_into_meta_model(
module_name = param_name
set_module_kwargs = {}

# We convert floating dtypes to the `dtype` passed. We want to keep the buffers/params
# We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params
# in int/uint/bool and not cast them.
if dtype is not None and torch.is_floating_point(param) and param.dtype != torch.float8_e4m3fn:
is_param_float8_e4m3fn = is_torch_e4m3fn_available and param.dtype == torch.float8_e4m3fn
if dtype is not None and torch.is_floating_point(param) and not is_param_float8_e4m3fn:
if (
keep_in_fp32_modules is not None
and any(
Expand Down

0 comments on commit af0e4b7

Please sign in to comment.