Skip to content

Commit

Permalink
fix sharded_ddp mode
Browse files Browse the repository at this point in the history
  • Loading branch information
kamo-naoyuki committed Feb 25, 2021
1 parent 7b78b65 commit f626a38
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions espnet2/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ def run(
scaler=scaler,
summary_writer=summary_writer,
options=trainer_options,
distributed_option=distributed_option,
)

with reporter.observe("valid") as sub_reporter:
Expand All @@ -297,6 +298,7 @@ def run(
iterator=valid_iter_factory.build_iter(iepoch),
reporter=sub_reporter,
options=trainer_options,
distributed_option=distributed_option,
)

if not distributed_option.distributed or distributed_option.dist_rank == 0:
Expand Down Expand Up @@ -435,6 +437,7 @@ def train_one_epoch(
reporter: SubReporter,
summary_writer: Optional[SummaryWriter],
options: TrainerOptions,
distributed_option: DistributedOption,
) -> bool:
assert check_argument_types()

Expand All @@ -446,7 +449,7 @@ def train_one_epoch(
no_forward_run = options.no_forward_run
ngpu = options.ngpu
use_wandb = options.use_wandb
distributed = isinstance(model, torch.nn.parallel.DistributedDataParallel)
distributed = distributed_option.distributed

if log_interval is None:
try:
Expand Down Expand Up @@ -650,11 +653,12 @@ def validate_one_epoch(
iterator: Iterable[Dict[str, torch.Tensor]],
reporter: SubReporter,
options: TrainerOptions,
distributed_option: DistributedOption,
) -> None:
assert check_argument_types()
ngpu = options.ngpu
no_forward_run = options.no_forward_run
distributed = isinstance(model, torch.nn.parallel.DistributedDataParallel)
distributed = distributed_option.distributed

model.eval()

Expand Down

0 comments on commit f626a38

Please sign in to comment.