Commit cfaa23f
authored
AWQ Apply Scales Bugfix when smooth layer output length doesn't match balance layer input length (#1451)
### Summary
We are hitting an edge case in AWQ we had not previously hit with the
initial Llama/Qwen testing models. When a smooth layer's # of
output_features does not match a balance layer's # of input_features,
the code as it is currently will error out when trying to update the
smooth layer's weights with `weights.div(scales)`, due to a shape
mismatch error. We are hitting this in #1440 for Phi3 models, which
include a mapping between the fused `qkv_proj` smooth layer and `o_proj`
balance layer in AutoAWQ (see
[here](https://github.com/casper-hansen/AutoAWQ/blob/main/awq/models/phi3.py#L51-L57)).
The resolution in AutoAWQ is to only use the last rows of the smooth
layer so that the shapes line up, as shown
[here](https://github.com/casper-hansen/AutoAWQ/blob/main/awq/quantize/scale.py#L123).
This PR includes that update, and with #1440 will allow Phi3 models to
be quantizable with AWQModifier. Like with v_proj -> o_proj, if shapes
don't match up, they will be excluded from resolved mappings. This
allows
[phi-3-mini](https://huggingface.co/microsoft/Phi-3-mini-128k-instruct/tree/main?show_file_info=model-00001-of-00002.safetensors)
to include the mapping because `qkv_proj out_features == 3*o_proj
in_features == 9216`, but excludes it from
[phi-3-medium](https://huggingface.co/microsoft/Phi-3-medium-128k-instruct/tree/main?show_file_info=model-00001-of-00006.safetensors)
which has `qkv_proj out_features == 7680` and `o_proj
in_features==5120`. If the mapping is included for phi-3-medium, the
model blows up with wikitext eval perplexities >2000. This
implementation was agreed upon with @anmarques .
PS: I also moved `mul` & `div` to `mul_` & `div_`, to avoid unnecessary
memory allocation.
-------------
### Test Plan
With these changes and with #1440 , `examples/awq/llama_example.py`
works with `"microsoft/Phi-3-mini-128k-instruct"` and produces similar
results as when qkv_proj to o_proj mapping is included
Without mapping:
| Tasks |Version|Filter|n-shot| Metric | | Value | |Stderr|
|--------|------:|------|-----:|---------------|---|------:|---|------|
|wikitext| 2|none | 5|bits_per_byte |↓ | 0.6474|± | N/A|
| | |none | 5|byte_perplexity|↓ | 1.5664|± | N/A|
| | |none | 5|word_perplexity|↓ |11.0201|± | N/A|
With mapping:
| Tasks |Version|Filter|n-shot| Metric | | Value | |Stderr|
|--------|------:|------|-----:|---------------|---|------:|---|------|
|wikitext| 2|none | 5|bits_per_byte |↓ | 0.6482|± | N/A|
| | |none | 5|byte_perplexity|↓ | 1.5672|± | N/A|
| | |none | 5|word_perplexity|↓ |11.0527|± | N/A|
I also confirmed re-running with `meta-llama/Llama-3.2-3B-Instruct` and
`meta-llama/Llama-2-7b-hf` does not deviate in PPL scores from what is
currently on `main`
---------
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>1 parent cc2b5d9 commit cfaa23f
1 file changed
+31
-12
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
310 | 310 | | |
311 | 311 | | |
312 | 312 | | |
313 | | - | |
| 313 | + | |
314 | 314 | | |
315 | 315 | | |
316 | | - | |
317 | | - | |
318 | | - | |
| 316 | + | |
319 | 317 | | |
320 | | - | |
| 318 | + | |
| 319 | + | |
| 320 | + | |
| 321 | + | |
| 322 | + | |
| 323 | + | |
| 324 | + | |
| 325 | + | |
| 326 | + | |
| 327 | + | |
| 328 | + | |
| 329 | + | |
| 330 | + | |
321 | 331 | | |
322 | 332 | | |
323 | 333 | | |
| |||
466 | 476 | | |
467 | 477 | | |
468 | 478 | | |
469 | | - | |
470 | | - | |
471 | 479 | | |
472 | 480 | | |
473 | 481 | | |
| 482 | + | |
474 | 483 | | |
475 | | - | |
| 484 | + | |
| 485 | + | |
| 486 | + | |
| 487 | + | |
| 488 | + | |
476 | 489 | | |
477 | 490 | | |
478 | 491 | | |
479 | 492 | | |
480 | 493 | | |
481 | | - | |
| 494 | + | |
482 | 495 | | |
483 | 496 | | |
| 497 | + | |
| 498 | + | |
| 499 | + | |
| 500 | + | |
| 501 | + | |
| 502 | + | |
484 | 503 | | |
485 | 504 | | |
486 | 505 | | |
487 | | - | |
488 | | - | |
| 506 | + | |
| 507 | + | |
489 | 508 | | |
490 | 509 | | |
491 | 510 | | |
492 | 511 | | |
493 | 512 | | |
494 | 513 | | |
495 | | - | |
| 514 | + | |
496 | 515 | | |
497 | 516 | | |
498 | 517 | | |
| |||
0 commit comments