Skip to content

Commit

Permalink
Inference frame (FlagOpen#136)
Browse files Browse the repository at this point in the history
* upd ign

* init inference

* fix trtexec

* fix trtexec

* fix

* upd pipe

* rm secret

* fix

* add 5time 4perf and summary in run_inference

* update monitor (#1)

* finish logdir

* finish merge

* format

* fix

* lic & rdm

* ur

* Update README.md

* fix log output

* fix cal perf

* fix sync

* fix output

* fix

* fixbug

* fix frame

* ur

* add skip validation

* fix

* fix kunlun

* fix

---------

Co-authored-by: uuup <55571217+upvenly@users.noreply.github.com>
  • Loading branch information
shh2000 and upvenly authored Aug 3, 2023
1 parent b50c550 commit 7211a51
Show file tree
Hide file tree
Showing 33 changed files with 2,745 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@
__pycache__/
.pytest_cache
training/result/*
inference/result/*
inference/onnxs/*
86 changes: 86 additions & 0 deletions inference/benchmarks/resnet50/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
### 1. 推理数据集
> Download website:https://image-net.org/
We use ImageNet2012 Validation Images:
| Dataset | FileName | Size | Checksum |
| ----------------------------- | ---------------------- | ----- | ------------------------------------- |
| Validation images (all tasks) | ILSVRC2012_img_val.tar | 6.3GB | MD5: 29b22e2961454d5413ddabcf34fc5622 |
Dataset format conversion:
https://github.com/pytorch/examples/blob/main/imagenet/extract_ILSVRC.sh

make sure ILSVRC2012_img_train.tar & ILSVRC2012_img_val.tar are in the same directory with extract_ILSVRC.sh.
```bash
sh extract_ILSVRC.sh
```

preview directory structures of decompressed dataset.

```bash
tree -d -L 1
```

```
.
├── train
└── val
```
dataset samples size

```bash
find ./val -name "*JPEG" | wc -l
50000
```

### 2. 模型与权重

* 模型实现
* pytorch:torchvision.models.resnet50
* 权重下载
* pytorch:https://download.pytorch.org/models/resnet50-0676ba61.pth

### 2. 软硬件配置与运行信息参考

#### 2.1 Nvidia A100

- ##### 硬件环境
- 机器、加速卡型号: NVIDIA_A100-SXM4-40GB
- 多机网络类型、带宽: InfiniBand,200Gb/s

- ##### 软件环境
- OS版本:Ubuntu 20.04
- OS kernel版本: 5.4.0-113-generic
- 加速卡驱动版本:470.129.06
- Docker 版本:20.10.16
- 训练框架版本:pytorch-1.13.0a0+937e930
- 依赖软件版本:
- cuda: 11.8

- 推理工具包

- TensorRT 8.5.1.7
- torch_tensorrt 1.3.0

### 3. 运行情况

* 指标列表

| 指标名称 | 指标值索引 | 特殊说明 |
| ------------------ | ---------------- | -------------------------------------------- |
| 数据精度 | precision | 可选fp32/fp16 |
| 批尺寸 | bs | |
| 硬件存储使用 | mem | 通常称为“显存”,单位为GiB |
| 端到端时间 | e2e_time | 总时间+Perf初始化等时间 |
| 验证总吞吐量 | p_val_whole | 实际验证图片数除以总验证时间 |
| 验证计算吞吐量 | \*p_val_core | 不包含IO部分耗时 |
| 推理总吞吐量 | p_infer_whole | 实际推理图片数除以总推理时间 |
| **推理计算吞吐量** | **\*p_infer_core** | 不包含IO部分耗时 |
| 推理结果 | acc(推理/验证) | 单位为top1分类准确率(acc1) |

* 指标值

| 推理工具 | precision | bs | e2e_time | p_val_whole | \*p_val_core | p_infer_whole | \*p_infer_core | acc | mem |
| ----------- | --------- | ---- | -------- | ----------- | ---------- | ------------- | ------------ | ----------- | ---------- |
| tensorrt | fp16 | 256 | 613.4 | 1358.9 | 4263.3 | 1391.4 | 12406.0 | 76.2/76.2 | 19.7/40.0 |
| tensorrt | fp32 | 256 | 474.4 | 1487.3 | 2653.2 | 1560.3 | 6091.6 | 76.2/76.2 | 28.86/40.0 |
| torchtrt | fp16 | 256 | 716.4 | 1370.4 | 4282.6 | 1320.0 | 4723.0 | 76.2/76.2 | 9.42/40.0 |

5 changes: 5 additions & 0 deletions inference/benchmarks/resnet50/pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .dataloader import build_dataloader
from .model import create_model
from .export import export_model
from .evaluator import evaluator
from .forward import model_forward, engine_forward
49 changes: 49 additions & 0 deletions inference/benchmarks/resnet50/pytorch/dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torchvision as tv
from torch.utils.data import DataLoader as dl
import torch
import tqdm


def build_dataset(config):
crop = 256
c_crop = 224
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)

if config.fp16:

class ToFloat16(object):

def __call__(self, tensor):
return tensor.to(dtype=torch.float16)

tx = tv.transforms.Compose([
tv.transforms.Resize(crop),
tv.transforms.CenterCrop(c_crop),
tv.transforms.ToTensor(),
ToFloat16(),
tv.transforms.Normalize(mean=mean, std=std),
])
dataset = tv.datasets.ImageFolder(config.data_dir, tx)
else:
tx = tv.transforms.Compose([
tv.transforms.Resize(crop),
tv.transforms.CenterCrop(c_crop),
tv.transforms.ToTensor(),
tv.transforms.Normalize(mean=mean, std=std),
])
dataset = tv.datasets.ImageFolder(config.data_dir, tx)

return dataset


def build_dataloader(config):
dataset = build_dataset(config)
loader = dl(dataset,
batch_size=config.batch_size,
shuffle=False,
drop_last=True,
num_workers=config.num_workers,
pin_memory=True)

return loader
10 changes: 10 additions & 0 deletions inference/benchmarks/resnet50/pytorch/evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
def topk(output, target, ks=(1, )):
_, pred = output.topk(max(ks), 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
return [correct[:k].max(0)[0] for k in ks]


def evaluator(pred, ground_truth):
top1, top5 = topk(pred, ground_truth, ks=(1, 5))
return top1
34 changes: 34 additions & 0 deletions inference/benchmarks/resnet50/pytorch/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import torch
import os


def export_model(model, config):
if config.exist_onnx_path is not None:
return config.exist_onnx_path

filename = config.case + "_bs" + str(config.batch_size)
filename = filename + "_" + str(config.framework)
filename = filename + "_fp16" + str(config.fp16)
filename = "onnxs/" + filename + ".onnx"
onnx_path = config.perf_dir + "/" + filename

dummy_input = torch.randn(config.batch_size, 3, 224, 224)

if config.fp16:
dummy_input = dummy_input.half()
dummy_input = dummy_input.cuda()

dir_onnx_path = os.path.dirname(onnx_path)
os.makedirs(dir_onnx_path, exist_ok=True)

with torch.no_grad():
torch.onnx.export(model,
dummy_input,
onnx_path,
verbose=False,
input_names=["input"],
output_names=["output"],
training=torch.onnx.TrainingMode.EVAL,
do_constant_folding=True)

return onnx_path
103 changes: 103 additions & 0 deletions inference/benchmarks/resnet50/pytorch/forward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from loguru import logger
import torch
import numpy as np
import time
from tools import torch_sync


def cal_perf(config, dataloader_len, duration, core_time, str_prefix):
model_forward_perf = config.repeat * dataloader_len * config.batch_size / duration
logger.info(str_prefix + "(" + config.framework + ") Perf: " +
str(model_forward_perf) + " ips")
model_forward_core_perf = config.repeat * dataloader_len * config.batch_size / core_time
logger.info(str_prefix + "(" + config.framework + ") core Perf: " +
str(model_forward_core_perf) + " ips")
return round(model_forward_perf, 3), round(model_forward_core_perf, 3)


def model_forward(model, dataloader, evaluator, config):
if config.no_validation:
return None, None, None
start = time.time()
core_time = 0.0
acc = []

for times in range(config.repeat):

logger.debug("Repeat: " + str(times + 1))

all_top1 = []
for step, (x, y) in enumerate(dataloader):
torch_sync(config)
core_time_start = time.time()

if step % config.log_freq == 0:
logger.debug("Step: " + str(step) + " / " +
str(len(dataloader)))

with torch.no_grad():

x = x.cuda()
y = y.cuda()
pred = model(x)
torch_sync(config)
core_time += time.time() - core_time_start

top1 = evaluator(pred, y)

all_top1.extend(top1.cpu())

acc.append(np.mean(all_top1))

logger.info("Top1 Acc: " + str(acc))

duration = time.time() - start
model_forward_perf, model_forward_core_perf = cal_perf(
config, len(dataloader), duration, core_time, "Validation")

return model_forward_perf, model_forward_core_perf, round(
float(np.mean(acc)), 3)


def engine_forward(model, dataloader, evaluator, config):
start = time.time()
core_time = 0.0
foo_time = 0.0
acc = []

for times in range(config.repeat):

logger.debug("Repeat: " + str(times + 1))

all_top1 = []
for step, (x, y) in enumerate(dataloader):
torch_sync(config)
core_time_start = time.time()

if step % config.log_freq == 0:
logger.debug("Step: " + str(step) + " / " +
str(len(dataloader)))

with torch.no_grad():

outputs = model([x])
pred = outputs[0][0]
foo_time += outputs[1]
pred = pred.float()
torch_sync(config)
core_time += time.time() - core_time_start

top1 = evaluator(pred, y)

all_top1.extend(top1.cpu())

acc.append(np.mean(all_top1))

logger.info("Top1 Acc: " + str(acc))

duration = time.time() - start - foo_time
model_forward_perf, model_forward_core_perf = cal_perf(
config, len(dataloader), duration, core_time - foo_time, "Inference")

return model_forward_perf, model_forward_core_perf, round(
float(np.mean(acc)), 3)
15 changes: 15 additions & 0 deletions inference/benchmarks/resnet50/pytorch/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from torchvision.models import resnet50
from torchvision.models import ResNet50_Weights as w


def create_model(config):
if config.no_validation:
assert config.exist_onnx_path is not None
return None
model = resnet50(weights=w.IMAGENET1K_V1)
model.cuda()
model.eval()
if config.fp16:
model.half()

return model
17 changes: 17 additions & 0 deletions inference/configs/host.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
FLAGPERF_PATH: "/home/FlagPerf/inference"
FLAGPERF_LOG_PATH: "result"
VENDOR: "nvidia"
MODEL: "resnet50"
FLAGPERF_LOG_LEVEL: "INFO"
LOG_CALL_INFORMATION: True
HOSTS: ["127.0.0.1"]
SSH_PORT: "22"
HOSTS_PORTS: ["2222"]
MASTER_PORT: "29501"
SHM_SIZE: "32G"
ACCE_CONTAINER_OPT: " --gpus all"
PIP_SOURCE: "https://mirror.baidu.com/pypi/simple"
CLEAR_CACHES: True
ACCE_VISIBLE_DEVICE_ENV_NAME: "CUDA_VISIBLE_DEVICES"
CASES:
"resnet50:pytorch_1.13": "/raid/dataset/ImageNet/imagenet/val"
14 changes: 14 additions & 0 deletions inference/configs/resnet50/configurations.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
batch_size: 256
# 3*224*224(1 item in x)
input_size: 150528
fp16: true
compiler: tensorrt
num_workers: 8
log_freq: 30
repeat: 5
# skip validation(will also skip create_model, export onnx). Assert exist_onnx_path != null
no_validation: false
# set a real onnx_path to use exist, or set it to anything but null to avoid export onnx manually(like torch-tensorrt)
exist_onnx_path: null
# set a exist path of engine file like resnet50.trt/resnet50.plan/resnet50.engine
exist_compiler_path: null
2 changes: 2 additions & 0 deletions inference/configs/resnet50/parameters.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# contain case-specified parameters, like max_seq_length in BERT.
# There is no parameters for resnet50 case.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
trt_tmp_path: nvidia_tmp/resnet50.trt
has_dynamic_axis: false
torchtrt_full_compile: true
14 changes: 14 additions & 0 deletions inference/docker_images/nvidia/nvidia_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
def analysis_log(logpath):
logfile = open(logpath)

max_usage = 0.0
max_mem = 0.0
for line in logfile.readlines():
if "MiB" in line:
usage = line.split(" ")[2]
usage = float(usage[:-3])
max_usage = max(max_usage, usage)
max_mem = line.split(" ")[3]
max_mem = float(max_mem[:-3])

return round(max_usage / 1024.0, 2), round(max_mem / 1024.0, 2)
Loading

0 comments on commit 7211a51

Please sign in to comment.