Skip to content

Commit

Permalink
enabling FP8 by layer and customizing the FP8 recipe per layer
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Jul 8, 2024
1 parent 257751d commit 7557757
Show file tree
Hide file tree
Showing 15 changed files with 433 additions and 278 deletions.
52 changes: 43 additions & 9 deletions examples/fp8/ablations/configs/sanity_fp8.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -166,34 +166,68 @@ tokens:

fp8:
model:
- module_name: attn.qkv_proj
- module_name: model.decoder.0.mlp.down_proj
accum_dtype: KFLOAT16
input:
dtype: float16
weight:
dtype: FP8E4M3
margin: 1
margin: 0
interval: 1
bias:
dtype: float16
input_grad:
dtype: FP8E4M3
margin: 1
dtype: FP8E5M2
margin: 0
interval: 1
weight_grad:
dtype: FP8E4M3
margin: 1
dtype: FP8E5M2
margin: 0
interval: 1
output_grad:
dtype: FP8E5M2
margin: 0
interval: 1
split_accumulator:
output: true
input_grad: true
weight_grad: true
accumulate:
output: true
input_grad: true
weight_grad: true
- module_name: model.decoder.1.attn.qkv_proj
accum_dtype: KFLOAT16
input:
dtype: float16
weight:
dtype: FP8E4M3
margin: 1
margin: 0
interval: 1
bias:
dtype: float16
input_grad:
dtype: FP8E5M2
margin: 0
interval: 1
weight_grad:
dtype: FP8E5M2
margin: 0
interval: 1
output_grad:
dtype: FP8E5M2
margin: 0
interval: 1
split_accumulator:
output: true
input_grad: true
weight_grad: true
accumulate:
output: true
input_grad: true
weight_grad: true
optim:
master_weight_dtype: KFLOAT32
accum_dtype: KFLOAT32
exp_avg_dtype: FP8E4M3
exp_avg_sq_dtype: FP8E4M3
exp_avg_dtype: KFLOAT32
exp_avg_sq_dtype: KFLOAT32
4 changes: 3 additions & 1 deletion src/nanotron/config/fp8_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from dataclasses import dataclass
from typing import List, Literal, Union

from nanotron.fp8.constants import DTypes, FP8OptimRecipe, FP8SplitAccumulator, FP8TensorRecipe
from nanotron.fp8.constants import DTypes
from nanotron.fp8.recipe import FP8Accumulate, FP8OptimRecipe, FP8SplitAccumulator, FP8TensorRecipe
from nanotron.logging import get_logger

logger = get_logger(__name__)
Expand All @@ -27,6 +28,7 @@ class FP8LayerArgs:
weight_grad: Union[FP8TensorRecipe, TorchDtype]
output_grad: Union[FP8TensorRecipe, TorchDtype]
split_accumulator: FP8SplitAccumulator
accumulate: FP8Accumulate


@dataclass
Expand Down
1 change: 1 addition & 0 deletions src/nanotron/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
IS_FP8: bool = True

NN_STATES = None
CONFIG = None

TRACKING_FP8_PARAM = {}

Expand Down
41 changes: 24 additions & 17 deletions src/nanotron/fp8/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,28 @@

# TODO(xrsrke): differentiate the precision that you initializes model weight
# and the accumulation precision in FP8 recipe

FP8LM_LINEAR_RECIPE = FP8LinearRecipe(
accum_dtype=DTypes.KFLOAT16,
input=FP8TensorRecipe(dtype=DTypes.FP8E4M3, margin=0, interval=16),
weight=FP8TensorRecipe(dtype=DTypes.FP8E4M3, margin=0, interval=1),
bias=FP8TensorRecipe(dtype=DTypes.KFLOAT16, margin=0, interval=16),
# NOTE: these are the dtypes for the gradients
input_grad=FP8TensorRecipe(dtype=DTypes.FP8E5M2, margin=0, interval=16), # NOTE: this is output_grad
weight_grad=FP8TensorRecipe(dtype=DTypes.FP8E4M3, margin=0, interval=1),
output_grad=FP8TensorRecipe(dtype=DTypes.FP8E5M2, margin=0, interval=16),
split_accumulator=FP8SplitAccumulator(output=True, input_grad=True, weight_grad=True),
# NOTE: tested, and it works
# accumulate=FP8SplitAccumulator(output=False, input_grad=False, weight_grad=False),
accumulate=FP8SplitAccumulator(output=True, input_grad=True, weight_grad=True),
)
FP8LM_OPTIM_RECIPE = FP8OptimRecipe(
accum_dtype=DTypes.KFLOAT32,
master_weight_dtype=DTypes.KFLOAT16,
exp_avg_dtype=DTypes.FP8E4M3,
exp_avg_sq_dtype=DTypes.KFLOAT16,
)

FP8LM_RECIPE = FP8TrainingRecipe(
# linear=FP8LinearRecipe(
# accum_dtype=DTypes.KFLOAT16,
Expand All @@ -67,17 +89,7 @@
# split_accumulator=FP8SplitAccumulator(output=True, input_grad=True, weight_grad=True),
# ),
# # NOTE: FP8-LM recipe
linear=FP8LinearRecipe(
accum_dtype=DTypes.KFLOAT16,
input=FP8TensorRecipe(dtype=DTypes.FP8E4M3, margin=0, interval=16),
weight=FP8TensorRecipe(dtype=DTypes.FP8E4M3, margin=0, interval=1),
bias=FP8TensorRecipe(dtype=DTypes.KFLOAT16, margin=0, interval=16),
# NOTE: these are the dtypes for the gradients
input_grad=FP8TensorRecipe(dtype=DTypes.FP8E5M2, margin=0, interval=16), # NOTE: this is output_grad
weight_grad=FP8TensorRecipe(dtype=DTypes.FP8E4M3, margin=0, interval=1),
output_grad=FP8TensorRecipe(dtype=DTypes.FP8E5M2, margin=0, interval=16),
split_accumulator=FP8SplitAccumulator(output=True, input_grad=True, weight_grad=True),
),
linear=FP8LM_LINEAR_RECIPE,
# NOTE: works for 8B
# linear=FP8LinearRecipe(
# accum_dtype=DTypes.KFLOAT16,
Expand All @@ -91,12 +103,7 @@
# # split_accumulator=FP8SplitAccumulator(output=False, input_grad=True, weight_grad=True), # NOTE: msamp use this
# split_accumulator=FP8SplitAccumulator(output=True, input_grad=True, weight_grad=True),
# ),
optim=FP8OptimRecipe(
accum_dtype=DTypes.KFLOAT32,
master_weight_dtype=DTypes.KFLOAT16,
exp_avg_dtype=DTypes.FP8E4M3,
exp_avg_sq_dtype=DTypes.KFLOAT16,
),
optim=FP8LM_OPTIM_RECIPE,
)

### FOR DYNAMIC LOSS SCALING ###
Expand Down
11 changes: 7 additions & 4 deletions src/nanotron/fp8/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from nanotron.fp8.constants import QTYPE_TO_DTYPE
from nanotron.fp8.dtypes import DTypes
from nanotron.fp8.linear import FP8LinearMeta
from nanotron.fp8.recipe import FP8LinearRecipe
from nanotron.fp8.tensor import FP8Tensor


Expand Down Expand Up @@ -72,12 +73,14 @@ def linear(
input: torch.Tensor,
weight: FP8Tensor,
bias: Optional[torch.Tensor] = None,
accum_qtype: DTypes = None,
# accum_qtype: DTypes = None,
metadatas: FP8LinearMeta = None,
recipe: FP8LinearRecipe = None,
name: Optional[str] = None,
):
assert accum_qtype is not None, "accum_qtype must be specified"
# assert accum_qtype is not None, "accum_qtype must be specified"
assert metadatas is not None, "metadatas must be specified"
assert recipe is not None, "recipe must be specified"
assert input.device != torch.device("cpu"), "FP8Linear only supports CUDA tensors"
# return addmm(input=bias, mat1=input, mat2=weight.transpose_fp8(), output=output, accum_qtype=accum_qtype, metadatas=metadatas)

Expand All @@ -101,8 +104,8 @@ def linear(
# because weight and bias's requires_grad are set to False
# so that we can compute the gradients using the fp8 kernels by ourselves
phony = torch.empty(0, device=input.device, requires_grad=True)
output = torch.zeros(input.shape[0], weight.shape[0], device="cuda", dtype=QTYPE_TO_DTYPE[accum_qtype])
output, _ = _FP8Matmul.apply(input, weight, output, phony, metadatas, accum_qtype, name)
output = torch.zeros(input.shape[0], weight.shape[0], device="cuda", dtype=QTYPE_TO_DTYPE[recipe.accum_dtype])
output, _ = _FP8Matmul.apply(input, weight, output, phony, metadatas, recipe, name)

# TODO(xrsrke): add support for adding bias in fp8
# TODO(xrsrke): support return an fp8 tensor as output
Expand Down
55 changes: 31 additions & 24 deletions src/nanotron/fp8/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@ def fp8_matmul_kernel(
transpose_b: bool,
output,
use_split_accumulator: bool,
accumulate: bool,
accum_qtype: DTypes,
# TODO(xrsrke): remove this flag
is_backward: bool = False,
recipe=None,
) -> torch.Tensor:
assert (
mat_a.device != "cpu" and mat_b.device != "cpu"
Expand Down Expand Up @@ -45,7 +47,7 @@ def fp8_matmul_kernel(
# output = torch.empty(mat_b.shape[0], mat_a.shape[0], device=device, dtype=out_torch_dtype)

workspace = torch.empty(33_554_432, dtype=torch.int8, device=device)
accumulate = False
# accumulate = False

# NOTE: currently TE don't support adding bias in FP8
# along with matmul, it only takes an empty bias
Expand Down Expand Up @@ -98,28 +100,33 @@ def fp8_matmul_kernel(
# output = torch.empty(mat_b.shape[0], mat_a.shape[0], device=device, dtype=out_torch_dtype)
# # output = torch.empty(mat_a.shape[-1], mat_b.shape[-1], device=device, dtype=out_torch_dtype)

tex.te_gemm(
mat_a,
mat_a_fp8_meta.inverse_scale,
mat_a_fp8_meta.te_dtype,
TE_CONFIG_TRANSPOSE_A,
mat_b,
mat_b_fp8_meta.inverse_scale,
mat_b_fp8_meta.te_dtype,
TE_CONFIG_TRANSPOSE_B,
output,
SCALE,
out_dtype,
AMAX,
bias,
out_dtype,
_empty_tensor,
TE_CONFIG_TRANSPOSE_BIAS,
workspace,
workspace.shape[0],
accumulate,
use_split_accumulator,
0,
)
try:
tex.te_gemm(
mat_a,
mat_a_fp8_meta.inverse_scale,
mat_a_fp8_meta.te_dtype,
TE_CONFIG_TRANSPOSE_A,
mat_b,
mat_b_fp8_meta.inverse_scale,
mat_b_fp8_meta.te_dtype,
TE_CONFIG_TRANSPOSE_B,
output,
SCALE,
out_dtype,
AMAX,
bias,
out_dtype,
_empty_tensor,
TE_CONFIG_TRANSPOSE_BIAS,
workspace,
workspace.shape[0],
accumulate,
use_split_accumulator,
0,
)
except RuntimeError:
raise RuntimeError(
f"mat_a_fp8_meta.te_dtype: {mat_a_fp8_meta.te_dtype}, mat_b_fp8_meta.te_dtype: {mat_b_fp8_meta.te_dtype}, out_dtype: {out_dtype}, recipe: {recipe}"
)

return output
Loading

0 comments on commit 7557757

Please sign in to comment.