Skip to content

Commit

Permalink
feat: Implement fast approximation of Gelu as lowering pass to improv…
Browse files Browse the repository at this point in the history
…e performance

Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>

Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>

Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>

chore: refactor converters

Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>

chore: Upload reduce_gelu.cpp

Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>

chore: Add files

Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
  • Loading branch information
peri044 committed Nov 15, 2021
1 parent da15fa5 commit 8024ea2
Show file tree
Hide file tree
Showing 8 changed files with 99 additions and 32 deletions.
32 changes: 0 additions & 32 deletions core/conversion/converters/impl/activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,39 +166,7 @@ auto acthardtanh TORCHTRT_UNUSED =
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
return true;
}})
.pattern({"aten::gelu(Tensor self) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in = args[0].ITensorOrFreeze(ctx);
nvinfer1::DataType type = in->getType();
TORCHTRT_CHECK(
type == nvinfer1::DataType::kFLOAT || type == nvinfer1::DataType::kHALF,
"gelu only supports kFLOAT and kHALF");
std::string pluginName = "CustomGeluPluginDynamic";
nvinfer1::PluginFieldCollection fc;
std::vector<nvinfer1::PluginField> f;
// REVIEW is this right?
int type_id = ctx->settings.enabled_precisions.find(nvinfer1::DataType::kHALF) ==
ctx->settings.enabled_precisions.end()
? 0
: 1; // Integer encoding the DataType (0: FP32, 1: FP16)
f.emplace_back(nvinfer1::PluginField("type_id", &type_id, nvinfer1::PluginFieldType::kINT32, 1));
fc.nbFields = f.size();
fc.fields = f.data();

auto creator = getPluginRegistry()->getPluginCreator("CustomGeluPluginDynamic", "1", "");
auto gelu_plugin = creator->createPlugin("gelu", &fc);

TORCHTRT_CHECK(gelu_plugin, "Unable to create gelu plugin from TensorRT plugin registry" << *n);
auto new_layer =
ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&in), 1, *gelu_plugin);
new_layer->setName(util::node_info(n).c_str());
auto out_tensor = new_layer->getOutput(0);
out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor);
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
return true;
}});

} // namespace
} // namespace impl
} // namespace converters
Expand Down
1 change: 1 addition & 0 deletions core/lowering/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
passes::UnpackHardSwish(g);
passes::EliminateExceptionOrPassPattern(g);
passes::ReduceToOperation(g);
passes::ReduceGelu(g);
passes::RemoveContiguous(g);
passes::RemoveDropout(g);
passes::LinearToAddMM(g);
Expand Down
1 change: 1 addition & 0 deletions core/lowering/passes/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ cc_library(
"module_fallback.cpp",
"op_aliasing.cpp",
"reduce_to.cpp",
"reduce_gelu.cpp",
"remove_bn_dim_check.cpp",
"remove_contiguous.cpp",
"remove_dropout.cpp",
Expand Down
1 change: 1 addition & 0 deletions core/lowering/passes/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ void FuseAddMMBranches(std::shared_ptr<torch::jit::Graph> graph);
void LinearToAddMM(std::shared_ptr<torch::jit::Graph>& graph);
void EliminateExceptionOrPassPattern(std::shared_ptr<torch::jit::Graph> graph);
void ReduceToOperation(std::shared_ptr<torch::jit::Graph>& graph);
void ReduceGelu(std::shared_ptr<torch::jit::Graph>& graph);
void MarkNodesForFallback(std::shared_ptr<torch::jit::Graph>& g, bool delete_delims);
void RemoveBNDimCheck(std::shared_ptr<torch::jit::Graph> graph);
void RemoveContiguous(std::shared_ptr<torch::jit::Graph>& graph);
Expand Down
44 changes: 44 additions & 0 deletions core/lowering/passes/reduce_gelu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
#include "core/util/prelude.h"

namespace torch_tensorrt {
namespace core {
namespace lowering {
namespace passes {

void ReduceGelu(std::shared_ptr<torch::jit::Graph>& graph) {
std::string gelu_pattern = R"IR(
graph(%x):
%out : Tensor = aten::gelu(%x)
return (%out))IR";

std::string gelu_reduce_pattern = R"IR(
graph(%x.1 : Tensor):
%6 : float = prim::Constant[value=0.044714999999999998]()
%5 : float = prim::Constant[value=0.79788456080000003]()
%4 : float = prim::Constant[value=1.]()
%3 : float = prim::Constant[value=0.5]()
%2 : int = prim::Constant[value=1]()
%7 : Tensor = aten::mul(%x.1, %3)
%8 : Tensor = aten::mul(%x.1, %5)
%9 : Tensor = aten::mul(%x.1, %6)
%10 : Tensor = aten::mul(%9, %x.1)
%11 : Tensor = aten::add(%10, %4, %2)
%12 : Tensor = aten::mul(%8, %11)
%13 : Tensor = aten::tanh(%12)
%14 : Tensor = aten::add(%13, %4, %2)
%15 : Tensor = aten::mul(%7, %14)
return (%15))IR";

// replace aten::gelu with pointwise operations
torch::jit::SubgraphRewriter map_gelu_to_pointwise_ops;
map_gelu_to_pointwise_ops.RegisterRewritePattern(gelu_pattern, gelu_reduce_pattern);
map_gelu_to_pointwise_ops.runOnGraph(graph);

LOG_GRAPH("Post lowering of [aten::gelu] -> " << *graph);
}

} // namespace passes
} // namespace lowering
} // namespace core
} // namespace torch_tensorrt
5 changes: 5 additions & 0 deletions tests/core/conversion/converters/test_activation.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <string>
#include "core/compiler.h"
#include "core/lowering/passes/passes.h"
#include "gtest/gtest.h"
#include "tests/util/util.h"
#include "torch/csrc/jit/ir/irparser.h"
Expand Down Expand Up @@ -211,6 +212,10 @@ TEST(Converters, ATenGELUConvertsCorrectly) {

auto in = at::randint(-5, 5, {5}, {at::kCUDA});

// Lower aten::gelu to pointwise operators using Fast approximation
// Gelu(x) = 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
torch_tensorrt::core::lowering::passes::ReduceGelu(g);

auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});

Expand Down
5 changes: 5 additions & 0 deletions tests/core/lowering/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ lowering_test(
name = "test_reduce_to_pass",
)

lowering_test(
name = "test_reduce_gelu",
)

lowering_test(
name = "test_remove_detach_pass",
)
Expand Down Expand Up @@ -73,6 +77,7 @@ test_suite(
":test_remove_detach_pass",
":test_remove_dropout_pass",
":test_reduce_to_pass",
":test_reduce_gelu",
":test_unpack_hardswish",
":test_unpack_reduce_ops"
],
Expand Down
42 changes: 42 additions & 0 deletions tests/core/lowering/test_reduce_gelu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#include <string>
#include "core/compiler.h"
#include "core/lowering/passes/passes.h"
#include "gtest/gtest.h"
#include "tests/util/util.h"
#include "torch/csrc/jit/ir/irparser.h"
#include "torch/csrc/jit/ir/subgraph_matcher.h"

TEST(LoweringPasses, ReduceGeluCorrectly) {
std::string source_graph = R"IR(
graph(%x):
%out : Tensor = aten::gelu(%x)
return (%out))IR";
std::string target_graph = R"IR(
graph(%x.1 : Tensor):
%6 : float = prim::Constant[value=0.044714999999999998]()
%5 : float = prim::Constant[value=0.79788456080000003]()
%4 : float = prim::Constant[value=1.]()
%3 : float = prim::Constant[value=0.5]()
%2 : int = prim::Constant[value=1]()
%7 : Tensor = aten::mul(%x.1, %3)
%8 : Tensor = aten::mul(%x.1, %5)
%9 : Tensor = aten::mul(%x.1, %6)
%10 : Tensor = aten::mul(%9, %x.1)
%11 : Tensor = aten::add(%10, %4, %2)
%12 : Tensor = aten::mul(%8, %11)
%13 : Tensor = aten::tanh(%12)
%14 : Tensor = aten::add(%13, %4, %2)
%15 : Tensor = aten::mul(%7, %14)
return (%15))IR";

torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
auto sg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(source_graph, &*sg);
torch_tensorrt::core::lowering::passes::ReduceGelu(sg);

auto tg = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(target_graph, &*tg);

ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
}

0 comments on commit 8024ea2

Please sign in to comment.