Skip to content

Commit

Permalink
perf: scalable benchmark (#680)
Browse files Browse the repository at this point in the history
* fix: multithread

* fix: disable threads

* fix: benchmark

* fix: move bench to scripts folder

* fix: tune threads used by pytorch

* fix: docstr

* fix: optimize onnx

* fix: onnx optimization

* fix: remove providers argument from onnx executor

* fix: clear codes

* fix: more ort optimization

* fix: minor revision

* fix: add onnx optim

* fix: revision

* fix: bump onnxruntime-gpu version

* fix: temp float16 support

* fix: add onnx qunantize

* fix: polish codes

* fix: cast embedding fp16 to fp32

* fix: clean codes

* fix: revert quantization

* fix: clean setup
  • Loading branch information
numb3r3 authored Apr 18, 2022
1 parent fb229ae commit 10d53eb
Show file tree
Hide file tree
Showing 5 changed files with 286 additions and 30 deletions.
141 changes: 141 additions & 0 deletions scripts/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import random
import time
from typing import Optional
import threading
import click
import numpy as np
from docarray import Document, DocumentArray


def warn(*args, **kwargs):
pass


import warnings

warnings.warn = warn


class BenchmarkClient(threading.Thread):
def __init__(
self,
server: str,
batch_size: int = 1,
modality: str = 'text',
num_iter: Optional[int] = 100,
image_sample: str = None,
**kwargs,
):
"""
@param server: the clip-as-service server URI
@param batch_size: number of batch sample
@param num_iter: number of repeat run per experiment
@param image_sample: uri of the test image
"""
assert num_iter > 2, 'num_iter must be greater than 2'
super().__init__()
self.server = server
self.batch_size = batch_size
self.modality = modality
self.image_sample = image_sample
self.num_iter = num_iter
self.avg_time = 0

def run(self):
try:
from clip_client import Client
except ImportError:
raise ImportError(
'clip_client module is not available. it is required for benchmarking.'
'Please use ""pip install clip-client" to install it.'
)

if self.modality == 'text':
from clip_server.model.simple_tokenizer import SimpleTokenizer

tokenizer = SimpleTokenizer()
vocab = list(tokenizer.encoder.keys())
batch = DocumentArray(
[
Document(text=' '.join(random.choices(vocab, k=78)))
for _ in range(self.batch_size)
]
)
elif self.modality == 'image':
batch = DocumentArray(
[
Document(blob=open(self.image_sample, 'rb').read())
for _ in range(self.batch_size)
]
)
else:
raise ValueError(f'The modality "{self.modality}" is unsupported')

client = Client(self.server)

time_costs = []
for _ in range(self.num_iter):
start = time.perf_counter()
r = client.encode(batch)
time_costs.append(time.perf_counter() - start)
self.avg_time = np.mean(time_costs[2:])


@click.command(name='clip-as-service benchmark')
@click.argument('server')
@click.option(
'--batch_sizes',
multiple=True,
type=int,
default=[1, 8, 16, 32, 64],
help='number of batch',
)
@click.option(
'--num_iter', default=10, help='number of repeat run per experiment (must > 2)'
)
@click.option(
"--concurrent_clients",
multiple=True,
type=int,
default=[1, 4, 16, 32, 64],
help='number of concurrent clients per experiment',
)
@click.option("--image_sample", help='path to the image sample file')
def main(server, batch_sizes, num_iter, concurrent_clients, image_sample):
# wait until the server is ready
for batch_size in batch_sizes:
for num_client in concurrent_clients:
all_clients = [
BenchmarkClient(
server,
batch_size=batch_size,
num_iter=num_iter,
modality='image' if (image_sample is not None) else 'text',
image_sample=image_sample,
)
for _ in range(num_client)
]

for bc in all_clients:
bc.start()

clients_speed = []
for bc in all_clients:
bc.join()
clients_speed.append(batch_size / bc.avg_time)

max_speed, min_speed, avg_speed = (
max(clients_speed),
min(clients_speed),
np.mean(clients_speed),
)

print(
'(concurrent client=%d, batch_size=%d) avg speed: %.3f\tmax speed: %.3f\tmin speed: %.3f'
% (num_client, batch_size, avg_speed, max_speed, min_speed),
flush=True,
)


if __name__ == '__main__':
main()
58 changes: 47 additions & 11 deletions server/clip_server/executors/clip_onnx.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import io
import os
from multiprocessing.pool import ThreadPool, Pool
from typing import List, Sequence, Tuple
from typing import List, Tuple, Optional

import onnxruntime as ort

from PIL import Image
from jina import Executor, requests, DocumentArray

from clip_server.model import clip
Expand All @@ -24,26 +25,61 @@ class CLIPEncoder(Executor):
def __init__(
self,
name: str = 'ViT-B/32',
providers: Sequence = (
'TensorrtExecutionProvider',
'CUDAExecutionProvider',
'CPUExecutionProvider',
),
device: Optional[str] = None,
num_worker_preprocess: int = 4,
minibatch_size: int = 64,
pool_backend: str = 'thread',
**kwargs
**kwargs,
):
super().__init__(**kwargs)
self._preprocess_blob = clip._transform_blob(_SIZE[name])
self._preprocess_tensor = clip._transform_ndarray(_SIZE[name])
self._model = CLIPOnnxModel(name)
if pool_backend == 'thread':
self._pool = ThreadPool(processes=num_worker_preprocess)
else:
self._pool = Pool(processes=num_worker_preprocess)
self._minibatch_size = minibatch_size
self._model.start_sessions(providers=providers)

self._model = CLIPOnnxModel(name)

import torch

if not device:
self._device = 'cuda' if torch.cuda.is_available() else 'cpu'
else:
self._device = device

# define the priority order for the execution providers
providers = ['CPUExecutionProvider']

# prefer CUDA Execution Provider over CPU Execution Provider
if self._device == 'cuda':
providers.insert(0, 'CUDAExecutionProvider')
# TODO: support tensorrt
# providers.insert(0, 'TensorrtExecutionProvider')

sess_options = ort.SessionOptions()

# Enables all available optimizations including layout optimizations
sess_options.graph_optimization_level = (
ort.GraphOptimizationLevel.ORT_ENABLE_ALL
)

if self._device != 'cuda' and (not os.environ.get('OMP_NUM_THREADS')):
num_threads = torch.get_num_threads() // self.runtime_args.replicas
if num_threads < 2:
self.logger.warning(
f'Too many encoder replicas ({self.runtime_args.replicas})'
)

# Run the operators in the graph in parallel (not support the CUDA Execution Provider)
sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL

# The number of threads used to parallelize the execution of the graph (across nodes)
sess_options.inter_op_num_threads = 1
sess_options.intra_op_num_threads = max(num_threads, 1)

self._model.start_sessions(sess_options=sess_options, providers=providers)

def _preproc_image(self, da: 'DocumentArray') -> 'DocumentArray':
for d in da:
Expand Down
39 changes: 33 additions & 6 deletions server/clip_server/executors/clip_torch.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import io
import os
import numpy as np
from multiprocessing.pool import ThreadPool, Pool
from typing import Optional, List, Tuple

import torch
from PIL import Image
from jina import Executor, requests, DocumentArray
from jina.logging.logger import JinaLogger

from clip_server.model import clip

Expand All @@ -18,13 +18,32 @@ def __init__(
num_worker_preprocess: int = 4,
minibatch_size: int = 64,
pool_backend: str = 'thread',
**kwargs
**kwargs,
):
super().__init__(**kwargs)
self.logger = JinaLogger(self.__class__.__name__)

import torch

if not device:
self._device = 'cuda' if torch.cuda.is_available() else 'cpu'
else:
self._device = device

if self._device != 'cuda' and (not os.environ.get('OMP_NUM_THREADS')):
num_threads = torch.get_num_threads() // self.runtime_args.replicas
if num_threads < 2:
self.logger.warning(
f'Too many encoder replicas ({self.runtime_args.replicas})'
)

# NOTE: make sure to set the threads right after the torch import,
# and `torch.set_num_threads` always take precedence over environment variables `OMP_NUM_THREADS`.
# For more details, please see https://pytorch.org/docs/stable/generated/torch.set_num_threads.html
# FIXME: This hack would harm the performance in K8S deployment.
torch.set_num_threads(max(num_threads, 1))
torch.set_num_interop_threads(1)

self._minibatch_size = minibatch_size
self._model, self._preprocess_blob, self._preprocess_tensor = clip.load(
name, device=self._device, jit=jit
Expand Down Expand Up @@ -59,6 +78,8 @@ async def encode(self, docs: 'DocumentArray', **kwargs):
)
_txt_da = docs.find({'text': {'$exists': True}})

import torch

with torch.inference_mode():
# for image
if _img_da:
Expand All @@ -68,7 +89,10 @@ async def encode(self, docs: 'DocumentArray', **kwargs):
pool=self._pool,
):
minibatch.embeddings = (
self._model.encode_image(minibatch.tensors).cpu().numpy()
self._model.encode_image(minibatch.tensors)
.cpu()
.numpy()
.astype(np.float32)
)

# for text
Expand All @@ -79,7 +103,10 @@ async def encode(self, docs: 'DocumentArray', **kwargs):
pool=self._pool,
):
minibatch.embeddings = (
self._model.encode_text(minibatch.tensors).cpu().numpy()
self._model.encode_text(minibatch.tensors)
.cpu()
.numpy()
.astype(np.float32)
)
minibatch.texts = _texts

Expand Down
70 changes: 60 additions & 10 deletions server/clip_server/model/clip_onnx.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import onnxruntime

import os
import onnx
import onnxruntime as ort
from .clip import _download, available_models

_S3_BUCKET = 'https://clip-as-service.s3.us-east-2.amazonaws.com/models/onnx/'
Expand All @@ -17,10 +17,7 @@


class CLIPOnnxModel:
def __init__(
self,
name: str = None,
):
def __init__(self, name: str = None):
if name in _MODELS:
cache_dir = os.path.expanduser(f'~/.cache/clip/{name.replace("/", "-")}')
self._textual_path = _download(_S3_BUCKET + _MODELS[name][0], cache_dir)
Expand All @@ -34,10 +31,11 @@ def start_sessions(
self,
**kwargs,
):
self._visual_session = onnxruntime.InferenceSession(self._visual_path, **kwargs)
self._textual_session = onnxruntime.InferenceSession(
self._textual_path, **kwargs
)
self._visual_session = ort.InferenceSession(self._visual_path, **kwargs)
self._visual_session.disable_fallback()

self._textual_session = ort.InferenceSession(self._textual_path, **kwargs)
self._textual_session.disable_fallback()

def encode_image(self, onnx_image):
onnx_input_image = {self._visual_session.get_inputs()[0].name: onnx_image}
Expand All @@ -48,3 +46,55 @@ def encode_text(self, onnx_text):
onnx_input_text = {self._textual_session.get_inputs()[0].name: onnx_text}
(textual_output,) = self._textual_session.run(None, onnx_input_text)
return textual_output


def convert_float_to_float16(model_path: str, output_model_path: str):
from onnxmltools.utils.float16_converter import (
convert_float_to_float16_model_path,
)

new_onnx_model = convert_float_to_float16_model_path(model_path)

onnx.save(new_onnx_model, output_model_path)

# Alternate approach
# from onnx import load_model
# from onnxruntime.transformers import optimizer, onnx_model
#
# # optimized_model = optimizer.optimize_model(model_path, model_type='bert')
#
# model = load_model(model_path)
# optimized_model = onnx_model.OnnxModel(model)
#
# if hasattr(optimized_model, 'convert_float32_to_float16'):
# optimized_model.convert_float_to_float16()
# else:
# optimized_model.convert_model_float32_to_float16()
#
# self._textual_path = f'{self._textual_path[:-5]}_optimized.onnx'
# optimized_model.save_model_to_file(output_model_path)


def quantize(model_path: str, output_model_path: str):
"""
Quantize the weights of the model from float32 to in8 to allow very efficient inference on modern CPU
Uses unsigned ints for activation values, signed ints for weights, per
https://onnxruntime.ai/docs/performance/quantization.html#data-type-selection
it is faster on most CPU architectures
Args:
onnx_model_path: Path to location the exported ONNX model is stored
Returns: The Path generated for the quantized
"""
from onnxruntime.quantization import quantize_dynamic, QuantType

quantize_dynamic(
model_input=model_path,
model_output=output_model_path,
per_channel=True,
reduce_range=True, # should be the same as per_channel
activation_type=QuantType.QUInt8,
weight_type=QuantType.QInt8, # per docs, signed is faster on most CPUs
optimize_model=True,
op_types_to_quantize=["MatMul", "Attention", "Mul", "Add"],
extra_options={"WeightSymmetric": False, "MatMulConstBOnly": True},
) # op_types_to_quantize=['MatMul', 'Relu', 'Add', 'Mul' ],
Loading

0 comments on commit 10d53eb

Please sign in to comment.