diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 72a00f3ebe68d8..fa658d9e05782c 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -500,6 +500,7 @@ class GPTJPreTrainedModel(PreTrainedModel): _no_split_modules = ["GPTJBlock"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_param_buffer_assignment = False def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs)