diff --git a/trl/models/utils.py b/trl/models/utils.py index 2ec88845c3..849ac19020 100644 --- a/trl/models/utils.py +++ b/trl/models/utils.py @@ -76,5 +76,10 @@ def setup_chat_format( model.resize_token_embeddings( len(tokenizer), pad_to_multiple_of=resize_to_multiple_of if resize_to_multiple_of is not None else None ) + # Make sure to update the generation config to use the new eos & bos token + if getattr(model, "generation_config", None) is not None: + model.generation_config.bos_token_id = tokenizer.bos_token_id + model.generation_config.eos_token_id = tokenizer.eos_token_id + model.generation_config.pad_token_id = tokenizer.pad_token_id return model, tokenizer