Skip to content

Commit 988fd33

Browse files
[Fea] Support python inference (#773)
* [Doc] Add pretrained model for laplace2d & refine comments (#639) * update laplace2d pretrained model * remove 'after finished training' comment in evaluate function * update README.md * add deploy module for aneurysm * update code * update aneurysm code * update code * update code * update code * update aneurysm document * update export and inference document * fix docstring
1 parent a1ed7a3 commit 988fd33

File tree

13 files changed

+706
-226
lines changed

13 files changed

+706
-226
lines changed

deploy/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
deploy module is designed for inference and deployment.
17+
"""

deploy/python_infer/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.

deploy/python_infer/base.py

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import platform
18+
from os import path as osp
19+
from typing import TYPE_CHECKING
20+
from typing import Optional
21+
from typing import Tuple
22+
23+
from paddle import inference as paddle_inference
24+
from typing_extensions import Literal
25+
26+
from ppsci.utils import logger
27+
28+
if TYPE_CHECKING:
29+
import onnxruntime
30+
31+
32+
class Predictor:
33+
"""
34+
Initializes the inference engine with the given parameters.
35+
36+
Args:
37+
pdmodel_path (Optional[str]): Path to the PaddlePaddle model file. Defaults to None.
38+
pdpiparams_path (Optional[str]): Path to the PaddlePaddle model parameters file. Defaults to None.
39+
device (Literal["gpu", "cpu", "npu", "xpu"], optional): Device to use for inference. Defaults to "cpu".
40+
engine (Literal["native", "tensorrt", "onnx", "mkldnn"], optional): Inference engine to use. Defaults to "native".
41+
precision (Literal["fp32", "fp16", "int8"], optional): Precision to use for inference. Defaults to "fp32".
42+
onnx_path (Optional[str], optional): Path to the ONNX model file. Defaults to None.
43+
ir_optim (bool, optional): Whether to use IR optimization. Defaults to True.
44+
min_subgraph_size (int, optional): Minimum subgraph size for IR optimization. Defaults to 15.
45+
gpu_mem (int, optional): Initial size of GPU memory pool(MB). Defaults to 500(MB).
46+
gpu_id (int, optional): GPU ID to use. Defaults to 0.
47+
num_cpu_threads (int, optional): Number of CPU threads to use. Defaults to 1.
48+
"""
49+
50+
def __init__(
51+
self,
52+
pdmodel_path: Optional[str] = None,
53+
pdpiparams_path: Optional[str] = None,
54+
*,
55+
device: Literal["gpu", "cpu", "npu", "xpu"] = "cpu",
56+
engine: Literal["native", "tensorrt", "onnx", "mkldnn"] = "native",
57+
precision: Literal["fp32", "fp16", "int8"] = "fp32",
58+
onnx_path: Optional[str] = None,
59+
ir_optim: bool = True,
60+
min_subgraph_size: int = 15,
61+
gpu_mem: int = 500,
62+
gpu_id: int = 0,
63+
max_batch_size: int = 10,
64+
num_cpu_threads: int = 10,
65+
):
66+
self.pdmodel_path = pdmodel_path
67+
self.pdpiparams_path = pdpiparams_path
68+
69+
self._check_device(device)
70+
self.device = device
71+
self._check_engine(engine)
72+
self.engine = engine
73+
self._check_precision(precision)
74+
self.precision = precision
75+
76+
self.onnx_path = onnx_path
77+
self.ir_optim = ir_optim
78+
self.min_subgraph_size = min_subgraph_size
79+
self.gpu_mem = gpu_mem
80+
self.gpu_id = gpu_id
81+
self.max_batch_size = max_batch_size
82+
self.num_cpu_threads = num_cpu_threads
83+
84+
if self.engine == "onnx":
85+
self.predictor, self.config = self._create_onnx_predictor()
86+
else:
87+
self.predictor, self.config = self._create_paddle_predictor()
88+
89+
logger.message(
90+
f"Inference with engine: {self.engine}, precision: {self.precision}, "
91+
f"device: {self.device}."
92+
)
93+
94+
def predict(self, image):
95+
raise NotImplementedError
96+
97+
def _create_paddle_predictor(
98+
self,
99+
) -> Tuple[paddle_inference.Predictor, paddle_inference.Config]:
100+
if not osp.exists(self.pdmodel_path):
101+
raise FileNotFoundError(
102+
f"Given 'pdmodel_path': {self.pdmodel_path} does not exist. "
103+
"Please check if it is correct."
104+
)
105+
if not osp.exists(self.pdpiparams_path):
106+
raise FileNotFoundError(
107+
f"Given 'pdpiparams_path': {self.pdpiparams_path} does not exist. "
108+
"Please check if it is correct."
109+
)
110+
111+
config = paddle_inference.Config(self.pdmodel_path, self.pdpiparams_path)
112+
if self.device == "gpu":
113+
config.enable_use_gpu(self.gpu_mem, self.gpu_id)
114+
if self.engine == "tensorrt":
115+
if self.precision == "fp16":
116+
precision = paddle_inference.Config.Precision.Half
117+
elif self.precision == "int8":
118+
precision = paddle_inference.Config.Precision.Int8
119+
else:
120+
precision = paddle_inference.Config.Precision.Float32
121+
config.enable_tensorrt_engine(
122+
workspace_size=1 << 30,
123+
precision_mode=precision,
124+
max_batch_size=self.max_batch_size,
125+
min_subgraph_size=self.min_subgraph_size,
126+
use_calib_mode=False,
127+
)
128+
# collect shape
129+
pdmodel_dir = osp.dirname(self.pdmodel_path)
130+
trt_shape_path = osp.join(pdmodel_dir, "trt_dynamic_shape.txt")
131+
132+
if not osp.exists(trt_shape_path):
133+
config.collect_shape_range_info(trt_shape_path)
134+
logger.info(
135+
f"Save collected dynamic shape info to: {trt_shape_path}"
136+
)
137+
try:
138+
config.enable_tuned_tensorrt_dynamic_shape(trt_shape_path, True)
139+
except Exception as e:
140+
logger.warning(e)
141+
logger.warning(
142+
"TRT dynamic shape is disabled for your paddlepaddle < 2.3.0"
143+
)
144+
145+
elif self.device == "npu":
146+
config.enable_custom_device("npu")
147+
elif self.device == "xpu":
148+
config.enable_xpu(10 * 1024 * 1024)
149+
else:
150+
config.disable_gpu()
151+
if self.engine == "mkldnn":
152+
# 'set_mkldnn_cache_capatity' is not available on macOS
153+
if platform.system() != "Darwin":
154+
...
155+
# cache 10 different shapes for mkldnn to avoid memory leak
156+
# config.set_mkldnn_cache_capacity(10)
157+
config.enable_mkldnn()
158+
159+
if self.precision == "fp16":
160+
config.enable_mkldnn_bfloat16()
161+
162+
config.set_cpu_math_library_num_threads(self.num_cpu_threads)
163+
164+
# enable memory optim
165+
config.enable_memory_optim()
166+
config.disable_glog_info()
167+
# enable zero copy
168+
config.switch_use_feed_fetch_ops(False)
169+
config.switch_ir_optim(self.ir_optim)
170+
171+
predictor = paddle_inference.create_predictor(config)
172+
return predictor, config
173+
174+
def _create_onnx_predictor(
175+
self,
176+
) -> Tuple["onnxruntime.InferenceSession", "onnxruntime.SessionOptions"]:
177+
if not osp.exists(self.onnx_path):
178+
raise FileNotFoundError(
179+
f"Given 'onnx_path' {self.onnx_path} does not exist. "
180+
"Please check if it is correct."
181+
)
182+
183+
try:
184+
import onnxruntime as ort
185+
except ModuleNotFoundError:
186+
raise ModuleNotFoundError(
187+
"Please install onnxruntime with `pip install onnxruntime`."
188+
)
189+
190+
# set config for onnx predictor
191+
config = ort.SessionOptions()
192+
config.intra_op_num_threads = self.num_cpu_threads
193+
if self.ir_optim:
194+
config.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
195+
196+
# instantiate onnx predictor
197+
predictor = ort.InferenceSession(self.onnx_path, sess_options=config)
198+
return predictor, config
199+
200+
def _check_device(self, device: str):
201+
if device not in ["gpu", "cpu", "npu", "xpu"]:
202+
raise ValueError(
203+
"Inference only supports 'gpu', 'cpu', 'npu' and 'xpu' devices, "
204+
f"but got {device}."
205+
)
206+
207+
def _check_engine(self, engine: str):
208+
if engine not in ["native", "tensorrt", "onnx", "mkldnn"]:
209+
raise ValueError(
210+
"Inference only supports 'native', 'tensorrt', 'onnx' and 'mkldnn' "
211+
f"engines, but got {engine}."
212+
)
213+
214+
def _check_precision(self, precision: str):
215+
if precision not in ["fp32", "fp16", "int8"]:
216+
raise ValueError(
217+
"Inference only supports 'fp32', 'fp16' and 'int8' "
218+
f"precision, but got {precision}."
219+
)
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Dict
16+
from typing import Union
17+
18+
import numpy as np
19+
import paddle
20+
from omegaconf import DictConfig
21+
22+
from deploy.python_infer import base
23+
from ppsci.utils import logger
24+
from ppsci.utils import misc
25+
26+
27+
class PINNPredictor(base.Predictor):
28+
"""General predictor for PINN-based models.
29+
30+
Args:
31+
cfg (DictConfig): Running configuration.
32+
"""
33+
34+
def __init__(
35+
self,
36+
cfg: DictConfig,
37+
):
38+
super().__init__(
39+
cfg.INFER.pdmodel_path,
40+
cfg.INFER.pdpiparams_path,
41+
device=cfg.INFER.device,
42+
engine=cfg.INFER.engine,
43+
precision=cfg.INFER.precision,
44+
onnx_path=cfg.INFER.onnx_path,
45+
ir_optim=cfg.INFER.ir_optim,
46+
min_subgraph_size=cfg.INFER.min_subgraph_size,
47+
gpu_mem=cfg.INFER.gpu_mem,
48+
gpu_id=cfg.INFER.gpu_id,
49+
max_batch_size=cfg.INFER.max_batch_size,
50+
num_cpu_threads=cfg.INFER.num_cpu_threads,
51+
)
52+
self.log_freq = cfg.log_freq
53+
54+
def predict(
55+
self,
56+
input_dict: Dict[str, Union[np.ndarray, paddle.Tensor]],
57+
batch_size: int = 64,
58+
) -> Dict[str, np.ndarray]:
59+
"""
60+
Predicts the output of the model for the given input.
61+
62+
Args:
63+
input_dict (Dict[str, Union[np.ndarray, paddle.Tensor]]):
64+
A dictionary containing the input data.
65+
batch_size (int, optional): The batch size to use for prediction.
66+
Defaults to 64.
67+
68+
Returns:
69+
Dict[str, np.ndarray]: A dictionary containing the predicted output.
70+
"""
71+
if batch_size > self.max_batch_size:
72+
logger.warning(
73+
f"batch_size({batch_size}) is larger than "
74+
f"max_batch_size({self.max_batch_size}), which may occur error."
75+
)
76+
77+
# prepare input handle(s)
78+
input_handles = {
79+
name: self.predictor.get_input_handle(name) for name in input_dict
80+
}
81+
# prepare output handle(s)
82+
output_handles = {
83+
name: self.predictor.get_output_handle(name)
84+
for name in self.predictor.get_output_names()
85+
}
86+
87+
num_samples = len(next(iter(input_dict.values())))
88+
batch_num = (num_samples + (batch_size - 1)) // batch_size
89+
pred_dict = misc.Prettydefaultdict(list)
90+
91+
# inference by batch
92+
for batch_id in range(1, batch_num + 1):
93+
if batch_id % self.log_freq == 0 or batch_id == batch_num:
94+
logger.info(f"Predicting batch {batch_id}/{batch_num}")
95+
96+
# prepare batch input dict
97+
st = (batch_id - 1) * batch_size
98+
ed = min(num_samples, batch_id * batch_size)
99+
batch_input_dict = {key: input_dict[key][st:ed] for key in input_dict}
100+
101+
# send batch input data to input handle(s)
102+
for name, handle in input_handles.items():
103+
handle.copy_from_cpu(batch_input_dict[name])
104+
105+
# run predictor
106+
self.predictor.run()
107+
108+
# receive batch output data from output handle(s)
109+
batch_output_dict = {
110+
name: output_handles[name].copy_to_cpu() for name in output_handles
111+
}
112+
113+
# collect batch output data
114+
for key, batch_output in batch_output_dict.items():
115+
pred_dict[key].append(batch_output)
116+
117+
# concatenate local predictions
118+
pred_dict = {key: np.concatenate(value) for key, value in pred_dict.items()}
119+
120+
return pred_dict

0 commit comments

Comments
 (0)