Skip to content

Commit 83dbebc

Browse files
authored
[trainer] ensure special tokens in model configs are aligned with tokenizer at train time (#38441)
* tmp commit * add test * make fixup * reset warns/info in test
1 parent 9977cf1 commit 83dbebc

File tree

4 files changed

+134
-9
lines changed

4 files changed

+134
-9
lines changed

src/transformers/generation/configuration_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -792,8 +792,8 @@ def validate(self, strict=False):
792792
)
793793
if logging.get_verbosity() >= logging.WARNING:
794794
warning_message += " Set `TRANSFORMERS_VERBOSITY=info` for more details."
795-
logger.warning(warning_message)
796-
logger.info(info_message)
795+
logger.warning_once(warning_message)
796+
logger.info_once(info_message)
797797

798798
def save_pretrained(
799799
self,

src/transformers/trainer.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,76 @@ def _move_model_to_device(self, model, device):
905905
if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"):
906906
model.tie_weights()
907907

908+
def _align_special_tokens(self):
909+
"""
910+
Aligns the special tokens of the tokenizer with the model configs.
911+
912+
A new tokens may be defined in the tokenizer for fine-tuning purposes, e.g. an "end of turn" token may be
913+
added on chat models. In that case, we want the model configs to be aligned with the tokenizer, so that all
914+
downstream uses work as expected. This alignment should happen before training, to ensure the prediction step
915+
uses the new tokens as well.
916+
"""
917+
if isinstance(self.processing_class, ProcessorMixin):
918+
tokenizer = self.processing_class.tokenizer
919+
else:
920+
tokenizer = self.tokenizer
921+
model_has_generation_config = (
922+
hasattr(self.model, "generation_config") and self.model.generation_config is not None
923+
)
924+
updated_tokens = {}
925+
926+
# 1 - Align EOS token. EOS is more complex than the others, as `generation_config` may hold more than one EOS
927+
# token.
928+
tokenizer_has_new_eos = tokenizer.eos_token_id != self.model.config.eos_token_id
929+
if model_has_generation_config:
930+
# `generation_config.eos_token_id` is None: direct comparision
931+
if self.model.generation_config.eos_token_id is None:
932+
tokenizer_has_new_eos |= tokenizer.eos_token_id != self.model.generation_config.eos_token_id
933+
else:
934+
# `generation_config.eos_token_id` is an `int`: convert it to list (and continue below)
935+
if isinstance(self.model.generation_config.eos_token_id, int):
936+
self.model.generation_config.eos_token_id = [self.model.generation_config.eos_token_id]
937+
# `generation_config.eos_token_id` is a `list`: check if the tokenizer's EOS token is in the list
938+
tokenizer_has_new_eos |= tokenizer.eos_token_id not in self.model.generation_config.eos_token_id
939+
940+
if tokenizer_has_new_eos:
941+
updated_tokens["eos_token_id"] = tokenizer.eos_token_id
942+
self.model.config.eos_token_id = tokenizer.eos_token_id
943+
# The generation config may hold more than one EOS token. We preserve the original EOS tokens: any of the
944+
# EOS tokens defined here will halt generation.
945+
if model_has_generation_config:
946+
all_eos_tokens = [tokenizer.eos_token_id] + list(self.model.generation_config.eos_token_id)
947+
self.model.generation_config.eos_token_id = [token for token in all_eos_tokens if token is not None]
948+
949+
# 2 - Align BOS
950+
tokenizer_has_new_bos = tokenizer.bos_token_id != self.model.config.bos_token_id
951+
if model_has_generation_config:
952+
tokenizer_has_new_bos |= tokenizer.bos_token_id != self.model.generation_config.bos_token_id
953+
954+
if tokenizer_has_new_bos:
955+
updated_tokens["bos_token_id"] = tokenizer.bos_token_id
956+
self.model.config.bos_token_id = tokenizer.bos_token_id
957+
if model_has_generation_config:
958+
self.model.generation_config.bos_token_id = tokenizer.bos_token_id
959+
960+
# 3 - Align PAD
961+
tokenizer_has_new_pad = tokenizer.pad_token_id != self.model.config.pad_token_id
962+
if model_has_generation_config:
963+
tokenizer_has_new_pad |= tokenizer.pad_token_id != self.model.generation_config.pad_token_id
964+
965+
if tokenizer_has_new_pad:
966+
updated_tokens["pad_token_id"] = tokenizer.pad_token_id
967+
self.model.config.pad_token_id = tokenizer.pad_token_id
968+
if model_has_generation_config:
969+
self.model.generation_config.pad_token_id = tokenizer.pad_token_id
970+
971+
# 4 - Warn users about the changes
972+
if len(updated_tokens) > 0:
973+
logger.warning(
974+
"The tokenizer has new special tokens that are also defined in the model configs. The model "
975+
f"configs were aligned accordingly. Updated tokens: {updated_tokens}"
976+
)
977+
908978
def _set_signature_columns_if_needed(self):
909979
if self._signature_columns is None:
910980
# Inspect model forward signature to keep only the arguments it accepts.
@@ -2162,6 +2232,10 @@ def train(
21622232

21632233
self.is_in_train = True
21642234

2235+
# If the model uses a tokenizer, it may have a new tokens for fine-tuning purposes.
2236+
if isinstance(self.processing_class, (PreTrainedTokenizerBase, ProcessorMixin)):
2237+
self._align_special_tokens()
2238+
21652239
# Attach NEFTune hooks if necessary
21662240
if self.neftune_noise_alpha is not None:
21672241
self.model = self._activate_neftune(self.model)

tests/generation/test_configuration_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,32 +153,38 @@ def test_validate(self):
153153
logger = transformers_logging.get_logger("transformers.generation.configuration_utils")
154154

155155
# A correct configuration will not throw any warning
156+
logger.warning_once.cache_clear()
156157
with CaptureLogger(logger) as captured_logs:
157158
GenerationConfig()
158159
self.assertEqual(len(captured_logs.out), 0)
159160

160161
# Inconsequent but technically wrong configuration will throw a warning (e.g. setting sampling
161162
# parameters with `do_sample=False`). May be escalated to an error in the future.
163+
logger.warning_once.cache_clear()
162164
with CaptureLogger(logger) as captured_logs:
163165
GenerationConfig(return_dict_in_generate=False, output_scores=True)
164166
self.assertNotEqual(len(captured_logs.out), 0)
165167

168+
logger.warning_once.cache_clear()
166169
with CaptureLogger(logger) as captured_logs:
167170
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5) # store for later
168171
self.assertNotEqual(len(captured_logs.out), 0)
169172

170173
# Expanding on the case above, we can update a bad configuration to get rid of the warning. Ideally,
171174
# that is done by unsetting the parameter (i.e. setting it to None)
175+
logger.warning_once.cache_clear()
172176
with CaptureLogger(logger) as captured_logs:
173177
# BAD - 0.9 means it is still set, we should warn
174178
generation_config_bad_temperature.update(temperature=0.9)
175179
self.assertNotEqual(len(captured_logs.out), 0)
176180

181+
logger.warning_once.cache_clear()
177182
with CaptureLogger(logger) as captured_logs:
178183
# CORNER CASE - 1.0 is the default, we can't detect whether it is set by the user or not, we shouldn't warn
179184
generation_config_bad_temperature.update(temperature=1.0)
180185
self.assertEqual(len(captured_logs.out), 0)
181186

187+
logger.warning_once.cache_clear()
182188
with CaptureLogger(logger) as captured_logs:
183189
# OK - None means it is unset, nothing to warn about
184190
generation_config_bad_temperature.update(temperature=None)
@@ -198,12 +204,14 @@ def test_validate(self):
198204
GenerationConfig(logits_processor="foo")
199205

200206
# Model-specific parameters will NOT raise an exception or a warning
207+
logger.warning_once.cache_clear()
201208
with CaptureLogger(logger) as captured_logs:
202209
GenerationConfig(foo="bar")
203210
self.assertEqual(len(captured_logs.out), 0)
204211

205212
# By default we throw a short warning. However, we log with INFO level the details.
206213
# Default: we don't log the incorrect input values, only a short summary. We explain how to get more details.
214+
logger.warning_once.cache_clear()
207215
with LoggingLevel(logging.WARNING):
208216
with CaptureLogger(logger) as captured_logs:
209217
GenerationConfig(do_sample=False, temperature=0.5)
@@ -212,6 +220,8 @@ def test_validate(self):
212220
self.assertIn("Set `TRANSFORMERS_VERBOSITY=info` for more details", captured_logs.out)
213221

214222
# INFO level: we share the full deets
223+
logger.warning_once.cache_clear()
224+
logger.info_once.cache_clear()
215225
with LoggingLevel(logging.INFO):
216226
with CaptureLogger(logger) as captured_logs:
217227
GenerationConfig(do_sample=False, temperature=0.5)

tests/trainer/test_trainer.py

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
default_data_collator,
4949
enable_full_determinism,
5050
get_polynomial_decay_schedule_with_warmup,
51+
is_datasets_available,
5152
is_torch_available,
5253
logging,
5354
set_seed,
@@ -161,6 +162,8 @@
161162
if is_safetensors_available():
162163
import safetensors.torch
163164

165+
if is_datasets_available():
166+
import datasets
164167

165168
# for version specific tests in TrainerIntegrationTest
166169
require_accelerate_version_min_0_28 = partial(require_accelerate, min_version="0.28")
@@ -519,7 +522,6 @@ def forward(self, input_ids, **kwargs):
519522
return logits
520523

521524
def create_dummy_dataset_for_text_generation(vocab_size, seq_length, num_samples):
522-
import datasets
523525
import numpy as np
524526

525527
# Create random input sequences
@@ -595,8 +597,6 @@ def get_regression_trainer(
595597
)
596598

597599
def get_language_model_trainer(**kwargs):
598-
import datasets
599-
600600
dataset = datasets.load_dataset("fka/awesome-chatgpt-prompts")
601601
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
602602
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
@@ -773,8 +773,6 @@ def test_reproducible_training(self):
773773
self.check_trained_model(trainer.model, alternate_seed=True)
774774

775775
def test_trainer_with_datasets(self):
776-
import datasets
777-
778776
np.random.seed(42)
779777
x = np.random.normal(size=(64,)).astype(np.float32)
780778
y = 2.0 * x + 3.0 + np.random.normal(scale=0.1, size=(64,)).astype(np.float32)
@@ -823,7 +821,6 @@ def test_model_init(self):
823821
@slow
824822
def test_gradient_accumulation_loss_alignment_with_model_loss(self):
825823
set_seed(42)
826-
import datasets
827824

828825
model_name = "nickypro/tinyllama-15M"
829826
dataset_name = "wikitext"
@@ -923,7 +920,6 @@ def tokenize_function(examples):
923920

924921
def test_gradient_accumulation_loss_alignment_with_loss_func(self):
925922
set_seed(42)
926-
import datasets
927923

928924
model_name = "roneneldan/TinyStories-33M"
929925
dataset_name = "wikitext"
@@ -4960,6 +4956,51 @@ def test_best_model_checkpoint_behavior(self):
49604956

49614957
assert len(os.listdir(tmpdir)) == trainer.state.global_step // 2
49624958

4959+
def test_special_token_aligment(self):
4960+
"""
4961+
Tests that special token changes in the tokenizer result in model configs updates when using the trainer, to
4962+
ensure special tokens are aligned across configs
4963+
"""
4964+
4965+
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
4966+
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
4967+
4968+
# add new special tokens to tokenizer, so we can test that trainer aligns the model configs with the tokenizer
4969+
tokenizer.eos_token = "<|im_end|>"
4970+
tokenizer.pad_token = "<|im_end|>"
4971+
tokenizer.bos_token = "<|im_start|>"
4972+
tokenizer.add_special_tokens({"additional_special_tokens": ["<|im_end|>", "<|im_start|>"]})
4973+
4974+
# the model needs to have its embedding layer resized accordingly
4975+
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
4976+
4977+
# create a random dataset from the **new** vocab size
4978+
x = torch.randint(0, len(tokenizer), (64,))
4979+
dataset = RepeatDataset(x, length=2)
4980+
4981+
with tempfile.TemporaryDirectory() as tmpdir:
4982+
training_args = TrainingArguments(
4983+
output_dir=tmpdir, report_to="none", max_steps=1, per_device_train_batch_size=1
4984+
)
4985+
trainer = Trainer(
4986+
model=model,
4987+
args=training_args,
4988+
processing_class=tokenizer,
4989+
train_dataset=dataset,
4990+
)
4991+
4992+
# We haven't started training -> not yet aligned
4993+
self.assertNotEqual(trainer.model.config.eos_token_id, tokenizer.eos_token_id)
4994+
self.assertNotEqual(trainer.model.config.pad_token_id, tokenizer.pad_token_id)
4995+
self.assertNotEqual(trainer.model.config.bos_token_id, tokenizer.bos_token_id)
4996+
4997+
trainer.train()
4998+
4999+
# Must be aligned as soon as we start training
5000+
self.assertEqual(trainer.model.config.eos_token_id, tokenizer.eos_token_id)
5001+
self.assertEqual(trainer.model.config.pad_token_id, tokenizer.pad_token_id)
5002+
self.assertEqual(trainer.model.config.bos_token_id, tokenizer.bos_token_id)
5003+
49635004

49645005
@require_torch
49655006
@is_staging_test

0 commit comments

Comments
 (0)