Skip to content

Commit

Permalink
[plugin] support get_grad_norm (#6115)
Browse files Browse the repository at this point in the history
  • Loading branch information
ver217 authored Nov 5, 2024
1 parent 13ffa08 commit a15ab13
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 2 deletions.
7 changes: 6 additions & 1 deletion colossalai/amp/naive_amp/mixed_precision_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple

import torch
from torch import Tensor, inf
Expand Down Expand Up @@ -84,6 +84,7 @@ def __init__(
self.master_to_working_map[master_p] = p
master_params.append(master_p)
group["params"] = master_params
self._current_grad_norm: Optional[float] = None

def backward(self, loss: Tensor, *args, **kwargs):
loss = self.mixed_precision.pre_backward(loss)
Expand Down Expand Up @@ -187,6 +188,7 @@ def step(self, *args, **kwargs):
if p.grad is not None
]
total_norm = self._compute_grad_norm(param_gradient_pairs)
self._current_grad_norm = total_norm
self._unscale_and_clip_grads(total_norm)

self.optim.step(*args, **kwargs)
Expand All @@ -212,3 +214,6 @@ def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:

def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
return {id(master_p): working_p for master_p, working_p in self.master_to_working_map.items()}

def get_grad_norm(self, norm_type=2, **kwargs):
return self._current_grad_norm
5 changes: 5 additions & 0 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ def __init__(
self.pp_pg = pp_process_group
self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
self._current_grad_norm: Optional[float] = None
super().__init__(optim)

def backward(self, loss: Tensor, *args, **kwargs):
Expand Down Expand Up @@ -364,6 +365,7 @@ def step(self, *args, **kwargs):
(p, p.grad) for group in self.optim.param_groups for p in group["params"] if p.grad is not None
]
total_norm = self._compute_grad_norm(param_gradient_pairs)
self._current_grad_norm = total_norm

# Clip the gradients to prevent exploding gradients.
self._clip_grad_norm(total_norm)
Expand Down Expand Up @@ -477,6 +479,9 @@ def get_working_to_master_map(self):
def get_master_to_working_map(self):
return None

def get_grad_norm(self, norm_type=2, **kwargs):
return self._current_grad_norm


class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
def __init__(
Expand Down
12 changes: 12 additions & 0 deletions colossalai/interface/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,18 @@ def unwrap(self):
"""
return self.optim

def get_grad_norm(self, norm_type: Union[float, int] = 2.0, **kwargs) -> Optional[float]:
"""
Returns the gradient norm of an iterable of parameters. This method should be called after optimizer.step().
Args:
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm.
Returns:
Optional[float]: Total norm of the gradients (viewed as a single vector). If there are no valid gradients, returns None.
"""
raise NotImplementedError("The method get_grad_norm is not implemented yet.")


class DistributedOptim(Optimizer):
def setup_distributed(
Expand Down
7 changes: 6 additions & 1 deletion colossalai/zero/gemini/gemini_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
import copy
import math
from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple, Union
from typing import Any, Dict, Iterator, Optional, OrderedDict, Set, Tuple, Union

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -195,6 +195,7 @@ def __init__(
self._logger.warning(f'gpu_margin_mem_ratio is meaningless when placement_policy is not "auto"', ranks=[0])

self._register_states = disposable(self._register_states_)
self._current_grad_norm: Optional[float] = None

def _set_grad_ptr(self):
for group in self.param_groups:
Expand Down Expand Up @@ -255,6 +256,7 @@ def _get_combined_scale(self):

if self.clipping_flag:
total_norm = self._calc_global_norm()
self._current_grad_norm = total_norm
clip = ((total_norm / div_scale) + 1e-6) / self.max_norm
if clip > 1:
div_scale = clip * div_scale
Expand Down Expand Up @@ -846,6 +848,9 @@ def clip_grad_by_norm(
f"Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm", ranks=[0]
)

def get_grad_norm(self, norm_type=2, **kwargs):
return self._current_grad_norm


class GeminiAdamOptimizer(GeminiOptimizer):
def __init__(self, model: torch.nn.Module, **defaults: Any) -> None:
Expand Down
5 changes: 5 additions & 0 deletions colossalai/zero/low_level/low_level_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def __init__(
)
elif self._dtype is torch.bfloat16:
self.mixed_precision_mixin = BF16MixedPrecisionMixin()
self._current_grad_norm: Optional[float] = None

def __del__(self):
for hook in self.grad_handles:
Expand Down Expand Up @@ -551,6 +552,7 @@ def step(self, closure=None):

# unscale and clip grads
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
self._current_grad_norm = global_norm
self._unscale_and_clip_grads(grad_partition_groups, global_norm)

# update the parameters
Expand Down Expand Up @@ -934,3 +936,6 @@ def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) ->
def _force_wait_all_gather(self):
for param in self._working_param_to_padded_working_param.keys():
wait_all_gather_handle(param)

def get_grad_norm(self, norm_type=2, **kwargs):
return self._current_grad_norm
2 changes: 2 additions & 0 deletions tests/test_booster/test_plugin/test_3d_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ def _criterion(outputs, inputs):

booster.execute_pipeline(data_iter, model, _criterion, optimizer, return_loss=True)
optimizer.step()
grad_norm = optimizer.get_grad_norm()
assert grad_norm is None or isinstance(grad_norm, float)

except Exception as e:
return repr(e)
Expand Down
2 changes: 2 additions & 0 deletions tests/test_booster/test_plugin/test_gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, t

booster.backward(loss, optimizer)
optimizer.step()
grad_norm = optimizer.get_grad_norm()
assert grad_norm is None or isinstance(grad_norm, float)

except NotImplementedError:
print(f"Tensor Parallelism policy for {model.__class__} is not implemented yet\n.")
Expand Down
2 changes: 2 additions & 0 deletions tests/test_booster/test_plugin/test_low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def run_fn(stage, model_fn, data_gen_fn, output_transform_fn, lora_config=None)

booster.backward(loss, optimizer)
optimizer.step()
grad_norm = optimizer.get_grad_norm()
assert grad_norm is None or isinstance(grad_norm, float)

except Exception as e:
return repr(e)
Expand Down

0 comments on commit a15ab13

Please sign in to comment.