Skip to content

feat: Support Aggregate mode for phi4-mm #6184

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Wanli-Jiang
Copy link
Collaborator

@Wanli-Jiang Wanli-Jiang commented Jul 18, 2025

Summary by CodeRabbit

  • New Features

    • Introduced support for the Phi-4-MM multimodal model, enabling integration of text, image, and audio inputs.
    • Added configurable embedding layers for image and audio modalities.
    • Implemented a dedicated multimodal encoder for seamless processing of multiple input types.
    • Integrated a full PyTorch implementation of the SigLIP model for advanced vision and text feature extraction.
    • Enhanced support for advanced attention mechanisms and positional embeddings.
  • Improvements

    • Streamlined multimodal embedding extraction and processing, improving clarity and maintainability for end-users.

Description

Test Coverage

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

Copy link
Contributor

coderabbitai bot commented Jul 18, 2025

Walkthrough

This update introduces comprehensive multimodal support for the Phi-4-MM model within the TensorRT-LLM PyTorch codebase. It adds new configuration, embedding, and utility classes for handling image and audio modalities, implements a full SigLIP vision model, and refactors the input processing and encoder pipeline to integrate these multimodal components.

Changes

File(s) Change Summary
tensorrt_llm/_torch/models/configuration_phi4mm.py Added Phi4MMConfig, a configuration class for the Phi-4-MM model, supporting advanced attention, rotary embeddings, and multimodal parameters, with validation and compatibility logic.
tensorrt_llm/_torch/models/modeling_phi4mm.py Added HFPhi4MultimodalEncoder class for handling text, image, and audio embeddings; refactored Phi4MMInputProcessor to use this encoder and updated embedding extraction logic for multimodal inputs.
tensorrt_llm/_torch/models/utils_phi4mm.py Introduced Phi4MMImageEmbedding and Phi4MMAudioEmbedding classes, implementing modular PyTorch layers for integrating image and audio features into the Phi-4-MM model, with token replacement, projection, and advanced transforms.
tensorrt_llm/_torch/models/vision_siglip_navit.py Added a full PyTorch implementation of the SigLIP multimodal model, including text and vision configs, embedding modules, transformer encoder layers, attention mechanisms, pooling heads, and model integration for contrastive learning.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant Phi4MMInputProcessor
    participant HFPhi4MultimodalEncoder
    participant ImageEmbedding
    participant AudioEmbedding

    User->>Phi4MMInputProcessor: Provide input_ids, image/audio embeds, masks
    Phi4MMInputProcessor->>HFPhi4MultimodalEncoder: forward(input_ids, image/audio embeds, masks)
    alt Image embeddings present
        HFPhi4MultimodalEncoder->>ImageEmbedding: Process image tokens/embeds
        ImageEmbedding-->>HFPhi4MultimodalEncoder: Projected image features
    end
    alt Audio embeddings present
        HFPhi4MultimodalEncoder->>AudioEmbedding: Process audio tokens/embeds
        AudioEmbedding-->>HFPhi4MultimodalEncoder: Projected audio features
    end
    HFPhi4MultimodalEncoder-->>Phi4MMInputProcessor: Hidden states (with multimodal features)
    Phi4MMInputProcessor-->>User: Extracted multimodal embeddings
Loading

Poem

In a meadow of code, where new features sprout,
Multimodal magic hops in, there's no doubt!
Images and sounds now join the texty song,
SigLIP vision’s strong, embeddings hop along.
With configs and layers, the model’s a delight—
A rabbit’s proud of this multimodal might!
🐰✨

✨ Finishing Touches
  • 📝 Generate Docstrings

🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🧹 Nitpick comments (9)
tensorrt_llm/_torch/models/configuration_phi4mm.py (1)

224-227: Consider splitting long error messages for better readability

The error messages exceed the 120 character line limit. While this doesn't affect functionality, splitting them would improve code readability.

         if not len(rope_scaling_short_factor) == rotary_ndims // 2:
             raise ValueError(
-                f"`rope_scaling`'s short_factor field must have length {rotary_ndims // 2}, got {len(rope_scaling_short_factor)}"
+                f"`rope_scaling`'s short_factor field must have length {rotary_ndims // 2}, "
+                f"got {len(rope_scaling_short_factor)}"
             )
         if not len(rope_scaling_long_factor) == rotary_ndims // 2:
             raise ValueError(
-                f"`rope_scaling`'s long_factor field must have length {rotary_ndims // 2}, got {len(rope_scaling_long_factor)}"
+                f"`rope_scaling`'s long_factor field must have length {rotary_ndims // 2}, "
+                f"got {len(rope_scaling_long_factor)}"
             )

Also applies to: 235-238

tensorrt_llm/_torch/models/utils_phi4mm.py (1)

128-128: Improve assertion error message

The assertion could provide a more descriptive error message to help users understand the requirement.

-        assert self.use_hd_transform == self.with_learnable_separator, 'use_hd_transform and with_learnable_separator should have same value'
+        assert self.use_hd_transform == self.with_learnable_separator, \
+            f'use_hd_transform ({self.use_hd_transform}) and with_learnable_separator ({self.with_learnable_separator}) must have the same value'
tensorrt_llm/_torch/models/vision_siglip_navit.py (7)

268-268: Fix capitalization in log message

The log message should start with a capital letter for consistency.

-            logger.info("`vision_config` is `None`. initializing the `SiglipVisionConfig` with default values.")
+            logger.info("`vision_config` is `None`. Initializing the `SiglipVisionConfig` with default values.")

729-729: Improve comment clarity

The comment uses informal language that could be more professional.

-        self.is_causal = False  # Hack to make sure we don't use a causal mask
+        self.is_causal = False  # Siglip uses bidirectional attention, not causal

758-765: Remove commented-out code

This commented-out code for rotary embeddings and KV cache should be removed if it's not needed for Siglip.

Consider removing these lines entirely if rotary embeddings are not part of the Siglip architecture.


766-767: Address TODO about transpose inefficiency

This TODO indicates a known performance issue with the transposes required for Flash Attention.

Would you like me to create an issue to track this optimization opportunity? The inefficiency stems from Flash Attention requiring a different tensor layout than the standard transformer implementation.


995-995: Use math.sqrt for consistency

The code uses np.sqrt here but math.sqrt elsewhere in the file.

-            nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
+            nn.init.normal_(module.position_embedding.weight, std=1 / math.sqrt(width))

1687-1688: Implement SigLIP loss function

The loss function raises NotImplementedError, which means the model cannot be trained.

The SigLIP loss is a key component for training. Would you like me to help implement the SigLIP loss function or create an issue to track this?


1705-1720: Consider externalizing model configuration

The model configuration is hardcoded in the function. Consider moving it to a configuration file or class constant for better maintainability.

# Define as a class constant or in a separate config file
SIGLIP_VISION_BASE_CONFIG = {
    "hidden_size": 1152,
    "image_size": 448,
    "intermediate_size": 4304,
    "model_type": "siglip_vision_model",
    "num_attention_heads": 16,
    "num_hidden_layers": 27,
    "patch_size": 14,
}

def get_siglip_vision_model(_flash_attn_2_enabled=True, **kwargs):
    model_config = SiglipVisionConfig(**SIGLIP_VISION_BASE_CONFIG, _flash_attn_2_enabled=_flash_attn_2_enabled, **kwargs)
    vision_model = SiglipVisionModel(model_config).vision_model
    return vision_model
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ec2b953 and 4510412.

📒 Files selected for processing (4)
  • tensorrt_llm/_torch/models/configuration_phi4mm.py (1 hunks)
  • tensorrt_llm/_torch/models/modeling_phi4mm.py (3 hunks)
  • tensorrt_llm/_torch/models/utils_phi4mm.py (1 hunks)
  • tensorrt_llm/_torch/models/vision_siglip_navit.py (1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
tensorrt_llm/_torch/models/configuration_phi4mm.py (1)
tensorrt_llm/models/modeling_utils.py (1)
  • PretrainedConfig (361-562)
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/models/configuration_phi4mm.py

30-30: Line too long (121 > 120)

(E501)


226-226: Line too long (129 > 120)

(E501)


237-237: Line too long (127 > 120)

(E501)

tensorrt_llm/_torch/models/utils_phi4mm.py

105-105: Line too long (123 > 120)

(E501)


108-108: Line too long (141 > 120)

(E501)


166-166: Line too long (137 > 120)

(E501)


172-172: Line too long (133 > 120)

(E501)


190-190: Line too long (125 > 120)

(E501)


199-199: Line too long (125 > 120)

(E501)


246-246: Line too long (133 > 120)

(E501)


299-299: Line too long (130 > 120)

(E501)


306-306: Line too long (173 > 120)

(E501)


316-316: Line too long (250 > 120)

(E501)


338-338: Line too long (366 > 120)

(E501)


342-342: Line too long (140 > 120)

(E501)


351-351: Line too long (301 > 120)

(E501)


352-352: Line too long (334 > 120)

(E501)


355-355: Line too long (333 > 120)

(E501)


360-360: Line too long (161 > 120)

(E501)


363-363: Line too long (123 > 120)

(E501)


365-365: Line too long (140 > 120)

(E501)


377-377: Line too long (142 > 120)

(E501)


380-380: Local variable num_img_tokens is assigned to but never used

Remove assignment to unused variable num_img_tokens

(F841)


385-385: Line too long (138 > 120)

(E501)


386-386: Line too long (171 > 120)

(E501)


395-395: Line too long (123 > 120)

(E501)


404-404: Line too long (164 > 120)

(E501)


409-409: Line too long (133 > 120)

(E501)


419-419: Line too long (150 > 120)

(E501)


436-436: Line too long (122 > 120)

(E501)


539-539: Line too long (139 > 120)

(E501)


556-556: Line too long (198 > 120)

(E501)


569-569: Local variable MAX_INPUT_ID is assigned to but never used

Remove assignment to unused variable MAX_INPUT_ID

(F841)


603-603: Line too long (310 > 120)

(E501)


627-627: Line too long (131 > 120)

(E501)

tensorrt_llm/_torch/models/vision_siglip_navit.py

75-75: Line too long (122 > 120)

(E501)


169-169: Line too long (124 > 120)

(E501)


303-303: Module level import not at top of file

(E402)


304-304: Module level import not at top of file

(E402)


305-305: Module level import not at top of file

(E402)


306-306: Module level import not at top of file

(E402)


308-308: Module level import not at top of file

(E402)


309-309: Module level import not at top of file

(E402)


310-310: Module level import not at top of file

(E402)


311-311: Module level import not at top of file

(E402)


312-312: Module level import not at top of file

(E402)


313-313: Module level import not at top of file

(E402)


315-315: Module level import not at top of file

(E402)


316-316: Module level import not at top of file

(E402)


317-317: Module level import not at top of file

(E402)


318-318: Module level import not at top of file

(E402)


319-326: Module level import not at top of file

(E402)


372-372: Ambiguous variable name: l

(E741)


467-467: Line too long (153 > 120)

(E501)


471-471: Line too long (159 > 120)

(E501)


475-475: Line too long (150 > 120)

(E501)


494-494: Line too long (152 > 120)

(E501)


498-498: Line too long (159 > 120)

(E501)


502-502: Line too long (150 > 120)

(E501)


531-531: Line too long (121 > 120)

(E501)


697-697: Line too long (125 > 120)

(E501)


766-766: Line too long (184 > 120)

(E501)


834-834: Line too long (170 > 120)

(E501)


948-948: Line too long (140 > 120)

(E501)

tensorrt_llm/_torch/models/modeling_phi4mm.py

127-127: Line too long (121 > 120)

(E501)


193-193: Line too long (153 > 120)

(E501)


196-196: Line too long (165 > 120)

(E501)

🔇 Additional comments (6)
tensorrt_llm/_torch/models/configuration_phi4mm.py (1)

1-2: Configuration file source is properly documented

Good practice to document the source of copied code with a clear explanation of why it was necessary.

tensorrt_llm/_torch/models/modeling_phi4mm.py (3)

154-160: Good backward compatibility implementation

The token ID remapping ensures compatibility with legacy token ranges while maintaining clear special token IDs.


233-245: Clean refactoring to use dedicated multimodal encoder

The change from AutoModelForCausalLM to HFPhi4MultimodalEncoder simplifies the pipeline and improves code clarity. Setting trust_remote_code=False is a good security practice.


1-3: Clear roadmap for AGGREGATE mode implementation

Good to see the implementation plan clearly documented. This PR provides the foundational multimodal support needed before implementing AGGREGATE mode in the next step.

tensorrt_llm/_torch/models/vision_siglip_navit.py (2)

1415-1439: Clarify relationship to Aggregate mode feature

The PR objective mentions "Support Aggregate mode for phi4-mm", but this file doesn't appear to have any explicit aggregate mode implementation. The SiglipMultiheadAttentionPoolingHead performs attention-based pooling, which might be related to the aggregate feature.

Could you clarify:

  1. Is this multihead attention pooling the "Aggregate mode" mentioned in the PR?
  2. If not, where is the aggregate mode implemented?
  3. Should there be a configuration option to enable/disable aggregate mode?

887-890: Clarify cu_seqlens allocation and buffer‐reuse suggestion

Torch’s arange(..., device=…) allocates and fills the tensor directly on the GPU (no host→device memcpy). If profiling shows this per-forward allocation is a bottleneck, consider pre-allocating a maximum‐length cu_seqlens buffer in __init__ (via register_buffer) and slicing it each call.

• File: tensorrt_llm/_torch/models/vision_siglip_navit.py
Lines: 887–890

-            cu_seqlens_q = torch.arange(
-                batch_size + 1, dtype=torch.int32, device=query_layer.device
-            )  # There is a memcpy here, that is very bad.
+            # Direct on-device fill; no host-device memcpy
+            cu_seqlens_q = torch.arange(
+                batch_size + 1, dtype=torch.int32, device=query_layer.device
+            )

Example buffer‐reuse pattern:

# in __init__
self.register_buffer(
    "_cu_seqlens_buffer",
    torch.arange(self.max_batch_size + 1, dtype=torch.int32, device=self.device),
)

# in forward
cu_seqlens_q = self._cu_seqlens_buffer[: batch_size + 1]

Please profile this allocation and, if it shows up in your GPU-allocation metrics, switch to the buffer-reuse approach.

@Wanli-Jiang Wanli-Jiang force-pushed the user/williamj/update-phi4-mm branch from 4510412 to 15702ee Compare July 29, 2025 08:54
if not DISAGG:
# Forward the multimodal data to HFPhi4MultimodalEncoder in AGGREGATE mode.
# TODO: any better ways to run data parallel or batching?
for i in range(len(multimodal_params)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably you can do torch.concat and pass the concatenated all values to the Encoder's forward().

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@2ez4bz do you see the similar cases? The patches of images OR audios are with different sizes, want to know any thought to handle with it? Thanks!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, I don't know the exact details of what goes on for phi4, but supposing the audio / image portions have separate encoders, they should be individually batchable. So perhaps what @yechank-nvidia meant is that, the various input_image_embeds can be torch.cat, and so can the input_audio_embeds?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I think I see what you were asking now - were you asking what I did to take into account that images themselves could be of different shapes? If so, I replicated what happens in the pixtral input processor, but in the model forward here: https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/models/modeling_mistral.py#L439

The reason being that the input processor always runs on a single element, so it wouldn't have knowledge of other images, and how to pad the "batch".

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after reading the codes, the vision_encoder itself requires the same size amoung batches, so that we can not concat it as batched_input.

I left a comment for it.

@Wanli-Jiang Wanli-Jiang force-pushed the user/williamj/update-phi4-mm branch 2 times, most recently from 63bd263 to 3d32c42 Compare July 30, 2025 06:11
input_embeds=input_image_embeds,
image_sizes=image_sizes,
wte=wte,
image_attention_mask=image_attention_mask,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@2ez4bz hi, William, is Mistral SigLip using Full attention mask, or customized attention_mask due to padding and patching of original images?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you asking about Mistral3.1 Small's vision encoder? If so, good question - I have not looked at it, but it looks like I might need to replicate it 😅 Perhaps @yechank-nvidia can chime in on whether that is the case?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Upon further investigation, it seems the HF implementation uses a block diagonal attention mask, which puts me in a bind...

  1. Only FlashInfer seems to support a custom mask
  2. ... but it expects a KV cache manager.
  3. I got segfaults trying to run the advanced quickstart on Llama3.1 8B using VANILLA as a backend, and in Mistral3.1 VLM using the multimodal quickstart, so... yeah... 🫠

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I think I can answer this more correctly now:

  1. For HF's pixtral implementation (ref), the attention mask is full attention, treating each image as a sequence.
  2. In TRTLLM, it turns out that I had already set the mask to FULL, and prepared the attention metadata such, that, the seq_lens reflected each image's "sequence length".

So the attention mask was set correctly since TRTLLM supports FULL.

TL;DR - I did not need a custom attention mask for mistral 😅

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after investigating a bit, I keep using the HF implementation (flash_attn_2) here.

The reason are

  • TRTLLM-attn cannot support customized attn_mask.
  • While flashinfer-attn is quite complicated, and from some internet searching results, flashinfer_attn without kv_cache_manager is just like flash_attn_2.

input_embeds=input_image_embeds,
image_sizes=image_sizes,
wte=wte,
image_attention_mask=image_attention_mask,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you asking about Mistral3.1 Small's vision encoder? If so, good question - I have not looked at it, but it looks like I might need to replicate it 😅 Perhaps @yechank-nvidia can chime in on whether that is the case?

if not DISAGG:
# Forward the multimodal data to HFPhi4MultimodalEncoder in AGGREGATE mode.
# TODO: any better ways to run data parallel or batching?
for i in range(len(multimodal_params)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, I don't know the exact details of what goes on for phi4, but supposing the audio / image portions have separate encoders, they should be individually batchable. So perhaps what @yechank-nvidia meant is that, the various input_image_embeds can be torch.cat, and so can the input_audio_embeds?

@Wanli-Jiang Wanli-Jiang force-pushed the user/williamj/update-phi4-mm branch from 3d32c42 to c9e09c2 Compare July 31, 2025 09:36
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
@Wanli-Jiang Wanli-Jiang force-pushed the user/williamj/update-phi4-mm branch from c9e09c2 to e1aa7b6 Compare July 31, 2025 09:46
@Wanli-Jiang Wanli-Jiang marked this pull request as ready for review July 31, 2025 09:48
@Wanli-Jiang Wanli-Jiang requested a review from a team as a code owner July 31, 2025 09:48
@Wanli-Jiang
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13653 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13653 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #10249 completed with status: 'FAILURE'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants