Skip to content

Commit

Permalink
[Trainer] Fix default data collator (#30142)
Browse files Browse the repository at this point in the history
* Fix data collator

* Support feature extractors as well
  • Loading branch information
NielsRogge authored and Ita Zaporozhets committed May 14, 2024
1 parent 5d9eede commit 9664243
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from .configuration_utils import PretrainedConfig
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from .debug_utils import DebugOption, DebugUnderflowOverflow
from .feature_extraction_sequence_utils import SequenceFeatureExtractor
from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend
from .image_processing_utils import BaseImageProcessor
from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
Expand Down Expand Up @@ -492,7 +493,11 @@ def __init__(
):
self.place_model_on_device = False

default_collator = DataCollatorWithPadding(tokenizer) if tokenizer is not None else default_data_collator
default_collator = (
DataCollatorWithPadding(tokenizer)
if tokenizer is not None and isinstance(tokenizer, (PreTrainedTokenizerBase, SequenceFeatureExtractor))
else default_data_collator
)
self.data_collator = data_collator if data_collator is not None else default_collator
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
Expand Down

0 comments on commit 9664243

Please sign in to comment.