diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py index 0d971a88a..6c4b95b20 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_space.py @@ -66,13 +66,23 @@ def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: return TensorDim(self.name, self.size * distributed_dim.size, distributed_dim) def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": - if self.parallel_group is not None: + if self.is_parallel: from fast_llm.core.ops import gather_op return gather_op(tensor, self.parallel_group, dim) else: return tensor + def local_to_global_partial( + self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 + ) -> "torch.Tensor": + if self.is_parallel: + output = tensor.new_full((*tensor.shape[:dim], self.parallel_dim.size, *tensor.shape[dim:]), fill_value) + output.narrow(dim, self.parallel_dim.rank, 1).copy_(tensor.unsqueeze(dim)).squeeze(dim) + return output.flatten(dim, dim + 1) + else: + return tensor + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": return ( tensor.chunk(self.parallel_dim.size, dim)[self.parallel_dim.rank] @@ -85,7 +95,7 @@ class CompositeTensorDim(TensorDim): def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): parallel_dim = None for dim, tensor_dim in enumerate(tensor_dims): - if tensor_dim.is_parallel: + if tensor_dim.parallel_dim is not None: # TODO: Allow more than one parallel subdim? assert parallel_dim is None parallel_dim = tensor_dim.parallel_dim @@ -111,6 +121,15 @@ def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor return tensor.flatten(dim, dim + len(self._tensor_dims) - 1) + def local_to_global_partial( + self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 + ) -> "torch.Tensor": + tensor = tensor.unflatten(dim, [tensor_dim.size for tensor_dim in self._tensor_dims]) + for i, tensor_dim in enumerate(self._tensor_dims): + tensor = tensor_dim.local_to_global_partial(tensor, dim + i) + + return tensor.flatten(dim, dim + len(self._tensor_dims) - 1) + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": tensor = tensor.unflatten(dim, [tensor_dim.global_size for tensor_dim in self._tensor_dims]) for i, tensor_dim in reversed(list(enumerate(self._tensor_dims))): @@ -157,6 +176,27 @@ def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor else tensor ) + def local_to_global_partial( + self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 + ) -> "torch.Tensor": + import torch + + return ( + torch.concatenate( + [ + tensor_dim.local_to_global_partial(tensor_, dim) + for tensor_, tensor_dim in zip( + tensor.split([tensor_dim.size for tensor_dim in self._tensor_dims], dim), + self._tensor_dims, + strict=True, + ) + ], + dim, + ) + if self.is_parallel + else tensor + ) + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": if self.is_parallel and expand: raise NotImplementedError() @@ -223,8 +263,5 @@ def add_tensor_dim(self, tensor_dim: TensorDim) -> None: ) self._tensor_dims[tensor_dim.name] = tensor_dim - def get_tensor_dim(self, name: str) -> TensorDim: + def __getitem__(self, name: str) -> TensorDim: return self._tensor_dims[name] - - # TODO: Replace uses - __getitem__ = get_tensor_dim diff --git a/fast_llm/engine/multi_stage/config.py b/fast_llm/engine/multi_stage/config.py index 6ac157dfe..719088057 100644 --- a/fast_llm/engine/multi_stage/config.py +++ b/fast_llm/engine/multi_stage/config.py @@ -31,6 +31,7 @@ if typing.TYPE_CHECKING: from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM + from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel logger = logging.getLogger(__name__) @@ -241,6 +242,10 @@ def get_checkpoint_handler_class(cls, format: type[CheckpointFormat] | str) -> t def get_model_class(cls) -> type["FastLLMModel"]: raise NotImplementedError + @classmethod + def get_inference_runner_class(cls) -> type["InferenceRunner"]: + raise NotImplementedError + @classmethod def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceBaseModelForCausalLM"]: raise NotImplementedError diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py index 5b44bf14b..be15cd37a 100644 --- a/fast_llm/engine/multi_stage/fsdp.py +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -441,39 +441,21 @@ def _get_parameter_shard_indices_in_full_weight( where it is located in the shard if it exists, or -1 if it's not in the shard. Used to determine the location of each entry in a different distributed configuration. """ - - # Create an empty index for the global parameter. - index = torch.full( - parameter_meta.global_shape, - -1, - dtype=torch.int64, - device=device, - ) # Set the shard slice of the global parameter to corresponding indices of the parameter slice of the shard begin, end = self._get_parameter_range_in_shard(parameter_name) - buffer_index = parameter_meta.global_to_local(index, expand=True) - # Copying directly into `buffer_index` requires a view of the tensor, which may not be feasible. - # In that case, we work with a separate tensor to be copied back into `buffer_index`. - try: - buffer_index_flat = buffer_index.view(-1) - is_view = True - except RuntimeError: - buffer_index_flat = buffer_index.new_full((buffer_index.numel(),), -1) - is_view = False - - # Copy the shard indices at their respective positions in the flat buffer index. - buffer_index_flat[ + # Create an empty local index to hold the local shard indices. + buffer_index = torch.full_like(parameter_meta, -1, dtype=torch.int64, device=device) + + # Copy the shard indices at their respective positions in the buffer index. + buffer_index.flatten()[ self._index_buffer_to_param( self._fsdp_dim.rank * self._shard_size, parameter_name ) : self._index_buffer_to_param((self._fsdp_dim.rank + 1) * self._shard_size, parameter_name) ].copy_(torch.arange(begin, end, dtype=torch.int64, device=device)) - # If needed, copy the flat buffer index back into the index. - if not is_view: - buffer_index.copy_(buffer_index_flat.view_as(buffer_index)) - - return index + # Create a global index from the local one. + return parameter_meta.local_to_global_partial(buffer_index, -1) def copy_shard_overlaps( self, diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 3dbec5348..9372ad7fb 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -32,7 +32,6 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.training.trainer import Trainer, TrainingEvaluator @@ -403,10 +402,6 @@ def _setup(self): def get_trainer_class(cls) -> type["Trainer"]: raise NotImplementedError - @classmethod - def get_inference_runner_class(cls) -> type["InferenceRunner"]: - raise NotImplementedError - def _get_runnable(self) -> typing.Callable[[], None]: from fast_llm.engine.distributed.distributed import Distributed diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index 5f5511a15..ec3c4cebe 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -142,7 +142,7 @@ def __init__(self, config: TrainerConfig): self._reference_models = {} for name, reference_config in self._config.reference_models.items(): log_main_rank(f"Creating `{name} reference model...") - self._reference_models[name] = self._config.get_inference_runner_class()( + self._reference_models[name] = reference_config.model.get_inference_runner_class()( reference_config.model.get_model_class()(reference_config.model) ) self._multi_stage.base_model.add_reference_model(name, self._reference_models[name]) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index eaeaa0d18..b18a9ec0b 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -49,6 +49,19 @@ def _torch_cross_entropy_forward_backward( return loss.detach_(), grad +def distributed_log_softmax(logits: torch.Tensor, group: ProcessGroup, dim: int = -1): + logits = logits.float() + local_max = logits.max(dim=dim, keepdim=True)[0] + all_reduce(local_max, op=ReduceOp.MAX, group=group) + + logits_shifted = logits - local_max + exp_logits = torch.exp(logits_shifted) + sum_exp = exp_logits.sum(dim=dim, keepdim=True) + all_reduce(sum_exp, op=ReduceOp.SUM, group=group) + + return logits_shifted - sum_exp.log() # log_softmax + + @torch.compile def _fused_softmax_base( logits: torch.Tensor, logits_scale_factor: float = 1.0, group: ProcessGroup | None = None, dim: int = -1 @@ -214,21 +227,21 @@ def cross_entropy_forward_backward( ) -def _torch_reverse_kl_forward_backward( +def _torch_reverse_kl_forward_backward_vocab_parallel( logits: torch.Tensor, target: torch.Tensor, loss_mask: torch.Tensor | None, grad_output: float | None, - logits_scale_factor: float, target_format: TargetFormat, group: ProcessGroup | None = None, - teacher_softmax_temperature: float = 1.0, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Reverse KL using PyTorch's native kl_div function. + This is used for TP version where we split accross vocab dimantion. This works with sequence-tensor-parallel (distributing over the sequence dimention) as well as a non-TP case. In sequence-tensor-parallel, where we split along sequence dim., we compute per split loss and then average the loss. """ + # TODO: merge into single function _torch_reverse_kl_forward_backward Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format") Assert.eq(target.shape, logits.shape) assert target.dtype.is_floating_point, target.dtype @@ -236,16 +249,66 @@ def _torch_reverse_kl_forward_backward( Assert.eq(loss_mask.shape, logits.shape[:-1]) # Compute log probabilities - let _fused_softmax handle scaling internally - # teacher_probs = _fused_softmax(target, logits_scale_factor * (1 / teacher_softmax_temperature), group) - # # teacher_log_probs = torch.log(teacher_probs + 1e-8) # log(p) - # teacher_probs = torch.clamp(teacher_probs, min=1e-7) # or even 1e-6 - # teacher_log_probs = torch.log(teacher_probs) + teacher_log_probs = distributed_log_softmax(target, group=group) + batch_size = logits.shape[0] + with torch.enable_grad(): + logits_ = logits.detach().requires_grad_(grad_output is not None) + student_log_probs = distributed_log_softmax(logits_, group=group) + + # Reverse KL: input=teacher_log_probs, target=student_probs + if loss_mask is None: + loss = torch.nn.functional.kl_div( + teacher_log_probs, # input = log(p) + student_log_probs, # target = log(q) + reduction="sum", + log_target=True, + ) + else: + # Apply loss mask - this requires some reshaping + raise NotImplementedError("Loss mask not implemented with TP for reverse KL , it must be doublechecked") + loss_per_sample = torch.nn.functional.kl_div( + teacher_log_probs, student_log_probs, reduction="none", log_target=True + ).sum(dim=-1) + loss = (loss_per_sample * loss_mask).sum() + + if group is not None and target_format != TargetFormat.labels: + all_reduce(loss, op=ReduceOp.SUM, group=group) + loss /= batch_size + + if grad_output is not None: + loss.backward(torch.full_like(loss, grad_output)) + grad = logits_.grad.to(logits.dtype) + else: + grad = None + return loss.detach_(), grad + + +def _torch_reverse_kl_forward_backward( + logits: torch.Tensor, + target: torch.Tensor, + loss_mask: torch.Tensor | None, + grad_output: float | None, + logits_scale_factor: float, + target_format: TargetFormat, + group: ProcessGroup | None = None, + teacher_softmax_temperature: float = 1.0, +) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Reverse KL using PyTorch's native kl_div function. + This works with sequence-tensor-parallel (distributing over the sequence dimention) as well as a non-TP case. + In sequence-tensor-parallel, where we split along sequence dim., we compute per split loss and then average the loss. + """ + Assert.eq(target_format, TargetFormat.logits, msg="Reverse KL only supports logits format") + Assert.eq(target.shape, logits.shape) + assert target.dtype.is_floating_point, target.dtype + if loss_mask is not None: + Assert.eq(loss_mask.shape, logits.shape[:-1]) # Scale target logits more carefully scaled_target = target * (logits_scale_factor / teacher_softmax_temperature) + # Clamp to prevent extreme values that cause NaNs in log_softmax + scaled_target = torch.clamp(scaled_target, min=-100.0, max=100.0) - # Clamp to prevent extreme values before log_softmax - scaled_target = torch.clamp(scaled_target, min=-50, max=50) teacher_log_probs = torch.log_softmax(scaled_target, dim=-1) # For reverse KL: KL(q||p) = Σ q * log(q/p) = Σ q * (log(q) - log(p)) @@ -256,9 +319,10 @@ def _torch_reverse_kl_forward_backward( logits_ = logits.detach().requires_grad_(grad_output is not None) scaled_logits = logits_ * logits_scale_factor - scaled_logits = torch.clamp(scaled_logits, min=-50, max=50) + # Clamp to prevent extreme values that cause NaNs in log_softmax + scaled_logits = torch.clamp(scaled_logits, min=-100.0, max=100.0) student_log_probs = torch.log_softmax(scaled_logits, dim=-1) - + # Reverse KL: input=teacher_log_probs, target=student_probs if loss_mask is None: loss = torch.nn.functional.kl_div( @@ -279,6 +343,7 @@ def _torch_reverse_kl_forward_backward( loss /= group.size() if grad_output is not None: + # note, we never get here in TP over seq. dim. loss.backward(torch.full_like(loss, grad_output)) grad = logits_.grad.to(logits.dtype) else: @@ -344,6 +409,14 @@ def reverse_kl_forward_backward( Assert.eq(teacher_softmax_temperature, 1) Assert.eq(logits_scale_factor, 1) raise NotImplementedError("Vocab parallel reverse KL is not implemented yet.") + return _torch_reverse_kl_forward_backward_vocab_parallel( + logits, + target, + loss_mask, + grad_output, + target_format, + group, + ) else: return _torch_reverse_kl_forward_backward( logits, diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 7036a1e97..f6f43d199 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -46,10 +46,10 @@ def __init__( self._dropout_p = config.transformer.hidden_dropout self._use_absolute_position_embeddings = config.use_absolute_position_embeddings - hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) - vocab_dim = tensor_space.get_tensor_dim( + hidden_dim = tensor_space[TransformerDimNames.hidden] + vocab_dim = tensor_space[ LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab - ) + ] if self._parallel_embeddings: self._vocab_start_index = self._distributed_config.tensor_rank * vocab_dim.size @@ -66,7 +66,7 @@ def __init__( ) if self._use_absolute_position_embeddings: self.position_embeddings_weight = ParameterMeta.from_dims( - (tensor_space.get_tensor_dim(LanguageModelDimNames.position_embed), hidden_dim), + (tensor_space[LanguageModelDimNames.position_embed], hidden_dim), init_method=init_normal_( std=config.init_method_std_embed, min_val=config.init_method_min_embed, diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index eed2d134f..24c06d5cc 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -61,7 +61,7 @@ def __init__( if self._cross_entropy_splits is not None and self._sequence_parallel: assert not self._parallel_embeddings - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = self._tensor_space[TransformerDimNames.hidden] self._loss_coefficient = ( config.prediction_loss_coefficient[prediction_distance] if config.prediction_loss_coefficient else 1.0 @@ -108,9 +108,9 @@ def _init_output_weights(self, hidden_dim: TensorDim, config) -> None: if self._tie_word_embeddings or self._prediction_distance > 0: return # untie embedding weights - vocab_dim = self._tensor_space.get_tensor_dim( + vocab_dim = self._tensor_space[ LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab - ) + ] self.output_weights = ParameterMeta.from_dims( (vocab_dim, hidden_dim), init_method=init_normal_( @@ -237,7 +237,6 @@ def _get_targets( ).flatten() else: lm_target = None - targets = (dpo_target, lm_target, distillation_target) # If we do distillation, no need to split it here as it has already been split in the embedding layer! # if we do CPT/language modeling, we need to split the targets here! @@ -350,9 +349,9 @@ def _logits_cross_entropy_forward_backward( logits_scale_factor=self._logits_scale_factor, ) if self._debug_transformer and self._cross_entropy_splits is None: - vocab_dim = self._tensor_space.get_tensor_dim( + vocab_dim = self._tensor_space[ LanguageModelDimNames.vocab if self._sequence_parallel_logits else LanguageModelDimNames.vocab_tp - ) + ] dims = [*kwargs[TransformerKwargs.hidden_dims][:-1], vocab_dim] sequence_index = 1 - int(kwargs[TransformerKwargs.sequence_first]) dims[sequence_index] = ( diff --git a/fast_llm/layers/language_model/preprocessing.py b/fast_llm/layers/language_model/preprocessing.py index d719bef3d..c8d53a789 100644 --- a/fast_llm/layers/language_model/preprocessing.py +++ b/fast_llm/layers/language_model/preprocessing.py @@ -28,7 +28,7 @@ def __init__( assert config.use_absolute_position_embeddings self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config - self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) + self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] def _create_tensors(self, sequence_length: int) -> None: if sequence_length <= self._tensor_cache_max_sequence_length: @@ -76,7 +76,7 @@ def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config - self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) + self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: return diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index c06d85148..3b21ca698 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -1,13 +1,16 @@ import enum +import typing from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, ConcatenatedTensorDim, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import LLMBlockConfig, NormalizationConfig -from fast_llm.tensor import Initializer from fast_llm.utils import Assert, div +if typing.TYPE_CHECKING: + from fast_llm.tensor import Initializer + class SSMDimNames: # TODO: Use separate tensor space for different mixers so there is no risk of name conflict. @@ -16,6 +19,8 @@ class SSMDimNames: head_groups = "ssm_head_groups" group_heads = "ssm_group_heads" + # Mamba 2 + x_proj_dim_2 = "x_proj_dim_2" # d_xb convolution_kernel = "ssm_convolution_kernel" # Kernel dimension of the conv1d in mamba layers dt_rank = "ssm_dt_rank" @@ -62,7 +67,7 @@ class DTInitType(enum.StrEnum): constant = "constant" random = "random" - def get_init_method(self, scale: float) -> Initializer: + def get_init_method(self, scale: float) -> "Initializer": from fast_llm.tensor import init_fill_, init_uniform_centered_ return init_fill_(scale) if self == DTInitType.constant else init_uniform_centered_(scale) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 64377b93c..c9d555de9 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -49,25 +49,25 @@ def __init__( layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) - inner_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_head_dim) - hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) - conv1d_dim = tensor_space.get_tensor_dim(SSMDimNames.concatenated_convolution) - heads_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads) + inner_dim = tensor_space[SSMDimNames.composite_heads_and_head_dim] + hidden_dim = tensor_space[TransformerDimNames.hidden] + conv1d_dim = tensor_space[SSMDimNames.concatenated_convolution] + heads_dim = tensor_space[SSMDimNames.composite_heads] # local_head_groups = head_groups / TP - self._local_head_groups = tensor_space.get_tensor_dim(SSMDimNames.head_groups).size + self._local_head_groups = tensor_space[SSMDimNames.head_groups].size # local_heads = local_head_groups * group_heads self._local_heads = heads_dim.size # local_inner_size = local_heads * head_size self._local_inner_size = inner_dim.size # local_bc_size = local_head_groups * state - self._local_bc_size = tensor_space.get_tensor_dim(SSMDimNames.composite_head_groups_and_state).size + self._local_bc_size = tensor_space[SSMDimNames.composite_head_groups_and_state].size # TODO: double check initializations # Projections self.in_proj = OutputParallelLinear( hidden_dim, - tensor_space.get_tensor_dim(name=SSMDimNames.concatenated_inner_projection), + tensor_space[SSMDimNames.concatenated_inner_projection], bias=config.add_bias_linear, weight_init_method=init_kaiming_(transformer_config.hidden_size), sequence_parallel=self._sequence_parallel, @@ -83,8 +83,8 @@ def __init__( self.conv1d_weight = ParameterMeta.from_dims( ( conv1d_dim, - tensor_space.get_tensor_dim(DefaultDimNames.scalar), - tensor_space.get_tensor_dim(name=SSMDimNames.convolution_kernel), + tensor_space[DefaultDimNames.scalar], + tensor_space[SSMDimNames.convolution_kernel], ), init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), lr_scale=lr_scale, diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 1ae25e44c..77c1b3869 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -62,13 +62,13 @@ def __init__( layer_lr_scale: float | None = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale: float | tuple[float | None, ...] | None = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) - inner_dim: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.composite_heads_and_head_dim) - xb_dim = tensor_space.get_tensor_dim(name=SSMDimNames.composite_head_groups_and_state) - hidden_dim: TensorDim = tensor_space.get_tensor_dim(name=TransformerDimNames.hidden) - dt_rank_dim = tensor_space.get_tensor_dim(name=SSMDimNames.dt_rank) + inner_dim: TensorDim = tensor_space[SSMDimNames.composite_heads_and_head_dim] + xb_dim = tensor_space[SSMDimNames.composite_head_groups_and_state] + hidden_dim: TensorDim = tensor_space[TransformerDimNames.hidden] + dt_rank_dim = tensor_space[SSMDimNames.dt_rank] - self._local_heads = tensor_space.get_tensor_dim(name=SSMDimNames.composite_heads).size - self._local_head_groups = tensor_space.get_tensor_dim(name=SSMDimNames.head_groups).size + self._local_heads = tensor_space[SSMDimNames.composite_heads].size + self._local_head_groups = tensor_space[SSMDimNames.head_groups].size self._group_heads = div(self._local_heads, self._local_head_groups) self._local_inner_size = inner_dim.size self._local_xb_size = xb_dim.size @@ -77,8 +77,8 @@ def __init__( self.conv1d_weight = ParameterMeta.from_dims( ( conv1d_dim, - tensor_space.get_tensor_dim(DefaultDimNames.scalar), - tensor_space.get_tensor_dim(name=SSMDimNames.convolution_kernel), + tensor_space[DefaultDimNames.scalar], + tensor_space[SSMDimNames.convolution_kernel], ), init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), lr_scale=lr_scale, @@ -90,7 +90,7 @@ def __init__( ) self.in_proj = OutputParallelLinear( hidden_dim, - tensor_space.get_tensor_dim(name=SSMDimNames.concatenated_inner_projection), + tensor_space[SSMDimNames.concatenated_inner_projection], bias=config.add_bias_linear, weight_init_method=init_kaiming_(transformer_config.hidden_size), sequence_parallel=self._sequence_parallel, @@ -122,7 +122,7 @@ def __init__( lr_scale=lr_scale, ) self.A_log = ParameterMeta.from_dims( - (inner_dim, tensor_space.get_tensor_dim(name=SSMDimNames.state)), + (inner_dim, tensor_space[SSMDimNames.state]), init_method=init_A(self._config.state_size, self._config.d_inner), lr_scale=lr_scale, weight_decay=False, diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 64c8227fc..9343ef1b8 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -69,8 +69,8 @@ def __init__( Assert.eq(self._config.activation_type, ActivationType.silu) # Tensor dims: - inner_dim = tensor_space.get_tensor_dim(SSMDimNames.composite_heads_and_head_dim) - hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) + inner_dim = tensor_space[SSMDimNames.composite_heads_and_head_dim] + hidden_dim = tensor_space[TransformerDimNames.hidden] layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) @@ -78,7 +78,7 @@ def __init__( # TODO: lr_scale? self.in_proj = Linear( hidden_dim, - tensor_space.get_tensor_dim(SSMDimNames.concatenated_inner_projection), + tensor_space[SSMDimNames.concatenated_inner_projection], bias=False, weight_init_method=init_kaiming_(hidden_dim.size), ) @@ -86,8 +86,8 @@ def __init__( self.conv1d_weight = ParameterMeta.from_dims( ( inner_dim, - tensor_space.get_tensor_dim(DefaultDimNames.scalar), - tensor_space.get_tensor_dim(SSMDimNames.convolution_kernel), + tensor_space[DefaultDimNames.scalar], + tensor_space[SSMDimNames.convolution_kernel], ), init_method=init_kaiming_(inner_dim.size), lr_scale=lr_scale, @@ -95,7 +95,7 @@ def __init__( self.x_proj = Linear( inner_dim, - tensor_space.get_tensor_dim(SSMDimNames.concatenated_x_projection), + tensor_space[SSMDimNames.concatenated_x_projection], weight_init_method=init_kaiming_(inner_dim.size), bias=False, lr_scale=lr_scale, @@ -104,7 +104,7 @@ def __init__( # TODO: the weights are initialized a bit differently here https://github.com/state-spaces/mamba/blob/0cce0fa645f100f00620ddf2333c2b7712abfdec/mamba_ssm/modules/mamba_simple.py#L82 self.dt_proj_weight = ParameterMeta.from_dims( - (inner_dim, tensor_space.get_tensor_dim(SSMDimNames.dt_rank)), + (inner_dim, tensor_space[SSMDimNames.dt_rank]), init_method=init_kaiming_(self._config.dt_rank), lr_scale=lr_scale, ) @@ -116,7 +116,7 @@ def __init__( ) self.A_log = ParameterMeta.from_dims( - (inner_dim, tensor_space.get_tensor_dim(SSMDimNames.state)), + (inner_dim, tensor_space[SSMDimNames.state]), weight_decay=False, init_method=init_A(self._config.state_size, inner_dim.size), lr_scale=lr_scale, diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 7b8bc98c8..c03aeed8e 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -72,14 +72,14 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i max_val=self._config.init_method_max_attn_proj, ) - self._kv_channels = self._tensor_space.get_tensor_dim(self._transformer_dim_names.kv_channels).size - self._head_groups = self._tensor_space.get_tensor_dim(self._transformer_dim_names.head_groups).global_size - self._local_head_groups = self._tensor_space.get_tensor_dim(self._transformer_dim_names.head_groups).size - self._local_heads_per_group = self._tensor_space.get_tensor_dim(self._transformer_dim_names.group_heads).size + self._kv_channels = self._tensor_space[self._transformer_dim_names.kv_channels].size + self._head_groups = self._tensor_space[self._transformer_dim_names.head_groups].global_size + self._local_head_groups = self._tensor_space[self._transformer_dim_names.head_groups].size + self._local_heads_per_group = self._tensor_space[self._transformer_dim_names.group_heads].size self._local_heads = self._local_head_groups * self._local_heads_per_group self._softmax_scale = self._kv_channels ** (-self._config.attention_softmax_scale_power) - hidden_dim = self._tensor_space.get_tensor_dim(self._transformer_dim_names.hidden) + hidden_dim = self._tensor_space[self._transformer_dim_names.hidden] layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None attention_lr_scale = get_lr_scale(self._config.attention_lr_scale, layer_lr_scale) @@ -87,7 +87,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i # TODO: Merge the query and key-value computations? (harder with sequence parallel.) self.query = OutputParallelLinear( hidden_dim, - self._tensor_space.get_tensor_dim(self._transformer_dim_names.composite_query), + self._tensor_space[self._transformer_dim_names.composite_query], bias=self._config.add_attn_qkv_bias, weight_init_method=init_method_qkv, bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, @@ -96,7 +96,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i ) self.key_value = OutputParallelLinear( hidden_dim, - self._tensor_space.get_tensor_dim(self._transformer_dim_names.composite_key_value), + self._tensor_space[self._transformer_dim_names.composite_key_value], bias=self._config.add_attn_qkv_bias, weight_init_method=init_method_qkv, bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, @@ -110,7 +110,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i # Output. self.dense = InputParallelLinear( - self._tensor_space.get_tensor_dim(self._transformer_dim_names.composite_dense), + self._tensor_space[self._transformer_dim_names.composite_dense], hidden_dim, bias=self._config.add_attn_dense_bias, weight_init_method=init_method_std_attn_proj, diff --git a/fast_llm/layers/transformer/mixture_of_experts.py b/fast_llm/layers/transformer/mixture_of_experts.py index 73f83ccf5..4fd2844d5 100644 --- a/fast_llm/layers/transformer/mixture_of_experts.py +++ b/fast_llm/layers/transformer/mixture_of_experts.py @@ -63,8 +63,8 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s router_lr_scale = get_lr_scale(config.router_lr_scale, layer_lr_scale) self.router = Linear( - tensor_space.get_tensor_dim(TransformerDimNames.hidden), - tensor_space.get_tensor_dim(TransformerDimNames.unshared_experts), + tensor_space[TransformerDimNames.hidden], + tensor_space[TransformerDimNames.unshared_experts], bias=False, weight_init_method=init_normal_( std=config.init_method_std, min_val=config.init_method_min, max_val=config.init_method_max @@ -255,7 +255,7 @@ def _debug_log( def _get_meta(self, tensor: torch.Tensor, name: str, dim_name: str, kwargs: dict[str, typing.Any]) -> TensorMeta: return TensorMeta.from_dims( - kwargs[TransformerKwargs.hidden_dims][:-1] + (self._tensor_space.get_tensor_dim(dim_name),), + kwargs[TransformerKwargs.hidden_dims][:-1] + (self._tensor_space[dim_name],), tensor_name=f"{self._name} {name}", dtype=tensor.dtype, ) diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index ecf2c3fea..5dee4e077 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -33,8 +33,8 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s max_val=config.init_method_max_mlp_2, ) - hidden_dim = tensor_space.get_tensor_dim(self._transformer_dim_names.hidden) - self._intermediate_dim = tensor_space.get_tensor_dim(self._transformer_dim_names.composite_expert_mlp) + hidden_dim = tensor_space[self._transformer_dim_names.hidden] + self._intermediate_dim = tensor_space[self._transformer_dim_names.composite_expert_mlp] self._sequence_parallel = tensor_space.distributed_config.sequence_tensor_parallel self._recompute_level = config.mlp_recompute_level @@ -49,7 +49,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) self.layer_1 = LinearBase( hidden_dim, - tensor_space.get_tensor_dim(self._transformer_dim_names.composite_gated_expert_mlp), + tensor_space[self._transformer_dim_names.composite_gated_expert_mlp], bias=config.add_mlp_bias, weight_init_method=init_method_1, bias_init_method=init_method_1 if config.random_bias_init else init_zeros_, diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index cb64ccf06..ee30112d7 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -30,7 +30,7 @@ def __init__( self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config assert not self._config.do_use_flash_attention(self._distributed_config) - self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) + self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] def _create_tensors(self, sequence_length: int) -> None: if sequence_length <= self._tensor_cache_max_sequence_length: diff --git a/fast_llm/layers/transformer/rotary/preprocessing.py b/fast_llm/layers/transformer/rotary/preprocessing.py index cc83dae02..c357411b6 100644 --- a/fast_llm/layers/transformer/rotary/preprocessing.py +++ b/fast_llm/layers/transformer/rotary/preprocessing.py @@ -25,8 +25,8 @@ def __init__( self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config - self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) - self._kv_channels_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.kv_channels) + self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] + self._kv_channels_dim = self._tensor_space[TransformerDimNames.kv_channels] def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: self._create_tensors(kwargs[TransformerKwargs.sequence_length]) diff --git a/fast_llm/layers/transformer/rotary/rotary.py b/fast_llm/layers/transformer/rotary/rotary.py index b2c69dd8d..6b4b81415 100644 --- a/fast_llm/layers/transformer/rotary/rotary.py +++ b/fast_llm/layers/transformer/rotary/rotary.py @@ -84,8 +84,8 @@ def __init__( super().__init__(config, tensor_space) self._tensor_space = tensor_space if self._tensor_space is not None: - self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) - self._kv_channels_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.kv_channels) + self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] + self._kv_channels_dim = self._tensor_space[TransformerDimNames.kv_channels] def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: assert self._tensor_space is not None diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index d2f3bfba8..9289dccfb 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -48,7 +48,7 @@ def _get_meta( } return TensorMeta.from_dims( tuple( - hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space.get_tensor_dim(dim_name) + hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space[dim_name] for dim_name in dim_names ), tensor_name=f"Block {self._block_index} {self._mixer_name} {name}", @@ -99,8 +99,7 @@ def __init__( self._block_index = block_index self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory - - hidden_dim = self._tensor_space.get_tensor_dim(self._transformer_dim_names.hidden) + hidden_dim = self._tensor_space[self._transformer_dim_names.hidden] # Note, layer_lr_scale does not impact the norms # TODO: add a separate norm_lr_scale self.norm_1 = self._config.normalization.get_layer(hidden_dim) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index bc64821f2..182ad1712 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -184,6 +184,12 @@ def get_model_class(cls) -> type["GPTModel"]: return GPTModel + @classmethod + def get_inference_runner_class(cls) -> type["GPTInferenceRunner"]: + from fast_llm.models.gpt.model import GPTInferenceRunner + + return GPTInferenceRunner + @classmethod def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceGPTModelForCausalLM"]: from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelForCausalLM @@ -289,9 +295,3 @@ def get_trainer_class(cls) -> type["GPTTrainer"]: from fast_llm.models.gpt.trainer import GPTTrainer return GPTTrainer - - @classmethod - def get_inference_runner_class(cls) -> type["GPTInferenceRunner"]: - from fast_llm.models.gpt.model import GPTInferenceRunner - - return GPTInferenceRunner diff --git a/fast_llm/models/gpt/megatron.py b/fast_llm/models/gpt/megatron.py index e7379e61e..20ed8e828 100644 --- a/fast_llm/models/gpt/megatron.py +++ b/fast_llm/models/gpt/megatron.py @@ -14,8 +14,8 @@ def get_init_megatron( meta: "ParameterMeta", config: TransformerConfig -) -> typing.Callable[["torch.Tensor", "Distributed"], "torch.Tensor"]: - def init_megatron(tensor: "torch.Tensor", distributed: "Distributed"): +) -> typing.Callable[["torch.Tensor", "Distributed"], None]: + def init_megatron(tensor: "torch.Tensor", distributed: "Distributed") -> None: Assert.eq(distributed.config.world_size, 1) if "bias" in meta.tensor_name: # Generator unused. @@ -29,11 +29,11 @@ def init_megatron(tensor: "torch.Tensor", distributed: "Distributed"): elif config.num_experts > 1 and "mlp.layer_" in meta.tensor_name: tensor_ = _init_moe_mlp_megatron(config, meta, tensor, distributed) elif "mlp.layer_2" in meta.tensor_name: - tensor_ = _init_transposed_mlp_weight_megatron(config, meta, tensor, distributed) + tensor_ = _init_transposed_mlp_weight_megatron(meta, tensor, distributed) else: # Word embedding (override generator), layer norm (generator unused), other mlp weights. return meta.param_init_method(meta, tensor, distributed.tp_init_generator) - return tensor.copy_(tensor_.reshape_as(tensor)) + tensor.copy_(tensor_.reshape_as(tensor)) return init_megatron @@ -58,9 +58,9 @@ def _init_attention_megatron( generator = distributed.tp_init_generator state = generator.get_state() # Initialize a mock dense layer to advance the random state - dense_tensor_ = meta.param_init_method( + meta.param_init_method( meta, - tensor.new_empty( + dense_tensor_ := tensor.new_empty( config.kv_channels * config.num_attention_heads, config.hidden_size, ), @@ -68,9 +68,9 @@ def _init_attention_megatron( ) # QKV is split differently. (Assuming no tensor-parallel.) heads_per_group = div(config.num_attention_heads, config.head_groups) - qkv_tensor_ = meta.param_init_method( + meta.param_init_method( meta, - tensor.new_empty( + qkv_tensor_ := tensor.new_empty( config.head_groups, heads_per_group + 2, config.kv_channels, @@ -110,18 +110,19 @@ def _init_position_embeddings_megatron( # Megatron initializes the position embeddings on cpu twice. assert meta.param_init_method is not None generator = distributed.default_cpu_generator - tensor_ = meta.param_init_method(meta, torch.empty(tensor.shape, dtype=tensor.dtype), generator) - return meta.param_init_method(meta, tensor_, generator) + meta.param_init_method(meta, tensor_ := torch.empty(tensor.shape, dtype=tensor.dtype), generator) + meta.param_init_method(meta, tensor_, generator) + return tensor_ def _init_transposed_mlp_weight_megatron( - config: TransformerConfig, meta: "ParameterMeta", tensor: "torch.Tensor", distributed: "Distributed" + meta: "ParameterMeta", tensor: "torch.Tensor", distributed: "Distributed" ) -> "torch.Tensor": import torch # Megatron never transposes the mlp layer 2 weight. assert meta.param_init_method is not None - tensor_ = meta.param_init_method(meta, torch.empty_like(tensor), distributed.tp_init_generator) + meta.param_init_method(meta, tensor_ := torch.empty_like(tensor), distributed.tp_init_generator) return tensor_.view(meta.size(1), meta.size(0)).t() @@ -132,8 +133,8 @@ def _init_moe_router_megatron( # Megatron initializes the router on cpu. assert meta.param_init_method is not None - tensor_ = meta.param_init_method( - meta, torch.empty(tensor.shape, dtype=tensor.dtype), distributed.default_cpu_generator + meta.param_init_method( + meta, tensor_ := torch.empty(tensor.shape, dtype=tensor.dtype), distributed.default_cpu_generator ) return tensor_ diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 7b4d165c5..ebf84fc58 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -219,7 +219,7 @@ def preprocess_meta( sequence_first = self._config.sequence_first assert not (need_sequence_first and not sequence_first) - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = self._tensor_space[TransformerDimNames.hidden] hidden_dims = ( (hidden_sequence_q_dim, batch_dim, hidden_dim) if sequence_first diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 886fa7a32..471e6d06c 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -4,19 +4,23 @@ from fast_llm.config import Field, FieldHint, FieldUpdate, config_class from fast_llm.data.data.gpt.config import GPTDataConfig -from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointHandler +from fast_llm.engine.checkpoint.config import CheckpointHandler from fast_llm.engine.config_utils.runnable import RunnableConfig from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig from fast_llm.layers.ssm.config import SSMBlockType, SSMConfig -from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, PretrainedGPTModelConfig +from fast_llm.models.gpt.config import ( + GPTBaseModelConfig, + GPTBatchConfig, + GPTHuggingfaceCheckpointFormat, + PretrainedGPTModelConfig, +) from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.models.gpt.model import GPTInferenceRunner from fast_llm.models.ssm.huggingface import HuggingfaceHybridSSMModelForCausalLM - from fast_llm.models.ssm.model import HybridSSMModel + from fast_llm.models.ssm.model import HybridSSMInferenceRunner, HybridSSMModel from fast_llm.models.ssm.trainer import HybridSSMTrainer logger = logging.getLogger(__name__) @@ -80,8 +84,7 @@ def _validate(self): self.ssm_block_type = ssm_block_types.pop() if ssm_block_types else None -class LLambaHuggingfaceCheckpointFormat(CheckpointFormat): - support_optimizer: typing.ClassVar[bool] = False +class LLambaHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "llamba" @classmethod @@ -91,8 +94,7 @@ def get_handler_class(cls) -> type[CheckpointHandler]: return LLambaHuggingfaceCheckpointHandler -class AprielSSMHuggingfaceCheckpointFormat(CheckpointFormat): - support_optimizer: typing.ClassVar[bool] = False +class AprielSSMHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "apriel_ssm" @classmethod @@ -102,8 +104,7 @@ def get_handler_class(cls) -> type[CheckpointHandler]: return AprielSSMHuggingfaceCheckpointHandler -class AprielSSMHHybridHuggingfaceCheckpointFormat(CheckpointFormat): - support_optimizer: typing.ClassVar[bool] = False +class AprielSSMHHybridHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "apriel_ssm_hybrid" @classmethod @@ -113,8 +114,7 @@ def get_handler_class(cls) -> type[CheckpointHandler]: return AprielSSMHHybridHuggingfaceCheckpointHandler -class AprielThinkerSSMHHybridHuggingfaceCheckpointFormat(CheckpointFormat): - support_optimizer: typing.ClassVar[bool] = False +class AprielThinkerSSMHHybridHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat): name: typing.ClassVar[str] = "apriel_ssm_thinker_hybrid" trust_remote_code: typing.ClassVar[bool] = True @@ -165,6 +165,16 @@ def get_model_class(cls) -> type["HybridSSMModel"]: return HybridSSMModel + @classmethod + def get_inference_runner_class(cls) -> type["HybridSSMInferenceRunner"]: + from fast_llm.models.ssm.model import HybridSSMInferenceRunner + + logger.warning( + "HybridSSMInferenceRunner only supports training-style forward pass. Use generate with cache disabled." + ) + + return HybridSSMInferenceRunner + @classmethod def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceHybridSSMModelForCausalLM"]: from fast_llm.models.ssm.huggingface import HuggingfaceHybridSSMModelForCausalLM @@ -213,14 +223,3 @@ def _validate(self) -> None: Assert.none(reference_model.model.base_model.cross_entropy_splits) Assert.eq(reference_model.model.base_model.parallel_embeddings, self.model.base_model.parallel_embeddings) Assert.geq(reference_model.model.base_model.prediction_heads, self.model.base_model.prediction_heads) - - @classmethod - def get_inference_runner_class(cls) -> type["GPTInferenceRunner"]: - from fast_llm.models.gpt.model import GPTInferenceRunner - - # TODO: we dont have inference runner for SSM/Hybrid yet, should return None? - logger.warning( - "No inference runner for SSM/Hybrid yet, using GPTInferenceRunner for now, which does not support SSM/Hybrid" - ) - - return GPTInferenceRunner diff --git a/fast_llm/models/ssm/external/15B_hybrid.ipynb b/fast_llm/models/ssm/external/15B_hybrid.ipynb new file mode 100644 index 000000000..a8f0c33b7 --- /dev/null +++ b/fast_llm/models/ssm/external/15B_hybrid.ipynb @@ -0,0 +1,1562 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/toolkit/.local/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "from transformers import AutoConfig, AutoModelForCausalLM\n", + "# from transformers import MistralForCausalLM\n", + "# from fast_llm.models.ssm.external.apriel_15b_hybrid.configuration_ssm_hybrid_apriel15b import AprielSSMHybridConfig\n", + "# from fast_llm.models.ssm.external.apriel_15b_hybrid.modeling_ssm_hybrid_apriel15b import AprielSSMHybridForCausalLM\n", + "# autoreload changes to the code\n", + "%reload_ext autoreload\n", + "%autoreload 2\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# model_path = \"/mnt/checkpoints/fast_llm_exp/slam_ssm_distill/15bch-ifrhyb20l32h-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti1000_lm2/export/apriel_ssm_thinker_hybrid/1000\"\n", + "# AutoConfig.from_pretrained(model_path, trust_remote_code=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# model_path = \"/mnt/checkpoints/fast_llm_exp/slam_ssm_distill/15bch-ifrhyb20l32h-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti1000_lm2/export/apriel_ssm_thinker_hybrid/1000\"\n", + "# m = AutoModelForCausalLM.from_pretrained(\n", + "# model_path, trust_remote_code=True,\n", + "# config=AutoConfig.from_pretrained(model_path, trust_remote_code=True),\n", + "# )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Slam 15B upcycled" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " Lead the weights of https://huggingface.co/ServiceNow-AI/Slam-15B-Upcycled/ into Thiked modeling, it shoudl work" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.append(\"/home/toolkit/dev/fml-ops/__oo_playground\")\n", + "from results_analysis.results_loader import ResultsLoader\n", + "layer_importance_path = \"/mnt/evaluations/training_evaluation/model_runs/lm_eval_runner/apriel_ssm_importance/\"\n", + "results_loader = ResultsLoader(layer_importance_path)\n", + "\n", + "results_loader.deserialize_results()\n", + "results_df = results_loader.to_df()\n", + "results_df[\"layer_index\"] = results_df.apply(lambda row: int(row[\"model_name_sanitized\"].split(\"_\")[-1] if \"layers_\" in row[\"model_name_sanitized\"] else -1), axis=1)\n", + "results_df = results_df[results_df[\"metric\"] == \"acc_norm\"]\n", + "columns_to_keep = [\"layer_index\", \"metric_value\"]\n", + "results_df = results_df[columns_to_keep]\n", + "layer_importance = results_df.groupby(\"layer_index\").mean()\n", + "layer_importance = layer_importance.sort_values(by=\"metric_value\", ascending=False).reset_index()\n", + "layer_importance = layer_importance[layer_importance[\"layer_index\"]!= -1]\n", + "layer_importance = list(layer_importance[\"layer_index\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[22,\n", + " 25,\n", + " 20,\n", + " 31,\n", + " 29,\n", + " 46,\n", + " 23,\n", + " 26,\n", + " 33,\n", + " 24,\n", + " 47,\n", + " 27,\n", + " 21,\n", + " 41,\n", + " 17,\n", + " 18,\n", + " 34,\n", + " 42,\n", + " 44,\n", + " 30,\n", + " 16,\n", + " 8,\n", + " 43,\n", + " 35,\n", + " 19,\n", + " 38,\n", + " 15,\n", + " 28,\n", + " 32,\n", + " 45,\n", + " 37,\n", + " 40,\n", + " 7,\n", + " 36,\n", + " 13,\n", + " 10,\n", + " 5,\n", + " 39,\n", + " 6,\n", + " 14,\n", + " 4,\n", + " 12,\n", + " 9,\n", + " 48,\n", + " 1,\n", + " 3,\n", + " 11,\n", + " 49,\n", + " 0]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "layer_importance" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "path_thinker = \"/mnt/checkpoints/upstream/Apriel-Nemotron-15b-Thinker\"\n", + "n_ssm = 25\n", + "\n", + "config_thinker = AutoConfig.from_pretrained(path_thinker)\n", + "hybrid_block_layout = [\"t\"] * config_thinker.num_hidden_layers\n", + "\n", + "for i in range(n_ssm):\n", + " hybrid_block_layout[layer_importance[i]] = \"m2d\"\n", + "\n", + "config_hybrid = AprielSSMHybridConfig(\n", + " **config_thinker.to_dict(),\n", + " hybrid_block_layout=hybrid_block_layout,\n", + " ssm_cfg = {\n", + " \"d_state\": 64,\n", + " \"n_v_heads\": 32,\n", + " \"n_qk_heads\": 32,\n", + " \"expand\": 1,\n", + " \"chunk_size\": 128,\n", + " \"activation\": \"identity\",\n", + " \"bias\": False,\n", + " \"d_conv\": 4,\n", + " \"d_inner\": 32 * 128\n", + " }\n", + ")\n", + "model_hybrid = AprielSSMHybridForCausalLM(config_hybrid)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['t',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 'm2d',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 't',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 't',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 'm2d',\n", + " 't',\n", + " 'm2d',\n", + " 'm2d',\n", + " 't',\n", + " 't']" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hybrid_block_layout" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "You are using a model of type llama to instantiate a model of type mistral. This is not supported for all configurations of models and can yield errors.\n", + "Loading checkpoint shards: 0%| | 0/4 [00:00 v, B -> k, C -> q\n", + " mamba_encoder.mixer.in_proj.weight.data[\n", + " mamba_config.ssm_cfg[\"d_inner\"] : mamba_config.ssm_cfg[\"d_inner\"] + mamba_config.ssm_cfg[\"d_xb\"], :\n", + " ].copy_(layer_module.self_attn.v_proj.weight.data)\n", + " mamba_encoder.mixer.in_proj.weight.data[\n", + " mamba_config.ssm_cfg[\"d_inner\"] + mamba_config.ssm_cfg[\"d_xb\"] : mamba_config.ssm_cfg[\"d_inner\"] + 2 * mamba_config.ssm_cfg[\"d_xb\"], :\n", + " ].copy_(layer_module.self_attn.k_proj.weight.data)\n", + " mamba_encoder.mixer.in_proj.weight.data[\n", + " mamba_config.ssm_cfg[\"d_inner\"] + 2 * mamba_config.ssm_cfg[\"d_xb\"] : 2 * mamba_config.ssm_cfg[\"d_inner\"] + 2 * mamba_config.ssm_cfg[\"d_xb\"], :\n", + " ].copy_(layer_module.self_attn.q_proj.weight.data)\n", + "\n", + " print(\"Init Mamba using Attention\")\n", + "\n", + " transformer.model.layers[layer_idx] = mamba_encoder\n", + "\n", + " # elif type == \"m2d\":\n", + " # print(\"Converting layer %d...\" % layer_idx)\n", + " # mamba_encoder = AprielSSMDecoderLayer(\n", + " # mamba_config,\n", + " # layer_idx,\n", + " # device=\"cpu\",\n", + " # dtype=torch_dtype,\n", + " # )\n", + " # mamba_encoder.mlp.load_state_dict(layer_module.mlp.state_dict())\n", + " # mamba_encoder.input_layernorm.load_state_dict(layer_module.input_layernorm.state_dict())\n", + " # mamba_encoder.post_attention_layernorm.load_state_dict(layer_module.post_attention_layernorm.state_dict())\n", + " # mamba_encoder.mixer.out_proj.load_state_dict(layer_module.self_attn.o_proj.state_dict())\n", + "\n", + " # if init_with_kqvo:\n", + " \n", + "\n", + "\n", + " \n", + " else:\n", + " raise ValueError(f\"Invalid layer type: {type}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading checkpoint shards: 100%|██████████| 7/7 [00:05<00:00, 1.22it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Converting layer %d... 0\n", + "Skipping transformer layer 0...\n", + "Converting layer %d... 1\n", + "Skipping transformer layer 1...\n", + "Converting layer %d... 2\n", + "Skipping transformer layer 2...\n", + "Converting layer %d... 3\n", + "Skipping transformer layer 3...\n", + "Converting layer %d... 4\n", + "Skipping transformer layer 4...\n", + "Converting layer %d... 5\n", + "Skipping transformer layer 5...\n", + "Converting layer %d... 6\n", + "Skipping transformer layer 6...\n", + "Converting layer %d... 7\n", + "Skipping transformer layer 7...\n", + "Converting layer %d... 8\n", + "Skipping transformer layer 8...\n", + "Converting layer %d... 9\n", + "Skipping transformer layer 9...\n", + "Converting layer %d... 10\n", + "Skipping transformer layer 10...\n", + "Converting layer %d... 11\n", + "Skipping transformer layer 11...\n", + "Converting layer %d... 12\n", + "Skipping transformer layer 12...\n", + "Converting layer %d... 13\n", + "Skipping transformer layer 13...\n", + "Converting layer %d... 14\n", + "Skipping transformer layer 14...\n", + "Converting layer %d... 15\n", + "Skipping transformer layer 15...\n", + "Converting layer %d... 16\n", + "Skipping transformer layer 16...\n", + "Converting layer %d... 17\n", + "Skipping transformer layer 17...\n", + "Converting layer %d... 18\n", + "Skipping transformer layer 18...\n", + "Converting layer %d... 19\n", + "Skipping transformer layer 19...\n", + "Converting layer %d... 20\n", + "Converting layer 20...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 21\n", + "Converting layer 21...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 22\n", + "Converting layer 22...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 23\n", + "Converting layer 23...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 24\n", + "Converting layer 24...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 25\n", + "Converting layer 25...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 26\n", + "Skipping transformer layer 26...\n", + "Converting layer %d... 27\n", + "Converting layer 27...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 28\n", + "Skipping transformer layer 28...\n", + "Converting layer %d... 29\n", + "Skipping transformer layer 29...\n", + "Converting layer %d... 30\n", + "Converting layer 30...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 31\n", + "Converting layer 31...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 32\n", + "Converting layer 32...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 33\n", + "Converting layer 33...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 34\n", + "Skipping transformer layer 34...\n", + "Converting layer %d... 35\n", + "Converting layer 35...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 36\n", + "Converting layer 36...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 37\n", + "Converting layer 37...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 38\n", + "Converting layer 38...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 39\n", + "Converting layer 39...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 40\n", + "Converting layer 40...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 41\n", + "Converting layer 41...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 42\n", + "Converting layer 42...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 43\n", + "Converting layer 43...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 44\n", + "Converting layer 44...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 45\n", + "Converting layer 45...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 46\n", + "Converting layer 46...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 47\n", + "Converting layer 47...\n", + "Init Mamba using Attention\n", + "Converting layer %d... 48\n", + "Skipping transformer layer 48...\n", + "Converting layer %d... 49\n", + "Converting layer 49...\n", + "Init Mamba using Attention\n" + ] + } + ], + "source": [ + "transformer = AutoModelForCausalLM.from_pretrained(path_thinker)\n", + "init_with_kqvo = True\n", + "torch_dtype = torch.bfloat16\n", + "attn_bias = True\n", + "convert_layers(transformer, config_hybrid, hybrid_block_layout, init_with_kqvo, attn_bias, torch_dtype)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "transformer.config = config_hybrid" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AprielSSMHybridConfig {\n", + " \"architectures\": [\n", + " \"MistralForCausalLM\"\n", + " ],\n", + " \"attention_dropout\": 0.0,\n", + " \"bos_token_id\": 1,\n", + " \"eos_token_id\": 2,\n", + " \"head_dim\": 128,\n", + " \"hidden_act\": \"silu\",\n", + " \"hidden_size\": 5120,\n", + " \"hybrid_block_layout\": [\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"t\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"t\",\n", + " \"m2\",\n", + " \"t\",\n", + " \"t\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"t\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"m2\",\n", + " \"t\",\n", + " \"m2\"\n", + " ],\n", + " \"initializer_range\": 0.02,\n", + " \"intermediate_size\": 14336,\n", + " \"max_position_embeddings\": 65536,\n", + " \"model_type\": \"apriel_ssm_thinker_hybrid\",\n", + " \"num_attention_heads\": 32,\n", + " \"num_hidden_layers\": 50,\n", + " \"num_key_value_heads\": 8,\n", + " \"rms_norm_eps\": 1e-05,\n", + " \"rope_theta\": 1000000.0,\n", + " \"sliding_window\": null,\n", + " \"ssm_cfg\": {\n", + " \"activation\": \"identity\",\n", + " \"bias\": false,\n", + " \"chunk_size\": 128,\n", + " \"conv_bias\": true,\n", + " \"d_conv\": 4,\n", + " \"d_inner\": 4096,\n", + " \"d_state\": 16,\n", + " \"d_xb\": 1024,\n", + " \"dt_init\": \"random\",\n", + " \"dt_init_floor\": 0.0001,\n", + " \"dt_max\": 0.1,\n", + " \"dt_min\": 0.001,\n", + " \"dt_rank\": \"auto\",\n", + " \"dt_scale\": 1.0,\n", + " \"expand\": 1,\n", + " \"n_qk_heads\": 32,\n", + " \"n_v_heads\": 32\n", + " },\n", + " \"tie_word_embeddings\": false,\n", + " \"torch_dtype\": \"bfloat16\",\n", + " \"transformers_version\": \"4.53.2\",\n", + " \"use_cache\": true,\n", + " \"vocab_size\": 131072\n", + "}" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "transformer.config" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "transformer.config.architectures=[\"AprielThinkerSSMHybridForCausalLM\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 427.77it/s]\n" + ] + } + ], + "source": [ + "# load state dict from existing pretrained SSM?\n", + "path_25hyb = \"/mnt/checkpoints/ssm/apriel_ssm_thinker5l_hybrid_1ssm_init_rand_debug_tpformat\" #\"/mnt/checkpoints/fast_llm_exp/slam_ssm_distill/15b-oshyb25lmil-bs768-lr0.0003-lrs0-0-0-0-sl4096_ti5000_lm6/export/apriel_ssm_thinker_hybrid/5000_new\"\n", + "model = AprielThinkerSSMHybridForCausalLM.from_pretrained(path_25hyb)\n", + "state_dict = model.state_dict()\n", + "\n", + "# missing, unexpected = transformer.load_state_dict(state_dict, strict=False)\n", + "# print(missing)\n", + "# print(unexpected)\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Note: saving as transformer wilkl still keep architectures[\"Mistral....\"]. So currently need to manually update the checkpoints architectures list to have AprielThinkerSSMHybridForCausalLM" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# mamba2, state 16, expand 1, i.e. same as M1, but with discrete mamba2 and MIL\n", + "# transformer.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_thinker15b_hybrid_1ssm_leastimportant_m2_16hexp1_init_mil\") # 1 ssm\n", + "# transformer.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_thinker15b_hybrid_25ssm_leastimportant_m2_16hexp1_init_mil\") # 25 ssm\n", + "transformer.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_thinker15b_hybrid_25ssm_leastimportant_m2_16hexp1_init_mil_tpformat\") # 25 ssm\n", + "\n", + "\n", + "# transformer.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_thinker15b_hybrid_40ssm_leastimportant_m2_16hexp1_init_mil_uniform_from_25h5000lm6\") # 40 ssm" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Data mixing" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([])\n", + "KL (global, F.kl_div) = 0.738795\n", + "KL (sum of shards, manual) = 0.738795\n" + ] + } + ], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "fast_llm", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/fast_llm/models/ssm/external/5B_hybrid.ipynb b/fast_llm/models/ssm/external/5B_hybrid.ipynb new file mode 100644 index 000000000..9a33f577e --- /dev/null +++ b/fast_llm/models/ssm/external/5B_hybrid.ipynb @@ -0,0 +1,416 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/toolkit/.local/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "\n", + "import torch\n", + "import random\n", + "from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer\n", + "\n", + "fast_llm_path = \"/home/toolkit/dev/Fast-LLM\"\n", + "\n", + "# add fast_llm to the python path\n", + "import sys\n", + "sys.path.append(fast_llm_path)\n", + "from fast_llm.models.ssm.external.apriel_hybrid.modeling_ssm_hybrid_apriel import AprielSSMHybridConfig\n", + "from fast_llm.models.ssm.external.apriel_hybrid.modeling_ssm_hybrid_apriel import AprielSSMHybridModel, AprielSSMDecoderLayer, AprielSSMHybridForCausalLM\n", + "\n", + "%load_ext autoreload\n", + "%autoreload 2\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "base = 0.612615\n", + "layer_scores = {\n", + " \"22\": 0.607389,\n", + " \"24\": 0.603498,\n", + " \"19\": 0.597907,\n", + " \"27\": 0.597173,\n", + " \"20\": 0.590442,\n", + " \"5\": 0.578949,\n", + " \"4\": 0.576852,\n", + " \"9\": 0.576484,\n", + " \"23\": 0.574833,\n", + " \"7\": 0.571860,\n", + " \"8\": 0.571790,\n", + " \"6\": 0.571614,\n", + " \"2\": 0.571330,\n", + " \"26\": 0.570205,\n", + " \"11\": 0.567128,\n", + " \"14\": 0.566175,\n", + " \"15\": 0.566076,\n", + " \"3\": 0.562861,\n", + " \"1\": 0.560154,\n", + " \"13\": 0.559304,\n", + " \"16\": 0.559017,\n", + " \"10\": 0.558789,\n", + " \"12\": 0.555186,\n", + " \"17\": 0.554236,\n", + " \"25\": 0.549215,\n", + " \"18\": 0.537257,\n", + " \"0\": 0.233085,\n", + "}\n", + "layer_scores = {k: base - v for k, v in layer_scores.items()}\n", + "layer_importanfce = sorted(layer_scores.items(), key=lambda x: x[1])\n", + "layer_importanfce_rand = random.sample(layer_importanfce, len(layer_importanfce))" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[('22', 0.005226000000000064),\n", + " ('24', 0.009117000000000042),\n", + " ('19', 0.014708000000000054),\n", + " ('27', 0.015442000000000067),\n", + " ('20', 0.022173),\n", + " ('5', 0.033665999999999974),\n", + " ('4', 0.03576299999999999),\n", + " ('9', 0.036131000000000024),\n", + " ('23', 0.03778199999999998),\n", + " ('7', 0.040754999999999986),\n", + " ('8', 0.040825),\n", + " ('6', 0.041001000000000065),\n", + " ('2', 0.041285000000000016),\n", + " ('26', 0.04241000000000006),\n", + " ('11', 0.045487000000000055),\n", + " ('14', 0.04644000000000004),\n", + " ('15', 0.046539),\n", + " ('3', 0.049754000000000076),\n", + " ('1', 0.05246099999999998),\n", + " ('13', 0.053311),\n", + " ('16', 0.053598000000000035),\n", + " ('10', 0.05382600000000004),\n", + " ('12', 0.05742900000000006),\n", + " ('17', 0.05837900000000007),\n", + " ('25', 0.06340000000000001),\n", + " ('18', 0.07535800000000004),\n", + " ('0', 0.37953000000000003)]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "layer_importanfce" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "layer_importanfce = layer_importanfce_rand" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create hybrid with any number of SSM layers" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "checkpoint = \"ServiceNow-AI/Apriel-5B-Instruct\"\n", + "config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)\n", + "device = \"cuda\"\n", + "n_hybrid = 0\n", + "\n", + "index_swaped = []\n", + "hybrid_block_layout = [\"t\"] * config.num_hidden_layers\n", + "for i in range(n_hybrid):\n", + " hybrid_block_layout[int(layer_importanfce[i][0])] = \"m2d\"\n", + " index_swaped.append(int(layer_importanfce[i][0]))\n", + "\n", + "hybrdif_apriel_config = AprielSSMHybridConfig(**config.to_dict(),\n", + " hybrid_block_layout=hybrid_block_layout,\n", + " ssm_cfg={\n", + " \"d_state\": 64,\n", + " \"n_v_heads\": 24,\n", + " \"n_qk_heads\": 24,\n", + " \"expand\": 1,\n", + " \"chunk_size\": 128,\n", + " \"activation\": \"identity\",\n", + " \"bias\": False,\n", + " \"d_inner\": 24 * 128, # num_heads * head_dim\n", + " })" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['t',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't',\n", + " 't']" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hybrdif_apriel_config.hybrid_block_layout" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AprielSSMHybridForCausalLM(\n", + " (model): AprielSSMHybridModel(\n", + " (embed_tokens): Embedding(131072, 4096)\n", + " (layers): ModuleList(\n", + " (0-27): 28 x AprielDecoderLayer(\n", + " (self_attn): AprielAttention(\n", + " (q_proj): Linear(in_features=4096, out_features=3072, bias=False)\n", + " (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear(in_features=3072, out_features=4096, bias=False)\n", + " )\n", + " (mlp): AprielMLP(\n", + " (gate_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (up_proj): Linear(in_features=4096, out_features=8192, bias=False)\n", + " (down_proj): Linear(in_features=8192, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): AprielRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " )\n", + " (norm): AprielRMSNorm((4096,), eps=1e-05)\n", + " (rotary_emb): AprielRotaryEmbedding()\n", + " )\n", + " (lm_head): Linear(in_features=4096, out_features=131072, bias=False)\n", + ")" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hybrid_apriel_model = AprielSSMHybridForCausalLM(hybrdif_apriel_config)\n", + "hybrid_apriel_model.to(dtype=torch.bfloat16)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 2.22it/s]\n" + ] + } + ], + "source": [ + "\n", + "config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=True)\n", + "apriel_model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, trust_remote_code=True)\n", + "apriel_state_dict = apriel_model.state_dict()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "missing, unexpected = hybrid_apriel_model.load_state_dict(apriel_state_dict, strict=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Missing keys: []\n", + "Unexpected keys: []\n" + ] + } + ], + "source": [ + "# unexpected will contain keys from the SSM layers we added\n", + "print(\"Missing keys:\", missing)\n", + "# unexpected will contain keys from the transformer layers we replaced\n", + "print(\"Unexpected keys:\", unexpected)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading checkpoint shards: 100%|██████████| 5/5 [00:04<00:00, 1.22it/s]\n" + ] + } + ], + "source": [ + "from fast_llm.models.ssm.external.apriel_ssm.modeling_ssm_apriel import AprielSSMModel, AprielSSMForCausalLM\n", + "\n", + "mohawk_path = \"/mnt/checkpoints/ssm/mohawk_distributed_stage2_apriel_8GPU_16ksteps_lr0.0_layernorm/final\"\n", + "# config = AutoConfig.from_pretrained(mohawk_path, trust_remote_code=True)\n", + "apriel_model = AprielSSMForCausalLM.from_pretrained(mohawk_path, torch_dtype=torch.bfloat16, trust_remote_code=True)\n", + "apriel_state_dict = apriel_model.state_dict()" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "missing, unexpected = hybrid_apriel_model.load_state_dict(apriel_state_dict, strict=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Missing keys: ['model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.1.self_attn.q_proj.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.1.self_attn.v_proj.weight', 'model.layers.1.self_attn.o_proj.weight', 'model.layers.3.self_attn.q_proj.weight', 'model.layers.3.self_attn.k_proj.weight', 'model.layers.3.self_attn.v_proj.weight', 'model.layers.3.self_attn.o_proj.weight', 'model.layers.10.self_attn.q_proj.weight', 'model.layers.10.self_attn.k_proj.weight', 'model.layers.10.self_attn.v_proj.weight', 'model.layers.10.self_attn.o_proj.weight', 'model.layers.11.self_attn.q_proj.weight', 'model.layers.11.self_attn.k_proj.weight', 'model.layers.11.self_attn.v_proj.weight', 'model.layers.11.self_attn.o_proj.weight', 'model.layers.12.self_attn.q_proj.weight', 'model.layers.12.self_attn.k_proj.weight', 'model.layers.12.self_attn.v_proj.weight', 'model.layers.12.self_attn.o_proj.weight', 'model.layers.13.self_attn.q_proj.weight', 'model.layers.13.self_attn.k_proj.weight', 'model.layers.13.self_attn.v_proj.weight', 'model.layers.13.self_attn.o_proj.weight', 'model.layers.14.self_attn.q_proj.weight', 'model.layers.14.self_attn.k_proj.weight', 'model.layers.14.self_attn.v_proj.weight', 'model.layers.14.self_attn.o_proj.weight', 'model.layers.15.self_attn.q_proj.weight', 'model.layers.15.self_attn.k_proj.weight', 'model.layers.15.self_attn.v_proj.weight', 'model.layers.15.self_attn.o_proj.weight', 'model.layers.16.self_attn.q_proj.weight', 'model.layers.16.self_attn.k_proj.weight', 'model.layers.16.self_attn.v_proj.weight', 'model.layers.16.self_attn.o_proj.weight', 'model.layers.17.self_attn.q_proj.weight', 'model.layers.17.self_attn.k_proj.weight', 'model.layers.17.self_attn.v_proj.weight', 'model.layers.17.self_attn.o_proj.weight', 'model.layers.18.self_attn.q_proj.weight', 'model.layers.18.self_attn.k_proj.weight', 'model.layers.18.self_attn.v_proj.weight', 'model.layers.18.self_attn.o_proj.weight', 'model.layers.21.self_attn.q_proj.weight', 'model.layers.21.self_attn.k_proj.weight', 'model.layers.21.self_attn.v_proj.weight', 'model.layers.21.self_attn.o_proj.weight', 'model.layers.25.self_attn.q_proj.weight', 'model.layers.25.self_attn.k_proj.weight', 'model.layers.25.self_attn.v_proj.weight', 'model.layers.25.self_attn.o_proj.weight']\n", + "Unexpected keys: ['model.layers.0.mixer.z_bias', 'model.layers.0.mixer.D', 'model.layers.0.mixer.in_proj.weight', 'model.layers.0.mixer.conv1d.weight', 'model.layers.0.mixer.conv1d.bias', 'model.layers.0.mixer.out_proj.weight', 'model.layers.1.mixer.z_bias', 'model.layers.1.mixer.D', 'model.layers.1.mixer.in_proj.weight', 'model.layers.1.mixer.conv1d.weight', 'model.layers.1.mixer.conv1d.bias', 'model.layers.1.mixer.out_proj.weight', 'model.layers.3.mixer.z_bias', 'model.layers.3.mixer.D', 'model.layers.3.mixer.in_proj.weight', 'model.layers.3.mixer.conv1d.weight', 'model.layers.3.mixer.conv1d.bias', 'model.layers.3.mixer.out_proj.weight', 'model.layers.10.mixer.z_bias', 'model.layers.10.mixer.D', 'model.layers.10.mixer.in_proj.weight', 'model.layers.10.mixer.conv1d.weight', 'model.layers.10.mixer.conv1d.bias', 'model.layers.10.mixer.out_proj.weight', 'model.layers.11.mixer.z_bias', 'model.layers.11.mixer.D', 'model.layers.11.mixer.in_proj.weight', 'model.layers.11.mixer.conv1d.weight', 'model.layers.11.mixer.conv1d.bias', 'model.layers.11.mixer.out_proj.weight', 'model.layers.12.mixer.z_bias', 'model.layers.12.mixer.D', 'model.layers.12.mixer.in_proj.weight', 'model.layers.12.mixer.conv1d.weight', 'model.layers.12.mixer.conv1d.bias', 'model.layers.12.mixer.out_proj.weight', 'model.layers.13.mixer.z_bias', 'model.layers.13.mixer.D', 'model.layers.13.mixer.in_proj.weight', 'model.layers.13.mixer.conv1d.weight', 'model.layers.13.mixer.conv1d.bias', 'model.layers.13.mixer.out_proj.weight', 'model.layers.14.mixer.z_bias', 'model.layers.14.mixer.D', 'model.layers.14.mixer.in_proj.weight', 'model.layers.14.mixer.conv1d.weight', 'model.layers.14.mixer.conv1d.bias', 'model.layers.14.mixer.out_proj.weight', 'model.layers.15.mixer.z_bias', 'model.layers.15.mixer.D', 'model.layers.15.mixer.in_proj.weight', 'model.layers.15.mixer.conv1d.weight', 'model.layers.15.mixer.conv1d.bias', 'model.layers.15.mixer.out_proj.weight', 'model.layers.16.mixer.z_bias', 'model.layers.16.mixer.D', 'model.layers.16.mixer.in_proj.weight', 'model.layers.16.mixer.conv1d.weight', 'model.layers.16.mixer.conv1d.bias', 'model.layers.16.mixer.out_proj.weight', 'model.layers.17.mixer.z_bias', 'model.layers.17.mixer.D', 'model.layers.17.mixer.in_proj.weight', 'model.layers.17.mixer.conv1d.weight', 'model.layers.17.mixer.conv1d.bias', 'model.layers.17.mixer.out_proj.weight', 'model.layers.18.mixer.z_bias', 'model.layers.18.mixer.D', 'model.layers.18.mixer.in_proj.weight', 'model.layers.18.mixer.conv1d.weight', 'model.layers.18.mixer.conv1d.bias', 'model.layers.18.mixer.out_proj.weight', 'model.layers.21.mixer.z_bias', 'model.layers.21.mixer.D', 'model.layers.21.mixer.in_proj.weight', 'model.layers.21.mixer.conv1d.weight', 'model.layers.21.mixer.conv1d.bias', 'model.layers.21.mixer.out_proj.weight', 'model.layers.25.mixer.z_bias', 'model.layers.25.mixer.D', 'model.layers.25.mixer.in_proj.weight', 'model.layers.25.mixer.conv1d.weight', 'model.layers.25.mixer.conv1d.bias', 'model.layers.25.mixer.out_proj.weight']\n" + ] + } + ], + "source": [ + "# unexpected will contain keys from the SSM layers we added\n", + "print(\"Missing keys:\", missing)\n", + "# unexpected will contain keys from the transformer layers we replaced\n", + "print(\"Unexpected keys:\", unexpected)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "# hybrid_apriel_model.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_instruct_hybrid_14ssm_leastimportant_init_MOHAWK\")\n", + "# hybrid_apriel_model.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_instruct_hybrid_20ssm_leastimportant_init_rand\")\n", + "# hybrid_apriel_model.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_instruct_hybrid_14ssm_randplacement_init_rand\")\n", + "hybrid_apriel_model.save_pretrained(\"/mnt/checkpoints/ssm/apriel_ssm_instruct_hybrid_0ssm_full_transformer_debug\")" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "# save the hybrid model\n", + "output_path = \"/mnt/checkpoints/ssm/iterative_hybrids_5b\"\n", + "assert len(index_swaped) == 1\n", + "layer_swaped = index_swaped[0]\n", + "hybrid_apriel_model.save_pretrained(\n", + " f\"{output_path}/apriel_ssm_instruct5b_hybrid_{layer_swaped+1}ssm_leastimportant_32h_init_rand\"\n", + " )\n", + "print(f\"Hybrid model saved to {output_path}/apriel_ssm_instruct5b_hybrid_{layer_swaped+1}ssm_leastimportant_32h_init_rand\")\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "fast_llm", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py index 35e9b6885..9f4588a29 100644 --- a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py +++ b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py @@ -893,7 +893,7 @@ def forward( self, hidden_states: torch.Tensor, past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, - attention_mask: Optional[torch.Tensor] = None, + mamba_mask: Optional[torch.Tensor] = None, return_mixer_matrix=False, **kwargs, ): @@ -905,6 +905,10 @@ def forward( assert is_fast_path_available and "cuda" in self.in_proj.weight.device.type, "Only support fast path on cuda" cache_position = kwargs.get("cache_position", None) batch, seqlen, dim = hidden_states.shape + # mamba_mask = ( + # None if seqlen == 1 else mamba_mask + # ) # prevent that hidden_states are expanded to mask's seq. dimention., i.e. we do not need apply_mask_to_padding_states when generating single token at a time + # hidden_states = apply_mask_to_padding_states(hidden_states, mamba_mask) ssm_state, conv_state = None, None use_precomputed_states = False @@ -985,7 +989,7 @@ def forward( # Update state (B D W) conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0))) if causal_conv1d_fn is None: - x = self.act(self.conv1d(x)[..., :seqlen]) + x = self.act(self.conv1d(x)[..., :seqlen]).transpose(1, 2) else: assert self.activation in ["silu", "swish"] x = causal_conv1d_fn( @@ -993,7 +997,10 @@ def forward( weight=rearrange(self.conv1d.weight, "d 1 w -> d w"), bias=self.conv1d.bias, activation=self.activation, - ) + ) # .transpose(1, 2) + # x = apply_mask_to_padding_states(x, mamba_mask).transpose( + # 1, 2 + # ) # zero out everything that comes from padding tokens if not self.repeat_kv_before_conv: x = rearrange(x, "b (n_group dstate) l -> b n_group l dstate", dstate=self.d_state) @@ -1048,14 +1055,14 @@ def step(self, hidden_states, conv_state, ssm_state): A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - zxbcdt = self.in_proj(hidden_states_input) - z, x, B, C, dt = torch.split(zxbcdt, [self.d_inner, self.d_xb, self.d_xb, self.d_inner, self.dt_rank], dim=-1) + zxbc = self.in_proj(hidden_states_input) + z, x, B, C = torch.split(zxbc, [self.d_inner, self.d_xb, self.d_xb, self.d_inner], dim=-1) B = rearrange(B, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state) B = torch.repeat_interleave(B, dim=1, repeats=self.repeat_group) C = rearrange(C, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state).contiguous() - dt = self.dt_proj(dt) # B, d_inner + dt = self.dt_proj(self.dt_in_proj(hidden_states_input)) # B, d_inner if self.repeat_kv_before_conv: x = rearrange(x, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state) @@ -1239,9 +1246,10 @@ def forward( ) -> BaseModelOutputWithPast: use_cache = use_cache if use_cache is not None else self.config.use_cache if use_cache and past_key_values is None: + # for the case where prepare_inputs_for_generation is not called to create the cache (as in fast-llm test) batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] past_key_values = HybridMambaAttentionDynamicCache(self.config, batch_size, self.dtype, device=self.device) - return super().forward( + output = super().forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1253,6 +1261,10 @@ def forward( cache_position=cache_position, **flash_attn_kwargs, ) + past_key_values: HybridMambaAttentionDynamicCache = output.past_key_values + if past_key_values and not past_key_values.has_previous_state: + past_key_values.has_previous_state = True + return output class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... @@ -1435,6 +1447,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, + mamba_mask=attention_mask, # non-expended mask **kwargs, ) diff --git a/fast_llm/models/ssm/huggingface.py b/fast_llm/models/ssm/huggingface.py index 02f472076..1ece10edf 100644 --- a/fast_llm/models/ssm/huggingface.py +++ b/fast_llm/models/ssm/huggingface.py @@ -20,4 +20,5 @@ class HuggingfaceHybridSSMModelForCausalLM(HuggingfaceGPTModelForCausalLM): config: HuggingfaceSSMModelConfig runner_class: typing.ClassVar[type[HybridSSMInferenceRunner]] = HybridSSMInferenceRunner model_class = HybridSSMModel + runner_class: typing.ClassVar[type[HybridSSMInferenceRunner]] = HybridSSMInferenceRunner _fast_llm_model: HybridSSMModel diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index b89ed4a04..d080e6a1e 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -1,5 +1,6 @@ import abc import functools +import logging import math import typing @@ -13,6 +14,8 @@ from fast_llm.functional.triton.pointwise import triton_add, triton_copy from fast_llm.utils import Assert +logger = logging.getLogger(__name__) + class _SafeTensorSliceMeta(type): def __instancecheck__(self, instance) -> bool: @@ -147,7 +150,7 @@ def from_tensor_space( reductions: tuple[tuple[str, ReduceOp], ...] = (), **kwargs: typing.Any, ) -> typing.Self: - dims = tuple(tensor_space.get_tensor_dim(dim_name) for dim_name in dim_names) + dims = tuple(tensor_space[dim_name] for dim_name in dim_names) if reductions: # kwarg not available for ParameterMeta, so we only provide if necessary. kwargs["reductions"] = tuple( @@ -159,12 +162,11 @@ def from_tensor_space( def global_shape(self) -> torch.Size: return torch.Size([dim.global_size for dim in self.dims]) - def local_to_global( - self, - tensor: torch.Tensor, - *, - distributed: Distributed, - ) -> tuple[torch.Tensor, ...]: + def local_to_global(self, tensor: torch.Tensor, *, distributed: Distributed) -> tuple[torch.Tensor, ...]: + """ + Reconstruct a global tensor from its distributed slices. Support lazy-loaded safetensor slices. + Returns a view of the input tensor (or the input tensor itself) when possible. + """ if tensor.ndim == 0: tensor = tensor[None] Assert.eq(tensor.shape, self.shape) @@ -188,14 +190,28 @@ def local_to_global( Assert.eq(tensor.shape, self.global_shape) return tensor, is_first_rank - def global_to_local( - self, - tensor: torch.Tensor | SafeTensorSlice, - # Return an expanded tensor, avoiding `flatten` which copies the data. TODO: Rework. - expand: bool = False, - ) -> torch.Tensor: + def local_to_global_partial(self, tensor: torch.Tensor, fill_value: float | int = -1) -> torch.Tensor: + """ + Construct a tensor of shape `self.global_shape` that contains its local slice at the appropriate location, + i.e. for which `self.global_to_local(self.local_to_global_partial(tensor)) == tensor`. + Other entries are filled with `fill_value`. + Returns a view of the input tensor (or the input tensor itself) when possible. + """ + if tensor.ndim == 0: + tensor = tensor[None] + Assert.eq(tensor.shape, self.shape) + assert not self._reductions + for dim, tensor_dim in enumerate(self.dims): + if tensor_dim.is_parallel: + tensor = tensor_dim.local_to_global_partial(tensor, dim, fill_value) + + Assert.eq(tensor.shape, self.global_shape) + return tensor + + def global_to_local(self, tensor: torch.Tensor | SafeTensorSlice) -> torch.Tensor: """ - Recover the tensor-parallel slice of a tensor. Support lazy-loaded safetensor slices. + Select the local slice of a global tensor. Support lazy-loaded safetensor slices. + Returns a view of the input tensor (or the input tensor itself) when possible. """ # Take a trivial slice to convert safetensor slices. tensor = tensor[:] @@ -205,9 +221,9 @@ def global_to_local( Assert.eq(tensor.shape, self.global_shape) for dim, tensor_dim in reversed(list(enumerate(self.dims))): - tensor = tensor_dim.global_to_local(tensor, dim, expand) - if not expand: - Assert.eq(tensor.shape, self.shape) + tensor = tensor_dim.global_to_local(tensor, dim) + + Assert.eq(tensor.shape, self.shape) return tensor @classmethod @@ -302,7 +318,11 @@ def __repr__(self, *, tensor_contents=()) -> str: def init_parameter(self, tensor: torch.Tensor, distributed: Distributed) -> None: assert self.param_init_method is not None - if distributed.config.tensor_parallel == 1 or distributed.config.reproducible_init: + if ( + distributed.config.tensor_parallel == 1 + or distributed.config.reproducible_init + or self.param_init_method.requires_global_initialization + ): generator = distributed.pp_init_generator else: generator = distributed.tp_init_generator if self.is_tensor_parallel else distributed.pp_init_generator diff --git a/setup.cfg b/setup.cfg index baa6e4adc..6ea98610c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -41,20 +41,20 @@ OPTIONAL = # Huggingface tools HUGGINGFACE = - transformers>=4.52.4 + transformers==4.53.2 hf-transfer>=0.1.9 datasets>=3.6.0 huggingface-hub>=0.32.6 # Required to run SSMs # To install on cpu environment (ex. for IDE support): -# MAMBA_SKIP_CUDA_BUILD=TRUE MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[SSM]" --no-build-isolation +# MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[CORE,SSM]" --no-build-isolation SSM = mamba_ssm[causal-conv1d]==2.2.4 cartesia_pytorch>=0.0.2 -GENERATION = - lm_eval>=0.4.9 +# GENERATION = +# lm_eval>=0.4.9 # Required for supporting vision inputs diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 9a878c494..6d00d05ba 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -23,12 +23,10 @@ def _reverse_kl_loss( ): scaled_target = target / teacher_softmax_temperature - scaled_target = torch.clamp(target, min=-50, max=50) teacher_log_probs = torch.log_softmax(scaled_target, dim=-1) with torch.enable_grad(): # Use log_softmax for consistency instead of _fused_softmax - logits = torch.clamp(logits, min=-50, max=50) student_log_probs = torch.log_softmax(logits, dim=-1) if loss_mask is None: loss = torch.nn.functional.kl_div(