Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizer CPU offload for single GPU training #584

Merged
merged 31 commits into from
Aug 6, 2024

Conversation

gau-nernst
Copy link
Collaborator

@gau-nernst gau-nernst commented Aug 1, 2024

Background

Currently there is no simple way to do optimizer CPU offload for single GPU training, although such feature exists for FSDP. DeepSpeed ZeRO-Offload can work with single GPU, but it requires installing DeepSpeed, which can be complicated, and major changes to training loop (not convenient to switch between DeepSpeed and non-DeepSpeed).

Optimizer memory footprint is the largest in a training setup (2x model size for plain Adam), thus offloading optimizer to CPU would be greatly beneficial.

Below is a copy of optimizer CPU offload README

Optimizer CPU Offload

This folder also implements optimizer CPU offload (i.e. ZeRO-Offload) for single GPU training. For multi-GPU training, you can use FSDP's built-in CPU offload.

import torch
from torchao.prototype.low_bit_optim import CPUOffloadOptimizer

model = ...

# only offload optimizer state
optim = CPUOffloadOptimizer(model.parameters(), torch.optim.AdamW, fused=True)

# offload optimizer state AND gradients
optim = CPUOffloadOptimizer(model.parameters(), torch.optim.AdamW, offload_gradients=True, fused=True)

This will reduce GPU memory usage by optimizer state size, and additionally gradient size if offload_gradients=True. CPUOffloadOptimizer can wrap any base optimizer.

For saving and loading CPUOffloadOptimizer, it is important that you load model's weights BEFORE creating the optimizer, since we create a CPU copy of the parameters inside CPUOffloadOptimizer.__init__(). (TODO: we might want to have a method to synchronize CUDA and CPU params in either direction CPU->CUDA and CUDA->CPU, in case they are out of sync.)

ckpt = torch.load("checkpoint.pth")

model = ...
model.load_state_dict(ckpt["model"])

optim = CPUOffloadOptimizer(model.parameters(), torch.optim.AdamW, fused=True)
optim.load_state_dict(ckpt["optim"])

NOTE:

  • Since the optimizer step is done on CPU, it is highly recommended to use a fast CPU optimizer, such as torch.optim.AdamW(fused=True) (requires PyTorch 2.4). For other optimizers, you can try torch.compile() their optimizer step.
  • To minimize the amount of CPU<->GPU data transfer, we keep a copy of parameters and pre-allocate gradients memory on CPU. Therefore, expect your RAM usage to increase by 2x model size + optimizer state (which is 2x model size for Adam).
  • It is recommended NOT to torch.compile() your whole model when CPUOffloadOptimizer is used, as it prevents us from interleaving gradient device-to-host transfer with backward pass. To minimize such impact, you can compile parts of your model separately. See #584 for more information.
  • CPU optimizer step is often the bottleneck when optimizer CPU offload is used. To minimize the slowdown, it is recommended to (1) do full BF16 training (instead of AMP), so that parameters, gradients, and optimizer states are in BF16; and (2) give GPU more work per optimizer step (e.g. larger batch size with activation checkpointing, gradient accumulation).
  • offload_gradients=True is not compatible with gradient accumulation, since we clear gradients on GPU every backward pass.
  • Gradient clipping is currently not supported.

Benchmark done for timm/vit_giant_patch14_dinov2.lvd142m (1.1B params), eager mode, full BF16 training, activations checkpointing, batch size 32, on 4070Ti SUPER (16GB VRAM), Ryzen 5600, DDR4 RAM. DeepSpeed is untuned.

Adam offload Time per step Max memory
None 1.27s/it 9.82 GB
DeepSpeed ZeRO-Offload 3.13s/it 6.85 GB
ao 1.52s/it 5.24 GB
ao (offload gradients) 1.53s/it 4.01 GB

Ablations on AMP and torch.compile()

Training config Adam offload Time per step Max memory
Full BF16, compiled None 1.18s/it 9.90 GB
Full BF16, compiled ao 1.75s/it 5.33 GB
BF16 AMP, eager None OOM OOM
BF16 AMP, eager ao 2.18s/it 9.90 GB

Implementation details

Keep a copy of params on CPU. After backward pass, copy gradients from GPU to CPU (optionally deallocate GPU gradients). Do optimizer step on CPU. Copy updated gradients from CPU to GPU.

To hide CPU <-> GPU data movements, interleave grad device->host with backward, and interleave param host->device with CPU optim step. Also start CPU optim step as soon as CPU is free (i.e. after launching all backward kernels) -> some interleaving of backward and CPU optim step. The following trace illustrates the strategy.

image

Two interesting observations:

1. torch.compile() prevents overlapping grad D2H with backward. Probably because compiled backward will launch/queue all backward kernels at once, so waiting for current stream will last until all backwards finish. Trace with torch.compile()

image

One way to mitigate this is to compile parts of the model separately, so on the host side, backwards are launched as K groups of kernels, then we can start grad D2H in-between. (haven't tried it, just an idea. maybe still not possible). This would also reduce kernel launch overhead, which helps CPU Adam to start even earlier.

2. Fused CPU Adam is much faster in BF16 than in FP32. Trace with BF16 AMP (params, grads, optimizer states are in FP32), optim step time increases from ~700ms -> 1200ms.

image

Time for CUDA optim: CUDA forward time + CUDA backward time + CUDA optim time.
Time with CPU offload optim: CUDA forward time + CPU time to launch backward + CPU optim time.

If CPU optim time is super fast (assume zero), the upper bound will be CUDA forward time + CUDA backward time + param H2D time (though we can further hide H2D time with forward). So optim CPU offload MAY even be faster than optim on GPU.

My setup is Ryzen 5600 with dual-channel DDR4. Using CPUs with AVX-512 support (Ryzen 4 and later or server CPUs) and DDR5 (or 6-channel DDR4 on servers) would probably be faster.

Copy link

pytorch-bot bot commented Aug 1, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/584

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 231a6ef with merge base de4a1fb (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 1, 2024
Copy link
Contributor

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this approach, it looks like you are okay with having only the optimizer states be solely on CPU and are not targeting any parameter/gradient memory, which is fine. However, you would still see throughput improvements if you tried to overlap gradient D2H copies with backward and H2D copies with forward, but this leads to some complexity.

For example, with this kind of overlapping (taken from FSDP2 backward), it is possible to mostly hide the copies:
Screenshot 2024-08-01 at 8 43 28 AM

# copy gradients from CUDA to CPU
for p_cpu, p_cuda in self.param_cpu2cuda_map.items():
if p_cuda.grad is not None:
p_cpu.grad = p_cuda.grad.to("cpu", non_blocking=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To check that we on the same page, the non_blocking=True here means that the host (CPU) is not blocked on this D2H copy. However, there is nothing for these D2H copies to overlap with, so the main benefit you are getting here is that copying D2H with non_blocking=True will copy directly to pinned memory.

Otherwise, the CPU side should look like issuing D2H copy for each gradient and then blocking via the torch.cuda.synchronize() for all D2H copies to finish.


# copy updated param from CPU to CUDA
for p_cpu, p_cuda in self.param_cpu2cuda_map.items():
p_cuda.copy_(p_cpu, non_blocking=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For these H2D copies, the non_blocking=True here only means that the CPU will not be blocked. The p_cpu is already in pinned memory, so there is no further pinned memory consideration.

Calling non_blocking=True allows the CPU to proceed into the next logic whether that is logging, the next iteration data loading, etc. or whatever.

However, subsequent CUDA kernels issued in the default stream will still serialize with the H2D copies.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will still mention that this non_blocking is still benefiicial as it allows the cpu to enqueue all the copies and much better saturate the bw even if there is no overlap with compute.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albanD I wanted to understand this point better.

If you call non_blocking=False, then there is a cudaDeviceSynchronize after each copy, blocking the CPU until the copy finishes. After that, the CPU will proceed to issue the next copy, so there may be some slight gaps between each H2D copy.

The part that I am not clear on is, are you suggesting that these gaps are exactly what would hurt the overall copy bandwidth, or do you mean that if you issue back-to-back H2D Memcpys, then there is some kind of batching effect across copies that improves bandwidth? (The latter would be non-intuitive to me, so I wanted to check.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess for non_blocking=False, the additional cudaDeviceSynchronize is coupled with having to copy to paged memory as well, so that also is slower than copying to pinned memory.

@gau-nernst
Copy link
Collaborator Author

@awgu Thank you for your feedback.

For "overlap gradient D2H copies with backward", it probably cannot be done without intrusive change to the training code? Perhaps something like this can help with the overlapping (i.e. once we finish accumulating gradient for a param, we start moving it to CPU, while still doing backward for other params. Since optim step on CPU is blocking, we can only do optim step once all gradients are copied to CPU?).

For "H2D copies with forward", how can I do this? I have been reading this, and it says that I need to use a separate CUDA stream for data transfer to overlap with computation. So it means that I have to somehow synchronize the CPU->CUDA transfer (which will be in a separate CUDA stream) before that param is needed for forward? Perhaps some kind of forward hook? (again, intrusive changes to the training code T.T)

@awgu
Copy link
Contributor

awgu commented Aug 1, 2024

I agree once you want to overlap, it becomes quite intrusive 😢 .

The post-accumulate-grad hook could be a good point to run the D2H copy for gradients, but again like you said, without a separate CUDA stream, that copy is not going to overlap. You will mainly be moving the same kernel that would happen in optimizer.step() into the backward.

I do not have a good solution for how to do the overlap in a non-intrusive way. From what I have seen, it is hard to do this kind of overlap without some kind of nn.Module level API since that gives you good points to hook into the forward and backward.

@gau-nernst
Copy link
Collaborator Author

Add a poor man's attempt at interleaving grad D2H with backward. seems to work! (you can check the latest changes). speed improves from 1.2s/it to 1.0s/it.

Before (blue is backward kernels, red is copy kernels)
image

After
image

in the 2nd image, one thing concerns me is the backward kernels (blue) ends after copy kernels (red). Maybe some bugs with torch profiler? (I'm using torch.profiler.profile, the trace is obtained from the benchmarks/benchmark_low_bit_adam.py script). Loss curve seems ok.

The profiling trace also reveals that CPU optimizer step is the bottleneck, which we won't be able to hide.

image

So I think in this workload, I should try to feed more work to GPU (e.g. larger batch size, gradient accumulation...).

@awgu
Copy link
Contributor

awgu commented Aug 1, 2024

@gau-nernst Is there any way to share the trace file? The backward kernels ending after the D2H copies is pretty interesting 😆 .

@gau-nernst
Copy link
Collaborator Author

the file size is quite big, even after gzip (26MB). I will re-run with fewer number of training steps (currently 20, maybe I reduce to 5). is sharing directly here ok? or you prefer some other channels, possibly for security reasons. I can also share via CUDA-MODE discord.

@awgu
Copy link
Contributor

awgu commented Aug 1, 2024

Oh, I did not realize that the profiler was profiling so many steps. I think it okay to just profile 1 step? I am okay with anyway for you to share it.

@gau-nernst
Copy link
Collaborator Author

Here it is

optim_cpu_offload_d2h_overlap.tar.gz

For interleaving H2D param copy with forward, I'm thinking of using nn.Module.register_forward_pre_hook(). But even with the forward hook, it is tricky to know which params under that module should be synchronized. Maybe we only synchronize the immediate nn.Parameter (there is no direct API for this I think, but using .named_parameters() and check for no prefix should be ok).


def copy_grad_hook(p_cuda):
if p_cuda.grad is not None:
p_cpu = self.param_cuda2cpu_map[p_cuda]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need a self.d2h_grad_stream.wait_stream(torch.cuda.current_stream()) or else the D2H copy may not see the correct values in p_cuda.grad. This should be why your backward kernels are finishing after your D2H copies.

It is interesting that loss looks good 😃

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm testing with ViT fine-tuning, so I guess it's quite forgiving to bugs 🤣

@gau-nernst gau-nernst marked this pull request as ready for review August 4, 2024 08:48
Copy link
Contributor

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM!

By the way, in pre-training, people really like to do clip_grad_norm_. If we do that on CPU, I imagine it will also be super slow, so it might be better to leave gradients on GPU, clip on GPU, and then incur an exposed D2H copy to CPU. (just something to think about)

Also, I would be pretty curious to see what DeepSpeed's trace looks like to understand the perf difference.


# deallocate CUDA gradients once D2H transfer finishes.
if offload_gradients:
p_cuda.grad.record_stream(self.stream)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that record_stream will have non-deterministic memory behavior (namely, when the CUDA tensor gets freed depends on when its last GPU kernel finishes, which is difficult to reason about precisely).

We really want to move away from using record_stream, but it can make your implementation more complicated.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, what are the alternatives?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The crux of the issue is that p_cuda.grad is allocated in the default stream but has ops on it in a different stream, so there is a producer/consumer stream relationship.

In such cases, you need to make sure that the consumer stream's kernels (in this case, the D2H copy) before any kernels in the producer stream reuse that memory.

The idea then is to hold a reference to p_cuda.grad until the CPU has issued the ops with which you want the D2H copy to overlap with, and then you do torch.cuda.current_stream().wait_event(event) where event was recorded in self.stream right after the D2H copy and current_stream() is the default/producer stream. That way, any subsequent ops in the producer stream will run after the D2H copy has finished and can safely reuse the p_cuda.grad address.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The challenge can be that you do not know how many / which ops to overlap with, so it is not convenient to sync back (torch.cuda.current_stream().wait_event(d2h_event)).

However, for cases like FSDP, we do have a good time: e.g. the previous reduce-scatter must finish before the next reduce-scatter, so let us wait for the previous reduce-scatter (doing this "sync back") right before the next reduce-scatter.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.cuda.current_stream().wait_event(event) means that the next backward op cannot overlap with D2H?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think another option is to delete p_cuda.grad reference inside optim.step(), but it means we only start deallocating CUDA grad when we iterate over self.queue -> might not reduce much peak memory.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.cuda.current_stream().wait_event(event) means that the next backward op cannot overlap with D2H?

Yes, the next backward op after the wait_event call cannot overlap with the D2H right before the recorded event.

One way to think about it is to think about the actual CUDA address. The GPU gradient must have some address A. We need to make sure that no other op uses A until the D2H finishes. We can reserve A and make sure no other backward ops use it as long as we keep a reference to A (as a PyTorch implementation fact). At some point, we have overlapped enough ops, and we can then free A, requiring the aforementioned sync back.

Note that with record_stream, the address A will be reserved until the D2H copy finishes on GPU, at which point maybe many or even all backward ops were issued (in the most extreme case). In that case, none of the backward ops can actually reuse A. This memory reuse depends on the relative timing of CPU and GPU, which makes it difficult to reason about precisely.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think another option is to delete p_cuda.grad reference inside optim.step(), but it means we only start deallocating CUDA grad when we iterate over self.queue -> might not reduce much peak memory.

I think calling record_stream probably dominates that approach because you will have to block CPU until the D2H finishes anyway.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Thank you for your detailed explanation!

params = param_group.pop("params")

for p_cuda in params:
p_cpu = p_cuda.detach().cpu().pin_memory()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: If you want this init to be slightly faster, you can probably pre-allocate the pinned memory and copy to it directly so that you do not have the intermediate copy to CPU paged memory.


for p_cuda, grad_d2h_event in self.queue.items():
grad_d2h_event.synchronize()
self.optim_dict[p_cuda].step()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, so fused optimizer only fuses vertically? (or is there a potential perf hit here by running per-parameter fused optimizer step?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is some perf hit due to calling fused Adam on each parameter separately (my current approach) instead of all (or some) parameters (650ms -> 750ms iirc). I couldn't figure out a way to call fused Adam on more than one parameter because of synchronization: in __init__(), we don't know which params will have theirs grads D2H finish first, so we can't statically schedule and group the params.

Technically it's still possible if we use functional Adam (i.e. wait for a few items in self.queue, then call functional Adam on them), but then it would require writing optim-specific code, instead of treating base optimizer as a black box.

@gau-nernst
Copy link
Collaborator Author

I sent the DeepSpeed trace on Discord. Two main reasons (1) DeepSpeed CPU Adam is slower than PyTorch fused Adam, (2) They don't interleave data transfer as well as CPU optim step (might be because I didn't set the config correctly).

@gau-nernst
Copy link
Collaborator Author

gau-nernst commented Aug 5, 2024

Regarding gradient clipping, I thought about it too. The biggest improvement is actually from overlap CPU Adam with backward (i.e. start CPU Adam as soon as host finish launching all backward kernels). We can still move grad D2H during backward (help with hiding data transfer + offload gradients), but CPU Adam can only start when all gradients are present on CPU to do gradient clipping. Even if we do gradient clipping on GPU, CPU Adam still needs to wait for all gradients to be available (i.e. backward finish).

Probably good to add a note that this CPU offload optimizer doesn't support gradient clipping at the moment. We can add support for it in a future PR.

Edit: another thing I haven't considered is gradient clipping speed. It is memory-bound. Gradient clipping on CPU would probably be much slower than on GPU.

optim = CPUOffloadOptimizer(model.parameters(), torch.optim.AdamW, fused=True)
optim.load_state_dict(ckpt["optim"])
```

## Credits

Credits to Tim Dettmers for creating the wonderful [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) library, and [lpmm](https://github.com/thu-ml/low-bit-optimizers) authors for their work on 4-bit optimizers.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

credit deepspeed as well

NOTE:
- Since the optimizer step is done on CPU, it is highly recommended to use a fast CPU optimizer, such as `torch.optim.AdamW(fused=True)` (requires PyTorch 2.4). For other optimizers, you can try `torch.compile()` their optimizer step.
- To minimize the amount of CPU<->GPU data transfer, we keep a copy of parameters and pre-allocate gradients memory on CPU. Therefore, expect your RAM usage to increase by 2x model size + optimizer state (which is 2x model size for Adam).
- It is recommended NOT to `torch.compile()` your whole model when `CPUOffloadOptimizer` is used, as it prevents us from interleaving gradient device-to-host transfer with backward pass. To minimize such impact, you can compile parts of your model separately.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not clear to me from the test or benchmark when this specific point is relevant

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you referring to the one about torch.compile? (sorry not sure which line you are referring to from GitHub UI). If it is, I can add benchmark for this config in README, and also its trace in the PR description (to compare with eager).

- Since the optimizer step is done on CPU, it is highly recommended to use a fast CPU optimizer, such as `torch.optim.AdamW(fused=True)` (requires PyTorch 2.4). For other optimizers, you can try `torch.compile()` their optimizer step.
- To minimize the amount of CPU<->GPU data transfer, we keep a copy of parameters and pre-allocate gradients memory on CPU. Therefore, expect your RAM usage to increase by 2x model size + optimizer state (which is 2x model size for Adam).
- It is recommended NOT to `torch.compile()` your whole model when `CPUOffloadOptimizer` is used, as it prevents us from interleaving gradient device-to-host transfer with backward pass. To minimize such impact, you can compile parts of your model separately.
- CPU optimizer step is often the bottleneck when optimizer CPU offload is used. To minimize the slowdown, it is recommended to (1) do full BF16 training (instead of AMP), so that parameters, gradients, and optimizer states are in BF16; and (2) give GPU more work per optimizer step (e.g. larger batch size with activation checkpointing, gradient accumulation).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

full bf16 training can be tricky fwiw, i believe we'll likely run into convergence issues at larger model sizes but this is fine for now

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add benchmarks for BF16 AMP.

@msaroufim msaroufim self-requested a review August 6, 2024 05:38
Copy link
Member

@msaroufim msaroufim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some minor nits but this is very nice

@msaroufim msaroufim merged commit 1b1e94c into pytorch:main Aug 6, 2024
13 checks passed
@gau-nernst gau-nernst deleted the optim_cpu_offload branch August 6, 2024 23:57
jainapurva pushed a commit that referenced this pull request Aug 7, 2024
* initial commit

* use fused=True by default for PyTorch adam

* detach param

* try overlap D2H grad copy with backward

* add customizable profile num steps

* add v2

* fix various bugs

* fix v1 impl

* add full BF16 option

* change n_profile_steps to 5

* add v3

* fix gradient accumulation

* add note

* add deepspeed offload

* update deepspeed config

* add some notes

* update instructions. make some packages optional. change to AdamW

* add last updated ordered dict

* update deepspeed version

* remove old versions

* update docs

* say deepspeed is untuned

* add test

* add test for offload_gradients. update benchmark script

* update benchmark results. fix test. fix benchmark script

* fix language

* add save and load

* pre-allocate CPU params. add note about gradient clipping

* update README and remove unused imports
@Theodotus1243
Copy link

How to use CPUOffloadOptimizer with LRScheduler
As it has check

# Attach optimizer
        if not isinstance(optimizer, Optimizer):
            raise TypeError(f'{type(optimizer).__name__} is not an Optimizer')

And

class CPUOffloadOptimizer:
    def __init__(

@gau-nernst
Copy link
Collaborator Author

Hi @Theodotus1243, you have to manually set the LR, since built-in PyTorch's LRScheduler will enforce the optimizer to be an torch.optim.Optimizer subclass as you have discovered. Something like this

lr = lr_schedule.get_lr(step)
for param_group in optim.param_groups:
if isinstance(param_group["lr"], torch.Tensor):
param_group["lr"].fill_(lr)
else:
param_group["lr"] = lr

The reason I don't want to make CPUOffloadOptimizer to be a torch.optim.Optimizer subclass is that it doesn't seem right: CPUOffloadOptimizer itself doesn't hold the params and buffers, it delegates to the base optimizer class, and only hold a list of base optimizers.

Hope it clarifies the problem. I think we can add this caveat to doc.

@bghira
Copy link

bghira commented Sep 26, 2024

i think it should not be referred to as a drop-in replacement then @gau-nernst and as it is, having a method that sets the lr isn't too much to ask for, i hope?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants