Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gradient checkpointing isn't effectively triggered and fails to save memory #3455

Closed
ronghanghu opened this issue Mar 29, 2022 · 21 comments · Fixed by #3482
Closed

gradient checkpointing isn't effectively triggered and fails to save memory #3455

ronghanghu opened this issue Mar 29, 2022 · 21 comments · Fixed by #3482
Assignees

Comments

@ronghanghu
Copy link
Collaborator

ronghanghu commented Mar 29, 2022

🐛 Bug

Gradient checkpointing (activation checkpointing) via torch.utils.checkpoint.checkpoint is supposed to save memory by recomputing the intermediate activations in the backward pass (rather than storing all intermediate activations of the entire computation graph) and has been working well in the vanilla PyTorch. However, I find that it fails to work in PyTorch XLA.

Specifically, gradient checkpointing via torch.utils.checkpoint.checkpoint does not result in any reduced memory consumption when running on TPUs -- The memory consumption is almost identical with and without checkpointing when measured via xm.get_memory_info. And the more weird thing is that their runtime speed is also nearly identical, although gradient checkpointing is supposed to make the speed slower by trading off compute for memory. (In a typical case, gradient checkpointing would lead to 1.33X slower runtime as the backward pass usually takes double the time of the forward pass, and gradient checkpointing results in calling the forward pass twice).

When monitoring the XLA metrics from met.metrics_report(), the DestroyDataHandles metrics are identical between the two cases with or without gradient checkpointing, but their DestroyXlaTensor metrics are notably different (see details below). Similarly, their CreateDataHandles metrics are the same but their CreateXlaTensor metrics are different. And their ExecuteTime metrics are also (nearly) identical.

And in practical use cases in PyTorch XLA on vision transformers, torch.utils.checkpoint.checkpoint also doesn't save any TPU memory for a large model or allow using a larger batch size. I suspect the XLA compiler somehow ignored the gradient checkpointing.

(A similar issue was previously reported in #1571.)

To Reproduce

  1. Allocate a v3-8 TPU VM with tpu-vm-pt-1.10 runtime environment.
  2. Save the following content to a file (e.g. ~/test_grad_checkpoint.py). It runs a dummy 64-layer network where each layer has two Conv2d + ReLU sub-modules, which should be able to benefit from gradient checkpointing in principle (since its inner conv activations no longer need to be stored when using checkpointing).
import argparse
import torch
import torch.utils.checkpoint
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met


def run(grad_checkpoint):
    device = xm.xla_device()
    model = torch.nn.ModuleList(
        [
            torch.nn.Sequential(
                torch.nn.Conv2d(1024, 1024, 1),
                torch.nn.ReLU(),
                torch.nn.Conv2d(1024, 1024, 1),
                torch.nn.ReLU(),
            )
            for _ in range(64)
        ]
    ).to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.0)

    for step in range(200):
        dummy_data = torch.zeros(64, 1024, 14, 14, device=device)
        optimizer.zero_grad()
        x = dummy_data
        for n_l, layer in enumerate(model):
            if n_l > 0 and grad_checkpoint:
                x = torch.utils.checkpoint.checkpoint(layer, x)
            else:
                x = layer(x)
        dummy_loss = x.sum()
        dummy_loss.backward()
        optimizer.step()
        xm.mark_step()
        print(f"step {step}, free memory = {xm.get_memory_info(device)['kb_free']}")

    print(met.metrics_report())


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--grad_checkpoint", type=int, required=True)
    args = parser.parse_args()
    run(args.grad_checkpoint)
  1. Run this file without or with gradient checkpointing:
# without gradient checkpointing
python3 -u ~/test_grad_checkpoint.py --grad_checkpoint 0

# with gradient checkpointing
python3 -u ~/test_grad_checkpoint.py --grad_checkpoint 1
  1. Compare their final outputs (pasted below).

Without gradient checkpointing:

step 199, free memory = 11514048
...
Metric: ExecuteTime
  TotalSamples: 199
  Accumulator: 41s287ms669.394us
  ValueRate: 762ms791.627us / second
  Rate: 3.6718 / second
  Percentiles: 1%=204ms731.038us; 5%=204ms033.989us; 10%=204ms098.950us; 20%=204ms223.762us; 50%=205ms527.677us; 80%=205ms389.712us; 90%=208ms694.179us; 95%=228ms651.128us; 99%=280ms986.356us
...
Counter: CreateDataHandles
  Value: 102741
Counter: CreateXlaTensor
  Value: 256800
Counter: DestroyDataHandles
  Value: 102485
Counter: DestroyXlaTensor
  Value: 256285
...

With gradient checkpointing:

step 199, free memory = 11514048
...
Metric: ExecuteTime
  TotalSamples: 199
  Accumulator: 41s185ms445.635us
  ValueRate: 776ms545.968us / second
  Rate: 3.74729 / second
  Percentiles: 1%=204ms971.108us; 5%=204ms047.071us; 10%=204ms119.401us; 20%=204ms223.839us; 50%=204ms481.195us; 80%=206ms626.759us; 90%=213ms479.502us; 95%=228ms529.960us; 99%=253ms493.759us
...
Counter: CreateDataHandles
  Value: 102741
Counter: CreateXlaTensor
  Value: 307200
Counter: DestroyDataHandles
  Value: 102485
Counter: DestroyXlaTensor
  Value: 306685
...

It can be seen that in both cases, the free memory is always 11514048, ExecuteTime is (nearly) the same (50% Percentile: 205ms527.677us vs 204ms481.195us), and CreateDataHandles and DestroyDataHandles are the same. But CreateXlaTensor and DestroyXlaTensor are largely different.

Expected behavior

Gradient checkpointing (torch.utils.checkpoint.checkpoint) should be able to trade off compute for memory in PyTorch XLA, reducing memory consumption at the cost of increasing runtime.

Environment

  • Reproducible on XLA backend [CPU/TPU]: v3-8 TPU VM (tpu-vm-pt-1.10 runtime environment; and this also happens in nightly 20220308 builds)
  • torch_xla version: 1.10

Additional context

I also observed that my FSDP implementation in #3431 (which is supposed to free full model parameters after each layer's forward pass and rebuild them before its backward pass, bearing some high-level similarity to gradient checkpointing) does not save TPU memory, unlike what I would expect. I suspect it could be the same underlying issue as what we saw here.

cc: @JackCaoG @ultrons

@cicirori
Copy link
Contributor

cicirori commented Mar 30, 2022

I've also recently been implementing gradient checkpointing on torch xla.

I haven't observed any memory savings by lowering torch.utils.checkpoint to xla either.

Based on some experiments and graph analysis I have some guesses so far:

  1. the forward computation and backward re-computation are fused or merged together. This was verified on a very simple linear model I written for test only: adding a mark step after the forward, the expected memory savings were observed.
  2. However, on the bert model, no memory reduce is observed even if the mark step is added forward and backward, so I guess there is some dependency between the backward re-computations. Since these re-computations are very similar, there is no hint for hlo not to optimize them jointly.
  3. torch xla trace's hlo graph is a large flat computation, and I wonder if this causes the buffer in one computation wouldn't be unfree until the whole computation finished . According to my observation of the dumped buffer assignment the total buffer size is larger than the sum of the buffer sizes of all instructions at max hlo live range. I'm not very familiar with the internal mechanism of xla, so I'm not sure what this means.

I also looked at the jax implementation. I'm also not familiar with jax, and I don't even know how to compile the full graph (if someone can tell me, thanks a lot). But I can still feel the difference between jax and torch xla in using xla. As mentioned above, torch xla's hlo graph is a flat, huge graph with only one computation(before optimization, except for particularly small subcomputation like min max add... ).

jax has a much clearer and more structured graph. doesn't this structure of computation naturally avoid some of the optimizations that cause gradient checkpoint invalid? Does this computation structure naturally avoid some of the optimizations that cause gradient checkpoint invalid?

Also, I found that even then jax has to circumvent CSE optimizations by using a while/if condition + random (the latest version replace by new optimization barrier op).

@cicirori
Copy link
Contributor

cicirori commented Mar 30, 2022

I also tested the effect of HloRematerialization pass, and one of the more serious problems is that it skips the custom-call operator. Again, to avoid CSE optimizations, it is recommended in the comments that it be executed as far back as possible. By this time the hlo graph already has a large number of custom-calls (cudnn/cublas etc.). Also it skips the random operator, which is very common in the popular model, like transfromers-class.

@ronghanghu
Copy link
Collaborator Author

@cicirori Thanks for the comments!

adding a mark step after the forward, the expected memory savings were observed.

This is consistent with what I observed in another case (I'm running on TPUs). The XLA compiler seems to fuse the operations and results in more memory consumption than necessary.

torch xla trace's hlo graph is a large flat computation, and I wonder if this causes the buffer in one computation wouldn't be unfree until the whole computation finished.

Yeah, I was trying to ask the XLA compiler to free memory by deleting (and garbage collecting) the PyTorch torch.Tensor object on the Python side. However, this seems often unsuccessful unless I explicitly insert a mark_step after it.

And even when using mark_step afterwards, sometimes it still doesn't free all the memory I would expect. For example, in #3330 (comment) it ended up with only 3.5 GB free memory (out of the 16 GB TPU memory) left despite that nothing should be taking up memory after the computation is done.

So far the memory consumption is a bit mysterious to me and I'm still learning how the XLA compiler handles it to best suit my memory-bounded programs.

But I can still feel the difference between jax and torch xla in using xla. As mentioned above, torch xla's hlo graph is a flat, huge graph with only one computation(before optimization, except for particularly small subcomputation like min max add... ).

I also think there is some difference in how PT-XLA and JAX use the XLA compiler. So far I'm observing that the same network uses more TPU memory in PT-XLA than JAX. For JAX, I was told to look into the donate argument to pmap/jit, and try to "donate" the model params and optimizer state or anything to make XLA update in place if possible. (I wonder if there is something similar to the donate argument in PT-XLA.)

jax has a much clearer and more structured graph.

I feel another useful feature from JAX is its scan argument to compile a block and then re-use it. This seems quite helpful in reducing the compilation time (and perhaps also making the graph more structured). Currently, for a 48-layer network (ViT-G in https://arxiv.org/pdf/2106.04560.pdf) the compilation is taking 30+ minutes in PT-XLA, so it would be quite helpful if one can just compile one network block and then re-use it to build the whole network.

@JackCaoG
Copy link
Collaborator

Thanks for the discussion, I should be able to work on this issue today or tmr. If there is a feature you think would be nice to have, please open a feature request and we will prioritizing that.

@ronghanghu
Copy link
Collaborator Author

Thanks, @JackCaoG!

If there is a feature you think would be nice to have, please open a feature request and we will prioritizing that.

I think in my case, I would like to learn more about how the XLA compiler allocates and frees tensors on TPU, and if there is a way to have fine-grained control over the compiler (such as encouraging it to optimize for memory reduction rather than speed). Do you know if there is any documentation on the XLA compiler I can look at?

@JackCaoG
Copy link
Collaborator

JackCaoG commented Apr 1, 2022

By reading the issue description, I think the problem is that we fused the forward and backward pass into one single graph. As you can image, this won't save any memory because there is no notion about foward and backward if it is in one graph.

On top of that I think xla compiler is smart enough to figure out the repeated computation you performed in the backward pass is unnecessary when it is executed along side with the forward pass. Therefore XLA compiler decided to skip the repeated forward computation.My suggestion would be adding a mark_step after the forward pass to explicitly cut the graph. I need to confirm but I think activation tensor won't reside in HBM.

When you run the backward pass, XLA compiler should be able to figure out that activation tensor is not the output of the graph hence it should not be stored in the HBM after the computation.(I am less sure about this statement and I need to run some experiment)

@ronghanghu
Copy link
Collaborator Author

ronghanghu commented Apr 1, 2022

@JackCaoG Thanks for the comments!

I think the problem is that we fused the forward and backward pass into one single graph. As you can image, this won't save any memory because there is no notion about foward and backward if it is in one graph.

I think in the case of gradient checkpointing, even if the forward pass and the backward pass are fused into a single graph, as long as the program is executed in the compute graph DAG's order and intermediate tensors are freed when they are no longer used in subsequent parts of the graph, then gradient checkpointing should still be effective (e.g. as in the toy program example below).

On top of that I think xla compiler is smart enough to figure out the repeated computation you performed in the backward pass is unnecessary when it is executed along side with the forward pass.

Yeah, I think this is the villain. The XLA compiler optimized the whole graph in a way that incurred more memory usage than the original graph before optimization, defeating the whole purpose of gradient checkpointing.


I wonder in the simple case below, how would the XLA compiler operate?

Suppose we ultimately want to compute e below. A native (computation-efficient but memory-inefficient) graph is as follows, which would require at least 12 GB of memory to execute:

# need 4 GB to hold variable a
a = torch.sin(torch.ones(1024**3, requires_grad=False))

# need 8 GB to hold variable a, b
b = torch.sin(a)

# need 12 GB to hold variable a, b, c
c = torch.sin(b)

# need 8 GB to hold variable a, d
d = torch.sin(c) + torch.cos(b); del b, c

# need 4 GB to hold variable e
e = torch.sin(d) + torch.cos(a); del a, d

The "gradient checkpointing" version to compute e (computation-inefficient but memory-efficient) would only need 8 GB of memory to execute by trading off compute for memory, effectively "checkpointing" variables a and b:

# need 4 GB to hold variable a
a = torch.sin(torch.ones(1024**3, requires_grad=False))

# need 4 GB to hold variable b
b = torch.sin(a); del a

# need 4 GB to hold variable c
c = torch.sin(b); del b

######## recompute b via checkpointing ########
# need 8 GB to hold variable a, c
a = torch.sin(torch.ones(1024**3, requires_grad=False))

# need 8 GB to hold variable b, c
b = torch.sin(a); del a
###############################################

# need 4 GB to hold variable d
d = torch.sin(c) + torch.cos(b); del b, c

######## recompute a via checkpointing ########
# need 8 GB to hold variable a, d
a = torch.sin(torch.ones(1024**3, requires_grad=False))
###############################################

# need 4 GB to hold variable e
e = torch.sin(d) + torch.cos(a); del a, d

I wonder how would the XLA compiler treat this gradient checkpointing version above -- would it try to "de-duplicate" the computation above and compile the same program as the native version above (that sadly takes more memory)?


My suggestion would be adding a mark_step after the forward pass to explicitly cut the graph.

I feel this solution would often be undesirable in practical PyTorch XLA applications, unfortunately, as it is often not applicable to more general checkpointing cases and would result in great mark-step overhead and perhaps also frequent recompilation (e.g. in the case of inference, one may want to explicitly save the memory of the forward pass itself by doing some checkpointing-like operations at several places within the forward pass, like the toy case above).

Another similar use case is FSDP for memory-bounded programs. FSDP with ZeRO-3 (#3431) can be seen as a general variant of gradient checkpointing by releasing the full parameters of a network layer after its forward pass (so that its memory can be used by the next layer) and then rebuilding the full parameters before its backward pass. So if we want to resort to explicitly cutting the graph via mark_step to recycle the memory of each layer, we would have to insert 48 mark_step calls in the forward pass alone to run a 48-layer network with ZeRO-3.

In our earlier analyses, a single mark_step call can introduce up to 100 ms overhead (as it triggers synchronization with the XRT server, sets up tensor handles, etc; the overhead is especially great on TPU pods), so even in the case without recompilation, a few mark_step in the graph (as in general cases, if we want to do checkpointing at multiple locations in the program) would make the PyTorch XLA code prohibitively slow.

The root cause here seems to be that the XLA compiler unnecessarily removes repeated computation at the cost of holding a tensor longer in the computation graph and incurring more memory consumption. Even if we try to optimize our code in a memory-efficient way, the compiler essentially defeats our optimization and we have no control over it at this moment.

  • I wonder if there is a way we can stop the compiler from de-duplicating computations (and incurring more memory cost) in cases when it is undesired?
  • Also, is there a way to insert a compilation barrier in the graph (maybe something like xm.compiler_barrier) that doesn't trigger a mark_step but tells the compiler to separately compile each sub-graph (segmented by this barrier) and then stitch the compiled sub-graphs together, but without fusion across the barrier? I think the "scan" feature in JAX would do something similar for this purpose. Is it possible to bring it into PyTorch XLA?

@JackCaoG
Copy link
Collaborator

JackCaoG commented Apr 1, 2022

Let me follow up with xla team to see what's the best solution here.

@JackCaoG
Copy link
Collaborator

JackCaoG commented Apr 6, 2022

Let me get back to your comment #3455 (comment) later. I did a quick test by adding xm.mark_step() and xm.wait_device_ops() after the forward pass and I observe the expected memory saving.

with grad_checkpoint 0 I saw

step 0, free memory after forward = 9199248
step 0, free memory = 8525728
step 1, free memory after forward = 8484320
step 1, free memory = 7851056
....
step 18, free memory after forward = 7851056
step 18, free memory = 7851056
step 19, free memory after forward = 7851056
step 19, free memory = 7851056
Metric: ExecuteTime
  TotalSamples: 40
  Accumulator: 05s654ms088.723us
  ValueRate: 191ms566.291us / second
  Rate: 1.63784 / second
  Percentiles: 1%=066ms746.006us; 5%=066ms860.618us; 10%=066ms955.470us; 20%=066ms394.789us; 50%=149ms759.926us; 80%=150ms024.912us; 90%=151ms890.630us; 95%=193ms887.565us; 99%=198ms158.770us
Counter: CreateDataHandles
  Value: 18216
Counter: CreateXlaTensor
  Value: 25680
Counter: DestroyDataHandles
  Value: 17550
Counter: DestroyXlaTensor
  Value: 25165

with grad_checkpoint 1 I saw

step 0, free memory after forward = 12377440
step 0, free memory = 11359456
step 1, free memory after forward = 11335136
step 1, free memory = 10910112
....
step 19, free memory after forward = 10910112
step 19, free memory = 10910112
Metric: ExecuteTime
  TotalSamples: 40
  Accumulator: 06s387ms363.482us
  ValueRate: 178ms427.421us / second
  Rate: 1.11738 / second
  Percentiles: 1%=090ms514.084us; 5%=090ms140.683us; 10%=090ms209.193us; 20%=091ms070.170us; 50%=211ms850.579us; 80%=213ms701.990us; 90%=214ms573.371us; 95%=261ms522.963us; 99%=270ms042.494us
Counter: CreateDataHandles
  Value: 16956
Counter: CreateXlaTensor
  Value: 30720
Counter: DestroyDataHandles
  Value: 16441
Counter: DestroyXlaTensor
  Value: 30205

so the issue seems to be that cse(common-subexpression elimination) detected and remove the duplicated computation in the graph, which essentially undo the grad checkpointing. There is a new xla operator optimizationbarrier introduced to fix this exact issue. Let me try to bring this op to pytorch/xla and see if that fixed our issue here.

@ronghanghu
Copy link
Collaborator Author

so the issue seems to be that cse(common-subexpression elimination) detected and remove the duplicated computation in the graph, which essentially undo the grad checkpointing. There is a new xla operator optimizationbarrier introduced to fix this exact issue. Let me try to bring this op to pytorch/xla and see if that fixed our issue here.

@JackCaoG Thanks for the follow-up. It's great to have the new operator optimizationbarrier introduced in PyTorch/XLA. Really looking forward to it!

@JackCaoG
Copy link
Collaborator

JackCaoG commented Apr 8, 2022

I introduced a new api xm.optimization_barrier, the usage is pretty straightforward, you call this api on the output of the forward pass. In the example posted above, you simply need to do


        x = xm.optimization_barrier(x)

        dummy_loss = x.sum()
        dummy_loss.backward()
        optimizer.step()

this ensures that all of its inputs(in this case the result of the forward pass) are evaluated before any operators that depend on the optimization_barrier's outputs, and XLA is forbidden from eliding or moving non-trivial computations across optimization_barrier operators.

running a mini repo on my end, without optimization_barrier, I see

step 198, free memory after forward = 11514304
step 198, free memory = 11514304
step 199, free memory after forward = 11514304
step 199, free memory = 11514304
Metric: ExecuteTime
  TotalSamples: 200
  Accumulator: 41s329ms857.374us
  ValueRate: 590ms547.571us / second
  Rate: 2.85296 / second
  Percentiles: 1%=204ms379.962us; 5%=205ms526.820us; 10%=205ms608.515us; 20%=205ms785.190us; 50%=205ms188.236us; 80%=206ms065.475us; 90%=207ms667.218us; 95%=218ms647.765us; 99%=259ms730.178us

with optimization barrier I saw

step 198, free memory after forward = 15487456
step 198, free memory = 15487456
step 199, free memory after forward = 15487456
step 199, free memory = 15487456
Metric: ExecuteTime
  TotalSamples: 200
  Accumulator: 12s261ms783.098us
  ValueRate: 763ms711.836us / second
  Rate: 12.4415 / second
  Percentiles: 1%=061ms649.505us; 5%=061ms756.333us; 10%=061ms832.865us; 20%=061ms923.435us; 50%=061ms077.979us; 80%=061ms220.387us; 90%=061ms284.215us; 95%=061ms348.149us; 99%=081ms290.276us

For some reason I also see a much better executeTime which is really puzzling.. I will clean up the pr and posted it.

@cicirori I haven't have change to test this on gpu, I think this op is not supported on CPU. If you are interested I can give it a try.

@ronghanghu
Copy link
Collaborator Author

Thanks for introducing this new API. This is exactly what we need here.

For some reason I also see a much better executeTime which is really puzzling.. I will clean up the pr and posted it.

Yeah, the speed is a little puzzling, but the memory consumption looks more reasonable now. I'm going to test it on more cases on my end to further verify.

@ymwangg
Copy link
Contributor

ymwangg commented Apr 8, 2022

@JackCaoG Adding optimization_barrier is a great feature! I'm experimenting with this op recently and I found adding a barrier can sometimes help save memory by reducing the overlap among buffer live ranges. This feature may be helpful to improve the performance of large models as we noticed xla has superlinear computational time growth as model size increases.

FYI currently I got the following error when using this API on GPU, but it will be automatically solved when torch_xla updates its tensorflow version.

RuntimeError: INTERNAL: From /job:localservice/replica:0/task:0:
2 root error(s) found.
  (0) INTERNAL: LHLO opcode opt-barrier is not supported.
	 [[{{node XRTCompile}}]]
	 [[XRTCompile_G3]]
  (1) INTERNAL: LHLO opcode opt-barrier is not supported.
	 [[{{node XRTCompile}}]]
0 successful operations.
0 derived errors ignored.
Recent warning and error logs:
  OP_REQUIRES failed at xrt_compile_ops.cc:221 : INTERNAL: LHLO opcode opt-barrier is not supported.

@JackCaoG
Copy link
Collaborator

JackCaoG commented Apr 8, 2022

@ymwangg Ah nice. I plan to update tf version in ~10 days. I will also test it after we updated.

@ronghanghu
Copy link
Collaborator Author

ronghanghu commented Apr 8, 2022

I can confirm that gradient checkpointing is working well in ViT-G in https://arxiv.org/abs/2106.04560. I'm testing it in more models and also checking whether it fixes my previous FSDP ZeRO-3 cases in #3431.

Another interesting/weird observation I had is that on ViT-G xm.optimization_barrier not only fixes my gradient checkpointing (i.e. reduces memory consumption) but also increases my training speed. This is consistent with @JackCaoG's observation above and indicates that there are some problems with how the current XLA compiler works.

This API seems to break the gradient flow when used in a naive way as described in #3486 (comment). For now, it needs be used as

x.data, = xm.optimization_barrier([x])

@ronghanghu
Copy link
Collaborator Author

ronghanghu commented Apr 9, 2022

Update: even when preserving the PyTorch autograd via x.data, = xm.optimization_barrier([x]) as in #3486 (comment), it seems that gradient checkpointing still doesn't work with this new API.

@JackCaoG I'm thinking maybe the underlying reason is that we need to apply xm.optimization_barrier to more tensors than x alone? Since this API is also used in TF and JAX, I wonder how TF and JAX handle this case and accomplish gradient checkpointing? Anyway, I guess we should re-open this issue.


Steps to reproduce the new issue

  1. Created an updated new file ~/test_grad_checkpoint_with_optimization_barrier.py with x.data, = xm.optimization_barrier([x]) between the forward and the backward pass follows:
import argparse
import torch
import torch.utils.checkpoint
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met


def run(grad_checkpoint):
    device = xm.xla_device()
    model = torch.nn.ModuleList(
        [
            torch.nn.Sequential(
                torch.nn.Conv2d(1024, 1024, 1),
                torch.nn.ReLU(),
                torch.nn.Conv2d(1024, 1024, 1),
                torch.nn.ReLU(),
            )
            for _ in range(64)
        ]
    ).to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.0)

    for step in range(200):
        dummy_data = torch.zeros(64, 1024, 14, 14, device=device)
        optimizer.zero_grad()
        x = dummy_data
        for n_l, layer in enumerate(model):
            if n_l > 0 and grad_checkpoint:
                x = torch.utils.checkpoint.checkpoint(layer, x)
            else:
                x = layer(x)
        # need to assign to .data to avoid breaking autograd (see #3486)
        x.data, = xm.optimization_barrier([x])
        dummy_loss = x.sum()
        dummy_loss.backward()
        optimizer.step()
        xm.mark_step()
        print(f"step {step}, free memory = {xm.get_memory_info(device)['kb_free']}")

    print(met.metrics_report())


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--grad_checkpoint", type=int, required=True)
    args = parser.parse_args()
    run(args.grad_checkpoint)
  1. Run this file without or with gradient checkpointing (using a v3-8 TPU VM with nightly 20220408 builds):
# without gradient checkpointing
python3 -u ~/test_grad_checkpoint_with_optimization_barrier.py --grad_checkpoint 0

# with gradient checkpointing
python3 -u ~/test_grad_checkpoint_with_optimization_barrier.py --grad_checkpoint 1
  1. Compare their final outputs (pasted below).

Without gradient checkpointing:

step 199, free memory = 11503264
...
Metric: ExecuteTime     
  TotalSamples: 199     
  Accumulator: 41s260ms369.778us                                                                                                                                                                                                                   
  ValueRate: 769ms432.987us / second
  Rate: 3.711 / second  
  Percentiles: 1%=204ms331.204us; 5%=204ms440.245us; 10%=204ms493.746us; 20%=205ms631.073us; 50%=205ms935.018us; 80%=206ms846.295us; 90%=215ms659.551us; 95%=221ms119.383us; 99%=273ms743.510us 
...
Counter: CreateDataHandles
  Value: 102741            
Counter: CreateXlaTensor 
  Value: 257000                                                                                                                                                                                                                                    
Counter: DestroyDataHandles
  Value: 102485    
Counter: DestroyXlaTensor       
  Value: 256485
...

With gradient checkpointing:

step 199, free memory = 11514304
...
Metric: ExecuteTime     
  TotalSamples: 199     
  Accumulator: 41s397ms317.994us                                                                                                                                                                                                                   
  ValueRate: 770ms061.803us / second
  Rate: 3.70174 / second
  Percentiles: 1%=204ms277.648us; 5%=204ms480.609us; 10%=205ms664.033us; 20%=205ms811.664us; 50%=205ms416.216us; 80%=206ms415.858us; 90%=215ms461.984us; 95%=221ms528.006us; 99%=259ms486.881us
...
​​Counter: CreateDataHandles
  Value: 102741                                                                                                                                                                                                                                    
Counter: CreateXlaTensor 
  Value: 307400                                                                                                                                                                                                                                    
Counter: DestroyDataHandles     
  Value: 102485                     
Counter: DestroyXlaTensor       
  Value: 306885
...

It can be seen that in both cases, the free memory is always 11514304, ExecuteTime is the same (50% Percentile: 205ms935.018us vs 205ms416.216us), and CreateDataHandles and DestroyDataHandles are the same. But CreateXlaTensor and DestroyXlaTensor are largely different.

This is still the same as what we observed before the fix in #3455 (comment), suggesting that the gradient checkpointing issue isn't mitigated by optimization_barrier.

Besides, the ZeRO-3 use cases in #3431 still fail under the new optimization_barrier API. I suspect it is due to the same underlying issue as the gradient checkpointing failure here.

@cicirori
Copy link
Contributor

cicirori commented Apr 10, 2022

Update: even when preserving the PyTorch autograd via x.data, = xm.optimization_barrier([x]) as in #3486 (comment), it seems that gradient checkpointing still doesn't work with this new API.

@JackCaoG I'm thinking maybe the underlying reason is that we need to apply xm.optimization_barrier to more tensors than x alone? Since this API is also used in TF and JAX, I wonder how TF and JAX handle this case and accomplish gradient checkpointing? Anyway, I guess we should re-open this issue.

Steps to reproduce the new issue

  1. Created an updated new file ~/test_grad_checkpoint_with_optimization_barrier.py with x.data, = xm.optimization_barrier([x]) between the forward and the backward pass follows:
import argparse
import torch
import torch.utils.checkpoint
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met


def run(grad_checkpoint):
    device = xm.xla_device()
    model = torch.nn.ModuleList(
        [
            torch.nn.Sequential(
                torch.nn.Conv2d(1024, 1024, 1),
                torch.nn.ReLU(),
                torch.nn.Conv2d(1024, 1024, 1),
                torch.nn.ReLU(),
            )
            for _ in range(64)
        ]
    ).to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.0)

    for step in range(200):
        dummy_data = torch.zeros(64, 1024, 14, 14, device=device)
        optimizer.zero_grad()
        x = dummy_data
        for n_l, layer in enumerate(model):
            if n_l > 0 and grad_checkpoint:
                x = torch.utils.checkpoint.checkpoint(layer, x)
            else:
                x = layer(x)
        # need to assign to .data to avoid breaking autograd (see #3486)
        x.data, = xm.optimization_barrier([x])
        dummy_loss = x.sum()
        dummy_loss.backward()
        optimizer.step()
        xm.mark_step()
        print(f"step {step}, free memory = {xm.get_memory_info(device)['kb_free']}")

    print(met.metrics_report())


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--grad_checkpoint", type=int, required=True)
    args = parser.parse_args()
    run(args.grad_checkpoint)
  1. Run this file without or with gradient checkpointing (using a v3-8 TPU VM with nightly 20220408 builds):
# without gradient checkpointing
python3 -u ~/test_grad_checkpoint_with_optimization_barrier.py --grad_checkpoint 0

# with gradient checkpointing
python3 -u ~/test_grad_checkpoint_with_optimization_barrier.py --grad_checkpoint 1
  1. Compare their final outputs (pasted below).

Without gradient checkpointing:

step 199, free memory = 11503264
...
Metric: ExecuteTime     
  TotalSamples: 199     
  Accumulator: 41s260ms369.778us                                                                                                                                                                                                                   
  ValueRate: 769ms432.987us / second
  Rate: 3.711 / second  
  Percentiles: 1%=204ms331.204us; 5%=204ms440.245us; 10%=204ms493.746us; 20%=205ms631.073us; 50%=205ms935.018us; 80%=206ms846.295us; 90%=215ms659.551us; 95%=221ms119.383us; 99%=273ms743.510us 
...
Counter: CreateDataHandles
  Value: 102741            
Counter: CreateXlaTensor 
  Value: 257000                                                                                                                                                                                                                                    
Counter: DestroyDataHandles
  Value: 102485    
Counter: DestroyXlaTensor       
  Value: 256485
...

With gradient checkpointing:

step 199, free memory = 11514304
...
Metric: ExecuteTime     
  TotalSamples: 199     
  Accumulator: 41s397ms317.994us                                                                                                                                                                                                                   
  ValueRate: 770ms061.803us / second
  Rate: 3.70174 / second
  Percentiles: 1%=204ms277.648us; 5%=204ms480.609us; 10%=205ms664.033us; 20%=205ms811.664us; 50%=205ms416.216us; 80%=206ms415.858us; 90%=215ms461.984us; 95%=221ms528.006us; 99%=259ms486.881us
...
​​Counter: CreateDataHandles
  Value: 102741                                                                                                                                                                                                                                    
Counter: CreateXlaTensor 
  Value: 307400                                                                                                                                                                                                                                    
Counter: DestroyDataHandles     
  Value: 102485                     
Counter: DestroyXlaTensor       
  Value: 306885
...

It can be seen that in both cases, the free memory is always 11514304, ExecuteTime is the same (50% Percentile: 205ms935.018us vs 205ms416.216us), and CreateDataHandles and DestroyDataHandles are the same. But CreateXlaTensor and DestroyXlaTensor are largely different.

This is still the same as what we observed before the fix in #3455 (comment), suggesting that the gradient checkpointing issue isn't mitigated by optimization_barrier.

Besides, the ZeRO-3 use cases in #3431 still fail under the new optimization_barrier API. I suspect it is due to the same underlying issue as the gradient checkpointing failure here.

This is consistent with the results of my previous experiments.

I have experimented with gradient checkpoint on the bert model, CSE is only one possible cause, and the optimizer barrier op (if I understand its implementation correctly) currently only prevents CSE optimization.

In addition to this I suspect that the two main types of optimization passes, simplification and fusion, also cause the memory not to be saved.

Also, as I mentioned above, I'm still not entirely clear on the mechanism of buffer assignment, so there may be engineering implementation issues as well.

The above points are not solid, although I spent a lot of time on this issue. But it is very difficult for me to analyze a hlo graph with tens of thousands of instructions. Sorry for that.

@JackCaoG JackCaoG reopened this Apr 11, 2022
ronghanghu added a commit to ronghanghu/xla that referenced this issue Apr 12, 2022
@cicirori
Copy link
Contributor

@JackCaoG
I have tested the latest solution #3486 on a bert class model. I do see a memory save, this is exciting. But there is still a big gap with torch native.

V100 16GB bs=256 no gc bs=256 with gc bs=1536 with gc
torch xla 5.934GiB 4.560 GiB 24.83GiB(OOM)
torch native 6.561GiB 2.560 GiB 13.192GiB

Since my tests were done on a very dirty code base (torch 1.10 + modified torch xla + 04/07/2022 tensorflow), I don't know if this will affect the final conclusions. So I'm going to build a pure master version of the image and test it again.

(By the way, as we discussed before, it would be helpful to have an official image that works for both torch cuda and torch xla cuda to verify this kind of problem.

@ronghanghu
Copy link
Collaborator Author

Hi @cicirori, thanks for the analyses! I'm also looking at applying xm.optimization_barrier_ in another use case (#3431).

I wonder if you have compared the solution in #3486 (comment) with the alternative of explicitly calling a xm.mark_step between the forward and the backward pass? Does the case of explicit mark-step match torch native's memory usage?

hjm-aws referenced this issue Apr 18, 2022
@JackCaoG
Copy link
Collaborator

Pending item for this pr would be make a pytorch/xla version of the checkpointing following diff in #3486 (comment)

@JackCaoG
Copy link
Collaborator

#3524 is merged, I will close this pr for now.

ronghanghu added a commit to ronghanghu/xla that referenced this issue Apr 28, 2022
ronghanghu added a commit to ronghanghu/xla that referenced this issue May 6, 2022
miladm pushed a commit that referenced this issue May 9, 2022
* Implement Fully Sharded Data Parallel (FSDP) in PyTorch XLA

* move the FSDP module to `torch_xla.distributed`

* adding `mark_step_on_freeing` as a temp workaround to #3455

* check in __init__ whether the module is already FSDP; fix exception types

* add `optimization_barrier_` (#3493) to avoid fusion of full parameter reconstruction with subsequent freeing

* also apply `xm.optimization_barrier_` to FSDP output's gradients

* deprecate `mark_step_on_freeing` (since we have optimization barrier now)

* add option to run a dummy forward pass in FSDP

* add `_shard_size_multiple` to make sharded parameters a multiple of 128 for efficient all-gather (see #3510 (comment))

* refactor optimization_barrier_ to separately apply to forward and backward pass `_rebuild_full_params` and `_free_full_params`

* seal off more relevant ops w/ optimization_barrier_ to avoid undesired fusion

* remove obsolete `mark_step_on_freeing` and `use_all_gather_via_all_reduce` configs; unpin layout for all_reduce; add a wrapper for gradient checkpointing on modules; remove redundant `param_names`

* handle keyword arguments in `checkpoint_module`

* add gradient checkpointing option to MNIST and ImageNet FSDP examples

* refactor `optimization_barrier` and only apply it in forward or backward when specified

* refactor command line tool to consolidate sharded checkpoints

* address reviewers' comments from GitHub

* add more user instructions for checkpoint consolidation

* change `flatten_parameters` default to False since it didn't bring an actual speed up in tests and breaks optimizer groups

* documentation refinement
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants