Skip to content

Commit

Permalink
Skip device placement for past key values in decoder models (huggingf…
Browse files Browse the repository at this point in the history
  • Loading branch information
sgugger authored and novice03 committed Jun 23, 2023
1 parent 1bf7aa9 commit dfcb3a1
Show file tree
Hide file tree
Showing 15 changed files with 19 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,6 +1052,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
main_input_name = "input_ids"
_auto_class = None
_no_split_modules = None
_skip_keys_device_placement = None
_keep_in_fp32_modules = None

# a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
Expand Down Expand Up @@ -2887,7 +2888,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

# Dispatch model with hooks on all devices if necessary
if device_map is not None:
dispatch_model(model, device_map=device_map, offload_dir=offload_folder, offload_index=offload_index)
kwargs = {"device_map": device_map, "offload_dir": offload_folder, "offload_index": offload_index}
if "skip_keys" in inspect.signature(dispatch_model).parameters:
kwargs["skip_keys"] = model._skip_keys_device_placement
dispatch_model(model, **kwargs)

if output_loading_info:
if loading_info is None:
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,7 @@ class BartPretrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_unexpected = [r"encoder.version", r"decoder.version"]
_no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"]
_skip_keys_device_placement = "past_key_values"

def _init_weights(self, module):
std = self.config.init_std
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1597,6 +1597,7 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel):
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["BigBirdPegasusEncoderLayer", "BigBirdPegasusDecoderLayer"]
_skip_keys_device_placement = "past_key_values"

def _init_weights(self, module):
std = self.config.init_std
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/blip_2/modeling_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ class Blip2PreTrainedModel(PreTrainedModel):
r"language_model.lm_head.weight",
]
_no_split_modules = ["Blip2Attention", "T5Block", "OPTDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_keep_in_fp32_modules = ["wo"]

def _init_weights(self, module):
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/bloom/modeling_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,7 @@ class BloomPreTrainedModel(PreTrainedModel):
base_model_prefix = "transformer"
supports_gradient_checkpointing = True
_no_split_modules = ["BloomBlock"]
_skip_keys_device_placement = "past_key_values"

def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -982,6 +982,7 @@ class BridgeTowerPreTrainedModel(PreTrainedModel):
base_model_prefix = "bridgetower"
supports_gradient_checkpointing = False
_no_split_modules = ["BridgeTowerSelfAttention", "BridgeTowerResidualAttention"]
_skip_keys_device_placement = "past_key_values"

def _init_weights(self, module):
if isinstance(module, BridgeTowerVisionModel):
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/codegen/modeling_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ class CodeGenPreTrainedModel(PreTrainedModel):
base_model_prefix = "transformer"
supports_gradient_checkpointing = True
_no_split_modules = ["CodeGenBlock"]
_skip_keys_device_placement = "past_key_values"

def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,7 @@ class GPT2PreTrainedModel(PreTrainedModel):
is_parallelizable = True
supports_gradient_checkpointing = True
_no_split_modules = ["GPT2Block"]
_skip_keys_device_placement = "past_key_values"

def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
base_model_prefix = "transformer"
supports_gradient_checkpointing = True
_no_split_modules = ["GPTBigCodeBlock"]
_skip_keys_device_placement = "past_key_values"

def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/gpt_neo/modeling_gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ class GPTNeoPreTrainedModel(PreTrainedModel):
base_model_prefix = "transformer"
supports_gradient_checkpointing = True
_no_split_modules = ["GPTNeoBlock"]
_skip_keys_device_placement = "past_key_values"

def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class GPTNeoXPreTrainedModel(PreTrainedModel):
base_model_prefix = "gpt_neox"
supports_gradient_checkpointing = True
_no_split_modules = ["GPTNeoXLayer"]
_skip_keys_device_placement = "past_key_values"

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class GPTNeoXJapanesePreTrainedModel(PreTrainedModel):
base_model_prefix = "gpt_neox_japanese"
supports_gradient_checkpointing = True
_no_split_modules = ["GPTNeoXJapaneseLayer"]
_skip_keys_device_placement = "past_key_values"

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/gptj/modeling_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ class GPTJPreTrainedModel(PreTrainedModel):
is_parallelizable = True
supports_gradient_checkpointing = True
_no_split_modules = ["GPTJBlock"]
_skip_keys_device_placement = "past_key_values"

def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,7 @@ class GPTSanJapanesePreTrainedModel(PreTrainedModel):
base_model_prefix = "gptsan_japanese"
supports_gradient_checkpointing = False
_no_split_modules = ["GPTSanJapaneseBlock"]
_skip_keys_device_placement = "past_key_values"

@property
def dummy_inputs(self):
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ class LlamaPreTrainedModel(PreTrainedModel):
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["LlamaDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_keys_to_ignore_on_load_unexpected = [r"decoder\.version"]

def _init_weights(self, module):
Expand Down

0 comments on commit dfcb3a1

Please sign in to comment.