Skip to content

Commit 7eee01e

Browse files
authored
Merge 32907fd into e6a8791
2 parents e6a8791 + 32907fd commit 7eee01e

9 files changed

+289
-1
lines changed

docker/serve/Dockerfile

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
ARG PYTORCH="1.6.0"
2+
ARG CUDA="10.1"
3+
ARG CUDNN="7"
4+
FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel
5+
6+
ARG MMCV="1.3.1"
7+
ARG MMSEG="0.13.0"
8+
9+
ENV PYTHONUNBUFFERED TRUE
10+
11+
RUN apt-get update && \
12+
DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \
13+
ca-certificates \
14+
g++ \
15+
openjdk-11-jre-headless \
16+
# MMDet Requirements
17+
ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6 \
18+
&& rm -rf /var/lib/apt/lists/*
19+
20+
ENV PATH="/opt/conda/bin:$PATH"
21+
RUN export FORCE_CUDA=1
22+
23+
# TORCHSEVER
24+
RUN pip install torchserve torch-model-archiver
25+
26+
# MMLAB
27+
RUN pip install mmcv-full==${MMCV} -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.6.0/index.html
28+
RUN pip install mmsegmentation==${MMSEG}
29+
30+
RUN useradd -m model-server \
31+
&& mkdir -p /home/model-server/tmp
32+
33+
COPY entrypoint.sh /usr/local/bin/entrypoint.sh
34+
35+
RUN chmod +x /usr/local/bin/entrypoint.sh \
36+
&& chown -R model-server /home/model-server
37+
38+
COPY config.properties /home/model-server/config.properties
39+
RUN mkdir /home/model-server/model-store && chown -R model-server /home/model-server/model-store
40+
41+
EXPOSE 8080 8081 8082
42+
43+
USER model-server
44+
WORKDIR /home/model-server
45+
ENV TEMP=/home/model-server/tmp
46+
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
47+
CMD ["serve"]

docker/serve/config.properties

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
inference_address=http://0.0.0.0:8080
2+
management_address=http://0.0.0.0:8081
3+
metrics_address=http://0.0.0.0:8082
4+
model_store=/home/model-server/model-store
5+
load_models=all

docker/serve/entrypoint.sh

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#!/bin/bash
2+
set -e
3+
4+
if [[ "$1" = "serve" ]]; then
5+
shift 1
6+
torchserve --start --ts-config /home/model-server/config.properties
7+
else
8+
eval "$@"
9+
fi
10+
11+
# prevent docker exit
12+
tail -f /dev/null

docs/useful_tools.md

+61
Original file line numberDiff line numberDiff line change
@@ -254,3 +254,64 @@ Examples:
254254
```shell
255255
python tools/analyze_logs.py log.json --keys loss --legend loss
256256
```
257+
258+
## Model Serving
259+
260+
In order to serve an `MMSegmentation` model with [`TorchServe`](https://pytorch.org/serve/), you can follow the steps:
261+
262+
### 1. Convert model from MMSegmentation to TorchServe
263+
264+
```shell
265+
python tools/mmseg2torchserve.py ${CONFIG_FILE} ${CHECKPOINT_FILE} \
266+
--output-folder ${MODEL_STORE} \
267+
--model-name ${MODEL_NAME}
268+
```
269+
270+
**Note**: ${MODEL_STORE} needs to be an absolute path to a folder.
271+
272+
### 2. Build `mmseg-serve` docker image
273+
274+
```shell
275+
docker build -t mmseg-serve:latest docker/serve/
276+
```
277+
278+
### 3. Run `mmseg-serve`
279+
280+
Check the official docs for [running TorchServe with docker](https://github.com/pytorch/serve/blob/master/docker/README.md#running-torchserve-in-a-production-docker-environment).
281+
282+
In order to run in GPU, you need to install [nvidia-docker](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). You can omit the `--gpus` argument in order to run in CPU.
283+
284+
Example:
285+
286+
```shell
287+
docker run --rm \
288+
--cpus 8 \
289+
--gpus device=0 \
290+
-p8080:8080 -p8081:8081 -p8082:8082 \
291+
--mount type=bind,source=$MODEL_STORE,target=/home/model-server/model-store \
292+
mmseg-serve:latest
293+
```
294+
295+
[Read the docs](https://github.com/pytorch/serve/blob/072f5d088cce9bb64b2a18af065886c9b01b317b/docs/rest_api.md) about the Inference (8080), Management (8081) and Metrics (8082) APis
296+
297+
### 4. Test deployment
298+
299+
```shell
300+
curl -O https://raw.githubusercontent.com/open-mmlab/mmsegmentation/master/resources/3dogs.jpg
301+
curl http://127.0.0.1:8080/predictions/${MODEL_NAME} -T 3dogs.jpg -o 3dogs_mask.png
302+
```
303+
304+
The response will be a ".png" mask.
305+
306+
You can visualize the output as follows:
307+
308+
```python
309+
import matplotlib.pyplot as plt
310+
import mmcv
311+
plt.imshow(mmcv.imread("3dogs_mask.png", "grayscale"))
312+
plt.show()
313+
```
314+
315+
You should see something similar to:
316+
317+
![3dogs_mask](../resources/3dogs_mask.png)

resources/3dogs.jpg

181 KB
Loading

resources/3dogs_mask.png

19.2 KB
Loading

setup.cfg

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,6 @@ line_length = 79
88
multi_line_output = 0
99
known_standard_library = setuptools
1010
known_first_party = mmseg
11-
known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,prettytable,pytest,scipy,seaborn,torch
11+
known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,prettytable,pytest,scipy,seaborn,torch,ts
1212
no_lines_before = STDLIB,LOCALFOLDER
1313
default_section = THIRDPARTY

tools/mmseg2torchserve.py

+110
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
from argparse import ArgumentParser, Namespace
2+
from pathlib import Path
3+
from tempfile import TemporaryDirectory
4+
5+
import mmcv
6+
7+
try:
8+
from model_archiver.model_packaging import package_model
9+
from model_archiver.model_packaging_utils import ModelExportUtils
10+
except ImportError:
11+
package_model = None
12+
13+
14+
def mmseg2torchserve(
15+
config_file: str,
16+
checkpoint_file: str,
17+
output_folder: str,
18+
model_name: str,
19+
model_version: str = '1.0',
20+
force: bool = False,
21+
):
22+
"""Converts mmsegmentation model (config + checkpoint) to TorchServe
23+
`.mar`.
24+
25+
Args:
26+
config_file:
27+
In MMSegmentation config format.
28+
The contents vary for each task repository.
29+
checkpoint_file:
30+
In MMSegmentation checkpoint format.
31+
The contents vary for each task repository.
32+
output_folder:
33+
Folder where `{model_name}.mar` will be created.
34+
The file created will be in TorchServe archive format.
35+
model_name:
36+
If not None, used for naming the `{model_name}.mar` file
37+
that will be created under `output_folder`.
38+
If None, `{Path(checkpoint_file).stem}` will be used.
39+
model_version:
40+
Model's version.
41+
force:
42+
If True, if there is an existing `{model_name}.mar`
43+
file under `output_folder` it will be overwritten.
44+
"""
45+
mmcv.mkdir_or_exist(output_folder)
46+
47+
config = mmcv.Config.fromfile(config_file)
48+
49+
with TemporaryDirectory() as tmpdir:
50+
config.dump(f'{tmpdir}/config.py')
51+
52+
args = Namespace(
53+
**{
54+
'model_file': f'{tmpdir}/config.py',
55+
'serialized_file': checkpoint_file,
56+
'handler': f'{Path(__file__).parent}/mmseg_handler.py',
57+
'model_name': model_name or Path(checkpoint_file).stem,
58+
'version': model_version,
59+
'export_path': output_folder,
60+
'force': force,
61+
'requirements_file': None,
62+
'extra_files': None,
63+
'runtime': 'python',
64+
'archive_format': 'default'
65+
})
66+
manifest = ModelExportUtils.generate_manifest_json(args)
67+
package_model(args, manifest)
68+
69+
70+
def parse_args():
71+
parser = ArgumentParser(
72+
description='Convert mmseg models to TorchServe `.mar` format.')
73+
parser.add_argument('config', type=str, help='config file path')
74+
parser.add_argument('checkpoint', type=str, help='checkpoint file path')
75+
parser.add_argument(
76+
'--output-folder',
77+
type=str,
78+
required=True,
79+
help='Folder where `{model_name}.mar` will be created.')
80+
parser.add_argument(
81+
'--model-name',
82+
type=str,
83+
default=None,
84+
help='If not None, used for naming the `{model_name}.mar`'
85+
'file that will be created under `output_folder`.'
86+
'If None, `{Path(checkpoint_file).stem}` will be used.')
87+
parser.add_argument(
88+
'--model-version',
89+
type=str,
90+
default='1.0',
91+
help='Number used for versioning.')
92+
parser.add_argument(
93+
'-f',
94+
'--force',
95+
action='store_true',
96+
help='overwrite the existing `{model_name}.mar`')
97+
args = parser.parse_args()
98+
99+
return args
100+
101+
102+
if __name__ == '__main__':
103+
args = parse_args()
104+
105+
if package_model is None:
106+
raise ImportError('`torch-model-archiver` is required.'
107+
'Try: pip install torch-model-archiver')
108+
109+
mmseg2torchserve(args.config, args.checkpoint, args.output_folder,
110+
args.model_name, args.model_version, args.force)

tools/mmseg_handler.py

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import base64
2+
import io
3+
import os
4+
5+
import cv2
6+
import mmcv
7+
import torch
8+
from ts.torch_handler.base_handler import BaseHandler
9+
10+
from mmseg.apis import inference_segmentor, init_segmentor
11+
12+
13+
class MMsegHandler(BaseHandler):
14+
15+
def initialize(self, context):
16+
properties = context.system_properties
17+
self.map_location = 'cuda' if torch.cuda.is_available() else 'cpu'
18+
self.device = torch.device(self.map_location + ':' +
19+
str(properties.get('gpu_id')) if torch.cuda.
20+
is_available() else self.map_location)
21+
self.manifest = context.manifest
22+
23+
model_dir = properties.get('model_dir')
24+
serialized_file = self.manifest['model']['serializedFile']
25+
checkpoint = os.path.join(model_dir, serialized_file)
26+
self.config_file = os.path.join(model_dir, 'config.py')
27+
28+
self.model = init_segmentor(self.config_file, checkpoint, self.device)
29+
self.initialized = True
30+
31+
def preprocess(self, data):
32+
images = []
33+
34+
for row in data:
35+
image = row.get('data') or row.get('body')
36+
if isinstance(image, str):
37+
image = base64.b64decode(image)
38+
image = mmcv.imfrombytes(image)
39+
images.append(image)
40+
41+
return images
42+
43+
def inference(self, data, *args, **kwargs):
44+
results = [inference_segmentor(self.model, img) for img in data]
45+
return results
46+
47+
def postprocess(self, data):
48+
output = []
49+
for image_result in data:
50+
buffer = io.BytesIO()
51+
_, buffer = cv2.imencode('.png', image_result[0].astype('uint8'))
52+
output.append(buffer.tobytes())
53+
return output

0 commit comments

Comments
 (0)