Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Qwen2 support? #3166

Open
HadXu opened this issue Sep 24, 2024 · 0 comments
Open

Qwen2 support? #3166

HadXu opened this issue Sep 24, 2024 · 0 comments

Comments

@HadXu
Copy link

HadXu commented Sep 24, 2024

Follow https://huggingface.co/docs/transformers/en/quantization/fbgemm_fp8 and I run it successfully.

But when I run it with qwen2 model, with error "RuntimeError: Invalid datatype. input must be BF16".

img_v3_02f1_3f046408-68c8-465f-bb54-71c26a139eag


RuntimeError Traceback (most recent call last)
Cell In[4], line 11
8 input_text = "What are we having for dinner?"
9 input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
---> 11 output = quantized_model.generate(**input_ids, max_new_tokens=10)
12 print(tokenizer.decode(output[0], skip_special_tokens=True))

File /usr/local/lib/python3.11/dist-packages/torch/utils/_contextlib.py:116, in context_decorator..decorate_context(*args, **kwargs)
113 @functools.wraps(func)
114 def decorate_context(*args, **kwargs):
115 with ctx_factory():
--> 116 return func(*args, **kwargs)

File /usr/local/lib/python3.11/dist-packages/transformers/generation/utils.py:2024, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
2016 input_ids, model_kwargs = self._expand_inputs_for_generation(
2017 input_ids=input_ids,
2018 expand_size=generation_config.num_return_sequences,
2019 is_encoder_decoder=self.config.is_encoder_decoder,
2020 **model_kwargs,
2021 )
2023 # 13. run sample (it degenerates to greedy search when generation_config.do_sample=False)
-> 2024 result = self._sample(
2025 input_ids,
2026 logits_processor=prepared_logits_processor,
2027 logits_warper=prepared_logits_warper,
2028 stopping_criteria=prepared_stopping_criteria,
2029 generation_config=generation_config,
2030 synced_gpus=synced_gpus,
2031 streamer=streamer,
2032 **model_kwargs,
2033 )
2035 elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
2036 # 11. prepare logits warper
2037 prepared_logits_warper = (
2038 self._get_logits_warper(generation_config, device=input_ids.device)
2039 if generation_config.do_sample
2040 else None
2041 )

File /usr/local/lib/python3.11/dist-packages/transformers/generation/utils.py:2982, in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, logits_warper, **model_kwargs)
2979 model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
2981 # forward pass to get next token
-> 2982 outputs = self(**model_inputs, return_dict=True)
2984 if synced_gpus and this_peer_finished:
2985 continue # don't waste resources running the code we don't need

File /usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
-> 1553 return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
1557 # If we don't have any hooks, we want to skip the rest of the logic in
1558 # this function, and just call forward.
1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1560 or _global_backward_pre_hooks or _global_backward_hooks
1561 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562 return forward_call(*args, **kwargs)
1564 try:
1565 result = None

File /usr/local/lib/python3.11/dist-packages/transformers/models/qwen2/modeling_qwen2.py:1104, in Qwen2ForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
1101 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1103 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1104 outputs = self.model(
1105 input_ids=input_ids,
1106 attention_mask=attention_mask,
1107 position_ids=position_ids,
1108 past_key_values=past_key_values,
1109 inputs_embeds=inputs_embeds,
1110 use_cache=use_cache,
1111 output_attentions=output_attentions,
1112 output_hidden_states=output_hidden_states,
1113 return_dict=return_dict,
1114 cache_position=cache_position,
1115 )
1117 hidden_states = outputs[0]
1118 logits = self.lm_head(hidden_states)

File /usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
-> 1553 return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
1557 # If we don't have any hooks, we want to skip the rest of the logic in
1558 # this function, and just call forward.
1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1560 or _global_backward_pre_hooks or _global_backward_hooks
1561 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562 return forward_call(*args, **kwargs)
1564 try:
1565 result = None

File /usr/local/lib/python3.11/dist-packages/transformers/models/qwen2/modeling_qwen2.py:915, in Qwen2Model.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
904 layer_outputs = self._gradient_checkpointing_func(
905 decoder_layer.call,
906 hidden_states,
(...)
912 cache_position,
913 )
914 else:
--> 915 layer_outputs = decoder_layer(
916 hidden_states,
917 attention_mask=causal_mask,
918 position_ids=position_ids,
919 past_key_value=past_key_values,
920 output_attentions=output_attentions,
921 use_cache=use_cache,
922 cache_position=cache_position,
923 )
925 hidden_states = layer_outputs[0]
927 if use_cache:

File /usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
-> 1553 return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
1557 # If we don't have any hooks, we want to skip the rest of the logic in
1558 # this function, and just call forward.
1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1560 or _global_backward_pre_hooks or _global_backward_hooks
1561 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562 return forward_call(*args, **kwargs)
1564 try:
1565 result = None

File /usr/local/lib/python3.11/dist-packages/transformers/models/qwen2/modeling_qwen2.py:655, in Qwen2DecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, **kwargs)
652 hidden_states = self.input_layernorm(hidden_states)
654 # Self Attention
--> 655 hidden_states, self_attn_weights, present_key_value = self.self_attn(
656 hidden_states=hidden_states,
657 attention_mask=attention_mask,
658 position_ids=position_ids,
659 past_key_value=past_key_value,
660 output_attentions=output_attentions,
661 use_cache=use_cache,
662 cache_position=cache_position,
663 )
664 hidden_states = residual + hidden_states
666 # Fully Connected

File /usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
-> 1553 return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
1557 # If we don't have any hooks, we want to skip the rest of the logic in
1558 # this function, and just call forward.
1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1560 or _global_backward_pre_hooks or _global_backward_hooks
1561 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562 return forward_call(*args, **kwargs)
1564 try:
1565 result = None

File /usr/local/lib/python3.11/dist-packages/transformers/models/qwen2/modeling_qwen2.py:592, in Qwen2SdpaAttention.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position)
589 attn_output = attn_output.transpose(1, 2).contiguous()
590 attn_output = attn_output.view(bsz, q_len, self.hidden_size)
--> 592 attn_output = self.o_proj(attn_output)
594 return attn_output, None, past_key_value

File /usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
-> 1553 return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
1557 # If we don't have any hooks, we want to skip the rest of the logic in
1558 # this function, and just call forward.
1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1560 or _global_backward_pre_hooks or _global_backward_hooks
1561 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562 return forward_call(*args, **kwargs)
1564 try:
1565 result = None

File /usr/local/lib/python3.11/dist-packages/transformers/integrations/fbgemm_fp8.py:50, in FbgemmFp8Linear.forward(self, x)
47 num_tokens = None
48 # x_quantized and x_scale are not necessarily on the same device as x, this is an issue.
49 # FBGEMM/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu at e08af8539c391437f447173863df0f3f6f
---> 50 x_quantized, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
51 x.view(-1, x.shape[-1]), num_tokens, self.input_scale_ub
52 )
53 # moving x_quantized, x_scale here creates glibberish output ... However, if we move the output, it works
54 # x_quantized, x_scale = x_quantized.to(x.device), x_scale.to(x.device)
55
56 # The computation still happens on the device where self.weight is even if x_quantized is not on the same device as self.weight
57 output = torch.ops.fbgemm.f8f8bf16_rowwise(
58 x_quantized, self.weight, x_scale, self.weight_scale, use_fast_accum=True
59 )

File /usr/local/lib/python3.11/dist-packages/torch/ops.py:1061, in OpOverloadPacket.call(self, *args, **kwargs)
1059 if self_._has_torchbind_op_overload and must_dispatch_in_python(args, kwargs):
1060 return call_overload_packet_from_python(self, args, kwargs)
-> 1061 return self
._op(*args, **(kwargs or {}))

RuntimeError: Invalid datatype. input must be BF16

But I compare qwen2 and llama3 8B, the dtype are all bf16.

img_v3_02f1_fac20cb7-6896-4763-9100-7fc92ff76f6g

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant