diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index 09de21502c27..2706d4474515 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -341,39 +341,6 @@ def _bwd_hook_unexpected_inputs_msg(value): def _pre_backward_module_hook(module, inputs, output): - if not hasattr(module, "pre_bwd_fn"): - - @instrument_w_nvtx - def _run_before_backward_function(sub_module): - # some models (e.g. Albert) may run multiple forwards on the same layer in a loop - # before doing backwards, so each backward will need a pre-fetch - using reference - # counting to support this scenario - #print(f"COUNTER before: {sub_module.applied_pre_backward_ref_cnt}") - if sub_module.applied_pre_backward_ref_cnt > 0: - self.pre_sub_module_backward_function(sub_module) - sub_module.applied_pre_backward_ref_cnt -= 1 - #print(f"COUNTER after: {sub_module.applied_pre_backward_ref_cnt}") - - class PreBackwardFunctionForModule(torch.autograd.Function): - - @staticmethod - def forward(ctx, outputs): - # Capture `module` and _run_before_backward_function - ctx.module = module - ctx.pre_backward_function = _run_before_backward_function - if not hasattr(ctx.module, "applied_pre_backward_ref_cnt"): - ctx.module.applied_pre_backward_ref_cnt = 0 - ctx.module.applied_pre_backward_ref_cnt += 1 - outputs = outputs.detach() - return outputs - - @staticmethod - def backward(ctx, *args): - ctx.pre_backward_function(ctx.module) - return args - - module.pre_bwd_fn = PreBackwardFunctionForModule - return apply_to_tensors_only(module.pre_bwd_fn.apply, output, warning_msg_fn=_bwd_hook_unexpected_inputs_msg) @@ -402,41 +369,6 @@ def _post_backward_module_hook(module, inputs): if not hasattr(module, "ds_grads_remaining"): module.ds_grads_remaining = 0 - if not hasattr(module, "post_bwd_fn"): - - @instrument_w_nvtx - def _run_after_backward_function(sub_module): - if sub_module.ds_grads_remaining == 0: - self.post_sub_module_backward_function(sub_module) - - class PostBackwardFunctionModule(torch.autograd.Function): - - @staticmethod - def forward(ctx, output): - ctx.module = module - if output.requires_grad: - #TODO SOME TIMES post backward does not seem to be triggered debug in detail - #Should only cause increase in memory not correctness issue - #if output.grad_fn.__class__.__name__ == 'ViewBackward': - # ctx.view=True - # print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly") - #assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors." - #if module.ds_grads_remaining == 0: - # print(f"Before Forward: {ctx.module.__class__.__name__}") - module.ds_grads_remaining += 1 - ctx.post_backward_function = _run_after_backward_function - output = output.detach() - return output - - @staticmethod - def backward(ctx, *args): - ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1 - if ctx.module.ds_grads_remaining == 0: - ctx.post_backward_function(ctx.module) - return args - - module.post_bwd_fn = PostBackwardFunctionModule - return apply_to_tensors_only(module.post_bwd_fn.apply, inputs, warning_msg_fn=_bwd_hook_unexpected_inputs_msg) @@ -448,9 +380,77 @@ def backward(ctx, *args): self.forward_hooks.append(module.register_forward_hook(_post_forward_module_hook)) # Pre backward hook + if not hasattr(module, "pre_bwd_fn"): + + @instrument_w_nvtx + def _run_before_backward_function(sub_module): + # some models (e.g. Albert) may run multiple forwards on the same layer in a loop + # before doing backwards, so each backward will need a pre-fetch - using reference + # counting to support this scenario + #print(f"COUNTER before: {sub_module.applied_pre_backward_ref_cnt}") + if sub_module.applied_pre_backward_ref_cnt > 0: + self.pre_sub_module_backward_function(sub_module) + sub_module.applied_pre_backward_ref_cnt -= 1 + #print(f"COUNTER after: {sub_module.applied_pre_backward_ref_cnt}") + + class PreBackwardFunctionForModule(torch.autograd.Function): + + @staticmethod + def forward(ctx, outputs): + # Capture `module` and _run_before_backward_function + ctx.module = module + ctx.pre_backward_function = _run_before_backward_function + if not hasattr(ctx.module, "applied_pre_backward_ref_cnt"): + ctx.module.applied_pre_backward_ref_cnt = 0 + ctx.module.applied_pre_backward_ref_cnt += 1 + outputs = outputs.detach() + return outputs + + @staticmethod + def backward(ctx, *args): + ctx.pre_backward_function(ctx.module) + return args + + module.pre_bwd_fn = PreBackwardFunctionForModule + self.backward_hooks.append(module.register_forward_hook(_pre_backward_module_hook)) # post backward hook + if not hasattr(module, "post_bwd_fn"): + + @instrument_w_nvtx + def _run_after_backward_function(sub_module): + if sub_module.ds_grads_remaining == 0: + self.post_sub_module_backward_function(sub_module) + + class PostBackwardFunctionModule(torch.autograd.Function): + + @staticmethod + def forward(ctx, output): + ctx.module = module + if output.requires_grad: + #TODO SOME TIMES post backward does not seem to be triggered debug in detail + #Should only cause increase in memory not correctness issue + #if output.grad_fn.__class__.__name__ == 'ViewBackward': + # ctx.view=True + # print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly") + #assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors." + #if module.ds_grads_remaining == 0: + # print(f"Before Forward: {ctx.module.__class__.__name__}") + module.ds_grads_remaining += 1 + ctx.post_backward_function = _run_after_backward_function + output = output.detach() + return output + + @staticmethod + def backward(ctx, *args): + ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1 + if ctx.module.ds_grads_remaining == 0: + ctx.post_backward_function(ctx.module) + return args + + module.post_bwd_fn = PostBackwardFunctionModule + self.backward_hooks.append(module.register_forward_pre_hook(_post_backward_module_hook)) @torch.no_grad()