Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vanilla LeViT model run with ORT is slower than PyTorch (seems even slower for large batch size) #12522

Closed
fxmarty opened this issue Aug 9, 2022 · 4 comments · Fixed by huggingface/optimum#348
Labels
core runtime issues related to core runtime

Comments

@fxmarty
Copy link
Contributor

fxmarty commented Aug 9, 2022

Describe the bug
The model https://huggingface.co/facebook/levit-256 converted to ONNX is slower than the PyTorch model.

Reproduce

conda create -n myenv python=3.9
pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
pip install onnx onnxruntime transformers

Convert the model to onnx:

python -m transformers.onnx -m facebook/levit-256 --feature image-classification --opset 13 outdir_levit

For ease of reproduction, you can as well find the model.onnx file from this conversion here: https://huggingface.co/fxmarty/bad-levit-onnx/tree/main

Run the code:

import onnxruntime
import time
import torch
import numpy as np

from transformers import AutoModelForImageClassification

model_path = "/path/to/outdir_levit/model.onnx"

model_pt = AutoModelForImageClassification.from_pretrained("facebook/levit-256")
model_pt.eval()

ort_session = onnxruntime.InferenceSession(
    model_path,
    providers=['CPUExecutionProvider']
)

print(model_pt.device)
print(ort_session.get_providers())

# check models yield the same output
pt_inputs = dict()
pt_inputs["pixel_values"] = torch.ones(4, 3, 224, 224, dtype=torch.float32)

onnx_inputs = {
    "pixel_values": pt_inputs["pixel_values"].cpu().detach().numpy(),
}

with torch.no_grad():
    res_pt = model_pt(**pt_inputs)
res_ort = ort_session.run(None, onnx_inputs)

assert np.allclose(res_ort[0], res_pt["logits"].numpy(), atol=1e-4)

for batch_size in [1, 2, 4, 8, 16]:
    print(f"\n--- BATCH SIZE {batch_size} ---")
    pt_inputs = dict()
    pt_inputs["pixel_values"] = torch.ones(batch_size, 3, 224, 224, dtype=torch.float32)

    onnx_inputs = {
        "pixel_values": pt_inputs["pixel_values"].cpu().detach().numpy(),
    }

    # warmup pytorch
    with torch.no_grad():
        for i in range(20):
            res = model_pt(**pt_inputs)

        start = time.time()
        for i in range(200):
            res = model_pt(**pt_inputs)
        runtime_pt = time.time() - start
        print(f"PyTorch: {runtime_pt:.2f} s")

    # warmup onnxruntime
    for i in range(20):
        res = ort_session.run(None, onnx_inputs)

    start = time.time()
    for i in range(200):
        res = ort_session.run(None, onnx_inputs)
    runtime_ort = time.time() - start
    print(f"ONNX Runtime: {runtime_ort:.2f} s")
    print(f"ORT {(runtime_ort - runtime_pt) / runtime_pt * 100:.2f} % slower than PT")

Output:

cpu
['CPUExecutionProvider']

--- BATCH SIZE 1 ---
PyTorch: 2.66 s
ONNX Runtime: 3.20 s
ORT 20.49 % slower than PT

--- BATCH SIZE 2 ---
PyTorch: 4.40 s
ONNX Runtime: 5.54 s
ORT 26.14 % slower than PT

--- BATCH SIZE 4 ---
PyTorch: 7.25 s
ONNX Runtime: 10.71 s
ORT 47.69 % slower than PT

--- BATCH SIZE 8 ---
PyTorch: 14.65 s
ONNX Runtime: 21.37 s
ORT 45.90 % slower than PT

--- BATCH SIZE 16 ---
PyTorch: 33.66 s
ONNX Runtime: 49.81 s
ORT 47.96 % slower than PT

Note that this is not consistent, for example https://huggingface.co/google/vit-base-patch16-224 gives the time below, which is very fine:

--- BATCH SIZE 1 ---
PyTorch: 2.27 s
ONNX Runtime: 1.96 s
ORT -13.75 % slower than PT

--- BATCH SIZE 2 ---
PyTorch: 5.58 s
ONNX Runtime: 3.77 s
ORT -32.35 % slower than PT

--- BATCH SIZE 4 ---
PyTorch: 10.16 s
ONNX Runtime: 7.75 s
ORT -23.65 % slower than PT

--- BATCH SIZE 8 ---
PyTorch: 20.66 s
ONNX Runtime: 15.73 s
ORT -23.85 % slower than PT

lscpu

Architecture:            x86_64
  CPU op-mode(s):        32-bit, 64-bit
  Address sizes:         46 bits physical, 48 bits virtual
  Byte Order:            Little Endian
CPU(s):                  20
  On-line CPU(s) list:   0-19
Vendor ID:               GenuineIntel
  Model name:            12th Gen Intel(R) Core(TM) i7-1280P
    CPU family:          6
    Model:               154
    Thread(s) per core:  2
    Core(s) per socket:  14
    Socket(s):           1
    Stepping:            3
    CPU max MHz:         4800,0000
    CPU min MHz:         400,0000
    BogoMIPS:            3993.60
    Flags:               fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx p
                         dpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclm
                         ulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer 
                         aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l2 invpcid_single cdp_l2 ssbd ibrs ibpb stibp ibrs_enhanced 
                         tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdt_a rdseed adx smap clflushopt c
                         lwb intel_pt sha_ni xsaveopt xsavec xgetbv1 xsaves split_lock_detect avx_vnni dtherm ida arat pln pts hwp hwp_notify hwp_act_window 
                         hwp_epp hwp_pkg_req umip pku ospke waitpkg gfni vaes vpclmulqdq tme rdpid movdiri movdir64b fsrm md_clear serialize pconfig arch_lbr
                          flush_l1d arch_capabilities
Virtualization features: 
  Virtualization:        VT-x
Caches (sum of all):     
  L1d:                   544 KiB (14 instances)
  L1i:                   704 KiB (14 instances)
  L2:                    11,5 MiB (8 instances)
  L3:                    24 MiB (1 instance)
NUMA:                    
  NUMA node(s):          1
  NUMA node0 CPU(s):     0-19
Vulnerabilities:         
  Itlb multihit:         Not affected
  L1tf:                  Not affected
  Mds:                   Not affected
  Meltdown:              Not affected
  Mmio stale data:       Not affected
  Spec store bypass:     Mitigation; Speculative Store Bypass disabled via prctl and seccomp
  Spectre v1:            Mitigation; usercopy/swapgs barriers and __user pointer sanitization
  Spectre v2:            Mitigation; Enhanced IBRS, IBPB conditional, RSB filling
  Srbds:                 Not affected
  Tsx async abort:       Not affected

Urgency
None

System information

  • OS Platform and Distribution: Linux 5.15.0-43-generic 46-Ubuntu SMP Tue Jul 12 10:30:17 UTC 2022 x86_64 x86_64 x86_64 GNU/Linux
  • ONNX Runtime installed from (source or binary): 1.12.0 (binary)
  • ONNX Runtime version: 1.12.0
  • Python version: 3.9.12
  • PyTorch version: 1.12.0

To Reproduce

  • Describe steps/code to reproduce the behavior.
  • Attach the ONNX model to the issue (where applicable) to expedite investigation.

Expected behavior
ONNX Runtime at least as fast as PyTorch

Additional context
PyTorch default torch.get_num_threads() is 14, with 14 cores used during inference. For ONNX Runtime, it seems like only the 10 physical cores are used. I tried to play with those numbers, but it does not really help.

At first I thought the issue is related to #12130 , but I am not 100% sure.

Am I doing something wrong?

@fxmarty
Copy link
Contributor Author

fxmarty commented Aug 9, 2022

Running the same script on an AWS EC2 c6i instance gives:

cpu
['CPUExecutionProvider']

--- BATCH SIZE 1 ---
PyTorch: 1.99 s
ONNX Runtime: 5.56 s
ORT 179.97 % slower than PT

--- BATCH SIZE 2 ---
PyTorch: 2.61 s
ONNX Runtime: 10.81 s
ORT 313.96 % slower than PT

--- BATCH SIZE 4 ---
PyTorch: 3.67 s
ONNX Runtime: 19.81 s
ORT 439.57 % slower than PT

--- BATCH SIZE 8 ---
PyTorch: 6.62 s
ONNX Runtime: 37.55 s
ORT 466.96 % slower than PT

lscpu gives

Architecture:            x86_64
  CPU op-mode(s):        32-bit, 64-bit
  Address sizes:         46 bits physical, 48 bits virtual
  Byte Order:            Little Endian
CPU(s):                  16
  On-line CPU(s) list:   0-15
Vendor ID:               GenuineIntel
  Model name:            Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz
    CPU family:          6
    Model:               106
    Thread(s) per core:  2
    Core(s) per socket:  8
    Socket(s):           1
    Stepping:            6
    BogoMIPS:            5799.98
    Flags:               fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm
                          constant_tsc rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2
                         apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb st
                         ibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb a
                         vx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves wbnoinvd ida arat avx512vbmi pku ospke avx512_vbmi2 gfni vaes vpclmu
                         lqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq rdpid md_clear flush_l1d arch_capabilities
Virtualization features: 
  Hypervisor vendor:     KVM
  Virtualization type:   full
Caches (sum of all):     
  L1d:                   384 KiB (8 instances)
  L1i:                   256 KiB (8 instances)
  L2:                    10 MiB (8 instances)
  L3:                    54 MiB (1 instance)
NUMA:                    
  NUMA node(s):          1
  NUMA node0 CPU(s):     0-15
Vulnerabilities:         
  Itlb multihit:         Not affected
  L1tf:                  Not affected
  Mds:                   Not affected
  Meltdown:              Not affected
  Spec store bypass:     Mitigation; Speculative Store Bypass disabled via prctl and seccomp
  Spectre v1:            Mitigation; usercopy/swapgs barriers and __user pointer sanitization
  Spectre v2:            Mitigation; Enhanced IBRS, IBPB conditional, RSB filling
  Srbds:                 Not affected
  Tsx async abort:       Not affected

@fxmarty
Copy link
Contributor Author

fxmarty commented Aug 9, 2022

To follow up on this, I ran the onnxruntime profiler on my model (on my laptop) to see what is taking so much time. Here's my finding, with batch size = 4:

image

So it seems this issue is indeed related to #12130

When profiling the PyTorch model with FX, it is clearly not the batchnorm taking most of the time: https://pastebin.com/CxmXYbY7

Script:

import onnxruntime
import time
import torch
import numpy as np
import json
import pandas as pd
import matplotlib.pyplot as plt

model_path = "/path/to/model.onnx"

options = onnxruntime.SessionOptions()
options.enable_profiling = True

ort_session = onnxruntime.InferenceSession(
    model_path,
    sess_options=options,
    providers=['CPUExecutionProvider']
)

batch_size = 4
print(f"\n--- BATCH SIZE {batch_size} ---")
pt_inputs = dict()
pt_inputs["pixel_values"] = torch.ones(batch_size, 3, 224, 224, dtype=torch.float32)

onnx_inputs = {
    "pixel_values": pt_inputs["pixel_values"].cpu().detach().numpy(),
}

for i in range(200):
    res = ort_session.run(None, onnx_inputs)

prof = ort_session.end_profiling()
print(prof)

json_path = f"/path/to/{prof}"
with open(json_path, "r") as f:
    js = json.load(f)

def process_profiling(js):
    """
    Flattens json returned by onnxruntime profiling.
    :param js: json
    :return: list of dictionaries
    """
    rows = []
    for row in js:
        if 'args' in row and isinstance(row['args'], dict):
            for k, v in row['args'].items():
                row[f'args_{k}'] = v
            del row['args']
        rows.append(row)
    return rows

df = pd.DataFrame(process_profiling(js))

gr_dur = df[['dur', "args_op_name"]].groupby(
    "args_op_name").sum().sort_values('dur')
gr_n = df[['dur', "args_op_name"]].groupby(
    "args_op_name").count().sort_values('dur')
gr_n = gr_n.loc[gr_dur.index, :]

fig, ax = plt.subplots(1, 2, figsize=(8, 4))
gr_dur.plot.barh(ax=ax[0])
gr_n.plot.barh(ax=ax[1])
ax[0].set_title("duration")
ax[1].set_title("n occurences")

plt.show()

@yufenglee
Copy link
Member

BatchNormalization itself is not optimized specifically because it is usually following Conv and can be fused into Conv.
For this model, BatchNormalization can be fused together with MatMul into an extended Gemm or a MatMul + Bias.

image

@fxmarty
Copy link
Contributor Author

fxmarty commented Aug 10, 2022

Thanks a lot for your help. Should I do the fusing by hand or is this an optimization proposed by onnxruntime? I could not find ressource on this in the documentation.

Edit: note I am using an exotic model where there is this flatten inbetween MatMul and BatchNorm.
Edit2: Can confirm BatchNorm2d is folded into Conv2d when converting from PyTorch. Here it's an exotic case hence no automatic folding.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
core runtime issues related to core runtime
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants