Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mixtral inference breaks when output_router_logits=True #29087

Closed
2 of 4 tasks
LeonardoEmili opened this issue Feb 18, 2024 · 5 comments · Fixed by #29249
Closed
2 of 4 tasks

Mixtral inference breaks when output_router_logits=True #29087

LeonardoEmili opened this issue Feb 18, 2024 · 5 comments · Fixed by #29249

Comments

@LeonardoEmili
Copy link
Contributor

System Info

  • transformers version: 4.38.0.dev0
  • Platform: Linux-5.15.0-1038-oracle-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.20.3
  • Safetensors version: 0.4.2
  • Accelerate version: 0.26.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.1.1+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: yes, 8x A100 80GBs
  • Using distributed or parallel set-up in script?: yes, using device_map="auto"

Who can help?

@ArthurZucker @gante

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

The snippet

from transformers import AutoTokenizer, MixtralForCausalLM
import torch

model = MixtralForCausalLM.from_pretrained(<path_to_finetuned_Mixtral-8x7B-v0.1>)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(<path_to_finetuned_Mixtral-8x7B-v0.1>)
prompts = ['Pu', 'Av', 'Il', 'Please', 'access']

batch = tokenizer(prompts, padding=True, return_tensors="pt")
with torch.no_grad():
    outputs = model.generate(
        **batch, max_new_tokens=400, do_sample=True, top_p=0.9, temperature=0.1, min_length=None, use_cache=True, top_k=50,
        bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id,
    )

produces

Traceback (most recent call last):
  File "/job_workspace/axolotl/scripts/custom_modules/checkpoint_selection/evaluate_model_checkpoint.py", line 283, in main
    hypothesis = llm.batch_translate(batch["prompts"], batch["tl_names"])
  File "/job_workspace/axolotl/scripts/custom_modules/checkpoint_selection/evaluate_model_checkpoint.py", line 140, in batch_translate
    outputs = self._model.generate(
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 1525, in generate
    return self.sample(
  File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 2598, in sample
    outputs = self(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py", line 165, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/mixtral/modeling_mixtral.py", line 1392, in forward
    aux_loss = load_balancing_loss_func(
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/mixtral/modeling_mixtral.py", line 132, in load_balancing_loss_func
    tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
RuntimeError: The size of tensor a (160) must match the size of tensor b (0) at non-singleton dimension 0

Important details:

  • the model is a Q-Lora fine-tuned version of Mixtral-8x7B-v0.1 using axolotl, different weights but same shapes
  • task is instruction following where the model learns to translate the English prompt in Italian
  • the issue only arises for some input data which is likely very short (see the snippet above)

Expected behavior

  • Doubt/clarification: it seems that Mixtral in inference shall not output output_router_logits (see official docs) and its usage should only be limited during training (as described here). I believe this was set by during training and then stored into the checkpoints, disabling it in the configs produces the expected results.
  • Proposal: shall we always override this configuration to False when model.eval() is called?
  • Expected outcome: completion consistently returned with variable prompt
@LeonardoEmili LeonardoEmili changed the title Mixtral inference should not output_router_logits Mixtral inference breaks when output_router_logits=True Feb 18, 2024
@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Feb 20, 2024

When running inference you should set model .config.output_router_logits=False

@LeonardoEmili
Copy link
Contributor Author

Thanks @ArthurZucker, I believe it is a bit hard to spot the correct behaviour from the docs so I was wondering if it is always the case that inference requires turning off the config and if so maybe it should be enforced when model.eval() is called?

@ArthurZucker
Copy link
Collaborator

Actually this should be enforced when call prepare_inputs_for_generation! Would you like to open a PR for mixtral ?

@LeonardoEmili
Copy link
Contributor Author

Sounds good, I'll happily take care of it @ArthurZucker.

Just to make sure do you think it's better raising an assertion when mixtral is used in inference with that configuration or rather raising a warning and ignoring it (even if the user set it to True)? I believe at that stage the first option should be preferred and the second scenario should be handled earlier (maybe when setting the model in inference mode?).

@ArthurZucker
Copy link
Collaborator

No I think we should always set it, this is the expected api for output_attention for example. 🤗

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants