Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MixtralSparseMoeBlock: add gate jitter #29865

Merged
merged 1 commit into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/transformers/models/mixtral/configuration_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ class MixtralConfig(PretrainedConfig):
allow the model to output the auxiliary loss. See [here]() for more details
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
The aux loss factor for the total loss.
router_jitter_noise (`float`, *optional*, defaults to 0.0):
Amount of noise to add to the router.

```python
>>> from transformers import MixtralModel, MixtralConfig
Expand Down Expand Up @@ -134,6 +136,7 @@ def __init__(
num_local_experts=8,
output_router_logits=False,
router_aux_loss_coef=0.001,
router_jitter_noise=0.0,
**kwargs,
):
self.vocab_size = vocab_size
Expand All @@ -160,6 +163,7 @@ def __init__(
self.num_local_experts = num_local_experts
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef
self.router_jitter_noise = router_jitter_noise
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,9 +837,14 @@ def __init__(self, config):

self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)])

# Jitter parameters
self.jitter_noise = config.router_jitter_noise

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
""" """
batch_size, sequence_length, hidden_dim = hidden_states.shape
if self.training and self.jitter_noise > 0:
hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
Expand Down
4 changes: 3 additions & 1 deletion tests/models/mixtral/test_modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@


class MixtralModelTester:
# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTester.__init__
def __init__(
self,
parent,
Expand All @@ -69,6 +68,7 @@ def __init__(
num_choices=4,
pad_token_id=0,
scope=None,
router_jitter_noise=0.1,
):
self.parent = parent
self.batch_size = batch_size
Expand All @@ -94,6 +94,7 @@ def __init__(
self.num_choices = num_choices
self.pad_token_id = pad_token_id
self.scope = scope
self.router_jitter_noise = router_jitter_noise

# Copied from tests.models.mistral.test_modeling_mistral.MistralModelTester.prepare_config_and_inputs
def prepare_config_and_inputs(self):
Expand Down Expand Up @@ -137,6 +138,7 @@ def get_config(self):
pad_token_id=self.pad_token_id,
num_experts_per_tok=2,
num_local_experts=2,
router_jitter_noise=self.router_jitter_noise,
)

# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester.create_and_check_model with Llama->Mixtral
Expand Down
Loading