diff --git a/fast_llm/functional/triton/normalization.py b/fast_llm/functional/triton/normalization.py index 96d1663f7..00b8891e8 100644 --- a/fast_llm/functional/triton/normalization.py +++ b/fast_llm/functional/triton/normalization.py @@ -1,3 +1,4 @@ +import math import typing import torch @@ -306,3 +307,340 @@ def rms_norm(input_: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Te input_dtype = input_.dtype input_ = input_.to(torch.float32) return (weight * input_ * torch.rsqrt(input_.pow(2).mean(dim=-1, keepdim=True) + eps)).to(dtype=input_dtype) + + +# from mamba2 +@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None}) +@triton.jit +def _layer_norm_fwd_1pass_kernel( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + Z, # pointer to the other branch + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_z_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_N: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_Z: tl.constexpr, + NORM_BEFORE_GATE: tl.constexpr, + IS_RMS_NORM: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + group = tl.program_id(1) + X += row * stride_x_row + group * N + Y += row * stride_y_row + group * N + if HAS_Z: + Z += row * stride_z_row + group * N + if not IS_RMS_NORM: + Mean += group * M + Rstd += group * M + W += group * N + if HAS_BIAS: + B += group * N + # Compute mean and variance + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_Z and not NORM_BEFORE_GATE: + z = tl.load(Z + cols, mask=cols < N).to(tl.float32) + x *= z * tl.sigmoid(z) + if not IS_RMS_NORM: + mean = tl.sum(x, axis=0) / N + tl.store(Mean + row, mean) + xbar = tl.where(cols < N, x - mean, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + else: + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + w = tl.load(W + cols, mask=mask).to(tl.float32) + if HAS_BIAS: + b = tl.load(B + cols, mask=mask).to(tl.float32) + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + y = x_hat * w + b if HAS_BIAS else x_hat * w + if HAS_Z and NORM_BEFORE_GATE: + z = tl.load(Z + cols, mask=mask).to(tl.float32) + y *= z * tl.sigmoid(z) + # Write output + tl.store(Y + cols, y, mask=mask) + + +def _layer_norm_fwd(x, weight, bias, eps, z=None, out=None, group_size=None, norm_before_gate=True, is_rms_norm=False): + M, N = x.shape + if group_size is None: + group_size = N + assert N % group_size == 0 + ngroups = N // group_size + assert x.stride(-1) == 1 + if z is not None: + assert z.stride(-1) == 1 + assert z.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + # allocate output + if out is not None: + assert out.shape == x.shape + else: + out = torch.empty_like(x) + assert out.stride(-1) == 1 + mean = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None + rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device) + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) + if group_size > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + num_warps = min(max(BLOCK_N // 256, 1), 8) + grid = (M, ngroups) + with torch.cuda.device(x.device.index): + _layer_norm_fwd_1pass_kernel[grid]( + x, + out, + weight, + bias, + z, + mean, + rstd, + x.stride(0), + out.stride(0), + z.stride(0) if z is not None else 0, + M, + group_size, + eps, + BLOCK_N=BLOCK_N, + NORM_BEFORE_GATE=norm_before_gate, + IS_RMS_NORM=is_rms_norm, + num_warps=num_warps, + ) + return out, mean, rstd + + +@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None}) +@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) +@triton.jit +def _layer_norm_bwd_kernel( + X, # pointer to the input + W, # pointer to the weights + B, # pointer to the biases + Z, # pointer to the other branch + Y, # pointer to the output to be recomputed + DY, # pointer to the output gradient + DX, # pointer to the input gradient + DW, # pointer to the partial sum of weights gradient + DB, # pointer to the partial sum of biases gradient + DZ, # pointer to the other branch + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_z_row, + stride_y_row, + stride_dy_row, + stride_dx_row, + stride_dz_row, + stride_dw_row, + stride_db_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + rows_per_program, + NORM_BEFORE_GATE: tl.constexpr, + IS_RMS_NORM: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_Z: tl.constexpr, + RECOMPUTE_OUTPUT: tl.constexpr, + BLOCK_N: tl.constexpr, +): + # Map the program id to the elements of X, DX, and DY it should compute. + row_block_id = tl.program_id(0) + group = tl.program_id(1) + row_start = row_block_id * rows_per_program + cols = tl.arange(0, BLOCK_N) + mask = cols < N + X += row_start * stride_x_row + group * N + if HAS_Z: + Z += row_start * stride_z_row + group * N + DZ += row_start * stride_dz_row + group * N + DY += row_start * stride_dy_row + group * N + DX += row_start * stride_dx_row + group * N + if RECOMPUTE_OUTPUT: + Y += row_start * stride_y_row + group * N + if not IS_RMS_NORM: + Mean += group * M + Rstd += group * M + W += group * N + w = tl.load(W + cols, mask=mask).to(tl.float32) + if (RECOMPUTE_OUTPUT or HAS_Z) and HAS_BIAS: + B += group * N + b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32) + dw = tl.zeros((BLOCK_N,), dtype=tl.float32) + if HAS_BIAS: + db = tl.zeros((BLOCK_N,), dtype=tl.float32) + row_end = min((row_block_id + 1) * rows_per_program, M) + for row in range(row_start, row_end): + # Load data to SRAM + x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + if not IS_RMS_NORM: + mean = tl.load(Mean + row) + if HAS_Z and not NORM_BEFORE_GATE: + z = tl.load(Z + cols, mask=mask, other=0.0).to(tl.float32) + x_og = x + x = x_og * z * tl.sigmoid(z) + rstd = tl.load(Rstd + row) + # Compute dx + xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + xhat = tl.where(mask, xhat, 0.0) + if HAS_Z and NORM_BEFORE_GATE: + z = tl.load(Z + cols, mask=mask, other=0.0).to(tl.float32) + z_sigmoid = tl.sigmoid(z) + y = xhat * w + b if HAS_BIAS else xhat * w + if RECOMPUTE_OUTPUT: + tl.store(Y + cols, y * z * z_sigmoid, mask=mask) + dz = dy * y * z_sigmoid * (1 + z * (1 - z_sigmoid)) + tl.store(DZ + cols, dz, mask=mask) + dy *= z * z_sigmoid + else: + if RECOMPUTE_OUTPUT: + y = xhat * w + b if HAS_BIAS else xhat * w + tl.store(Y + cols, y, mask=mask) + wdy = w * dy + c1 = tl.sum(xhat * wdy, axis=0) / N + if not IS_RMS_NORM: + c2 = tl.sum(wdy, axis=0) / N + dx = (wdy - (xhat * c1 + c2)) * rstd + else: + dx = (wdy - xhat * c1) * rstd + dw += dy * xhat + if HAS_BIAS: + db += dy + if HAS_Z and not NORM_BEFORE_GATE: + z_sigmoid = tl.sigmoid(z) + dz = dx * x_og * z_sigmoid * (1 + z * (1 - z_sigmoid)) + tl.store(DZ + cols, dz, mask=mask) + dx *= z * z_sigmoid + # Write dx + tl.store(DX + cols, dx, mask=mask) + + X += stride_x_row + if HAS_Z: + Z += stride_z_row + DZ += stride_dz_row + if RECOMPUTE_OUTPUT: + Y += stride_y_row + DY += stride_dy_row + DX += stride_dx_row + tl.store(DW + row_block_id * stride_dw_row + group * N + cols, dw, mask=mask) + if HAS_BIAS: + tl.store(DB + row_block_id * stride_db_row + group * N + cols, db, mask=mask) + + +def _layer_norm_bwd( + dy, + x, + weight, + bias, + eps, + mean, + rstd, + z=None, + group_size=None, + norm_before_gate=True, + is_rms_norm=False, + recompute_output=False, + dz=None, + out=None, +): + M, N = x.shape + if group_size is None: + group_size = N + assert N % group_size == 0 + ngroups = N // group_size + assert x.stride(-1) == 1 + assert dy.stride(-1) == 1 + assert dy.shape == (M, N) + if z is not None: + assert z.stride(-1) == 1 + assert z.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + # allocate output + dx = torch.empty_like(x) + if dz is not None: + assert z is not None + assert dz.shape == z.shape + assert dz.stride(-1) == 1 + else: + dz = torch.empty_like(z) if z is not None else None + if recompute_output: + if out is None: + out = torch.empty_like(x) + assert out.shape == x.shape + + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) + if group_size > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + num_warps = min(max(BLOCK_N // 256, 1), 8) + sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count + # If group size is small (e.g., 64), we're only using 1 warp. So having just 108 programs + # would limit the occupancy. + nrow_groups = math.ceil(sm_count * math.ceil(4 / num_warps) / ngroups) + _dw = torch.empty((nrow_groups, N), dtype=torch.float32, device=weight.device) + _db = torch.empty((nrow_groups, N), dtype=torch.float32, device=bias.device) if bias is not None else None + rows_per_program = math.ceil(M / nrow_groups) + grid = (nrow_groups, ngroups) + with torch.cuda.device(x.device.index): + _layer_norm_bwd_kernel[grid]( + x, + weight, + bias, + z, + out if recompute_output else None, + dy, + dx, + _dw, + _db, + dz, + mean, + rstd, + x.stride(0), + z.stride(0) if z is not None else 0, + 0 if not recompute_output else out.stride(0), + dy.stride(0), + dx.stride(0), + dz.stride(0) if dz is not None else 0, + _dw.stride(0), + _db.stride(0) if _db is not None else 0, + M, + group_size, + eps, + rows_per_program, + BLOCK_N=BLOCK_N, + NORM_BEFORE_GATE=norm_before_gate, + IS_RMS_NORM=is_rms_norm, + num_warps=num_warps, + ) + dw = _dw.sum(0).to(weight.dtype) + db = _db.sum(0).to(bias.dtype) if bias is not None else None + return (dx, dw, db, dz) if not recompute_output else (dx, dw, db, dz, out) diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 07dadbc22..94bc10d5c 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -38,6 +38,16 @@ class NormalizationImplementation(str, enum.Enum): triton = "triton" +class TPRMSNormImplementation(str, enum.Enum): + """ + An enum for the available implementations of rms norm. + """ + + fused_redtensor = "fused_redtensor" + autograd_redstats = "autograd_redstats" + torch_comp_redstats = "torch_comp_redstats" + + @config_class(registry=True) class NormalizationConfig(BaseModelConfig): pass diff --git a/fast_llm/layers/common/normalization.py b/fast_llm/layers/common/normalization.py index bccc1d627..37612154b 100644 --- a/fast_llm/layers/common/normalization.py +++ b/fast_llm/layers/common/normalization.py @@ -1,10 +1,11 @@ import torch +from torch.distributed import ProcessGroup, ReduceOp, all_reduce # noqa from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.functional.config import TritonConfig -from fast_llm.functional.triton.normalization import triton_normalization_autograd -from fast_llm.layers.common.config import NormalizationImplementation +from fast_llm.functional.triton.normalization import _layer_norm_bwd, _layer_norm_fwd, triton_normalization_autograd +from fast_llm.layers.common.config import NormalizationImplementation, TPRMSNormImplementation from fast_llm.tensor import ParameterMeta, accumulate_gradient, init_ones_, init_zeros_ from fast_llm.utils import Assert @@ -243,6 +244,7 @@ def __init__( ): super().__init__() assert not hidden_dim.is_parallel + self._eps = eps self._zero_centered = zero_centered if implementation == NormalizationImplementation.auto: @@ -288,3 +290,107 @@ def _forward_fused(self, input_: torch.Tensor) -> torch.Tensor: def _forward_torch(self, input_: torch.Tensor) -> torch.Tensor: return torch.rms_norm(input_.to(self.weight.dtype), self.normalized_shape, self.weight, self._eps) + + +class MambaRMSNormGated(torch.nn.Module): + def __init__( + self, + hidden_dim: TensorDim, + *, + eps=1e-5, + implementation: NormalizationImplementation = TPRMSNormImplementation.autograd_redstats, + weight_init_method=None, + zero_centered: bool = False, + lr_scale: float | None = None, + group_size: int | None = None, + norm_before_gate: bool = True, + ): + """ + RMS Norm per group of size group_size + """ + super().__init__() + self.hidden_dim = hidden_dim + self.group_size = group_size + self.norm_before_gate = norm_before_gate + + self._eps = eps + self._zero_centered = zero_centered + + if weight_init_method is None: + weight_init_method = init_zeros_ if self._zero_centered else init_ones_ + # self._forward = self._forward_distributed + + self.weight = ParameterMeta.from_dims( # local weights + (hidden_dim,), + init_method=weight_init_method, + weight_decay=False, + auto_grad_accumulation=True, + lr_scale=lr_scale, + ) + self.normalized_shape = self.weight.shape + + # def forward(self, input_: torch.Tensor) -> torch.Tensor: + # return self._forward(input_) + + def forward(self, input_: torch.Tensor, gate: torch.Tensor | None = None) -> torch.Tensor: + return LayerNormFn.apply( + input_, self.weight, None, gate, self._eps, self.group_size, self.norm_before_gate, True + ) + + +class LayerNormFn(torch.autograd.Function): + @staticmethod + def forward(ctx, x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, is_rms_norm=False): + """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))""" + + x_shape_og = x.shape + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if z is not None: + assert z.shape == x_shape_og + z = z.reshape(-1, z.shape[-1]) + if z.stride(-1) != 1: + z = z.contiguous() + weight = weight.contiguous() + if bias is not None: + bias = bias.contiguous() + y, mean, rstd = _layer_norm_fwd( + x, + weight, + bias, + eps, + z=z, + group_size=group_size, + norm_before_gate=norm_before_gate, + is_rms_norm=is_rms_norm, + ) + ctx.save_for_backward(x, weight, bias, mean, rstd, z) + ctx.x_shape_og = x_shape_og + ctx.eps = eps + ctx.group_size = group_size + ctx.norm_before_gate = norm_before_gate + ctx.is_rms_norm = is_rms_norm + return y.reshape(x_shape_og) + + @staticmethod + def backward(ctx, dy): + x, weight, bias, mean, rstd, z = ctx.saved_tensors + dy = dy.reshape(-1, dy.shape[-1]) + if dy.stride(-1) != 1: + dy = dy.contiguous() + assert dy.shape == x.shape + dx, dw, db, dz = _layer_norm_bwd( + dy, x, weight, bias, ctx.eps, mean, rstd, z, ctx.group_size, ctx.norm_before_gate, ctx.is_rms_norm + ) + return ( + dx.reshape(ctx.x_shape_og), + dw, + db, + dz.reshape(ctx.x_shape_og) if dz is not None else None, + None, + None, + None, + None, + ) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 194063a26..b1ae657d2 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -38,6 +38,7 @@ class SSMDimNames: head_dim = "ssm_head_dim" head_groups = "ssm_head_groups" group_heads = "ssm_group_heads" + conv1d_dim = "ssm_conv1d_dim" # Mamba 2 x_proj_dim_2 = "x_proj_dim_2" # d_xb @@ -48,7 +49,10 @@ class SSMDimNames: # Composite dimensions composite_heads = "ssm_composite_heads" composite_heads_and_head_dim = "ssm_composite_heads_and_head_dim" + composite_heads_and_head_dim_nontp = "ssm_composite_heads_and_head_dim_nontp" + composite_heads_and_state_dim = "ssm_composite_heads_and_state_dim" composite_head_groups_and_state = "ssm_composite_head_groups_and_state" + composite_head_groups_and_head = "ssm_composite_head_groups_and_head" # Concatenated dimensions concatenated_convolution = "ssm_concatenated_convolution" @@ -65,6 +69,7 @@ class SSMBlockType(enum.StrEnum): mamba2_discrete = "m2d" mamba2 = "m2" transformer = "t" + nemotron_h_mamba2 = "nm2" def get_mixer_class(self): if self == SSMBlockType.mamba: @@ -79,6 +84,10 @@ def get_mixer_class(self): from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 return DiscreteMamba2 + elif self == SSMBlockType.nemotron_h_mamba2: + from fast_llm.layers.ssm.mamba2 import NemotronHMamba2 + + return NemotronHMamba2 else: raise NotImplementedError(self) @@ -226,6 +235,27 @@ class SSMConfig(LLMBlockConfig): valid=check_field(Assert.gt, 0), ) + # Nemotron H Mamba2 (the real mamba2 actually) + # here instead of setting d_inner, we set head dim. and number of heads + # Note: we do not implement n_groups for Mamba2, because, sicne we do MiL init, we do not want to share B and C parameters accross heads. + # Instead, we mimic the GQA behaviour (x -> v, B -> k, C -> q), where x and B are shared accross heads. So this is the same as having n_groups = n_heads? + n_groups: int = Field( + default=128, + desc="Number of groups for Mamba2. Allows sharing B and C parameters accross heads.", + hint=FieldHint.architecture, + ) + head_dim: int = Field( + default=128, + desc="Head dimension for Nemotron H", + hint=FieldHint.architecture, + ) + + norm_before_gate: bool = Field( + default=False, + desc="Whether to apply the norm before the gate in Mamba2 blocks.", + hint=FieldHint.architecture, + ) + def _validate(self) -> None: with self._set_implicit_default(): if self.activation_type is None: @@ -243,6 +273,11 @@ def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType elif block_type == SSMBlockType.mamba2: num_heads = div(self.d_inner, self.state_size) num_head_groups = div(self.d_xb, self.state_size) + elif block_type == SSMBlockType.nemotron_h_mamba2: + # head dim and state size are not the same + Assert.eq(self.head_dim, self.state_size) # for MIL init + num_heads = self.n_groups # div(self.d_inner, self.head_dim) + num_head_groups = div(self.d_xb, self.head_dim) elif block_type == SSMBlockType.mamba2_discrete: # TODO: Use different variables? num_heads = self.n_v_heads @@ -253,6 +288,8 @@ def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType tensor_space.add_tensor_dim(state := TensorDim(SSMDimNames.state, self.state_size)) if block_type == SSMBlockType.mamba2_discrete: tensor_space.add_tensor_dim(head_dim := TensorDim(SSMDimNames.head_dim, div(self.d_inner, num_heads))) + elif block_type == SSMBlockType.nemotron_h_mamba2: + tensor_space.add_tensor_dim(head_dim := TensorDim(SSMDimNames.head_dim, self.head_dim)) else: head_dim = state @@ -261,14 +298,16 @@ def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType tensor_space.add_tensor_dim( heads := CompositeTensorDim(SSMDimNames.composite_heads, (head_groups, group_heads)) ) + # full d_inner or intermediate_size (e.g. for z gate, also the d_inner size for C in mamba2) tensor_space.add_tensor_dim( heads_and_head_dim := CompositeTensorDim( SSMDimNames.composite_heads_and_head_dim, (head_groups, group_heads, head_dim) ) ) + # d_xb tensor_space.add_tensor_dim( - head_groups_and_state := CompositeTensorDim( - SSMDimNames.composite_head_groups_and_state, (head_groups, state) + head_groups_and_head := CompositeTensorDim( + SSMDimNames.composite_head_groups_and_head, (head_groups, head_dim) ) ) tensor_space.add_tensor_dim(TensorDim(SSMDimNames.convolution_kernel, self.conv_kernel_dimension)) @@ -292,7 +331,41 @@ def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType tensor_space.add_tensor_dim( ConcatenatedTensorDim( SSMDimNames.concatenated_inner_projection, - (heads_and_head_dim, head_groups_and_state, head_groups_and_state, heads_and_head_dim), + (heads_and_head_dim, head_groups_and_head, head_groups_and_head, heads_and_head_dim), + ) + ) + elif block_type == SSMBlockType.nemotron_h_mamba2: + # for the norm + tensor_space.add_tensor_dim( + TensorDim( + SSMDimNames.composite_heads_and_head_dim_nontp, num_head_groups * group_heads.size * head_dim.size + ) + ) + # state and head dim are not the same + # C: for each head, size of state + tensor_space.add_tensor_dim( + heads_and_state_dim := CompositeTensorDim( + SSMDimNames.composite_heads_and_state_dim, (head_groups, group_heads, state) + ) + ) + # B: for each head group, size of state + tensor_space.add_tensor_dim( + head_groups_and_state := CompositeTensorDim( + SSMDimNames.composite_head_groups_and_state, (head_groups, state) + ) + ) + # here we apply depthwise conv. layer to xBC, so the dim. is x (d_xb) x B (d_bb) x C + tensor_space.add_tensor_dim( + conv1d_dim := ConcatenatedTensorDim( + SSMDimNames.conv1d_dim, (heads_and_state_dim, head_groups_and_head, head_groups_and_state) + ) + ) + + # inner projection dimention: also includes z (gate), which has size d_inner (heads_and_head_dim) + tensor_space.add_tensor_dim( + ConcatenatedTensorDim( + SSMDimNames.concatenated_inner_projection, + (conv1d_dim, heads_and_head_dim), ) ) elif block_type == SSMBlockType.mamba2_discrete: diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 5ed689a73..0e1b4d6b6 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -2,21 +2,24 @@ import logging import typing +import einops import torch from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.functional.config import ActivationType from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear +from fast_llm.layers.common.normalization import MambaRMSNormGated from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames, SSMKwargs from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.transformer import Mixer -from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_ +from fast_llm.tensor import LambdaInitializer, ParameterMeta, init_kaiming_, init_ones_, init_uniform_centered_ from fast_llm.utils import Assert, div, get_lr_scale _mamba_varlen = False try: from mamba_ssm.ops.selective_scan_interface import selective_scan_fn # noqa + from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined _mamba_available = True sig = inspect.signature(selective_scan_fn) @@ -75,7 +78,7 @@ def __init__( lr_scale: float | tuple[float | None, ...] | None = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) inner_dim: TensorDim = tensor_space[SSMDimNames.composite_heads_and_head_dim] - xb_dim = tensor_space[SSMDimNames.composite_head_groups_and_state] + xb_dim = tensor_space[SSMDimNames.composite_head_groups_and_head] hidden_dim: TensorDim = tensor_space[TransformerDimNames.hidden] dt_rank_dim = tensor_space[SSMDimNames.dt_rank] @@ -176,7 +179,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ dt = dt.transpose(0, 1) sequence_length = inner_projection.size(1) - + # is this like Mamba1, the conv is only on the x? z, x, b, c = torch.split( inner_projection, [self._local_inner_size, self._local_xb_size, self._local_xb_size, self._local_inner_size], @@ -188,6 +191,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ # x: (batch, sequence, local_head_groups * state) -> (batch, local_heads * state, sequence) x = x.transpose(1, 2) + # x: (batch, local_heads * state, sequence) -> (batch, local_head_per_groups, state, sequence) if self._config.repeat_kv_before_conv: x = ( x.unflatten(1, (self._local_head_groups, self._config.state_size)) @@ -274,3 +278,240 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ # (batch/sequence, sequence/batch, local_heads * state) # -> (batch/local_sequence, local_sequence/batch, hidden) return self.out_proj(y) + + +class NemotronHMamba2(Mixer): + """ + This is the actual Mamab2, called NemotronHMamba2 for historical reasons. + Decompesl, d_state and head_dim. + Head dimention -- later head dimention means me project hidden statte into larger space (more channel mixing) + Larger state size -- more temporar memory. + + This code is adapted from https://huggingface.co/nvidia/Nemotron-H-8B-Base-8K/blob/main/modeling_nemotron_h.py + """ + + _mixer_name: typing.ClassVar[str] = "nm_mamba_2" + + _XZ_DIMS = ( + TransformerDimNames.batch, + SSMDimNames.composite_heads_and_head_dim, + TransformerDimNames.sequence_q, + ) + _BC_DIMS = ( + TransformerDimNames.batch, + SSMDimNames.composite_heads, + SSMDimNames.state, + TransformerDimNames.sequence_q, + ) + + def __init__( + self, + config: SSMConfig, + tensor_space: TensorSpace, + block_index: int, + transformer_config: TransformerConfig, + ): + super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) + self._config: SSMConfig = config + Assert.eq(self._config.activation_type, ActivationType.silu) + layer_lr_scale: float | None = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None + lr_scale: float | tuple[float | None, ...] | None = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) + + inner_dim: TensorDim = tensor_space[SSMDimNames.composite_heads_and_head_dim] + inner_dim_non_tp: TensorDim = tensor_space[SSMDimNames.composite_heads_and_head_dim_nontp] + c_dim: TensorDim = tensor_space[SSMDimNames.composite_heads_and_state_dim] + xb_dim = tensor_space[SSMDimNames.composite_head_groups_and_head] + bb_dim = tensor_space[SSMDimNames.composite_head_groups_and_state] + hidden_dim: TensorDim = tensor_space[TransformerDimNames.hidden] + + self._head_dim_size: TensorDim = tensor_space[SSMDimNames.head_dim].size + self._local_heads = tensor_space[SSMDimNames.composite_heads].size + self._local_head_groups = tensor_space[SSMDimNames.head_groups].size + self._group_heads = div(self._local_heads, self._local_head_groups) + Assert.eq(self._local_heads, self._local_head_groups * self._group_heads) + + self._local_inner_size = inner_dim.size + self._local_c_size = c_dim.size + + Assert.eq(self._local_inner_size, self._head_dim_size * self._local_heads) + self._local_xb_size = xb_dim.size # x has head dim and is for each head group + self._local_bb_size = bb_dim.size # b has state dim and is for each head group + Assert.eq(self._local_xb_size, self._head_dim_size * self._local_head_groups) + Assert.eq(self._local_bb_size, self._config.state_size * self._local_head_groups) + + conv1d_dim = tensor_space[SSMDimNames.conv1d_dim] # applied to xBC, so d_xb + d_bb + c_dim + self.conv1d_weight = ParameterMeta.from_dims( + ( + conv1d_dim, + tensor_space[DefaultDimNames.scalar], + tensor_space[SSMDimNames.convolution_kernel], + ), + init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), + lr_scale=lr_scale, + ) + self.conv1d_bias = ParameterMeta.from_dims( + (conv1d_dim,), + init_method=init_uniform_centered_(self._config.conv_kernel_dimension**-0.5), + lr_scale=lr_scale, + ) + self.in_proj = OutputParallelLinear( + hidden_dim, + tensor_space[SSMDimNames.concatenated_inner_projection], + bias=config.add_bias_linear, + weight_init_method=init_kaiming_(transformer_config.hidden_size), + sequence_parallel=self._sequence_parallel, + lr_scale=lr_scale, + ) + + # project single number per head + self.dt_in_proj = OutputParallelLinear( + hidden_dim, + tensor_space[SSMDimNames.composite_heads], + bias=config.add_bias_linear, + weight_init_method=init_kaiming_(transformer_config.hidden_size), + sequence_parallel=self._sequence_parallel, + lr_scale=lr_scale, + ) + + self.dt_bias = ParameterMeta.from_dims( + (tensor_space[SSMDimNames.composite_heads],), + init_method=init_dtprojbias(self._config.dt_max, self._config.dt_min, self._config.dt_init_floor), + lr_scale=lr_scale, + ) + + def init_A_uniform(A_init_range: tuple[float, float] = (1, 16)) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa + tensor.uniform_(*A_init_range).log_() + + return LambdaInitializer(init_, requires_global_initialization=True) + + self.A_log = ParameterMeta.from_dims( + (tensor_space[SSMDimNames.composite_heads],), + init_method=init_A_uniform(A_init_range=(1, 16)), + lr_scale=lr_scale, + weight_decay=False, + ) + self.D = ParameterMeta.from_dims( + (tensor_space[SSMDimNames.composite_heads],), # can also be nheads x headim + weight_decay=False, + init_method=init_ones_, + lr_scale=lr_scale, + ) + self.out_proj = InputParallelLinear( + inner_dim, + hidden_dim, + bias=config.add_bias_linear, + weight_init_method=init_kaiming_(inner_dim.size), + sequence_parallel=self._sequence_parallel, + lr_scale=lr_scale, + ) + # no need to comunicate stats accross ranks, since we do norm per local heaed here. Would not work if we have a single head, since then we would really need to distribute. + self.norm = MambaRMSNormGated( + inner_dim, + eps=1e-5, + group_size=inner_dim.size // self._local_heads, + lr_scale=lr_scale, + norm_before_gate=config.norm_before_gate, + ) + + def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: + """ """ + assert _mamba_available + assert _causal_conv1d_available + cu_seqlens = kwargs.get(SSMKwargs.cu_seqlens) # kwargs[SSMKwargs.cu_seqlens] + seq_idx = kwargs.get(SSMKwargs.seq_idx) # kwargs[SSMKwargs.seq_idx] + + # inner_projection : (batch/local_sequence, local_sequence/batch, hidden) + # -> (batch/sequence, sequence/batch, inner_projection) + inner_projection = self.in_proj(input_) + dt = self.dt_in_proj(input_) # bs, seq, heads + # Standardize to (batch, sequence, inner_projection) + if kwargs[TransformerKwargs.sequence_first]: + inner_projection = inner_projection.transpose(0, 1) + dt = dt.transpose(0, 1) + # note: self.in_proj gathers full sequence length here + sequence_length = inner_projection.size(1) + + z, xBC = torch.split( + inner_projection, + [self._local_inner_size, self._local_xb_size + self._local_bb_size + self._local_c_size], + dim=2, + ) + + if cu_seqlens is not None: + xBC = _causal_conv1d_fn( + xBC.transpose(1, 2), + weight=self.conv1d_weight.squeeze(1), + bias=self.conv1d_bias, + seq_idx=seq_idx, + activation="silu", + ).transpose(1, 2) + else: + xBC = _causal_conv1d_fn( + x=xBC.transpose(1, 2), weight=self.conv1d_weight.squeeze(1), bias=self.conv1d_bias, activation="silu" + ).transpose(1, 2) + + x, b, c = torch.split(xBC, [self._local_xb_size, self._local_bb_size, self._local_c_size], dim=-1) + # simulate GQA by repeating heads in x,b, x -> v, B -> k, C -> q + x = einops.rearrange( + x, "b l (local_head_groups head_dim) -> b local_head_groups l head_dim", head_dim=self._head_dim_size + ) # x is b x local_head_groups x l x head_dim + b = einops.rearrange( + b, + "b l (local_head_groups state_size) -> b local_head_groups l state_size", + state_size=self._config.state_size, + ) # b is b x local_head_groups x l x state_size + batch, num_key_value_heads, slen, head_dim = x.shape + x = x[:, :, None, :, :].expand(batch, num_key_value_heads, self._group_heads, slen, head_dim) + x = x.reshape(batch, num_key_value_heads * self._group_heads, slen, head_dim) + b = b[:, :, None, :, :].expand(batch, num_key_value_heads, self._group_heads, slen, self._config.state_size) + b = b.reshape(batch, num_key_value_heads * self._group_heads, slen, self._config.state_size) + + if self._debug_level: + self._debug_log(z, "z", self._XZ_DIMS, kwargs) + self._debug_log(x, "x", self._XZ_DIMS, kwargs) + self._debug_log(b, "b", self._BC_DIMS, kwargs) + self._debug_log(c, "c", self._BC_DIMS, kwargs) + self._debug_log(dt, "dt", self._XZ_DIMS, kwargs) + + dt_limit_kwargs = ( + {} + ) # can be used to set time-step limit as in https://huggingface.co/nvidia/Nemotron-H-8B-Base-8K/blob/main/modeling_nemotron_h.py#L424 + # c is b x seq x (heads * state) + # b is b x heads x seq x state) + # x is b x heads x seq x head_dim + # note, we could used mamba_split_conv1d_scan_combined directly for training, however because of the GQA, we need to use the chunked version. + y = mamba_chunk_scan_combined( + einops.rearrange(x, "b g l p -> b l g p"), + dt, + A=-torch.exp(self.A_log.float()), + B=einops.rearrange(b, "b g l n -> b l g n"), + C=einops.rearrange(c, "b l (g n) -> b l g n", g=self._local_heads), + chunk_size=self._config.chunk_size, + D=self.D, + z=None, + dt_bias=self.dt_bias, + dt_softplus=True, + seq_idx=seq_idx, # assume this is used for packing + cu_seqlens=cu_seqlens, # assume this is used for packing, but maybe not needed at training + **dt_limit_kwargs, + return_final_states=False, + return_varlen_states=False, + ) + + if self._debug_level: + self._debug_log(y, "y", self._XZ_DIMS, kwargs) + + # y: (batch, local_heads * state, sequence) -> (batch, sequence, local_heads * state) + y = y.view(batch, sequence_length, -1) + y = self.norm(y, z) + + if kwargs[TransformerKwargs.sequence_first]: + # TODO: Is contiguous needed? + y = y.transpose(0, 1).contiguous() + + # y *= torch.nn.functional.silu(z) # gate after the norm + # (batch/sequence, sequence/batch, local_heads * state) + # -> (batch/local_sequence, local_sequence/batch, hidden) + out = self.out_proj(y) + return out diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index 64afbea06..27cab0195 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -227,6 +227,16 @@ def _create_config_converters(cls) -> list[ParamConverter]: fast_llm_value=lambda x: DTInitType.random if x == MISSING else DTInitType(x), export_value=lambda x: x.value, ), + RenameParamConverter( + fast_llm_names=(("ssm", "norm_before_gate"),), + export_names=( + ( + "ssm_cfg", + "norm_before_gate", + ), + ), + default_value=False, + ), ] def _create_weight_converters( @@ -296,6 +306,9 @@ def _create_weight_converters( converters += self._get_weight_and_bias_converters( f"layers.{offset+i+1}.mixer.dt_proj", f"{hf_base_prefix}model.layers.{i}.mixer.dt_proj", False ) + converters += self._get_weight_and_bias_converters( + f"layers.{offset+i+1}.mixer.norm", f"{hf_base_prefix}model.layers.{i}.mixer.norm", True + ) # bias is treated separately in Mamba2 and must always exist (https://github.com/jxiw/M1/blob/537a1ca5407a786a99dc6c721873493cf8750d5e/mamba/hybrid_mamba_layer.py) converters.append( WeightConverter( @@ -304,6 +317,14 @@ def _create_weight_converters( self._model.config.base_model, ) ) + # for nemotron mamba2, bias is a seperate parameter + converters.append( + WeightConverter( + f"layers.{offset+i+1}.mixer.dt_bias", + f"{hf_base_prefix}model.layers.{i}.mixer.dt_bias", + self._model.config.base_model, + ) + ) converters.append( WeightConverter( @@ -789,6 +810,14 @@ def _create_config_converters(cls) -> list[ParamConverter]: fast_llm_names=(("ssm", "d_inner"),), export_names=(("ssm_cfg", "d_inner"),), ), + RenameParamConverter( + fast_llm_names=(("ssm", "head_dim"),), + export_names=(("ssm_cfg", "head_dim"),), + ), + RenameParamConverter( + fast_llm_names=(("ssm", "n_groups"),), + export_names=(("ssm_cfg", "n_groups"),), + ), IgnoreImportParamConverter(export_names=(("sliding_window",),), ignore_export_value=None), ] diff --git a/fast_llm/models/ssm/external/15B_hybrid.ipynb b/fast_llm/models/ssm/external/15B_hybrid.ipynb index a8f0c33b7..21cb78322 100644 --- a/fast_llm/models/ssm/external/15B_hybrid.ipynb +++ b/fast_llm/models/ssm/external/15B_hybrid.ipynb @@ -745,7 +745,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -846,22 +846,30 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "path_thinker = \"/mnt/checkpoints/upstream/Apriel-Nemotron-15b-Thinker\"\n", - "n_ssm = 25\n", + "# path_hybrid = \"/mnt/checkpoints/fast_llm_exp/slam_ssm_distill/15b-hyb25distsftvrlm2-bs64-lr5e-06-lrs1-1-1-1-sl16384_ti60000_aprsft/export/apriel_ssm_thinker_hybrid/23000\"\n", + "path_hybrid=\"/mnt/checkpoints/fast_llm_exp/slam_ssm_distill/15b-h27distsftvrlm2f145-bs64-lr5e-06-lrs1-1-1-1-sl16384_ti60000_aprsft/export/apriel_ssm_thinker_hybrid/3500\"\n", "\n", + "n_ssm = 25\n", "\n", + "config_hybrid = AprielSSMHybridConfig.from_pretrained(path_hybrid)\n", "config_thinker = AutoConfig.from_pretrained(path_thinker)\n", "# config_thinker.num_hidden_layers = 5\n", - "hybrid_block_layout = [\"t\"] * config_thinker.num_hidden_layers\n", + "# hybrid_block_layout = [\"t\"] * config_thinker.num_hidden_layers\n", "# hybrid_block_layout[3] = \"m2\"\n", + "hybrid_block_layout = config_hybrid.hybrid_block_layout\n", + "# hybrid_block_layout[7] = \"m2\"\n", + "hybrid_block_layout[6] = \"m2\"\n", + "hybrid_block_layout[8] = \"m2\"\n", "\n", - "for i in range(n_ssm):\n", - " hybrid_block_layout[layer_importance[i]] = \"m2\"\n", + "\n", + "# for i in range(n_ssm):\n", + "# hybrid_block_layout[layer_importance[i]] = \"m2\"\n", "\n", "# group_size = 10 # 2nd layer importance is missing\n", "# for i in range(0, len(layer_importance), group_size):\n", @@ -922,7 +930,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -931,12 +939,12 @@ "['t',\n", " 't',\n", " 't',\n", + " 'm2',\n", " 't',\n", + " 'm2',\n", + " 'm2',\n", " 't',\n", - " 't',\n", - " 't',\n", - " 't',\n", - " 't',\n", + " 'm2',\n", " 't',\n", " 't',\n", " 't',\n", @@ -980,7 +988,7 @@ " 'm2']" ] }, - "execution_count": 10, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -991,25 +999,32 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Loading checkpoint shards: 100%|██████████| 7/7 [00:04<00:00, 1.58it/s]\n" + "Loading checkpoint shards: 29%|██▊ | 2/7 [00:01<00:04, 1.13it/s]\n" ] }, { - "data": { - "text/plain": [ - "_IncompatibleKeys(missing_keys=['model.layers.3.mixer.A_log', 'model.layers.3.mixer.D', 'model.layers.3.mixer.conv1d.weight', 'model.layers.3.mixer.conv1d.bias', 'model.layers.3.mixer.in_proj.weight', 'model.layers.3.mixer.dt_proj.weight', 'model.layers.3.mixer.dt_proj.bias', 'model.layers.3.mixer.out_proj.weight'], unexpected_keys=['model.layers.5.self_attn.q_proj.weight', 'model.layers.5.self_attn.k_proj.weight', 'model.layers.5.self_attn.v_proj.weight', 'model.layers.5.self_attn.o_proj.weight', 'model.layers.5.mlp.gate_proj.weight', 'model.layers.5.mlp.up_proj.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.5.input_layernorm.weight', 'model.layers.5.post_attention_layernorm.weight', 'model.layers.6.self_attn.q_proj.weight', 'model.layers.6.self_attn.k_proj.weight', 'model.layers.6.self_attn.v_proj.weight', 'model.layers.6.self_attn.o_proj.weight', 'model.layers.6.mlp.gate_proj.weight', 'model.layers.6.mlp.up_proj.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.6.input_layernorm.weight', 'model.layers.6.post_attention_layernorm.weight', 'model.layers.7.self_attn.q_proj.weight', 'model.layers.7.self_attn.k_proj.weight', 'model.layers.7.self_attn.v_proj.weight', 'model.layers.7.self_attn.o_proj.weight', 'model.layers.7.mlp.gate_proj.weight', 'model.layers.7.mlp.up_proj.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.7.input_layernorm.weight', 'model.layers.7.post_attention_layernorm.weight', 'model.layers.8.self_attn.q_proj.weight', 'model.layers.8.self_attn.k_proj.weight', 'model.layers.8.self_attn.v_proj.weight', 'model.layers.8.self_attn.o_proj.weight', 'model.layers.8.mlp.gate_proj.weight', 'model.layers.8.mlp.up_proj.weight', 'model.layers.8.mlp.down_proj.weight', 'model.layers.8.input_layernorm.weight', 'model.layers.8.post_attention_layernorm.weight', 'model.layers.9.self_attn.q_proj.weight', 'model.layers.9.self_attn.k_proj.weight', 'model.layers.9.self_attn.v_proj.weight', 'model.layers.9.self_attn.o_proj.weight', 'model.layers.9.mlp.gate_proj.weight', 'model.layers.9.mlp.up_proj.weight', 'model.layers.9.mlp.down_proj.weight', 'model.layers.9.input_layernorm.weight', 'model.layers.9.post_attention_layernorm.weight', 'model.layers.10.self_attn.q_proj.weight', 'model.layers.10.self_attn.k_proj.weight', 'model.layers.10.self_attn.v_proj.weight', 'model.layers.10.self_attn.o_proj.weight', 'model.layers.10.mlp.gate_proj.weight', 'model.layers.10.mlp.up_proj.weight', 'model.layers.10.mlp.down_proj.weight', 'model.layers.10.input_layernorm.weight', 'model.layers.10.post_attention_layernorm.weight', 'model.layers.11.self_attn.q_proj.weight', 'model.layers.11.self_attn.k_proj.weight', 'model.layers.11.self_attn.v_proj.weight', 'model.layers.11.self_attn.o_proj.weight', 'model.layers.11.mlp.gate_proj.weight', 'model.layers.11.mlp.up_proj.weight', 'model.layers.11.mlp.down_proj.weight', 'model.layers.11.input_layernorm.weight', 'model.layers.11.post_attention_layernorm.weight', 'model.layers.12.self_attn.q_proj.weight', 'model.layers.12.self_attn.k_proj.weight', 'model.layers.12.self_attn.v_proj.weight', 'model.layers.12.self_attn.o_proj.weight', 'model.layers.12.mlp.gate_proj.weight', 'model.layers.12.mlp.up_proj.weight', 'model.layers.12.mlp.down_proj.weight', 'model.layers.12.input_layernorm.weight', 'model.layers.12.post_attention_layernorm.weight', 'model.layers.13.self_attn.q_proj.weight', 'model.layers.13.self_attn.k_proj.weight', 'model.layers.13.self_attn.v_proj.weight', 'model.layers.13.self_attn.o_proj.weight', 'model.layers.13.mlp.gate_proj.weight', 'model.layers.13.mlp.up_proj.weight', 'model.layers.13.mlp.down_proj.weight', 'model.layers.13.input_layernorm.weight', 'model.layers.13.post_attention_layernorm.weight', 'model.layers.14.self_attn.q_proj.weight', 'model.layers.14.self_attn.k_proj.weight', 'model.layers.14.self_attn.v_proj.weight', 'model.layers.14.self_attn.o_proj.weight', 'model.layers.14.mlp.gate_proj.weight', 'model.layers.14.mlp.up_proj.weight', 'model.layers.14.mlp.down_proj.weight', 'model.layers.14.input_layernorm.weight', 'model.layers.14.post_attention_layernorm.weight', 'model.layers.15.self_attn.q_proj.weight', 'model.layers.15.self_attn.k_proj.weight', 'model.layers.15.self_attn.v_proj.weight', 'model.layers.15.self_attn.o_proj.weight', 'model.layers.15.mlp.gate_proj.weight', 'model.layers.15.mlp.up_proj.weight', 'model.layers.15.mlp.down_proj.weight', 'model.layers.15.input_layernorm.weight', 'model.layers.15.post_attention_layernorm.weight', 'model.layers.16.self_attn.q_proj.weight', 'model.layers.16.self_attn.k_proj.weight', 'model.layers.16.self_attn.v_proj.weight', 'model.layers.16.self_attn.o_proj.weight', 'model.layers.16.mlp.gate_proj.weight', 'model.layers.16.mlp.up_proj.weight', 'model.layers.16.mlp.down_proj.weight', 'model.layers.16.input_layernorm.weight', 'model.layers.16.post_attention_layernorm.weight', 'model.layers.17.self_attn.q_proj.weight', 'model.layers.17.self_attn.k_proj.weight', 'model.layers.17.self_attn.v_proj.weight', 'model.layers.17.self_attn.o_proj.weight', 'model.layers.17.mlp.gate_proj.weight', 'model.layers.17.mlp.up_proj.weight', 'model.layers.17.mlp.down_proj.weight', 'model.layers.17.input_layernorm.weight', 'model.layers.17.post_attention_layernorm.weight', 'model.layers.18.self_attn.q_proj.weight', 'model.layers.18.self_attn.k_proj.weight', 'model.layers.18.self_attn.v_proj.weight', 'model.layers.18.self_attn.o_proj.weight', 'model.layers.18.mlp.gate_proj.weight', 'model.layers.18.mlp.up_proj.weight', 'model.layers.18.mlp.down_proj.weight', 'model.layers.18.input_layernorm.weight', 'model.layers.18.post_attention_layernorm.weight', 'model.layers.19.self_attn.q_proj.weight', 'model.layers.19.self_attn.k_proj.weight', 'model.layers.19.self_attn.v_proj.weight', 'model.layers.19.self_attn.o_proj.weight', 'model.layers.19.mlp.gate_proj.weight', 'model.layers.19.mlp.up_proj.weight', 'model.layers.19.mlp.down_proj.weight', 'model.layers.19.input_layernorm.weight', 'model.layers.19.post_attention_layernorm.weight', 'model.layers.20.self_attn.q_proj.weight', 'model.layers.20.self_attn.k_proj.weight', 'model.layers.20.self_attn.v_proj.weight', 'model.layers.20.self_attn.o_proj.weight', 'model.layers.20.mlp.gate_proj.weight', 'model.layers.20.mlp.up_proj.weight', 'model.layers.20.mlp.down_proj.weight', 'model.layers.20.input_layernorm.weight', 'model.layers.20.post_attention_layernorm.weight', 'model.layers.21.self_attn.q_proj.weight', 'model.layers.21.self_attn.k_proj.weight', 'model.layers.21.self_attn.v_proj.weight', 'model.layers.21.self_attn.o_proj.weight', 'model.layers.21.mlp.gate_proj.weight', 'model.layers.21.mlp.up_proj.weight', 'model.layers.21.mlp.down_proj.weight', 'model.layers.21.input_layernorm.weight', 'model.layers.21.post_attention_layernorm.weight', 'model.layers.22.self_attn.q_proj.weight', 'model.layers.22.self_attn.k_proj.weight', 'model.layers.22.self_attn.v_proj.weight', 'model.layers.22.self_attn.o_proj.weight', 'model.layers.22.mlp.gate_proj.weight', 'model.layers.22.mlp.up_proj.weight', 'model.layers.22.mlp.down_proj.weight', 'model.layers.22.input_layernorm.weight', 'model.layers.22.post_attention_layernorm.weight', 'model.layers.23.self_attn.q_proj.weight', 'model.layers.23.self_attn.k_proj.weight', 'model.layers.23.self_attn.v_proj.weight', 'model.layers.23.self_attn.o_proj.weight', 'model.layers.23.mlp.gate_proj.weight', 'model.layers.23.mlp.up_proj.weight', 'model.layers.23.mlp.down_proj.weight', 'model.layers.23.input_layernorm.weight', 'model.layers.23.post_attention_layernorm.weight', 'model.layers.24.self_attn.q_proj.weight', 'model.layers.24.self_attn.k_proj.weight', 'model.layers.24.self_attn.v_proj.weight', 'model.layers.24.self_attn.o_proj.weight', 'model.layers.24.mlp.gate_proj.weight', 'model.layers.24.mlp.up_proj.weight', 'model.layers.24.mlp.down_proj.weight', 'model.layers.24.input_layernorm.weight', 'model.layers.24.post_attention_layernorm.weight', 'model.layers.25.self_attn.q_proj.weight', 'model.layers.25.self_attn.k_proj.weight', 'model.layers.25.self_attn.v_proj.weight', 'model.layers.25.self_attn.o_proj.weight', 'model.layers.25.mlp.gate_proj.weight', 'model.layers.25.mlp.up_proj.weight', 'model.layers.25.mlp.down_proj.weight', 'model.layers.25.input_layernorm.weight', 'model.layers.25.post_attention_layernorm.weight', 'model.layers.26.self_attn.q_proj.weight', 'model.layers.26.self_attn.k_proj.weight', 'model.layers.26.self_attn.v_proj.weight', 'model.layers.26.self_attn.o_proj.weight', 'model.layers.26.mlp.gate_proj.weight', 'model.layers.26.mlp.up_proj.weight', 'model.layers.26.mlp.down_proj.weight', 'model.layers.26.input_layernorm.weight', 'model.layers.26.post_attention_layernorm.weight', 'model.layers.27.self_attn.q_proj.weight', 'model.layers.27.self_attn.k_proj.weight', 'model.layers.27.self_attn.v_proj.weight', 'model.layers.27.self_attn.o_proj.weight', 'model.layers.27.mlp.gate_proj.weight', 'model.layers.27.mlp.up_proj.weight', 'model.layers.27.mlp.down_proj.weight', 'model.layers.27.input_layernorm.weight', 'model.layers.27.post_attention_layernorm.weight', 'model.layers.28.self_attn.q_proj.weight', 'model.layers.28.self_attn.k_proj.weight', 'model.layers.28.self_attn.v_proj.weight', 'model.layers.28.self_attn.o_proj.weight', 'model.layers.28.mlp.gate_proj.weight', 'model.layers.28.mlp.up_proj.weight', 'model.layers.28.mlp.down_proj.weight', 'model.layers.28.input_layernorm.weight', 'model.layers.28.post_attention_layernorm.weight', 'model.layers.29.self_attn.q_proj.weight', 'model.layers.29.self_attn.k_proj.weight', 'model.layers.29.self_attn.v_proj.weight', 'model.layers.29.self_attn.o_proj.weight', 'model.layers.29.mlp.gate_proj.weight', 'model.layers.29.mlp.up_proj.weight', 'model.layers.29.mlp.down_proj.weight', 'model.layers.29.input_layernorm.weight', 'model.layers.29.post_attention_layernorm.weight', 'model.layers.30.self_attn.q_proj.weight', 'model.layers.30.self_attn.k_proj.weight', 'model.layers.30.self_attn.v_proj.weight', 'model.layers.30.self_attn.o_proj.weight', 'model.layers.30.mlp.gate_proj.weight', 'model.layers.30.mlp.up_proj.weight', 'model.layers.30.mlp.down_proj.weight', 'model.layers.30.input_layernorm.weight', 'model.layers.30.post_attention_layernorm.weight', 'model.layers.31.self_attn.q_proj.weight', 'model.layers.31.self_attn.k_proj.weight', 'model.layers.31.self_attn.v_proj.weight', 'model.layers.31.self_attn.o_proj.weight', 'model.layers.31.mlp.gate_proj.weight', 'model.layers.31.mlp.up_proj.weight', 'model.layers.31.mlp.down_proj.weight', 'model.layers.31.input_layernorm.weight', 'model.layers.31.post_attention_layernorm.weight', 'model.layers.32.self_attn.q_proj.weight', 'model.layers.32.self_attn.k_proj.weight', 'model.layers.32.self_attn.v_proj.weight', 'model.layers.32.self_attn.o_proj.weight', 'model.layers.32.mlp.gate_proj.weight', 'model.layers.32.mlp.up_proj.weight', 'model.layers.32.mlp.down_proj.weight', 'model.layers.32.input_layernorm.weight', 'model.layers.32.post_attention_layernorm.weight', 'model.layers.33.self_attn.q_proj.weight', 'model.layers.33.self_attn.k_proj.weight', 'model.layers.33.self_attn.v_proj.weight', 'model.layers.33.self_attn.o_proj.weight', 'model.layers.33.mlp.gate_proj.weight', 'model.layers.33.mlp.up_proj.weight', 'model.layers.33.mlp.down_proj.weight', 'model.layers.33.input_layernorm.weight', 'model.layers.33.post_attention_layernorm.weight', 'model.layers.34.self_attn.q_proj.weight', 'model.layers.34.self_attn.k_proj.weight', 'model.layers.34.self_attn.v_proj.weight', 'model.layers.34.self_attn.o_proj.weight', 'model.layers.34.mlp.gate_proj.weight', 'model.layers.34.mlp.up_proj.weight', 'model.layers.34.mlp.down_proj.weight', 'model.layers.34.input_layernorm.weight', 'model.layers.34.post_attention_layernorm.weight', 'model.layers.35.self_attn.q_proj.weight', 'model.layers.35.self_attn.k_proj.weight', 'model.layers.35.self_attn.v_proj.weight', 'model.layers.35.self_attn.o_proj.weight', 'model.layers.35.mlp.gate_proj.weight', 'model.layers.35.mlp.up_proj.weight', 'model.layers.35.mlp.down_proj.weight', 'model.layers.35.input_layernorm.weight', 'model.layers.35.post_attention_layernorm.weight', 'model.layers.36.self_attn.q_proj.weight', 'model.layers.36.self_attn.k_proj.weight', 'model.layers.36.self_attn.v_proj.weight', 'model.layers.36.self_attn.o_proj.weight', 'model.layers.36.mlp.gate_proj.weight', 'model.layers.36.mlp.up_proj.weight', 'model.layers.36.mlp.down_proj.weight', 'model.layers.36.input_layernorm.weight', 'model.layers.36.post_attention_layernorm.weight', 'model.layers.37.self_attn.q_proj.weight', 'model.layers.37.self_attn.k_proj.weight', 'model.layers.37.self_attn.v_proj.weight', 'model.layers.37.self_attn.o_proj.weight', 'model.layers.37.mlp.gate_proj.weight', 'model.layers.37.mlp.up_proj.weight', 'model.layers.37.mlp.down_proj.weight', 'model.layers.37.input_layernorm.weight', 'model.layers.37.post_attention_layernorm.weight', 'model.layers.38.self_attn.q_proj.weight', 'model.layers.38.self_attn.k_proj.weight', 'model.layers.38.self_attn.v_proj.weight', 'model.layers.38.self_attn.o_proj.weight', 'model.layers.38.mlp.gate_proj.weight', 'model.layers.38.mlp.up_proj.weight', 'model.layers.38.mlp.down_proj.weight', 'model.layers.38.input_layernorm.weight', 'model.layers.38.post_attention_layernorm.weight', 'model.layers.39.self_attn.q_proj.weight', 'model.layers.39.self_attn.k_proj.weight', 'model.layers.39.self_attn.v_proj.weight', 'model.layers.39.self_attn.o_proj.weight', 'model.layers.39.mlp.gate_proj.weight', 'model.layers.39.mlp.up_proj.weight', 'model.layers.39.mlp.down_proj.weight', 'model.layers.39.input_layernorm.weight', 'model.layers.39.post_attention_layernorm.weight', 'model.layers.40.self_attn.q_proj.weight', 'model.layers.40.self_attn.k_proj.weight', 'model.layers.40.self_attn.v_proj.weight', 'model.layers.40.self_attn.o_proj.weight', 'model.layers.40.mlp.gate_proj.weight', 'model.layers.40.mlp.up_proj.weight', 'model.layers.40.mlp.down_proj.weight', 'model.layers.40.input_layernorm.weight', 'model.layers.40.post_attention_layernorm.weight', 'model.layers.41.self_attn.q_proj.weight', 'model.layers.41.self_attn.k_proj.weight', 'model.layers.41.self_attn.v_proj.weight', 'model.layers.41.self_attn.o_proj.weight', 'model.layers.41.mlp.gate_proj.weight', 'model.layers.41.mlp.up_proj.weight', 'model.layers.41.mlp.down_proj.weight', 'model.layers.41.input_layernorm.weight', 'model.layers.41.post_attention_layernorm.weight', 'model.layers.42.self_attn.q_proj.weight', 'model.layers.42.self_attn.k_proj.weight', 'model.layers.42.self_attn.v_proj.weight', 'model.layers.42.self_attn.o_proj.weight', 'model.layers.42.mlp.gate_proj.weight', 'model.layers.42.mlp.up_proj.weight', 'model.layers.42.mlp.down_proj.weight', 'model.layers.42.input_layernorm.weight', 'model.layers.42.post_attention_layernorm.weight', 'model.layers.43.self_attn.q_proj.weight', 'model.layers.43.self_attn.k_proj.weight', 'model.layers.43.self_attn.v_proj.weight', 'model.layers.43.self_attn.o_proj.weight', 'model.layers.43.mlp.gate_proj.weight', 'model.layers.43.mlp.up_proj.weight', 'model.layers.43.mlp.down_proj.weight', 'model.layers.43.input_layernorm.weight', 'model.layers.43.post_attention_layernorm.weight', 'model.layers.44.self_attn.q_proj.weight', 'model.layers.44.self_attn.k_proj.weight', 'model.layers.44.self_attn.v_proj.weight', 'model.layers.44.self_attn.o_proj.weight', 'model.layers.44.mlp.gate_proj.weight', 'model.layers.44.mlp.up_proj.weight', 'model.layers.44.mlp.down_proj.weight', 'model.layers.44.input_layernorm.weight', 'model.layers.44.post_attention_layernorm.weight', 'model.layers.45.self_attn.q_proj.weight', 'model.layers.45.self_attn.k_proj.weight', 'model.layers.45.self_attn.v_proj.weight', 'model.layers.45.self_attn.o_proj.weight', 'model.layers.45.mlp.gate_proj.weight', 'model.layers.45.mlp.up_proj.weight', 'model.layers.45.mlp.down_proj.weight', 'model.layers.45.input_layernorm.weight', 'model.layers.45.post_attention_layernorm.weight', 'model.layers.46.self_attn.q_proj.weight', 'model.layers.46.self_attn.k_proj.weight', 'model.layers.46.self_attn.v_proj.weight', 'model.layers.46.self_attn.o_proj.weight', 'model.layers.46.mlp.gate_proj.weight', 'model.layers.46.mlp.up_proj.weight', 'model.layers.46.mlp.down_proj.weight', 'model.layers.46.input_layernorm.weight', 'model.layers.46.post_attention_layernorm.weight', 'model.layers.47.self_attn.q_proj.weight', 'model.layers.47.self_attn.k_proj.weight', 'model.layers.47.self_attn.v_proj.weight', 'model.layers.47.self_attn.o_proj.weight', 'model.layers.47.mlp.gate_proj.weight', 'model.layers.47.mlp.up_proj.weight', 'model.layers.47.mlp.down_proj.weight', 'model.layers.47.input_layernorm.weight', 'model.layers.47.post_attention_layernorm.weight', 'model.layers.48.self_attn.q_proj.weight', 'model.layers.48.self_attn.k_proj.weight', 'model.layers.48.self_attn.v_proj.weight', 'model.layers.48.self_attn.o_proj.weight', 'model.layers.48.mlp.gate_proj.weight', 'model.layers.48.mlp.up_proj.weight', 'model.layers.48.mlp.down_proj.weight', 'model.layers.48.input_layernorm.weight', 'model.layers.48.post_attention_layernorm.weight', 'model.layers.49.self_attn.q_proj.weight', 'model.layers.49.self_attn.k_proj.weight', 'model.layers.49.self_attn.v_proj.weight', 'model.layers.49.self_attn.o_proj.weight', 'model.layers.49.mlp.gate_proj.weight', 'model.layers.49.mlp.up_proj.weight', 'model.layers.49.mlp.down_proj.weight', 'model.layers.49.input_layernorm.weight', 'model.layers.49.post_attention_layernorm.weight', 'model.layers.3.self_attn.q_proj.weight', 'model.layers.3.self_attn.k_proj.weight', 'model.layers.3.self_attn.v_proj.weight', 'model.layers.3.self_attn.o_proj.weight'])" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[6], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Load state dict into hybrid model from Thinker\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m model_base \u001b[38;5;241m=\u001b[39m \u001b[43mMistralForCausalLM\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpath_thinker\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3\u001b[0m model_hybrid\u001b[38;5;241m.\u001b[39mload_state_dict(model_base\u001b[38;5;241m.\u001b[39mstate_dict(), strict\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/transformers/modeling_utils.py:311\u001b[0m, in \u001b[0;36mrestore_default_torch_dtype.._wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 309\u001b[0m old_dtype \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mget_default_dtype()\n\u001b[1;32m 310\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 311\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 312\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 313\u001b[0m torch\u001b[38;5;241m.\u001b[39mset_default_dtype(old_dtype)\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/transformers/modeling_utils.py:4839\u001b[0m, in \u001b[0;36mPreTrainedModel.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, weights_only, *model_args, **kwargs)\u001b[0m\n\u001b[1;32m 4829\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m dtype_orig \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 4830\u001b[0m torch\u001b[38;5;241m.\u001b[39mset_default_dtype(dtype_orig)\n\u001b[1;32m 4832\u001b[0m (\n\u001b[1;32m 4833\u001b[0m model,\n\u001b[1;32m 4834\u001b[0m missing_keys,\n\u001b[1;32m 4835\u001b[0m unexpected_keys,\n\u001b[1;32m 4836\u001b[0m mismatched_keys,\n\u001b[1;32m 4837\u001b[0m offload_index,\n\u001b[1;32m 4838\u001b[0m error_msgs,\n\u001b[0;32m-> 4839\u001b[0m ) \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_load_pretrained_model\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 4840\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4841\u001b[0m \u001b[43m \u001b[49m\u001b[43mstate_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4842\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheckpoint_files\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4843\u001b[0m \u001b[43m \u001b[49m\u001b[43mpretrained_model_name_or_path\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4844\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_mismatched_sizes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_mismatched_sizes\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4845\u001b[0m \u001b[43m \u001b[49m\u001b[43msharded_metadata\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msharded_metadata\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4846\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice_map\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice_map\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4847\u001b[0m \u001b[43m \u001b[49m\u001b[43mdisk_offload_folder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moffload_folder\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4848\u001b[0m \u001b[43m \u001b[49m\u001b[43moffload_state_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moffload_state_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4849\u001b[0m \u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtorch_dtype\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4850\u001b[0m \u001b[43m \u001b[49m\u001b[43mhf_quantizer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhf_quantizer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4851\u001b[0m \u001b[43m \u001b[49m\u001b[43mkeep_in_fp32_regex\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkeep_in_fp32_regex\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4852\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice_mesh\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice_mesh\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4853\u001b[0m \u001b[43m \u001b[49m\u001b[43mkey_mapping\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkey_mapping\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4854\u001b[0m \u001b[43m \u001b[49m\u001b[43mweights_only\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mweights_only\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4855\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4857\u001b[0m \u001b[38;5;66;03m# record tp degree the model sharded to\u001b[39;00m\n\u001b[1;32m 4858\u001b[0m model\u001b[38;5;241m.\u001b[39m_tp_size \u001b[38;5;241m=\u001b[39m tp_size\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/transformers/modeling_utils.py:5302\u001b[0m, in \u001b[0;36mPreTrainedModel._load_pretrained_model\u001b[0;34m(cls, model, state_dict, checkpoint_files, pretrained_model_name_or_path, ignore_mismatched_sizes, sharded_metadata, device_map, disk_offload_folder, offload_state_dict, dtype, hf_quantizer, keep_in_fp32_regex, device_mesh, key_mapping, weights_only)\u001b[0m\n\u001b[1;32m 5299\u001b[0m args_list \u001b[38;5;241m=\u001b[39m logging\u001b[38;5;241m.\u001b[39mtqdm(args_list, desc\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mLoading checkpoint shards\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 5301\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m args \u001b[38;5;129;01min\u001b[39;00m args_list:\n\u001b[0;32m-> 5302\u001b[0m _error_msgs, disk_offload_index, cpu_offload_index \u001b[38;5;241m=\u001b[39m \u001b[43mload_shard_file\u001b[49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 5303\u001b[0m error_msgs \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m _error_msgs\n\u001b[1;32m 5305\u001b[0m \u001b[38;5;66;03m# Adjust offloaded weights name and save if needed\u001b[39;00m\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/transformers/modeling_utils.py:933\u001b[0m, in \u001b[0;36mload_shard_file\u001b[0;34m(args)\u001b[0m\n\u001b[1;32m 931\u001b[0m \u001b[38;5;66;03m# Skip it with fsdp on ranks other than 0\u001b[39;00m\n\u001b[1;32m 932\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (is_fsdp_enabled() \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_local_dist_rank_0() \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_quantized):\n\u001b[0;32m--> 933\u001b[0m disk_offload_index, cpu_offload_index \u001b[38;5;241m=\u001b[39m \u001b[43m_load_state_dict_into_meta_model\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 934\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel_to_load\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 935\u001b[0m \u001b[43m \u001b[49m\u001b[43mstate_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 936\u001b[0m \u001b[43m \u001b[49m\u001b[43mshard_file\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 937\u001b[0m \u001b[43m \u001b[49m\u001b[43mexpected_keys\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 938\u001b[0m \u001b[43m \u001b[49m\u001b[43mreverse_key_renaming_mapping\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 939\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice_map\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice_map\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 940\u001b[0m \u001b[43m \u001b[49m\u001b[43mdisk_offload_folder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdisk_offload_folder\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 941\u001b[0m \u001b[43m \u001b[49m\u001b[43mdisk_offload_index\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdisk_offload_index\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 942\u001b[0m \u001b[43m \u001b[49m\u001b[43mcpu_offload_folder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcpu_offload_folder\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 943\u001b[0m \u001b[43m \u001b[49m\u001b[43mcpu_offload_index\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcpu_offload_index\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 944\u001b[0m \u001b[43m \u001b[49m\u001b[43mhf_quantizer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhf_quantizer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 945\u001b[0m \u001b[43m \u001b[49m\u001b[43mis_safetensors\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mis_offloaded_safetensors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 946\u001b[0m \u001b[43m \u001b[49m\u001b[43mkeep_in_fp32_regex\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkeep_in_fp32_regex\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 947\u001b[0m \u001b[43m \u001b[49m\u001b[43munexpected_keys\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43munexpected_keys\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 948\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice_mesh\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice_mesh\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 949\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 951\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m error_msgs, disk_offload_index, cpu_offload_index\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/torch/utils/_contextlib.py:116\u001b[0m, in \u001b[0;36mcontext_decorator..decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 116\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.conda/envs/fast_llm/lib/python3.12/site-packages/transformers/modeling_utils.py:810\u001b[0m, in \u001b[0;36m_load_state_dict_into_meta_model\u001b[0;34m(model, state_dict, shard_file, expected_keys, reverse_renaming_mapping, device_map, disk_offload_folder, disk_offload_index, cpu_offload_folder, cpu_offload_index, hf_quantizer, is_safetensors, keep_in_fp32_regex, unexpected_keys, device_mesh)\u001b[0m\n\u001b[1;32m 808\u001b[0m param \u001b[38;5;241m=\u001b[39m param[\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m]\n\u001b[1;32m 809\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m casting_dtype \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 810\u001b[0m param \u001b[38;5;241m=\u001b[39m \u001b[43mparam\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcasting_dtype\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 811\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m to_contiguous:\n\u001b[1;32m 812\u001b[0m param \u001b[38;5;241m=\u001b[39m param\u001b[38;5;241m.\u001b[39mcontiguous()\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] } ], "source": [ @@ -1050,7 +1065,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -1125,14 +1140,14 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Loading checkpoint shards: 100%|██████████| 7/7 [00:05<00:00, 1.22it/s]\n" + "Loading checkpoint shards: 100%|██████████| 7/7 [00:07<00:00, 1.01s/it]\n" ] }, { @@ -1146,17 +1161,21 @@ "Converting layer %d... 2\n", "Skipping transformer layer 2...\n", "Converting layer %d... 3\n", - "Skipping transformer layer 3...\n", + "Converting layer 3...\n", + "Init Mamba using Attention\n", "Converting layer %d... 4\n", "Skipping transformer layer 4...\n", "Converting layer %d... 5\n", - "Skipping transformer layer 5...\n", + "Converting layer 5...\n", + "Init Mamba using Attention\n", "Converting layer %d... 6\n", - "Skipping transformer layer 6...\n", + "Converting layer 6...\n", + "Init Mamba using Attention\n", "Converting layer %d... 7\n", "Skipping transformer layer 7...\n", "Converting layer %d... 8\n", - "Skipping transformer layer 8...\n", + "Converting layer 8...\n", + "Init Mamba using Attention\n", "Converting layer %d... 9\n", "Skipping transformer layer 9...\n", "Converting layer %d... 10\n", @@ -1277,7 +1296,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -1286,7 +1305,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -1306,12 +1325,12 @@ " \"t\",\n", " \"t\",\n", " \"t\",\n", + " \"m2\",\n", " \"t\",\n", + " \"m2\",\n", + " \"m2\",\n", " \"t\",\n", - " \"t\",\n", - " \"t\",\n", - " \"t\",\n", - " \"t\",\n", + " \"m2\",\n", " \"t\",\n", " \"t\",\n", " \"t\",\n", @@ -1380,6 +1399,8 @@ " \"dt_rank\": \"auto\",\n", " \"dt_scale\": 1.0,\n", " \"expand\": 1,\n", + " \"head_dim\": 128,\n", + " \"layer_norm_epsilon\": 1e-05,\n", " \"n_qk_heads\": 32,\n", " \"n_v_heads\": 32\n", " },\n", @@ -1391,7 +1412,7 @@ "}" ] }, - "execution_count": 14, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -1402,7 +1423,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -1411,26 +1432,40 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 427.77it/s]\n" + "Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 37.80it/s]" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['model.layers.6.mixer.A_log', 'model.layers.6.mixer.D', 'model.layers.6.mixer.conv1d.weight', 'model.layers.6.mixer.conv1d.bias', 'model.layers.6.mixer.in_proj.weight', 'model.layers.6.mixer.dt_in_proj.weight', 'model.layers.6.mixer.dt_proj.weight', 'model.layers.6.mixer.dt_proj.bias', 'model.layers.6.mixer.out_proj.weight', 'model.layers.8.mixer.A_log', 'model.layers.8.mixer.D', 'model.layers.8.mixer.conv1d.weight', 'model.layers.8.mixer.conv1d.bias', 'model.layers.8.mixer.in_proj.weight', 'model.layers.8.mixer.dt_in_proj.weight', 'model.layers.8.mixer.dt_proj.weight', 'model.layers.8.mixer.dt_proj.bias', 'model.layers.8.mixer.out_proj.weight']\n", + "['model.layers.6.self_attn.q_proj.weight', 'model.layers.6.self_attn.k_proj.weight', 'model.layers.6.self_attn.v_proj.weight', 'model.layers.6.self_attn.o_proj.weight', 'model.layers.8.self_attn.q_proj.weight', 'model.layers.8.self_attn.k_proj.weight', 'model.layers.8.self_attn.v_proj.weight', 'model.layers.8.self_attn.o_proj.weight']\n" ] } ], "source": [ "# load state dict from existing pretrained SSM?\n", - "path_25hyb = \"/mnt/checkpoints/ssm/apriel_ssm_thinker5l_hybrid_1ssm_init_rand_debug_tpformat\" #\"/mnt/checkpoints/fast_llm_exp/slam_ssm_distill/15b-oshyb25lmil-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti5000_lm6/export/apriel_ssm_thinker_hybrid/5000_new\"\n", + "path_25hyb = path_hybrid #\"/mnt/checkpoints/ssm/apriel_ssm_thinker5l_hybrid_1ssm_init_rand_debug_tpformat\" #\"/mnt/checkpoints/fast_llm_exp/slam_ssm_distill/15b-oshyb25lmil-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti5000_lm6/export/apriel_ssm_thinker_hybrid/5000_new\"\n", "model = AprielThinkerSSMHybridForCausalLM.from_pretrained(path_25hyb)\n", "state_dict = model.state_dict()\n", - "\n", - "# missing, unexpected = transformer.load_state_dict(state_dict, strict=False)\n", - "# print(missing)\n", - "# print(unexpected)\n", + "missing, unexpected = transformer.load_state_dict(state_dict, strict=False)\n", + "print(missing)\n", + "print(unexpected)\n", "\n", "\n", "\n", @@ -1450,7 +1485,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1458,7 +1493,10 @@ "# mamba2, state 16, expand 1, i.e. same as M1, but with discrete mamba2 and MIL\n", "# transformer.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_thinker15b_hybrid_1ssm_leastimportant_m2_16hexp1_init_mil\") # 1 ssm\n", "# transformer.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_thinker15b_hybrid_25ssm_leastimportant_m2_16hexp1_init_mil\") # 25 ssm\n", - "transformer.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_thinker15b_hybrid_25ssm_leastimportant_m2_16hexp1_init_mil_tpformat\") # 25 ssm\n", + "# transformer.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_thinker15b_hybrid_25ssm_leastimportant_m2_16hexp1_init_mil_tpformat\") # 25 ssm\n", + "# transformer.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_thinker15b_hybrid_27ssm_leastimportant_m2_16hexp1_init_hyb25distsftvrlm223k_mil\") # 25 ssm\n", + "# transformer.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_thinker15b_hybrid_30ssm_leastimportant_m2_16hexp1_init_hyb27distsftvrlm2_3p5ksteps_mil\") # 30 ssm\n", + "transformer.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_thinker15b_hybrid_29ssm_leastimportant_m2_16hexp1_init_hyb27distsftvrlm2_3p5ksteps_mil\") # 29 ssm\n", "\n", "\n", "# transformer.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_thinker15b_hybrid_40ssm_leastimportant_m2_16hexp1_init_mil_uniform_from_25h5000lm6\") # 40 ssm" @@ -1482,60 +1520,1046 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Data mixing" + "# Nemotron-h mamba layer" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "torch.Size([])\n", - "KL (global, F.kl_div) = 0.738795\n", - "KL (sum of shards, manual) = 0.738795\n" + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" ] } ], - "source": [] + "source": [ + "import gc\n", + "\n", + "import click\n", + "import torch\n", + "from transformers import AutoConfig, AutoModelForCausalLM\n", + "from transformers import MistralForCausalLM\n", + "\n", + "from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig\n", + "from fast_llm.models.ssm.external.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import AprielThinkerSSMHybridForCausalLM, AprielSSMM2DecoderLayer, AprielSSMDecoderLayer, AprielSSMNemotronHM2DecoderLayer\n", + "from fast_llm.models.ssm.external.nemotron.modeling import NemotronHMamba2Mixer as NemotronHMamba2Mixer_original\n", + "from fast_llm.models.ssm.external.nemotron.config import NemotronHConfig\n", + "from transformers.models.mistral.modeling_mistral import MistralDecoderLayer\n", + "\n", + "# enable file reload \n", + "%load_ext autoreload\n", + "%autoreload 2" + ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 21, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "# model.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_thinker15b_hybrid_nmtrhnm2_5l_debug\")\n", + "# model = AprielThinkerSSMHybridForCausalLM.from_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_thinker15b_hybrid_nmtrhnm2_5l_debug\")" + ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 17, "metadata": {}, - "outputs": [], - "source": [] + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "model.embed_tokens.weight\n", + "model.layers.0.self_attn.q_proj.weight\n", + "model.layers.0.self_attn.k_proj.weight\n", + "model.layers.0.self_attn.v_proj.weight\n", + "model.layers.0.self_attn.o_proj.weight\n", + "model.layers.0.mlp.gate_proj.weight\n", + "model.layers.0.mlp.up_proj.weight\n", + "model.layers.0.mlp.down_proj.weight\n", + "model.layers.0.input_layernorm.weight\n", + "model.layers.0.post_attention_layernorm.weight\n", + "model.layers.1.self_attn.q_proj.weight\n", + "model.layers.1.self_attn.k_proj.weight\n", + "model.layers.1.self_attn.v_proj.weight\n", + "model.layers.1.self_attn.o_proj.weight\n", + "model.layers.1.mlp.gate_proj.weight\n", + "model.layers.1.mlp.up_proj.weight\n", + "model.layers.1.mlp.down_proj.weight\n", + "model.layers.1.input_layernorm.weight\n", + "model.layers.1.post_attention_layernorm.weight\n", + "model.layers.2.mixer.dt_bias\n", + "model.layers.2.mixer.A_log\n", + "model.layers.2.mixer.D\n", + "model.layers.2.mixer.conv1d.weight\n", + "model.layers.2.mixer.conv1d.bias\n", + "model.layers.2.mixer.in_proj.weight\n", + "model.layers.2.mixer.dt_in_proj.weight\n", + "model.layers.2.mixer.norm.weight\n", + "model.layers.2.mixer.out_proj.weight\n", + "model.layers.2.mlp.gate_proj.weight\n", + "model.layers.2.mlp.up_proj.weight\n", + "model.layers.2.mlp.down_proj.weight\n", + "model.layers.2.input_layernorm.weight\n", + "model.layers.2.post_attention_layernorm.weight\n", + "model.layers.3.mixer.dt_bias\n", + "model.layers.3.mixer.A_log\n", + "model.layers.3.mixer.D\n", + "model.layers.3.mixer.conv1d.weight\n", + "model.layers.3.mixer.conv1d.bias\n", + "model.layers.3.mixer.in_proj.weight\n", + "model.layers.3.mixer.dt_in_proj.weight\n", + "model.layers.3.mixer.norm.weight\n", + "model.layers.3.mixer.out_proj.weight\n", + "model.layers.3.mlp.gate_proj.weight\n", + "model.layers.3.mlp.up_proj.weight\n", + "model.layers.3.mlp.down_proj.weight\n", + "model.layers.3.input_layernorm.weight\n", + "model.layers.3.post_attention_layernorm.weight\n", + "model.layers.4.self_attn.q_proj.weight\n", + "model.layers.4.self_attn.k_proj.weight\n", + "model.layers.4.self_attn.v_proj.weight\n", + "model.layers.4.self_attn.o_proj.weight\n", + "model.layers.4.mlp.gate_proj.weight\n", + "model.layers.4.mlp.up_proj.weight\n", + "model.layers.4.mlp.down_proj.weight\n", + "model.layers.4.input_layernorm.weight\n", + "model.layers.4.post_attention_layernorm.weight\n", + "model.norm.weight\n", + "lm_head.weight\n" + ] + } + ], + "source": [ + "for k,v in model.state_dict().items():\n", + " print(k)" + ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 36, "metadata": {}, - "outputs": [], - "source": [] + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Instantiating Mamba with num_heads: 32, head_dim: 128, \n", + " intermediate_size: 4096, \n", + " d_xb: 1024, \n", + " number_xb_heads: 8, \n", + " number_heads: 32, \n", + " repeat_groups: 4, \n", + " d_state: 128, \n", + " n_groups: 32\n", + "Instantiating Mamba with num_heads: 32, head_dim: 128, \n", + " intermediate_size: 4096, \n", + " d_xb: 1024, \n", + " number_xb_heads: 8, \n", + " number_heads: 32, \n", + " repeat_groups: 4, \n", + " d_state: 128, \n", + " n_groups: 32\n" + ] + } + ], + "source": [ + "path_hybrid = \"/mnt/checkpoints/fast_llm_exp/slam_ssm_distill/15b-hyb25distsftvrlm2-bs64-lr5e-06-lrs1-1-1-1-sl16384_ti60000_aprsft/export/apriel_ssm_thinker_hybrid/23000\"\n", + "path_thinker = \"/mnt/checkpoints/upstream/Apriel-Nemotron-15b-Thinker\"\n", + "\n", + "config_thinker = AutoConfig.from_pretrained(path_thinker)\n", + "\n", + "# config_hybrid = AprielSSMHybridConfig.from_pretrained(path_hybrid)\n", + "config_thinker.num_hidden_layers = 5\n", + "hybrid_block_layout = [\"t\"] * config_thinker.num_hidden_layers\n", + "# debug\n", + "hybrid_block_layout[2] = \"nm2\"\n", + "hybrid_block_layout[3] = \"nm2\"\n", + "\n", + "# 25/50\n", + "# hybrid_block_layout = [\n", + "# \"t\",\n", + "# \"t\",\n", + "# \"t\",\n", + "# \"t\",\n", + "# \"t\",\n", + "# \"t\",\n", + "# \"t\",\n", + "# \"t\",\n", + "# \"t\",\n", + "# \"t\",\n", + "# \"t\",\n", + "# \"t\",\n", + "# \"t\",\n", + "# \"t\",\n", + "# \"t\",\n", + "# \"t\",\n", + "# \"t\",\n", + "# \"t\",\n", + "# \"t\",\n", + "# \"t\",\n", + "# \"nm2\",\n", + "# \"nm2\",\n", + "# \"nm2\",\n", + "# \"nm2\",\n", + "# \"nm2\",\n", + "# \"nm2\",\n", + "# \"t\",\n", + "# \"nm2\",\n", + "# \"t\",\n", + "# \"t\",\n", + "# \"nm2\",\n", + "# \"nm2\",\n", + "# \"nm2\",\n", + "# \"nm2\",\n", + "# \"t\",\n", + "# \"nm2\",\n", + "# \"nm2\",\n", + "# \"nm2\",\n", + "# \"nm2\",\n", + "# \"nm2\",\n", + "# \"nm2\",\n", + "# \"nm2\",\n", + "# \"nm2\",\n", + "# \"nm2\",\n", + "# \"nm2\",\n", + "# \"nm2\",\n", + "# \"nm2\",\n", + "# \"nm2\",\n", + "# \"t\",\n", + "# \"nm2\"\n", + "# ]\n", + " \n", + "ssm_config = {\n", + " # \"d_state\": 16,\n", + " # \"d_xb\": 1024,\n", + " # \"expand\": 1,\n", + " # \"d_conv\": 4,\n", + " # \"d_inner\": 4096,\n", + " # \"conv_bias\": True,\n", + " # \"bias\": False,\n", + " # \"head_dim\": 16, # 4096/16 = 32 heads, 1024/128 = 8 KVheads and 4 repeat groups\n", + " \"d_state\": 128,\n", + " \"d_xb\": 1024, # if we have 8 kv heads, 128 * 8 = 1024\n", + " \"d_conv\": 4,\n", + " \"conv_bias\": True,\n", + " \"bias\": False,\n", + " \"head_dim\": 128,\n", + " # \"num_heads\": 32, # intermediate_size/d_inner = 32 * 128 = 4096\n", + " \"n_groups\": 32 # just ofor C\n", + " \n", + "}\n", + "config_thinker.hybrid_block_layout = hybrid_block_layout\n", + "# config_thinker.ssm_cfg = ssm_config\n", + "model = AprielThinkerSSMHybridForCausalLM(config_hybrid)\n", + "# mixer = NemotronHMamba2Mixer(d_model=4096, **ssm_config)\n", + "\n", + "\n", + "\n", + "config_hybrid = AprielSSMHybridConfig(\n", + " **config_thinker.to_dict(),\n", + " ssm_cfg=ssm_config\n", + ")\n" + ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 32, "metadata": {}, - "outputs": [], - "source": [] + "outputs": [ + { + "data": { + "text/plain": [ + "5" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(config_hybrid.hybrid_block_layout)" + ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 33, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "\n", + "def convert_layers(transformer, mamba_config, hybrid_block_layout, init_with_kqvo, attn_bias, torch_dtype):\n", + " config = transformer.config\n", + " embed_dim = config.hidden_size\n", + " num_heads = config.num_attention_heads\n", + " num_heads_kv = config.num_key_value_heads\n", + " head_dim = embed_dim // num_heads\n", + " q_dim = head_dim * num_heads\n", + " kv_dim = head_dim * num_heads_kv\n", + "\n", + " for layer_idx, type in enumerate(hybrid_block_layout):\n", + " print(\"Converting layer %d...\", layer_idx)\n", + " # Fetch the layer module for easier access\n", + " layer_module = transformer.model.layers._modules[f\"{layer_idx}\"]\n", + " if type == \"t\":\n", + " print(\"Skipping transformer layer %d...\" % layer_idx)\n", + " elif type == \"nm2\":\n", + " print(\"Converting layer %d...\" % layer_idx)\n", + " # Use MambaDecoderLayer for the remaining layers\n", + " mamba_encoder = AprielSSMNemotronHM2DecoderLayer(\n", + " mamba_config,\n", + " layer_idx,\n", + " device=\"cpu\",\n", + " dtype=torch_dtype,\n", + " )\n", + " \n", + " mamba_encoder.mlp.load_state_dict(layer_module.mlp.state_dict())\n", + " mamba_encoder.input_layernorm.load_state_dict(layer_module.input_layernorm.state_dict())\n", + " mamba_encoder.post_attention_layernorm.load_state_dict(layer_module.post_attention_layernorm.state_dict())\n", + " mamba_encoder.mixer.out_proj.load_state_dict(layer_module.self_attn.o_proj.state_dict())\n", + "\n", + " num_xb_heads = mamba_config.ssm_cfg[\"d_xb\"] // mamba_config.ssm_cfg[\"head_dim\"]\n", + " num_heads = mamba_config.ssm_cfg[\"d_inner\"] // mamba_config.ssm_cfg[\"head_dim\"]\n", + "\n", + " if init_with_kqvo:\n", + " # # Copy weights: [z, x, B, C, dt], x -> v, B -> k, C -> q\n", + " # mamba_encoder.mixer.in_proj.weight.data[\n", + " # mamba_config.ssm_cfg[\"d_inner\"] : mamba_config.ssm_cfg[\"d_inner\"] + mamba_config.ssm_cfg[\"d_xb\"], :\n", + " # ].copy_(layer_module.self_attn.v_proj.weight.data)\n", + " # mamba_encoder.mixer.in_proj.weight.data[\n", + " # mamba_config.ssm_cfg[\"d_inner\"] + mamba_config.ssm_cfg[\"d_xb\"] : mamba_config.ssm_cfg[\"d_inner\"] + mamba_config.ssm_cfg[\"d_xb\"] + (num_xb_heads * mamba_config.ssm_cfg[\"d_state\"]), :\n", + " # ].copy_(layer_module.self_attn.k_proj.weight.data)\n", + " # mamba_encoder.mixer.in_proj.weight.data[\n", + " # mamba_config.ssm_cfg[\"d_inner\"] + mamba_config.ssm_cfg[\"d_xb\"] + (num_xb_heads * mamba_config.ssm_cfg[\"d_state\"]) : mamba_config.ssm_cfg[\"d_inner\"] + mamba_config.ssm_cfg[\"d_xb\"] + (num_xb_heads * mamba_config.ssm_cfg[\"d_state\"]) + (num_heads * mamba_config.ssm_cfg[\"d_state\"]), :\n", + " # ].copy_(layer_module.self_attn.q_proj.weight.data)\n", + " mamba_encoder.mixer.in_proj.weight.data[\n", + " mamba_config.ssm_cfg[\"d_inner\"] : mamba_config.ssm_cfg[\"d_inner\"] + mamba_config.ssm_cfg[\"d_xb\"], :\n", + " ].copy_(layer_module.self_attn.v_proj.weight.data)\n", + " mamba_encoder.mixer.in_proj.weight.data[\n", + " mamba_config.ssm_cfg[\"d_inner\"] + mamba_config.ssm_cfg[\"d_xb\"] : mamba_config.ssm_cfg[\"d_inner\"] + mamba_config.ssm_cfg[\"d_xb\"] + mamba_config.ssm_cfg[\"d_xb\"], :\n", + " ].copy_(layer_module.self_attn.k_proj.weight.data)\n", + " mamba_encoder.mixer.in_proj.weight.data[\n", + " mamba_config.ssm_cfg[\"d_inner\"] + mamba_config.ssm_cfg[\"d_xb\"] + mamba_config.ssm_cfg[\"d_xb\"] : mamba_config.ssm_cfg[\"d_inner\"] + mamba_config.ssm_cfg[\"d_xb\"] + mamba_config.ssm_cfg[\"d_xb\"] + (num_heads * mamba_config.ssm_cfg[\"d_state\"]), :\n", + " ].copy_(layer_module.self_attn.q_proj.weight.data)\n", + "\n", + " print(\"Init Mamba using Attention\")\n", + "\n", + " transformer.model.layers[layer_idx] = mamba_encoder\n", + "\n", + " # elif type == \"m2d\":\n", + " # print(\"Converting layer %d...\" % layer_idx)\n", + " # mamba_encoder = AprielSSMDecoderLayer(\n", + " # mamba_config,\n", + " # layer_idx,\n", + " # device=\"cpu\",\n", + " # dtype=torch_dtype,\n", + " # )\n", + " # mamba_encoder.mlp.load_state_dict(layer_module.mlp.state_dict())\n", + " # mamba_encoder.input_layernorm.load_state_dict(layer_module.input_layernorm.state_dict())\n", + " # mamba_encoder.post_attention_layernorm.load_state_dict(layer_module.post_attention_layernorm.state_dict())\n", + " # mamba_encoder.mixer.out_proj.load_state_dict(layer_module.self_attn.o_proj.state_dict())\n", + "\n", + " # if init_with_kqvo:\n", + " \n", + "\n", + "\n", + " \n", + " else:\n", + " raise ValueError(f\"Invalid layer type: {type}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading checkpoint shards: 0%| | 0/7 [00:00)}" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x_dummy = torch.randn(1, 1024, 5120).to(\"cuda\")\n", + "cache = HybridMambaAttentionDynamicCache(apriel_config, 1,)\n", + "mamba2(x_dummy, cache)" + ] + }, + { + "cell_type": "code", + "execution_count": 74, + "metadata": {}, + "outputs": [], + "source": [ + "config = NemotronHConfig(\n", + " num_hidden_layers=6,\n", + " hybrid_override_pattern=\"M-M-*-\",\n", + " ssm_state_size=128,\n", + " mamba_num_heads=128,\n", + " # mamba_n_groups=32,\n", + " n_groups=8,\n", + " mamba_head_dim=64,\n", + " mamba_d_conv=4,\n", + " mamba_expand=2,\n", + " mamba_hidden_act=\"silu\",\n", + " mamba_dt_min=0.001,\n", + " hidden_size=5120,\n", + " mlp_hidden_act=\"silu\",\n", + " intermediate_size=14336\n", + ")\n", + "# config = NemotronHConfig(\n", + "# num_hidden_layers=6,\n", + "# hybrid_override_pattern=\"M-M-*-\",\n", + "# ssm_state_size=16,\n", + "# mamba_num_heads=32,\n", + "# n_groups=32,\n", + "# mamba_head_dim=128,\n", + "# mamba_d_conv=4,\n", + "# mamba_expand=1,\n", + "# mamba_hidden_act=\"silu\",\n", + "# mamba_dt_min=0.001,\n", + "# hidden_size=5120,\n", + "# mlp_hidden_act=\"silu\",\n", + "# intermediate_size=14336\n", + "# )\n", + "nemotron_h = NemotronHForCausalLM(config)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "NemotronHForCausalLM(\n", + " (backbone): NemotronHModel(\n", + " (embeddings): Embedding(131072, 5120)\n", + " (layers): ModuleList(\n", + " (0): NemotronHBlock(\n", + " (norm): NemotronHRMSNorm()\n", + " (mixer): NemotronHMamba2Mixer(\n", + " (act): SiLU()\n", + " (conv1d): Conv1d(10240, 10240, kernel_size=(4,), stride=(1,), padding=(3,), groups=10240)\n", + " (in_proj): Linear(in_features=5120, out_features=18560, bias=False)\n", + " (norm): MambaRMSNormGated()\n", + " (out_proj): Linear(in_features=8192, out_features=5120, bias=False)\n", + " )\n", + " )\n", + " (1): NemotronHBlock(\n", + " (norm): NemotronHRMSNorm()\n", + " (mixer): NemotronHMLP(\n", + " (up_proj): Linear(in_features=5120, out_features=14336, bias=False)\n", + " (down_proj): Linear(in_features=14336, out_features=5120, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " )\n", + " (2): NemotronHBlock(\n", + " (norm): NemotronHRMSNorm()\n", + " (mixer): NemotronHMamba2Mixer(\n", + " (act): SiLU()\n", + " (conv1d): Conv1d(10240, 10240, kernel_size=(4,), stride=(1,), padding=(3,), groups=10240)\n", + " (in_proj): Linear(in_features=5120, out_features=18560, bias=False)\n", + " (norm): MambaRMSNormGated()\n", + " (out_proj): Linear(in_features=8192, out_features=5120, bias=False)\n", + " )\n", + " )\n", + " (3): NemotronHBlock(\n", + " (norm): NemotronHRMSNorm()\n", + " (mixer): NemotronHMLP(\n", + " (up_proj): Linear(in_features=5120, out_features=14336, bias=False)\n", + " (down_proj): Linear(in_features=14336, out_features=5120, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " )\n", + " (4): NemotronHBlock(\n", + " (norm): NemotronHRMSNorm()\n", + " (mixer): NemotronHAttention(\n", + " (q_proj): Linear(in_features=5120, out_features=4096, bias=False)\n", + " (k_proj): Linear(in_features=5120, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=5120, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=4096, out_features=5120, bias=False)\n", + " )\n", + " )\n", + " (5): NemotronHBlock(\n", + " (norm): NemotronHRMSNorm()\n", + " (mixer): NemotronHMLP(\n", + " (up_proj): Linear(in_features=5120, out_features=14336, bias=False)\n", + " (down_proj): Linear(in_features=14336, out_features=5120, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " )\n", + " )\n", + " (norm_f): NemotronHRMSNorm()\n", + " )\n", + " (lm_head): Linear(in_features=5120, out_features=131072, bias=False)\n", + ")" + ] + }, + "execution_count": 75, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "nemotron_h" + ] } ], "metadata": { diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py index 98d2fc28d..c179caa20 100644 --- a/fast_llm/models/ssm/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py +++ b/fast_llm/models/ssm/external/apriel_15b_hybrid/configuration_ssm_hybrid_apriel15b.py @@ -23,6 +23,12 @@ "dt_scale": 1.0, "dt_init_floor": 1e-4, "conv_bias": True, + # nemotron mamba2 + "head_dim": 128, + "n_groups": 128, + "norm_before_gate": True, + # "num_heads": 128, + # "layer_norm_epsilon": 1e-5, } @@ -31,6 +37,10 @@ class AprielSSMHybridConfig(MistralConfig): def __init__(self, hybrid_block_layout=["m2d"], ssm_cfg=None, **kwargs): super().__init__(**kwargs) + kwargs["sliding_window"] = None + logger.warning( + "Note! Sliding_window is not supported for AprielSSMHybridConfig (15B thinker does not use sliding window), setting to None" + ) self.hybrid_block_layout = hybrid_block_layout self.head_dim = self.head_dim or self.hidden_size // self.num_attention_heads # as in transformers 4.51.3 self.ssm_cfg = ssm_cfg or ssm_config_default diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py index 9f4588a29..9b8052551 100644 --- a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py +++ b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py @@ -8,6 +8,7 @@ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update from einops import rearrange, repeat from mamba_ssm.ops.selective_scan_interface import selective_scan_fn +from mamba_ssm.ops.triton.layernorm_gated import rmsnorm_fn from mamba_ssm.ops.triton.selective_state_update import selective_state_update from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined from torch import nn @@ -202,14 +203,14 @@ def update_conv_state( self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False ) -> torch.Tensor: if cache_init: - self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device) + self.conv_states[layer_idx] = new_conv_state.to(self.conv_states[0].device) else: self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device) return self.conv_states[layer_idx] def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): - self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states[0].device) return self.ssm_states[layer_idx] def reset(self): @@ -217,7 +218,8 @@ def reset(self): self.ssm_states.zero_() -# Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/jamba/modeling_jamba.py +# Copied from https://huggingface.co/nvidia/Nemotron-H-8B-Base-8K/blob/main/modeling_nemotron_h.py whichis taken from +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/jamba/modeling_jamba.py class HybridMambaAttentionDynamicCache(DynamicCache): """ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache @@ -242,9 +244,9 @@ def __init__(self, config: AprielSSMHybridConfig, batch_size, dtype=torch.float1 else config.ssm_cfg["expand"] * config.hidden_size ) ssm_state_size = config.ssm_cfg["d_state"] - conv_kernel_size = config.ssm_cfg["d_conv"] + self.conv_kernel_size = conv_kernel_size = config.ssm_cfg["d_conv"] self.n_qk_heads = config.ssm_cfg["n_qk_heads"] - self.num_C_head = intermediate_size // ssm_state_size # mamba2 + self.num_C_head = intermediate_size // ssm_state_size # Mamba assert intermediate_size % self.n_qk_heads == 0, "d_inner must be divisible by n_qk_heads" self.head_d = intermediate_size // self.n_qk_heads self.conv_states = [] @@ -341,14 +343,14 @@ def update_conv_state( self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False ) -> torch.Tensor: if cache_init: - self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device) + self.conv_states[layer_idx] = new_conv_state.to(self.conv_states[0].device) else: self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) - self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device) + self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states[0].device) return self.conv_states[layer_idx] def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): - self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device) + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states[0].device) return self.ssm_states[layer_idx] def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: @@ -783,7 +785,437 @@ def convolutional_step(self, xBC, conv_state): return xBC, conv_state +def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int): + """ + Padding x tensor with `pad_size` on the seq_len dim (dim=1) + + Assumes that we only have tensors of either size 4 or 3 + """ + pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0) + + return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0) + + +class MambaRMSNormGated(torch.nn.Module): + def __init__(self, hidden_size, group_size, eps=1e-5, norm_before_gate=True): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + self.group_size = group_size + self.norm_before_gate = norm_before_gate + + # jan28b version + def forward(self, hidden_states, gate=None): + return rmsnorm_fn( + x=hidden_states, + weight=self.weight, + bias=None, # No bias + z=gate, + eps=self.variance_epsilon, + group_size=self.group_size, + norm_before_gate=self.norm_before_gate, + ) + + class Mamba2(nn.Module): + """ + Mix of https://huggingface.co/nvidia/Nemotron-H-8B-Base-8K/blob/main/modeling_nemotron_h.py. and https://github.com/jxiw/M1/blob/main/mamba2/hybrid_mamba_layer.py + + Note: we assume num_heads=n_groups here, hence we do not share B or C over heads. + Why: we use MIL init from a GQA model like this: x -> V, B -> K, C -> Q + """ + + def __init__( + self, + d_model, + d_xb=None, # to mimic GQA, i.e. if we have e.g. 32 heads and 8 kv heads, i.e. 4 GQA groups, then d_xb should be head_dim * 8 (so that we can MIL init B and C from Ks and Vs) + d_state=128, + d_conv=4, + head_dim=128, + layer_norm_epsilon=1e-5, + conv_bias=True, + chunk_size=128, + bias=False, + layer_idx=None, + n_groups=1, + norm_before_gate=False, + # num_heads=128, + # device=None, + # dtype=None, + **kwargs, + ): + super().__init__() + self.hidden_size = d_model + self.ssm_state_size = d_state + self.conv_kernel_size = d_conv + # self.expand = expand + # self.intermediate_size = ( + # d_inner if d_inner is not None else d_model * expand + # ) # + + assert head_dim == d_state, "head_dim must be equal to d_state" + + self.d_xb = d_xb if d_xb is not None else self.intermediate_size + self.layer_idx = layer_idx + self.use_conv_bias = conv_bias + self.activation = "silu" + self.act = nn.SiLU() + self.head_dim = head_dim + self.num_heads = n_groups # to make sure B and C are same shape, so we have seperate B and C per head #self.intermediate_size // self.head_dim + self.intermediate_size = self.num_heads * head_dim + assert self.intermediate_size % self.head_dim == 0, "intermediate_size must be divisible by head_dim" + + # for GQA simulation, where we repeat x and B for each GQA group + self.num_xb_heads = self.d_xb // self.head_dim # how many heads inside one xb + assert self.num_heads % self.num_xb_heads == 0, "num_heads must be divisible by num_xb_heads" + self.repeat_xb_groups = self.num_heads // self.num_xb_heads # num. GQA groups + if self.d_xb == self.intermediate_size: + assert self.repeat_xb_groups == 1 + logger.warning( + f"d_xb == intermediate_size, d_xb: {self.d_xb}, intermediate_size: {self.intermediate_size}, repeat_groups: {self.repeat_xb_groups}" + ) + + self.layer_norm_epsilon = layer_norm_epsilon + # nemotron allows for any num_groups, we use the same as num_heads for now, otherwisxe it becomes too complecated with GQA simulation + # this would mean we share B and C + self.n_groups = ( + # self.num_heads + n_groups + ) + self.chunk_size = chunk_size + + logger.warning( + f"Instantiating Mamba with num_heads: {self.num_heads}, head_dim: {self.head_dim}, \n \ + intermediate_size: {self.intermediate_size}, \n \ + d_xb: {self.d_xb}, \n \ + number_xb_heads: {self.num_xb_heads}, \n \ + number_heads: {self.num_heads}, \n \ + repeat_groups: {self.repeat_xb_groups}, \n \ + d_state: {self.ssm_state_size}, \n \ + n_groups: {self.n_groups}" + ) + + self.time_step_limit = (0.0, float("inf")) # hard coded + # conv is over xBC -- d_xb (head_dim), d_bb (state_dim), d_c (state_dim) + # self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + # for B dstate is shared within each group of heads; + # with MIL we need to copy k_proj weights into B proj., and B has is ngroups x dstate i M2 + # so k_proj is 5120 x 1024 (1024/8 = 128 per head) in apriel_thinker + self.conv_dim = ( + self.d_xb # self.num_xb_heads x head_dim # x, 1024 for Apriel_thinker, innitialized from v + + self.num_xb_heads + * self.ssm_state_size # we disable ngroups in fvour of MIL init. So we will act as if we had seperate B per head; + + self.n_groups + * self.ssm_state_size # C # C can be shared accross groups, if group is 1 its shared across all heads + ) + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=self.use_conv_bias, + kernel_size=self.conv_kernel_size, + groups=self.conv_dim, + padding=self.conv_kernel_size - 1, + ) + + # projection of the input hidden states + # intermediate_size for the gate x the rest for x, B and C + projection_size = self.intermediate_size + self.conv_dim # + self.num_heads + self.in_proj = nn.Linear( + self.hidden_size, + projection_size, + bias=bias, + ) + self.dt_in_proj = nn.Linear( + self.hidden_size, + self.num_heads, + bias=bias, + ) + # selective projection used to make dt, B and C input dependant + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.dt_bias = nn.Parameter(torch.ones(self.num_heads)) + + # S4D real initialization. These are not discretized! + # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded + A = torch.arange(1, self.num_heads + 1) + self.A_log = nn.Parameter(torch.log(A)) + self.A_log._no_weight_decay = True + # self.norm = nn.RMSNorm(self.intermediate_size, eps=self.layer_norm_epsilon) + self.norm = MambaRMSNormGated( + self.intermediate_size, + eps=self.layer_norm_epsilon, + group_size=self.intermediate_size + // self.n_groups, # this norm should actually work per head, leave it here just to eval a checkpoint trained like this + norm_before_gate=norm_before_gate, + ) + self.D = nn.Parameter(torch.ones(self.num_heads)) + self.D._no_weight_decay = True + + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias) + self.use_bias = bias + + if not is_fast_path_available: + logger.warning_once( + "The fast path is not available because on of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`" + " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and" + " https://github.com/Dao-AILab/causal-conv1d" + ) + + def cuda_kernels_forward( + self, + hidden_states: torch.Tensor, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + outputs = {} + # 1. Gated MLP's linear projection + # Apply_mask_to_padding_states is not used in nemotron, + # because attention_mask is not pased, see https://huggingface.co/nvidia/Nemotron-H-8B-Base-8K/blob/main/modeling_nemotron_h.py#L774 + attention_mask = None # so apply_mask_to_padding_states does nothing + hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) + projected_states = self.in_proj(hidden_states) + + # Set up dimensions for reshapes later + batch_size, seq_len, _ = hidden_states.shape + self.repeat_xb_groups * self.n_groups * self.ssm_state_size + groups_time_state_size = self.n_groups * self.ssm_state_size + # C dim, note we keep number of groups same as number of heads here + # head_time_state_size = self.num_heads * self.ssm_state_size + # num_xb_heads_time_state_size = self.num_xb_heads * self.ssm_state_size + + # d_mlp = ( + # projected_states.shape[-1] + # - 2 * self.intermediate_size + # - 2 * self.n_groups * self.ssm_state_size + # - self.num_heads + # ) // 2 + + # Single step calculations via cache + if cache_params is not None and cache_position is not None and cache_position[0] > 0: + gate, hidden_states_B_C = projected_states.squeeze(1).split( + [self.intermediate_size, self.conv_dim], dim=-1 + ) + dt = self.dt_in_proj(hidden_states).squeeze(1) + + # 2. Convolution sequence transformation + hidden_states_B_C = causal_conv1d_update( + hidden_states_B_C, + cache_params.conv_states[self.layer_idx], + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.activation, + ) + + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.d_xb, self.d_xb, groups_time_state_size], + dim=-1, + ) + # simulate GQA by repeating heads in x,b, x -> v, B -> k, C -> q + hidden_states = rearrange( + hidden_states, + "b (num_xb_heads head_dim) -> b num_xb_heads head_dim", + head_dim=self.head_dim, + num_xb_heads=self.num_xb_heads, + ) # x is b x local_head_groups x l x head_dim + B = rearrange( + B, + "b (xb_groups state_size) -> b xb_groups state_size", + state_size=self.ssm_state_size, + ) # b is b x local_head_groups x l x state_size + batch, num_key_value_heads, head_dim = hidden_states.shape + hidden_states = hidden_states[:, :, None, :].expand( + batch, num_key_value_heads, self.repeat_xb_groups, head_dim + ) + hidden_states = hidden_states.reshape(batch, num_key_value_heads * self.repeat_xb_groups, head_dim) + B = B[:, :, None, :].expand(batch, num_key_value_heads, self.repeat_xb_groups, self.ssm_state_size) + B = B.reshape(batch, num_key_value_heads * self.repeat_xb_groups, self.ssm_state_size) + + # 3. SSM transformation z + A = -torch.exp(self.A_log.float()) # (nheads,) + A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + dt = dt[:, :, None].expand(-1, -1, self.head_dim) + dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) + D = self.D[:, None, ...].expand(-1, self.head_dim) + C = rearrange( + C, "b (g n) -> b g n", g=self.n_groups + ) # C.view(batch_size, self.num_heads, self.ssm_state_size) + # B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups) + # C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) + hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) + hidden_states = selective_state_update( + cache_params.ssm_states[self.layer_idx], + hidden_states_reshaped, + dt, + A, + B, + C, + D, + z=None, + dt_bias=dt_bias, + dt_softplus=True, + ) + hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim) + hidden_states = self.norm(hidden_states, gate) + # hidden_states *= torch.nn.functional.silu(gate) + + # 4. Final linear projection + out = self.out_proj(hidden_states)[:, None, ...] + + # Fused calculations or step by step if no initialized cache is found + else: + A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size) + dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit} + + # 2-4. Fused kernel for conv1d, SSM, and the final projection + if self.training and cache_params is None: + assert False, "Should not have ended here for inference" + out = mamba_split_conv1d_scan_combined( + projected_states, + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.dt_bias, + A, + D=self.D, + chunk_size=self.chunk_size, + seq_idx=None, # was seq_idx + activation=self.activation, + rmsnorm_weight=self.norm.weight, + rmsnorm_eps=self.norm.variance_epsilon, + outproj_weight=self.out_proj.weight, + outproj_bias=self.out_proj.bias, + headdim=self.head_dim, + ngroups=self.n_groups, + norm_before_gate=False, + return_final_states=False, + **dt_limit_kwargs, + ) + + else: + # we are not using mlp here, leaving it here from nemotron modeling + gate, hidden_states_B_C = projected_states.split([self.intermediate_size, self.conv_dim], dim=-1) + dt = self.dt_in_proj(hidden_states) + + # 2. Convolution sequence transformation + # Init cache + if cache_params is not None: + hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) + conv_states = nn.functional.pad( + hidden_states_B_C_transposed, + (cache_params.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0), + ) + cache_params.update_conv_state( + layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True + ) + + if self.activation not in ["silu", "swish"]: + hidden_states_B_C = self.act( + self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2) + ) + else: + hidden_states_B_C = causal_conv1d_fn( + x=hidden_states_B_C.transpose(1, 2), + weight=self.conv1d.weight.squeeze(1), + bias=self.conv1d.bias, + activation=self.activation, + ).transpose(1, 2) + hidden_states_B_C = apply_mask_to_padding_states( + hidden_states_B_C, attention_mask + ) # this does not seem to do anything in nemotron + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.d_xb, self.d_xb, groups_time_state_size], + dim=-1, + ) + # simulate GQA by repeating heads in x,b, x -> v, B -> k, C -> q + hidden_states = rearrange( + hidden_states, + "b l (num_xb_heads head_dim) -> b num_xb_heads l head_dim", + l=seq_len, + head_dim=self.head_dim, + num_xb_heads=self.num_xb_heads, + ) # x is b x local_head_groups x l x head_dim + B = rearrange( + B, + "b l (xb_groups state_size) -> b xb_groups l state_size", + state_size=self.ssm_state_size, + ) # b is b x local_head_groups x l x state_size + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, self.repeat_xb_groups, slen, head_dim + ) + hidden_states = hidden_states.reshape( + batch, num_key_value_heads * self.repeat_xb_groups, slen, head_dim + ) + B = B[:, :, None, :, :].expand( + batch, num_key_value_heads, self.repeat_xb_groups, slen, self.ssm_state_size + ) + B = B.reshape(batch, num_key_value_heads * self.repeat_xb_groups, slen, self.ssm_state_size) + hidden_states = hidden_states.transpose(1, 2).contiguous() + B = B.transpose(1, 2).contiguous() + + # 3. SSM transformation + scan_output, ssm_state = mamba_chunk_scan_combined( + hidden_states.view(batch_size, seq_len, -1, self.head_dim), # (b, s, h, d) + dt, # (b, s, h) + A, # (h) + B.view(batch_size, seq_len, -1, self.ssm_state_size), # (b, s, n_groups, state) + C.view(batch_size, seq_len, -1, self.ssm_state_size), # (b, s, n_groups, state) + chunk_size=self.chunk_size, + D=self.D, + z=None, + dt_bias=self.dt_bias, + seq_idx=None, + return_final_states=True, + dt_softplus=True, + **dt_limit_kwargs, + ) + + # Init cache + if ssm_state is not None and cache_params is not None: + cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state) + + scan_output = scan_output.view(batch_size, seq_len, -1) + + # Multiply "gate" branch and apply extra normalization layer + scan_output = self.norm(scan_output, gate) + # scan_output *= torch.nn.functional.silu(gate) + + # 4. Final linear projection + out = self.out_proj(scan_output) + outputs["hidden_states"] = out[:, :seq_len, :] + return outputs + + def torch_forward(self, *args, **kwargs): + assert False, "Should not have ended here for inference, make sure all neccessary kernels are installed" + # see implementation in nemotron modeling https://huggingface.co/nvidia/Nemotron-H-8B-Base-8K/blob/main/modeling_nemotron_h.py + + def forward( + self, + hidden_states, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ): + cache_params = past_key_value + if is_fast_path_available and "cuda" in self.in_proj.weight.device.type: + return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask) + dtype = hidden_states.dtype + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask) + + +class Mamba(nn.Module): + """ + From https://github.com/jxiw/M1/blob/537a1ca5407a786a99dc6c721873493cf8750d5e/mamba/hybrid_mamba_layer.py + """ + def __init__( self, d_model, @@ -1188,7 +1620,7 @@ def forward( class AprielSSMM2DecoderLayer(AprielSSMDecoderLayer): - _mixer_class = Mamba2 + _mixer_class = Mamba class AprielHybridIdentity(nn.Module): @@ -1200,6 +1632,10 @@ def forward(self, hidden_states: torch.Tensor, **kwargs): return (hidden_states,) +class AprielSSMNemotronHM2DecoderLayer(AprielSSMDecoderLayer): + _mixer_class = Mamba2 + + class AprielThinkerSSMHybridModel(MistralModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`AprielDecoderLayer`, `AprielSSMDecoderLayer`] @@ -1213,7 +1649,7 @@ def __init__(self, config: AprielSSMHybridConfig, **kwargs): super().__init__(config_copy, **kwargs) self.config = config blocks = [] - logger.info(f"Loading hyubrid model with the following layout: {config.hybrid_block_layout}") + logger.info(f"Loading hybrid model with the following layout: {config.hybrid_block_layout}") for layer_idx, type in enumerate(config.hybrid_block_layout): if type == "m2d": blocks.append(AprielSSMDecoderLayer(config, layer_idx)) @@ -1223,6 +1659,8 @@ def __init__(self, config: AprielSSMHybridConfig, **kwargs): blocks.append(MistralDecoderLayer(config, layer_idx)) elif type == "i": blocks.append(AprielHybridIdentity(config)) + elif type == "nm2": + blocks.append(AprielSSMNemotronHM2DecoderLayer(config, layer_idx)) else: raise ValueError(f"Invalid block type: {type}") self.layers = nn.ModuleList(blocks) @@ -1246,7 +1684,7 @@ def forward( ) -> BaseModelOutputWithPast: use_cache = use_cache if use_cache is not None else self.config.use_cache if use_cache and past_key_values is None: - # for the case where prepare_inputs_for_generation is not called to create the cache (as in fast-llm test) + # for the case where prepare_inputs_for_generation is not called to create the cache (as in fast-llm test) batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] past_key_values = HybridMambaAttentionDynamicCache(self.config, batch_size, self.dtype, device=self.device) output = super().forward( diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/test_modeling.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/test_modeling.py new file mode 100644 index 000000000..91857d8a0 --- /dev/null +++ b/fast_llm/models/ssm/external/apriel_15b_hybrid/test_modeling.py @@ -0,0 +1,53 @@ +import pytest + +from fast_llm.models.ssm.external.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import NemotronHMamba2Mixer +from fast_llm.models.ssm.external.nemotron.config import NemotronHConfig +from fast_llm.models.ssm.external.nemotron.modeling import NemotronHMamba2Mixer as NemotronHMamba2Mixer_original + + +# in apriel's mamba2 mixer we do not used groups for B and C, but we have the d_xb dim, that simulates GQA +# so in order to reconstruct the original nemotron mixer, we need to set d_xb same as d_inner +@pytest.mark.parametrize( + "apriel_ssm_config, nemotron_h_config", + [ + ( + { + "d_state": 16, + "d_xb": 4096, + "expand": 1, + "d_conv": 4, + "d_inner": 4096, + "conv_bias": True, + "bias": False, + "head_dim": 128, # 4096/128 = 32 heads, 1024/128 = 8 KVheads and 4 repeat groups + }, + NemotronHConfig( + hidden_size=4096, + mamba_num_heads=32, + mamba_head_dim=128, + mamba_n_groups=32, + mamba_d_conv=4, + mamba_expand=1, + ssm_state_size=16, + use_bias=False, + mamba_hidden_act="silu", + ), + ) + ], +) +def test_nemotron_h_mamba2_mixers_identical(apriel_ssm_config: dict, nemotron_h_config: dict): + mixer_apriel = NemotronHMamba2Mixer(d_model=4096, **apriel_ssm_config) + mixer_nemotron_h = NemotronHMamba2Mixer_original(nemotron_h_config, 0) + + for k_a, v_a in mixer_apriel.state_dict().items(): + if k_a == "dt_in_proj.weight": + continue + v_b = mixer_nemotron_h.state_dict()[k_a] + if k_a == "in_proj.weight": + assert [v_a.shape[0], v_a.shape[1]] == [v_b.shape[0] - nemotron_h_config.mamba_num_heads, v_b.shape[1]] + else: + assert v_a.shape == v_b.shape + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py b/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py index ee2c83e03..ea1d6cd35 100644 --- a/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py +++ b/fast_llm/models/ssm/external/eval/apriel_eval_wrapper.py @@ -250,3 +250,76 @@ def _model_generate(self, context, max_length, stop, **generation_kwargs): use_cache=True, **generation_kwargs, ) + + +@register_model("nemotron_h") +class NemotronHWrapper(HFLM): + """Wrapper for NemotronH model for compatibility with lm-evaluation-harness.""" + + def __init__(self, pretrained, **kwargs) -> None: + if "backend" in kwargs: + assert kwargs["backend"] == "causal" + + super().__init__( + pretrained=pretrained, + backend=kwargs.pop("backend", "causal"), + **kwargs, + ) + + # Override device detection for distributed settings + self._device = _get_device() + + def _get_config(self, pretrained: str, **kwargs) -> None: + """Get the model configuration.""" + from fast_llm.models.ssm.external.nemotron.config import NemotronHConfig + + self._config = NemotronHConfig.from_pretrained(pretrained, trust_remote_code=True) + + def _create_model(self, pretrained: str, dtype: Optional[Union[str, torch.dtype]] = "float16", **kwargs) -> None: + """Create the model.""" + from fast_llm.models.ssm.external.nemotron.modeling import NemotronHForCausalLM + + # Ensure we're using the correct device + device = _get_device() + self._device = device + + self._model = NemotronHForCausalLM.from_pretrained( + pretrained, + # device=device, + torch_dtype=torch.bfloat16 if dtype == "auto" else lm_eval.models.utils.get_dtype(dtype), + config=self._config, + ) + + def _model_generate(self, context, max_length, stop, **generation_kwargs): + # Ensure we're using the correct device + device = _get_device() + + # Ensure context is on the same device as the model + context = context.to(device) + self.model.to(device) + + # Move any tensors in generation_kwargs to the correct device + generation_kwargs = _move_tensors_to_device(generation_kwargs, device) + + stopping_criteria = lm_eval.models.utils.stop_sequences_criteria( + self.tokenizer, + stop, + context.shape[1], + context.shape[0], + ) + + generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0) + do_sample = generation_kwargs.get("do_sample", None) + + # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies + if generation_kwargs.get("temperature") == 0.0 and do_sample is None: + generation_kwargs["do_sample"] = do_sample = False + if do_sample is False and generation_kwargs.get("temperature") == 0.0: + generation_kwargs.pop("temperature") + return self.model.generate( + input_ids=context, + max_length=max_length, + stopping_criteria=stopping_criteria, + use_cache=True, + **generation_kwargs, + ) diff --git a/fast_llm/models/ssm/external/nemotron/config.py b/fast_llm/models/ssm/external/nemotron/config.py new file mode 100644 index 000000000..058cd0cbd --- /dev/null +++ b/fast_llm/models/ssm/external/nemotron/config.py @@ -0,0 +1,249 @@ +# Copyright 2024 AI21 Labs Ltd. and the HuggingFace Inc. team. All rights reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""NemotronH model configuration""" + +import re + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class NemotronHConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`NemotronHModel`]. It is used to instantiate a + NemotronH model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the NemotronH-v0.1 model. + + [todo](todo) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 131072): + Vocabulary size of the NemotronH model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`NemotronHModel`] + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the + model has a output word embedding layer. + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 21504): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 52): + Number of hidden layers in the Transformer encoder. + hybrid_override_pattern (`str`, *optional*, defaults to `"M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-"`): + The pattern of the hybrid model. The pattern is a string of characters where each character represents M: Mamba2, *: Attention, -: MLP + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + attention_head_dim (`int`, *optional*, defaults to 128): + Dimension of each attention head. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. + mlp_hidden_act (`str`, *optional*, defaults to "relu2"): + The non-linear activation function in the MLP layers. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in attention layers. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in MLP layers. + use_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the model. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): + The epsilon used by the layer normalization layers. + residual_in_fp32 (`bool`, *optional*, defaults to `False`): + Whether or not residuals should be in `float32`. If set to `False` residuals will keep the same `dtype` as the rest of the model. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + num_logits_to_keep (`int` or `None`, *optional*, defaults to 1): + Number of prompt logits to calculate during generation. If `None`, all logits will be calculated. If an + integer value, only last `num_logits_to_keep` logits will be calculated. + pad_token_id (`int`, *optional*, defaults to 0): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + sliding_window (`int`, *optional*, defaults to None): + Sliding window attention window size. + max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model might ever be used with. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + hidden_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the hidden states. + use_mamba_kernels (`bool`, *optional*, defaults to `True`): + Flag indicating whether or not to use the fast mamba kernels. These are available only if `mamba-ssm` and + `causal-conv1d` are installed, and the mamba modules are running on a CUDA device. + ssm_state_size (`int`, *optional*, defaults to 128): + The dimension of the mamba state space latents. + mamba_num_heads (`int`, *optional*, defaults to 128): + Number of heads in Mamba layers. + mamba_n_groups (`int`, *optional*, defaults to 8): + Number of groups in Mamba layers. + mamba_head_dim (`int`, *optional*, defaults to 64): + Dimension of each Mamba head. + mamba_d_conv (`int`, *optional*, defaults to 4): + The size of the mamba convolution kernel. + mamba_expand (`int`, *optional*, defaults to 2): + Expanding factor used to determine the mamba intermediate size. + mamba_hidden_act (`str`, *optional*, defaults to "silu"): + The non-linear activation function in the Mamba layers. + mamba_dt_min (`float`, *optional*, defaults to 0.001): + Minimum value for the time step in Mamba. + mamba_dt_max (`float`, *optional*, defaults to 0.1): + Maximum value for the time step in Mamba. + mamba_dt_limit (`tuple`, *optional*, defaults to (0.0, float("inf"))): + Limits for the time step in Mamba. + mamba_dt_init_floor (`float`, *optional*, defaults to 1e-4): + Floor value for time step initialization in Mamba. + mamba_conv_bias (`bool`, *optional*, defaults to `True`): + Whether to use bias in the convolution layer of the mamba mixer block. + mamba_proj_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the input and output projections of the mamba mixer block. + mamba_chunk_size (`int`, *optional*, defaults to 256): + Size of chunks for Mamba processing. + rescale_prenorm_residual (`bool`, *optional*, defaults to `True`): + Whether to rescale the pre-normalization residual connections. + """ + + model_type = "nemotron_h" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=131072, + tie_word_embeddings=False, + hidden_size=4096, + intermediate_size=21504, + num_hidden_layers=52, + hybrid_override_pattern="M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-", + num_attention_heads=32, + attention_head_dim=128, + num_key_value_heads=8, # nemo: num_query_groups + mlp_hidden_act="relu2", + attention_bias=False, + mlp_bias=False, + use_bias=False, + initializer_range=0.02, # nemo: init_method_std + layer_norm_epsilon=1e-5, # nemo: layernorm_epsilon + residual_in_fp32=False, # Megatron Core default value + use_cache=True, + num_logits_to_keep=1, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + sliding_window=None, + max_position_embeddings=4096, + attention_dropout=0.0, + hidden_dropout=0.0, # * ADDED + use_mamba_kernels=True, + ssm_state_size=128, # mamba_state_size + mamba_num_heads=128, + mamba_n_groups=8, # nemo: mamba_ssm_ngroups = num_heads + mamba_head_dim=64, + mamba_d_conv=4, + mamba_expand=2, + mamba_hidden_act="silu", + mamba_dt_min=0.001, + mamba_dt_max=0.1, + mamba_dt_limit=(0.0, float("inf")), + mamba_dt_init_floor=1e-4, + mamba_conv_bias=True, + mamba_proj_bias=False, + mamba_chunk_size=256, + rescale_prenorm_residual=True, + **kwargs, + ): + self.vocab_size = vocab_size + self.tie_word_embeddings = tie_word_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.hybrid_override_pattern = hybrid_override_pattern + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.sliding_window = sliding_window + self.max_position_embeddings = max_position_embeddings + self.attention_dropout = attention_dropout + self.hidden_dropout = hidden_dropout + + # Validate hybrid_override_pattern + # M: Mamba2, *: Attention, -: MLP + assert ( + len(self.hybrid_override_pattern) == self.num_hidden_layers + ), "hybrid_override_pattern must have the same length as num_hidden_layers" + assert re.match( + r"^[*-M]+$", self.hybrid_override_pattern + ), "hybrid_override_pattern must only contain characters 'M', '*', or '-'" + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.mlp_hidden_act = mlp_hidden_act + self.attention_bias = attention_bias + self.mlp_bias = mlp_bias + self.use_bias = use_bias + self.initializer_range = initializer_range + self.layer_norm_epsilon = layer_norm_epsilon + self.residual_in_fp32 = residual_in_fp32 + + self.use_cache = use_cache + self.num_logits_to_keep = num_logits_to_keep + + self.use_mamba_kernels = use_mamba_kernels + self.n_groups = mamba_n_groups + self.mamba_head_dim = mamba_head_dim + self.ssm_state_size = ssm_state_size + self.mamba_num_heads = mamba_num_heads + self.conv_kernel = mamba_d_conv + self.expand = mamba_expand + self.mamba_hidden_act = mamba_hidden_act + self.time_step_min = mamba_dt_min + self.time_step_max = mamba_dt_max + self.time_step_limit = mamba_dt_limit + self.time_step_floor = mamba_dt_init_floor + self.use_conv_bias = mamba_conv_bias + self.mamba_proj_bias = mamba_proj_bias + self.chunk_size = mamba_chunk_size + self.rescale_prenorm_residual = rescale_prenorm_residual + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + @property + def layers_block_type(self): + return [ + ( + "mamba" + if self.hybrid_override_pattern[i] == "M" + else "attention" if self.hybrid_override_pattern[i] == "*" else "mlp" + ) + for i in range(self.num_hidden_layers) + ] diff --git a/fast_llm/models/ssm/external/nemotron/modeling.py b/fast_llm/models/ssm/external/nemotron/modeling.py new file mode 100644 index 000000000..154316c7f --- /dev/null +++ b/fast_llm/models/ssm/external/nemotron/modeling.py @@ -0,0 +1,1628 @@ +# Copyright 2024 HuggingFace Inc. team. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch NemotronH model.""" + +import math +from dataclasses import dataclass +from typing import Any, Optional, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers.activations import ACT2FN +from transformers.cache_utils import DynamicCache # we need __iter__ and __len__ of pkv +from transformers.generation import GenerationMixin +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, +) +from transformers.utils.import_utils import ( + is_causal_conv1d_available, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + is_mamba_2_ssm_available, +) + +from fast_llm.models.ssm.external.nemotron.config import NemotronHConfig + +logger = logging.get_logger(__name__) + + +# Copied from transformers.models.mamba.modeling_mamba2.modeling_mamba2.py with MAMBA2->NEMOTRONH,Mamba2->NemotronH +# For Mamba2 components Mamba2->NemotronHMamba2 +if is_mamba_2_ssm_available(): + from mamba_ssm.ops.triton.selective_state_update import selective_state_update + from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined +else: + mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined, selective_state_update = None, None, None + +try: + # from mamba_ssm.ops.triton.layernorm_gated import RMSNorm as RMSNormGated + from mamba_ssm.ops.triton.layernorm_gated import rmsnorm_fn +except ImportError: + raise ImportError("mamba-ssm is required by the Mamba model but cannot be imported") + +if is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +else: + causal_conv1d_update, causal_conv1d_fn = None, None + +if is_flash_attn_2_available(): + from transformers.modeling_flash_attention_utils import _flash_attention_forward + +is_fast_path_available = all( + ( + selective_state_update, + mamba_chunk_scan_combined, + mamba_split_conv1d_scan_combined, + causal_conv1d_fn, + causal_conv1d_update, + ) +) + + +_CHECKPOINT_FOR_DOC = "nvidia/Nemotron-H-56B-Base-8K" +_CONFIG_FOR_DOC = "NemotronHConfig" + + +# Helper methods for segment sum computation + + +def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int): + """ + Padding x tensor with `pad_size` on the seq_len dim (dim=1) + + Assumes that we only have tensors of either size 4 or 3 + """ + pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0) + + return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0) + + +def reshape_into_chunks(input_tensor, pad_size, chunk_size): + """ + Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and + simultaneously splitting it into chunk sequences. + + Assumes that we only have tensors of either size 4 or 3 + """ + # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...] + input_tensor = pad_tensor_by_size(input_tensor, pad_size) + + if len(input_tensor.shape) == 3: + # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads] + return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2]) + else: + # [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size] + return input_tensor.reshape( + input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3] + ) + + +def segment_sum(input_tensor): + """ + More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions. + """ + chunk_size = input_tensor.size(-1) + # 1. expand input tensor to have an additional dimension and repeat along that dimension + # [..., chunk_size] -> [..., chunk_size, chunk_size] + input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size) + # 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag + mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=-1) + input_tensor = input_tensor.masked_fill(~mask, 0) + # 3. compute actual cumsum + tensor_segsum = torch.cumsum(input_tensor, dim=-2) + + # 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time) + mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=0) + tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf) + return tensor_segsum + + +def apply_mask_to_padding_states(hidden_states, attention_mask): + """ + Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 + """ + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return hidden_states + + +# Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/jamba/modeling_jamba.py +class HybridMambaAttentionDynamicCache(DynamicCache): + """ + A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache + (which has a constant shape regardless of seq_len). + + This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` + and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor + For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, + while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). + For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), + while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, + and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. + """ + + def __init__(self, config, batch_size, dtype=torch.float16, device=None): + super().__init__() + self.dtype = dtype + self.hybrid_override_pattern = config.hybrid_override_pattern + self.has_previous_state = False # only used by mamba + intermediate_size = config.expand * config.hidden_size + ssm_state_size = config.ssm_state_size + self.conv_kernel_size = conv_kernel_size = config.conv_kernel + self.conv_states = [] + self.ssm_states = [] + self.transformer_layers = [] + for i in range(config.num_hidden_layers): + if self.hybrid_override_pattern[i] == "M": + # Mamba layer + self.conv_states += [ + torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype) + ] + self.ssm_states += [ + torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype) + ] + else: + # Attention or MLP layer + self.conv_states += [torch.tensor([[]] * batch_size, device=device)] + self.ssm_states += [torch.tensor([[]] * batch_size, device=device)] + self.transformer_layers.append(i) + + self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)] + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Update the cache + if self.key_cache[layer_idx].shape[-1] == 0: + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + device = self.conv_states[layer_idx].device + self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device)) + device = self.ssm_states[layer_idx].device + self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device)) + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # take any layer that contains cache and not empty tensor + layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx + if len(self.key_cache) <= layer_idx: + return 0 + return self.key_cache[layer_idx].shape[-2] + + def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]: + raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + + @classmethod + def from_legacy_cache(cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None) -> "DynamicCache": + raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.") + + # Copied from modeling_mamba2.py + def update_conv_state( + self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False + ) -> torch.Tensor: + if cache_init: + self.conv_states[layer_idx] = new_conv_state.to(self.conv_states[0].device) + else: + self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1) + self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device) + return self.conv_states[layer_idx] + + def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor): + self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states[0].device) + return self.ssm_states[layer_idx] + + def reset(self): + self.conv_states.zero_() + self.ssm_states.zero_() + + +class MambaRMSNormGated(torch.nn.Module): + def __init__(self, hidden_size, group_size, eps=1e-5): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + self.group_size = group_size + + # jan28b version + def forward(self, hidden_states, gate=None): + return rmsnorm_fn( + x=hidden_states, + weight=self.weight, + bias=None, # No bias + z=gate, + eps=self.variance_epsilon, + group_size=self.group_size, + norm_before_gate=False, + ) + + +class NemotronHMamba2Mixer(nn.Module): + """ + Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. + A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) + ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, + and is why Mamba is called **selective** state spaces) + """ + + def __init__(self, config: NemotronHConfig, layer_idx: int): + super().__init__() + self.num_heads = config.mamba_num_heads + self.hidden_size = config.hidden_size + self.ssm_state_size = config.ssm_state_size + self.conv_kernel_size = config.conv_kernel + self.intermediate_size = config.mamba_num_heads * config.mamba_head_dim + self.layer_idx = layer_idx + self.use_conv_bias = config.use_conv_bias + self.activation = config.mamba_hidden_act + self.act = ACT2FN[config.mamba_hidden_act] + + self.layer_norm_epsilon = config.layer_norm_epsilon + + self.n_groups = config.n_groups + self.head_dim = config.mamba_head_dim + self.chunk_size = config.chunk_size + + self.time_step_limit = config.time_step_limit + self.time_step_min = config.time_step_min + self.time_step_max = config.time_step_max + + self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=config.use_conv_bias, + kernel_size=config.conv_kernel, + groups=self.conv_dim, + padding=config.conv_kernel - 1, + ) + + # projection of the input hidden states + projection_size = self.intermediate_size + self.conv_dim + self.num_heads + self.in_proj = nn.Linear( + self.hidden_size, + projection_size, + bias=config.use_bias, + ) + # selective projection used to make dt, B and C input dependant + + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.dt_bias = nn.Parameter(torch.ones(self.num_heads)) + + # S4D real initialization. These are not discretized! + # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded + A = torch.arange(1, self.num_heads + 1) + self.A_log = nn.Parameter(torch.log(A)) + self.A_log._no_weight_decay = True + self.norm = MambaRMSNormGated( + self.intermediate_size, eps=self.layer_norm_epsilon, group_size=self.intermediate_size // self.n_groups + ) + self.D = nn.Parameter(torch.ones(self.num_heads)) + self.D._no_weight_decay = True + + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) + self.use_bias = config.use_bias + + if not is_fast_path_available: + logger.warning_once( + "The fast path is not available because on of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`" + " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and" + " https://github.com/Dao-AILab/causal-conv1d" + ) + + def cuda_kernels_forward( + self, + hidden_states: torch.Tensor, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + # 1. Gated MLP's linear projection + hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) + projected_states = self.in_proj(hidden_states) + + # Set up dimensions for reshapes later + batch_size, seq_len, _ = hidden_states.shape + groups_time_state_size = self.n_groups * self.ssm_state_size + d_mlp = ( + projected_states.shape[-1] + - 2 * self.intermediate_size + - 2 * self.n_groups * self.ssm_state_size + - self.num_heads + ) // 2 + + # Single step calculations via cache + if cache_params is not None and cache_position is not None and cache_position[0] > 0: + _, _, gate, hidden_states_B_C, dt = projected_states.squeeze(1).split( + [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + + # 2. Convolution sequence transformation + hidden_states_B_C = causal_conv1d_update( + hidden_states_B_C, + cache_params.conv_states[self.layer_idx], + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.activation, + ) + + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, groups_time_state_size, groups_time_state_size], + dim=-1, + ) + + # 3. SSM transformation + A = -torch.exp(self.A_log.float()) # (nheads,) + A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + dt = dt[:, :, None].expand(-1, -1, self.head_dim) + dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) + D = self.D[:, None, ...].expand(-1, self.head_dim) + B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups) + C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups) + hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim) + hidden_states = selective_state_update( + cache_params.ssm_states[self.layer_idx], + hidden_states_reshaped, + dt, + A, + B, + C, + D, + z=None, + dt_bias=dt_bias, + dt_softplus=True, + ) + hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim) + hidden_states = self.norm(hidden_states, gate) + + # 4. Final linear projection + out = self.out_proj(hidden_states)[:, None, ...] + + # Fused calculations or step by step if no initialized cache is found + else: + A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size) + dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit} + + # 2-4. Fused kernel for conv1d, SSM, and the final projection + if self.training and cache_params is None: + out = mamba_split_conv1d_scan_combined( + projected_states, + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + self.dt_bias, + A, + D=self.D, + chunk_size=self.chunk_size, + seq_idx=None, # was seq_idx + activation=self.activation, + rmsnorm_weight=self.norm.weight, + rmsnorm_eps=self.norm.variance_epsilon, + outproj_weight=self.out_proj.weight, + outproj_bias=self.out_proj.bias, + headdim=self.head_dim, + ngroups=self.n_groups, + norm_before_gate=False, + return_final_states=False, + **dt_limit_kwargs, + ) + + else: + _, _, gate, hidden_states_B_C, dt = projected_states.split( + [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + + # 2. Convolution sequence transformation + # Init cache + if cache_params is not None: + hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) + conv_states = nn.functional.pad( + hidden_states_B_C_transposed, + (cache_params.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0), + ) + cache_params.update_conv_state( + layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True + ) + + if self.activation not in ["silu", "swish"]: + hidden_states_B_C = self.act( + self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2) + ) + else: + hidden_states_B_C = causal_conv1d_fn( + x=hidden_states_B_C.transpose(1, 2), + weight=self.conv1d.weight.squeeze(1), + bias=self.conv1d.bias, + activation=self.activation, + ).transpose(1, 2) + hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, groups_time_state_size, groups_time_state_size], + dim=-1, + ) + + # 3. SSM transformation + scan_output, ssm_state = mamba_chunk_scan_combined( + hidden_states.view(batch_size, seq_len, -1, self.head_dim), + dt, + A, + B.view(batch_size, seq_len, self.n_groups, -1), + C.view(batch_size, seq_len, self.n_groups, -1), + chunk_size=self.chunk_size, + D=self.D, + z=None, + seq_idx=None, + return_final_states=True, + dt_bias=self.dt_bias, + dt_softplus=True, + **dt_limit_kwargs, + ) + + # Init cache + if ssm_state is not None and cache_params is not None: + cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state) + + scan_output = scan_output.view(batch_size, seq_len, -1) + + # Multiply "gate" branch and apply extra normalization layer + scan_output = self.norm(scan_output, gate) + + # 4. Final linear projection + out = self.out_proj(scan_output) + return out + + # fmt: off + def torch_forward(self, input_states, cache_params: Optional[HybridMambaAttentionDynamicCache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None): + batch_size, seq_len, _ = input_states.shape + dtype = input_states.dtype + + # 1. Gated MLP's linear projection + input_states = apply_mask_to_padding_states(input_states, attention_mask) + projected_states = self.in_proj(input_states) + d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size-self.num_heads) // 2 + _, _, gate, hidden_states_B_C, dt = projected_states.split( + [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 + ) + + # 2. Convolution sequence transformation + if cache_params is not None and cache_position is not None and cache_position[0] > 0: + cache_params.update_conv_state(layer_idx=self.layer_idx, new_conv_state=hidden_states_B_C, cache_init=False) + + # We need to guarantee that anything regarding the cache is on the same device + conv_states = cache_params.conv_states[self.layer_idx].to(device=self.conv1d.weight.device) + + hidden_states_B_C = torch.sum( + conv_states * self.conv1d.weight.squeeze(1), dim=-1 + ) + if self.use_conv_bias: + hidden_states_B_C = hidden_states_B_C + self.conv1d.bias + hidden_states_B_C = self.act(hidden_states_B_C) + else: + # Init cache + if cache_params is not None: + hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2) + conv_states = nn.functional.pad( + hidden_states_B_C_transposed, (cache_params.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0) + ) + cache_params.update_conv_state(layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True) + + hidden_states_B_C = self.act(self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)) + + hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask) + hidden_states, B, C = torch.split( + hidden_states_B_C, + [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], + dim=-1 + ) + + # 3. SSM transformation + A = -torch.exp(self.A_log.float()) # [num_heads] + if cache_params is not None and cache_position is not None and cache_position[0] > 0: + # We need to guarantee that anything regarding the cache is on the same device + cache_device = cache_params.ssm_states.device + + # Note: there is no need to pad parameter matrices here, as there is just one new token + # for batched generation + dt = dt[:, 0, :][:, None, ...] + dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim) + # [num_heads] -> [num_heads, head_dim] + dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim) + + dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype)) + dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) + A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32) + # [bsz, num_heads, head_dim, state_size] + dA = (torch.exp(dt[..., None] * A)).to(device=cache_device) + + # Discretize B + # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] -> + # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size] + B = B.reshape(batch_size, self.n_groups, -1)[..., None, :] + B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous() + B = B.reshape(batch_size, -1, B.shape[-1]) + # [bsz, num_heads, head_dim, state_size] + dB = dt[..., None] * B[..., None, :] + + # Discretize x into dB + # [bsz, intermediate_size] -> [bsz, num_heads, head_dim] + hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim) + dBx = (dB * hidden_states[..., None]).to(device=cache_device) + + # State calculation + cache_params.update_ssm_state( + layer_idx=self.layer_idx, + new_ssm_state=cache_params.ssm_states[self.layer_idx] * dA + dBx + ) + + # Subsequent output + # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size] + C = C.reshape(batch_size, self.n_groups, -1)[..., None, :] + C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous() + C = C.reshape(batch_size, -1, C.shape[-1]) + # [bsz, num_heads, head_dim] + + ssm_states = cache_params.ssm_states[self.layer_idx].to(device=C.device, dtype=C.dtype) # Shape: [b, h, d, n] + # Reshape ssm_states to merge the first two dimensions + ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n] + C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1] + y = torch.bmm(ssm_states_reshaped, C_reshaped) + y = y.view(batch_size, self.num_heads, self.head_dim) + + # D skip connection + # [num_heads] -> [num_heads, head_dim] + D = self.D[..., None].expand(self.D.shape[0], self.head_dim) + y = (y + hidden_states * D).to(y.dtype) + + # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size] + y = y.reshape(batch_size, -1)[:, None, ...] + else: + # begin ssd naive implementation without einsums + dt = nn.functional.softplus(dt + self.dt_bias) + dt = torch.clamp(dt, self.time_step_limit[0], self.time_step_limit[1]) + hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float() + B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float() + B = B.repeat(1, 1, self.num_heads // self.n_groups, 1) + C = C.repeat(1, 1, self.num_heads // self.n_groups, 1) + pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size + + D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size) + + # Discretize x and A + hidden_states = hidden_states * dt[..., None] + A = A.to(hidden_states.dtype) * dt + + # Rearrange into blocks/chunks + hidden_states, A, B, C = (reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)) + + # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size] + A = A.permute(0, 3, 1, 2) + A_cumsum = torch.cumsum(A, dim=-1) + + # 1. Compute the output for each intra-chunk (diagonal blocks) + # This is the analog of a causal mask + L = torch.exp(segment_sum(A)) + + # Contraction of C and B to get G (attention-weights like) + G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :] # shape: (b, c, l, s, h, n) + G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h) + + # Compute M, equivalent to applying attention mask to weights + M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None] + M = M_intermediate.sum(dim=-1) + + # Compute Y_diag (apply to values) + Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(dim=3) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) + decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum) + B_decay = B * decay_states.permute(0, -2, -1, 1)[..., None] + states = (B_decay[..., None, :] * hidden_states[..., None]).sum(dim=2) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + if cache_params is not None and cache_position is not None and cache_position[0] > 0: + previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...].to(device=states.device) + else: + previous_states = torch.zeros_like(states[:, :1]) + states = torch.cat([previous_states, states], dim=1) + decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) + decay_chunk = decay_chunk.transpose(1, 3) + new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1) + states, ssm_state = new_states[:, :-1], new_states[:, -1] + + # 4. Compute state -> output conversion per chunk + # (left term of low-rank factorization of off-diagonal blocks; C terms) + state_decay_out = torch.exp(A_cumsum) + C_times_states = (C[..., None, :] * states[:, :, None, ...]) + state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1) + Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None]) + + # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) + y = Y_diag + Y_off + # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim] + y = y.reshape(batch_size, -1, self.num_heads, self.head_dim) + + y = y + D_residual + # Cutting off padded chunks + if pad_size > 0: + y = y[:, :seq_len, :, :] + y = y.reshape(batch_size, seq_len, -1) + + # Init cache + if ssm_state is not None and cache_params is not None: + cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state) + + scan_output = self.norm(y, gate) + + # end ssd naive + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size] + return contextualized_states + # fmt: on + + def forward( + self, + hidden_states, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + if is_fast_path_available and "cuda" in self.in_proj.weight.device.type: + return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask) + dtype = hidden_states.dtype + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask) + + +class NemotronHRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + NemotronHRMSNorm is equivalent to T5LayerNorm and LlamaRMSNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + # Weights are in float32 + return (self.weight.to(torch.float32) * hidden_states).to(input_dtype) + + +class NemotronHBlock(nn.Module): + def __init__(self, config, layer_idx): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.residual_in_fp32 = config.residual_in_fp32 + self.norm = NemotronHRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + + # M: Mamba2, *: Attention, -: MLP + self.block_type = config.layers_block_type[layer_idx] + if self.block_type == "mamba": + self.mixer = NemotronHMamba2Mixer(config, layer_idx=layer_idx) + elif self.block_type == "attention": + self.mixer = NEMOTRONH_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx) + elif self.block_type == "mlp": + self.mixer = NemotronHMLP(config, layer_idx=layer_idx) + else: + raise ValueError(f"Invalid layer pattern {config.hybrid_override_pattern[layer_idx]}") + + def forward( + self, + hidden_states, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ): + with torch.cuda.stream(torch.cuda.default_stream(hidden_states.device)): + # * Use torch.cuda.stream() to avoid NaN issues when using multiple GPUs + residual = hidden_states + hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + + if self.block_type == "mamba": + hidden_states = self.mixer(hidden_states, cache_params=cache_params, cache_position=cache_position) + elif self.block_type == "attention": + hidden_states = self.mixer(hidden_states, cache_position=cache_position) + hidden_states = hidden_states[0] + elif self.block_type == "mlp": + hidden_states = self.mixer(hidden_states) + else: + raise ValueError(f"Invalid block_type: {self.block_type}") + + hidden_states = residual + hidden_states + return hidden_states + + +# Copied from transformers.models.nemotron.modeling_nemotron Nemotron->NemotronH +class NemotronHMLP(nn.Module): + def __init__(self, config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.mlp_hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.up_proj(x))) + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class NemotronHAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: NemotronHConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + if config.attention_head_dim is not None: + self.head_dim = config.attention_head_dim + else: + self.head_dim = config.hidden_size // config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.is_causal = True + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.head_dim * self.num_heads, self.hidden_size, bias=config.attention_bias) + + def forward( + self, + hidden_states: torch.Tensor, + # position_embeddings: Tuple[torch.Tensor, torch.Tensor], #TODO + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + attn_output = attn_output.transpose(1, 2).contiguous() + # attn_output = attn_output.view(bsz, q_len, self.hidden_size) + attn_output = attn_output.view(bsz, q_len, self.num_heads * self.head_dim) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +# Adapted from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Jamba +# class JambaFlashAttention2(JambaAttention): +class NemotronHFlashAttention2(NemotronHAttention): + """ + Jamba flash attention module. This module inherits from `JambaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + sliding_window=getattr(self.config, "sliding_window", None), + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + # attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Adapted from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Jamba +# class JambaSdpaAttention(JambaAttention): +class NemotronHSdpaAttention(NemotronHAttention): + """ + Jamba attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `JambaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from NemotronHAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "NemotronHModel is using NemotronHSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if self.is_causal and causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +NEMOTRONH_ATTENTION_CLASSES = { + "eager": NemotronHAttention, + "flash_attention_2": NemotronHFlashAttention2, + "sdpa": NemotronHSdpaAttention, +} + + +# Copied from transformers.models.mamba.modeling_mamba2.Mamba2PreTrainedModel +class NemotronHPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = NemotronHConfig + base_model_prefix = "backbone" + _no_split_modules = ["NemotronHBlock"] + supports_gradient_checkpointing = True + _is_stateful = True + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, NemotronHMamba2Mixer): + module.A_log._no_weight_decay = True + module.D._no_weight_decay = True + + dt = torch.exp( + torch.rand(self.config.mamba_num_heads) + * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) + + math.log(self.config.time_step_min) + ).clamp(min=self.config.time_step_floor) + + # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + with torch.no_grad(): + module.dt_bias.copy_(inv_dt) + module.dt_bias._no_reinit = True + + if isinstance(module, nn.Linear): + if module.bias is not None: + if not getattr(module.bias, "_no_reinit", False): + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, std=self.config.initializer_range) + + # TODO: Check + if self.config.rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["out_proj.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(self.config.num_hidden_layers) + + +@dataclass +# Copied from transformers.models.mamba.modeling_mamba2.Mamba2Output with MAMBA2->NemotronH,Mamba2->NemotronH +class NemotronHOutput(ModelOutput): + """ + Class for the NemotronH model outputs. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + cache_params (`HybridMambaAttentionDynamicCache`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + + Includes both the State space model state matrices after the selective scan, and the Convolutional states + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + last_hidden_state: Optional[torch.FloatTensor] = None + cache_params: Optional[HybridMambaAttentionDynamicCache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + + +@dataclass +# Copied from transformers.models.mamba2.modeling_mamba2.MambaCausalLMOutput with Mamba2->NemotronH +class NemotronHCausalLMOutput(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + cache_params (`HybridMambaAttentionDynamicCache`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + + Includes both the State space model state matrices after the selective scan, and the Convolutional states + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + cache_params: Optional[HybridMambaAttentionDynamicCache] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + + +NEMOTRONH_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`NemotronHConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +NEMOTRONH_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): + Indices of input sequence tokens in the vocabulary. + + If `cache_params.seqlen_offset>0`, only `input_ids` that do not have their past calculated should be passed as + `input_ids`. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + position_ids (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. + cache_params (`HybridMambaAttentionDynamicCache`, *optional*): + If passed along, the model uses the previous state in all the blocks (which will give the output for the + `input_ids` provided as if the model add `state_input_ids + input_ids` as context). + use_cache (`bool`, *optional*): + If set to `True`, the `cache_params` is returned and can be used to quickly generate the next logits. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + The position of the current input in the cache. This is used to ensure that the cache is correctly updated. + If `cache_params` is passed, `cache_position` should also be passed. + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) +""" + + +@add_start_docstrings( + "The bare NemotronH Model transformer outputting raw hidden-states without any specific head on top.", + NEMOTRONH_START_DOCSTRING, +) +class NemotronHModel(NemotronHPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList([NemotronHBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)]) + + self.gradient_checkpointing = False + self.norm_f = NemotronHRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + # Initialize weights and apply final processing + self._register_load_state_dict_pre_hook(self.load_hook) + self.post_init() + + def load_hook(self, state_dict, prefix, *args): + for k in state_dict: + if "embedding." in k: + state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k) + break + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, new_embeddings): + self.embeddings = new_embeddings + + @add_start_docstrings_to_model_forward(NEMOTRONH_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=NemotronHOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[tuple, NemotronHOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + # use_cache = use_cache if use_cache is not None else self.config.use_cache + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + # From zamba_modeling.py + if use_cache and cache_params is None: + logger.warning_once( + "NemotronH requires an initialized `NemotronHHybridDynamicCache` to return a cache. None was " + "provided, so no cache will be returned." + ) + + hidden_states = inputs_embeds + + if cache_position is None: + cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) + mamba_mask = self._update_mamba_mask(attention_mask, cache_position) + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + # Until HERE + + for layer_idx, mixer_block in enumerate(self.layers): + # Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention) + if mixer_block.block_type == "mamba": + layer_mask = mamba_mask + elif mixer_block.block_type == "attention": + layer_mask = causal_mask + elif mixer_block.block_type == "mlp": + layer_mask = None + else: + raise ValueError(f"Invalid block_type: {self.block_type}") + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + mixer_block.__call__, hidden_states, cache_params, cache_position, layer_mask + ) + else: + hidden_states = mixer_block( + hidden_states, + cache_params=cache_params, + cache_position=cache_position, + attention_mask=layer_mask, + ) + + # TODO: Store attentions + # if output_attentions: + # if layer_outputs[1] is not None: + # # append attentions only of attention layers. Mamba layers return `None` as the attention weights + # all_self_attns += (layer_outputs[1],) + + # TODO (Check): should it happen before the forward pass? + # if output_hidden_states: + # all_hidden_states = all_hidden_states + (hidden_states,) + + hidden_states = self.norm_f(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None) + + return NemotronHOutput( + last_hidden_state=hidden_states, + cache_params=cache_params if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + # Copied from transformers.models.jamba.modeling_jamba.JambaModel._update_causal_mask + def _update_causal_mask(self, attention_mask, input_tensor, cache_position): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + target_length = cache_position[-1] + 1 + + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.dim() == 2: + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) + causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + def _update_mamba_mask(self, attention_mask, cache_position): + """ + No need for zeroing states when + 1. Cached forward + 2. Attending to all inputs + """ + mamba_mask = attention_mask + if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)): + mamba_mask = None + return mamba_mask + + +@add_start_docstrings( + """ + The NEMOTRONH Model transformer with a language modeling head on top (linear layer with weights not tied to the input + embeddings). + """, + NEMOTRONH_START_DOCSTRING, +) +class NemotronHForCausalLM(NemotronHPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.backbone = NemotronHModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.backbone.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + return self.backbone.set_input_embeddings(new_embeddings) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_decoder(self): + return self.model + + def set_decoder(self, decoder): + self.model = decoder + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + **kwargs, + ): + # Copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/jamba/modeling_jamba.py + # Overwitten -- uses `cache_params` as opposed to `past_key_values` + empty_past_kv = past_key_values is None + + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. + # (we can't check exception 3 while compiling) + if not empty_past_kv: + if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 # Exception 3 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + else: + past_key_values = HybridMambaAttentionDynamicCache( + self.config, input_ids.shape[0], self.dtype, device=self.device + ) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if not empty_past_kv: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and empty_past_kv: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "logits_to_keep": self.config.num_logits_to_keep, + "cache_position": cache_position, + } + ) + return model_inputs + + @add_start_docstrings_to_model_forward(NEMOTRONH_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=NemotronHCausalLMOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + cache_params: Optional[HybridMambaAttentionDynamicCache] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, # for now we need this for generation + ) -> Union[tuple, NemotronHCausalLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + cache_params = cache_params if cache_params is not None else kwargs["past_key_values"] + + nemotron_h_outputs = self.backbone( + input_ids, + cache_params=cache_params, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + use_cache=use_cache, + cache_position=cache_position, + attention_mask=attention_mask, + ) + hidden_states = nemotron_h_outputs[0] + + # TODO: Check zamba_modeling.py: https://github.com/huggingface/transformers/blob/d7188ba600e36d3fd191b12e19f1b3bb81a8404f/src/transformers/models/zamba/modeling_zamba.py#L1284C1-L1286C2 + # logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float() + logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float() + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (logits,) + nemotron_h_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return NemotronHCausalLMOutput( + loss=loss, + logits=logits, + cache_params=nemotron_h_outputs.cache_params, + hidden_states=nemotron_h_outputs.hidden_states, + attentions=nemotron_h_outputs.attentions, + ) diff --git a/tests/test_ssms.py b/tests/test_ssms.py deleted file mode 100644 index 2a338f1ba..000000000 --- a/tests/test_ssms.py +++ /dev/null @@ -1,349 +0,0 @@ -import inspect -import itertools -import pathlib -from functools import partial - -import pytest -import torch -from mamba2 import Mamba2 - -from fast_llm.config import NoAutoValidate -from fast_llm.engine.checkpoint.config import CheckpointLoadConfig -from fast_llm.engine.config_utils.tensor_space import TensorSpace -from fast_llm.engine.distributed.config import DistributedConfig, PhaseType -from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.schedule.config import ScheduleConfig -from fast_llm.engine.schedule.runner import ScheduleRunner -from fast_llm.engine.schedule.schedule import Schedule -from fast_llm.layers.ssm.config import SSMConfig -from fast_llm.layers.ssm.llamba_block import SSMBlock -from fast_llm.layers.transformer.config import TransformerConfig, TransformerKwargs -from fast_llm.models.gpt.config import GPTBatchConfig -from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, LLambaHuggingfaceCheckpointFormat -from fast_llm.models.ssm.model import HybridSSMModel - -_mamba_varlen = False -try: - from mamba_ssm.ops.selective_scan_interface import selective_scan_fn # noqa - - _mamba_available = True - sig = inspect.signature(selective_scan_fn) - if "position_indices" in sig.parameters: - _mamba_varlen = True - else: - _mamba_varlen = False - # for training with packing install https://github.com/jxiw/varlen_mamba - # see https://github.com/jxiw/M1/blob/main/HYBRID_PACK.md - -except (ImportError, RuntimeError): - _mamba_available = False - - -def get_hybrid_config(hybrid_block_layout=["t", "m2"], prediction_heads=1, default_mtp_type=None): - hidden_size = 512 - config = HybridSSMBaseModelConfig( - transformer=TransformerConfig(num_layers=len(hybrid_block_layout), hidden_size=hidden_size), - ssm=SSMConfig(d_xb=hidden_size, dt_rank=10, d_inner=hidden_size * 2), - hybrid_block_layout=hybrid_block_layout, - prediction_heads=prediction_heads, - default_mtp_type=default_mtp_type, - init_method_std_embed=0.02, - init_method_min_embed=-0.02, - init_method_max_embed=0.02, - use_position_embeddings=True, - tie_word_embeddings=False, - ) - return config - - -@pytest.mark.skip("Disabled due to cartesia_pytorch installation issue") -@pytest.mark.slow -def test_load_from_llamba_checkpoint(): - """ - Test to check whether the of Fast-LLM and Huggingface checkpoint loading for Llamba-1B produce the same results. - """ - import cartesia_pytorch.Llamba.llamba - - vocab_size = 128256 # from https://huggingface.co/cartesia-ai/Llamba-1B/blob/main/config.json - batch_size = 2 - seq_length = 32 - - path = pathlib.Path("/mnt/checkpoints_fml/pretrained_models/Llamba-1B") - format = LLambaHuggingfaceCheckpointFormat - - x = torch.randint(0, vocab_size, (batch_size, seq_length), device="cuda") - - hf_model = cartesia_pytorch.Llamba.llamba.LMHeadModel.from_pretrained(path, strict=True).to("cuda") - parameter_sum_hf = sum(p.detach().sum().cpu().item() for p in hf_model.parameters()) - hf_logits = hf_model(x)["logits"].cpu() - del hf_model - torch.cuda.empty_cache() - - # Create checkpoint load config - checkpoint_config = CheckpointLoadConfig(path=path, format=format, model_weights=True, optimizer_state=False) - # Initialize model - model = HybridSSMModel.from_pretrained(checkpoint_config) - param_sum = 0 - for stage in model.stages: - for fsdp in stage.fsdps: - if hasattr(fsdp, "_weight_shard"): - param_sum += torch.sum(fsdp._weight_shard).item() - assert torch.abs(torch.tensor(param_sum) - parameter_sum_hf) < 1e-1 - - # model = GPTModel.from_pretrained(checkpoint_config) - assert model.config.base_model.vocab_size == vocab_size - schedule_config = ScheduleConfig() - with NoAutoValidate(): - batch_config = GPTBatchConfig(micro_batch_size=batch_size, sequence_length=seq_length) - batch_config.setup(DistributedConfig.from_dict({})) - batch_config.validate() - schedule_runner = ScheduleRunner( - config=schedule_config, - multi_stage=model, - distributed_config=model.distributed.config, - ) - schedule = Schedule( - multi_stage=model, - batch_config=batch_config, - schedule_config=schedule_config, - distributed_config=model.distributed.config, - phase=PhaseType.inference, - ) - schedule_runner.setup(model.distributed, optimizer=None) - - common_kwargs = { - TransformerKwargs.sequence_first: True, - TransformerKwargs.grad_output: False, - } - input_data = [(x, common_kwargs)] - - schedule_runner.run_step(iter([input_data]), schedule, iteration=0, return_metrics=True, preprocessed=True) - - logits = input_data[0][1]["logits"].cpu() - assert torch.allclose(logits, hf_logits, atol=1e-2) - - -@pytest.fixture -def distributed_config(): - return DistributedConfig( - tensor_parallel=1, - pipeline_parallel=1, - sequence_data_parallel=1, - local_world_size=1, - world_size=1, - ) - - -@pytest.fixture -def distributed(distributed_config): - return Distributed(config=distributed_config) - - -def materialize_meta_tensors(model, tensor_space): - # Materialize parameters that are on meta device - for name, param in model.named_parameters(): - if param.device.type == "meta": - # Check if the parameter is a custom tensor type - if hasattr(param, "tensor_name") and hasattr(param, "init_parameter"): - param_data = param.new_empty(param.shape, device="cuda") - # Initialize param_data - param.init_parameter(param_data, tensor_space.distributed) - # Replace the parameter in the module - module_path, param_name = name.rsplit(".", 1) if "." in name else (None, name) - module = model - if module_path is not None: - for part in module_path.split("."): - module = getattr(module, part) - param = torch.nn.Parameter(param_data, requires_grad=param.requires_grad) - # TODO: add param_grad_is_zero etc., grad_buffer, etc., see test_mlp_recomputation - param.grad = None - param.grad_buffer = torch.empty_like(param) - param.param_grad_is_zero = True - module._parameters[param_name] = param - return model - - -def unpack(packed_hidden_states, cu_seqlens): - batch_size = packed_hidden_states.shape[0] - package_num = cu_seqlens.shape[0] - 1 - seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max() - hidden_dim = packed_hidden_states.shape[2] - hidden_states = torch.zeros( - package_num * batch_size, - seq_len, - hidden_dim, - dtype=packed_hidden_states.dtype, - device=packed_hidden_states.device, - ) - for j in range(batch_size): - for i in range(package_num): - line = j * package_num + i - hidden_states[line, : cu_seqlens[i + 1] - cu_seqlens[i], :] = packed_hidden_states[ - j, cu_seqlens[i] : cu_seqlens[i + 1], : - ] - return hidden_states - - -def pack(hidden_states, cu_seqlens, batch_size): - package_num, seq_len, hidden_dim = hidden_states.shape - seq_len_list = cu_seqlens[1:] - cu_seqlens[:-1] - seq_len_list_3d = seq_len_list.unsqueeze(1).unsqueeze(2) - indices_3d = ( - torch.arange(seq_len, device=hidden_states.device).unsqueeze(0).unsqueeze(2).repeat(package_num, 1, hidden_dim) - ) - mask_3d = indices_3d < seq_len_list_3d.repeat(batch_size, 1, 1) - packed_hidden_states = hidden_states[mask_3d].view(batch_size, -1, hidden_dim) - return packed_hidden_states - - -def generate_random_cu_seqlens(seq_len, packages_num=2): - if packages_num < 1: - raise ValueError("packages_num must be at least 1") - - # base size of each chunk, and how many get an extra token - base, rem = divmod(seq_len, packages_num) - # lengths: e.g. for seq_len=10, packages=3 → [4,3,3] - lengths = [base + 1 if i < rem else base for i in range(packages_num)] - - # split points exclude the final cumulative (seq_len) - split_points = list(itertools.accumulate(lengths))[:-1] - - # cu_seqlens = [0] + split_points + [seq_len] - cu_seqlens = [0] + split_points + [seq_len] - - # index: for each chunk, we emit 0,1,...,length-1 - index = [] - for length in lengths: - index.extend(range(length)) - - # sanity check - assert len(cu_seqlens) - 1 == packages_num - assert sum(lengths) == seq_len - assert len(index) == seq_len - - return cu_seqlens, index - - -# Quick and dirty test for Mamba2 varlen block from https://github.com/jxiw/M1/blob/d92b53faa640f8ebf624d3e9e771fe24648ef014/rl/verl/tests/pack_mamba/test_mamba_layer.py -# TODO: integrate in the testing framework -@pytest.mark.slow -@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA available") -@pytest.mark.skipif(not _mamba_available, reason="Mamba2 is not available") -@pytest.mark.skipif(not _mamba_varlen, reason="Mamba2 varlen is not available") -def test_mamba_varlen_block(distributed_config, distributed): - """ - Compare that the output and grads of packed and unpacked Mamba2 varlen block are the same. - """ - hybrid_config = get_hybrid_config(hybrid_block_layout=["m2", "t"]) - tensor_space = TensorSpace(distributed_config=distributed_config) - tensor_space.setup(distributed) - hybrid_config.setup_tensor_space(tensor_space) - layer_idx = 0 - - mixer_cls = partial(Mamba2, block_index=layer_idx) - block_packed = SSMBlock( - hybrid_config.transformer, - hybrid_config.ssm, - mixer_cls=mixer_cls, - tensor_space=tensor_space, - block_index=layer_idx, - ) - block_ref = SSMBlock( - hybrid_config.transformer, - hybrid_config.ssm, - mixer_cls=mixer_cls, - tensor_space=tensor_space, - block_index=layer_idx, - ) - device = "cuda" - materialize_meta_tensors(block_packed, tensor_space) - materialize_meta_tensors(block_ref, tensor_space) - block_ref.load_state_dict(block_packed.state_dict()) - block_packed.to(device) - block_ref.to(device) - - batch_size = 2 - seq_len = 64 - packages_num = 2 - hidden_dim = hybrid_config.transformer.hidden_size - - cu_seqlens, index = generate_random_cu_seqlens(seq_len, packages_num=packages_num) - cu_seqlens = torch.tensor(cu_seqlens).cuda() - ssm_position_ids = torch.tensor(index, dtype=torch.int32).unsqueeze(0).expand(batch_size, -1).contiguous().cuda() - seq_idx = ( - torch.cat( - [ - torch.full((s,), i, dtype=torch.int32, device=cu_seqlens.device) - for i, s in enumerate(cu_seqlens[1:] - cu_seqlens[:-1]) - ], - dim=0, - ) - .unsqueeze(0) - .repeat(batch_size, 1) - ) - - # Generate packed_hidden_states with random values for testing - hidden_states_list = [ - torch.randn(l, hidden_dim, device=device, dtype=torch.bfloat16, requires_grad=True) - for l in (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() - ] - packed_hidden_states = torch.cat(hidden_states_list, dim=0).unsqueeze(0) - packed_hidden_states = packed_hidden_states.expand(batch_size, -1, -1).contiguous() - # hidden_states should be forwarded without cu_seqlens - hidden_states = unpack(packed_hidden_states, cu_seqlens) - - # Check: sum of seq_len of item in hidden_states_list should be equal to seq_len of packed_hidden_states - assert sum([hs.shape[0] for hs in hidden_states_list]) == packed_hidden_states.shape[1] - # Check: max of seq_len of item in hidden_states_list should be equal to seq_len of hidden_states - assert max([hs.shape[0] for hs in hidden_states_list]) == hidden_states.shape[1] - - output_states_packed = block_packed( - packed_hidden_states, - {"cu_seqlens": cu_seqlens, "seq_idx": seq_idx, "ssm_position_ids": ssm_position_ids, "sequence_first": False}, - ) - output_states_unpacked = block_ref( - hidden_states.clone(), {"cu_seqlens": None, "seq_idx": None, "ssm_position_ids": None, "sequence_first": False} - ) - tollerance = 1e-4 - assert output_states_packed.shape == packed_hidden_states.shape - assert output_states_unpacked.shape == hidden_states.shape - assert not torch.isnan(hidden_states).any() - assert not torch.isinf(hidden_states).any() - - output_states_unpacked = pack(output_states_unpacked, cu_seqlens, batch_size) - torch.allclose(output_states_packed, output_states_unpacked, atol=tollerance) - - loss = output_states_packed.sum() - loss.backward() - loss_ref = output_states_unpacked.sum() - loss_ref.backward() - assert torch.allclose(block_packed.mixer.conv1d_weight.grad, block_ref.mixer.conv1d_weight.grad, atol=tollerance) - assert torch.allclose(block_packed.mixer.conv1d_bias.grad, block_ref.mixer.conv1d_bias.grad, atol=tollerance) - assert torch.allclose( - block_packed.mixer.in_proj.weight.grad_buffer, block_ref.mixer.in_proj.weight.grad_buffer, atol=tollerance - ) - assert torch.allclose( - block_packed.mixer.out_proj.weight.grad_buffer, block_ref.mixer.out_proj.weight.grad_buffer, atol=tollerance - ) - assert torch.allclose( - block_packed.mixer.dt_in_proj.weight.grad_buffer, - block_ref.mixer.dt_in_proj.weight.grad_buffer, - atol=tollerance, - ) - - assert torch.allclose( - block_packed.mlp.layer_1.weight.grad_buffer, block_ref.mlp.layer_1.weight.grad_buffer, atol=tollerance - ) - assert torch.allclose( - block_packed.mlp.layer_1.bias.grad_buffer, block_ref.mlp.layer_1.bias.grad_buffer, atol=tollerance - ) - assert torch.allclose( - block_packed.mlp.layer_2.weight.grad_buffer, block_ref.mlp.layer_2.weight.grad_buffer, atol=tollerance - ) - assert torch.allclose( - block_packed.mlp.layer_2.bias.grad_buffer, block_ref.mlp.layer_2.bias.grad_buffer, atol=tollerance - ) - - -if __name__ == "__main__": - pytest.main([__file__]) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 96982e510..e5c9eb129 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -562,6 +562,37 @@ def _update_and_add_testing_config( ), # "pp","dp", "ce","16", "bf", "df", "stp"), ) +_update_and_add_testing_config( + # Tests hybrid Mamba 2. + "llama", + "hybrid_nm2", + model_type="hybrid_ssm", + extra_args=[ + "model.base_model.hybrid_block_layout=['t','nm2']", + "model.base_model.ssm.state_size=8", + "model.base_model.ssm.d_xb=16", + "model.base_model.ssm.head_dim=8", + "model.base_model.ssm.n_groups=4", + # f"model.base_model.transformer.debug_transformer={_LOG_LEVEL}" + ], + megatron_args=None, + checkpoint_format=AprielThinkerSSMHHybridHuggingfaceCheckpointFormat, + groups={ + ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.normal, + ModelTestingGroup.convert: ModelTestingGroupAction.normal, + ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.distributed: ModelTestingGroupAction.normal, + }, + compare_factor=2.0, + # Micro-sequence split not supported. + skip_tests=( + "sdp", + "ms", + ), # "pp","dp", "ce","16", "bf", "df", "stp"), +) + _update_and_add_testing_config( # Tests hybrid discrete Mamba 2.