diff --git a/src/nanotron/mod/__init__.py b/src/nanotron/mod/__init__.py new file mode 100644 index 00000000..9b571df5 --- /dev/null +++ b/src/nanotron/mod/__init__.py @@ -0,0 +1 @@ +from nanotron.mod.mod import MixtureOfDepth, Router diff --git a/src/nanotron/mod/llama.py b/src/nanotron/mod/llama.py new file mode 100644 index 00000000..869c34c3 --- /dev/null +++ b/src/nanotron/mod/llama.py @@ -0,0 +1,146 @@ +from typing import Dict, Optional, Union, List + +import torch +from torch import nn + +import torch.distributed as dist +from nanotron.config import LlamaConfig, ParallelismArgs +from nanotron.nn.layer_norm import TritonRMSNorm +from nanotron.parallel import ParallelContext +from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer +from nanotron.parallel.pipeline_parallel.p2p import P2P +from nanotron.parallel.tensor_parallel.nn import ( + TensorParallelColumnLinear, + TensorParallelLinearMode, +) +from nanotron.models.llama import LlamaModel, Embedding, LlamaDecoderLayer, CausalSelfAttention, MLP +from nanotron.mod.mod import MixtureOfDepth, Router + + +# class LlamaDecoderLayer(nn.Module): +# def __init__( +# self, +# config: LlamaConfig, +# parallel_config: Optional[ParallelismArgs], +# tp_pg: dist.ProcessGroup, +# layer_idx: int, +# ): +# super().__init__() +# self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) +# self.attn = CausalSelfAttention( +# config=config, +# parallel_config=parallel_config, +# tp_pg=tp_pg, +# layer_idx=layer_idx, +# ) + +# self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) +# self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg) +# self.router = Router(seq_len=1024, top_k=10) + +# def forward( +# self, +# hidden_states: Union[torch.Tensor, TensorPointer], +# sequence_mask: Union[torch.Tensor, TensorPointer], +# ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: +# residual = hidden_states +# hidden_states = self.input_layernorm(hidden_states) + +# output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask) +# hidden_states = output["hidden_states"] +# hidden_states = hidden_states + residual + +# residual = hidden_states +# hidden_states = self.post_attention_layernorm(hidden_states) +# hidden_states = self.mlp(hidden_states=hidden_states)["hidden_states"] +# hidden_states = hidden_states + residual + +# return { +# "hidden_states": hidden_states, +# "sequence_mask": output["sequence_mask"], +# } + + +class MoDLlamaModel(nn.Module, LlamaModel): + """Build pipeline graph""" + + def __init__( + self, + config: LlamaConfig, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], + ): + super().__init__() + + # Declare all the nodes + self.p2p = P2P(parallel_context.pp_pg, device=torch.device("cuda")) + self.config = config + self.parallel_config = parallel_config + self.parallel_context = parallel_context + self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + tp_linear_async_communication = ( + parallel_config.tp_linear_async_communication if parallel_config is not None else False + ) + + self.token_position_embeddings = PipelineBlock( + p2p=self.p2p, + module_builder=Embedding, + module_kwargs={ + "tp_pg": parallel_context.tp_pg, + "config": config, + "parallel_config": parallel_config, + }, + module_input_keys={"input_ids", "input_mask"}, + module_output_keys={"input_embeds"}, + ) + + self.decoder = nn.ModuleList( + [ + PipelineBlock( + p2p=self.p2p, + module_builder=LlamaDecoderLayer, + module_kwargs={ + "config": config, + "parallel_config": parallel_config, + "tp_pg": parallel_context.tp_pg, + "layer_idx": layer_idx, + }, + module_input_keys={"hidden_states", "sequence_mask"}, + module_output_keys={"hidden_states", "sequence_mask"}, + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) + + self.final_layer_norm = PipelineBlock( + p2p=self.p2p, + module_builder=TritonRMSNorm, + module_kwargs={"hidden_size": config.hidden_size, "eps": config.rms_norm_eps}, + module_input_keys={"input"}, + module_output_keys={"hidden_states"}, + ) + + self.lm_head = PipelineBlock( + p2p=self.p2p, + # Understand that this means that we return sharded logits that are going to need to be gathered + module_builder=TensorParallelColumnLinear, + module_kwargs={ + "in_features": config.hidden_size, + "out_features": config.vocab_size, + "pg": parallel_context.tp_pg, + "bias": False, + # TODO @thomasw21: refactor so that we store that default in a single place. + "mode": self.tp_mode, + "async_communication": tp_linear_async_communication, + }, + module_input_keys={"x"}, + module_output_keys={"logits"}, + ) + + self.cast_to_fp32 = PipelineBlock( + p2p=self.p2p, + module_builder=lambda: lambda x: x.float(), + module_kwargs={}, + module_input_keys={"x"}, + module_output_keys={"output"}, + ) diff --git a/src/nanotron/mod/mod.py b/src/nanotron/mod/mod.py new file mode 100644 index 00000000..f6e534f6 --- /dev/null +++ b/src/nanotron/mod/mod.py @@ -0,0 +1,76 @@ +from typing import Tuple, Union + +import torch +import torch.nn.functional as F +from einops import rearrange +from torch import nn +from torchtyping import TensorType + +from nanotron.parallel.pipeline_parallel.block import TensorPointer + + +class MixtureOfDepth(nn.Module): + def __init__(self, capacity: int, d_model: int, block: nn.Module): + super().__init__() + self.router = Router(capacity, d_model) + self.block = block + + # def forward(self, inputs: TensorType["batch_size", "seq_len", "d_model"]) -> TensorType["batch_size", "seq_len", "d_model"]: + def forward( + self, + hidden_states: Union[TensorType["batch_size", "seq_len", "d_model"], TensorPointer], + sequence_mask: Union[TensorType["batch_size", "seq_len"], TensorPointer], + ) -> Tuple[ + Union[TensorType["batch_size", "seq_len", "d_model"], TensorPointer], + Union[TensorType["batch_size", "seq_len"], TensorPointer], + ]: + hidden_states = rearrange(hidden_states, "seq_len batch_size d_model -> batch_size seq_len d_model") + selected_idxs = self.router(hidden_states) + assert selected_idxs.shape == (hidden_states.size(0), self.router.capacity) + selected_hidden_states = hidden_states[torch.arange(hidden_states.size(0)).unsqueeze(1), selected_idxs] + selected_sequence_mask = sequence_mask[torch.arange(sequence_mask.size(0)).unsqueeze(1), selected_idxs] + + selected_hidden_states = rearrange( + selected_hidden_states, "batch_size seq_len d_model -> seq_len batch_size d_model" + ) + outputs_of_selected_inputs = self.block(selected_hidden_states, selected_sequence_mask) + # NOTE: now keep the representation of the selected inputs and replace the original inputs with the new ones + hidden_states[torch.arange(hidden_states.size(0)).unsqueeze(1), selected_idxs] = rearrange( + outputs_of_selected_inputs["hidden_states"], "seq_len batch_size d_model -> batch_size seq_len d_model" + ) + hidden_states = rearrange(hidden_states, "batch_size seq_len d_model -> seq_len batch_size d_model") + return {"hidden_states": hidden_states, "sequence_mask": sequence_mask} + + +class Router(nn.Module): + def __init__( + self, + capacity: int, + d_model: int, + # tp_pg: dist.ProcessGroup, + # parallel_config: Optional[ParallelismArgs] + ): + super().__init__() + self.capacity = capacity + self.gate = nn.Linear(d_model, 1) + + # TODO(xrsrke): deduplicate this + # tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + # tp_linear_async_communication = ( + # parallel_config.tp_linear_async_communication if parallel_config is not None else False + # ) + + # self.gate = TensorParallelRowLinear( + # d_model, + # 1, + # pg=tp_pg, + # mode=TensorParallelLinearMode.REDUCE_SCATTER, + # bias=False, + # async_communication=True, + # # contiguous_chunks=gate_up_contiguous_chunks, + # ) + + def forward(self, inputs: TensorType["batch_size", "seq_len", "d_model"]) -> TensorType["batch_size", "seq_len"]: + probs = F.softmax(self.gate(inputs), dim=1).view(-1, inputs.size(1)) + _, top_k_indices = torch.topk(probs, self.capacity) + return top_k_indices diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 32aab9cd..dc38f195 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -14,7 +14,7 @@ # limitations under the License. """PyTorch LLaMa model.""" -from typing import Dict, Optional, Union, List +from typing import Dict, Optional, Union import torch from torch import nn @@ -25,6 +25,7 @@ from nanotron.config.models_config import RandomInit, SpectralMupInit from nanotron.generation.generate_store import AttachableStore from nanotron.logging import log_rank +from nanotron.mod.mod import MixtureOfDepth from nanotron.models import NanotronModel from nanotron.nn.activations import ACT2FN from nanotron.nn.layer_norm import TritonRMSNorm @@ -45,6 +46,15 @@ logger = logging.get_logger(__name__) +CAPACITY = 50 +D_MODEL = 16 + + +def build_mod_block(*args, **kwargs): + block = LlamaDecoderLayer(*args, **kwargs) + mod = MixtureOfDepth(CAPACITY, D_MODEL, block) + return mod + class RotaryEmbedding(nn.Module): def __init__(self, dim: int, end: int, theta: float = 10000.0): @@ -704,11 +714,20 @@ def __init__( module_output_keys={"input_embeds"}, ) + # def build_mod_block(module_kwargs, block): + # block = self.module_builder(**self.module_kwargs) + + # NOTE: how make MixtureOfDepth block wrap around these blocks? + + # CAPACITY = 50 + # D_MODEL = config.hidden_size + self.decoder = nn.ModuleList( [ PipelineBlock( p2p=self.p2p, - module_builder=LlamaDecoderLayer, + # module_builder=LlamaDecoderLayer, + module_builder=build_mod_block, module_kwargs={ "config": config, "parallel_config": parallel_config, @@ -755,6 +774,8 @@ def __init__( module_output_keys={"output"}, ) + # self.mod_blocks = [] + def forward( self, input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] @@ -795,6 +816,8 @@ def get_block_compute_costs(self): # CausalSelfAttention (qkv proj + attn out) + MLP LlamaDecoderLayer: 4 * model_config.num_attention_heads * d_qkv * model_config.hidden_size + 3 * d_ff * model_config.hidden_size, + build_mod_block: 4 * model_config.num_attention_heads * d_qkv * model_config.hidden_size + + 3 * d_ff * model_config.hidden_size, # This is the last lm_head TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size, } diff --git a/tests/test_mod.py b/tests/test_mod.py new file mode 100644 index 00000000..b368f13a --- /dev/null +++ b/tests/test_mod.py @@ -0,0 +1,36 @@ +import torch +from torch import nn +import pytest + +from nanotron.mod import MixtureOfDepth, Router + + +@pytest.mark.parametrize("seq_len, top_k", [(1, 1), (10, 5), (10, 10)]) +def test_mod(seq_len, top_k): + BATCH_SIZE = 15 + D_MODEL = 1024 + + linear = nn.Linear(D_MODEL, D_MODEL) + block = MixtureOfDepth(top_k, D_MODEL, linear) + + inputs = torch.randn(BATCH_SIZE, seq_len, D_MODEL) + ref_inputs = inputs.clone() + outputs = block(inputs) + + expected_num_tokens_not_changed = (seq_len - top_k) * BATCH_SIZE + num_tokens_not_changed = torch.eq(outputs.view(-1, D_MODEL), ref_inputs.view(-1, D_MODEL)).all(dim=1).sum().item() + + assert outputs.shape == linear(ref_inputs).shape + assert num_tokens_not_changed == expected_num_tokens_not_changed, f"num_tokens_not_changed: {num_tokens_not_changed}, expected: {expected_num_tokens_not_changed}" + + +@pytest.mark.parametrize("capacity, d_model", [(1, 64), (10, 64)]) +def test_router(capacity, d_model): + BATCH_SIZE, SEQ_LEN = 5, 10 + inputs = torch.randn(BATCH_SIZE, SEQ_LEN, d_model) + + router = Router(capacity, d_model) + selected_idxs = router(inputs) + + assert selected_idxs.shape == (BATCH_SIZE, capacity) + assert selected_idxs.dtype == torch.int64