Skip to content

Commit c6f2b30

Browse files
committed
[TRTLLM-6577][feat] Support nano_v2_vlm in pytorch backend
* Update according to reviewers' comments. Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
1 parent 5b7daef commit c6f2b30

File tree

3 files changed

+34
-47
lines changed

3 files changed

+34
-47
lines changed

.github/CODEOWNERS

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@
9999
/tests/unittest/_torch/modeling/test_modeling_pixtral.py @NVIDIA/trt-llm-torch-models-vlm-devs @NVIDIA/trt-llm-torch-models-devs
100100

101101
### TensorRT-LLM Pytorch - Models - Nemotron
102-
/tensorrt_llm/_torch/models/modeling_nanov2vlm.py @NVIDIA/trt-llm-torch-models-nemotron-devs @NVIDIA/trt-llm-torch-models-devs
103-
/tensorrt_llm/_torch/models/modeling_radio.py @NVIDIA/trt-llm-torch-models-nemotron-devs @NVIDIA/trt-llm-torch-models-devs
102+
/tensorrt_llm/_torch/models/modeling_nanov2vlm.py @NVIDIA/trt-llm-torch-models-nemotron-devs @NVIDIA/trt-llm-torch-models-vlm-devs @NVIDIA/trt-llm-torch-models-devs
103+
/tensorrt_llm/_torch/models/modeling_radio.py @NVIDIA/trt-llm-torch-models-nemotron-devs @NVIDIA/trt-llm-torch-models-vlm-devs @NVIDIA/trt-llm-torch-models-devs
104104
/tensorrt_llm/_torch/models/modeling_nemotron_nas.py @NVIDIA/trt-llm-torch-models-nemotron-devs @NVIDIA/trt-llm-torch-models-devs
105105
/tensorrt_llm/_torch/models/modeling_nemotron_h.py @NVIDIA/trt-llm-torch-models-nemotron-devs @NVIDIA/trt-llm-torch-models-devs
106106
/tensorrt_llm/_torch/models/modeling_nemotron_nas.py @NVIDIA/trt-llm-torch-models-nemotron-devs @NVIDIA/trt-llm-torch-models-devs

tensorrt_llm/_torch/models/modeling_nanov2vlm.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -40,21 +40,6 @@ def forward(self, x):
4040
return torch.pow(torch.nn.functional.relu(x), 2)
4141

4242

43-
class RMSNorm(nn.Module):
44-
45-
def __init__(self, hidden_size, eps=1e-5):
46-
super().__init__()
47-
self.weight = nn.Parameter(torch.ones(hidden_size))
48-
self.eps = eps
49-
50-
def forward(self, hidden_states):
51-
input_dtype = hidden_states.dtype
52-
hidden_states = hidden_states.to(torch.float32)
53-
variance = hidden_states.pow(2).mean(-1, keepdim=True)
54-
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
55-
return (self.weight.to(torch.float32) * hidden_states).to(input_dtype)
56-
57-
5843
class NanoV2VLVisionEncoder(transformers.PreTrainedModel,
5944
transformers.generation.GenerationMixin):
6045

@@ -74,8 +59,8 @@ def __init__(self,
7459
self.vision_projection_hidden_size = config.projector_hidden_size
7560
self.llm_hidden_size = config.llm_config.hidden_size
7661
self.mlp1 = nn.Sequential(
77-
RMSNorm(self.vit_hidden_size * int(1 / self.downsample_ratio)**2,
78-
eps=1e-5),
62+
nn.RMSNorm(self.vit_hidden_size * int(1 / self.downsample_ratio)**2,
63+
eps=config.llm_config.rms_norm_eps),
7964
nn.Linear(self.vit_hidden_size * int(1 / self.downsample_ratio)**2,
8065
self.vision_projection_hidden_size,
8166
bias=False), SquaredReLU(),
@@ -204,8 +189,7 @@ def get_mm_token_ids(self):
204189
def get_num_tokens_per_image(
205190
self,
206191
*,
207-
image_width: int,
208-
image_height: int,
192+
image: Image.Image,
209193
**kwargs,
210194
):
211195

@@ -256,6 +240,8 @@ def calculate_targets(
256240

257241
return blocks
258242

243+
image_height = image.height
244+
image_width = image.width
259245
target_ratios = get_internvl_target_ratios(1,
260246
self.processor.max_num_tiles)
261247
blocks = calculate_targets(image_width, image_height, target_ratios,

tensorrt_llm/_torch/models/modeling_radio.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from transformers import PretrainedConfig, PreTrainedModel
1515

1616
from tensorrt_llm._torch import model_config as model_config_lib
17+
from tensorrt_llm._torch.attention_backend import AttentionMetadata
1718
from tensorrt_llm._torch.attention_backend import \
1819
interface as attention_interface
1920
from tensorrt_llm._torch.attention_backend import utils as attention_utils
@@ -540,9 +541,8 @@ def __init__(
540541
act_layer = nn.GELU
541542

542543
self.model_config = model_config
543-
if self.model_config is not None:
544-
self.config = model_config.pretrained_config
545-
self.config.num_key_value_heads = num_heads
544+
self.config = model_config.pretrained_config
545+
self.config.num_key_value_heads = num_heads
546546

547547
self.num_classes = num_classes
548548
self.global_pool = global_pool
@@ -622,28 +622,31 @@ def __init__(
622622
self.patch_size = patch_size
623623
self.num_cls_tokens = num_cls_tokens
624624
self.num_registers = self.patch_generator.num_registers
625-
if self.model_config is not None:
626-
self.metadata_cls = attention_utils.get_attention_backend(
627-
model_config.attn_backend).Metadata
628-
else:
629-
self.metadata_cls = None
630625

631-
def prepare_attn_metadata(self, batch_size: int, seq_lengths: List[int]):
626+
self.metadata_cls = attention_utils.get_attention_backend(
627+
model_config.attn_backend).Metadata
628+
self.attn_metadata = self.metadata_cls(
629+
max_num_requests=8192, # TODO: Make this dynamic
630+
max_num_tokens=model_config.max_num_tokens,
631+
kv_cache_manager=None,
632+
)
633+
634+
def prepare_attn_metadata(self, batch_size: int, seq_lengths: List[int],
635+
attn_metadata: AttentionMetadata):
632636
"""
633637
To simplify the usage of the model, this function aims to fill the metadata for Attention
634638
Call this function before forward pass
635639
"""
640+
prompt_lens = seq_lengths
641+
seq_lens = torch.tensor(seq_lengths, dtype=torch.int, pin_memory=True)
636642
request_ids = list(range(1, batch_size + 1))
637-
attn_metadata = self.metadata_cls(
638-
seq_lens=torch.tensor(seq_lengths, dtype=torch.int),
639-
num_contexts=batch_size,
640-
max_num_requests=batch_size,
641-
max_num_tokens=sum(seq_lengths),
642-
kv_cache_manager=None,
643-
request_ids=request_ids,
644-
prompt_lens=seq_lengths,
645-
)
646-
attn_metadata.max_seq_len = max(seq_lengths)
643+
644+
attn_metadata.seq_lens = seq_lens
645+
attn_metadata.num_contexts = batch_size
646+
attn_metadata.request_ids = request_ids
647+
attn_metadata.prompt_lens = prompt_lens
648+
attn_metadata.max_seq_len = seq_lens.max().item()
649+
647650
attn_metadata.prepare()
648651
return attn_metadata
649652

@@ -652,13 +655,11 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
652655
x = self.patch_generator(x)
653656

654657
batch_size, seq_len, hidden_size = x.shape
655-
if self.model_config is not None:
656-
seq_lengths = [seq_len] * batch_size
657-
attn_metadata = self.prepare_attn_metadata(batch_size, seq_lengths)
658-
# Need flatten batch/seq_len for trtllm attention.
659-
x = x.reshape(batch_size * seq_len, hidden_size)
660-
else:
661-
attn_metadata = None
658+
seq_lengths = [seq_len] * batch_size
659+
attn_metadata = self.prepare_attn_metadata(batch_size, seq_lengths,
660+
self.attn_metadata)
661+
# Need flatten batch/seq_len for trtllm attention.
662+
x = x.reshape(batch_size * seq_len, hidden_size)
662663
for block in self.blocks:
663664
x = block(x, attn_metadata=attn_metadata)
664665
x = x.reshape(batch_size, seq_len, hidden_size)

0 commit comments

Comments
 (0)