Skip to content

Commit

Permalink
save tensors in context of memory_efficient_linear (#3413)
Browse files Browse the repository at this point in the history
Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
  • Loading branch information
tohtana and tjruwase authored May 1, 2023
1 parent b4b63f5 commit 42858a9
Showing 1 changed file with 2 additions and 15 deletions.
17 changes: 2 additions & 15 deletions deepspeed/runtime/zero/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
from deepspeed import comm as dist
from deepspeed.accelerator import get_accelerator

tensor_map = {}


def print_rank_0(message, debug=False, force=False):
if dist.get_rank() == 0 and (debug or force):
Expand All @@ -50,14 +48,7 @@ class LinearFunctionForZeroStage3(torch.autograd.Function):
# bias is an optional argument
def forward(ctx, input, weight, bias=None):

weight_id = id(weight)
bias_id = id(bias)

#ctx.save_for_backward(input, weight, bias)
ctx.save_for_backward(input, torch.tensor(weight_id), torch.tensor(bias_id))

tensor_map[weight_id] = weight
tensor_map[bias_id] = bias
ctx.save_for_backward(input, weight, bias)

if input.dim() == 2 and bias is not None:
# fused op is marginally faster
Expand All @@ -79,11 +70,7 @@ def backward(ctx, grad_output):
# None. Thanks to the fact that additional trailing Nones are
# ignored, the return statement is simple even when the function has
# optional inputs.
#input, weight, bias = ctx.saved_tensors

input, weight_id, bias_id = ctx.saved_tensors
weight = tensor_map[weight_id.item()]
bias = tensor_map[bias_id.item()]
input, weight, bias = ctx.saved_tensors

grad_input = grad_weight = grad_bias = None

Expand Down

0 comments on commit 42858a9

Please sign in to comment.