Skip to content

Commit

Permalink
FP8 plugin recipes (#10208) (#10807)
Browse files Browse the repository at this point in the history
* create mixed precision plugin recipes



* cleanup



* fix annotation



* full definition instead of attach



* update docstrings



* update fp8



* Apply isort and black reformatting



* refactor import



---------

Signed-off-by: Maanu Grover <maanug@nvidia.com>
Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>
Signed-off-by: JimmyZhang12 <JimmyZhang12@users.noreply.github.com>
Co-authored-by: JimmyZhang12 <67203904+JimmyZhang12@users.noreply.github.com>
Co-authored-by: Jimmy Zhang <jiemingz@nvidia.com>
Co-authored-by: JimmyZhang12 <JimmyZhang12@users.noreply.github.com>
  • Loading branch information
4 people authored Oct 8, 2024
1 parent 8eea169 commit a14a6a0
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 3 deletions.
22 changes: 22 additions & 0 deletions nemo/collections/llm/recipes/precision/mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,25 @@ def fp16_mixed() -> run.Config[MegatronMixedPrecision]:
autocast_enabled=False,
grad_reduce_in_fp32=False,
)


def bf16_with_fp8_mixed() -> run.Config[MegatronMixedPrecision]:
"""FP8 recipes are experimental and have not been tested for training convergence."""
cfg = bf16_mixed()
cfg.fp8 = 'hybrid'
cfg.fp8_margin = 0
cfg.fp8_amax_history_len = 1024
cfg.fp8_amax_compute_algo = "max"
cfg.fp8_params = True
return cfg


def fp16_with_fp8_mixed() -> run.Config[MegatronMixedPrecision]:
"""FP8 recipes are experimental and have not been tested for training convergence."""
cfg = fp16_mixed()
cfg.fp8 = 'hybrid'
cfg.fp8_margin = 0
cfg.fp8_amax_history_len = 1024
cfg.fp8_amax_compute_algo = "max"
cfg.fp8_params = True
return cfg
15 changes: 12 additions & 3 deletions nemo/lightning/pytorch/plugins/mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torch.optim import Optimizer

from nemo.utils import logging
from nemo.utils.import_utils import safe_import

AnyT = TypeVar("AnyT")

Expand Down Expand Up @@ -54,12 +55,12 @@ class DtypeConfig:
# fp8 related
fp8: str = None
fp8_margin: int = 0
fp8_interval: int = 1
fp8_amax_history_len: int = 1
fp8_amax_compute_algo: str = "most_recent"
fp8_wgrad: bool = True
fp8_dot_product_attention: bool = False
fp8_multi_head_attention: bool = False
fp8_param_gather: bool = True
# FP16 Loss scaling
loss_scale: float = (None,)
initial_loss_scale: float = (None,)
Expand All @@ -80,12 +81,12 @@ def __init__(
# fp8 related,
fp8: str = None,
fp8_margin: int = 0,
fp8_interval: int = 1,
fp8_amax_history_len: int = 1,
fp8_amax_compute_algo: str = "most_recent",
fp8_wgrad: bool = True,
fp8_dot_product_attention: bool = False,
fp8_multi_head_attention: bool = False,
fp8_params: bool = False,
fp16_loss_scale: float = None,
fp16_initial_loss_scale: float = 4294967296,
fp16_min_loss_scale: float = 1.0,
Expand All @@ -96,6 +97,14 @@ def __init__(
if isinstance(precision, int):
precision = str(precision)

fp8_param_gather = False
if fp8 is not None:
te_fp8, HAVE_TE = safe_import("transformer_engine.pytorch.fp8")
assert HAVE_TE, "FP8 precision requires transformer engine."
if fp8_params:
te_fp8.FP8GlobalStateManager.FP8_PARAMETERS = True
fp8_param_gather = True

dtype = torch.bfloat16 if precision in ['bf16', 'bf16-mixed'] else torch.float32
self.dtype_config = DtypeConfig(
fp32=precision in ['fp32', '32'],
Expand All @@ -108,12 +117,12 @@ def __init__(
grad_reduce_in_fp32=grad_reduce_in_fp32,
fp8=fp8,
fp8_margin=fp8_margin,
fp8_interval=fp8_interval,
fp8_amax_history_len=fp8_amax_history_len,
fp8_amax_compute_algo=fp8_amax_compute_algo,
fp8_wgrad=fp8_wgrad,
fp8_dot_product_attention=fp8_dot_product_attention,
fp8_multi_head_attention=fp8_multi_head_attention,
fp8_param_gather=fp8_param_gather,
# fp16 loss scale
loss_scale=fp16_loss_scale,
initial_loss_scale=fp16_initial_loss_scale,
Expand Down

0 comments on commit a14a6a0

Please sign in to comment.