Skip to content

Benchmark intel xpu #1259

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 29 additions & 17 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
def device_sync(device):
if "cuda" in device:
torch.cuda.synchronize(device)
elif "xpu" in device:
torch.xpu.synchronize(device)
elif ("cpu" in device) or ("mps" in device):
pass
else:
print(f"device={device} is not yet suppported")

default_device = 'cuda' if torch.cuda.is_available() else 'cpu'
default_device = 'cuda' if torch.cuda.is_available() else 'xpu' if torch.xpu.is_available() else 'cpu'

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
Expand Down Expand Up @@ -415,10 +417,13 @@ def main(
prefill = torch.compile(prefill, fullgraph=True, dynamic=True)

if memory_profile:
if device != "cuda":
print("Memory profiling only works on CUDA")
else:
if device == "cuda":
torch.cuda.memory._record_memory_history(True,trace_alloc_max_entries=250000, trace_alloc_record_context=True)
elif device == "xpu":
torch.xpu.memory._record_memory_history(True,trace_alloc_max_entries=250000, trace_alloc_record_context=True)
else:
print("Memory profiling only works on CUDA or XPU devices")

aggregate_metrics = {
'tokens_per_sec': [],
}
Expand All @@ -428,6 +433,8 @@ def main(
if i==0:
if device == "cuda":
torch.cuda.reset_peak_memory_stats() # MKG
elif device == "xpu":
torch.xpu.reset_peak_memory_stats() # MKG
device_sync(device=device) # MKG
if i >= 0 and interactive:
prompt = input("What is your prompt? ")
Expand Down Expand Up @@ -495,24 +502,29 @@ def callback(x):
print(f"Bandwidth achieved: {model_size * tokens_sec:.02f} GB/s")

if memory_profile and i==0:
if device != "cuda":
print("Memory profiling only works on CUDA")
else:
if device == "cuda":
snapshot = torch.cuda.memory._snapshot()
with open(f"{memory_profile}.pickle", 'wb') as f:
from pickle import dump
dump(snapshot, f)
print(
f"\nmemory profile {memory_profile}.pickle saved, to convert that to a usable file, use",
"python pytorch/torch/cuda/_memory_viz.py trace_plot <pickle file> -o <desired output name>.html"
)
break

elif device == "xpu":
snapshot = torch.xpu.memory._snapshot()
else:
print("Memory profiling only works on CUDA or XPU devices")

with open(f"{memory_profile}.pickle", 'wb') as f:
from pickle import dump
dump(snapshot, f)
print(
f"\nmemory profile {memory_profile}.pickle saved, to convert that to a usable file, use",
"python pytorch/torch/cuda/_memory_viz.py trace_plot <pickle file> -o <desired output name>.html"
)
break
print("==========")

tokpersec = torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item()
bandwidth = model_size * tokpersec
mem = torch.cuda.max_memory_reserved() /1e9
if device == "cuda":
mem = torch.cuda.max_memory_reserved() /1e9
elif device == "xpu":
mem = torch.xpu.max_memory_reserved() /1e9
print(f"Average tokens/sec: {tokpersec:.2f}")
if batch_size > 1:
print(f"Average tokens/sec including batches {batch_size*tokpersec:.2f}")
Expand Down
20 changes: 18 additions & 2 deletions torchao/quantization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Typically quantization algorithms will have different schemes for how the activa

## Benchmarks
Benchmarks and evaluation are run on a machine with a single NVIDIA-A100-80GB GPU using the scripts for [generation](../_models/llama/generate.py) and [eval](../_models/llama/eval.py). Evaluation was done using the lm_eval library for tasks/data. The models used were meta-llama/Llama-2-7b-chat-hf and meta-llama/Meta-Llama-3-8B.

### CUDA backend
| Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) |
| ----------- | ----------------------- | ------------------- | ------------- | ----------------------- | ---------------- | --------------- |
| Llama-2-7B | Base (bfloat16) | 12.212 | 107.38 | 1418.93 | 13.88 | 13.21 |
Expand All @@ -20,9 +20,16 @@ Benchmarks and evaluation are run on a machine with a single NVIDIA-A100-80GB GP
| | int4wo-64 | 8.316 | 180.80 | 763.33 | 6.88 | 4.22 |
| | int4wo-64-GPTQ | 7.921 | 180.80 | 763.33 | 6.88 | 4.22 |
| | autoquant-int4hqq | 8.110 | 188.41 | 800.58 | 7.14 | 4.25 |
### XPU backend
| Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) |
| ----------- | ----------------------- | ------------------- | ------------- | ----------------------- | ---------------- | --------------- |
| Llama-2-7B | Base (bfloat16) | NA | 42.20 | 557.71 | 13.89 | 13.21 |
| | int8dq | NA | 9.87 | 65.35 | 14.60 | 6.62 |
| | int8wo | NA | 66.24 | 438.61 | 14.60 | 6.62


Benchmarks and evaluation for model meta-llama/Meta-Llama-3.1-8B are run on a machine with a single NVIDIA-H100 GPU using the scripts for [generation](../_models/llama/generate.py) and [eval](../_models/llama/eval.py). Evaluation was done using the lm_eval library for tasks/data.

### CUDA backend
| Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) |
| ----------- | ----------------------- | ------------------- | ------------- | ----------------------- | ---------------- | --------------- |
| Llama-3.1-8B | Base (bfloat16) | 7.54 | 126.90 | 1904.75 | 16.75 | 15.01 |
Expand All @@ -31,6 +38,15 @@ Benchmarks and evaluation for model meta-llama/Meta-Llama-3.1-8B are run on a ma
| | float8wo | 7.60 | 178.46 | 1339.93 | 12.09 | 7.51 |
| | float8dq (PerTensor) | 7.62 | 116.40 | 873.58 | 11.14 | 7.51 |
| | float8dq (Per Row) | 7.61 | 154.63 | 1161.47 | 11.14 | 7.51 |
### XPU backend
| Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) |
| ----------- | ----------------------- | ------------------- | ------------- | ----------------------- | ---------------- | --------------- |
| Llama-3-8.1B | Base (bfloat16) | 7.441 | 40.36 | 605.77 | 16.35 | 15.01 |
| | int8dq | 7.581 | 13.60 | 102.28 | 18.69 | 7.52 |
| | int8wo | 7.447 | 59.49 | 447.27 | 18.60 | 7.52


Benchmarks and evaluation for model meta-llama/Meta-Llama-3.1-8B are run on a machine with a single NVIDIA-H100 GPU or Intel-Max1100 using the scripts for [generation](../_models/llama/generate.py) and [eval](../_models/llama/eval.py). Evaluation was done using the lm_eval library for tasks/data.

note: Int8 dynamic quantization works best on compute bound models like [SAM](https://github.com/pytorch-labs/segment-anything-fast) whereas Llama with batchsize=1 tends to be memory bound, thus the rather low performance.

Expand Down
Loading