From e44474187328fcfada278eceb67e112cc73aa114 Mon Sep 17 00:00:00 2001 From: Rajeev Goel <44844911+rajeevgl01@users.noreply.github.com> Date: Fri, 12 Jan 2024 13:09:33 -0700 Subject: [PATCH] Using torch.bfloat16 to prevent overflow instead of default fp16 in AMP --- main.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index 84230ea75..88eec76f9 100644 --- a/main.py +++ b/main.py @@ -184,7 +184,8 @@ def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mix if mixup_fn is not None: samples, targets = mixup_fn(samples, targets) - with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE): + # Using torch.bfloat16 to prevent overflow. Float16 has three less integer bits compared to bfloat16 which causes NaN loss and NaN grad norms during AMP training. + with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE, dtype=torch.bfloat16): outputs = model(samples) loss = criterion(outputs, targets) loss = loss / config.TRAIN.ACCUMULATION_STEPS @@ -241,7 +242,8 @@ def validate(config, data_loader, model): target = target.cuda(non_blocking=True) # compute output - with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE): + # Using torch.bfloat16 to prevent overflow. Float16 has three less integer bits compared to bfloat16 which causes NaN loss and NaN grad norms during AMP training. + with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE, dtype=torch.bfloat16): output = model(images) # measure accuracy and record loss