Skip to content

Commit

Permalink
feat: support non-cuda devices for vision models
Browse files Browse the repository at this point in the history
This commit adds support of non-cuda pytorch backend devices
to vision models. Commit verified on Llama3.2-11B-Vision-Instruct
model for:
* "cuda" device type on NVidia A10 GPU
* "cpu" device type
* "xpu" device type on Intel Data Center Max Series GPU (PVC)

Note that this commit requires a fix on pytorch side for gloo
torch distributed backend to restore TLS on gloo working threads.

Requires: pytorch/pytorch#142184
Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
  • Loading branch information
dvrogozh committed Dec 6, 2024
1 parent a2cae08 commit 563e2a1
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 6 deletions.
2 changes: 1 addition & 1 deletion models/llama3/reference_impl/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def build(
from .multimodal.model import CrossAttentionTransformer

model = CrossAttentionTransformer(model_args)
model.setup_cache(model_args.max_batch_size, torch.bfloat16)
model.setup_cache(model_args.max_batch_size, torch.get_default_dtype())
else:
model = Transformer(model_args)
model.load_state_dict(checkpoint, strict=True)
Expand Down
8 changes: 3 additions & 5 deletions models/llama3/reference_impl/multimodal/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1113,7 +1113,7 @@ def forward(
# aspect_ratios: (B, T)
# h: (B, T, D)
vision_tokens = self.vision_encoder(
images.to(dtype=torch.bfloat16), aspect_ratios
images.to(dtype=torch.get_default_dtype()), aspect_ratios
)

vision_tokens = F.linear(
Expand Down Expand Up @@ -1407,8 +1407,6 @@ def compute_vision_tokens_masks(
else:
vision_tokens = self.vision_model(stacked_images, aspect_ratios)

vision_tokens = vision_tokens.to("cuda")

bsz, nimg, nchunk, ntok, image_token_dim = tuple(vision_tokens.shape)
xattn_caches = torch.stack(
[
Expand All @@ -1428,7 +1426,7 @@ def compute_vision_tokens_masks(
cross_attention_masks, full_text_row_masked_out_mask = (
self.text_model._get_xattn_mask(
num_tokens=total_len,
text_device="cuda",
text_device=vision_tokens.device.type,
text_dtype=next(self.text_model.parameters()).dtype,
vision_tokens=vision_tokens,
cross_attention_masks=padded_masks,
Expand Down Expand Up @@ -1495,7 +1493,7 @@ def _pad_masks(
total_len: int,
max_num_chunks: int,
) -> torch.Tensor:
dtype = torch.bfloat16
dtype = torch.get_default_dtype()
inf_value = get_negative_inf_value(dtype)

bsz = len(all_masks)
Expand Down

0 comments on commit 563e2a1

Please sign in to comment.