You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
feat(//core/plugins): Add adaptive_max_pool2d plugin, enable the plugins to run on GPU
Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
"Adaptive pooling layer will be run through ATen, via not TensorRT, performace will be lower than expected. Consider switching either to static input shape or moving to non adaptive pooling if this is an issue");
67
-
#else
68
-
LOG_WARNING(
69
-
"Adaptive pooling layer will be run through ATen (on CPU), via not TensorRT, performace will suffer. Consider switching either to static input shape or moving to non adaptive pooling");
70
-
#endif
63
+
/*======CONFIGURE PLUGIN PARAMETERS======*/
64
+
nvinfer1::PluginFieldCollection fc;
65
+
std::vector<nvinfer1::PluginField> f;
71
66
72
-
TRTORCH_CHECK(
73
-
pool_type == nvinfer1::PoolingType::kAVERAGE,
74
-
"Unable to create MAX pooling (interpolation) plugin from node" << *n);
"Adaptive pooling layer will be using Aten library kernels in pytorch for execution. TensorRT does not support adaptive pooling natively. Consider switching to non-adaptive pooling if this is an issue");
126
104
127
-
LOG_DEBUG("Window: " << util::toDims(window));
105
+
auto creator = getPluginRegistry()->getPluginCreator("Interpolate", "1", "trtorch");
106
+
auto interpolate_plugin = creator->createPlugin(mode.c_str(), &fc);
128
107
129
-
auto pooling_layer = ctx->net->addPoolingNd(*in, pool_type, util::toDims(window));
130
-
TRTORCH_CHECK(pooling_layer, "Unable to create average pooling layer from node: " << *n);
0 commit comments