Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mixed precision training to TorchEngine #1322

Merged
merged 1 commit into from
May 15, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions returnn/torch/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@

from __future__ import annotations
from typing import Optional, Union, Callable, Dict
from contextlib import nullcontext

import os
import torch
import torch.utils.data.datapipes as dp
from torch import autocast
from torch.cuda import amp
from torchdata.dataloader2 import DataLoader2
from random import random

Expand Down Expand Up @@ -53,6 +56,9 @@ def __init__(self, config: Config):
self._save_model_epoch_interval = 1
self._updater = None # type: Optional[Updater]

self._amp_dtype = None # type: Optional[str]
self._grad_scaler = None # type: Optional[amp.GradScaler]

self._device = _get_device_from_config(config)
print("Using device:", self._device, file=log.v2)

Expand Down Expand Up @@ -104,6 +110,12 @@ def init_train_from_config(
self._train_step_func = self.config.typed_value("train_step")
assert self._train_step_func, "train_step not defined"

amp_options = self.config.typed_value("torch_amp_options")
if amp_options is not None:
assert isinstance(amp_options, dict)
self._amp_dtype = amp_options.get("dtype")
self._grad_scaler = amp.GradScaler()

def train(self):
"""
Main training loop.
Expand Down Expand Up @@ -144,7 +156,9 @@ def train_epoch(self):
accumulated_inv_norm_factors_dict = NumbersDict()
step_idx = 0
for data in self._train_dataloader:
self._run_step(data, train_flag=True)
self._updater.get_optimizer().zero_grad()
with autocast(device_type=self._device, dtype=self._amp_dtype) if self._amp_dtype else nullcontext():
self._run_step(data, train_flag=True)

train_ctx = rf.get_run_ctx()
total_loss = train_ctx.total_loss()
Expand All @@ -158,9 +172,13 @@ def train_epoch(self):
{name: float(_to_raw(loss.get_inv_norm_factor())) for name, loss in train_ctx.losses.items()}
)

self._updater.get_optimizer().zero_grad()
total_loss.raw_tensor.backward()
self._updater.get_optimizer().step()
if self._amp_dtype:
self._grad_scaler.scale(total_loss).backward()
self._grad_scaler.step(self._updater.get_optimizer())
self._grad_scaler.update()
else:
total_loss.raw_tensor.backward()
self._updater.get_optimizer().step()

accumulated_losses_dict += losses_dict
accumulated_inv_norm_factors_dict += inv_norm_factors_dict
Expand Down Expand Up @@ -211,7 +229,10 @@ def eval_model(self):
with torch.no_grad():
for data in data_loader:

self._run_step(data)
with autocast(
device_type=self._device, dtype=self._amp_dtype
) if self._amp_dtype else nullcontext():
self._run_step(data)
train_ctx = rf.get_run_ctx()

if score_keys is None:
Expand Down