-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Vanilla LeViT model run with ORT is slower than PyTorch (seems even slower for large batch size) #12522
Comments
Running the same script on an AWS EC2 c6i instance gives:
lscpu gives
|
To follow up on this, I ran the onnxruntime profiler on my model (on my laptop) to see what is taking so much time. Here's my finding, with batch size = 4: So it seems this issue is indeed related to #12130 When profiling the PyTorch model with FX, it is clearly not the batchnorm taking most of the time: https://pastebin.com/CxmXYbY7 Script: import onnxruntime
import time
import torch
import numpy as np
import json
import pandas as pd
import matplotlib.pyplot as plt
model_path = "/path/to/model.onnx"
options = onnxruntime.SessionOptions()
options.enable_profiling = True
ort_session = onnxruntime.InferenceSession(
model_path,
sess_options=options,
providers=['CPUExecutionProvider']
)
batch_size = 4
print(f"\n--- BATCH SIZE {batch_size} ---")
pt_inputs = dict()
pt_inputs["pixel_values"] = torch.ones(batch_size, 3, 224, 224, dtype=torch.float32)
onnx_inputs = {
"pixel_values": pt_inputs["pixel_values"].cpu().detach().numpy(),
}
for i in range(200):
res = ort_session.run(None, onnx_inputs)
prof = ort_session.end_profiling()
print(prof)
json_path = f"/path/to/{prof}"
with open(json_path, "r") as f:
js = json.load(f)
def process_profiling(js):
"""
Flattens json returned by onnxruntime profiling.
:param js: json
:return: list of dictionaries
"""
rows = []
for row in js:
if 'args' in row and isinstance(row['args'], dict):
for k, v in row['args'].items():
row[f'args_{k}'] = v
del row['args']
rows.append(row)
return rows
df = pd.DataFrame(process_profiling(js))
gr_dur = df[['dur', "args_op_name"]].groupby(
"args_op_name").sum().sort_values('dur')
gr_n = df[['dur', "args_op_name"]].groupby(
"args_op_name").count().sort_values('dur')
gr_n = gr_n.loc[gr_dur.index, :]
fig, ax = plt.subplots(1, 2, figsize=(8, 4))
gr_dur.plot.barh(ax=ax[0])
gr_n.plot.barh(ax=ax[1])
ax[0].set_title("duration")
ax[1].set_title("n occurences")
plt.show() |
Thanks a lot for your help. Should I do the fusing by hand or is this an optimization proposed by onnxruntime? I could not find ressource on this in the documentation. Edit: note I am using an exotic model where there is this flatten inbetween MatMul and BatchNorm. |
Describe the bug
The model https://huggingface.co/facebook/levit-256 converted to ONNX is slower than the PyTorch model.
Reproduce
Convert the model to onnx:
For ease of reproduction, you can as well find the model.onnx file from this conversion here: https://huggingface.co/fxmarty/bad-levit-onnx/tree/main
Run the code:
Output:
Note that this is not consistent, for example https://huggingface.co/google/vit-base-patch16-224 gives the time below, which is very fine:
lscpu
Urgency
None
System information
To Reproduce
Expected behavior
ONNX Runtime at least as fast as PyTorch
Additional context
PyTorch default
torch.get_num_threads()
is 14, with 14 cores used during inference. For ONNX Runtime, it seems like only the 10 physical cores are used. I tried to play with those numbers, but it does not really help.At first I thought the issue is related to #12130 , but I am not 100% sure.
Am I doing something wrong?
The text was updated successfully, but these errors were encountered: