Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add uintx quant to generate and eval #811

Merged
merged 1 commit into from
Sep 5, 2024

Conversation

jerryzh168
Copy link
Contributor

@jerryzh168 jerryzh168 commented Sep 5, 2024

Summary:
att

Also rerun the benchmarks/eval for llama2/llama3 to get most recent perf/acc data

Test Plan:
torchao/_models/llama/generate.py
torchao/_models/llama/eval.py

llama2:

# torch.uint4, group_size = 64
python generate.py --compile --precision bfloat16 --quantization uintx-4-64
Average tokens/sec: 48.25
Average Bandwidth: 189.32 GB/s
Peak Memory Usage: 6.29 GB
Model Size: 3.92 GB

wikitext: {'word_perplexity,none': 12.890544846479484, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 1.612969956510788, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 0.6897195668279897, 'bits_per_byte_stderr,none': 'N/A', 'alias': 'wikitext'}

# torch.uint2, group_size = 8
python generate.py --compile --precision bfloat16 --quantization uintx-2-8
Average tokens/sec: 36.11
Average Bandwidth: 238.58 GB/s
Peak Memory Usage: 9.26 GB
Model Size: 6.61 GB

python eval.py --compile --precision bfloat16 --quantization uintx-2-8
wikitext: {'word_perplexity,none': 28.766343716897, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 1.8742120465648264, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 0.9062841873734042, 'bits_per_byte_stderr,none': 'N/A', 'alias': 'wikitext'}

llama3:

# torch.uint4, group_size = 64
python generate.py --compile --precision bfloat16 --checkpoint_path=../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --quantization uintx-4-64
Average tokens/sec: 47.77
Average Bandwidth: 212.90 GB/s
Peak Memory Usage: 11.85 GB
Model Size: 4.46 GB

wikitext: {'word_perplexity,none': 8.112931736704462, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 1.479179221121259, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 0.5647968636325521, 'bits_per_byte_stderr,none': 'N/A', 'alias': 'wikitext'}


# torch.uint2, group_size = 8
python generate.py --compile --precision bfloat16 --checkpoint_path=../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --quantization uintx-2-8
Average tokens/sec: 33.21
Average Bandwidth: 249.22 GB/s
Peak Memory Usage: 15.04 GB
Model Size: 7.51 GB

wikitext: {'word_perplexity,none': 39.36764348732592, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 1.98746296691363, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 0.9909279784106695, 'bits_per_byte_stderr,none': 'N/A', 'alias': 'wikitext'}






Reviewers:

Subscribers:

Tasks:

Tags:

Copy link

pytorch-bot bot commented Sep 5, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/811

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit d5ebc0e with merge base 317392d (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 5, 2024
@HDCharles
Copy link
Contributor

i would put the generate/eval results in a table somewhere, if you want to add them to the standard benchmarks you can add them to benchmarks.sh

also i would rebase on mine or you will have merge issues

@HDCharles
Copy link
Contributor

if eval is broken for you, can you send me the error?

@jerryzh168
Copy link
Contributor Author

if eval is broken for you, can you send me the error?

seems to be fine, it seems that int8wo and bfloat16 are just very close, I thought they were exactly the same before, but there is actually a slight difference

Summary:
att

Also rerun the benchmarks/eval for llama2/llama3 to get most recent perf/acc data

Test Plan:
torchao/_models/llama/generate.py
torchao/_models/llama/eval.py

Reviewers:

Subscribers:

Tasks:

Tags:
@jerryzh168
Copy link
Contributor Author

right now these are slow, we can add to benchmarks.sh later when the perf is better I think

@jerryzh168 jerryzh168 merged commit e05635e into pytorch:main Sep 5, 2024
17 checks passed
@jerryzh168 jerryzh168 deleted the benchmarks branch September 5, 2024 16:46
HDCharles pushed a commit that referenced this pull request Sep 9, 2024
Summary:
att

Also rerun the benchmarks/eval for llama2/llama3 to get most recent perf/acc data

Test Plan:
torchao/_models/llama/generate.py
torchao/_models/llama/eval.py

Reviewers:

Subscribers:

Tasks:

Tags:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants