Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Benchmark intel xpu #1259

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@
def device_sync(device):
if "cuda" in device:
torch.cuda.synchronize(device)
elif "xpu" in device:
torch.xpu.synchronize(device)
elif ("cpu" in device) or ("mps" in device):
pass
else:
print(f"device={device} is not yet suppported")

default_device = 'cuda' if torch.cuda.is_available() else 'cpu'
default_device = 'cuda' if torch.cuda.is_available() else 'xpu' if torch.xpu.is_available() else 'cpu'

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
Expand Down Expand Up @@ -333,15 +335,21 @@ def main(
prefill = torch.compile(prefill, fullgraph=True, dynamic=True)

if memory_profile:
torch.cuda.memory._record_memory_history(True,trace_alloc_max_entries=250000, trace_alloc_record_context=True)
if device == "cuda":
torch.cuda.memory._record_memory_history(True,trace_alloc_max_entries=250000, trace_alloc_record_context=True)
elif device == "xpu":
torch.xpu.memory._record_memory_history(True,trace_alloc_max_entries=250000, trace_alloc_record_context=True)
aggregate_metrics = {
'tokens_per_sec': [],
}
start = -1 if compile else 0

for i in range(start, num_samples):
if i==0:
torch.cuda.reset_peak_memory_stats()
if device == "cuda":
torch.cuda.reset_peak_memory_stats()
elif device == "xpu":
torch.xpu.reset_peak_memory_stats()
device_sync(device=device) # MKG
if i >= 0 and interactive:
prompt = input("What is your prompt? ")
Expand Down Expand Up @@ -409,7 +417,10 @@ def callback(x):
print(f"Bandwidth achieved: {model_size * tokens_sec:.02f} GB/s")

if memory_profile and i==0:
snapshot = torch.cuda.memory._snapshot()
if device == "cuda":
snapshot = torch.cuda.memory._snapshot()
elif device == "xpu":
snapshot = torch.xpu.memory._snapshot()
with open(f"{memory_profile}.pickle", 'wb') as f:
from pickle import dump
dump(snapshot, f)
Expand All @@ -423,7 +434,10 @@ def callback(x):

tokpersec = torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item()
bandwidth = model_size * tokpersec
mem = torch.cuda.max_memory_reserved() /1e9
if device == "cuda":
mem = torch.cuda.max_memory_reserved() /1e9
elif device == "xpu":
mem = torch.xpu.max_memory_reserved() /1e9
print(f"Average tokens/sec: {tokpersec:.2f}")
if batch_size > 1:
print(f"Average tokens/sec including batches {batch_size*tokpersec:.2f}")
Expand Down
Loading