-
Notifications
You must be signed in to change notification settings - Fork 315
AWQModifier fast resolve mappings, better logging, MoE support #1444
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
4ec91c1
squashed/rebased
brian-dellabetta 8dc118f
test fixes
brian-dellabetta c001953
fast resolution, still failing on Qwen 3 MoE
brian-dellabetta 1eeac0c
stylefixes
brian-dellabetta a294646
skip if no activations
brian-dellabetta 665e554
working with Qwen MoE
brian-dellabetta 4976ac3
update get_lowest_common_parent
brian-dellabetta 9ab1ca1
mappings reorg
brian-dellabetta 3f364fe
include awq Qwen MoE example
brian-dellabetta 2eac1c3
cleanup
brian-dellabetta e31074f
unit tests
brian-dellabetta 75a1602
test updates
brian-dellabetta 96d5f59
codreview updates
brian-dellabetta 4c91f4f
Merge branch 'main' into bdellabe/awq-fast-resolve-mappings
kylesayrs File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,82 @@ | ||
| from datasets import load_dataset | ||
| from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
|
||
| from llmcompressor import oneshot | ||
| from llmcompressor.modifiers.awq import AWQModifier | ||
|
|
||
| # Select model and load it. | ||
| MODEL_ID = "Qwen/Qwen3-30B-A3B" | ||
|
|
||
| model = AutoModelForCausalLM.from_pretrained( | ||
| MODEL_ID, device_map="auto", torch_dtype="auto" | ||
| ) | ||
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) | ||
|
|
||
| # Select calibration dataset. | ||
| DATASET_ID = "mit-han-lab/pile-val-backup" | ||
| DATASET_SPLIT = "validation" | ||
|
|
||
| # Select number of samples. 256 samples is a good place to start. | ||
| # Increasing the number of samples can improve accuracy. | ||
| NUM_CALIBRATION_SAMPLES = 256 | ||
| MAX_SEQUENCE_LENGTH = 512 | ||
|
|
||
| # Load dataset and preprocess. | ||
| ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") | ||
| ds = ds.shuffle(seed=42) | ||
|
|
||
|
|
||
| def preprocess(example): | ||
| return { | ||
| "text": tokenizer.apply_chat_template( | ||
| [{"role": "user", "content": example["text"]}], | ||
| tokenize=False, | ||
| ) | ||
| } | ||
|
|
||
|
|
||
| ds = ds.map(preprocess) | ||
|
|
||
|
|
||
| # Tokenize inputs. | ||
| def tokenize(sample): | ||
| return tokenizer( | ||
| sample["text"], | ||
| padding=False, | ||
| max_length=MAX_SEQUENCE_LENGTH, | ||
| truncation=True, | ||
| add_special_tokens=False, | ||
| ) | ||
|
|
||
|
|
||
| # Configure the quantization algorithm to run. | ||
| # NOTE: vllm currently does not support asym MoE, using symmetric here | ||
| recipe = [ | ||
| AWQModifier( | ||
| ignore=["lm_head", "re:.*mlp.gate$", "re:.*mlp.shared_expert_gate$"], | ||
| scheme="W4A16", | ||
| targets=["Linear"], | ||
| ), | ||
| ] | ||
|
|
||
| # Apply algorithms. | ||
| oneshot( | ||
| model=model, | ||
| dataset=ds, | ||
| recipe=recipe, | ||
| max_seq_length=MAX_SEQUENCE_LENGTH, | ||
| num_calibration_samples=NUM_CALIBRATION_SAMPLES, | ||
| ) | ||
|
|
||
| # Confirm generations of the quantized model look sane. | ||
| print("\n\n") | ||
| print("========== SAMPLE GENERATION ==============") | ||
| input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") | ||
| output = model.generate(input_ids, max_new_tokens=100) | ||
| print(tokenizer.decode(output[0])) | ||
| print("==========================================\n\n") | ||
|
|
||
| # Save to disk compressed. | ||
| SAVE_DIR = MODEL_ID.split("/")[-1] + "-awq-sym" | ||
| model.save_pretrained(SAVE_DIR, save_compressed=True) | ||
| tokenizer.save_pretrained(SAVE_DIR) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.