Skip to content

Commit

Permalink
[core] Raise warning on using prepare_model_for_int8_training (#483)
Browse files Browse the repository at this point in the history
* raise warning on using older method

* Update src/peft/utils/other.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* quality

---------

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
younesbelkada and sgugger authored May 22, 2023
1 parent 0fcc30d commit 3714aa2
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/peft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
bloom_model_postprocess_past_key_value,
get_peft_model_state_dict,
prepare_model_for_int8_training,
prepare_model_for_kbit_training,
set_peft_model_state_dict,
shift_tokens_right,
)
1 change: 1 addition & 0 deletions src/peft/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
_set_trainable,
bloom_model_postprocess_past_key_value,
prepare_model_for_int8_training,
prepare_model_for_kbit_training,
shift_tokens_right,
transpose,
_get_submodules,
Expand Down
12 changes: 11 additions & 1 deletion src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import copy
import warnings

import torch

Expand All @@ -32,7 +33,7 @@ def bloom_model_postprocess_past_key_value(past_key_values):
return tuple(zip(keys, values))


def prepare_model_for_int8_training(model, use_gradient_checkpointing=True):
def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True):
r"""
This method wraps the entire protocol for preparing a model before running a training. This includes:
1- Cast the layernorm in fp32 2- making output embedding layer require grads 3- Add the upcasting of the lm
Expand Down Expand Up @@ -70,6 +71,15 @@ def make_inputs_require_grad(module, input, output):
return model


# For backward compatibility
def prepare_model_for_int8_training(*args, **kwargs):
warnings.warn(
"prepare_model_for_int8_training is deprecated and will be removed in a future version. Use prepare_model_for_kbit_training instead.",
FutureWarning,
)
return prepare_model_for_kbit_training(*args, **kwargs)


# copied from transformers.models.bart.modeling_bart
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
"""
Expand Down

0 comments on commit 3714aa2

Please sign in to comment.