-
Notifications
You must be signed in to change notification settings - Fork 498
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
gradient checkpointing isn't effectively triggered and fails to save memory #3455
Comments
I've also recently been implementing gradient checkpointing on torch xla. I haven't observed any memory savings by lowering torch.utils.checkpoint to xla either. Based on some experiments and graph analysis I have some guesses so far:
I also looked at the jax implementation. I'm also not familiar with jax, and I don't even know how to compile the full graph (if someone can tell me, thanks a lot). But I can still feel the difference between jax and torch xla in using xla. As mentioned above, torch xla's hlo graph is a flat, huge graph with only one computation(before optimization, except for particularly small subcomputation like min max add... ). jax has a much clearer and more structured graph. doesn't this structure of computation naturally avoid some of the optimizations that cause gradient checkpoint invalid? Does this computation structure naturally avoid some of the optimizations that cause gradient checkpoint invalid? Also, I found that even then jax has to circumvent CSE optimizations by using a while/if condition + random (the latest version replace by new optimization barrier op). |
I also tested the effect of HloRematerialization pass, and one of the more serious problems is that it skips the custom-call operator. Again, to avoid CSE optimizations, it is recommended in the comments that it be executed as far back as possible. By this time the hlo graph already has a large number of custom-calls (cudnn/cublas etc.). Also it skips the random operator, which is very common in the popular model, like transfromers-class. |
@cicirori Thanks for the comments!
This is consistent with what I observed in another case (I'm running on TPUs). The XLA compiler seems to fuse the operations and results in more memory consumption than necessary.
Yeah, I was trying to ask the XLA compiler to free memory by deleting (and garbage collecting) the PyTorch And even when using So far the memory consumption is a bit mysterious to me and I'm still learning how the XLA compiler handles it to best suit my memory-bounded programs.
I also think there is some difference in how PT-XLA and JAX use the XLA compiler. So far I'm observing that the same network uses more TPU memory in PT-XLA than JAX. For JAX, I was told to look into the
I feel another useful feature from JAX is its |
Thanks for the discussion, I should be able to work on this issue today or tmr. If there is a feature you think would be nice to have, please open a feature request and we will prioritizing that. |
Thanks, @JackCaoG!
I think in my case, I would like to learn more about how the XLA compiler allocates and frees tensors on TPU, and if there is a way to have fine-grained control over the compiler (such as encouraging it to optimize for memory reduction rather than speed). Do you know if there is any documentation on the XLA compiler I can look at? |
By reading the issue description, I think the problem is that we fused the forward and backward pass into one single graph. As you can image, this won't save any memory because there is no notion about On top of that I think xla compiler is smart enough to figure out the repeated computation you performed in the backward pass is unnecessary when it is executed along side with the forward pass. Therefore XLA compiler decided to skip the repeated forward computation.My suggestion would be adding a When you run the backward pass, XLA compiler should be able to figure out that activation tensor is not the output of the graph hence it should not be stored in the HBM after the computation.(I am less sure about this statement and I need to run some experiment) |
@JackCaoG Thanks for the comments!
I think in the case of gradient checkpointing, even if the forward pass and the backward pass are fused into a single graph, as long as the program is executed in the compute graph DAG's order and intermediate tensors are freed when they are no longer used in subsequent parts of the graph, then gradient checkpointing should still be effective (e.g. as in the toy program example below).
Yeah, I think this is the villain. The XLA compiler optimized the whole graph in a way that incurred more memory usage than the original graph before optimization, defeating the whole purpose of gradient checkpointing. I wonder in the simple case below, how would the XLA compiler operate? Suppose we ultimately want to compute
The "gradient checkpointing" version to compute
I wonder how would the XLA compiler treat this gradient checkpointing version above -- would it try to "de-duplicate" the computation above and compile the same program as the native version above (that sadly takes more memory)?
I feel this solution would often be undesirable in practical PyTorch XLA applications, unfortunately, as it is often not applicable to more general checkpointing cases and would result in great mark-step overhead and perhaps also frequent recompilation (e.g. in the case of inference, one may want to explicitly save the memory of the forward pass itself by doing some checkpointing-like operations at several places within the forward pass, like the toy case above). Another similar use case is FSDP for memory-bounded programs. FSDP with ZeRO-3 (#3431) can be seen as a general variant of gradient checkpointing by releasing the full parameters of a network layer after its forward pass (so that its memory can be used by the next layer) and then rebuilding the full parameters before its backward pass. So if we want to resort to explicitly cutting the graph via mark_step to recycle the memory of each layer, we would have to insert 48 mark_step calls in the forward pass alone to run a 48-layer network with ZeRO-3. In our earlier analyses, a single mark_step call can introduce up to 100 ms overhead (as it triggers synchronization with the XRT server, sets up tensor handles, etc; the overhead is especially great on TPU pods), so even in the case without recompilation, a few The root cause here seems to be that the XLA compiler unnecessarily removes repeated computation at the cost of holding a tensor longer in the computation graph and incurring more memory consumption. Even if we try to optimize our code in a memory-efficient way, the compiler essentially defeats our optimization and we have no control over it at this moment.
|
Let me follow up with xla team to see what's the best solution here. |
Let me get back to your comment #3455 (comment) later. I did a quick test by adding with
with
so the issue seems to be that cse(common-subexpression elimination) detected and remove the duplicated computation in the graph, which essentially undo the grad checkpointing. There is a new xla operator |
@JackCaoG Thanks for the follow-up. It's great to have the new operator |
I introduced a new api
this ensures that all of its inputs(in this case the result of the forward pass) are evaluated before any operators that depend on the running a mini repo on my end, without
with
For some reason I also see a much better executeTime which is really puzzling.. I will clean up the pr and posted it. @cicirori I haven't have change to test this on gpu, I think this op is not supported on CPU. If you are interested I can give it a try. |
Thanks for introducing this new API. This is exactly what we need here.
Yeah, the speed is a little puzzling, but the memory consumption looks more reasonable now. I'm going to test it on more cases on my end to further verify. |
@JackCaoG Adding optimization_barrier is a great feature! I'm experimenting with this op recently and I found adding a barrier can sometimes help save memory by reducing the overlap among buffer live ranges. This feature may be helpful to improve the performance of large models as we noticed xla has superlinear computational time growth as model size increases. FYI currently I got the following error when using this API on GPU, but it will be automatically solved when torch_xla updates its tensorflow version.
|
@ymwangg Ah nice. I plan to update tf version in ~10 days. I will also test it after we updated. |
This API seems to break the gradient flow when used in a naive way as described in #3486 (comment). For now, it needs be used as
|
Update: even when preserving the PyTorch autograd via @JackCaoG I'm thinking maybe the underlying reason is that we need to apply Steps to reproduce the new issue
Without gradient checkpointing:
With gradient checkpointing:
It can be seen that in both cases, the free memory is always 11514304, This is still the same as what we observed before the fix in #3455 (comment), suggesting that the gradient checkpointing issue isn't mitigated by Besides, the ZeRO-3 use cases in #3431 still fail under the new |
This is consistent with the results of my previous experiments. I have experimented with gradient checkpoint on the bert model, CSE is only one possible cause, and the optimizer barrier op (if I understand its implementation correctly) currently only prevents CSE optimization. In addition to this I suspect that the two main types of optimization passes, simplification and fusion, also cause the memory not to be saved. Also, as I mentioned above, I'm still not entirely clear on the mechanism of buffer assignment, so there may be engineering implementation issues as well. The above points are not solid, although I spent a lot of time on this issue. But it is very difficult for me to analyze a hlo graph with tens of thousands of instructions. Sorry for that. |
@JackCaoG
Since my tests were done on a very dirty code base (torch 1.10 + modified torch xla + 04/07/2022 tensorflow), I don't know if this will affect the final conclusions. So I'm going to build a pure master version of the image and test it again. (By the way, as we discussed before, it would be helpful to have an official image that works for both torch cuda and torch xla cuda to verify this kind of problem. |
Hi @cicirori, thanks for the analyses! I'm also looking at applying I wonder if you have compared the solution in #3486 (comment) with the alternative of explicitly calling a |
Pending item for this pr would be make a pytorch/xla version of the checkpointing following diff in #3486 (comment) |
#3524 is merged, I will close this pr for now. |
* Implement Fully Sharded Data Parallel (FSDP) in PyTorch XLA * move the FSDP module to `torch_xla.distributed` * adding `mark_step_on_freeing` as a temp workaround to #3455 * check in __init__ whether the module is already FSDP; fix exception types * add `optimization_barrier_` (#3493) to avoid fusion of full parameter reconstruction with subsequent freeing * also apply `xm.optimization_barrier_` to FSDP output's gradients * deprecate `mark_step_on_freeing` (since we have optimization barrier now) * add option to run a dummy forward pass in FSDP * add `_shard_size_multiple` to make sharded parameters a multiple of 128 for efficient all-gather (see #3510 (comment)) * refactor optimization_barrier_ to separately apply to forward and backward pass `_rebuild_full_params` and `_free_full_params` * seal off more relevant ops w/ optimization_barrier_ to avoid undesired fusion * remove obsolete `mark_step_on_freeing` and `use_all_gather_via_all_reduce` configs; unpin layout for all_reduce; add a wrapper for gradient checkpointing on modules; remove redundant `param_names` * handle keyword arguments in `checkpoint_module` * add gradient checkpointing option to MNIST and ImageNet FSDP examples * refactor `optimization_barrier` and only apply it in forward or backward when specified * refactor command line tool to consolidate sharded checkpoints * address reviewers' comments from GitHub * add more user instructions for checkpoint consolidation * change `flatten_parameters` default to False since it didn't bring an actual speed up in tests and breaks optimizer groups * documentation refinement
🐛 Bug
Gradient checkpointing (activation checkpointing) via
torch.utils.checkpoint.checkpoint
is supposed to save memory by recomputing the intermediate activations in the backward pass (rather than storing all intermediate activations of the entire computation graph) and has been working well in the vanilla PyTorch. However, I find that it fails to work in PyTorch XLA.Specifically, gradient checkpointing via
torch.utils.checkpoint.checkpoint
does not result in any reduced memory consumption when running on TPUs -- The memory consumption is almost identical with and without checkpointing when measured viaxm.get_memory_info
. And the more weird thing is that their runtime speed is also nearly identical, although gradient checkpointing is supposed to make the speed slower by trading off compute for memory. (In a typical case, gradient checkpointing would lead to 1.33X slower runtime as the backward pass usually takes double the time of the forward pass, and gradient checkpointing results in calling the forward pass twice).When monitoring the XLA metrics from
met.metrics_report()
, theDestroyDataHandles
metrics are identical between the two cases with or without gradient checkpointing, but theirDestroyXlaTensor
metrics are notably different (see details below). Similarly, theirCreateDataHandles
metrics are the same but theirCreateXlaTensor
metrics are different. And theirExecuteTime
metrics are also (nearly) identical.And in practical use cases in PyTorch XLA on vision transformers,
torch.utils.checkpoint.checkpoint
also doesn't save any TPU memory for a large model or allow using a larger batch size. I suspect the XLA compiler somehow ignored the gradient checkpointing.(A similar issue was previously reported in #1571.)
To Reproduce
tpu-vm-pt-1.10
runtime environment.~/test_grad_checkpoint.py
). It runs a dummy 64-layer network where each layer has two Conv2d + ReLU sub-modules, which should be able to benefit from gradient checkpointing in principle (since its inner conv activations no longer need to be stored when using checkpointing).Without gradient checkpointing:
With gradient checkpointing:
It can be seen that in both cases, the free memory is always 11514048,
ExecuteTime
is (nearly) the same (50% Percentile: 205ms527.677us vs 204ms481.195us), andCreateDataHandles
andDestroyDataHandles
are the same. ButCreateXlaTensor
andDestroyXlaTensor
are largely different.Expected behavior
Gradient checkpointing (
torch.utils.checkpoint.checkpoint
) should be able to trade off compute for memory in PyTorch XLA, reducing memory consumption at the cost of increasing runtime.Environment
tpu-vm-pt-1.10
runtime environment; and this also happens in nightly 20220308 builds)Additional context
I also observed that my FSDP implementation in #3431 (which is supposed to free full model parameters after each layer's forward pass and rebuild them before its backward pass, bearing some high-level similarity to gradient checkpointing) does not save TPU memory, unlike what I would expect. I suspect it could be the same underlying issue as what we saw here.
cc: @JackCaoG @ultrons
The text was updated successfully, but these errors were encountered: