Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ViT stdcase #186

Merged
merged 7 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions inference/benchmarks/vit_l_16/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
### 1. 推理数据集
> Download website:https://image-net.org/

We use ImageNet2012 Validation Images:
| Dataset | FileName | Size | Checksum |
| ----------------------------- | ---------------------- | ----- | ------------------------------------- |
| Validation images (all tasks) | ILSVRC2012_img_val.tar | 6.3GB | MD5: 29b22e2961454d5413ddabcf34fc5622 |
Dataset format conversion:
https://github.com/pytorch/examples/blob/main/imagenet/extract_ILSVRC.sh

make sure ILSVRC2012_img_train.tar & ILSVRC2012_img_val.tar are in the same directory with extract_ILSVRC.sh.
```bash
sh extract_ILSVRC.sh
```

preview directory structures of decompressed dataset.

```bash
tree -d -L 1
```

```
.
├── train
└── val
```
dataset samples size

```bash
find ./val -name "*JPEG" | wc -l
50000
```

### 2. 模型与权重

* 模型实现
* pytorch:transformers.ViTForImageClassification(hugging face)
* 权重下载
* pytorch:from_pretrained("google/vit-large-patch16-224")(hugging face)

### 2. 软硬件配置与运行信息参考

#### 2.1 Nvidia A100

- ##### 硬件环境
- 机器、加速卡型号: NVIDIA_A100-SXM4-40GB
- 多机网络类型、带宽: InfiniBand,200Gb/s

- ##### 软件环境
- OS版本:Ubuntu 20.04
- OS kernel版本: 5.4.0-113-generic
- 加速卡驱动版本:470.129.06
- Docker 版本:20.10.16
- 训练框架版本:pytorch-1.13.0a0+937e930
- 依赖软件版本:
- cuda: 11.8

- 推理工具包

- TensorRT 8.5.1.7
- torch_tensorrt 1.3.0

### 3. 运行情况

* 指标列表

| 指标名称 | 指标值索引 | 特殊说明 |
| ------------------ | ---------------- | -------------------------------------------- |
| 数据精度 | precision | 可选fp32/fp16 |
| 批尺寸 | bs | |
| 硬件存储使用 | mem | 通常称为“显存”,单位为GiB |
| 端到端时间 | e2e_time | 总时间+Perf初始化等时间 |
| 验证总吞吐量 | p_val_whole | 实际验证图片数除以总验证时间 |
| 验证计算吞吐量 | p_val_core | 不包含IO部分耗时 |
| 推理总吞吐量 | p_infer_whole | 实际推理图片数除以总推理时间 |
| **推理计算吞吐量** | **\*p_infer_core** | 不包含IO部分耗时 |
| **计算卡使用率** | **\*MFU** | model flops utilization |
| 推理结果 | acc(推理/验证) | 单位为top1分类准确率(acc1) |

* 指标值

| 推理工具 | precision | bs | e2e_time | p_val_whole | p_val_core | p_infer_whole | \*p_infer_core | \*MFU | acc | mem |
| ----------- | --------- | ---- | ---- | -------- | ----------- | ---------- | ------------- | ------------ | ----------- | ----------- |
| tensorrt | fp16 | 64 |1009.7 | 777.8 | 796.7 | 825.8 | 1329.2 | 26.2% | 79.0/79.3 | 35.0/40.0 |
| tensorrt | fp32 | 32 | 1275.9 | 482.4 | 491.1 | 555.5 | 590.5 | 23.3% | 79.3/79.3 | 35.0/40.0 |

5 changes: 5 additions & 0 deletions inference/benchmarks/vit_l_16/pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .dataloader import build_dataloader
from .model import create_model
from .export import export_model
from .evaluator import evaluator
from .forward import model_forward, engine_forward
49 changes: 49 additions & 0 deletions inference/benchmarks/vit_l_16/pytorch/dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torchvision as tv
from torch.utils.data import DataLoader as dl
import torch
import tqdm


def build_dataset(config):
crop = 256
c_crop = 224
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)

if config.fp16:

class ToFloat16(object):

def __call__(self, tensor):
return tensor.to(dtype=torch.float16)

tx = tv.transforms.Compose([
tv.transforms.Resize(crop),
tv.transforms.CenterCrop(c_crop),
tv.transforms.ToTensor(),
ToFloat16(),
tv.transforms.Normalize(mean=mean, std=std),
])
dataset = tv.datasets.ImageFolder(config.data_dir, tx)
else:
tx = tv.transforms.Compose([
tv.transforms.Resize(crop),
tv.transforms.CenterCrop(c_crop),
tv.transforms.ToTensor(),
tv.transforms.Normalize(mean=mean, std=std),
])
dataset = tv.datasets.ImageFolder(config.data_dir, tx)

return dataset


def build_dataloader(config):
dataset = build_dataset(config)
loader = dl(dataset,
batch_size=config.batch_size,
shuffle=False,
drop_last=True,
num_workers=config.num_workers,
pin_memory=True)

return loader
10 changes: 10 additions & 0 deletions inference/benchmarks/vit_l_16/pytorch/evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
def topk(output, target, ks=(1, )):
_, pred = output.topk(max(ks), 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
return [correct[:k].max(0)[0] for k in ks]


def evaluator(pred, ground_truth):
top1, top5 = topk(pred, ground_truth, ks=(1, 5))
return top1
34 changes: 34 additions & 0 deletions inference/benchmarks/vit_l_16/pytorch/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import torch
import os


def export_model(model, config):
if config.exist_onnx_path is not None:
return config.exist_onnx_path

filename = config.case + "_bs" + str(config.batch_size)
filename = filename + "_" + str(config.framework)
filename = filename + "_fp16" + str(config.fp16)
filename = "onnxs/" + filename + ".onnx"
onnx_path = config.perf_dir + "/" + filename

dummy_input = torch.randn(config.batch_size, 3, 224, 224)

if config.fp16:
dummy_input = dummy_input.half()
dummy_input = dummy_input.cuda()

dir_onnx_path = os.path.dirname(onnx_path)
os.makedirs(dir_onnx_path, exist_ok=True)

with torch.no_grad():
torch.onnx.export(model,
dummy_input,
onnx_path,
verbose=False,
input_names=["input"],
output_names=["output"],
training=torch.onnx.TrainingMode.EVAL,
do_constant_folding=True)

return onnx_path
106 changes: 106 additions & 0 deletions inference/benchmarks/vit_l_16/pytorch/forward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from loguru import logger
import torch
import numpy as np
import time
from tools import torch_sync


def cal_perf(config, dataloader_len, duration, core_time, str_prefix):
model_forward_perf = config.repeat * dataloader_len * config.batch_size / duration
logger.info(str_prefix + "(" + config.framework + ") Perf: " +
str(model_forward_perf) + " ips")
model_forward_core_perf = config.repeat * dataloader_len * config.batch_size / core_time
logger.info(str_prefix + "(" + config.framework + ") core Perf: " +
str(model_forward_core_perf) + " ips")
return round(model_forward_perf, 3), round(model_forward_core_perf, 3)


def model_forward(model, dataloader, evaluator, config):
if config.no_validation:
return None, None, None
start = time.time()
core_time = 0.0
acc = []

for times in range(config.repeat):

logger.debug("Repeat: " + str(times + 1))

all_top1 = []
for step, (x, y) in enumerate(dataloader):
torch_sync(config)
core_time_start = time.time()

if step % config.log_freq == 0:
logger.debug("Step: " + str(step) + " / " +
str(len(dataloader)))

with torch.no_grad():

x = x.cuda()
y = y.cuda()
pred = model(x)[0]
torch_sync(config)
core_time += time.time() - core_time_start

top1 = evaluator(pred, y)

all_top1.extend(top1.cpu())

acc.append(np.mean(all_top1))

logger.info("Top1 Acc: " + str(acc))

duration = time.time() - start
model_forward_perf, model_forward_core_perf = cal_perf(
config, len(dataloader), duration, core_time, "Validation")

return model_forward_perf, model_forward_core_perf, round(
float(np.mean(acc)), 3)


def engine_forward(model, dataloader, evaluator, config):
start = time.time()
core_time = 0.0
foo_time = 0.0
acc = []

for times in range(config.repeat):

logger.debug("Repeat: " + str(times + 1))

all_top1 = []
for step, (x, y) in enumerate(dataloader):
torch_sync(config)
core_time_start = time.time()

if step % config.log_freq == 0:
logger.debug("Step: " + str(step) + " / " +
str(len(dataloader)))

with torch.no_grad():

outputs = model([x])
pred = outputs[0]
foo_time += outputs[1]

torch_sync(config)
core_time += time.time() - core_time_start

pred = pred[0].float()
pred = pred.reshape(config.batch_size, -1)
pred = pred.cpu()
top1 = evaluator(pred, y)

all_top1.extend(top1.cpu())

acc.append(np.mean(all_top1))

logger.info("Top1 Acc: " + str(acc))

duration = time.time() - start - foo_time
model_forward_perf, model_forward_core_perf = cal_perf(
config, len(dataloader), duration, core_time - foo_time, "Inference")

return model_forward_perf, model_forward_core_perf, round(
float(np.mean(acc)), 3)
14 changes: 14 additions & 0 deletions inference/benchmarks/vit_l_16/pytorch/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from transformers import ViTForImageClassification as vit


def create_model(config):
if config.no_validation:
assert config.exist_onnx_path is not None
return None
model = vit.from_pretrained(config.weights)
model.cuda()
model.eval()
if config.fp16:
model.half()

return model
1 change: 1 addition & 0 deletions inference/benchmarks/vit_l_16/pytorch/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
transformers
16 changes: 16 additions & 0 deletions inference/configs/vit_l_16/configurations.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
batch_size: 32
# 1 item(like 1 sequence, 1 image) flops
# Attention! For transformer decoder like bert, 1 token cause 2*param flops, so we need 2*length*params like 2*512*0.33B here
# format: a_1*a*2*...*a_nea_0,like 2*512*0.33e9(bert) or 4.12e9(resnet50)
flops: 6.16e10
fp16: false
compiler: tensorrt
num_workers: 8
log_freq: 30
repeat: 5
# skip validation(will also skip create_model, export onnx). Assert exist_onnx_path != null
no_validation: false
# set a real onnx_path to use exist, or set it to anything but null to avoid export onnx manually(like torch-tensorrt)
exist_onnx_path: null
# set a exist path of engine file like resnet50.trt/resnet50.plan/resnet50.engine
exist_compiler_path: null
1 change: 1 addition & 0 deletions inference/configs/vit_l_16/parameters.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
weights: "google/vit-large-patch16-224"
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
trt_tmp_path: nvidia_tmp/vit.trt
has_dynamic_axis: false
torchtrt_full_compile: true