Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VLM Support via GPTQ Hooks and Sequential Data Pipeline #914

Draft
wants to merge 267 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
267 commits
Select commit Hold shift + click to select a range
83a5762
wip
kylesayrs Oct 16, 2024
7f49ab4
runnable
kylesayrs Oct 16, 2024
4539052
add ability to pass model class to support non-traditional (vision) m…
kylesayrs Oct 18, 2024
483744c
Merge remote-tracking branch 'origin' into kylesayrs/calculate_offloa…
kylesayrs Oct 18, 2024
9ae998b
update docstring
kylesayrs Oct 18, 2024
ac0d926
batching
kylesayrs Oct 21, 2024
6304973
calibration forward context
kylesayrs Oct 21, 2024
868a480
fix stuff
kylesayrs Oct 21, 2024
86c8a06
wip
kylesayrs Oct 21, 2024
1305173
use hooks list
kylesayrs Oct 21, 2024
e6adc5a
layer compressor
kylesayrs Oct 22, 2024
f65f832
style
kylesayrs Oct 22, 2024
1e22569
use layer compressor
kylesayrs Oct 22, 2024
9324695
replicate dtypes
kylesayrs Oct 22, 2024
eef4fb6
write weight changes
kylesayrs Oct 22, 2024
485813a
revert example
kylesayrs Oct 22, 2024
6006155
organization
kylesayrs Oct 22, 2024
c10d2ee
add create_single_batch_dataloader
kylesayrs Oct 22, 2024
6371193
add back empty_cache until I can justify removing it
kylesayrs Oct 22, 2024
92315a5
better type hinting, faster mask applying
kylesayrs Oct 22, 2024
8903fbf
Merge remote-tracking branch 'origin' into kylesayrs/gptq-hooks
kylesayrs Oct 22, 2024
8a25c68
remove breakpoint
kylesayrs Oct 22, 2024
6cd0d6c
apply style, add true_sequential docstring
kylesayrs Oct 22, 2024
0e0c586
update docstring
kylesayrs Oct 22, 2024
d23aabb
use private attrs
kylesayrs Oct 22, 2024
355074b
more docstring
kylesayrs Oct 23, 2024
bf2184d
docstrings
kylesayrs Oct 23, 2024
0b418c7
docstrings
kylesayrs Oct 23, 2024
56cceea
docstrings
kylesayrs Oct 23, 2024
7c7e3bc
move hooksmixin to separate file
kylesayrs Oct 23, 2024
2d52183
docstrings
kylesayrs Oct 23, 2024
d6ff46a
Merge branch 'main' into kylesayrs/gptq-hooks
kylesayrs Oct 23, 2024
9081f12
fix docstring, better arguments grouping
kylesayrs Oct 23, 2024
96e9496
use LayerCompressorMixin
kylesayrs Oct 24, 2024
7fbf8b1
docstrings
kylesayrs Oct 24, 2024
3d3af2a
add back hessian hook to support bs1
kylesayrs Oct 24, 2024
b3021ab
wip
kylesayrs Oct 25, 2024
8508b63
accumulate
kylesayrs Oct 25, 2024
3ff271d
virtualize batches for layers
kylesayrs Oct 25, 2024
d6c6dc3
maybe works, but padding is wrong
kylesayrs Oct 25, 2024
400fa08
WIP
kylesayrs Oct 29, 2024
c4d2dde
revert weird batching, support image text datasets
kylesayrs Oct 29, 2024
670b35e
remove breakpoint
kylesayrs Oct 29, 2024
3892b90
add example script
kylesayrs Oct 29, 2024
03515f0
remove hessian
kylesayrs Oct 29, 2024
6e37f64
allocated original weight
kylesayrs Oct 29, 2024
09dae14
proper clone
kylesayrs Oct 29, 2024
944601e
remove breakpoint
kylesayrs Oct 29, 2024
adbcee8
naive_update option
kylesayrs Oct 29, 2024
f4acab2
remove true sequential
kylesayrs Oct 29, 2024
151f566
allow update_offload_parameter to not require data
kylesayrs Oct 29, 2024
76ebc86
bugfix
kylesayrs Oct 29, 2024
3480d6b
ba
kylesayrs Oct 29, 2024
7c55fc5
delete parameter
kylesayrs Oct 29, 2024
0a8004b
sensible generations for small calibration size
kylesayrs Oct 30, 2024
d234b32
remove unnecessary variables
kylesayrs Oct 30, 2024
eeb5c83
remove non-naive updating stuff to focus on naive updating
kylesayrs Oct 30, 2024
99a2d97
Merge remote-tracking branch 'origin' into kylesayrs/gptq-steps
kylesayrs Nov 1, 2024
c7c8d04
use observer to calculate qparams
kylesayrs Nov 1, 2024
2beb59a
remove tokenizer args
kylesayrs Nov 5, 2024
4a336fe
fix shapes
kylesayrs Nov 5, 2024
f137347
complete, more or less
kylesayrs Nov 5, 2024
593d4fd
support vision datasets
kylesayrs Nov 5, 2024
0bdf98a
use pixtral
kylesayrs Nov 8, 2024
9f43b5d
better stopping
kylesayrs Nov 8, 2024
3d224db
implement partitioned model
kylesayrs Nov 13, 2024
4872242
working, although still a little higher memory usage than expected
kylesayrs Nov 14, 2024
7fa5c3c
offload intermediates
kylesayrs Nov 14, 2024
1c45963
cleanup
kylesayrs Nov 14, 2024
65b3e5b
better comments, support sending non-tensors to device
kylesayrs Nov 14, 2024
53d0601
remove breakpoint, fix move_tensors_to_device
kylesayrs Nov 14, 2024
4da451b
woof
kylesayrs Nov 14, 2024
c77a7fc
fix thing
kylesayrs Nov 14, 2024
9249434
remove LayerCompressorMixin, add hooks tests
kylesayrs Nov 14, 2024
2690e10
Implement HooksMixin
kylesayrs Nov 14, 2024
004f5c7
add docstring
kylesayrs Nov 15, 2024
d3058f0
integrate with smoothquant
kylesayrs Nov 15, 2024
1ae3ce0
integrate with QuantizationModifier
kylesayrs Nov 15, 2024
fc2488f
update hooks in tests
kylesayrs Nov 15, 2024
d0dc807
integrate with wanda
kylesayrs Nov 15, 2024
55f69d6
integrate with magnitude and constant
kylesayrs Nov 15, 2024
59ffe44
integrate with SparseGPTModifier
kylesayrs Nov 15, 2024
21fe61b
add hooksmixin to modifier
kylesayrs Nov 15, 2024
ba01137
Merge remote-tracking branch 'origin' into kylesayrs/HooksMixin
kylesayrs Nov 15, 2024
3771a89
Merge remote-tracking branch 'origin' into kylesayrs/HooksMixin
kylesayrs Nov 18, 2024
ccc5458
Merge branch 'kylesayrs/HooksMixin' into kylesayrs/gptq-partition
kylesayrs Nov 19, 2024
a5635a1
merge
kylesayrs Nov 19, 2024
83ed409
small updates
kylesayrs Nov 19, 2024
7fd142b
Merge branch 'main' into kylesayrs/HooksMixin
kylesayrs Nov 19, 2024
d104282
WIP
kylesayrs Nov 20, 2024
236a47a
WIP
kylesayrs Nov 21, 2024
188896e
able to run without hooks
kylesayrs Nov 21, 2024
8ef9c23
issue with different sizes
kylesayrs Nov 21, 2024
1362ca2
able to run through pixtral without issue and using real proxy tensor…
kylesayrs Nov 21, 2024
0539df7
nits
kylesayrs Nov 25, 2024
a734393
Merge remote-tracking branch 'origin' into kylesayrs/HooksMixin
kylesayrs Nov 25, 2024
ea10aed
Merge branch 'kylesayrs/HooksMixin' into kylesayrs/gptq-partition
kylesayrs Nov 25, 2024
ed96ee4
fix all variable
kylesayrs Nov 25, 2024
5f26711
tmp
kylesayrs Nov 25, 2024
ebc2c41
wip
kylesayrs Nov 26, 2024
922b407
wip
kylesayrs Nov 26, 2024
0577f36
testing with lots of models
kylesayrs Nov 26, 2024
3830696
preliminary data pipeline
kylesayrs Nov 26, 2024
1ecaa39
WIP
kylesayrs Nov 26, 2024
9aa9679
delete unnecessary files
kylesayrs Nov 26, 2024
7e6fe17
Merge remote-tracking branch 'origin' into kylesayrs/gptq-partition
kylesayrs Nov 26, 2024
034c0b1
Merge branch 'kylesayrs/gptq-hooks' into kylesayrs/gptq-partition
kylesayrs Nov 26, 2024
a62617c
clean up CustomDataset
kylesayrs Nov 28, 2024
57b5e02
chchchchanges
kylesayrs Nov 29, 2024
fa317fd
wip: use rename to processor, going through tests
kylesayrs Dec 2, 2024
f3f5875
remove labels from calibration dataset rather than assuming that all …
kylesayrs Dec 2, 2024
58c3afe
cleanup
kylesayrs Dec 2, 2024
72aecfc
cleanup, etc
kylesayrs Dec 2, 2024
77217fb
Merge remote-tracking branch 'origin' into kylesayrs/cleanup-custom-d…
kylesayrs Dec 2, 2024
4461a3e
fix typehinting
kylesayrs Dec 2, 2024
fb33001
add typechecking imports
kylesayrs Dec 2, 2024
bf4744a
remove sparseml utilities
kylesayrs Dec 3, 2024
62ae31d
Merge branch 'kylesayrs/remove-sparseml-utilities' into kylesayrs/cle…
kylesayrs Dec 3, 2024
7e516c1
use in model_load
kylesayrs Dec 3, 2024
d69106e
Merge branch 'main' into kylesayrs/calculate_offload_default_gpus
kylesayrs Dec 3, 2024
9e33641
remove use of RECIPE FILE NAME
kylesayrs Dec 3, 2024
58c0fba
rename to RECIPE_FILE_NAME, avoid circular import
kylesayrs Dec 3, 2024
b28aaae
Merge branch 'kylesayrs/remove-sparseml-utilities' into kylesayrs/cle…
kylesayrs Dec 3, 2024
8d13013
image dataset collation
kylesayrs Dec 3, 2024
17cf9f3
Merge branch 'kylesayrs/cleanup-custom-dataset' into kylesayrs/gptq-p…
kylesayrs Dec 3, 2024
163ee8f
cleanup, do not handle case where processor is None
kylesayrs Dec 3, 2024
1180b34
remove qa ignore
kylesayrs Dec 3, 2024
ad20ae7
Merge branch 'kylesayrs/remove-sparseml-utilities' into kylesayrs/cle…
kylesayrs Dec 3, 2024
c431958
add documentation
kylesayrs Dec 3, 2024
b48d55d
add data collator arg
kylesayrs Dec 3, 2024
2d201e0
Merge branch 'kylesayrs/cleanup-custom-dataset' into kylesayrs/gptq-p…
kylesayrs Dec 3, 2024
0ed5c2c
use default factor
kylesayrs Dec 3, 2024
ca61e90
Merge branch 'kylesayrs/cleanup-custom-dataset' into kylesayrs/gptq-p…
kylesayrs Dec 3, 2024
41dd463
wip mllama
kylesayrs Dec 4, 2024
8527e0e
cleanup
kylesayrs Dec 4, 2024
0a8a03f
merge-implement hessian offloading
kylesayrs Dec 4, 2024
fc044e2
better concrete arg handling
kylesayrs Dec 4, 2024
4576712
validate flickr
kylesayrs Dec 4, 2024
5276c58
discover bug, tests and multimodal working
kylesayrs Dec 4, 2024
dffcbc3
dataset split fallbacks
kylesayrs Dec 4, 2024
b3cb229
Merge branch 'kylesayrs/cleanup-custom-dataset' into kylesayrs/gptq-p…
kylesayrs Dec 4, 2024
779c9a2
Merge branch 'kylesayrs/dataset-split-fallbacks' into kylesayrs/clean…
kylesayrs Dec 4, 2024
85e3f59
Merge branch 'kylesayrs/cleanup-custom-dataset' into kylesayrs/gptq-p…
kylesayrs Dec 4, 2024
e9f150d
move typing
kylesayrs Dec 4, 2024
d061567
cleanup, depreciate remove_columns argument
kylesayrs Dec 4, 2024
55a31ca
silently assign tokenizer to processor
kylesayrs Dec 5, 2024
c14e40e
Merge branch 'kylesayrs/cleanup-custom-dataset' into kylesayrs/gptq-p…
kylesayrs Dec 5, 2024
1aba16d
replace tokenizer with processor
kylesayrs Dec 5, 2024
135e459
Merge branch 'kylesayrs/processor-replaces-tokenizer' into kylesayrs/…
kylesayrs Dec 5, 2024
dde2fa7
Merge branch 'kylesayrs/cleanup-custom-dataset' into kylesayrs/gptq-p…
kylesayrs Dec 5, 2024
89bda30
defer data collator changes
kylesayrs Dec 5, 2024
0fa4102
reduce warnings
kylesayrs Dec 5, 2024
bc505bf
typehinting, add not-implemented error
kylesayrs Dec 5, 2024
c91ba77
remove todos
kylesayrs Dec 5, 2024
e916936
Delete mllama.py
kylesayrs Dec 5, 2024
0a573a1
update dataset manager api in tests
kylesayrs Dec 5, 2024
853c0a8
typehinting, add not-implemented error
kylesayrs Dec 5, 2024
234ef79
remove todos
kylesayrs Dec 5, 2024
8972dd5
update dataset manager api in tests
kylesayrs Dec 5, 2024
acb1a18
Delete examples/multimodal_vision/qwen_vl2.py
kylesayrs Dec 5, 2024
56b5d12
Delete examples/multimodal_vision/mllama.py
kylesayrs Dec 5, 2024
57c293e
WIP: add pixtral
kylesayrs Dec 5, 2024
537c5ab
pixtral working
kylesayrs Dec 5, 2024
15b3508
move to data pipeline
kylesayrs Dec 6, 2024
42b5fc0
disable_hf_hook context
kylesayrs Dec 6, 2024
bc33e8e
woof
kylesayrs Dec 6, 2024
ca72bbb
change desc
kylesayrs Dec 6, 2024
293640a
fix docstring
kylesayrs Dec 6, 2024
17b3a70
rely on compressed tensors, support offloading
kylesayrs Dec 6, 2024
5e185f2
sequential targets
kylesayrs Dec 6, 2024
4d82180
support match_layers_params
kylesayrs Dec 6, 2024
6a1b2c2
make _update_size private and inferred
kylesayrs Dec 6, 2024
f9ab6fc
make a module
kylesayrs Dec 6, 2024
0dc74dd
fallback
kylesayrs Dec 6, 2024
9e07188
implement basic pipeline
kylesayrs Dec 6, 2024
ed099ef
balance between gpus
kylesayrs Dec 6, 2024
4bbbc49
add proper ignore list
kylesayrs Dec 6, 2024
ae74f45
treat offloaded modules as leaves, treat ignore as sequential target
kylesayrs Dec 7, 2024
31eeb8c
redisable piecewise for vision datasets
kylesayrs Dec 7, 2024
1b24090
implement pipeline fallback
kylesayrs Dec 9, 2024
d97ef2b
Merge remote-tracking branch 'origin' into kylesayrs/processor-replac…
kylesayrs Dec 9, 2024
e87e019
remove subbatch event
kylesayrs Dec 9, 2024
d5c08fb
input device inference
kylesayrs Dec 9, 2024
39ed8ca
do not disable hf hook during tracing
kylesayrs Dec 9, 2024
47ca742
Merge remote-tracking branch 'origin' into kylesayrs/gptq-partition
kylesayrs Dec 9, 2024
c1f5cb2
Merge remote-tracking branch 'origin' into kylesayrs/cleanup-custom-d…
kylesayrs Dec 9, 2024
4711e9f
remove import
kylesayrs Dec 9, 2024
e468197
use find_nodes
kylesayrs Dec 9, 2024
f8591ca
rename piecewise to sequential
kylesayrs Dec 9, 2024
cea02d2
add docstring
kylesayrs Dec 9, 2024
f1f6c0f
begin sequential pipeline testing
kylesayrs Dec 9, 2024
3b0b49f
remove todos, add tests for sequential pipeline
kylesayrs Dec 10, 2024
2c035b3
move function placement
kylesayrs Dec 10, 2024
b93868d
slight partition algorithm change
kylesayrs Dec 10, 2024
146e4be
revert llama3 example
kylesayrs Dec 10, 2024
0e4d8f3
Merge branch 'main' into kylesayrs/dataset-split-fallbacks
kylesayrs Dec 10, 2024
b8e867d
Merge branch 'main' into kylesayrs/processor-replaces-tokenizer
kylesayrs Dec 10, 2024
ccb007f
remove test, fix default in order to fix tests
kylesayrs Dec 10, 2024
e1055b0
bump memory requirements
kylesayrs Dec 11, 2024
70421ed
fix memory and offloading issues
kylesayrs Dec 12, 2024
b102bf5
add missing cache file
kylesayrs Dec 12, 2024
229d3ae
make mllama tracable
kylesayrs Dec 12, 2024
4e0b118
write using comprehesion
kylesayrs Dec 12, 2024
7dc4d2a
fix hessian requirements
kylesayrs Dec 12, 2024
377b2a4
implement offloading for tuple
kylesayrs Dec 12, 2024
adb1627
add save
kylesayrs Dec 12, 2024
ab3fc81
change num samples
kylesayrs Dec 12, 2024
1bf683e
implement intermediates offloading for dataclasses
kylesayrs Dec 12, 2024
8918917
Merge branch 'main' into kylesayrs/processor-replaces-tokenizer
kylesayrs Dec 12, 2024
b75fe15
wrap ignore but do not treat as sequential target
kylesayrs Dec 13, 2024
aa4a23d
tracable pixtral/mistral
kylesayrs Dec 13, 2024
aa532b5
remove double saving
kylesayrs Dec 13, 2024
19e4f97
revert dampening frac
kylesayrs Dec 13, 2024
f95b77f
do not cache model outputs to save memory
kylesayrs Dec 13, 2024
2d890db
fix dataclass case, add tests
kylesayrs Dec 13, 2024
7e69b9d
Merge remote-tracking branch 'origin' into kylesayrs/gptq-partition
kylesayrs Dec 13, 2024
4a22032
Remove docstring
kylesayrs Dec 13, 2024
8d72269
Merge branch 'main' into kylesayrs/processor-replaces-tokenizer
kylesayrs Dec 14, 2024
a71352a
move IntermediatesCache location
kylesayrs Dec 14, 2024
2d249a2
add fake_sequential
kylesayrs Dec 14, 2024
995cb2d
rename fake_sequential to layer_sequential
kylesayrs Dec 14, 2024
e4bca34
pipeline inference
kylesayrs Dec 14, 2024
4a046a5
update docstrings
kylesayrs Dec 14, 2024
f24a2af
fix last layer bug
kylesayrs Dec 14, 2024
691bac4
better inference
kylesayrs Dec 14, 2024
1e15d3e
even better inference
kylesayrs Dec 14, 2024
a4744d9
do now throw warning for calibration with training
kylesayrs Dec 16, 2024
9617e53
add information about how to silence warning
kylesayrs Dec 16, 2024
3b4cac1
nice
kylesayrs Dec 16, 2024
f53a3dd
remove unnecessary warning silencing
kylesayrs Dec 16, 2024
f45d0fa
Merge branch 'kylesayrs/processor-replaces-tokenizer', remote-trackin…
kylesayrs Dec 16, 2024
70a2811
Merge branch 'kylesayrs/dataset-split-fallbacks' into kylesayrs/gptq-…
kylesayrs Dec 16, 2024
fd151e4
add unmerged thing
kylesayrs Dec 16, 2024
d1d42de
fix deleted columns
kylesayrs Dec 16, 2024
92151a1
handle dataset dict case
kylesayrs Dec 17, 2024
4c049db
support torch.nn.Conv2d, silently ignore embeddings
kylesayrs Dec 17, 2024
7667998
handle columns better
kylesayrs Dec 17, 2024
f0eb640
fix tokenizer args
kylesayrs Dec 18, 2024
af86f45
filter_tokenizer_args
kylesayrs Dec 18, 2024
5567a90
Merge remote-tracking branch 'origin' into kylesayrs/gptq-partition
kylesayrs Dec 18, 2024
0438e17
Merge remote-tracking branch 'origin' into kylesayrs/cleanup-custom-d…
kylesayrs Dec 18, 2024
9b61145
update docstring
kylesayrs Dec 18, 2024
2f65d01
remove unused util
kylesayrs Dec 18, 2024
338d1cb
remove debug
kylesayrs Dec 18, 2024
f4fa9c3
more tests
kylesayrs Dec 18, 2024
6bd1721
Merge remote-tracking branch 'origin' into kylesayrs/cleanup-custom-d…
kylesayrs Dec 18, 2024
e757e61
remove duplicate file
kylesayrs Dec 18, 2024
bdfa3d4
better help texts
kylesayrs Dec 18, 2024
cd9dd21
Merge branch 'kylesayrs/cleanup-custom-dataset' into kylesayrs/gptq-p…
kylesayrs Dec 18, 2024
f674579
Merge branch 'kylesayrs/calculate_offload_default_gpus' into kylesayr…
kylesayrs Dec 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions examples/multimodal_vision/mllama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import os

import torch
from transformers import AutoProcessor

from llmcompressor.modifiers.quantization import GPTQModifier

# from llmcompressor.pytorch.data_collator import DataCollator
from llmcompressor.transformers import oneshot
from llmcompressor.transformers.tracing import TracableMllamaForConditionalGeneration

# Load model.
model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
model = TracableMllamaForConditionalGeneration.from_pretrained(
model_id, device_map="auto", torch_dtype="auto"
)
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)

# Oneshot arguments
DATASET_ID = "flickr30k"
DATASET_SPLIT = "test[:512]"
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048


# TODO: define real collators in utils
def data_collator(batch):
assert len(batch) == 1
return {
"input_ids": torch.LongTensor(batch[0]["input_ids"]),
"attention_mask": torch.tensor(batch[0]["attention_mask"]),
"pixel_values": torch.tensor(batch[0]["pixel_values"]),
"aspect_ratio_ids": torch.tensor(batch[0]["aspect_ratio_ids"]),
"aspect_ratio_mask": torch.tensor(batch[0]["aspect_ratio_mask"]),
"cross_attention_mask": torch.tensor(batch[0]["cross_attention_mask"]),
}


# Recipe
recipe = [
# SmoothQuantModifier(smoothing_strength=0.8, ignore=ignore),
GPTQModifier(
targets="Linear",
scheme="W8A8",
ignore=["re:.*lm_head", "re:multi_modal_projector.*", "re:vision_model.*"],
),
]

# Perform oneshot
save_name = model_id.split("/")[1] + "-W8A8"
save_path = os.path.join("./my_test/", save_name)
print("Starting quantization")
oneshot(
model=model,
tokenizer=model_id,
dataset=DATASET_ID,
splits=DATASET_SPLIT,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
trust_remote_code_model=True,
output_dir=save_path,
data_collator=data_collator,
# data_collator=DataCollator(),
)

# Confirm generations of the quantized model look sane.
print("========== SAMPLE GENERATION ==============")
input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=20)
print(processor.decode(output[0]))
print("==========================================")
70 changes: 70 additions & 0 deletions examples/multimodal_vision/pixtral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import os

import torch
from transformers import AutoProcessor

from llmcompressor.modifiers.quantization import GPTQModifier

# from llmcompressor.pytorch.data_collator import DataCollator
from llmcompressor.transformers import oneshot
from llmcompressor.transformers.tracing import TracableLlavaForConditionalGeneration

# Load model.
model_id = "mgoin/pixtral-12b"
model = TracableLlavaForConditionalGeneration.from_pretrained(
model_id, device_map="auto", torch_dtype="auto"
)
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)

# Oneshot arguments
DATASET_ID = "flickr30k"
DATASET_SPLIT = "test[:512]"
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 2048


# TODO: define real collators in utils
def data_collator(batch):
assert len(batch) == 1
return {
"input_ids": torch.LongTensor(batch[0]["input_ids"]),
"attention_mask": torch.tensor(batch[0]["attention_mask"]),
"pixel_values": torch.tensor(batch[0]["pixel_values"])[0],
}


# Recipe
recipe = [
# SmoothQuantModifier(smoothing_strength=0.8, ignore=ignore),
GPTQModifier(
targets="Linear",
scheme="W8A8",
ignore=["re:.*lm_head", "re:vision_tower.*", "re:multi_modal_projector.*"],
sequential_targets=["MistralDecoderLayer"],
),
]

# Perform oneshot
save_name = model_id.split("/")[1] + "-W8A8"
save_path = os.path.join("./my_test/", save_name)
print("Starting quantization")
oneshot(
model=model,
tokenizer=model_id,
dataset=DATASET_ID,
splits=DATASET_SPLIT,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
trust_remote_code_model=True,
output_dir=save_path,
data_collator=data_collator,
# data_collator=DataCollator(),
)

# Confirm generations of the quantized model look sane.
print("========== SAMPLE GENERATION ==============")
input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=20)
print(processor.decode(output[0]))
print("==========================================")
6 changes: 6 additions & 0 deletions examples/multimodal_vision/pixtral_large.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from transformers import AutoProcessor

processor = AutoProcessor.from_pretrained(
"mistral-community/Pixtral-Large-Instruct-2411"
)
processor = AutoProcessor.from_pretrained("mgoin/pixtral-12b")
88 changes: 88 additions & 0 deletions examples/multimodal_vision/qwen_vl2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import os

import torch
from compressed_tensors.quantization import (
QuantizationArgs,
QuantizationScheme,
QuantizationStrategy,
QuantizationType,
)
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration

from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor.transformers import oneshot

# Load model.
model_id = "Qwen/Qwen2-VL-2B-Instruct"
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_id, device_map="auto", torch_dtype="auto"
)
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)

# Oneshot arguments
DATASET_ID = "flickr30k"
DATASET_SPLIT = "test[:3]"
NUM_CALIBRATION_SAMPLES = 1
MAX_SEQUENCE_LENGTH = 2048


# TODO: define real collators in utils
def data_collator(batch):
assert len(batch) == 1
return {
"input_ids": torch.LongTensor(batch[0]["input_ids"]),
"attention_mask": torch.tensor(batch[0]["attention_mask"]),
"pixel_values": torch.tensor(
batch[0]["pixel_values"]
), # torch.Size([14308, 1176])
"image_grid_thw": torch.tensor(batch[0]["image_grid_thw"]),
}


# Recipe
recipe = GPTQModifier(
targets="Linear",
config_groups={
"config_group": QuantizationScheme(
targets=["Linear"],
weights=QuantizationArgs(
num_bits=4,
type=QuantizationType.INT,
strategy=QuantizationStrategy.GROUP,
group_size=128,
symmetric=True,
dynamic=False,
actorder="dynamic",
),
),
},
ignore=["re:.*lm_head"],
dampening_frac=0.5,
)

# Perform oneshot
save_name = model_id.split("/")[1] + "-W8A8"
save_path = os.path.join("./my_test/", save_name)
print("Starting quantization")
oneshot(
model=model,
tokenizer=model_id,
# dataset=ds,
dataset=DATASET_ID,
splits=DATASET_SPLIT,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
trust_remote_code_model=True,
output_dir=save_path,
data_collator=data_collator,
)

processor.save_pretrained(save_path)

# Confirm generations of the quantized model look sane.
print("========== SAMPLE GENERATION ==============")
input_ids = processor(text="Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=20)
print(processor.decode(output[0]))
print("==========================================")
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ profile = "black"
files = "src/guidellm"

[tool.ruff]
exclude = ["build", "dist", "env", ".venv"]
exclude = ["build", "dist", "env", ".venv", "src/llmcompressor/pytorch/tracing"]
lint.select = ["E", "F", "W"]

[tool.flake8]
Expand Down
Loading
Loading