diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index eae08f7f7..4cbb4f173 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -291,7 +291,8 @@ def load_model( if cfg.resize_token_embeddings_to_32x else len(tokenizer) ) - model.resize_token_embeddings(embeddings_len) + if model.get_input_embeddings().num_embeddings < embeddings_len: + model.resize_token_embeddings(embeddings_len) if ( hasattr(model.config, "max_position_embeddings")