Skip to content

Commit

Permalink
start using training_args.parallel_mode (#8882)
Browse files Browse the repository at this point in the history
  • Loading branch information
stas00 authored Dec 1, 2020
1 parent b08843c commit 379005c
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
3 changes: 2 additions & 1 deletion examples/seq2seq/finetune_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from seq2seq_training_args import Seq2SeqTrainingArguments
from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer, HfArgumentParser, MBartTokenizer, set_seed
from transformers.trainer_utils import EvaluationStrategy, is_main_process
from transformers.training_args import ParallelMode
from utils import (
Seq2SeqDataCollator,
Seq2SeqDataset,
Expand Down Expand Up @@ -132,7 +133,7 @@ def main():
training_args.local_rank,
training_args.device,
training_args.n_gpu,
bool(training_args.local_rank != -1),
bool(training_args.parallel_mode == ParallelMode.DISTRIBUTED),
training_args.fp16,
)
# Set the verbosity to info of the Transformers logger (on main process only):
Expand Down
3 changes: 2 additions & 1 deletion examples/seq2seq/seq2seq_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
get_polynomial_decay_schedule_with_warmup,
)
from transformers.trainer_pt_utils import get_tpu_sampler
from transformers.training_args import ParallelMode


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -123,7 +124,7 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
if self.args.sortish_sampler:
self.train_dataset.make_sortish_sampler(
self.args.per_device_train_batch_size,
distributed=(self.args.local_rank != -1),
distributed=(self.args.parallel_mode == ParallelMode.DISTRIBUTED),
)

return (
Expand Down

0 comments on commit 379005c

Please sign in to comment.