From b7a4acc1c8599f9306b519c9a88c044f1b280a07 Mon Sep 17 00:00:00 2001 From: Wil Kong Date: Fri, 30 Aug 2024 19:22:47 +0800 Subject: [PATCH] Add Unittest For Distributed Adam With CUDA Graph (#1836) * Add unittest for distributed adam with cuda graph. * Fix the distributed adam issue if user passes float LR. * skip if world_size < 8 --------- Co-authored-by: Masaki Kozuki --- .../optimizers/distributed_fused_adam.py | 6 +- .../contrib/test/optimizers/test_dist_adam.py | 175 ++++++++++++------ 2 files changed, 127 insertions(+), 54 deletions(-) diff --git a/apex/contrib/optimizers/distributed_fused_adam.py b/apex/contrib/optimizers/distributed_fused_adam.py index 364294aae..ae3590d89 100644 --- a/apex/contrib/optimizers/distributed_fused_adam.py +++ b/apex/contrib/optimizers/distributed_fused_adam.py @@ -819,7 +819,11 @@ def __init__( if len(group['params']) == 0: continue for item in ['lr']: - self.param_groups[idx][item] = group[item].to(device=self.device) + if torch.is_tensor(group[item]): + self.param_groups[idx][item] = group[item].to(device=self.device) + else: + self.param_groups[idx][item] = torch.tensor(group[item], + device=self.device) # For better representation string arg_names = inspect.getfullargspec(DistributedFusedAdam.__init__).args diff --git a/apex/contrib/test/optimizers/test_dist_adam.py b/apex/contrib/test/optimizers/test_dist_adam.py index cf21a19b7..12ff4c021 100644 --- a/apex/contrib/test/optimizers/test_dist_adam.py +++ b/apex/contrib/test/optimizers/test_dist_adam.py @@ -3,6 +3,7 @@ from typing import Callable, Optional, Tuple import unittest import warnings +from contextlib import nullcontext import torch from torch.testing._internal import common_utils @@ -49,6 +50,7 @@ def make_models( store_param_remainders: bool = False, with_scaled_states: bool = False, nccl_ub: bool = False, + with_cuda_graph: bool = False, ): # Construct models with same parameters @@ -100,6 +102,7 @@ def make_models( store_param_remainders=store_param_remainders, with_scaled_states=with_scaled_states, nccl_ub=nccl_ub, + capturable=with_cuda_graph, **optim_args, ) @@ -143,78 +146,130 @@ def test_matches_pytorch( with_scaled_states: bool = False, nccl_ub: bool = False, init_optim_func: Optional[Callable[[DistributedFusedAdam], None]] = None, + with_cuda_graph: bool = False, ): torch.manual_seed(self.seed + self.rank) # Identical models with data-parallel and ZeRO - ref_model, ref_optim, dist_model, dist_optim = make_models( - num_layers, - layer_size, - adam_w_mode=adam_w_mode, - model_dtype=model_dtype, - optim_dtype=optim_dtype, - grad_sync_dtype=grad_sync_dtype, - param_sync_dtype=param_sync_dtype, - device=device, - overlap_communication=overlap_communication, - bucket_cap_mb=bucket_cap_mb, - contiguous_buffers=contiguous_buffers, - store_params=store_params, - store_param_remainders=store_param_remainders, - with_scaled_states=with_scaled_states, - nccl_ub=nccl_ub, - ) + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + ref_model, ref_optim, dist_model, dist_optim = make_models( + num_layers, + layer_size, + adam_w_mode=adam_w_mode, + model_dtype=model_dtype, + optim_dtype=optim_dtype, + grad_sync_dtype=grad_sync_dtype, + param_sync_dtype=param_sync_dtype, + device=device, + overlap_communication=overlap_communication, + bucket_cap_mb=bucket_cap_mb, + contiguous_buffers=contiguous_buffers, + store_params=store_params, + store_param_remainders=store_param_remainders, + with_scaled_states=with_scaled_states, + nccl_ub=nccl_ub, + with_cuda_graph=with_cuda_graph, + ) # Initialize distributed optimizer if init_optim_func is not None: - init_optim_func(dist_optim) + with torch.cuda.stream(stream): + init_optim_func(dist_optim) - # Training loop - for step in range(num_steps): + # Static data + static_xs, static_dys = [], [] + ys_ref, grad_xs_ref = [], [] + ys_dist, grad_xs_dist = [], [] - # Reset gradients - ref_optim.zero_grad() - dist_optim.zero_grad() - - # Forward and backward passes - for micro_step in range(micro_batch_steps): + graph = torch.cuda.CUDAGraph() if with_cuda_graph else None + CAPTURE_ITERATION = 11 + if with_cuda_graph: + assert num_steps > CAPTURE_ITERATION + 3, \ + "Not enough iterations for CUDA graph test." + # Training loop + with torch.cuda.stream(stream): + for step in range(num_steps): # Synthetic data - x = torch.rand(batch_size, layer_size) - 0.5 - dy = torch.rand_like(x) - 0.5 - x = x.to(dtype=model_dtype, device=device) - dy = dy.to(dtype=model_dtype, device=device) + for micro_step in range(micro_batch_steps): + x = torch.rand(batch_size, layer_size) - 0.5 + dy = torch.rand_like(x) - 0.5 + x = x.to(dtype=model_dtype, device=device) + dy = dy.to(dtype=model_dtype, device=device) + if step == 0: + static_xs.append(x) + static_dys.append(dy) + else: + static_xs[micro_step].copy_(x) + static_dys[micro_step].copy_(dy) # Reference implementation - x_ref = x.detach().clone().requires_grad_(True) - y_ref = ref_model(x_ref) - y_ref.backward(dy) + ref_optim.zero_grad() + for micro_step in range(micro_batch_steps): + x, dy = static_xs[micro_step], static_dys[micro_step] + + x_ref = x.detach().clone().requires_grad_(True) + y_ref = ref_model(x_ref) + y_ref.backward(dy) + + if step == 0: + ys_ref.append(y_ref) + grad_xs_ref.append(x_ref.grad) + else: + with torch.no_grad(): + ys_ref[micro_step].copy_(y_ref) + grad_xs_ref[micro_step].copy_(x_ref.grad) + ref_optim.step() # Distributed implementation - x_dist = x.detach().clone().requires_grad_(True) - y_dist = dist_model(x_dist) - backward_context = dummy_context - if use_nosync and micro_step < micro_batch_steps-1: - backward_context = dist_optim.no_sync - with backward_context(): - y_dist.backward(dy) + if not with_cuda_graph or step <= CAPTURE_ITERATION: + if with_cuda_graph and step == CAPTURE_ITERATION: + ctx = torch.cuda.graph(graph) + torch.cuda.synchronize() + else: + ctx = nullcontext() + + with ctx: + dist_optim.zero_grad() + for micro_step in range(micro_batch_steps): + x, dy = static_xs[micro_step], static_dys[micro_step] + + x_dist = x.detach().clone().requires_grad_(True) + y_dist = dist_model(x_dist) + backward_context = dummy_context + if use_nosync and micro_step < micro_batch_steps-1: + backward_context = dist_optim.no_sync + with backward_context(): + y_dist.backward(dy) + + if step == 0: + ys_dist.append(y_dist) + grad_xs_dist.append(x_dist.grad) + else: + with torch.no_grad(): + ys_dist[micro_step].copy_(y_dist) + grad_xs_dist[micro_step].copy_(x_dist.grad) + dist_optim.step() + + if with_cuda_graph and step == CAPTURE_ITERATION: + graph.replay() + else: + graph.replay() # Check that data tensors match - torch.testing.assert_close( - y_dist, y_ref, rtol=rtol, atol=atol) - torch.testing.assert_close( - x_dist.grad, x_ref.grad, rtol=rtol, atol=atol) - - # Optimization step - ref_optim.step() - dist_optim.step() + for mbs in range(micro_batch_steps): + torch.testing.assert_close( + ys_dist[mbs], ys_ref[mbs], rtol=rtol, atol=atol) + torch.testing.assert_close( + grad_xs_dist[mbs], grad_xs_ref[mbs], rtol=rtol, atol=atol) - # Check that parameters match - for ref_param, dist_param in zip(ref_model.parameters(), - dist_model.parameters()): - torch.testing.assert_close( - dist_param, ref_param, rtol=rtol, atol=atol) + # Check that parameters match + for ref_param, dist_param in zip(ref_model.parameters(), + dist_model.parameters()): + torch.testing.assert_close( + dist_param, ref_param, rtol=rtol, atol=atol) def test_matches_pytorch_l2_reg(self): self.test_matches_pytorch(adam_w_mode=False) @@ -797,6 +852,20 @@ def test_bucket_low_utilization_warning(self): for w in warns: self.assertNotRegex(str(w.message), ".*Consider decreasing the bucket_cap_mb argument.") + def test_cuda_graph(self): + """Test distributed adam with CUDA graph""" + if self.world_size <= 8: + self.skipTest(f"{self.world_size=} is expected to be >= 8") + self.test_matches_pytorch( + rtol=5e-3, + atol=1e-5, + num_steps=15, + micro_batch_steps=1, + model_dtype=torch.float16, + optim_dtype=torch.float16, + contiguous_buffers=True, + with_cuda_graph=True, + ) if __name__ == "__main__": # Assume script has been run with torchrun