Skip to content

Commit

Permalink
fix nan in fwd pass
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Dec 18, 2024
1 parent 79341ea commit b440408
Show file tree
Hide file tree
Showing 10 changed files with 1,853 additions and 16 deletions.

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions src/nanotron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,9 @@ def get_config_from_dict(
for k, v in config_dict.items()
if v is not None
}

from nanotron.fp8.dtypes import DTypes

return from_dict(
data_class=config_class,
data=config_dict,
Expand All @@ -455,6 +458,7 @@ def get_config_from_dict(
TensorParallelLinearMode: lambda x: TensorParallelLinearMode[x.upper()],
RecomputeGranularity: lambda x: RecomputeGranularity[x.upper()],
SamplerType: lambda x: SamplerType[x.upper()],
DTypes: lambda x: DTypes[x.upper()], # Add this line,
},
# strict_unions_match=True,
strict=True,
Expand Down
1 change: 1 addition & 0 deletions src/nanotron/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@

# TODO(xrsrke): refactor
CPU_WEIGHTS = {}
ACCUM_GRADS = {}
8 changes: 7 additions & 1 deletion src/nanotron/fp8/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from nanotron.fp8.linear import FP8LinearMeta
from nanotron.fp8.recipe import FP8LinearRecipe
from nanotron.fp8.tensor import FP8Tensor
from nanotron.fp8.utils import is_overflow_underflow_nan
from nanotron.parallel.parameters import NanotronParameter


Expand Down Expand Up @@ -74,8 +75,13 @@ def linear(
# because weight and bias's requires_grad are set to False
# so that we can compute the gradients using the fp8 kernels by ourselves
phony = torch.empty(0, device=input.device, requires_grad=True)
output = torch.empty(input.shape[0], weight.shape[0], device="cuda", dtype=recipe.accum_dtype)
# NOTE: interesting that if i initialize the output buffer as torch.empty
# it leads to nan matmul, so i do torch.zeros instead
# output = torch.empty(input.shape[0], weight.shape[0], device="cuda", dtype=recipe.accum_dtype)
output = torch.zeros(input.shape[0], weight.shape[0], device="cuda", dtype=recipe.accum_dtype)
output, _ = _FP8Matmul.apply(input, weight, output, phony, metadatas, recipe, name)
if is_overflow_underflow_nan(output) is True:
assert 1 == 1

# TODO(xrsrke): add support for adding bias in fp8
# TODO(xrsrke): support return an fp8 tensor as output
Expand Down
18 changes: 13 additions & 5 deletions src/nanotron/fp8/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[

from nanotron import constants
from nanotron.config.fp8_config import FP8Args
from nanotron.fp8.utils import is_overflow_underflow_nan

# pydevd.settrace(suspend=False, trace_only_current_thread=True)
if (
Expand Down Expand Up @@ -200,6 +201,8 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[
fp8_input = cast(FP8Tensor, fp8_input)
fp8_weight = cast(FP8Tensor, fp8_weight)

assert is_overflow_underflow_nan(grad_output) is False, f"name: {ctx.name}"

ctx.metadatas = cast(FP8LinearMeta, ctx.metadatas)
if ctx.metadatas.input_grad is None:
fp8_grad_output = FP8Tensor(
Expand All @@ -214,7 +217,8 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[

if ctx.is_input_require_grad:
transposed_fp8_weight = fp8_weight.transpose_fp8()
grad_input_temp = torch.empty(
# NOTE: same reason as output buffer in .forward
grad_input_temp = torch.zeros(
fp8_grad_output.shape[0],
transposed_fp8_weight.shape[0],
device="cuda",
Expand All @@ -232,11 +236,14 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[
else:
grad_input = None

assert is_overflow_underflow_nan(grad_input) is False

# TODO(xrsrke): fuse cast and transpose
transposed_fp8_grad_output = fp8_grad_output.transpose_fp8()
transposed_fp8_input = fp8_input.transpose_fp8()

grad_weight_temp = torch.empty(
# NOTE: same reason as output buffer in .forward
grad_weight_temp = torch.zeros(
transposed_fp8_input.shape[0],
transposed_fp8_grad_output.shape[0],
device="cuda",
Expand All @@ -250,6 +257,7 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[
accumulate=recipe.accumulate.weight_grad,
accum_qtype=recipe.accum_dtype,
)
assert is_overflow_underflow_nan(grad_weight) is False

if ctx.is_input_require_grad:
assert grad_input.dtype == recipe.accum_dtype
Expand All @@ -272,8 +280,8 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[
# File "/fsx/phuc/temp/temp3_env_for_fp8/env/lib/python3.10/site-packages/torch/_tensor.py", line 1386, in __torch_function__
# ret = func(*args, **kwargs)
# RuntimeError: attempting to assign a gradient with dtype 'c10::BFloat16' to a tensor with dtype 'unsigned char'. Please ensure that the gradient and the tensor have the same dtype
fp8_weight.__accum_grad = grad_weight
assert fp8_weight.__accum_grad.dtype in [torch.float16, torch.bfloat16, torch.float32]
# fp8_weight.__accum_grad = grad_weight
# assert fp8_weight.__accum_grad.dtype in [torch.float16, torch.bfloat16, torch.float32]
# constants.ACCUM_GRADS[ctx.name] = grad_weight
set_accum_grad(ctx.name, grad_weight)
else:
Expand All @@ -295,4 +303,4 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[
# NOTE: sanity check
assert isinstance(fp8_weight_param.grad, FP8Tensor)

return grad_input, fp8_weight_grad, None, None, None, None, None
return grad_input, None, None, None, None, None, None
22 changes: 14 additions & 8 deletions src/nanotron/fp8/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,18 +239,23 @@ def find_fp8_config_by_module_name(target_module_name: str, config: FP8Args) ->

if config.model is not None:
for layer_args in config.model:
if layer_args.module_name == target_module_name:
if layer_args.module_name == target_module_name.replace("pp_block.", "").replace("module.", ""):
return layer_args
# elif config.is_quant_all_except_first_and_last:
else:

def match_layer_pattern(name, layer_idxs):
# patterns = [
# "model.decoder.{}.pp_block.attn.qkv_proj",
# "model.decoder.{}.pp_block.attn.o_proj",
# "model.decoder.{}.pp_block.mlp.down_proj",
# "model.decoder.{}.pp_block.mlp.gate_up_proj",
# ]
patterns = [
"model.decoder.{}.pp_block.attn.qkv_proj",
"model.decoder.{}.pp_block.attn.o_proj",
"model.decoder.{}.pp_block.mlp.down_proj",
# "model.decoder.{}.mlp.up_proj",
"model.decoder.{}.pp_block.mlp.gate_up_proj",
"model.decoder.{}.attn.qkv_proj",
"model.decoder.{}.attn.o_proj",
"model.decoder.{}.mlp.down_proj",
"model.decoder.{}.mlp.gate_up_proj",
]

for idx in layer_idxs:
Expand All @@ -267,12 +272,13 @@ def match_layer_pattern(name, layer_idxs):
# assert config.fp8_linear_config_temp is not None

quant_layer_idxs = list(range(1, num_layers - 1))
if match_layer_pattern(target_module_name, quant_layer_idxs) is True:
# NOTE: remove ".pp_block" from module name
if match_layer_pattern(target_module_name.replace(".pp_block", ""), quant_layer_idxs) is True:
from copy import deepcopy

# config_temp = deepcopy(config.fp8_linear_config_temp)
config_temp = deepcopy(FP8LM_LINEAR_RECIPE)
# config_temp.module_name = target_module_name
config_temp.module_name = target_module_name
return config_temp
# else:
# from nanotron.fp8.constant_recipe import MODULE_NAMES_THAT_NOT_FP8
Expand Down
16 changes: 16 additions & 0 deletions src/nanotron/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,3 +737,19 @@ def get_consumed_train_samples_of_a_data_stage_from_ckp(
(s.consumed_train_samples for s in metadata.data_stages if s.start_training_step == start_training_step),
None,
)


def get_accum_grad(param_name):
from nanotron import constants

assert "bias" not in param_name
# return constants.ACCUM_GRADS[param_name.replace("weight", "")]
return constants.ACCUM_GRADS[param_name.replace(".weight", "").replace(".pp_block", "")]


def set_accum_grad(param_name, value):
from nanotron import constants

assert "bias" not in param_name
# constants.ACCUM_GRADS[param_name.replace("weight", "")] = value
constants.ACCUM_GRADS[param_name.replace(".weight", "").replace(".pp_block", "")] = value
39 changes: 39 additions & 0 deletions src/nanotron/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from nanotron import logging
from nanotron.config import Config, LlamaConfig, ParallelismArgs
from nanotron.config.models_config import RandomInit, SpectralMupInit
from nanotron.fp8.utils import is_overflow_underflow_nan
from nanotron.generation.generate_store import AttachableStore
from nanotron.logging import log_rank
from nanotron.models import NanotronModel
Expand Down Expand Up @@ -431,6 +432,7 @@ def __init__(
parallel_config=parallel_config,
layer_idx=layer_idx,
)
self.layer_idx = layer_idx

self.prefill_kv_len = (
config.max_position_embeddings
Expand All @@ -452,6 +454,8 @@ def forward(
) # [seq_length, batch_size, n_local_q_heads * d_qk + 2 * n_local_kv_heads * d_qk]
q_length, batch_size, _ = qkv_states.shape

assert is_overflow_underflow_nan(qkv_states) is False, f"layer_idx: {self.layer_idx}"

if self.is_gqa:
query_states, key_states, value_states = torch.split(
qkv_states,
Expand Down Expand Up @@ -661,6 +665,10 @@ def forward(
key_value_states = torch.cat([key_states.unsqueeze(0), value_states.unsqueeze(0)], dim=0)
# [batch_size, seq_length, 2, num_heads, d_qk]
key_value_states = key_value_states.permute(1, 2, 0, 3, 4).contiguous()

assert is_overflow_underflow_nan(query_states) is False, f"layer_idx: {self.layer_idx}"
assert is_overflow_underflow_nan(key_value_states) is False, f"layer_idx: {self.layer_idx}"

query_states, key_value_states = self.flash_rotary_embedding(query_states, kv=key_value_states)
# [batch_size, seq_length, num_heads, d_qk]
key_states, value_states = torch.split(key_value_states, 1, dim=2)
Expand All @@ -685,10 +693,19 @@ def forward(
# NOTE: even though in some cases, we accumulate fp8 gemm in bfloat16,
# but since the layer norm are in float32, the resulting output will be in float32
# and flash attention don't support float32 qkv, so we have to cast it back to bfloat16

assert is_overflow_underflow_nan(query_states) is False, f"layer_idx: {self.layer_idx}"
assert is_overflow_underflow_nan(key_states) is False, f"layer_idx: {self.layer_idx}"
assert is_overflow_underflow_nan(value_states) is False, f"layer_idx: {self.layer_idx}"

query_states = query_states.to(torch.bfloat16)
key_states = key_states.to(torch.bfloat16)
value_states = value_states.to(torch.bfloat16)

assert is_overflow_underflow_nan(query_states) is False, f"layer_idx: {self.layer_idx}"
assert is_overflow_underflow_nan(key_states) is False, f"layer_idx: {self.layer_idx}"
assert is_overflow_underflow_nan(value_states) is False, f"layer_idx: {self.layer_idx}"

attention_output = self.attention(
query_states=query_states,
key_states=key_states,
Expand All @@ -700,6 +717,14 @@ def forward(
attention_output = (
attention_output.contiguous().view(batch_size, q_length, self.n_local_q_heads * self.d_v).transpose(0, 1)
)
from nanotron import constants

if attention_output.dtype != constants.CONFIG.fp8.resid_dtype:
assert is_overflow_underflow_nan(attention_output) is False, f"layer_idx: {self.layer_idx}"
attention_output = attention_output.to(constants.CONFIG.fp8.resid_dtype)
assert is_overflow_underflow_nan(attention_output) is False, f"layer_idx: {self.layer_idx}"

assert is_overflow_underflow_nan(attention_output) is False, f"layer_idx: {self.layer_idx}"
output = self.o_proj(attention_output)

return {"hidden_states": output, "sequence_mask": sequence_mask}
Expand Down Expand Up @@ -730,6 +755,7 @@ def __init__(
self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg, layer_idx=layer_idx)

self.recompute_layer = parallel_config.recompute_layer
self.layer_idx = layer_idx

def _core_forward(
self,
Expand All @@ -738,15 +764,22 @@ def _core_forward(
) -> List[Union[torch.Tensor, TensorPointer]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
assert is_overflow_underflow_nan(hidden_states) is False, f"layer_idx: {self.layer_idx}"

output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask)
hidden_states = output["hidden_states"]
assert is_overflow_underflow_nan(hidden_states) is False, f"layer_idx: {self.layer_idx}"
hidden_states = hidden_states + residual
assert is_overflow_underflow_nan(hidden_states) is False, f"layer_idx: {self.layer_idx}"

residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
assert is_overflow_underflow_nan(hidden_states) is False, f"layer_idx: {self.layer_idx}"

hidden_states = self.mlp(hidden_states=hidden_states)["hidden_states"]
assert is_overflow_underflow_nan(hidden_states) is False, f"layer_idx: {self.layer_idx}"
hidden_states = hidden_states + residual
assert is_overflow_underflow_nan(hidden_states) is False, f"layer_idx: {self.layer_idx}"

return hidden_states, output["sequence_mask"]

Expand Down Expand Up @@ -920,14 +953,20 @@ def forward_with_hidden_states(
"hidden_states": output["input_embeds"],
"sequence_mask": input_mask,
}
assert is_overflow_underflow_nan(hidden_encoder_states["hidden_states"]) is False

for encoder_block in self.decoder:
hidden_encoder_states = encoder_block(**hidden_encoder_states)
assert is_overflow_underflow_nan(hidden_encoder_states["hidden_states"]) is False

hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"]
assert is_overflow_underflow_nan(hidden_states) is False

sharded_logits = self.lm_head(x=hidden_states)["logits"]
assert is_overflow_underflow_nan(sharded_logits) is False

fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"]
assert is_overflow_underflow_nan(fp32_sharded_logits) is False

return fp32_sharded_logits, hidden_states

Expand Down
3 changes: 2 additions & 1 deletion src/nanotron/optim/gradient_accumulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def _accumulate_grad(self, name: str, half_param: NanotronParameter) -> None:

from nanotron.fp8.utils import is_overflow_underflow_nan

assert is_overflow_underflow_nan(grad) is False
assert is_overflow_underflow_nan(grad) is False, f"name: {name}"

fp32_grad = self.get_grad_buffer(name=name)

Expand Down Expand Up @@ -324,6 +324,7 @@ def _accumulate_grad(self, name: str, half_param: NanotronParameter) -> None:
else:
grad = fp32_grad
fp32_param.grad = grad
assert is_overflow_underflow_nan(fp32_param.grad) is False

@contextmanager
def no_sync(self):
Expand Down
2 changes: 1 addition & 1 deletion src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def __init__(
# from nanotron import constants
for n, p in self.model.named_parameters():
if hasattr(p, "_is_future_fp8") and p._is_future_fp8 is True:
constants.CPU_WEIGHTS[n] = p.data.cpu().clone()
constants.CPU_WEIGHTS[n.replace("module.", "")] = p.data.cpu().clone()

# NOTE: sanity check all hash are different
param_hash = []
Expand Down

0 comments on commit b440408

Please sign in to comment.