Skip to content

Commit

Permalink
add llama2 to tensorrt
Browse files Browse the repository at this point in the history
  • Loading branch information
dpower4 committed Nov 21, 2024
1 parent 50452ef commit c869c95
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 1 deletion.
31 changes: 30 additions & 1 deletion superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
import torch.hub
import torch.onnx
import torchvision.models
from transformers import BertConfig, GPT2Config
from transformers import BertConfig, GPT2Config, LlamaConfig

from superbench.benchmarks.model_benchmarks.pytorch_bert import BertBenchmarkModel
from superbench.benchmarks.model_benchmarks.pytorch_gpt2 import GPT2BenchmarkModel
from superbench.benchmarks.model_benchmarks.pytorch_lstm import LSTMBenchmarkModel
from superbench.benchmarks.model_benchmarks.pytorch_llama import LlamaBenchmarkModel


class torch2onnxExporter():
Expand Down Expand Up @@ -87,6 +88,34 @@ def __init__(self):
),
self.num_classes,
),
'llama2-7b':
lambda: LlamaBenchmarkModel(
LlamaConfig(
hidden_size=4096,
num_hidden_layers=32,
num_attention_heads=32,
),
self.num_classes,
),
'llama2-13b':
lambda: LlamaBenchmarkModel(
LlamaConfig(
hidden_size=5120,
num_hidden_layers=40,
num_attention_heads=40,
),
self.num_classes,
),
'llama2-70b':
lambda: LlamaBenchmarkModel(
LlamaConfig(
hidden_size=8192,
num_hidden_layers=80,
num_attention_heads=64,
num_key_value_heads=8,
),
self.num_classes,
),
}
self._onnx_model_path = Path(torch.hub.get_dir()) / 'onnx'
self._onnx_model_path.mkdir(parents=True, exist_ok=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,14 @@ def test_tensorrt_inference_params(self):
'seq_length': 128,
'iterations': 256,
},
# disabling this as it requires ~30GB GPU memory
# {
# 'pytorch_models': ['llama2-7b'],
# 'precision': 'fp16',
# 'batch_size': 1,
# 'seq_length': 32,
# 'iterations': 32,
# },
]
for test_case in test_cases:
with self.subTest(msg='Testing with case', test_case=test_case):
Expand Down

0 comments on commit c869c95

Please sign in to comment.