From 3d614901245103541f7215cc346ae47d4b180c0a Mon Sep 17 00:00:00 2001 From: Niels Date: Tue, 9 Apr 2024 11:53:40 +0200 Subject: [PATCH 1/2] Fix data collator --- src/transformers/trainer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 1cdd8623e58c98..494491e797d3b7 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -492,7 +492,11 @@ def __init__( ): self.place_model_on_device = False - default_collator = DataCollatorWithPadding(tokenizer) if tokenizer is not None else default_data_collator + default_collator = ( + DataCollatorWithPadding(tokenizer) + if tokenizer is not None and isinstance(tokenizer, PreTrainedTokenizerBase) + else default_data_collator + ) self.data_collator = data_collator if data_collator is not None else default_collator self.train_dataset = train_dataset self.eval_dataset = eval_dataset From 044c14a04c935e9353cc87a10e43bd51ba419021 Mon Sep 17 00:00:00 2001 From: Niels Date: Tue, 9 Apr 2024 11:58:36 +0200 Subject: [PATCH 2/2] Support feature extractors as well --- src/transformers/trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 494491e797d3b7..844e464c6eb270 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -58,6 +58,7 @@ from .configuration_utils import PretrainedConfig from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator from .debug_utils import DebugOption, DebugUnderflowOverflow +from .feature_extraction_sequence_utils import SequenceFeatureExtractor from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend from .image_processing_utils import BaseImageProcessor from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available @@ -494,7 +495,7 @@ def __init__( default_collator = ( DataCollatorWithPadding(tokenizer) - if tokenizer is not None and isinstance(tokenizer, PreTrainedTokenizerBase) + if tokenizer is not None and isinstance(tokenizer, (PreTrainedTokenizerBase, SequenceFeatureExtractor)) else default_data_collator ) self.data_collator = data_collator if data_collator is not None else default_collator