Skip to content

Commit

Permalink
feat: add fp16 inference support (torch/onnx) (#871)
Browse files Browse the repository at this point in the history
* feat: add fp16 inference in clip_torch

* Revert "feat: add fp16 inference in clip_torch"

This reverts commit 326e265.

* feat: add fp16 inference in clip_torch

* fix: device

* fix: str to torch.dtype

* fix: layernorm

* feat: add fp16 inference in clip_trt

* feat: add fp16 inference in clip_onnx

* fix: housekeeping

* fix: ci

* fix: ci

* fix: ci

* fix: ci and get test path

* fix: dtype amp and gpu test dependency

* fix: layernorm

* fix: cast dtype in visiontransformer

* fix: clip_onnx

* fix: clip_onnx

* fix: convert onnx to fp16

* fix: dtype in preproc images

* fix: dtype in preproc images

* fix: typo

* fix: dtype in clip_torch and fp16 in trt

* fix: remove plain text in trt_test

* fix: test

* fix: typo

* fix: stash

* Revert "fix: stash"

This reverts commit f72fd99.

* fix: for test

* fix: onnx

* fix: for test

* fix: for test

* fix: trt

* fix: convert onnx to fp16 before convert trt

* fix: discard changes in trt

* fix: optimize fp16 test

* fix: move __cast_dtype__

* Revert "fix: move __cast_dtype__"

This reverts commit edf4629.

* fix: ci
  • Loading branch information
OrangeSodahub authored Dec 8, 2022
1 parent fd16e5a commit 1fe3a5a
Show file tree
Hide file tree
Showing 9 changed files with 82 additions and 21 deletions.
6 changes: 4 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,11 @@ jobs:
pip install --no-cache-dir "server/[onnx]"
pip install --no-cache-dir "server/[transformers]"
pip install --no-cache-dir "server/[search]"
pip install --no-cache-dir "server/[transformers]"
- name: Test
id: test
run: |
pytest --suppress-no-test-exit-code --cov=clip_client --cov=clip_server --cov-report=xml \
-v -s -m "not gpu" ${{ matrix.test-path }}
-v -s ${{ matrix.test-path }}
echo "::set-output name=codecov_flag::cas"
timeout-minutes: 30
- name: Check codecov file
Expand Down Expand Up @@ -158,6 +157,7 @@ jobs:
python -m pip install wheel pytest pytest-cov nvidia-pyindex
pip install -e "client/[test]"
pip install -e "server/[tensorrt]"
pip install -e "server/[onnx]"
{
pip install -e "server/[flash-attn]"
} || {
Expand All @@ -168,6 +168,8 @@ jobs:
run: |
pytest --suppress-no-test-exit-code --cov=clip_client --cov=clip_server --cov-report=xml \
-v -s -m "gpu" ./tests/test_tensorrt.py
pytest --suppress-no-test-exit-code --cov=clip_client --cov=clip_server --cov-report=xml \
-v -s -m "gpu" ./tests/test_simple.py
echo "::set-output name=codecov_flag::cas"
timeout-minutes: 30
env:
Expand Down
20 changes: 12 additions & 8 deletions server/clip_server/executors/clip_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
minibatch_size: int = 32,
access_paths: str = '@r',
model_path: Optional[str] = None,
dtype: Optional[str] = None,
**kwargs,
):
"""
Expand All @@ -41,8 +42,17 @@ def __init__(
:param model_path: The path to the model to be used. If not specified, the model will be downloaded or loaded
from the local cache. Visit https://clip-as-service.jina.ai/user-guides/server/#use-custom-model-for-onnx
to learn how to finetune custom models.
:param dtype: inference data type, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
"""
super().__init__(**kwargs)
import torch

if not device:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self._device = device
if not dtype:
dtype = 'fp32' if self._device in ('cpu', torch.device('cpu')) else 'fp16'
self._dtype = dtype

self._minibatch_size = minibatch_size
self._access_paths = access_paths
Expand All @@ -55,18 +65,11 @@ def __init__(
self._num_worker_preprocess = num_worker_preprocess
self._pool = ThreadPool(processes=num_worker_preprocess)

self._model = CLIPOnnxModel(name, model_path)
self._model = CLIPOnnxModel(name, model_path, dtype)
self._tokenizer = Tokenizer(name)

self._image_transform = clip._transform_ndarray(self._model.image_size)

import torch

if not device:
self._device = 'cuda' if torch.cuda.is_available() else 'cpu'
else:
self._device = device

# define the priority order for the execution providers
providers = ['CPUExecutionProvider']

Expand Down Expand Up @@ -116,6 +119,7 @@ def _preproc_images(self, docs: 'DocumentArray', drop_image_content: bool):
preprocess_fn=self._image_transform,
return_np=True,
drop_image_content=drop_image_content,
dtype=self._dtype,
)

def _preproc_texts(self, docs: 'DocumentArray'):
Expand Down
24 changes: 19 additions & 5 deletions server/clip_server/executors/clip_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings
from functools import partial
from multiprocessing.pool import ThreadPool
from typing import Dict, Optional
from typing import Dict, Union, Optional

import numpy as np
import torch
Expand All @@ -12,6 +12,7 @@
set_rank,
split_img_txt_da,
)
from clip_server.helper import __cast_dtype__
from clip_server.model import clip
from clip_server.model.clip_model import CLIPModel
from clip_server.model.tokenization import Tokenizer
Expand All @@ -28,6 +29,7 @@ def __init__(
num_worker_preprocess: int = 4,
minibatch_size: int = 32,
access_paths: str = '@r',
dtype: Optional[Union[str, torch.dtype]] = None,
**kwargs,
):
"""
Expand All @@ -40,6 +42,7 @@ def __init__(
number if you encounter OOM errors.
:param access_paths: The access paths to traverse on the input documents to get the images and texts to be
processed. Visit https://docarray.jina.ai/fundamentals/documentarray/access-elements for more details.
:param dtype: inference data type, if None defaults to torch.float32 if device == 'cpu' else torch.float16.
"""
super().__init__(**kwargs)

Expand All @@ -52,9 +55,17 @@ def __init__(
self._access_paths = kwargs['traversal_paths']

if not device:
self._device = 'cuda' if torch.cuda.is_available() else 'cpu'
else:
self._device = device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self._device = device
if isinstance(dtype, str):
dtype = __cast_dtype__.get(dtype)
elif not dtype:
dtype = (
torch.float32
if self._device in ('cpu', torch.device('cpu'))
else torch.float16
)
self._dtype = dtype

if not self._device.startswith('cuda') and (
'OMP_NUM_THREADS' not in os.environ
Expand All @@ -77,7 +88,9 @@ def __init__(
self._num_worker_preprocess = num_worker_preprocess
self._pool = ThreadPool(processes=num_worker_preprocess)

self._model = CLIPModel(name, device=self._device, jit=jit, **kwargs)
self._model = CLIPModel(
name, device=self._device, jit=jit, dtype=dtype, **kwargs
)
self._tokenizer = Tokenizer(name)
self._image_transform = clip._transform_ndarray(self._model.image_size)

Expand All @@ -96,6 +109,7 @@ def _preproc_images(self, docs: 'DocumentArray', drop_image_content: bool):
device=self._device,
return_np=False,
drop_image_content=drop_image_content,
dtype=self._dtype,
)

def _preproc_texts(self, docs: 'DocumentArray'):
Expand Down
9 changes: 7 additions & 2 deletions server/clip_server/executors/helper.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Tuple, List, Callable, Any, Dict
from typing import Tuple, List, Callable, Any, Dict, Union
import torch
import numpy as np
from docarray import Document, DocumentArray
from docarray.math.distance.numpy import cosine
from clip_server.helper import __cast_dtype__


from clip_server.model.tokenization import Tokenizer
Expand All @@ -22,8 +23,12 @@ def preproc_image(
device: str = 'cpu',
return_np: bool = False,
drop_image_content: bool = False,
dtype: Union[str, torch.dtype] = torch.float32,
) -> Tuple['DocumentArray', Dict]:

if isinstance(dtype, str):
dtype = __cast_dtype__.get(dtype)

tensors_batch = []

for d in da:
Expand All @@ -42,7 +47,7 @@ def preproc_image(
if drop_image_content:
d.pop('blob', 'tensor')

tensors_batch = torch.stack(tensors_batch).type(torch.float32)
tensors_batch = torch.stack(tensors_batch).type(dtype)

if return_np:
tensors_batch = tensors_batch.cpu().numpy()
Expand Down
4 changes: 4 additions & 0 deletions server/clip_server/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import sys
import threading
import torch
from packaging.version import Version
from urllib.request import Request, urlopen

Expand All @@ -19,6 +20,9 @@
)


__cast_dtype__ = {'fp16': torch.float16, 'fp32': torch.float32, 'bf16': torch.bfloat16}


def _version_check(package: str = None, github_repo: str = None):
try:

Expand Down
23 changes: 21 additions & 2 deletions server/clip_server/model/clip_onnx.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Dict
from typing import Dict, Optional

from clip_server.model.pretrained_models import (
download_model,
Expand Down Expand Up @@ -201,8 +201,11 @@


class CLIPOnnxModel(BaseCLIPModel):
def __init__(self, name: str, model_path: str = None):
def __init__(
self, name: str, model_path: str = None, dtype: Optional[str] = 'fp32'
):
super().__init__(name)
self._dtype = dtype
if name in _MODELS:
if not model_path:
cache_dir = os.path.expanduser(
Expand Down Expand Up @@ -237,6 +240,22 @@ def __init__(self, name: str, model_path: str = None):
f'The given model path {model_path} should be a folder containing both '
f'`textual.onnx` and `visual.onnx`.'
)
if dtype == 'fp16':
import onnx
from onnxmltools.utils import float16_converter

_textual_model_fp16 = (
float16_converter.convert_float_to_float16_model_path(
self._textual_path
)
)
_visual_model_fp16 = (
float16_converter.convert_float_to_float16_model_path(
self._visual_path
)
)
onnx.save_model(_textual_model_fp16, self._textual_path)
onnx.save_model(_visual_model_fp16, self._visual_path)
else:
raise RuntimeError(
'CLIP model {} not found or not supports ONNX backend; below is a list of all available models:\n{}'.format(
Expand Down
14 changes: 12 additions & 2 deletions server/clip_server/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from dataclasses import dataclass
from typing import Tuple, Union, Optional
from copy import deepcopy
from clip_server.helper import __cast_dtype__
from open_clip.transformer import QuickGELU, LayerNorm, LayerNormFp32, Attention
from open_clip.timm_model import TimmModel
from open_clip.factory import _MODEL_CONFIGS
Expand Down Expand Up @@ -81,6 +82,11 @@ def __init__(
super().__init__(image_size, patch_size, output_dim=output_dim, **kwargs)
self.transformer = Transformer(dtype=dtype, **kwargs)

def forward(self, x: torch.Tensor):
dtype = self.transformer.get_cast_dtype()
x = x.to(dtype)
return super().forward(x)


class TextTransformer(_TextTransformer):
def __init__(
Expand Down Expand Up @@ -435,7 +441,9 @@ def load_openai_model(
preprocess : Callable[[PIL.Image], torch.Tensor]
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
"""
if dtype is None:
if isinstance(dtype, str):
dtype = __cast_dtype__.get(dtype, 'amp')
elif dtype is None:
dtype = (
torch.float32 if device in ('cpu', torch.device('cpu')) else torch.float16
)
Expand Down Expand Up @@ -550,7 +558,9 @@ def load_openclip_model(
pretrained_image: bool = False,
dtype: Optional[Union[str, torch.dtype]] = None,
):
if dtype is None:
if isinstance(dtype, str):
dtype = __cast_dtype__.get(dtype)
elif dtype is None:
dtype = (
torch.float32 if device in ('cpu', torch.device('cpu')) else torch.float16
)
Expand Down
1 change: 1 addition & 0 deletions server/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
'onnx': [
'onnxruntime',
'onnx',
'onnxmltools',
]
+ (['onnxruntime-gpu>=1.8.0'] if sys.platform != 'darwin' else []),
'tensorrt': ['nvidia-tensorrt'],
Expand Down
2 changes: 2 additions & 0 deletions tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def test_protocols(port_generator, protocol, jit, pytestconfig):
c.profile(content=f'{pytestconfig.rootdir}/tests/img/00000.jpg')


@pytest.mark.gpu
@pytest.mark.parametrize(
'inputs',
[
Expand All @@ -48,6 +49,7 @@ def test_plain_inputs(make_flow, inputs):
)


@pytest.mark.gpu
@pytest.mark.parametrize(
'inputs',
[
Expand Down

0 comments on commit 1fe3a5a

Please sign in to comment.