Skip to content

Commit

Permalink
Merge pull request #903 from wjj19950828/Add_onnx_tests
Browse files Browse the repository at this point in the history
Fixed ToPILImage && rm SymbolicShapeInference
  • Loading branch information
jiangjiajun authored Oct 10, 2022
2 parents 270c5dc + 0680f32 commit 23fa4be
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ from paddle.vision.transforms import functional as F

class ToPILImage(BaseTransform):
def __init__(self, mode=None, keys=None):
super(ToTensor, self).__init__(keys)
super(ToPILImage, self).__init__(keys)
self.mode = mode

def _apply_image(self, pic):
"""
Expand Down Expand Up @@ -53,7 +54,7 @@ class ToPILImage(BaseTransform):

npimg = pic
if isinstance(pic, paddle.Tensor) and "float" in str(pic.numpy(
).dtype) and mode != 'F':
).dtype) and self.mode != 'F':
pic = pic.mul(255).byte()
if isinstance(pic, paddle.Tensor):
npimg = np.transpose(pic.numpy(), (1, 2, 0))
Expand All @@ -74,40 +75,40 @@ class ToPILImage(BaseTransform):
expected_mode = 'I'
elif npimg.dtype == np.float32:
expected_mode = 'F'
if mode is not None and mode != expected_mode:
if self.mode is not None and self.mode != expected_mode:
raise ValueError(
"Incorrect mode ({}) supplied for input type {}. Should be {}"
.format(mode, np.dtype, expected_mode))
mode = expected_mode
.format(self.mode, np.dtype, expected_mode))
self.mode = expected_mode

elif npimg.shape[2] == 2:
permitted_2_channel_modes = ['LA']
if mode is not None and mode not in permitted_2_channel_modes:
if self.mode is not None and self.mode not in permitted_2_channel_modes:
raise ValueError("Only modes {} are supported for 2D inputs".
format(permitted_2_channel_modes))

if mode is None and npimg.dtype == np.uint8:
mode = 'LA'
if self.mode is None and npimg.dtype == np.uint8:
self.mode = 'LA'

elif npimg.shape[2] == 4:
permitted_4_channel_modes = ['RGBA', 'CMYK', 'RGBX']
if mode is not None and mode not in permitted_4_channel_modes:
if self.mode is not None and self.mode not in permitted_4_channel_modes:
raise ValueError("Only modes {} are supported for 4D inputs".
format(permitted_4_channel_modes))

if mode is None and npimg.dtype == np.uint8:
mode = 'RGBA'
if self.mode is None and npimg.dtype == np.uint8:
self.mode = 'RGBA'
else:
permitted_3_channel_modes = ['RGB', 'YCbCr', 'HSV']
if mode is not None and mode not in permitted_3_channel_modes:
if self.mode is not None and self.mode not in permitted_3_channel_modes:
raise ValueError("Only modes {} are supported for 3D inputs".
format(permitted_3_channel_modes))
if mode is None and npimg.dtype == np.uint8:
mode = 'RGB'
if self.mode is None and npimg.dtype == np.uint8:
self.mode = 'RGB'

if mode is None:
if self.mode is None:
raise TypeError('Input type {} is not supported'.format(
npimg.dtype))

return Image.fromarray(npimg, mode=mode)
return Image.fromarray(npimg, mode=self.mode)
```
11 changes: 6 additions & 5 deletions x2paddle/decoder/onnx_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,15 @@ def __init__(self, onnx_model, input_shape_dict):
self.value_infos = {}
self.graph = onnx_model.graph
self.get_place_holder_nodes()
print("shape inferencing ...")
self.graph = SymbolicShapeInference.infer_shapes(
onnx_model, fixed_input_shape=self.fixed_input_shape)
if self.graph is None:
print("Shape inferencing ...")
try:
self.graph = SymbolicShapeInference.infer_shapes(
onnx_model, fixed_input_shape=self.fixed_input_shape)
except:
print('[WARNING] Shape inference by ONNX offical interface.')
onnx_model = shape_inference.infer_shapes(onnx_model)
self.graph = onnx_model.graph
print("shape inferenced.")
print("Shape inferenced.")
self.build()
self.collect_value_infos()
self.allocate_shapes()
Expand Down
8 changes: 5 additions & 3 deletions x2paddle/decoder/onnx_shape_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def _insert_ready_nodes():
if pending_nodes and self.verbose_ > 0:
print('SymbolicShapeInference: orphaned nodes discarded: ')
print(
*[n.op_type + ': ' + n.output[0] for n in pending_nodes],
* [n.op_type + ': ' + n.output[0] for n in pending_nodes],
sep='\n')

if input_shapes is not None:
Expand Down Expand Up @@ -1588,7 +1588,9 @@ def infer_shapes(in_mp,
assert version.parse(onnx.__version__) >= version.parse("1.5.0")
onnx_opset = get_opset(in_mp)
if not onnx_opset or onnx_opset < 7:
print('[WARNING] Symbolic shape inference only support models of onnx opset 7 and above.')
print(
'[WARNING] Symbolic shape inference only support models of onnx opset 7 and above.'
)
return
symbolic_shape_inference = SymbolicShapeInference(
int_max, auto_merge, guess_output_rank, verbose)
Expand All @@ -1608,4 +1610,4 @@ def infer_shapes(in_mp,
print('[WARNING] Incomplete symbolic shape inference')
symbolic_shape_inference.out_mp_ = shape_inference.infer_shapes(
symbolic_shape_inference.out_mp_)
return symbolic_shape_inference.out_mp_.graph
return symbolic_shape_inference.out_mp_.graph

0 comments on commit 23fa4be

Please sign in to comment.