Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanlehome committed May 11, 2023
1 parent 638660e commit a38f81c
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 41 deletions.
6 changes: 0 additions & 6 deletions paddle/fluid/inference/tensorrt/op_teller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1862,12 +1862,6 @@ struct SimpleOpTypeSetTeller : public Teller {
"with static shape.";
return false;
}

#if IS_TRT_VERSION_LT(8000)
if (with_dynamic_shape && x_shape.size() == 0) {
return false; // not supported 0 dim.
}
#endif
}

if (op_type == "mish") {
Expand Down
120 changes: 85 additions & 35 deletions test/ir/inference/test_trt_convert_prelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,20 @@ def is_program_valid(self, program_config: ProgramConfig) -> bool:
return True

def sample_program_configs(self):
def generate_input(attrs: List[Dict[str, Any]], batch, dims):
if dims == 0:
def generate_input(attrs: List[Dict[str, Any]], batch):
if self.dims == 0:
return np.random.random([]).astype(np.float32)
elif dims == 1:
elif self.dims == 1:
return np.random.random([16]).astype(np.float32)
elif dims == 2:
return np.random.random([batch, 3]).astype(np.float32)
elif dims == 3:
elif self.dims == 2:
return np.random.random([1, 3]).astype(np.float32)
elif self.dims == 3:
if attrs[0]["data_format"] == "NCHW":
return np.random.random([batch, 3, 16]).astype(np.float32)
else:
elif attrs[0]["data_format"] == "NHWC":
return np.random.random([batch, 16, 3]).astype(np.float32)
else:
raise AssertionError()
else:
if attrs[0]["data_format"] == "NCHW":
return np.random.random([batch, 3, 16, 32]).astype(
Expand All @@ -50,27 +52,43 @@ def generate_input(attrs: List[Dict[str, Any]], batch, dims):
np.float32
)

def generate_alpha(attrs: List[Dict[str, Any]], dims):
if dims == 0:
def generate_alpha(attrs: List[Dict[str, Any]]):
if self.dims == 0:
return np.random.random([]).astype(np.float32)
if attrs[0]["mode"] == "all":
return np.random.random([1]).astype(np.float32)
elif attrs[0]["mode"] == "channel":
return np.random.random([3]).astype(np.float32)
elif attrs[0]["mode"] == "element":
if dims == 1:
if self.dims == 1:
return np.random.random([16]).astype(np.float32)
elif dims == 2:
return np.random.random([3, 16]).astype(np.float32)
elif dims == 3:
return np.random.random([3, 16, 32]).astype(np.float32)
elif self.dims == 2:
return np.random.random([1, 3]).astype(np.float32)
elif self.dims == 3:
if attrs[0]["data_format"] == "NCHW":
return np.random.random([1, 3, 16]).astype(np.float32)
elif attrs[0]["data_format"] == "NHWC":
return np.random.random([1, 16, 3]).astype(np.float32)
else:
raise AssertionError()
else:
return np.random.random([1, 3, 16, 32]).astype(np.float32)
if attrs[0]["data_format"] == "NCHW":
return np.random.random([1, 3, 16, 32]).astype(
np.float32
)
elif attrs[0]["data_format"] == "NHWC":
return np.random.random([1, 16, 32, 3]).astype(
np.float32
)
else:
raise AssertionError()

for batch in [1, 4]:
for dims in [0, 1, 2, 3, 4]:
for mode in ["channel"]:
for data_format in ["NHWC"]:
for mode in ["all", "element", "channel"]:
for data_format in ["NCHW", "NHWC"]:
if (mode == "element" or mode == "all") and dims == 0:
continue
if mode == "channel" and dims != 4:
continue
self.dims = dims
Expand All @@ -92,13 +110,13 @@ def generate_alpha(attrs: List[Dict[str, Any]], dims):
ops=ops,
weights={
"alpha_weight": TensorConfig(
data_gen=partial(generate_alpha, dics, dims)
data_gen=partial(generate_alpha, dics)
)
},
inputs={
"input_data": TensorConfig(
data_gen=partial(
generate_input, dics, batch, dims
generate_input, dics, batch
)
),
},
Expand All @@ -116,27 +134,59 @@ def generate_dynamic_shape(attrs):
self.dynamic_shape.max_input_shape = {"input_data": []}
self.dynamic_shape.opt_input_shape = {"input_data": []}
elif self.dims == 1:
self.dynamic_shape.min_input_shape = {"input_data": [1]}
self.dynamic_shape.max_input_shape = {"input_data": [32]}
self.dynamic_shape.min_input_shape = {"input_data": [16]}
self.dynamic_shape.max_input_shape = {"input_data": [16]}
self.dynamic_shape.opt_input_shape = {"input_data": [16]}
elif self.dims == 2:
self.dynamic_shape.min_input_shape = {"input_data": [1, 3]}
self.dynamic_shape.max_input_shape = {"input_data": [4, 3]}
self.dynamic_shape.opt_input_shape = {"input_data": [3, 3]}
self.dynamic_shape.max_input_shape = {"input_data": [1, 3]}
self.dynamic_shape.opt_input_shape = {"input_data": [1, 3]}
elif self.dims == 3:
self.dynamic_shape.min_input_shape = {"input_data": [1, 3, 16]}
self.dynamic_shape.max_input_shape = {"input_data": [4, 16, 16]}
self.dynamic_shape.opt_input_shape = {"input_data": [3, 3, 16]}
if attrs[0]["data_format"] == "NCHW":
self.dynamic_shape.min_input_shape = {
"input_data": [1, 3, 16]
}
self.dynamic_shape.max_input_shape = {
"input_data": [4, 3, 16]
}
self.dynamic_shape.opt_input_shape = {
"input_data": [1, 3, 16]
}
elif attrs[0]["data_format"] == "NHWC":
self.dynamic_shape.min_input_shape = {
"input_data": [1, 16, 3]
}
self.dynamic_shape.max_input_shape = {
"input_data": [4, 16, 3]
}
self.dynamic_shape.opt_input_shape = {
"input_data": [1, 16, 3]
}
else:
raise AssertionError()
else:
self.dynamic_shape.min_input_shape = {
"input_data": [1, 3, 16, 3]
}
self.dynamic_shape.max_input_shape = {
"input_data": [4, 16, 32, 32]
}
self.dynamic_shape.opt_input_shape = {
"input_data": [1, 3, 16, 32]
}
if attrs[0]["data_format"] == "NCHW":
self.dynamic_shape.min_input_shape = {
"input_data": [1, 3, 16, 32]
}
self.dynamic_shape.max_input_shape = {
"input_data": [4, 3, 16, 32]
}
self.dynamic_shape.opt_input_shape = {
"input_data": [1, 3, 16, 32]
}
elif attrs[0]["data_format"] == "NHWC":
self.dynamic_shape.min_input_shape = {
"input_data": [1, 16, 32, 3]
}
self.dynamic_shape.max_input_shape = {
"input_data": [4, 16, 32, 3]
}
self.dynamic_shape.opt_input_shape = {
"input_data": [1, 16, 32, 3]
}
else:
raise AssertionError()

def clear_dynamic_shape():
self.dynamic_shape.max_input_shape = {}
Expand Down

0 comments on commit a38f81c

Please sign in to comment.