diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 4f75b72641d2..8e4bc401f785 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -65,7 +65,7 @@ from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES from .optimization import Adafactor, get_scheduler -from .pytorch_utils import ALL_LAYERNORM_LAYERS +from .pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_less_than_1_11 from .tokenization_utils_base import PreTrainedTokenizerBase from .trainer_callback import ( CallbackHandler, @@ -85,6 +85,7 @@ distributed_broadcast_scalars, distributed_concat, find_batch_size, + get_dataloader_sampler, get_model_param_count, get_module_class_from_name, get_parameter_names, @@ -219,6 +220,7 @@ if TYPE_CHECKING: import optuna + logger = logging.get_logger(__name__) @@ -1783,8 +1785,17 @@ def _inner_training_loop( # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point. if not args.ignore_data_skip: for epoch in range(epochs_trained): - for _ in train_dataloader: - break + sampler = get_dataloader_sampler(train_dataloader) + is_random_sampler = isinstance(sampler, RandomSampler) + if is_torch_less_than_1_11 or not is_random_sampler: + # We just need to begin an iteration to create the randomization of the sampler. + for _ in train_dataloader: + break + else: + # Otherwise we need to call the whooooole sampler cause there is some random operation added + # AT THE VERY END! + sampler = sampler if sampler is not None else [] + _ = list(sampler) total_batched_samples = 0 for epoch in range(epochs_trained, num_train_epochs): diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index b8c4080c2d54..cb6249f19a93 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -55,6 +55,13 @@ logger = logging.get_logger(__name__) +def get_dataloader_sampler(dataloader): + if hasattr(dataloader, "batch_sampler") and dataloader.batch_sampler is not None: + return get_dataloader_sampler(dataloader.batch_sampler) + elif hasattr(dataloader, "sampler"): + return dataloader.sampler + + def atleast_1d(tensor_or_array: Union[torch.Tensor, np.ndarray]): if isinstance(tensor_or_array, torch.Tensor): if hasattr(torch, "atleast_1d"):