Skip to content

Commit

Permalink
Add mixed precision training to TorchEngine (#1322)
Browse files Browse the repository at this point in the history
uses torch_amp_options as config dict with "dtype" option.
Adds GradScaler to engine, and applies autocast and the scaler during training if
amp is enabled.
  • Loading branch information
JackTemaki authored May 15, 2023
1 parent 88127d1 commit 003b2e8
Showing 1 changed file with 26 additions and 5 deletions.
31 changes: 26 additions & 5 deletions returnn/torch/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@

from __future__ import annotations
from typing import Optional, Union, Callable, Dict
from contextlib import nullcontext

import os
import torch
import torch.utils.data.datapipes as dp
from torch import autocast
from torch.cuda import amp
from torchdata.dataloader2 import DataLoader2
from random import random

Expand Down Expand Up @@ -53,6 +56,9 @@ def __init__(self, config: Config):
self._save_model_epoch_interval = 1
self._updater = None # type: Optional[Updater]

self._amp_dtype = None # type: Optional[str]
self._grad_scaler = None # type: Optional[amp.GradScaler]

self._device = _get_device_from_config(config)
print("Using device:", self._device, file=log.v2)

Expand Down Expand Up @@ -104,6 +110,12 @@ def init_train_from_config(
self._train_step_func = self.config.typed_value("train_step")
assert self._train_step_func, "train_step not defined"

amp_options = self.config.typed_value("torch_amp_options")
if amp_options is not None:
assert isinstance(amp_options, dict)
self._amp_dtype = amp_options.get("dtype")
self._grad_scaler = amp.GradScaler()

def train(self):
"""
Main training loop.
Expand Down Expand Up @@ -144,7 +156,9 @@ def train_epoch(self):
accumulated_inv_norm_factors_dict = NumbersDict()
step_idx = 0
for data in self._train_dataloader:
self._run_step(data, train_flag=True)
self._updater.get_optimizer().zero_grad()
with autocast(device_type=self._device, dtype=self._amp_dtype) if self._amp_dtype else nullcontext():
self._run_step(data, train_flag=True)

train_ctx = rf.get_run_ctx()
total_loss = train_ctx.total_loss()
Expand All @@ -158,9 +172,13 @@ def train_epoch(self):
{name: float(_to_raw(loss.get_inv_norm_factor())) for name, loss in train_ctx.losses.items()}
)

self._updater.get_optimizer().zero_grad()
total_loss.raw_tensor.backward()
self._updater.get_optimizer().step()
if self._amp_dtype:
self._grad_scaler.scale(total_loss).backward()
self._grad_scaler.step(self._updater.get_optimizer())
self._grad_scaler.update()
else:
total_loss.raw_tensor.backward()
self._updater.get_optimizer().step()

accumulated_losses_dict += losses_dict
accumulated_inv_norm_factors_dict += inv_norm_factors_dict
Expand Down Expand Up @@ -211,7 +229,10 @@ def eval_model(self):
with torch.no_grad():
for data in data_loader:

self._run_step(data)
with autocast(
device_type=self._device, dtype=self._amp_dtype
) if self._amp_dtype else nullcontext():
self._run_step(data)
train_ctx = rf.get_run_ctx()

if score_keys is None:
Expand Down

0 comments on commit 003b2e8

Please sign in to comment.