From 09a2adf55da88d549e56f5a5d62add86fb6481e0 Mon Sep 17 00:00:00 2001 From: Lorenzo Verardo Date: Wed, 27 Mar 2024 12:22:24 +0100 Subject: [PATCH] MixtralSparseMoeBlock: add gate jitter This commit adds gate jitter to MixtralSparseMoeBlock's input data before passing it through the MoE layer, if turned on. --- src/transformers/models/mixtral/configuration_mixtral.py | 4 ++++ src/transformers/models/mixtral/modeling_mixtral.py | 5 +++++ tests/models/mixtral/test_modeling_mixtral.py | 4 +++- 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/mixtral/configuration_mixtral.py b/src/transformers/models/mixtral/configuration_mixtral.py index ac2dbed16e10cb..f63d1f88cbe052 100644 --- a/src/transformers/models/mixtral/configuration_mixtral.py +++ b/src/transformers/models/mixtral/configuration_mixtral.py @@ -93,6 +93,8 @@ class MixtralConfig(PretrainedConfig): allow the model to output the auxiliary loss. See [here]() for more details router_aux_loss_coef (`float`, *optional*, defaults to 0.001): The aux loss factor for the total loss. + router_jitter_noise (`float`, *optional*, defaults to 0.0): + Amount of noise to add to the router. ```python >>> from transformers import MixtralModel, MixtralConfig @@ -134,6 +136,7 @@ def __init__( num_local_experts=8, output_router_logits=False, router_aux_loss_coef=0.001, + router_jitter_noise=0.0, **kwargs, ): self.vocab_size = vocab_size @@ -160,6 +163,7 @@ def __init__( self.num_local_experts = num_local_experts self.output_router_logits = output_router_logits self.router_aux_loss_coef = router_aux_loss_coef + self.router_jitter_noise = router_jitter_noise super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 4c4c44bd2297d8..e9e801bb71670b 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -837,9 +837,14 @@ def __init__(self, config): self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)]) + # Jitter parameters + self.jitter_noise = config.router_jitter_noise + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ """ batch_size, sequence_length, hidden_dim = hidden_states.shape + if self.training and self.jitter_noise > 0: + hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise) hidden_states = hidden_states.view(-1, hidden_dim) # router_logits: (batch * sequence_length, n_experts) router_logits = self.gate(hidden_states) diff --git a/tests/models/mixtral/test_modeling_mixtral.py b/tests/models/mixtral/test_modeling_mixtral.py index df31ec0050d08b..b6dcee093e4db4 100644 --- a/tests/models/mixtral/test_modeling_mixtral.py +++ b/tests/models/mixtral/test_modeling_mixtral.py @@ -42,7 +42,6 @@ class MixtralModelTester: - # Copied from tests.models.mistral.test_modeling_mistral.MistralModelTester.__init__ def __init__( self, parent, @@ -69,6 +68,7 @@ def __init__( num_choices=4, pad_token_id=0, scope=None, + router_jitter_noise=0.1, ): self.parent = parent self.batch_size = batch_size @@ -94,6 +94,7 @@ def __init__( self.num_choices = num_choices self.pad_token_id = pad_token_id self.scope = scope + self.router_jitter_noise = router_jitter_noise # Copied from tests.models.mistral.test_modeling_mistral.MistralModelTester.prepare_config_and_inputs def prepare_config_and_inputs(self): @@ -137,6 +138,7 @@ def get_config(self): pad_token_id=self.pad_token_id, num_experts_per_tok=2, num_local_experts=2, + router_jitter_noise=self.router_jitter_noise, ) # Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Mixtral