Skip to content

Onboarding Mistral3.1_24B #358

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions QEfficient/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@
MistralModel,
MistralRMSNorm,
)
from transformers.models.mistral3.modeling_mistral3 import (
Mistral3ForConditionalGeneration,
Mistral3RMSNorm,
)
from transformers.models.mixtral.modeling_mixtral import (
MixtralAttention,
MixtralDecoderLayer,
Expand All @@ -69,6 +73,7 @@
from transformers.models.mpt.modeling_mpt import MptAttention, MptBlock, MptForCausalLM, MptModel
from transformers.models.phi.modeling_phi import PhiAttention, PhiForCausalLM, PhiModel
from transformers.models.phi3.modeling_phi3 import Phi3Attention, Phi3ForCausalLM, Phi3Model, Phi3RMSNorm
from transformers.models.pixtral.modeling_pixtral import PixtralRMSNorm
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, Qwen2ForCausalLM, Qwen2Model, Qwen2RMSNorm
from transformers.models.starcoder2.modeling_starcoder2 import (
Starcoder2Attention,
Expand All @@ -87,6 +92,7 @@
)

from QEfficient.customop import CustomRMSNormAIC
from QEfficient.transformers.models.mistral3.modeling_mistral3 import QEffMistral3ForConditionalGeneration

from .models.codegen.modeling_codegen import (
QEffCodeGenAttention,
Expand Down Expand Up @@ -177,6 +183,7 @@
GPTBigCodeForCausalLM.__name__,
MllamaForCausalLM.__name__,
WhisperForConditionalGeneration.__name__,
Mistral3ForConditionalGeneration.__name__,
]
)

Expand Down Expand Up @@ -226,6 +233,9 @@
MistralModel: QEffMistralModel,
MistralForCausalLM: QEffMistralForCausalLM,
MistralRMSNorm: CustomRMSNormAIC,
# Mistral3 model layers
Mistral3ForConditionalGeneration: QEffMistral3ForConditionalGeneration,
Mistral3RMSNorm: CustomRMSNormAIC,
# Mixtral model layers
MixtralAttention: QEffMixtralAttention,
MixtralDecoderLayer: QeffMixtralDecoderLayer,
Expand All @@ -242,6 +252,8 @@
PhiAttention: QEffPhiAttention,
PhiModel: QEffPhiModel,
PhiForCausalLM: QEffPhiForCausalLM,
# Pixtral model layers
PixtralRMSNorm: CustomRMSNormAIC,
# Falcon model layers
FalconAttention: QEffFalconAttention,
FalconForCausalLM: QEffFalconForCausalLM,
Expand Down
6 changes: 6 additions & 0 deletions QEfficient/transformers/models/mistral3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------
237 changes: 237 additions & 0 deletions QEfficient/transformers/models/mistral3/modeling_mistral3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

import torch
import torch.nn as nn
import torch.utils.checkpoint
from transformers.models.mistral3.modeling_mistral3 import Mistral3ForConditionalGeneration

from QEfficient.utils import constants
from QEfficient.utils._utils import IOInfo, get_padding_shape_from_config

BS = 1
NUM_CHANNEL = 3
SEQ_LEN = 3072
CTX_LEN = 4096


class QEFFMistral3EncoderWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.model.vision_model = self.model.vision_tower

def forward(self, pixel_values):
image_sizes = torch.tensor([[pixel_values.shape[2], pixel_values.shape[3]]])
image_features = self.model.get_image_features(
pixel_values=pixel_values,
vision_feature_layer=self.model.config.vision_feature_layer,
image_sizes=image_sizes,
)
return image_features


class QEFFMistral3DecoderWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.config = self.model.config
self.language_model = self.model.language_model

def forward(self, input_ids, vit_embeds, position_ids, past_key_values):
inputs_embeds = self.model.get_input_embeddings()(input_ids)
vit_embeds = vit_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
mask = input_ids == self.model.config.image_token_index
indices1 = mask.to(torch.int64).cumsum(1) - 1
indices0 = torch.arange(mask.shape[0]).view(-1, 1)
image_features_expanded = vit_embeds.unsqueeze(0)[indices0, indices1]
inputs_embeds = torch.where(mask.unsqueeze(-1), image_features_expanded, inputs_embeds)
outputs = self.model.language_model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
past_key_values=past_key_values,
)

return outputs.logits, vit_embeds, outputs.past_key_values


class QEffMistral3ForConditionalGeneration(Mistral3ForConditionalGeneration):
def get_qeff_vision_encoder(self):
return QEFFMistral3EncoderWrapper(self)

def get_qeff_language_decoder(self):
return QEFFMistral3DecoderWrapper(self)

def forward(self, pixel_values, input_ids, position_ids, past_key_values):
inputs_embeds = self.get_input_embeddings()(input_ids)
# Image features
image_sizes = torch.tensor([[pixel_values.shape[2], pixel_values.shape[3]]])
image_features = self.get_image_features(
pixel_values=pixel_values,
vision_feature_layer=self.config.vision_feature_layer,
image_sizes=image_sizes,
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
mask = input_ids == self.config.image_token_index
indices1 = mask.to(torch.int64).cumsum(1) - 1
indices0 = torch.arange(mask.shape[0]).view(-1, 1)
image_features_expanded = image_features.unsqueeze(0)[indices0, indices1]
inputs_embeds = torch.where(mask.unsqueeze(-1), image_features_expanded, inputs_embeds)
outputs = self.language_model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
past_key_values=past_key_values,
)
return outputs.logits, pixel_values, outputs.past_key_values

def get_dummy_inputs(self, kv_offload: bool = False, **kwargs):
inputs_shapes = {}
inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN)
inputs_shapes["vit_embeds"] = (
constants.MISTRAL3_FEATURE_SIZE,
self.language_model.config.hidden_size,
)
inputs_shapes["position_ids"] = (
constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN,
)
inputs_shapes["pixel_values"] = (
constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
constants.MISTRAL3_NUM_CHANNELS,
constants.MISTRAL3_HEIGHT,
constants.MISTRAL3_WIDTH,
)

# Define inputs
vision_inputs = {}
lang_inputs = {}
vision_inputs["pixel_values"] = torch.zeros((inputs_shapes["pixel_values"]), dtype=torch.float32)
lang_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64)
lang_inputs["vit_embeds"] = torch.zeros((inputs_shapes["vit_embeds"]), dtype=torch.float32)
lang_inputs["position_ids"] = (
torch.arange(constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, dtype=torch.int64)
.view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN)
.repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1)
)

# Add data for KV
kv_cache_shape = get_padding_shape_from_config(
config=self.language_model.config,
batch_size=constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE,
seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN,
)

lang_inputs["past_key_values"] = [[] for _ in range(self.language_model.config.num_hidden_layers)]
for i in range(self.language_model.config.num_hidden_layers):
for kv in ["key", "value"]:
lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32))

inputs = {}
if kv_offload:
inputs["vision"] = vision_inputs
inputs["lang"] = lang_inputs
else:
lang_inputs.pop("vit_embeds")
inputs = {**vision_inputs, **lang_inputs}

return inputs

def get_specializations(
self,
batch_size: int,
prefill_seq_len: int,
ctx_len: int,
img_size: int,
kv_offload: bool = False,
**compiler_options,
):
prefill_seq_len = prefill_seq_len if prefill_seq_len else SEQ_LEN
ctx_len = ctx_len if ctx_len else CTX_LEN
height = constants.MISTRAL3_HEIGHT
width = constants.MISTRAL3_WIDTH

vision = [
{
"batch_size": batch_size,
"seq_len": prefill_seq_len,
"ctx_len": ctx_len,
"height": height,
"width": width,
}
]
lang = [
{
"batch_size": batch_size,
"seq_len": prefill_seq_len,
"ctx_len": ctx_len,
"height": height,
"width": width,
},
{
"batch_size": batch_size,
"seq_len": "1",
"ctx_len": ctx_len,
"height": height,
"width": width,
},
]
specializations = {}

if kv_offload:
specializations["vision"] = vision
specializations["lang"] = lang
return specializations, compiler_options
else:
return lang, compiler_options

def get_onnx_dynamic_axes(self, kv_offload: bool = False):
# Define dynamic axes
num_layers = self.config.text_config.num_hidden_layers

vision_dynamic_axes = {
"pixel_values": {0: "batch_size", 2: "height", 3: "width"},
}
lang_dynamic_axes = {
"input_ids": {0: "batch_size", 1: "seq_len"},
"position_ids": {0: "batch_size", 1: "seq_len"},
}

for i in range(num_layers):
lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"}
lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"}

dynamic_axes = {}
if kv_offload:
dynamic_axes["vision"] = vision_dynamic_axes
dynamic_axes["lang"] = lang_dynamic_axes
else:
dynamic_axes = {**vision_dynamic_axes, **lang_dynamic_axes}
return dynamic_axes

def get_output_names(self, kv_offload: bool = False):
vision_output_names = ["vit_embeds"]
lang_output_names = ["logits"]
for i in range(self.language_model.config.num_hidden_layers):
for kv in ["key", "value"]:
lang_output_names.append(f"past_{kv}.{i}_RetainedState")

output_names = {}
if kv_offload:
lang_output_names.insert(1, "vit_embeds_RetainedState")
output_names["vision"] = vision_output_names
output_names["lang"] = lang_output_names
else:
lang_output_names.insert(1, "pixel_values_RetainedState")
return lang_output_names
return output_names

def get_inputs_info(self):
return [
IOInfo(name="input_ids", datatype=torch.int64, shape=("batch_size", "seq_len")),
IOInfo(name="attention_mask", datatype=torch.int64, shape=("batch_size", "seq_len")),
IOInfo(name="pixel_values", datatype=torch.float32, shape=("batch_size", 3, "height", "width")),
]
10 changes: 10 additions & 0 deletions QEfficient/transformers/models/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@
MistralModel,
MistralRMSNorm,
)
from transformers.models.mistral3.modeling_mistral3 import (
Mistral3ForConditionalGeneration,
Mistral3RMSNorm,
)
from transformers.models.mixtral.modeling_mixtral import (
MixtralAttention,
MixtralDecoderLayer,
Expand Down Expand Up @@ -96,6 +100,7 @@
Phi3Model,
Phi3RMSNorm,
)
from transformers.models.pixtral.modeling_pixtral import PixtralRMSNorm
from transformers.models.qwen2.modeling_qwen2 import (
Qwen2Attention,
Qwen2DecoderLayer,
Expand Down Expand Up @@ -188,6 +193,7 @@
QEffMistralForCausalLM,
QEffMistralModel,
)
from QEfficient.transformers.models.mistral3.modeling_mistral3 import QEffMistral3ForConditionalGeneration
from QEfficient.transformers.models.mixtral_moe.modeling_mixtral import (
QEffMixtralAttention,
QeffMixtralDecoderLayer,
Expand Down Expand Up @@ -255,11 +261,13 @@ class CustomOpsTransform(ModuleMappingTransform):
Gemma2RMSNorm: GemmaCustomRMSNormAIC,
LlamaRMSNorm: CustomRMSNormAIC,
MistralRMSNorm: CustomRMSNormAIC,
Mistral3RMSNorm: CustomRMSNormAIC,
MixtralRMSNorm: CustomRMSNormAIC,
Phi3RMSNorm: CustomRMSNormAIC,
Qwen2RMSNorm: CustomRMSNormAIC,
MllamaTextRMSNorm: CustomRMSNormAIC,
GraniteRMSNorm: CustomRMSNormAIC,
PixtralRMSNorm: CustomRMSNormAIC,
}


Expand Down Expand Up @@ -321,6 +329,8 @@ class KVCacheTransform(ModuleMappingTransform):
MistralDecoderLayer: QEffMistralDecoderLayer,
MistralModel: QEffMistralModel,
MistralForCausalLM: QEffMistralForCausalLM,
# Mistral3
Mistral3ForConditionalGeneration: QEffMistral3ForConditionalGeneration,
# Mixtral
MixtralAttention: QEffMixtralAttention,
MixtralSparseMoeBlock: QEffMixtralSparseMoeBlock,
Expand Down
7 changes: 7 additions & 0 deletions QEfficient/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,13 @@ def get_models_dir():
INTERN_NUM_CHANNELS = 3
INTERN_IMG_CONTEXT_TOKEN = 151667

# MISTRAL3 Constants
# Fixing the feature size with reference to mistralai/Mistral-Small-3.1-24B-Instruct-2503
MISTRAL3_FEATURE_SIZE = 2255
MISTRAL3_NUM_CHANNELS = 3
MISTRAL3_HEIGHT = 1540
MISTRAL3_WIDTH = 1162


class Constants:
# Export Constants.
Expand Down
Loading