Skip to content

Conversation

@malkomes
Copy link

@malkomes malkomes commented Apr 16, 2025

WIP for Phase 2 Qwen2.5-VL optimization.
Details will be provided once ready.

--
Co-authored-by: Gustavo Malkomes gustavo.malkomes@intel.com
Co-authored-by: Jimin Ha jimin.ha@intel.com
Co-authored-by: Sayantan Sarkar sayantan.sarkar@intel.com
Co-authored-by: Iman Gohari s.m.iman.gohari@intel.com

imangohari1 and others added 20 commits April 8, 2025 21:57
split warmup in text only and image only

force input_positions in text to be 3, seq_len
full_attention_mask doesn't need to be created for each full
attention layer, only create once and reuse. This can save memory
and time.
profile_run takes maximum tensor size of 65K. To support it, we need to
reduce significant memory usage by adding below.

- Set disable_tensor_cache=True for vision model as well
- Add additional mark_step to split the graphs
- Move einsum operation to cpu for bigger tensor(due to GC error)
- Run FusedSDPA for longer sequence as well
@imangohari1 imangohari1 changed the title Ig/qwen2 5 vl vision transformer [Gaudi][Model] Qwen2.5-VL optimization Apr 17, 2025
ssarkar2 and others added 2 commits April 18, 2025 02:30
- fix use_graph to detect multimodal bucket correctly
- pass the right pixel size for execution
- change multimodal buckets to align with resize
- remove multimodal warmup for Decode
#self.multimodal_buckets = [1600, 3136, 4096, 6400, 7744, 9216, 12544, 16384, 26500, 40000, 65536]
self.multimodal_buckets = [1600, 3136, 4096, 6400, 7744, 9216, 12544]
else:
self.multimodal_buckets = [int(i) for i in envvar.split(',')]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a "sorted" on this. The way get_multimodal_bucket works assumes sorted

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

range_to_max_for_each_img = torch.arange(maxsize, device=indices.device).unsqueeze(0).repeat(indices.shape[0]-1,1)
yy = range_to_max_for_each_img < indices[1:].unsqueeze(1)
zz = range_to_max_for_each_img >= indices[:-1].unsqueeze(1)
xx = torch.logical_and(yy, zz).float()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

chage var names

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@imangohari1
Copy link

this PR upto the following commit has been tested for pytests, online tests and offline tests.
a03181d

f"[MM_BUCKETING] Padding current number pixel {pixel_values.shape[0]} to {desired_number_of_pixels}"
)
# needs to make sure padding_len is even
assert padding_len % 64 == 0, '[testing version] padding needs to be multiple of 64'

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does [testing version] mean?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


if video_input is not None:
if is_hpu:
print("Video inputs have not been enabled/verified yet, ignoring video inputs")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use logger.warning

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

_PAD_SLOT_ID = 0
_PAD_BLOCK_ID = 0

_UNSET_NUM_PATCHES = 9999999

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose we use this because "None" means something else? if not could we use None?

def __init__(self):
envvar = os.environ.get('VLLM_MULTIMODAL_BUCKETS', "")
if envvar == "":
#TODO:with profile_run, the bucket of 65536 is added, so the pixel values

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this statement is no longer true i think. we profile with largest bucket in this class

Copy link

@imangohari1 imangohari1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sayantan-nervana
I have gone through the PR as of here, made some suggestions and comments.
I hope these help.


from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import cached_get_tokenizer
# from vllm.model_executor.models.qwen2_5_vl import Qwen2_5_VLImageProcessorForceAlignment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should remove this line?

expected_pixels_shape_one = 1176
expected_toks_per_img = expected_pixels_shape_zero // 4
mm_processor_kwargs = {}
#mm_processor_kwargs = {"force_alignment": True}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to be removed?

q_len = q.size(-2)
assert q_len % q_block_size == 0
q_tiles = (q_len // q_block_size) if (q_len % q_block_size == 0) else math.ceil(q_len / q_block_size)
#q_padding = q_tiles * q_block_size - q_len

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we remove these commented lines?

row_mask = mask[:, :, s:e, :]
attn_output[:, :, s:e, :] = FusedSDPA.apply(row_q, k, v, row_mask, 0.0, False, None)
#TODO: markstep every 10th layer, didn't experiment which one is optimal number.
#10,50,100 shows simliar result, without this, we see the program hangs for multiple prompts(with larger images)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#10,50,100 shows simliar result, without this, we see the program hangs for multiple prompts(with larger images)
#INFO: %10, 50, 100 show similar results. Without the mark_step here, the model hangs for multiple prompts and/or larger images

return (batch_size, seq_len, is_prompt) in self.graphed_buckets
if not num_patches:
return (batch_size, seq_len, is_prompt) in self.graphed_buckets
#TODO: We might need to check both language bucket and multimodal bucket

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still needed?

lora_request):
assert self.model_is_mrope, "Warmup compatible with Qwen2vl models"
if num_patches == _UNSET_NUM_PATCHES:
# # only half of the total number of tokens should be from image

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this section be cleaned up?


image_h = int(math.sqrt(num_patches))
image_grid_thw = torch.tensor([1, image_h, image_h])
pixel_values = torch.randn(image_grid_thw.prod(), 1176) # TODO: figure out the variable name

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pixel_values = torch.randn(image_grid_thw.prod(), 1176) # TODO: figure out the variable name
pixel_values = torch.randn(image_grid_thw.prod(), 1176)

#TODO: einsum with tensor dimension too big doesn't work. Register max size error.
#We can always move to CPU for all einsum without shape checking if perf impact is minimal.
if range_indices.shape[-1] > 40000:
print("einsum running on CPU : ", range_indices.shape)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
print("einsum running on CPU : ", range_indices.shape)
logger.info("einsum running on CPU : ", range_indices.shape)

Copy link

@sayantan-nervana sayantan-nervana Apr 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bug:

lets say, we want to warm up for bucket=42688

code:

image_h = int(math.sqrt(num_patches))
image_grid_thw = torch.tensor([1, image_h, image_h])

num_patches = 42688
sqrt is: 206.6
int trucate: 206
so now grid thw is 206x206
which is 42436

42436%64 != 0

proposed change, something like: image_grid_thw = torch.tensor([1, image_h, num_patches/image_h])

malkomes and others added 5 commits April 23, 2025 23:10
Co-authored-by: Iman Gohari <s.m.iman.gohari@intel.com>
Co-authored-by: Iman Gohari <s.m.iman.gohari@intel.com>
@malkomes malkomes force-pushed the ig/qwen2_5-vl_visionTransformer branch from 63e356f to 7c9cf4c Compare April 23, 2025 23:25
@malkomes malkomes changed the title [Gaudi][Model] Qwen2.5-VL optimization [Draft][Gaudi][Model] Qwen2.5-VL optimization Apr 29, 2025
@wenbinc-Bin wenbinc-Bin mentioned this pull request May 19, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants