Skip to content

Commit

Permalink
Adding Llama to TorchAO (#276)
Browse files Browse the repository at this point in the history
* Adding Llama to torchAO for a stable source of testing/benchmarking/demo

Summary: added models/llama with code for the model and tokenizer

this replicates previous functionality that was in test.

Test Plan:

python test/quantization/test_quant_api.py

Reviewers:

Subscribers:

Tasks:

Tags:

* Adding benchmarking generate and llama3

Summary: Adding stable benchmarks and generate for llama3. Llama3 eval
still has some issues, need to investigate.

Test Plan: sh benchmarks.sh

Reviewers:

Subscribers:

Tasks:

Tags:

* adding workaround so tiktoken doesn't error

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
HDCharles authored Jun 11, 2024
1 parent 71cf0a5 commit 61fef69
Show file tree
Hide file tree
Showing 10 changed files with 610 additions and 49 deletions.
78 changes: 59 additions & 19 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
get_symmetric_quantization_config,
)

import torchao
from torchao.dtypes import (
to_aq,
AffineQuantizedTensor,
Expand Down Expand Up @@ -49,8 +50,8 @@
TORCH_VERSION_AFTER_2_4,
)
from pathlib import Path
from sentencepiece import SentencePieceProcessor
from model import Transformer, prepare_inputs_for_model
from torchao._models.llama.tokenizer import get_tokenizer
from torchao._models.llama.model import Transformer, prepare_inputs_for_model
import copy


Expand Down Expand Up @@ -241,7 +242,8 @@ def test_8da4w_quantizer(self):
@unittest.skip("skipping until we get checkpoints for gpt-fast")
def test_8da4w_gptq_quantizer(self):
from torchao.quantization.GPTQ import Int8DynActInt4WeightGPTQQuantizer
from torchao._eval import InputRecorder, TransformerEvalWrapper
from torchao._models._eval import InputRecorder, TransformerEvalWrapper
torchao._models.llama.model.use_index_put_for_kv_cache = True
# should be similar to TorchCompileDynamicQuantizer
precision = torch.bfloat16
device = "cpu"
Expand All @@ -253,8 +255,9 @@ def test_8da4w_gptq_quantizer(self):
model.eval()
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), tokenizer_path
tokenizer = SentencePieceProcessor( # pyre-ignore[28]
model_file=str(tokenizer_path)
tokenizer = get_tokenizer( # pyre-ignore[28]
tokenizer_path,
"Llama-2-7b-chat-hf",
)
blocksize = 128
percdamp = 0.01
Expand Down Expand Up @@ -303,7 +306,7 @@ def test_8da4w_gptq_quantizer(self):
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
def test_8da4w_quantizer_eval(self):
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
from torchao._eval import TransformerEvalWrapper
from torchao._models._eval import TransformerEvalWrapper

precision = torch.bfloat16
device = "cpu"
Expand All @@ -315,8 +318,9 @@ def test_8da4w_quantizer_eval(self):
model.eval()
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), tokenizer_path
tokenizer = SentencePieceProcessor( # pyre-ignore[28]
model_file=str(tokenizer_path)
tokenizer = get_tokenizer( # pyre-ignore[28]
tokenizer_path,
"Llama-2-7b-chat-hf",
)

quantizer = Int8DynActInt4WeightQuantizer(groupsize=128, precision=precision)
Expand All @@ -338,7 +342,8 @@ def test_8da4w_quantizer_eval(self):
@unittest.skip("skipping until we get checkpoints for gpt-fast")
def test_gptq_quantizer_int4wo(self):
from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer
from torchao._eval import InputRecorder, TransformerEvalWrapper
from torchao._models._eval import InputRecorder, TransformerEvalWrapper
torchao._models.llama.model.use_index_put_for_kv_cache = True
precision = torch.bfloat16
device = "cuda"
checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
Expand All @@ -349,12 +354,13 @@ def test_gptq_quantizer_int4wo(self):
model.eval()
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), tokenizer_path
tokenizer = SentencePieceProcessor( # pyre-ignore[28]
model_file=str(tokenizer_path)
tokenizer = get_tokenizer( # pyre-ignore[28]
tokenizer_path,
"Llama-2-7b-chat-hf",
)
blocksize = 128
percdamp = 0.01
groupsize = 128
groupsize = 64
calibration_tasks = ["wikitext"]
calibration_limit = 1
calibration_seq_length = 100
Expand Down Expand Up @@ -398,7 +404,7 @@ def test_gptq_quantizer_int4wo(self):
@unittest.skip("skipping until we get checkpoints for gpt-fast")
def test_quantizer_int4wo(self):
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer
from torchao._eval import TransformerEvalWrapper
from torchao._models._eval import TransformerEvalWrapper
precision = torch.bfloat16
device = "cuda"
checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
Expand All @@ -409,10 +415,11 @@ def test_quantizer_int4wo(self):
model.eval()
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), tokenizer_path
tokenizer = SentencePieceProcessor( # pyre-ignore[28]
model_file=str(tokenizer_path)
tokenizer = get_tokenizer( # pyre-ignore[28]
tokenizer_path,
"Llama-2-7b-chat-hf",
)
groupsize = 128
groupsize = 64
quantizer = Int4WeightOnlyQuantizer(
groupsize,
)
Expand All @@ -433,7 +440,7 @@ def test_quantizer_int4wo(self):

@unittest.skip("skipping until we get checkpoints for gpt-fast")
def test_eval_wrapper(self):
from torchao._eval import TransformerEvalWrapper
from torchao._models._eval import TransformerEvalWrapper
precision = torch.bfloat16
device = "cuda"
checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
Expand All @@ -444,8 +451,9 @@ def test_eval_wrapper(self):
model.eval()
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), tokenizer_path
tokenizer = SentencePieceProcessor( # pyre-ignore[28]
model_file=str(tokenizer_path)
tokenizer = get_tokenizer( # pyre-ignore[28]
tokenizer_path,
"Llama-2-7b-chat-hf",
)
result=TransformerEvalWrapper(
model,
Expand All @@ -461,6 +469,38 @@ def test_eval_wrapper(self):
f"accuracy regressed from 7.76 to {result['results']['wikitext']['word_perplexity,none']}"
)

# EVAL IS CURRENTLY BROKEN FOR LLAMA 3, VERY LOW ACCURACY
@unittest.skip("skipping until we get checkpoints for gpt-fast")
def test_eval_wrapper_llama3(self):
from torchao._models._eval import TransformerEvalWrapper
precision = torch.bfloat16
device = "cuda"
checkpoint_path = Path(".../gpt-fast/checkpoints/meta-llama/Meta-Llama-3-8B/model.pth")
model = Transformer.from_name(checkpoint_path.parent.name)
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
model.load_state_dict(checkpoint, assign=True)
model = model.to(dtype=precision, device=device)
model.eval()
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), tokenizer_path
tokenizer = get_tokenizer( # pyre-ignore[28]
tokenizer_path,
"Meta-Llama-3-8B",
)
result = TransformerEvalWrapper(
model,
tokenizer,
model.config.block_size,
prepare_inputs_for_model,
device,
).run_eval(
["wikitext"],
1,
)
assert result['results']['wikitext']['word_perplexity,none'] < 8.24, (
f"accuracy regressed from 8.23 to {result['results']['wikitext']['word_perplexity,none']}"
)

# TODO: move to a separate test file
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
def test_quantized_tensor_subclass_8da4w(self):
Expand Down
Empty file added torchao/_models/__init__.py
Empty file.
6 changes: 4 additions & 2 deletions torchao/_eval.py → torchao/_models/_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn.functional as F

from quantization.utils import _lm_eval_available, _MultiInput
from torchao.quantization.utils import _lm_eval_available, _MultiInput

if _lm_eval_available:
import lm_eval
try: # lm_eval version 0.4
from lm_eval.evaluator import evaluate # pyre-ignore[21]
from lm_eval.models.huggingface import HFLM as eval_wrapper # pyre-ignore[21]
Expand Down Expand Up @@ -200,7 +202,7 @@ def _model_call(self, inps):
# TODO: make batches work
input = self.input_prep_func(inps)

max_seq_length = min(inps.size(1), self.max_length)
max_seq_length = min(max(inps.size()), self.max_length)
with torch.device(self._device):
self._model.setup_caches(self.batch_size, max_seq_length)
logits = self._model(*input)
Expand Down
Empty file.
8 changes: 8 additions & 0 deletions torchao/_models/llama/benchmark_results.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
20240610164534, tok/s= 94.91, mem/s=1424.58 GB/s, peak_mem=16.43 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../../gpt-fast/checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
20240610164738, tok/s=179.41, mem/s= 757.45 GB/s, peak_mem=23.44 GB, model_size= 4.22 GB quant: int4wo-64, mod: Meta-Llama-3-8B, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../../gpt-fast/checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
20240610164952, tok/s=136.75, mem/s=1028.38 GB/s, peak_mem=19.16 GB, model_size= 7.52 GB quant: int8wo, mod: Meta-Llama-3-8B, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../../gpt-fast/checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
20240610165423, tok/s= 8.41, mem/s= 63.23 GB/s, peak_mem=19.16 GB, model_size= 7.52 GB quant: int8dq, mod: Meta-Llama-3-8B, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8dq --checkpoint_path ../../../../gpt-fast/checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
20240610165618, tok/s=105.02, mem/s=1387.78 GB/s, peak_mem=13.88 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
20240610165808, tok/s=199.81, mem/s= 746.45 GB/s, peak_mem=15.92 GB, model_size= 3.74 GB quant: int4wo-64, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
20240610170005, tok/s=147.03, mem/s= 973.54 GB/s, peak_mem=14.50 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
20240610170408, tok/s= 9.40, mem/s= 62.26 GB/s, peak_mem=14.50 GB, model_size= 6.62 GB quant: int8dq, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8dq --checkpoint_path ../../../../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
21 changes: 21 additions & 0 deletions torchao/_models/llama/benchmarks.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
export CHECKPOINT_PATH=../../../../gpt-fast/checkpoints # path to checkpoints folder

export MODEL_REPO=meta-llama/Meta-Llama-3-8B
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt

export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt
Loading

0 comments on commit 61fef69

Please sign in to comment.