Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add keep_in_fp32_modules support #20683

Merged
merged 20 commits into from
Dec 13, 2022
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 55 additions & 4 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,7 @@ def _load_state_dict_into_meta_model(
dtype=None,
load_in_8bit=False,
is_safetensors=False,
keep_in_fp32_modules=None,
):
"""
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
Expand Down Expand Up @@ -611,7 +612,14 @@ def _load_state_dict_into_meta_model(
# We convert floating dtypes to the `dtype` passed. We want to keep the buffers/params
# in int/uint/bool and not cast them.
if dtype is not None and torch.is_floating_point(param):
param = param.to(dtype)
if (
keep_in_fp32_modules is not None
and any(module_to_keep_in_fp32 in param_name for module_to_keep_in_fp32 in keep_in_fp32_modules)
and dtype == torch.float16
):
param = param.to(torch.float32)
else:
param = param.to(dtype)

# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model
if dtype is None:
Expand Down Expand Up @@ -964,6 +972,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
main_input_name = "input_ids"
_auto_class = None
_no_split_modules = None
_keep_in_fp32_modules = None

# a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
# keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings.
Expand Down Expand Up @@ -2259,6 +2268,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
# weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype
# we also may have config.torch_dtype available, but we won't rely on it till v5
dtype_orig = None

if torch_dtype is not None:
if isinstance(torch_dtype, str):
if torch_dtype == "auto":
Expand All @@ -2276,11 +2286,25 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
)
dtype_orig = cls._set_default_torch_dtype(torch_dtype)

# Check if `_keep_in_fp32_modules` is not None
# but not force users to have `accelerate`
if cls._keep_in_fp32_modules is not None:
if not is_accelerate_available() and torch_dtype == torch.float16:
use_keep_in_fp32_modules = False
logger.warning(
" `_keep_in_fp32_modules` is not set to `None` and you don't have `accelerate` installed"
" it is recommended to have `accelerate` installed in this case `pip install accelerate`.",
)
else:
use_keep_in_fp32_modules = True
else:
use_keep_in_fp32_modules = False

if is_sharded:
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
else:
loaded_state_dict_keys = [k for k in state_dict.keys()]
if low_cpu_mem_usage:
if low_cpu_mem_usage or use_keep_in_fp32_modules:
state_dict = None

config.name_or_path = pretrained_model_name_or_path
Expand All @@ -2299,6 +2323,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
with ContextManagers(init_contexts):
model = cls(config, *model_args, **model_kwargs)

if use_keep_in_fp32_modules:
low_cpu_mem_usage = True
keep_in_fp32_modules = model._keep_in_fp32_modules
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's set it to [] here if it's not None, so that we don't have to check again layer on.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be addressed in cb89c42


if load_in_8bit:
from .utils.bitsandbytes import get_keys_to_not_convert, replace_8bit_linear

Expand All @@ -2309,6 +2337,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
modules_to_not_convert = get_keys_to_not_convert(model)
else:
modules_to_not_convert = load_in_8bit_skip_modules

if keep_in_fp32_modules is not None and isinstance(keep_in_fp32_modules, list):
modules_to_not_convert.extend(keep_in_fp32_modules)

model = replace_8bit_linear(
model, threshold=load_in_8bit_threshold, modules_to_not_convert=modules_to_not_convert
)
Expand Down Expand Up @@ -2415,6 +2447,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
load_in_8bit=load_in_8bit,
keep_in_fp32_modules=keep_in_fp32_modules,
)

model.is_loaded_in_8bit = load_in_8bit
Expand Down Expand Up @@ -2458,6 +2491,7 @@ def _load_pretrained_model(
offload_state_dict=None,
dtype=None,
load_in_8bit=False,
keep_in_fp32_modules=None,
):
is_safetensors = False
if load_in_8bit:
Expand Down Expand Up @@ -2534,11 +2568,27 @@ def _fix_key(key):
if key.startswith(prefix):
key = ".".join(key.split(".")[1:])
param = model_state_dict[key]

# upcast in fp32 if any
target_dtype = dtype
if (
keep_in_fp32_modules is not None
and dtype == torch.float16
and any(module_to_keep_in_fp32 in key for module_to_keep_in_fp32 in keep_in_fp32_modules)
):
target_dtype = torch.float32

if param.device == torch.device("meta"):
if not load_in_8bit:
set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=dtype))
set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=target_dtype))
else:
set_module_8bit_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=dtype))
set_module_8bit_tensor_to_device(
model, key, "cpu", torch.empty(*param.size(), dtype=target_dtype)
)
elif keep_in_fp32_modules is not None and state_dict is not None:
for key in state_dict:
if any(module_to_keep_in_fp32 in key for module_to_keep_in_fp32 in keep_in_fp32_modules):
state_dict[key] = state_dict[key].to(torch.float32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not useful as with torch.load_state_dict, the weights are converted to the dtype inside the model. So it's the model dtype that you should fix here.

Also this removes the necessity for an Accelerate warning above, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! Should be addressed in cb89c42


# retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights.
if _fast_init:
Expand Down Expand Up @@ -2681,6 +2731,7 @@ def _find_mismatched_keys(
dtype=dtype,
load_in_8bit=load_in_8bit,
is_safetensors=is_safetensors,
keep_in_fp32_modules=keep_in_fp32_modules,
)
error_msgs += new_error_msgs
else:
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,7 @@ class T5PreTrainedModel(PreTrainedModel):
is_parallelizable = True
supports_gradient_checkpointing = True
_no_split_modules = ["T5Block"]
_keep_in_fp32_modules = ["wo"]

@property
def dummy_inputs(self):
Expand Down
7 changes: 7 additions & 0 deletions tests/mixed_int8/test_mixed_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,13 @@ def test_device_and_dtype_assignment(self):
# Check this does not throw an error
_ = self.model_fp16.float()

def test_fp32_int8_conversion(self):
r"""
Test whether it is possible to mix both `int8` and `fp32` weights when using `keep_in_fp32_modules` correctly.
"""
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small", load_in_8bit=True, device_map="auto")
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)


class MixedInt8ModelClassesTest(BaseMixedInt8Test):
def setUp(self):
Expand Down
53 changes: 52 additions & 1 deletion tests/models/t5/test_modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,14 @@
import unittest

from transformers import T5Config, is_torch_available
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from transformers.testing_utils import (
require_accelerate,
require_sentencepiece,
require_tokenizers,
require_torch,
slow,
torch_device,
)
from transformers.utils import cached_property

from ...generation.test_utils import GenerationTesterMixin
Expand Down Expand Up @@ -820,6 +827,50 @@ def use_task_specific_params(model, task):
model.config.update(model.config.task_specific_params[task])


@require_torch
@require_accelerate
@require_tokenizers
@slow
class T5ModelFp16Tests(unittest.TestCase):
def test_fp16_fp32_conversion(self):
r"""
A test to check whether the argument `keep_in_fp32_modules` correctly does its job
"""
# Load without using `accelerate`
model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.float16)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16)

# Load without in bf16
model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.bfloat16)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16)

# Load using `accelerate` in bf16
model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.bfloat16, device_map="auto")
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16)

# Load using `accelerate` in bf16
model = T5ForConditionalGeneration.from_pretrained(
"t5-small", torch_dtype=torch.bfloat16, low_cpu_mem_usage=True
)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.bfloat16)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.bfloat16)

# Load without using `accelerate`
model = T5ForConditionalGeneration.from_pretrained(
"t5-small", torch_dtype=torch.float16, low_cpu_mem_usage=True
)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16)

# Load using `accelerate`
model = T5ForConditionalGeneration.from_pretrained("t5-small", torch_dtype=torch.float16, device_map="auto")
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)
self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wi.weight.dtype == torch.float16)


@require_torch
@require_sentencepiece
@require_tokenizers
Expand Down