Skip to content

Commit

Permalink
[Cherry-pick] Fix Paddle-TRT UT fails (#61605)
Browse files Browse the repository at this point in the history
* Fix Paddle-TRT UT fails

* Fix typo
  • Loading branch information
leo0519 authored Feb 26, 2024
1 parent c0f4a49 commit 867ab0d
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 16 deletions.
28 changes: 15 additions & 13 deletions test/ir/inference/program_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ def generate_weight():
self.outputs = outputs
self.input_type = input_type
self.no_cast_list = [] if no_cast_list is None else no_cast_list
self.supported_cast_type = [np.float32, np.float16]

def __repr__(self):
log_str = ''
Expand All @@ -292,11 +293,9 @@ def __repr__(self):
return log_str

def set_input_type(self, _type: np.dtype) -> None:
assert _type in [
np.float32,
np.float16,
None,
], "PaddleTRT only supports FP32 / FP16 IO"
assert (
_type in self.supported_cast_type or _type is None
), "PaddleTRT only supports FP32 / FP16 IO"

ver = paddle.inference.get_trt_compile_version()
trt_version = ver[0] * 1000 + ver[1] * 100 + ver[2] * 10
Expand All @@ -309,15 +308,14 @@ def set_input_type(self, _type: np.dtype) -> None:
def get_feed_data(self) -> Dict[str, Dict[str, Any]]:
feed_data = {}
for name, tensor_config in self.inputs.items():
do_casting = (
self.input_type is not None and name not in self.no_cast_list
)
data = tensor_config.data
# Cast to target input_type
data = (
tensor_config.data.astype(self.input_type)
if do_casting
else tensor_config.data
)
if (
self.input_type is not None
and name not in self.no_cast_list
and data.dtype in self.supported_cast_type
):
data = data.astype(self.input_type)
# Truncate FP32 tensors to FP16 precision for FP16 test stability
if data.dtype == np.float32 and name not in self.no_cast_list:
data = data.astype(np.float16).astype(np.float32)
Expand All @@ -334,10 +332,14 @@ def _cast(self) -> None:
for name, inp in self.inputs.items():
if name in self.no_cast_list:
continue
if inp.dtype not in self.supported_cast_type:
continue
inp.convert_type_inplace(self.input_type)
for name, weight in self.weights.items():
if name in self.no_cast_list:
continue
if weight.dtype not in self.supported_cast_type:
continue
weight.convert_type_inplace(self.input_type)
return self

Expand Down
5 changes: 2 additions & 3 deletions test/ir/inference/test_trt_convert_assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,8 @@ def clear_dynamic_shape():
self.dynamic_shape.opt_input_shape = {}

def generate_trt_nodes_num(attrs, dynamic_shape):
if not dynamic_shape and (
self.has_bool_dtype or self.dims == 1 or self.dims == 0
):
# Static shape does not support 0 or 1 dim's input
if not dynamic_shape and (self.dims == 1 or self.dims == 0):
return 0, 4
return 1, 2

Expand Down
1 change: 1 addition & 0 deletions test/ir/inference/test_trt_convert_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def generate_input(type):
)
},
outputs=["cast_output_data1"],
no_cast_list=["input_data"],
)

yield program_config
Expand Down

0 comments on commit 867ab0d

Please sign in to comment.