diff --git a/paddle/fluid/inference/tensorrt/convert/fc_op.cc b/paddle/fluid/inference/tensorrt/convert/fc_op.cc index a631332dae360..da2a4330ba0df 100644 --- a/paddle/fluid/inference/tensorrt/convert/fc_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/fc_op.cc @@ -34,29 +34,59 @@ namespace tensorrt { class FcOpConverter : public OpConverter { public: nvinfer1::ILayer* reshape_before_fc(nvinfer1::ITensor* before_fc, - nvinfer1::Dims x_dim, int x_num_col_dims, + nvinfer1::Dims x_dim, + int x_num_col_dims, std::string output_name) { // add shuffle before fc nvinfer1::Dims reshape_before_fc_dim; reshape_before_fc_dim.nbDims = x_num_col_dims + 3; // padding shape "* x q x 1 x 1" - for (int i = 0; i < reshape_before_fc_dim.nbDims; i++) { - reshape_before_fc_dim.d[i] = 1; - } - for (int i = 0; i < x_dim.nbDims; i++) { - if (i < x_num_col_dims) { - reshape_before_fc_dim.d[i] = 0; - } else { - if (x_dim.d[i] < 0) { - reshape_before_fc_dim.d[x_num_col_dims] = -1; - break; + + nvinfer1::ITensor* filal_reshape_before_fc_shape_tensor = nullptr; + + if (!engine_->with_dynamic_shape()) { + for (int i = 0; i < reshape_before_fc_dim.nbDims; i++) { + reshape_before_fc_dim.d[i] = 1; + } + for (int i = 0; i < x_dim.nbDims; i++) { + if (i < x_num_col_dims) { + reshape_before_fc_dim.d[i] = 0; + } else { + reshape_before_fc_dim.d[x_num_col_dims] *= x_dim.d[i]; + } + } + } else { + std::vector reshape_before_fc_shape_tensor; + nvinfer1::ITensor* input_shape_tensor = Shape(before_fc); + + for (int i = 0; i < reshape_before_fc_dim.nbDims; i++) { + reshape_before_fc_shape_tensor.push_back(Add1DConstantLayer(1)); + } + for (int i = 0; i < x_dim.nbDims; i++) { + if (i < x_num_col_dims) { + reshape_before_fc_shape_tensor[i] = + GetEleTensorOfShape(input_shape_tensor, i); + } else { + reshape_before_fc_shape_tensor[x_num_col_dims] = + Prod(GetEleTensorOfShape(input_shape_tensor, i), + reshape_before_fc_shape_tensor[x_num_col_dims]); + // If not set, test_trt_matmul_quant_dequant in trt 6015 will fail + reshape_before_fc_shape_tensor[x_num_col_dims]->setType( + nvinfer1::DataType::kINT32); } - reshape_before_fc_dim.d[x_num_col_dims] *= x_dim.d[i]; } + filal_reshape_before_fc_shape_tensor = + Concat(reshape_before_fc_shape_tensor); } + auto* reshape_before_fc_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *before_fc); - reshape_before_fc_layer->setReshapeDimensions(reshape_before_fc_dim); + if (!engine_->with_dynamic_shape()) { + reshape_before_fc_layer->setReshapeDimensions(reshape_before_fc_dim); + } else { + reshape_before_fc_layer->setInput(1, + *filal_reshape_before_fc_shape_tensor); + } reshape_before_fc_layer->setName( ("fc_op_reshape_before_fc: Shuffle (Output: " + output_name + ")") .c_str()); @@ -64,21 +94,37 @@ class FcOpConverter : public OpConverter { } nvinfer1::ILayer* reshape_after_fc(nvinfer1::ITensor* after_fc, - nvinfer1::Dims x_dim, int x_num_col_dims) { + nvinfer1::Dims x_dim, + int x_num_col_dims) { // add shuffle after fc nvinfer1::Dims reshape_after_fc_dim; reshape_after_fc_dim.nbDims = x_num_col_dims + 1; - for (int i = 0; i < reshape_after_fc_dim.nbDims; i++) { - reshape_after_fc_dim.d[i] = 0; + + nvinfer1::ITensor* filal_reshape_after_fc_shape_tensor = nullptr; + if (!engine_->with_dynamic_shape()) { + for (int i = 0; i < reshape_after_fc_dim.nbDims; i++) { + reshape_after_fc_dim.d[i] = 0; + } + } else { + std::vector gather_indices(x_num_col_dims + 1); + std::iota(gather_indices.begin(), gather_indices.end(), 0); + filal_reshape_after_fc_shape_tensor = + Gather(Shape(after_fc), gather_indices); } + auto* reshape_after_fc_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *after_fc); - reshape_after_fc_layer->setReshapeDimensions(reshape_after_fc_dim); + if (!engine_->with_dynamic_shape()) { + reshape_after_fc_layer->setReshapeDimensions(reshape_after_fc_dim); + } else { + reshape_after_fc_layer->setInput(1, *filal_reshape_after_fc_shape_tensor); + } return reshape_after_fc_layer; } void operator()(const framework::proto::OpDesc& op, - const framework::Scope& scope, bool test_mode) override { + const framework::Scope& scope, + bool test_mode) override { VLOG(3) << "convert a fluid fc op to tensorrt fc layer without bias"; framework::OpDesc op_desc(op, nullptr); auto output_name = op_desc.Output("Out").front(); @@ -96,8 +142,9 @@ class FcOpConverter : public OpConverter { // Declare weights auto* Y_v = scope.FindVar(op_desc.Input(w_name).front()); PADDLE_ENFORCE_NOT_NULL( - Y_v, platform::errors::NotFound( - "Can not find %s presistale var of fc in scope.", w_name)); + Y_v, + platform::errors::NotFound( + "Can not find %s presistale var of fc in scope.", w_name)); auto* Y_t = Y_v->GetMutable(); int x_num_col_dims = op_desc.HasAttr("x_num_col_dims") @@ -128,7 +175,8 @@ class FcOpConverter : public OpConverter { } weight_data = engine_->GetWeightCPUData(op_desc.Input(w_name).front(), Y_t); - PADDLE_ENFORCE_EQ(Y_t->dims().size(), 2UL, + PADDLE_ENFORCE_EQ(Y_t->dims().size(), + 2UL, platform::errors::InvalidArgument( "The fc's weight should be a matrix with 2 dims, but " "it's %d-dimensional.", @@ -143,7 +191,8 @@ class FcOpConverter : public OpConverter { } }; - auto regist_fc = [&](nvinfer1::ITensor* inputs, int n_output, + auto regist_fc = [&](nvinfer1::ITensor* inputs, + int n_output, TensorRTEngine::Weight& weight, TensorRTEngine::Weight& bias) { if (enable_int8 || support_int8) { @@ -151,7 +200,8 @@ class FcOpConverter : public OpConverter { float out_scale = 0; if (enable_int8) { PADDLE_ENFORCE_EQ( - op_desc.HasAttr("out_threshold"), true, + op_desc.HasAttr("out_threshold"), + true, platform::errors::InvalidArgument( "must have out threshold in fc layers in int8 mode")); out_scale = BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold")); @@ -159,9 +209,13 @@ class FcOpConverter : public OpConverter { out_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Out")); } nvinfer1::DimsHW nv_ksize(1, 1); - auto* fc_layer_int8 = - TRT_ENGINE_ADD_LAYER(engine_, Convolution, *inputs, n_output, - nv_ksize, weight.get(), bias.get()); + auto* fc_layer_int8 = TRT_ENGINE_ADD_LAYER(engine_, + Convolution, + *inputs, + n_output, + nv_ksize, + weight.get(), + bias.get()); fc_layer_int8->setName( ("fc_op_int8_conv1x1: Convolution (Output: " + output_name + ")") .c_str()); @@ -174,21 +228,29 @@ class FcOpConverter : public OpConverter { .c_str()); engine_->SetTensorDynamicRange(fc_after_reshape_int8->getOutput(0), out_scale); - nvinfer1::IActivationLayer* relu_layer_int8 = TRT_ENGINE_ADD_LAYER( - engine_, Activation, *(fc_after_reshape_int8->getOutput(0)), - nvinfer1::ActivationType::kRELU); - RreplenishLayerAndOutput(relu_layer_int8, "relu_after_fc_shuffle", - {output_name}, test_mode); + nvinfer1::IActivationLayer* relu_layer_int8 = + TRT_ENGINE_ADD_LAYER(engine_, + Activation, + *(fc_after_reshape_int8->getOutput(0)), + nvinfer1::ActivationType::kRELU); + RreplenishLayerAndOutput(relu_layer_int8, + "relu_after_fc_shuffle", + {output_name}, + test_mode); } else { RreplenishLayerAndOutput(fc_after_reshape_int8, "fc_op_int8_reshape_after_fc: Shuffle", - {output_name}, test_mode); + {output_name}, + test_mode); } } else { // add fc layer - auto* fc_layer_float = - TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, *inputs, n_output, - weight.get(), bias.get()); + auto* fc_layer_float = TRT_ENGINE_ADD_LAYER(engine_, + FullyConnected, + *inputs, + n_output, + weight.get(), + bias.get()); fc_layer_float->setName( ("fc_op_float: FullyConnected (Output: " + output_name + ")") .c_str()); @@ -198,14 +260,20 @@ class FcOpConverter : public OpConverter { fc_after_reshape_float->setName( ("float_reshape_after_fc: Shuffle (Output: " + output_name + ")") .c_str()); - nvinfer1::IActivationLayer* relu_layer_float = TRT_ENGINE_ADD_LAYER( - engine_, Activation, *(fc_after_reshape_float->getOutput(0)), - nvinfer1::ActivationType::kRELU); - RreplenishLayerAndOutput(relu_layer_float, "relu_after_fc_shuffle", - {output_name}, test_mode); + nvinfer1::IActivationLayer* relu_layer_float = + TRT_ENGINE_ADD_LAYER(engine_, + Activation, + *(fc_after_reshape_float->getOutput(0)), + nvinfer1::ActivationType::kRELU); + RreplenishLayerAndOutput(relu_layer_float, + "relu_after_fc_shuffle", + {output_name}, + test_mode); } else { - RreplenishLayerAndOutput(fc_after_reshape_float, "shuffle_after_fc", - {output_name}, test_mode); + RreplenishLayerAndOutput(fc_after_reshape_float, + "shuffle_after_fc", + {output_name}, + test_mode); } } }; @@ -255,15 +323,20 @@ class FcOpConverter : public OpConverter { if (enable_int8 || support_int8) { // add conv1x1 layer nvinfer1::DimsHW nv_ksize(1, 1); - auto* fc_layer_int8 = - TRT_ENGINE_ADD_LAYER(engine_, Convolution, *X, n_output, nv_ksize, - weight.get(), bias.get()); + auto* fc_layer_int8 = TRT_ENGINE_ADD_LAYER(engine_, + Convolution, + *X, + n_output, + nv_ksize, + weight.get(), + bias.get()); if (activation_type == "relu") { fc_layer_int8->setName( ("ernie_fc_op_int8: Convolution (Output: " + output_name + ")") .c_str()); PADDLE_ENFORCE_EQ( - op_desc.HasAttr("out_threshold"), true, + op_desc.HasAttr("out_threshold"), + true, platform::errors::InvalidArgument( "must have out threshold in fc layers in int8 mode")); float out_scale = 0; @@ -275,15 +348,20 @@ class FcOpConverter : public OpConverter { } engine_->SetTensorDynamicRange(fc_layer_int8->getOutput(0), out_scale); - nvinfer1::IActivationLayer* relu_layer_int8 = TRT_ENGINE_ADD_LAYER( - engine_, Activation, *(fc_layer_int8->getOutput(0)), - nvinfer1::ActivationType::kRELU); - RreplenishLayerAndOutput(relu_layer_int8, "relu_after_ernie_fc_int8", - {output_name}, test_mode); + nvinfer1::IActivationLayer* relu_layer_int8 = + TRT_ENGINE_ADD_LAYER(engine_, + Activation, + *(fc_layer_int8->getOutput(0)), + nvinfer1::ActivationType::kRELU); + RreplenishLayerAndOutput(relu_layer_int8, + "relu_after_ernie_fc_int8", + {output_name}, + test_mode); } else { RreplenishLayerAndOutput(fc_layer_int8, "ernie_fc_op_int8: Convolution", - {output_name}, test_mode); + {output_name}, + test_mode); } } else { // add fc layer @@ -292,25 +370,30 @@ class FcOpConverter : public OpConverter { if (activation_type == "relu") { fc_layer_float->setName( ("ernie_fc_op_float: (Output: " + output_name + ")").c_str()); - nvinfer1::IActivationLayer* relu_layer_float = TRT_ENGINE_ADD_LAYER( - engine_, Activation, *(fc_layer_float->getOutput(0)), - nvinfer1::ActivationType::kRELU); + nvinfer1::IActivationLayer* relu_layer_float = + TRT_ENGINE_ADD_LAYER(engine_, + Activation, + *(fc_layer_float->getOutput(0)), + nvinfer1::ActivationType::kRELU); RreplenishLayerAndOutput(relu_layer_float, - "relu_after_ernie_fc_float", {output_name}, + "relu_after_ernie_fc_float", + {output_name}, test_mode); } else { - RreplenishLayerAndOutput(fc_layer_float, "ernie_fc_op_float", - {output_name}, test_mode); + RreplenishLayerAndOutput( + fc_layer_float, "ernie_fc_op_float", {output_name}, test_mode); } } } else { // need reshape input before and after fc PADDLE_ENFORCE_GT( - x_dim.nbDims, x_num_col_dims, + x_dim.nbDims, + x_num_col_dims, platform::errors::InvalidArgument( "Params and input dims mismatch. Paddle-TRT FC " "converter expects x_dim.nbDims > x_num_col_dims, but " "x_dim.nbDims : %d, x_num_col_dims : %d.", - x_dim.nbDims, x_num_col_dims)); + x_dim.nbDims, + x_num_col_dims)); auto* reshape_before_fc_layer = reshape_before_fc(X, x_dim, x_num_col_dims, output_name); auto* reshape_itensor = reshape_before_fc_layer->getOutput(0); diff --git a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc index 21c79f0edd27f..f1d7eae826031 100644 --- a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc @@ -22,7 +22,8 @@ namespace tensorrt { class MultiheadMatMulOpConverter : public OpConverter { public: void operator()(const framework::proto::OpDesc& op, - const framework::Scope& scope, bool test_mode) override { + const framework::Scope& scope, + bool test_mode) override { VLOG(3) << "convert a fluid multihead_mamul op to a corresponding tensorrt " "network structure"; framework::OpDesc op_desc(op, nullptr); @@ -52,8 +53,8 @@ class MultiheadMatMulOpConverter : public OpConverter { float* bias_data = engine_->GetWeightCPUData(bias_name, bias_t); std::vector weight_data_tmp; weight_data_tmp.reserve(weight_t->numel()); - memcpy(weight_data_tmp.data(), weight_data, - weight_t->numel() * sizeof(float)); + memcpy( + weight_data_tmp.data(), weight_data, weight_t->numel() * sizeof(float)); // (hidden_in, 3, hidden_out) auto weight_dims = weight_t->dims(); @@ -98,14 +99,15 @@ class MultiheadMatMulOpConverter : public OpConverter { nvinfer1::ILayer* fc_layer = nullptr; float dp_probs = 1.0 / 127.0; nvinfer1::DimsHW nv_ksize(1, 1); - fc_layer = TRT_ENGINE_ADD_LAYER(engine_, Convolution, *input, n, - nv_ksize, weight, bias); + fc_layer = TRT_ENGINE_ADD_LAYER( + engine_, Convolution, *input, n, nv_ksize, weight, bias); fc_layer->setName( ("Multihead: Convolution/FullyConnected: (Output: " + output_name + ")") .c_str()); PADDLE_ENFORCE_EQ( - op_desc.HasAttr("fc_out_threshold"), true, + op_desc.HasAttr("fc_out_threshold"), + true, platform::errors::InvalidArgument( "must have out_threshold in multihead layers in int8 mode")); float out_scale = @@ -119,13 +121,19 @@ class MultiheadMatMulOpConverter : public OpConverter { "CustomQKVToContextPluginDynamic", "3"); assert(creator != nullptr); std::vector fields{ - {"hidden_size", &hidden_out, nvinfer1::PluginFieldType::kINT32, + {"hidden_size", + &hidden_out, + nvinfer1::PluginFieldType::kINT32, 1}, - {"num_heads", &head_number, nvinfer1::PluginFieldType::kINT32, + {"num_heads", + &head_number, + nvinfer1::PluginFieldType::kINT32, 1}}; if (qkv2context_plugin_int8) { - fields.push_back({"dq_probs", &dp_probs, - nvinfer1::PluginFieldType::kFLOAT32, 1}); + fields.push_back({"dq_probs", + &dp_probs, + nvinfer1::PluginFieldType::kFLOAT32, + 1}); } nvinfer1::PluginFieldCollection* plugin_collection = static_cast(malloc( @@ -154,7 +162,8 @@ class MultiheadMatMulOpConverter : public OpConverter { engine_->GetITensor(engine_->network()->getInput(3)->getName()); engine_->SetTensorDynamicRange(max_seqlen_tensor, 1.0f); auto* shuffle_layer = TRT_ENGINE_ADD_LAYER( - engine_, Shuffle, + engine_, + Shuffle, *const_cast(max_seqlen_tensor)); nvinfer1::Dims shape_dim; shape_dim.nbDims = 1; @@ -173,8 +182,11 @@ class MultiheadMatMulOpConverter : public OpConverter { // [3, head_number, head_size, hidden_in] -> [head_number, 3, // head_size, // hidden_in] - auto transpose_weight_v2 = [](const float* src, float* dst, int three, - int head_number, int head_size, + auto transpose_weight_v2 = [](const float* src, + float* dst, + int three, + int head_number, + int head_size, int hidden_in) { const int HH = head_size * hidden_in; for (int i = 0; i < three; ++i) { @@ -187,41 +199,47 @@ class MultiheadMatMulOpConverter : public OpConverter { } }; // [3, head_number, head_size] -> [head_number, 3, head_size] - auto transpose_bias_v2 = [](const float* src, float* dst, int N, - int H) { - for (int i = 0; i < 3; ++i) { - for (int n = 0; n < N; ++n) { - for (int h = 0; h < H; ++h) { - dst[n * 3 * H + i * H + h] = src[i * N * H + n * H + h]; + auto transpose_bias_v2 = + [](const float* src, float* dst, int N, int H) { + for (int i = 0; i < 3; ++i) { + for (int n = 0; n < N; ++n) { + for (int h = 0; h < H; ++h) { + dst[n * 3 * H + i * H + h] = src[i * N * H + n * H + h]; + } + } } - } - } - }; - memcpy(weight_data_tmp.data(), weight_data, + }; + memcpy(weight_data_tmp.data(), + weight_data, weight_t->numel() * sizeof(float)); - transpose_weight_v2(weight_data_tmp.data(), weight_data, three, - head_number, head_size, hidden_in); + transpose_weight_v2(weight_data_tmp.data(), + weight_data, + three, + head_number, + head_size, + hidden_in); std::vector bias_data_tmp; bias_data_tmp.reserve(bias_t->numel()); - memcpy(bias_data_tmp.data(), bias_data, - bias_t->numel() * sizeof(float)); - transpose_bias_v2(bias_data_tmp.data(), bias_data, head_number, - head_size); + memcpy( + bias_data_tmp.data(), bias_data, bias_t->numel() * sizeof(float)); + transpose_bias_v2( + bias_data_tmp.data(), bias_data, head_number, head_size); nvinfer1::ILayer* fc_layer = nullptr; float dp_probs = 1.0 / 127.0; if (op_desc.HasAttr("Input_scale")) { nvinfer1::DimsHW nv_ksize(1, 1); - fc_layer = TRT_ENGINE_ADD_LAYER(engine_, Convolution, *input, n, - nv_ksize, weight, bias); + fc_layer = TRT_ENGINE_ADD_LAYER( + engine_, Convolution, *input, n, nv_ksize, weight, bias); } else { - fc_layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, *input, n, - weight, bias); + fc_layer = TRT_ENGINE_ADD_LAYER( + engine_, FullyConnected, *input, n, weight, bias); } if (op_desc.HasAttr("fc_out_threshold")) { - PADDLE_ENFORCE_EQ(op_desc.HasAttr("fc_out_threshold"), true, + PADDLE_ENFORCE_EQ(op_desc.HasAttr("fc_out_threshold"), + true, platform::errors::InvalidArgument( "must have out threshold in multihead layers " "in int8 mode")); @@ -248,15 +266,21 @@ class MultiheadMatMulOpConverter : public OpConverter { int var_seqlen = 1; std::vector fields{ {"type_id", &type, nvinfer1::PluginFieldType::kINT32, 1}, - {"hidden_size", &hidden_out, nvinfer1::PluginFieldType::kINT32, + {"hidden_size", + &hidden_out, + nvinfer1::PluginFieldType::kINT32, 1}, {"num_heads", &head_number, nvinfer1::PluginFieldType::kINT32, 1}, {"has_mask", &has_mask, nvinfer1::PluginFieldType::kINT32, 1}, - {"var_seqlen", &var_seqlen, nvinfer1::PluginFieldType::kINT32, + {"var_seqlen", + &var_seqlen, + nvinfer1::PluginFieldType::kINT32, 1}}; if (qkv2context_plugin_int8) { - fields.push_back({"dq_probs", &dp_probs, - nvinfer1::PluginFieldType::kFLOAT32, 1}); + fields.push_back({"dq_probs", + &dp_probs, + nvinfer1::PluginFieldType::kFLOAT32, + 1}); } nvinfer1::PluginFieldCollection* plugin_collection = static_cast(malloc( @@ -285,7 +309,8 @@ class MultiheadMatMulOpConverter : public OpConverter { auto max_seqlen_tensor = engine_->GetITensor(engine_->network()->getInput(3)->getName()); auto* shuffle_layer = TRT_ENGINE_ADD_LAYER( - engine_, Shuffle, + engine_, + Shuffle, *const_cast(max_seqlen_tensor)); nvinfer1::Dims shape_dim; shape_dim.nbDims = 1; @@ -301,7 +326,8 @@ class MultiheadMatMulOpConverter : public OpConverter { } } else { PADDLE_ENFORCE_EQ( - input->getDimensions().nbDims, 3, + input->getDimensions().nbDims, + 3, platform::errors::InvalidArgument( "The Input dim of the MultiheadMatMul should be 3, " "but it's (%d) now.", @@ -320,20 +346,25 @@ class MultiheadMatMulOpConverter : public OpConverter { static_cast(bias_t->numel())}; // add shuffle before fc - nvinfer1::Dims reshape_before_fc_dim; - reshape_before_fc_dim.nbDims = 5; - reshape_before_fc_dim.d[0] = 0; - reshape_before_fc_dim.d[1] = 0; - reshape_before_fc_dim.d[2] = 0; - reshape_before_fc_dim.d[3] = 1; - reshape_before_fc_dim.d[4] = 1; + std::vector reshape_before_fc_shape_tensor; + nvinfer1::ITensor* input_shape_tensor = Shape(input); + + for (int i = 0; i < 5; i++) { + reshape_before_fc_shape_tensor.push_back(Add1DConstantLayer(1)); + } + for (int i = 0; i < 3; i++) { + reshape_before_fc_shape_tensor[i] = + GetEleTensorOfShape(input_shape_tensor, i); + } + auto* reshape_before_fc_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); if (op_desc.HasAttr("Input_scale")) { engine_->SetTensorDynamicRange(reshape_before_fc_layer->getOutput(0), in_scale); } - reshape_before_fc_layer->setReshapeDimensions(reshape_before_fc_dim); + reshape_before_fc_layer->setInput( + 1, *Concat(reshape_before_fc_shape_tensor)); reshape_before_fc_layer->setName( ("shuffle_before_multihead_mamul(Output: " + output_name + ")") .c_str()); @@ -342,18 +373,28 @@ class MultiheadMatMulOpConverter : public OpConverter { nvinfer1::ILayer* fc_layer = nullptr; if (op_desc.HasAttr("Input_scale")) { nvinfer1::DimsHW nv_ksize(1, 1); - fc_layer = TRT_ENGINE_ADD_LAYER( - engine_, Convolution, *reshape_before_fc_layer->getOutput(0), n, - nv_ksize, weight.get(), bias.get()); + fc_layer = + TRT_ENGINE_ADD_LAYER(engine_, + Convolution, + *reshape_before_fc_layer->getOutput(0), + n, + nv_ksize, + weight.get(), + bias.get()); } else { - fc_layer = TRT_ENGINE_ADD_LAYER( - engine_, FullyConnected, *reshape_before_fc_layer->getOutput(0), - n, weight.get(), bias.get()); + fc_layer = + TRT_ENGINE_ADD_LAYER(engine_, + FullyConnected, + *reshape_before_fc_layer->getOutput(0), + n, + weight.get(), + bias.get()); } if (op_desc.HasAttr("fc_out_threshold")) { PADDLE_ENFORCE_EQ( - op_desc.HasAttr("fc_out_threshold"), true, + op_desc.HasAttr("fc_out_threshold"), + true, platform::errors::InvalidArgument( "must have out threshold in multihead layers in int8 mode")); float out_scale = @@ -380,8 +421,8 @@ class MultiheadMatMulOpConverter : public OpConverter { with_fp16 = true; } plugin::DynamicPluginTensorRT* plugin = - new plugin::QkvToContextPluginDynamic(hidden_in, head_number, - head_size, scale, with_fp16); + new plugin::QkvToContextPluginDynamic( + hidden_in, head_number, head_size, scale, with_fp16); layer = engine_->AddDynamicPlugin(plugin_inputs.data(), 2, plugin); } } else { @@ -391,8 +432,8 @@ class MultiheadMatMulOpConverter : public OpConverter { "You can use the config.SetTRTDynamicShapeInfo(...) interface to set " "the shape information to run the dynamic shape mode.")); } - RreplenishLayerAndOutput(layer, "multihead_matmul", {output_name}, - test_mode); + RreplenishLayerAndOutput( + layer, "multihead_matmul", {output_name}, test_mode); } }; diff --git a/paddle/fluid/inference/tensorrt/convert/slice_op.cc b/paddle/fluid/inference/tensorrt/convert/slice_op.cc index dea9a1ec3d76d..4f85e4f07cc4e 100644 --- a/paddle/fluid/inference/tensorrt/convert/slice_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/slice_op.cc @@ -1,11 +1,8 @@ /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -14,7 +11,6 @@ limitations under the License. */ #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h" -#include "paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.h" namespace paddle { namespace inference { @@ -23,7 +19,8 @@ namespace tensorrt { class SliceOpConverter : public OpConverter { public: void operator()(const framework::proto::OpDesc& op, - const framework::Scope& scope, bool test_mode) override { + const framework::Scope& scope, + bool test_mode) override { // This OP is implemented by trt dynamic shpae plugin. // Dynamic shape plugin requires TRT version greater than 6.0. VLOG(4) << "convert slice op to tensorrt layer"; @@ -64,63 +61,141 @@ class SliceOpConverter : public OpConverter { } ends[i] = std::min(ends[i], input_dims.d[axes[i]]); PADDLE_ENFORCE_GT( - ends[i], starts[i], + ends[i], + starts[i], platform::errors::InvalidArgument( "Attr(ends) should be greater than attr(starts) in " "slice op. But received ends = %d, starts = %d.", - ends[i], starts[i])); + ends[i], + starts[i])); } } nvinfer1::ILayer* layer = nullptr; if (engine_->with_dynamic_shape()) { - if (engine_->use_oss() && engine_->with_ernie() && - input_dims.nbDims == 4) { - std::vector plugin_inputs; - if (engine_->with_interleaved()) { - auto* shuffler_slice = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); - nvinfer1::Permutation transpose_embed{2, 1, 0, 3}; - shuffler_slice->setSecondTranspose(transpose_embed); - engine_->SetTensorDynamicRange(shuffler_slice->getOutput(0), - out_scale); - shuffler_slice->setName( - ("SpecialSlice_interleaved: transpose: (Output: " + output_name + - ")") - .c_str()); - plugin_inputs.emplace_back(shuffler_slice->getOutput(0)); +#if IS_TRT_VERSION_GE(6000) + auto nchw_input_dims = input->getDimensions(); + nvinfer1::Dims trt_start_dims; + trt_start_dims.nbDims = nchw_input_dims.nbDims; + memset(trt_start_dims.d, 0, sizeof(int32_t) * nchw_input_dims.nbDims); + nvinfer1::Dims trt_size_dims = trt_start_dims; + nvinfer1::Dims trt_end_dims = trt_start_dims; + nvinfer1::Dims trt_step_dims = trt_start_dims; + for (int i = 0; i < trt_step_dims.nbDims; i++) trt_step_dims.d[i] = 1; + + // input : [N,C,H,W] + bool has_neg_indices = false; + for (size_t i = 0; i < axes.size(); i++) { + int trt_axis = axes[i]; + trt_start_dims.d[trt_axis] = starts[i]; + trt_end_dims.d[trt_axis] = ends[i]; + if (starts[i] < 0 || ends[i] < 0) has_neg_indices = true; + } + auto* shape_tensor = Shape(input); + auto* start_tensor = Add1DConstantLayer(trt_start_dims); + if (has_neg_indices) { + start_tensor = FixNegIndices(shape_tensor, start_tensor); + } + + std::vector end_vec_tensor; + for (int i = 0; i < trt_end_dims.nbDims; i++) { + end_vec_tensor.push_back(GetEleTensorOfShape(shape_tensor, i)); + } + + for (size_t i = 0; i < axes.size(); i++) { + int trt_axis = axes[i]; + if (ends[i] >= 0) { + end_vec_tensor[trt_axis] = Add1DConstantLayer(ends[i]); } else { - plugin_inputs.emplace_back(input); + end_vec_tensor[trt_axis] = + Sum(end_vec_tensor[trt_axis], Add1DConstantLayer(ends[i])); } - std::string pos_name; - if (engine_->Has("ernie_pos_name")) { - pos_name = engine_->Get("ernie_pos_name"); - } else { - // hard code for compatibility - pos_name = engine_->network()->getInput(2)->getName(); + } + +// CI failed in trt 6015 but success in 7134, may be a trt bug +#if IS_TRT_VERSION_GE(7134) + auto* size_tensor = + Sub(Min(Concat(end_vec_tensor), shape_tensor), start_tensor); +#else + auto* size_tensor = Sub(Concat(end_vec_tensor), start_tensor); +#endif + + layer = TRT_ENGINE_ADD_LAYER( + engine_, Slice, *input, trt_start_dims, trt_size_dims, trt_step_dims); + layer->setInput(1, *start_tensor); + layer->setInput(2, *size_tensor); + + if (decrease_axises.size() > 0) { + std::vector gather_indices; + for (int i = 0; i < trt_size_dims.nbDims; i++) { + if (decrease_axises.end() != + std::find(decrease_axises.begin(), decrease_axises.end(), i)) + continue; + gather_indices.push_back(i); } - plugin_inputs.emplace_back( - engine_->GetITensor(pos_name)); // cu_seqlens, eval_placeholder_2 - - // bool ban_fp16 = engine_->disable_trt_plugin_fp16(); - plugin::SpecialSlicePluginDynamic* plugin = - new plugin::SpecialSlicePluginDynamic(); - layer = engine_->AddDynamicPlugin(plugin_inputs.data(), - plugin_inputs.size(), plugin); - } else { - bool with_fp16 = - engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); - int decrease_axis = - decrease_axises.size() == 0 ? -1 : decrease_axises[0]; - plugin::SlicePluginDynamic* plugin = new plugin::SlicePluginDynamic( - starts, ends, axes, decrease_axis, with_fp16); - layer = engine_->AddDynamicPlugin(&input, 1, plugin); + if (gather_indices.empty()) + gather_indices.push_back(decrease_axises[0]); + auto real_size_tensor = Gather(size_tensor, gather_indices); + layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *layer->getOutput(0)); + layer->setInput(1, *real_size_tensor); } +#else + bool with_fp16 = + engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); + int decrease_axis = decrease_axises.size() == 0 ? -1 : decrease_axises[0]; + plugin::SlicePluginDynamic* plugin = new plugin::SlicePluginDynamic( + starts, ends, axes, decrease_axis, with_fp16); + layer = engine_->AddDynamicPlugin(&input, 1, plugin); +#endif } else { +#if IS_TRT_VERSION_GE(6000) + auto chw_input_dims = input->getDimensions(); + nvinfer1::Dims trt_start_dims; + trt_start_dims.nbDims = chw_input_dims.nbDims; + memset(trt_start_dims.d, 0, sizeof(int32_t) * chw_input_dims.nbDims); + nvinfer1::Dims trt_size_dims = chw_input_dims; + nvinfer1::Dims trt_step_dims; + trt_step_dims.nbDims = chw_input_dims.nbDims; + for (int i = 0; i < trt_step_dims.nbDims; i++) trt_step_dims.d[i] = 1; + + // input : [C,H,W] + for (size_t i = 0; i < axes.size(); i++) { + int trt_axis = axes[i] - 1; + trt_start_dims.d[trt_axis] = starts[i]; + trt_size_dims.d[trt_axis] = ends[i] - starts[i]; + } + layer = TRT_ENGINE_ADD_LAYER( + engine_, Slice, *input, trt_start_dims, trt_size_dims, trt_step_dims); + nvinfer1::Dims real_trt_size_dims; + real_trt_size_dims.nbDims = 0; + + if (decrease_axises.size() > 0) { + for (size_t i = 0; i < decrease_axises.size(); i++) { + decrease_axises[i]--; + } + for (int i = 0; i < trt_size_dims.nbDims; i++) { + if (decrease_axises.end() != + std::find(decrease_axises.begin(), decrease_axises.end(), i)) + continue; + real_trt_size_dims.d[real_trt_size_dims.nbDims] = trt_size_dims.d[i]; + real_trt_size_dims.nbDims++; + } + if (real_trt_size_dims.nbDims == 0) { + real_trt_size_dims.nbDims = 1; + real_trt_size_dims.d[0] = 1; + } + auto reshape_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *layer->getOutput(0)); + reshape_layer->setReshapeDimensions(real_trt_size_dims); + layer = static_cast(reshape_layer); + } +#else bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); plugin::SlicePlugin* plugin = new plugin::SlicePlugin(starts, ends, axes, with_fp16); layer = engine_->AddPlugin(&input, 1, plugin); +#endif } RreplenishLayerAndOutput(layer, "slice", {output_name}, test_mode); } diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index 33386c746ae5a..3a9504d9c67d9 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -48,7 +48,8 @@ void TensorRTEngine::InitNetwork() { optim_profiles_[i] = infer_builder_->createOptimizationProfile(); } -void TensorRTEngine::Execute(int batch_size, std::vector *buffers, +void TensorRTEngine::Execute(int batch_size, + std::vector *buffers, cudaStream_t stream) { freshDeviceId(); auto infer_context = context(); @@ -126,14 +127,32 @@ void TensorRTEngine::FreezeNetwork() { } #if IS_TRT_VERSION_GE(5122) - auto is_layer_int8 = [&](nvinfer1::ILayer *layer) -> bool { + auto layer_int8_fallback = [&](nvinfer1::ILayer *layer) -> bool { + if (layer->getType() == nvinfer1::LayerType::kSHAPE) { + return false; + } + bool all_int = true; + for (int j = 0; j < layer->getNbInputs(); j++) { + auto *temp_in = layer->getInput(j); + if (temp_in->getType() != nvinfer1::DataType::kINT32) { + all_int = false; + } + } + for (int j = 0; j < layer->getNbOutputs(); j++) { + auto *temp_out = layer->getOutput(j); + if (temp_out->getType() != nvinfer1::DataType::kINT32) { + all_int = false; + } + } + if (all_int) return false; + for (int j = 0; j < layer->getNbInputs(); j++) { auto *temp_in = layer->getInput(j); if (!temp_in->dynamicRangeIsSet()) { VLOG(1) << "Layer(Name: " << layer->getName() << ") is set to float32 because its input(" << temp_in->getName() << ") doesn't have dynamic range."; - return false; + return true; } } for (int j = 0; j < layer->getNbOutputs(); j++) { @@ -142,10 +161,10 @@ void TensorRTEngine::FreezeNetwork() { VLOG(1) << "Layer(Name: " << layer->getName() << ") is set to float32 because its output(" << temp_out->getName() << ") doesn't have dynamic range."; - return false; + return true; } } - return true; + return false; }; // If a layer's output is the network's output, or not all of its inputs // and outputs have scales, @@ -154,7 +173,7 @@ void TensorRTEngine::FreezeNetwork() { int layers_no_int8 = 0; for (int i = 0; i < network()->getNbLayers(); i++) { auto layer = network()->getLayer(i); - if (!is_layer_int8(layer)) { + if (layer_int8_fallback(layer)) { layer->setPrecision(nvinfer1::DataType::kFLOAT); ++layers_no_int8; } @@ -205,7 +224,8 @@ void TensorRTEngine::FreezeNetwork() { for (auto &input : min_input_shape_) { #if IS_TRT_VERSION_LT(7000) // trt6 will check all_of input > 0 - if (!(std::all_of(input.second.begin(), input.second.end(), + if (!(std::all_of(input.second.begin(), + input.second.end(), [](int x) { return x > 0; }) && std::all_of(max_input_shape_[input.first].begin(), max_input_shape_[input.first].end(), @@ -222,13 +242,16 @@ void TensorRTEngine::FreezeNetwork() { << ", opt: " << Vec2Str(optim_input_shape_[input.first]); optim_profiles_[i]->setDimensions( - input.first.c_str(), nvinfer1::OptProfileSelector::kMIN, + input.first.c_str(), + nvinfer1::OptProfileSelector::kMIN, Vec2TRT_Dims(input.second, input.first, true)); optim_profiles_[i]->setDimensions( - input.first.c_str(), nvinfer1::OptProfileSelector::kMAX, + input.first.c_str(), + nvinfer1::OptProfileSelector::kMAX, Vec2TRT_Dims(max_input_shape_[input.first], input.first, true)); optim_profiles_[i]->setDimensions( - input.first.c_str(), nvinfer1::OptProfileSelector::kOPT, + input.first.c_str(), + nvinfer1::OptProfileSelector::kOPT, Vec2TRT_Dims(optim_input_shape_[input.first], input.first, true)); } infer_builder_config_->addOptimizationProfile(optim_profiles_[i]); @@ -262,9 +285,10 @@ void TensorRTEngine::FreezeNetwork() { #endif PADDLE_ENFORCE_NOT_NULL( - infer_engine_, platform::errors::Fatal( - "Build TensorRT cuda engine failed! Please recheck " - "you configurations related to paddle-TensorRT.")); + infer_engine_, + platform::errors::Fatal( + "Build TensorRT cuda engine failed! Please recheck " + "you configurations related to paddle-TensorRT.")); binding_num_ = infer_engine_->getNbBindings(); // reset status for dynamic shape clone @@ -279,16 +303,19 @@ void TensorRTEngine::FreezeNetwork() { nvinfer1::ITensor *TensorRTEngine::DeclareInput(const std::string &name, nvinfer1::DataType dtype, const nvinfer1::Dims &dims) { - PADDLE_ENFORCE_EQ(network() != nullptr, true, + PADDLE_ENFORCE_EQ(network() != nullptr, + true, platform::errors::InvalidArgument( "The TRT network should be initialized first.")); auto *input = network()->addInput(name.c_str(), dtype, dims); PADDLE_ENFORCE_NOT_NULL( - input, platform::errors::InvalidArgument("Adding input %s failed in " - "TensorRT inference network. " - "Please recheck your input.", - name)); - PADDLE_ENFORCE_EQ(input->isNetworkInput(), true, + input, + platform::errors::InvalidArgument("Adding input %s failed in " + "TensorRT inference network. " + "Please recheck your input.", + name)); + PADDLE_ENFORCE_EQ(input->isNetworkInput(), + true, platform::errors::InvalidArgument( "Input %s is not the input of TRT inference network. " "Please recheck your input.", @@ -297,22 +324,26 @@ nvinfer1::ITensor *TensorRTEngine::DeclareInput(const std::string &name, return input; } -void TensorRTEngine::DeclareOutput(const nvinfer1::ILayer *layer, int offset, +void TensorRTEngine::DeclareOutput(const nvinfer1::ILayer *layer, + int offset, const std::string &name) { auto *output = layer->getOutput(offset); SetITensor(name, output); PADDLE_ENFORCE_NOT_NULL( - output, platform::errors::InvalidArgument( - "The output %s of TRT engine should not be null.", name)); + output, + platform::errors::InvalidArgument( + "The output %s of TRT engine should not be null.", name)); output->setName(name.c_str()); - PADDLE_ENFORCE_EQ(output->isNetworkInput(), false, + PADDLE_ENFORCE_EQ(output->isNetworkInput(), + false, platform::errors::InvalidArgument( "The output %s of TRT engine should not be the input " "of the network at the same time.", name)); network()->markOutput(*output); PADDLE_ENFORCE_EQ( - output->isNetworkOutput(), true, + output->isNetworkOutput(), + true, platform::errors::InvalidArgument( "The output %s of TRT engine should be the output of the network.", name)); @@ -321,10 +352,12 @@ void TensorRTEngine::DeclareOutput(const nvinfer1::ILayer *layer, int offset, void TensorRTEngine::DeclareOutput(const std::string &name) { auto *output = TensorRTEngine::GetITensor(name); PADDLE_ENFORCE_NOT_NULL( - output, platform::errors::InvalidArgument( - "The output %s of TRT engine should not be null.", name)); + output, + platform::errors::InvalidArgument( + "The output %s of TRT engine should not be null.", name)); output->setName(name.c_str()); - PADDLE_ENFORCE_EQ(output->isNetworkInput(), false, + PADDLE_ENFORCE_EQ(output->isNetworkInput(), + false, platform::errors::InvalidArgument( "The output %s of TRT engine should not be the input " "of the network at the same time.", @@ -335,17 +368,20 @@ void TensorRTEngine::DeclareOutput(const std::string &name) { void TensorRTEngine::SetITensor(const std::string &name, nvinfer1::ITensor *tensor) { PADDLE_ENFORCE_NOT_NULL( - tensor, platform::errors::InvalidArgument( - "Tensor named %s of TRT engine should not be null.", name)); + tensor, + platform::errors::InvalidArgument( + "Tensor named %s of TRT engine should not be null.", name)); PADDLE_ENFORCE_EQ( - 0, itensor_map_.count(name), + 0, + itensor_map_.count(name), platform::errors::InvalidArgument( "Tensor named %s of TRT engine should not be duplicated", name)); itensor_map_[name] = tensor; } nvinfer1::ITensor *TensorRTEngine::GetITensor(const std::string &name) { - PADDLE_ENFORCE_EQ(itensor_map_.count(name), true, + PADDLE_ENFORCE_EQ(itensor_map_.count(name), + true, platform::errors::NotFound( "Tensor named %s is not found in TRT engine", name)); return itensor_map_[name]; @@ -362,15 +398,16 @@ float *TensorRTEngine::GetWeightCPUData(const std::string &name, std::string splitter = "__"; std::string name_with_suffix = name + splitter + name_suffix; platform::CPUPlace cpu_place; - PADDLE_ENFORCE_EQ(weight_map.count(name_with_suffix), 0, + PADDLE_ENFORCE_EQ(weight_map.count(name_with_suffix), + 0, platform::errors::AlreadyExists( "The weight named %s is set into the weight map " "twice in TRT OP converter.", name_with_suffix)); weight_map[name_with_suffix].reset(new framework::Tensor()); weight_map[name_with_suffix]->Resize(weight_tensor->dims()); - paddle::framework::TensorCopySync(*weight_tensor, cpu_place, - weight_map[name_with_suffix].get()); + paddle::framework::TensorCopySync( + *weight_tensor, cpu_place, weight_map[name_with_suffix].get()); float *weight_data = weight_map[name_with_suffix]->mutable_data(cpu_place); name_suffix_counter += 1; @@ -380,21 +417,24 @@ float *TensorRTEngine::GetWeightCPUData(const std::string &name, int TensorRTEngine::GetRuntimeBatch() { return runtime_batch_; } nvinfer1::IPluginV2Layer *TensorRTEngine::AddPlugin( - nvinfer1::ITensor *const *inputs, int num_inputs, + nvinfer1::ITensor *const *inputs, + int num_inputs, plugin::PluginTensorRT *plugin) { owned_plugin_.emplace_back(plugin); return network()->addPluginV2(inputs, num_inputs, *plugin); } nvinfer1::IPluginV2Layer *TensorRTEngine::AddPluginV2Ext( - nvinfer1::ITensor *const *inputs, int num_inputs, + nvinfer1::ITensor *const *inputs, + int num_inputs, plugin::PluginTensorRTV2Ext *plugin) { owned_plugin_v2ext_.emplace_back(plugin); return network()->addPluginV2(inputs, num_inputs, *plugin); } nvinfer1::IPluginV2Layer *TensorRTEngine::AddPluginV2IOExt( - nvinfer1::ITensor *const *inputs, int num_inputs, + nvinfer1::ITensor *const *inputs, + int num_inputs, nvinfer1::IPluginV2IOExt *plugin) { owned_plugin_v2ioext_.emplace_back(plugin); return network()->addPluginV2(inputs, num_inputs, *plugin); @@ -403,10 +443,12 @@ nvinfer1::IPluginV2Layer *TensorRTEngine::AddPluginV2IOExt( void TensorRTEngine::freshDeviceId() { int count; cudaGetDeviceCount(&count); - PADDLE_ENFORCE_LT(device_id_, count, + PADDLE_ENFORCE_LT(device_id_, + count, platform::errors::OutOfRange( "Device id %d exceeds the current device count: %d.", - device_id_, count)); + device_id_, + count)); platform::SetDeviceId(device_id_); } diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 22147d6dc6352..d545d6f0e67e2 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -1063,14 +1063,9 @@ bool OpTeller::Tell(const framework::ir::Node* node, if (desc.HasAttr("decrease_axis")) { std::vector decrease_axis = BOOST_GET_CONST(std::vector, desc.GetAttr("decrease_axis")); - if (with_dynamic_shape) { - if (decrease_axis.size() > 1) { - return false; - } - } else { - if (decrease_axis.size() > 0) { - VLOG(3) << "Invalid slice decrease_axis. decrease_axis.size() > 0" - "is not supported in TensorRT"; + if (!with_dynamic_shape) { + if (decrease_axis.end() != + std::find(decrease_axis.begin(), decrease_axis.end(), 0)) { return false; } } @@ -1102,15 +1097,28 @@ bool OpTeller::Tell(const framework::ir::Node* node, return false; } } - } else { - for (size_t i = 0; i < axes.size(); i++) { - if (starts[i] < 0 || ends[i] < 0) { - VLOG(3) << "Invalid slice attribute 'starts' or 'ends'. " - "Negative starts or ends not supported in TensorRT " - "when running in dynamic shape mode."; - return false; - } - } + } + } + // not support following four inputs for slice in paddle-trt + auto slice_inputs = desc.Inputs(); // its size == 5 + if (slice_inputs.find("StartsTensor") != slice_inputs.end()) { + if (desc.Input("StartsTensor").size()) { + return false; + } + } + if (slice_inputs.find("EndsTensor") != slice_inputs.end()) { + if (desc.Input("EndsTensor").size()) { + return false; + } + } + if (slice_inputs.find("StartsTensorList") != slice_inputs.end()) { + if (desc.Input("StartsTensorList").size()) { + return false; + } + } + if (slice_inputs.find("EndsTensorList") != slice_inputs.end()) { + if (desc.Input("EndsTensorList").size()) { + return false; } } } diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_slice.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_slice.py index 86c52dad23af0..deac7ef9d2a14 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_slice.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_slice.py @@ -1,11 +1,11 @@ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -22,17 +22,14 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: inputs = program_config.inputs weights = program_config.weights attrs = [ - program_config.ops[i].attrs - for i in range(len(program_config.ops)) + program_config.ops[i].attrs for i in range(len(program_config.ops)) ] - - for x in attrs[0]["decrease_axis"]: - if x < 0: - return False + out_shape = list(inputs['input_data'].shape) for x in range(len(attrs[0]["axes"])): start = 0 end = 0 @@ -42,24 +39,30 @@ def is_program_valid(self, program_config: ProgramConfig) -> bool: else: start = attrs[0]["starts"][x] if attrs[0]["ends"][x] < 0: - end = attrs[0]["ends"][x] + inputs['input_data'].shape[attrs[0][ - "axes"][x]] + end = attrs[0]["ends"][x] + inputs['input_data'].shape[ + attrs[0]["axes"][x]] else: end = attrs[0]["ends"][x] start = max(0, start) end = max(0, end) + out_shape[attrs[0]["axes"][x]] = end - start if start >= end: return False - + for x in attrs[0]["decrease_axis"]: + if x < 0: + return False + if (out_shape[x] != 1): + return False return True def sample_program_configs(self): + def generate_input1(attrs: List[Dict[str, Any]]): - return np.ones([6, 6, 64, 64]).astype(np.float32) + return np.random.random([6, 6, 64, 64]).astype(np.float32) for axes in [[0, 1], [1, 3], [2, 3]]: for starts in [[0, 1]]: - for ends in [[2, 2], [5, 5]]: + for ends in [[2, 2], [5, 5], [1, -1]]: for decrease_axis in [[], [1], [2], [-1], [-100]]: for infer_flags in [[-1]]: dics = [{ @@ -86,8 +89,9 @@ def generate_input1(attrs: List[Dict[str, Any]]): ops=ops, weights={}, inputs={ - "input_data": TensorConfig(data_gen=partial( - generate_input1, dics)) + "input_data": + TensorConfig( + data_gen=partial(generate_input1, dics)) }, outputs=["slice_output_data"]) @@ -95,6 +99,7 @@ def generate_input1(attrs: List[Dict[str, Any]]): def sample_predictor_configs( self, program_config) -> (paddle_infer.Config, List[int], float): + def generate_dynamic_shape(attrs): self.dynamic_shape.min_input_shape = {"input_data": [1, 3, 32, 32]} self.dynamic_shape.max_input_shape = {"input_data": [8, 8, 64, 64]} @@ -106,17 +111,6 @@ def clear_dynamic_shape(): self.dynamic_shape.opt_input_shape = {} def generate_trt_nodes_num(attrs, dynamic_shape): - inputs = program_config.inputs - if dynamic_shape == True and len(attrs[0]["decrease_axis"]) == 0: - return 1, 2 - if dynamic_shape == True and len(attrs[0]["decrease_axis"]) != 1: - return 0, 3 - if dynamic_shape == False and len(attrs[0]["decrease_axis"]) != 0: - return 0, 3 - if dynamic_shape: - for i in range(len(attrs[0]["starts"])): - if attrs[0]["starts"][i] < 0 or attrs[0]["ends"][i] < 0: - return 0, 3 if not dynamic_shape: for x in attrs[0]["axes"]: if x == 0: @@ -124,8 +118,7 @@ def generate_trt_nodes_num(attrs, dynamic_shape): return 1, 2 attrs = [ - program_config.ops[i].attrs - for i in range(len(program_config.ops)) + program_config.ops[i].attrs for i in range(len(program_config.ops)) ] self.trt_param.max_batch_size = 9 # for static_shape @@ -140,11 +133,11 @@ def generate_trt_nodes_num(attrs, dynamic_shape): # for dynamic_shape generate_dynamic_shape(attrs) self.trt_param.precision = paddle_infer.PrecisionType.Float32 - yield self.create_inference_config(), generate_trt_nodes_num(attrs, - True), 1e-5 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True), 1e-5 self.trt_param.precision = paddle_infer.PrecisionType.Half - yield self.create_inference_config(), generate_trt_nodes_num(attrs, - True), 1e-4 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True), 1e-4 def test(self): # TODO(inference): fix.