Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizer CPU offload for single GPU training #584

Merged
merged 31 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
3cd42d2
initial commit
gau-nernst Aug 1, 2024
c044d88
use fused=True by default for PyTorch adam
gau-nernst Aug 1, 2024
d85b172
detach param
gau-nernst Aug 1, 2024
d468e6f
try overlap D2H grad copy with backward
gau-nernst Aug 1, 2024
d7a07eb
add customizable profile num steps
gau-nernst Aug 1, 2024
fe653e9
add v2
gau-nernst Aug 2, 2024
8ae42c3
fix various bugs
gau-nernst Aug 2, 2024
b2c00e5
fix v1 impl
gau-nernst Aug 2, 2024
68835e3
add full BF16 option
gau-nernst Aug 2, 2024
b5393cb
change n_profile_steps to 5
gau-nernst Aug 2, 2024
3069b23
add v3
gau-nernst Aug 3, 2024
7af8518
fix gradient accumulation
gau-nernst Aug 3, 2024
5ff2e5a
add note
gau-nernst Aug 3, 2024
a8a7b5a
add deepspeed offload
gau-nernst Aug 3, 2024
40aea0c
update deepspeed config
gau-nernst Aug 3, 2024
dff2f9c
add some notes
gau-nernst Aug 3, 2024
bd8db68
update instructions. make some packages optional. change to AdamW
gau-nernst Aug 3, 2024
c3883ce
add last updated ordered dict
gau-nernst Aug 3, 2024
0e9235c
update deepspeed version
gau-nernst Aug 3, 2024
b6e4c6a
remove old versions
gau-nernst Aug 4, 2024
c514dba
update docs
gau-nernst Aug 4, 2024
cfdfe5d
say deepspeed is untuned
gau-nernst Aug 4, 2024
c4ea68b
add test
gau-nernst Aug 4, 2024
6478be9
add test for offload_gradients. update benchmark script
gau-nernst Aug 4, 2024
03cf0ad
update benchmark results. fix test. fix benchmark script
gau-nernst Aug 4, 2024
a144b22
fix language
gau-nernst Aug 4, 2024
d344817
add save and load
gau-nernst Aug 4, 2024
fc358b1
Merge branch 'pytorch:main' into optim_cpu_offload
gau-nernst Aug 4, 2024
5a5253e
Merge branch 'main' into optim_cpu_offload
gau-nernst Aug 5, 2024
7aa31eb
pre-allocate CPU params. add note about gradient clipping
gau-nernst Aug 5, 2024
231a6ef
update README and remove unused imports
gau-nernst Aug 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion benchmarks/benchmark_low_bit_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

# lpmm doesn't have Adam, only AdamW
OPTIM_MAP = dict(
Adam=torch.optim.Adam,
Adam=partial(torch.optim.Adam, fused=True),
msaroufim marked this conversation as resolved.
Show resolved Hide resolved
Adam8bitBnb=bnb.optim.Adam8bit,
Adam8bitAo=low_bit_optim.Adam8bit,
AdamFp8Ao=low_bit_optim.AdamFp8,
Expand Down Expand Up @@ -90,6 +90,7 @@ def get_parser():
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--weight_decay", type=float, default=0)
parser.add_argument("--cosine_lr_scheduler", action="store_true")
parser.add_argument("--optim_cpu_offload", action="store_true")

parser.add_argument("--project")
parser.add_argument("--run_name", default="debug")
Expand Down Expand Up @@ -177,6 +178,8 @@ def evaluate_model(model, args):
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

optim = OPTIM_MAP[args.optim](model.parameters(), args.lr, weight_decay=args.weight_decay)
if args.optim_cpu_offload:
optim = low_bit_optim.CPUOffloadOptimizer(optim)
lr_schedule = CosineSchedule(args.lr, len(dloader) * args.n_epochs)

grad_scaler = torch.amp.GradScaler("cuda", enabled=args.amp == "fp16")
Expand Down
1 change: 1 addition & 0 deletions torchao/prototype/low_bit_optim/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .adam import Adam8bit, Adam4bit, AdamFp8
from .adamw import AdamW8bit, AdamW4bit, AdamWFp8
from .cpu_offload import CPUOffloadOptimizer
34 changes: 34 additions & 0 deletions torchao/prototype/low_bit_optim/cpu_offload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import torch
from torch.optim.optimizer import Optimizer


class CPUOffloadOptimizer:
def __init__(self, base_optimizer: Optimizer) -> None:
self.optim = base_optimizer
self.param_cpu2cuda_map = dict()

msaroufim marked this conversation as resolved.
Show resolved Hide resolved
# swap param in param_groups with CPU param
for param_group in base_optimizer.param_groups:
for i, p in enumerate(param_group["params"]):
p_cpu = p.detach().cpu().pin_memory()
param_group["params"][i] = p_cpu
self.param_cpu2cuda_map[p_cpu] = p

@torch.no_grad()
def step(self, closure=None):
# copy gradients from CUDA to CPU
for p_cpu, p_cuda in self.param_cpu2cuda_map.items():
if p_cuda.grad is not None:
p_cpu.grad = p_cuda.grad.to("cpu", non_blocking=True)
msaroufim marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To check that we on the same page, the non_blocking=True here means that the host (CPU) is not blocked on this D2H copy. However, there is nothing for these D2H copies to overlap with, so the main benefit you are getting here is that copying D2H with non_blocking=True will copy directly to pinned memory.

Otherwise, the CPU side should look like issuing D2H copy for each gradient and then blocking via the torch.cuda.synchronize() for all D2H copies to finish.

p_cuda.grad = None
torch.cuda.synchronize()

self.optim.step(closure)

# copy updated param from CPU to CUDA
for p_cpu, p_cuda in self.param_cpu2cuda_map.items():
p_cuda.copy_(p_cpu, non_blocking=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For these H2D copies, the non_blocking=True here only means that the CPU will not be blocked. The p_cpu is already in pinned memory, so there is no further pinned memory consideration.

Calling non_blocking=True allows the CPU to proceed into the next logic whether that is logging, the next iteration data loading, etc. or whatever.

However, subsequent CUDA kernels issued in the default stream will still serialize with the H2D copies.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will still mention that this non_blocking is still benefiicial as it allows the cpu to enqueue all the copies and much better saturate the bw even if there is no overlap with compute.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@albanD I wanted to understand this point better.

If you call non_blocking=False, then there is a cudaDeviceSynchronize after each copy, blocking the CPU until the copy finishes. After that, the CPU will proceed to issue the next copy, so there may be some slight gaps between each H2D copy.

The part that I am not clear on is, are you suggesting that these gaps are exactly what would hurt the overall copy bandwidth, or do you mean that if you issue back-to-back H2D Memcpys, then there is some kind of batching effect across copies that improves bandwidth? (The latter would be non-intuitive to me, so I wanted to check.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess for non_blocking=False, the additional cudaDeviceSynchronize is coupled with having to copy to paged memory as well, so that also is slower than copying to pinned memory.


msaroufim marked this conversation as resolved.
Show resolved Hide resolved
# redirect calls to base optimizer
def __getattr__(self, name: str):
return getattr(self.optim, name)
Loading