Skip to content

Commit 6ea5eec

Browse files
authored
[PHI] Skip check in DeformableConvInferMeta on dynamic dim (#74650)
1 parent 0665d48 commit 6ea5eec

File tree

1 file changed

+28
-25
lines changed

1 file changed

+28
-25
lines changed

paddle/phi/infermeta/multiary.cc

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1853,31 +1853,34 @@ void DeformableConvInferMeta(const MetaTensor& x,
18531853
paddings.size(),
18541854
strides.size()));
18551855

1856-
PADDLE_ENFORCE_EQ(
1857-
in_dims[1],
1858-
filter_dims[1] * groups,
1859-
common::errors::InvalidArgument(
1860-
"The number of input channels should be equal to filter "
1861-
"channels * groups. The difference is [%d]: [%d]",
1862-
in_dims[1],
1863-
filter_dims[1] * groups));
1864-
PADDLE_ENFORCE_EQ(
1865-
filter_dims[0] % groups,
1866-
0,
1867-
common::errors::InvalidArgument(
1868-
"The number of output channels should be divided by groups. But "
1869-
"received output channels:[%d], groups:[%d]",
1870-
filter_dims[0],
1871-
groups));
1872-
PADDLE_ENFORCE_EQ(
1873-
filter_dims[0] % deformable_groups,
1874-
0,
1875-
common::errors::InvalidArgument(
1876-
"The number of output channels should be "
1877-
"divided by deformable groups. The difference is [%d]: [%d]",
1878-
filter_dims[0] % groups,
1879-
0));
1880-
1856+
if (config.is_runtime || (filter_dims[1] != -1 && in_dims[1] != -1)) {
1857+
PADDLE_ENFORCE_EQ(
1858+
in_dims[1],
1859+
filter_dims[1] * groups,
1860+
common::errors::InvalidArgument(
1861+
"The number of input channels should be equal to filter "
1862+
"channels * groups. The difference is [%d]: [%d]",
1863+
in_dims[1],
1864+
filter_dims[1] * groups));
1865+
}
1866+
if (config.is_runtime || filter_dims[0] != -1) {
1867+
PADDLE_ENFORCE_EQ(
1868+
filter_dims[0] % groups,
1869+
0,
1870+
common::errors::InvalidArgument(
1871+
"The number of output channels should be divided by groups. But "
1872+
"received output channels:[%d], groups:[%d]",
1873+
filter_dims[0],
1874+
groups));
1875+
PADDLE_ENFORCE_EQ(
1876+
filter_dims[0] % deformable_groups,
1877+
0,
1878+
common::errors::InvalidArgument(
1879+
"The number of output channels should be "
1880+
"divided by deformable groups. The difference is [%d]: [%d]",
1881+
filter_dims[0] % groups,
1882+
0));
1883+
}
18811884
if (in_dims[0] > im2col_step) {
18821885
PADDLE_ENFORCE_EQ(
18831886
in_dims[0] % im2col_step,

0 commit comments

Comments
 (0)