Skip to content

Commit

Permalink
fix trt convert conv2d skip (#38999)
Browse files Browse the repository at this point in the history
* fix trt convert conv2d skip

* fix trt convert conv2d skip
  • Loading branch information
JZZ-NOTE authored Jan 18, 2022
1 parent 27f8460 commit dfa242e
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 95 deletions.
9 changes: 8 additions & 1 deletion paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,17 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op,
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("dilations"));
const std::vector<int> strides =
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("strides"));
const std::vector<int> paddings =
std::vector<int> paddings =
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("paddings"));
std::string padding_algorithm = "EXPLICIT";
if (op_desc.HasAttr("padding_algorithm"))
padding_algorithm =
BOOST_GET_CONST(std::string, op_desc.GetAttr("padding_algorithm"));
if (padding_algorithm == "VALID") {
for (size_t i = 0; i < paddings.size(); i++) {
paddings[i] = 0;
}
}

nvinfer1::DimsHW nv_ksize(filter_h, filter_w);
nvinfer1::DimsHW nv_dilations(dilations[0], dilations[1]);
Expand Down Expand Up @@ -139,6 +144,8 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op,
layer->setNbGroups(groups);
if (padding_algorithm == "SAME") {
layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER);
nv_dilations.d[0] = 1;
nv_dilations.d[1] = 1;
}
// set dilations
fset_dilation(layer, nv_dilations);
Expand Down
30 changes: 0 additions & 30 deletions paddle/fluid/inference/tensorrt/op_teller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -271,36 +271,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
return false;
}

if (desc.HasAttr("padding_algorithm")) {
auto padding_algorithm =
BOOST_GET_CONST(std::string, desc.GetAttr("padding_algorithm"));
if (padding_algorithm == "VALID") {
return false;
}
if (padding_algorithm == "SAME") {
if (desc.HasAttr("dilations")) {
const std::vector<int> dilations =
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("dilations"));
if (dilations[0] != 1 || dilations[1] != 1) {
VLOG(3) << "In Same mode, Dilations must be (1, 1) for "
"tensorRT, but given ("
<< dilations[0] << ", " << dilations[1] << ")";
return false;
}
}
}
}

if (use_no_calib_int8) {
if (desc.HasAttr("padding_algorithm")) {
auto padding_algorithm =
BOOST_GET_CONST(std::string, desc.GetAttr("padding_algorithm"));
if (padding_algorithm == "SAME") {
return false;
}
}
}

if (desc.HasAttr("enable_int8")) {
if (op_type == "conv2d" || op_type == "conv2d_fusion") {
if (!desc.HasAttr("Input_scale")) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ def is_program_valid(self, program_config: ProgramConfig) -> bool:
1] * attrs[0]['groups']:
return False

ver = paddle_infer.get_trt_compile_version()
if ver[0] * 1000 + ver[1] * 100 + ver[0] * 10 < 7000:
if attrs[0]['padding_algorithm'] == 'SAME' and (
attrs[0]['strides'][0] > 1 or attrs[0]['strides'][1] > 1):
return False

return True

def sample_program_configs(self):
Expand Down Expand Up @@ -68,39 +74,27 @@ def generate_weight1(attrs: List[Dict[str, Any]]):
"data_format": data_format
}, {}]

if padding_algorithm == 'EXPLICIT':
ops_config = [{
"op_type": "conv2d",
"op_inputs": {
"Input": ["input_data"],
"Filter": ["conv2d_weight"]
},
"op_outputs": {
"Output": ["conv_output_data"]
},
"op_attrs": dics[0]
}, {
"op_type": "relu",
"op_inputs": {
"X": ["conv_output_data"]
},
"op_outputs": {
"Out": ["output_data"]
},
"op_attrs": dics[1]
}]
else:
ops_config = [{
"op_type": "conv2d",
"op_inputs": {
"Input": ["input_data"],
"Filter": ["conv2d_weight"]
},
"op_outputs": {
"Output": ["output_data"]
},
"op_attrs": dics[0]
}]
ops_config = [{
"op_type": "conv2d",
"op_inputs": {
"Input": ["input_data"],
"Filter": ["conv2d_weight"]
},
"op_outputs": {
"Output": ["conv_output_data"]
},
"op_attrs": dics[0]
}, {
"op_type": "relu",
"op_inputs": {
"X": ["conv_output_data"]
},
"op_outputs": {
"Out": ["output_data"]
},
"op_attrs": dics[1]
}]

ops = self.generate_op_config(ops_config)

program_config = ProgramConfig(
Expand Down Expand Up @@ -188,7 +182,6 @@ def generate_trt_nodes_num(attrs, dynamic_shape):
attrs, False), (1e-5, 1e-5)

# for dynamic_shape

generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(attrs,
Expand All @@ -200,25 +193,10 @@ def generate_trt_nodes_num(attrs, dynamic_shape):
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True), (1e-5, 1e-5)

def add_skip_trt_case(self):
def teller1(program_config, predictor_config):
if program_config.ops[0].attrs[
'padding_algorithm'] == "SAME" or program_config.ops[
0].attrs['padding_algorithm'] == "VALID":
return True
return False

self.add_skip_case(
teller1, SkipReasons.TRT_NOT_IMPLEMENTED,
"When padding_algorithm is 'SAME' or 'VALID', Trt dose not support. In this case, trt build error is caused by scale op."
)

def test(self):
self.add_skip_trt_case()
self.run_test()

def test_quant(self):
self.add_skip_trt_case()
self.run_test(quant=True)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ def is_program_valid(self, program_config: ProgramConfig) -> bool:
if attrs[0]['groups'] <= 1:
return False

ver = paddle_infer.get_trt_compile_version()
if ver[0] * 1000 + ver[1] * 100 + ver[0] * 10 < 7000:
if attrs[0]['padding_algorithm'] == 'SAME' and (
attrs[0]['strides'][0] > 1 or attrs[0]['strides'][1] > 1):
return False

return True

def sample_program_configs(self):
Expand Down Expand Up @@ -184,25 +190,10 @@ def generate_trt_nodes_num(attrs, dynamic_shape):
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True), (1e-5, 1e-5)

def add_skip_trt_case(self):
def teller1(program_config, predictor_config):
if program_config.ops[0].attrs[
'padding_algorithm'] == "SAME" or program_config.ops[
0].attrs['padding_algorithm'] == "VALID":
return True
return False

self.add_skip_case(
teller1, SkipReasons.TRT_NOT_IMPLEMENTED,
"When padding_algorithm is 'SAME' or 'VALID', Trt dose not support. In this case, trt build error is caused by scale op."
)

def test(self):
self.add_skip_trt_case()
self.run_test()

def test_quant(self):
self.add_skip_trt_case()
self.run_test(quant=True)


Expand Down

0 comments on commit dfa242e

Please sign in to comment.