Skip to content

Commit

Permalink
!87 support inf nan mode and overflow check
Browse files Browse the repository at this point in the history
Merge pull request !87 from Jializheng/adaptor
  • Loading branch information
郭鹏 authored and gitee-org committed Oct 9, 2023
2 parents 5c7c899 + 7f4a626 commit c134c39
Show file tree
Hide file tree
Showing 9 changed files with 196 additions and 117 deletions.
4 changes: 4 additions & 0 deletions deepspeed_npu/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu

global FLAG_SUPPORT_INF_NAN
FLAG_SUPPORT_INF_NAN = hasattr(torch_npu.npu.utils, 'is_support_inf_nan') and torch_npu.npu.utils.is_support_inf_nan()

from . import adaptor_utils
from . import adaptor_launcher_runner
from . import adaptor_moe_shared_moe
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from deepspeed.runtime.activation_checkpointing.checkpointing import gather_partitioned_activations, detach_variable, \
merge_tensors, get_cuda_rng_tracker, is_activation_to_checkpoint, extract_tensors
from deepspeed.runtime.utils import copy_to_device, move_to_device, see_memory_usage, bwc_tensor_model_parallel_rank
from . import FLAG_SUPPORT_INF_NAN

CKPT_INIT_FLAG = False
CKPT_OVERFLOW_FLAG = False
Expand Down Expand Up @@ -78,19 +79,19 @@ def backward(ctx, *grads):

see_memory_usage("In backward checkpointing code before forward", force=False)

global CKPT_INIT_FLAG, CKPT_OVERFLOW_FLAG, CKPT_CONST_VAR
if not CKPT_INIT_FLAG:
CKPT_INIT_FLAG = True
CKPT_CONST_VAR = torch.tensor([65504.], dtype=torch.float16).npu()

CKPT_OVERFLOW_FLAG = torch_npu.npu.get_npu_overflow_flag()
with torch.enable_grad():

outputs = ctx.run_function(*detached_inputs)
torch.npu.clear_npu_overflow_flag()


if not FLAG_SUPPORT_INF_NAN:
global CKPT_INIT_FLAG, CKPT_OVERFLOW_FLAG, CKPT_CONST_VAR
if not CKPT_INIT_FLAG:
CKPT_INIT_FLAG = True
CKPT_CONST_VAR = torch.tensor([65504.], dtype=torch.float16).npu()

CKPT_OVERFLOW_FLAG = torch_npu.npu.get_npu_overflow_flag()
with torch.enable_grad():
outputs = ctx.run_function(*detached_inputs)
torch.npu.clear_npu_overflow_flag()
else:
with torch.enable_grad():
outputs = ctx.run_function(*detached_inputs)

see_memory_usage("In backward checkpointing code after forward", force=False)
# Set the states back to what it was at the start of this function.
Expand Down Expand Up @@ -137,9 +138,10 @@ def backward(ctx, *grads):
else:
ret_list.append(None)

temp = torch_npu.npu.get_npu_overflow_flag()
CKPT_OVERFLOW_FLAG = CKPT_OVERFLOW_FLAG or temp
CKPT_CONST_VAR + CKPT_OVERFLOW_FLAG * 10000
if not FLAG_SUPPORT_INF_NAN:
temp = torch_npu.npu.get_npu_overflow_flag()
CKPT_OVERFLOW_FLAG = CKPT_OVERFLOW_FLAG or temp
CKPT_CONST_VAR + CKPT_OVERFLOW_FLAG * 10000

return tuple(ret_list)

Expand Down
11 changes: 6 additions & 5 deletions deepspeed_npu/adaptor_runtime_fp16_fused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from deepspeed.utils import logger
from apex.contrib.combine_tensors import combine_npu
from deepspeed_npu.adaptor_ops_adam_fused_adam import FusedAdamNPU
from . import FLAG_SUPPORT_INF_NAN


# fused_optimizer============
Expand Down Expand Up @@ -127,7 +128,8 @@ def initialize_optimizer_states(self):


def Fp16OptimizerBackward(self, loss, create_graph=False, retain_graph=False):
torch.npu.clear_npu_overflow_flag()
if not FLAG_SUPPORT_INF_NAN:
torch.npu.clear_npu_overflow_flag()
scaled_loss = (loss.float()) * self.cur_scale
scaled_loss.backward(create_graph=create_graph, retain_graph=retain_graph)

Expand Down Expand Up @@ -238,10 +240,9 @@ def get_combine_weight_norm(self, parameters, norm_type=2, mpu=None):
group=mpu.get_model_parallel_group())
total_norm = total_norm_npu[0].item()**(1. / norm_type)

if torch_npu.__version__ >= "2.1":
overflow = torch_npu._amp_foreach_non_finite_check([total_norm_npu])
else:
overflow = torch_npu._amp_foreach_non_finite_check_([total_norm_npu])
overflow = False
if not FLAG_SUPPORT_INF_NAN:
overflow = torch_npu.npu.utils.npu_check_overflow([total_norm_npu])

if overflow or total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm:
total_norm = -1
Expand Down
19 changes: 15 additions & 4 deletions deepspeed_npu/adaptor_runtime_fp16_loss_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,27 @@
import torch_npu
from torch_npu.npu import clear_npu_overflow_flag
from deepspeed.runtime.fp16 import loss_scaler, unfused_optimizer
from . import FLAG_SUPPORT_INF_NAN


# loss_scaler============
def backward(self, loss, retain_graph=False):
clear_npu_overflow_flag()
if not FLAG_SUPPORT_INF_NAN:
clear_npu_overflow_flag()
scaled_loss = loss * self.loss_scale
scaled_loss.backward(retain_graph=retain_graph)


def has_overflow_serial(self, params):
grads = [p.grad.data for p in params if p.grad is not None]
return torch_npu._amp_foreach_non_finite_check_(grads)
if not FLAG_SUPPORT_INF_NAN:
grads = [p.grad.data for p in params if p.grad is not None]
return torch_npu.npu.utils.npu_check_overflow(grads)

for p in params:
if p.grad is not None and self._has_inf_or_nan(p.grad.data):
return True
return False


loss_scaler.LossScalerBase.backward = backward
loss_scaler.DynamicLossScaler.has_overflow_serial = has_overflow_serial
loss_scaler.DynamicLossScaler.has_overflow_serial = has_overflow_serial
28 changes: 16 additions & 12 deletions deepspeed_npu/adaptor_runtime_fp16_unfused_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import sys
from torch_npu.npu import clear_npu_overflow_flag
from deepspeed.runtime.fp16 import unfused_optimizer
from . import FLAG_SUPPORT_INF_NAN


# unfused_optimizer============
class Fp16UnfusedOptimizerNpu(unfused_optimizer.FP16_UnfusedOptimizer):
Expand All @@ -15,15 +17,15 @@ def __init__(self,
clip_grad=0.0,
fused_lamb_legacy=False):
super().__init__(init_optimizer,
deepspeed,
static_loss_scale,
dynamic_loss_scale,
dynamic_loss_args,
verbose,
mpu,
clip_grad,
fused_lamb_legacy)
deepspeed,
static_loss_scale,
dynamic_loss_scale,
dynamic_loss_args,
verbose,
mpu,
clip_grad,
fused_lamb_legacy)

def step(self, closure=None):
self.fused_lamb_legacy = False
super().step(closure)
Expand All @@ -36,12 +38,14 @@ def backward(self, loss, create_graph=False, retain_graph=False):
2. scaled_loss = fp32_loss*loss_scale
3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves
"""
clear_npu_overflow_flag()
if not FLAG_SUPPORT_INF_NAN:
clear_npu_overflow_flag()
scaled_loss = (loss.float()) * self.cur_scale

scaled_loss.backward(create_graph=create_graph, retain_graph=retain_graph)



unfused_optimizer.FP16_UnfusedOptimizer = Fp16UnfusedOptimizerNpu
for k, v in sys.modules.items():
if 'deepspeed' in k and hasattr(v, 'FP16_UnfusedOptimizer'):
setattr(v, 'FP16_UnfusedOptimizer', Fp16UnfusedOptimizerNpu)
setattr(v, 'FP16_UnfusedOptimizer', Fp16UnfusedOptimizerNpu)
5 changes: 4 additions & 1 deletion deepspeed_npu/adaptor_runtime_pipe_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from deepspeed.runtime.pipe.engine import _tensor_bytes, PipelineEngine
from deepspeed.runtime.pipe import p2p, schedule
from deepspeed.runtime.utils import PartitionedTensor
from . import FLAG_SUPPORT_INF_NAN


class PipelineEngineNPU(PipelineEngine):
ID_TO_DTYPE = [
Expand All @@ -22,7 +24,8 @@ class PipelineEngineNPU(PipelineEngine):
]

def _exec_backward_pass(self, buffer_id):
clear_npu_overflow_flag()
if not FLAG_SUPPORT_INF_NAN:
clear_npu_overflow_flag()
super()._exec_backward_pass(buffer_id)

def _exec_load_micro_batch(self, buffer_id):
Expand Down
51 changes: 31 additions & 20 deletions deepspeed_npu/adaptor_runtime_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from deepspeed.utils import groups
from deepspeed.runtime.utils import is_model_parallel_parameter
import deepspeed_npu.adaptor_runtime_activation_checkpointing_checkpointing as checkpointing
from . import FLAG_SUPPORT_INF_NAN


def check_using_norm(self, norm_group, reduce_overflow=True):
# TODO: I don't think reduce_overflow is needed if mpu is None
Expand Down Expand Up @@ -38,13 +40,18 @@ def check_using_norm(self, norm_group, reduce_overflow=True):
overflow = overflow_npu[0].item()
return bool(overflow)


def has_overflow_serial(self, params):
grads = [p.grad.data for p in params if p.grad is not None]
if torch_npu.__version__ >= "2.1":
res = torch_npu._amp_foreach_non_finite_check(grads)
else:
res = torch_npu._amp_foreach_non_finite_check_(grads)
return res
if not FLAG_SUPPORT_INF_NAN:
grads = [p.grad.data for p in params if p.grad is not None]
res = torch_npu.npu.utils.npu_check_overflow(grads)
return res

for i, p in enumerate(params):
if p.grad is not None and self._has_inf_or_nan(p.grad.data, i):
return True
return False


def has_overflow(self, params, has_moe_params=None):
if has_moe_params is None:
Expand Down Expand Up @@ -146,13 +153,15 @@ def get_grad_norm(parameters, norm_type=2, mpu=None):
group=mpu.get_model_parallel_group())
total_norm = total_norm_npu[0].item()**(1. / norm_type)

overflow = torch_npu._amp_foreach_non_finite_check_([total_norm_npu])
if mpu is not None:
overflow_npu = torch.npu.IntTensor([overflow])
torch.distributed.all_reduce(overflow_npu,
op=torch.distributed.ReduceOp.MAX,
group=mpu.get_model_parallel_group())
overflow = overflow_npu.item()
overflow = False
if not FLAG_SUPPORT_INF_NAN:
overflow = torch_npu.npu.utils.npu_check_overflow([total_norm_npu])
if mpu is not None:
overflow_npu = torch.npu.IntTensor([overflow])
torch.distributed.all_reduce(overflow_npu,
op=torch.distributed.ReduceOp.MAX,
group=mpu.get_model_parallel_group())
overflow = overflow_npu.item()

if overflow or total_norm == float('inf') or \
total_norm == -float('inf') or total_norm != total_norm:
Expand Down Expand Up @@ -214,13 +223,15 @@ def get_weight_norm(parameters, norm_type=2, mpu=None):
group=mpu.get_model_parallel_group())
total_norm = total_norm_npu[0].item()**(1. / norm_type)

overflow = torch_npu._amp_foreach_non_finite_check_([total_norm_npu])
if mpu is not None:
overflow_npu = torch.npu.IntTensor([overflow])
torch.distributed.all_reduce(overflow_npu,
op=torch.distributed.ReduceOp.MAX,
group=mpu.get_model_parallel_group())
overflow = overflow_npu.item()
overflow = False
if not FLAG_SUPPORT_INF_NAN:
overflow = torch_npu.npu.utils.npu_check_overflow([total_norm_npu])
if mpu is not None:
overflow_npu = torch.npu.IntTensor([overflow])
torch.distributed.all_reduce(overflow_npu,
op=torch.distributed.ReduceOp.MAX,
group=mpu.get_model_parallel_group())
overflow = overflow_npu.item()

if overflow or total_norm == float('inf') or \
total_norm == -float('inf') or total_norm != total_norm:
Expand Down
Loading

0 comments on commit c134c39

Please sign in to comment.