From 21fc1b2dc2067a4d38f96c7e1df7bb6c1a487ab8 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 11 Jan 2025 23:28:56 +0100 Subject: [PATCH] Fix trainer and Qwen2-VL (#179) * fix arange (qwen2-vl) * fix trainer prepare inputs error --- mlx_vlm/models/qwen2_vl/vision.py | 2 +- mlx_vlm/tests/test_trainer.py | 14 ++++++------- mlx_vlm/trainer/trainer.py | 33 +++++++++++-------------------- 3 files changed, 19 insertions(+), 30 deletions(-) diff --git a/mlx_vlm/models/qwen2_vl/vision.py b/mlx_vlm/models/qwen2_vl/vision.py index bd699a4..bd32514 100644 --- a/mlx_vlm/models/qwen2_vl/vision.py +++ b/mlx_vlm/models/qwen2_vl/vision.py @@ -88,7 +88,7 @@ def __call__(self, seqlen: int) -> mx.array: inv_freq = 1.0 / ( self.theta ** (mx.arange(0, self.dim, 2, dtype=mx.float32) / self.dim) ) - seq = mx.arange(seqlen, dtype=inv_freq.dtype) + seq = mx.arange(seqlen.tolist(), dtype=inv_freq.dtype) freqs = mx.outer(seq, inv_freq) return freqs diff --git a/mlx_vlm/tests/test_trainer.py b/mlx_vlm/tests/test_trainer.py index 97339e7..70dd370 100644 --- a/mlx_vlm/tests/test_trainer.py +++ b/mlx_vlm/tests/test_trainer.py @@ -47,15 +47,15 @@ def test_dataset_getitem(self, mock_prepare_inputs, mock_get_prompt): mock_get_prompt.return_value = "Mocked prompt" - mock_prepare_inputs.return_value = ( - mx.array([1, 2, 3]), # input_ids - mx.array( + mock_prepare_inputs.return_value = { + "input_ids": mx.array([1, 2, 3]), # input_ids + "pixel_values": mx.array( [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]] ), # pixel_values - mx.array([1, 1, 1]), # mask - (1, 1, 1), # image_grid_thw - [224, 224], # image_sizes - ) + "attention_mask": mx.array([1, 1, 1]), # mask + "image_grid_thw": (1, 1, 1), # image_grid_thw + "image_sizes": [224, 224], # image_sizes + } result = dataset[0] diff --git a/mlx_vlm/trainer/trainer.py b/mlx_vlm/trainer/trainer.py index 213b0ef..22f61f2 100644 --- a/mlx_vlm/trainer/trainer.py +++ b/mlx_vlm/trainer/trainer.py @@ -89,27 +89,21 @@ def __getitem__(self, idx): image_token_index = self.config["image_token_index"] inputs = prepare_inputs( - self.image_processor, self.processor, images, prompts, image_token_index, self.image_resize_shape, ) - input_ids, pixel_values, mask = inputs[:3] + input_ids = inputs["input_ids"] + pixel_values = inputs["pixel_values"] + mask = inputs["attention_mask"] kwargs = { k: v - for k, v in zip( - [ - "image_grid_thw", - "image_sizes", - "aspect_ratio_ids", - "aspect_ratio_mask", - "cross_attention_mask", - ], - inputs[3:], - ) + for k, v in inputs.items() + if k not in ["input_ids", "pixel_values", "attention_mask"] } + if mask is None: mask = mx.ones_like(input_ids) @@ -226,16 +220,11 @@ def loss_fn(self, model, batch): input_ids = input_ids[:, :-1] - kwargs = {} - image_keys = [ - "image_grid_thw", - "image_sizes", - "aspect_ratio_ids", - "aspect_ratio_mask", - "cross_attention_mask", - ] - if any(key in batch for key in image_keys): - kwargs = {key: batch[key] for key in image_keys if key in batch} + kwargs = { + k: v + for k, v in batch.items() + if k not in ["input_ids", "pixel_values", "attention_mask"] + } # Forward pass outputs = model(input_ids, pixel_values, attention_mask, **kwargs)