-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【SCU】【Paddle TensorRT No.57】Add pd_op.temporal_shift
converter
#69848
base: develop
Are you sure you want to change the base?
Changes from 4 commits
de7494e
f57496c
a70de78
1187a2f
cf49db5
076240a
1f5b7ea
ba335c0
c7f685d
b10c6a5
57a1289
f9b9750
e218312
e7ded78
9591aa9
227d6d7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,6 +26,7 @@ | |
trt_concat, | ||
trt_prod, | ||
trt_shape, | ||
trt_sub, | ||
trt_sum, | ||
) | ||
from paddle.tensorrt.register import converter_registry | ||
|
@@ -299,3 +300,109 @@ def share_data_converter(network, paddle_op, inputs): | |
identity_layer = network.add_identity(x) | ||
|
||
return identity_layer.get_output(0) | ||
|
||
|
||
@converter_registry.register("pd_op.temporal_shift", trt_version="8.x") | ||
def temporal_shift_converter(network, paddle_op, inputs): | ||
input_tensor = inputs[0] | ||
shift_ratio = paddle_op.attrs().get("shift_ratio") | ||
T = paddle_op.attrs().get("seg_num") | ||
data_format = paddle_op.attrs().get("data_format", "NCHW") | ||
|
||
if data_format == "NHWC": | ||
# Transpose input to [N, C, H, W] | ||
transpose_layer = network.add_shuffle(input_tensor) | ||
transpose_layer.first_transpose = trt.Permutation([0, 3, 1, 2]) | ||
input_tensor = transpose_layer.get_output(0) | ||
|
||
input_dims = input_tensor.shape | ||
C, H, W = input_dims[1], input_dims[2], input_dims[3] | ||
|
||
# Reshape input to [N, T, C, H, W] | ||
reshape_layer = network.add_shuffle(input_tensor) | ||
reshape_layer.reshape_dims = trt.Dims([-1, T, C, H, W]) | ||
input_tensor = reshape_layer.get_output(0) | ||
|
||
# Pad input to [N, T + 2, C, H, W] | ||
pre_pad = add_1D_constant_layer(network, [0, 1, 0, 0, 0]) | ||
post_pad = add_1D_constant_layer(network, [0, 1, 0, 0, 0]) | ||
dims = 5 | ||
zeros = add_1D_constant_layer(network, [0] * dims) | ||
start = trt_sum(network, zeros, pre_pad) | ||
total_padding = trt_sum(network, pre_pad, post_pad) | ||
input_shape = trt_shape(network, input_tensor) | ||
size = trt_sum(network, input_shape, total_padding) | ||
stride = [1] * dims | ||
dummy = stride | ||
|
||
slice_layer = network.add_slice(input_tensor, dummy, dummy, stride) | ||
slice_layer.set_input(1, start) | ||
slice_layer.set_input(2, size) | ||
|
||
trt_version = trt.__version__.split('.') | ||
if int(trt_version[0]) > 8 or ( | ||
int(trt_version[0]) == 8 and int(trt_version[1]) >= 5 | ||
): | ||
slice_layer.mode = trt.SampleMode.FILL | ||
else: | ||
slice_layer.mode = trt.SliceMode.FILL | ||
|
||
slice_c = int(C * shift_ratio) | ||
slice_c2 = int(C * shift_ratio * 2) | ||
|
||
slice_start1 = zeros | ||
slice_start2 = add_1D_constant_layer(network, [0, 2, slice_c, 0, 0]) | ||
slice_start3 = add_1D_constant_layer(network, [0, 1, slice_c2, 0, 0]) | ||
|
||
slice_size_base = trt_shape(network, input_tensor) | ||
sub_size1 = add_1D_constant_layer(network, [0, 0, C - slice_c, 0, 0]) | ||
sub_size2 = add_1D_constant_layer( | ||
network, [0, 0, C + slice_c - slice_c2, 0, 0] | ||
) | ||
sub_size3 = add_1D_constant_layer(network, [0, 0, slice_c2, 0, 0]) | ||
|
||
slice_size1 = trt_sub(network, slice_size_base, sub_size1) | ||
slice_size2 = trt_sub(network, slice_size_base, sub_size2) | ||
slice_size3 = trt_sub(network, slice_size_base, sub_size3) | ||
|
||
slice1_layer = network.add_slice( | ||
slice_layer.get_output(0), start=dummy, shape=dummy, stride=stride | ||
) | ||
slice1_layer.set_input(1, slice_start1) | ||
slice1_layer.set_input(2, slice_size1) | ||
slice2_layer = network.add_slice( | ||
slice_layer.get_output(0), start=dummy, shape=dummy, stride=stride | ||
) | ||
slice2_layer.set_input(1, slice_start2) | ||
slice2_layer.set_input(2, slice_size2) | ||
slice3_layer = network.add_slice( | ||
slice_layer.get_output(0), start=dummy, shape=dummy, stride=stride | ||
) | ||
slice3_layer.set_input(1, slice_start3) | ||
slice3_layer.set_input(2, slice_size3) | ||
|
||
if slice_c == 0: | ||
concat_inputs = [slice2_layer.get_output(0), slice3_layer.get_output(0)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个concat_inputs前面没定义的话,会报错吧,建议前面先定义一个 |
||
concat_layer = network.add_concatenation(concat_inputs) | ||
concat_layer.axis = 2 | ||
else: | ||
concat_inputs = [ | ||
slice1_layer.get_output(0), | ||
slice2_layer.get_output(0), | ||
slice3_layer.get_output(0), | ||
] | ||
concat_layer = network.add_concatenation(concat_inputs) | ||
concat_layer.axis = 2 | ||
|
||
# Reshape output to [N*T,C,H,W] | ||
reshape_layer = network.add_shuffle(concat_layer.get_output(0)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里和上面也重名,改成和旧ir相同的名字吧 |
||
reshape_layer.reshape_dims = trt.Dims(inputs[0].shape) | ||
|
||
if data_format == "NHWC": | ||
transpose_layer = network.add_shuffle(reshape_layer.get_output(0)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个改个名字吧,和前面的transpose_layer重名了,相当于重新赋值了 |
||
transpose_layer.first_transpose = trt.Permutation([0, 2, 3, 1]) | ||
output_tensor = transpose_layer.get_output(0) | ||
else: | ||
output_tensor = reshape_layer.get_output(0) | ||
|
||
return output_tensor |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -394,5 +394,90 @@ def test_trt_result(self): | |
self.check_trt_result() | ||
|
||
|
||
class TestTemporalShiftTRTPatternBasic(TensorRTBaseTest): | ||
def setUp(self): | ||
self.python_api = paddle.nn.functional.temporal_shift | ||
self.api_args = { | ||
"x": np.random.random([4, 9, 7, 7]).astype(np.float32), | ||
"seg_num": 2, | ||
"shift_ratio": 0.2, | ||
"data_format": "NCHW", | ||
} | ||
self.program_config = {"feed_list": ["x"]} | ||
self.min_shape = {"x": [2, 9, 7, 7]} | ||
self.max_shape = {"x": [8, 9, 7, 7]} | ||
|
||
def test_trt_result(self): | ||
self.check_trt_result() | ||
|
||
|
||
class TestTemporalShiftTRTPatternDifferentSegNum(TensorRTBaseTest): | ||
def setUp(self): | ||
self.python_api = paddle.nn.functional.temporal_shift | ||
self.api_args = { | ||
"x": np.random.random([4, 9, 7, 7]).astype(np.float32), | ||
"seg_num": 4, | ||
"shift_ratio": 0.2, | ||
"data_format": "NCHW", | ||
} | ||
self.program_config = {"feed_list": ["x"]} | ||
self.min_shape = {"x": [4, 9, 7, 7]} | ||
self.max_shape = {"x": [8, 9, 7, 7]} | ||
|
||
def test_trt_result(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个加一个fp16测试吧 |
||
self.check_trt_result() | ||
|
||
|
||
class TestTemporalShiftTRTPatternDifferentShiftRatio(TensorRTBaseTest): | ||
def setUp(self): | ||
self.python_api = paddle.nn.functional.temporal_shift | ||
self.api_args = { | ||
"x": np.random.random([4, 9, 7, 7]).astype(np.float32), | ||
"seg_num": 2, | ||
"shift_ratio": 0.4, | ||
"data_format": "NCHW", | ||
} | ||
self.program_config = {"feed_list": ["x"]} | ||
self.min_shape = {"x": [2, 9, 7, 7]} | ||
self.max_shape = {"x": [8, 9, 7, 7]} | ||
|
||
def test_trt_result(self): | ||
self.check_trt_result() | ||
|
||
|
||
class TestTemporalShiftTRTPatternDifferentDataFormat(TensorRTBaseTest): | ||
def setUp(self): | ||
self.python_api = paddle.nn.functional.temporal_shift | ||
self.api_args = { | ||
"x": np.random.random([4, 9, 7, 7]).astype(np.float32), | ||
"seg_num": 2, | ||
"shift_ratio": 0.2, | ||
"data_format": "NHWC", | ||
} | ||
self.program_config = {"feed_list": ["x"]} | ||
self.min_shape = {"x": [2, 9, 7, 7]} | ||
self.max_shape = {"x": [8, 9, 7, 7]} | ||
|
||
def test_trt_result(self): | ||
self.check_trt_result() | ||
|
||
|
||
class TestTemporalShiftTRTPatternMinMaxShape(TensorRTBaseTest): | ||
def setUp(self): | ||
self.python_api = paddle.nn.functional.temporal_shift | ||
self.api_args = { | ||
"x": np.random.random([4, 9, 7, 7]).astype(np.float32), | ||
"seg_num": 2, | ||
"shift_ratio": 0.2, | ||
"data_format": "NCHW", | ||
} | ||
self.program_config = {"feed_list": ["x"]} | ||
self.min_shape = {"x": [2, 9, 7, 7]} | ||
self.max_shape = {"x": [10, 9, 7, 7]} | ||
|
||
def test_trt_result(self): | ||
self.check_trt_result() | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里start应该是trt_sub