Skip to content

Commit

Permalink
apply style, ignore visual on qwen
Browse files Browse the repository at this point in the history
  • Loading branch information
kylesayrs committed Dec 19, 2024
1 parent 4123636 commit c1e66e8
Show file tree
Hide file tree
Showing 6 changed files with 8 additions and 9 deletions.
2 changes: 0 additions & 2 deletions examples/multimodal_vision/mllama.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import os

import torch
from transformers import AutoProcessor

from llmcompressor.modifiers.quantization import GPTQModifier

from llmcompressor.transformers import oneshot
from llmcompressor.transformers.tracing import TracableMllamaForConditionalGeneration
from llmcompressor.transformers.utils.data_collator import mllama_data_collator
Expand Down
1 change: 0 additions & 1 deletion examples/multimodal_vision/pixtral.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os

import torch
from transformers import AutoProcessor

from llmcompressor.modifiers.quantization import GPTQModifier
Expand Down
3 changes: 1 addition & 2 deletions examples/multimodal_vision/qwen2_vl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os

import torch
from compressed_tensors.quantization import (
QuantizationArgs,
QuantizationScheme,
Expand Down Expand Up @@ -43,7 +42,7 @@
),
),
},
ignore=["re:.*lm_head"],
ignore=["re:visual.*", "re:.*lm_head"],
dampening_frac=0.5,
)

Expand Down
2 changes: 1 addition & 1 deletion src/llmcompressor/modifiers/utils/pytorch_helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from itertools import cycle
from typing import Callable, Dict, List, Optional, Tuple, Any
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch
from torch.nn import Module
Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/pipelines/basic/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import torch
import torch.utils.data.dataloader
import tqdm
from compressed_tensors.utils import get_execution_device

from llmcompressor.modifiers.utils.pytorch_helpers import apply_pad_mask_to_batch
from llmcompressor.pytorch.utils.helpers import tensors_to_device
from llmcompressor.utils.helpers import calibration_forward_context
from compressed_tensors.utils import get_execution_device

__all__ = ["run_pipeline"]


def run_pipeline(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader):
model_device = get_execution_device(model)

with calibration_forward_context(model):
for batch in tqdm.tqdm(dataloader, desc="Calibrating"):
batch = apply_pad_mask_to_batch(batch)
Expand Down
5 changes: 4 additions & 1 deletion src/llmcompressor/transformers/utils/data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

__all__ = ["mllama_data_collator", "pixtral_data_collator"]


def mllama_data_collator(batch):
assert len(batch) == 1
return {
Expand All @@ -13,6 +14,7 @@ def mllama_data_collator(batch):
"cross_attention_mask": torch.tensor(batch[0]["cross_attention_mask"]),
}


def pixtral_data_collator(batch):
assert len(batch) == 1
return {
Expand All @@ -21,11 +23,12 @@ def pixtral_data_collator(batch):
"pixel_values": torch.tensor(batch[0]["pixel_values"])[0],
}


def qwen2_vl_data_collator(batch):
assert len(batch) == 1
return {
"input_ids": torch.LongTensor(batch[0]["input_ids"]),
"attention_mask": torch.tensor(batch[0]["attention_mask"]),
"pixel_values": torch.tensor(batch[0]["pixel_values"]),
"image_grid_thw": torch.tensor(batch[0]["image_grid_thw"]),
}
}

0 comments on commit c1e66e8

Please sign in to comment.