Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【SCU】【Paddle TensorRT No.57】Add pd_op.temporal_shift converter #69848

Open
wants to merge 16 commits into
base: develop
Choose a base branch
from
29 changes: 29 additions & 0 deletions paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2094,6 +2094,34 @@ class AssignValueOpPattern
}
};

class TemporalShiftOpPattern
: public pir::OpRewritePattern<paddle::dialect::TemporalShiftOp> {
public:
using pir::OpRewritePattern<
paddle::dialect::TemporalShiftOp>::OpRewritePattern;

bool MatchAndRewrite(paddle::dialect::TemporalShiftOp op,
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op.attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
return false;
}
if (!op->HasAttribute("shift_ratio") || !op->HasAttribute("seg_num")) {
VLOG(3) << "temporal shift need attributes : shift_ratio and seg_num";
return false;
}
auto x = op.operand_source(0);
auto x_shape = pir::GetShapeFromValue(x);
if (x_shape.size() != 4) {
VLOG(3) << "The input and grid tensors must be shape tensors of rank 4 "
"when using TRT TemporalShift layer.";
return false;
}
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
return true;
}
};

class TrtOpMarkerPass : public pir::PatternRewritePass {
public:
TrtOpMarkerPass() : pir::PatternRewritePass("trt_op_marker_pass", 2) {}
Expand Down Expand Up @@ -2207,6 +2235,7 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
ps.Add(std::make_unique<OneHotOpPattern>(context));
ps.Add(std::make_unique<AssignValueOpPattern>(context));
ps.Add(std::make_unique<AssignValue_OpPattern>(context));
ps.Add(std::make_unique<TemporalShiftOpPattern>(context));
return ps;
}
};
Expand Down
107 changes: 107 additions & 0 deletions python/paddle/tensorrt/impls/others.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
trt_concat,
trt_prod,
trt_shape,
trt_sub,
trt_sum,
)
from paddle.tensorrt.register import converter_registry
Expand Down Expand Up @@ -299,3 +300,109 @@ def share_data_converter(network, paddle_op, inputs):
identity_layer = network.add_identity(x)

return identity_layer.get_output(0)


@converter_registry.register("pd_op.temporal_shift", trt_version="8.x")
def temporal_shift_converter(network, paddle_op, inputs):
input_tensor = inputs[0]
shift_ratio = paddle_op.attrs().get("shift_ratio")
T = paddle_op.attrs().get("seg_num")
data_format = paddle_op.attrs().get("data_format", "NCHW")

if data_format == "NHWC":
# Transpose input to [N, C, H, W]
transpose_layer = network.add_shuffle(input_tensor)
transpose_layer.first_transpose = trt.Permutation([0, 3, 1, 2])
input_tensor = transpose_layer.get_output(0)

input_dims = input_tensor.shape
C, H, W = input_dims[1], input_dims[2], input_dims[3]

# Reshape input to [N, T, C, H, W]
reshape_layer = network.add_shuffle(input_tensor)
reshape_layer.reshape_dims = trt.Dims([-1, T, C, H, W])
input_tensor = reshape_layer.get_output(0)

# Pad input to [N, T + 2, C, H, W]
pre_pad = add_1D_constant_layer(network, [0, 1, 0, 0, 0])
post_pad = add_1D_constant_layer(network, [0, 1, 0, 0, 0])
dims = 5
zeros = add_1D_constant_layer(network, [0] * dims)
start = trt_sum(network, zeros, pre_pad)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里start应该是trt_sub

total_padding = trt_sum(network, pre_pad, post_pad)
input_shape = trt_shape(network, input_tensor)
size = trt_sum(network, input_shape, total_padding)
stride = [1] * dims
dummy = stride

slice_layer = network.add_slice(input_tensor, dummy, dummy, stride)
slice_layer.set_input(1, start)
slice_layer.set_input(2, size)

trt_version = trt.__version__.split('.')
if int(trt_version[0]) > 8 or (
int(trt_version[0]) == 8 and int(trt_version[1]) >= 5
):
slice_layer.mode = trt.SampleMode.FILL
else:
slice_layer.mode = trt.SliceMode.FILL

slice_c = int(C * shift_ratio)
slice_c2 = int(C * shift_ratio * 2)

slice_start1 = zeros
slice_start2 = add_1D_constant_layer(network, [0, 2, slice_c, 0, 0])
slice_start3 = add_1D_constant_layer(network, [0, 1, slice_c2, 0, 0])

slice_size_base = trt_shape(network, input_tensor)
sub_size1 = add_1D_constant_layer(network, [0, 0, C - slice_c, 0, 0])
sub_size2 = add_1D_constant_layer(
network, [0, 0, C + slice_c - slice_c2, 0, 0]
)
sub_size3 = add_1D_constant_layer(network, [0, 0, slice_c2, 0, 0])

slice_size1 = trt_sub(network, slice_size_base, sub_size1)
slice_size2 = trt_sub(network, slice_size_base, sub_size2)
slice_size3 = trt_sub(network, slice_size_base, sub_size3)

slice1_layer = network.add_slice(
slice_layer.get_output(0), start=dummy, shape=dummy, stride=stride
)
slice1_layer.set_input(1, slice_start1)
slice1_layer.set_input(2, slice_size1)
slice2_layer = network.add_slice(
slice_layer.get_output(0), start=dummy, shape=dummy, stride=stride
)
slice2_layer.set_input(1, slice_start2)
slice2_layer.set_input(2, slice_size2)
slice3_layer = network.add_slice(
slice_layer.get_output(0), start=dummy, shape=dummy, stride=stride
)
slice3_layer.set_input(1, slice_start3)
slice3_layer.set_input(2, slice_size3)

if slice_c == 0:
concat_inputs = [slice2_layer.get_output(0), slice3_layer.get_output(0)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个concat_inputs前面没定义的话,会报错吧,建议前面先定义一个

concat_layer = network.add_concatenation(concat_inputs)
concat_layer.axis = 2
else:
concat_inputs = [
slice1_layer.get_output(0),
slice2_layer.get_output(0),
slice3_layer.get_output(0),
]
concat_layer = network.add_concatenation(concat_inputs)
concat_layer.axis = 2

# Reshape output to [N*T,C,H,W]
reshape_layer = network.add_shuffle(concat_layer.get_output(0))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里和上面也重名,改成和旧ir相同的名字吧

reshape_layer.reshape_dims = trt.Dims(inputs[0].shape)

if data_format == "NHWC":
transpose_layer = network.add_shuffle(reshape_layer.get_output(0))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个改个名字吧,和前面的transpose_layer重名了,相当于重新赋值了

transpose_layer.first_transpose = trt.Permutation([0, 2, 3, 1])
output_tensor = transpose_layer.get_output(0)
else:
output_tensor = reshape_layer.get_output(0)

return output_tensor
85 changes: 85 additions & 0 deletions test/tensorrt/test_converter_others.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,5 +394,90 @@ def test_trt_result(self):
self.check_trt_result()


class TestTemporalShiftTRTPatternBasic(TensorRTBaseTest):
def setUp(self):
self.python_api = paddle.nn.functional.temporal_shift
self.api_args = {
"x": np.random.random([4, 9, 7, 7]).astype(np.float32),
"seg_num": 2,
"shift_ratio": 0.2,
"data_format": "NCHW",
}
self.program_config = {"feed_list": ["x"]}
self.min_shape = {"x": [2, 9, 7, 7]}
self.max_shape = {"x": [8, 9, 7, 7]}

def test_trt_result(self):
self.check_trt_result()


class TestTemporalShiftTRTPatternDifferentSegNum(TensorRTBaseTest):
def setUp(self):
self.python_api = paddle.nn.functional.temporal_shift
self.api_args = {
"x": np.random.random([4, 9, 7, 7]).astype(np.float32),
"seg_num": 4,
"shift_ratio": 0.2,
"data_format": "NCHW",
}
self.program_config = {"feed_list": ["x"]}
self.min_shape = {"x": [4, 9, 7, 7]}
self.max_shape = {"x": [8, 9, 7, 7]}

def test_trt_result(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个加一个fp16测试吧

self.check_trt_result()


class TestTemporalShiftTRTPatternDifferentShiftRatio(TensorRTBaseTest):
def setUp(self):
self.python_api = paddle.nn.functional.temporal_shift
self.api_args = {
"x": np.random.random([4, 9, 7, 7]).astype(np.float32),
"seg_num": 2,
"shift_ratio": 0.4,
"data_format": "NCHW",
}
self.program_config = {"feed_list": ["x"]}
self.min_shape = {"x": [2, 9, 7, 7]}
self.max_shape = {"x": [8, 9, 7, 7]}

def test_trt_result(self):
self.check_trt_result()


class TestTemporalShiftTRTPatternDifferentDataFormat(TensorRTBaseTest):
def setUp(self):
self.python_api = paddle.nn.functional.temporal_shift
self.api_args = {
"x": np.random.random([4, 9, 7, 7]).astype(np.float32),
"seg_num": 2,
"shift_ratio": 0.2,
"data_format": "NHWC",
}
self.program_config = {"feed_list": ["x"]}
self.min_shape = {"x": [2, 9, 7, 7]}
self.max_shape = {"x": [8, 9, 7, 7]}

def test_trt_result(self):
self.check_trt_result()


class TestTemporalShiftTRTPatternMinMaxShape(TensorRTBaseTest):
def setUp(self):
self.python_api = paddle.nn.functional.temporal_shift
self.api_args = {
"x": np.random.random([4, 9, 7, 7]).astype(np.float32),
"seg_num": 2,
"shift_ratio": 0.2,
"data_format": "NCHW",
}
self.program_config = {"feed_list": ["x"]}
self.min_shape = {"x": [2, 9, 7, 7]}
self.max_shape = {"x": [10, 9, 7, 7]}

def test_trt_result(self):
self.check_trt_result()


if __name__ == '__main__':
unittest.main()