Skip to content

Commit

Permalink
fix: Fix linear lowering pass, lift layer_norm scale layer restrictio…
Browse files Browse the repository at this point in the history
…n and matmul layer nbdims restriction

Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
  • Loading branch information
peri044 committed Jun 16, 2021
1 parent 3cb4917 commit 930d582
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 24 deletions.
29 changes: 24 additions & 5 deletions core/conversion/converters/impl/layer_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,31 @@ auto layer_norm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().
}

auto power = Weights(ctx, at::ones(expand_size));
auto scale_nd = ctx->net->addScaleNd(
*div_out, nvinfer1::ScaleMode::kELEMENTWISE, beta_weights.data, gamma_weights.data, power.data, 1);
scale_nd->setName((util::node_info(n) + "_scale_nd").c_str());
auto scale_nd_out = scale_nd->getOutput(0);

ctx->AssociateValueAndTensor(n->outputs()[0], scale_nd_out);
auto gamma_tensor = ctx->net->addConstant(gamma_weights.shape, gamma_weights.data)->getOutput(0);
auto scale_l = add_elementwise(
ctx, nvinfer1::ElementWiseOperation::kPROD, div_out, gamma_tensor, (util::node_info(n) + "_scale").c_str());

auto beta_tensor = ctx->net->addConstant(beta_weights.shape, beta_weights.data)->getOutput(0);
auto shift_l = add_elementwise(
ctx,
nvinfer1::ElementWiseOperation::kSUM,
scale_l->getOutput(0),
beta_tensor,
(util::node_info(n) + "_shift").c_str());

auto power_tensor = ctx->net->addConstant(power.shape, power.data)->getOutput(0);
auto power_l = add_elementwise(
ctx,
nvinfer1::ElementWiseOperation::kPOW,
shift_l->getOutput(0),
power_tensor,
(util::node_info(n) + "_power").c_str());

power_l->setName((util::node_info(n) + "_scale_nd").c_str());
auto power_l_out = power_l->getOutput(0);

ctx->AssociateValueAndTensor(n->outputs()[0], power_l_out);
return true;
}});

Expand Down
9 changes: 6 additions & 3 deletions core/conversion/converters/impl/matrix_multiply.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "core/conversion/converters/converter_util.h"
#include "core/conversion/converters/converters.h"
#include "core/util/prelude.h"

Expand All @@ -16,10 +17,12 @@ auto mm_registrations TRTORCH_UNUSED =
LOG_DEBUG("self tensor shape: " << self->getDimensions());

auto other = args[1].ITensorOrFreeze(ctx);
LOG_DEBUG("other tensor shape: " << other->getDimensions());
// "other" tensor should have same nbDims as self
auto wt_tensor = addPadding(ctx, n, other, self->getDimensions().nbDims, false, false);
LOG_DEBUG("other tensor shape: " << wt_tensor->getDimensions());

auto mm_layer = ctx->net->addMatrixMultiply(
*self, nvinfer1::MatrixOperation::kNONE, *other, nvinfer1::MatrixOperation::kNONE);
*self, nvinfer1::MatrixOperation::kNONE, *wt_tensor, nvinfer1::MatrixOperation::kNONE);
TRTORCH_CHECK(mm_layer, "Unable to create matrix multiplication node: " << *n);
mm_layer->setName(util::node_info(n).c_str());
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], mm_layer->getOutput(0));
Expand Down Expand Up @@ -73,4 +76,4 @@ auto mm_registrations TRTORCH_UNUSED =
} // namespace converters
} // namespace conversion
} // namespace core
} // namespace trtorch
} // namespace trtorch
57 changes: 41 additions & 16 deletions core/lowering/passes/linear_to_addmm.cpp
Original file line number Diff line number Diff line change
@@ -1,23 +1,55 @@
#include "torch/csrc/jit/passes/subgraph_rewrite.h"

#include <torch/csrc/jit/runtime/operator.h>
#include "torch/csrc/jit/ir/alias_analysis.h"
#include "torch/csrc/jit/jit_log.h"
#include "torch/csrc/jit/passes/constant_propagation.h"
#include "torch/csrc/jit/passes/dead_code_elimination.h"
#include "torch/csrc/jit/passes/guard_elimination.h"
#include "torch/csrc/jit/passes/peephole.h"
#include "torch/csrc/jit/runtime/graph_executor.h"

#include "core/util/prelude.h"
#include "torch/csrc/jit/passes/subgraph_rewrite.h"

namespace trtorch {
namespace core {
namespace lowering {
namespace passes {

void replaceLinearWithBiasNonePattern(std::shared_ptr<torch::jit::Graph> graph) {
// Define the decomposition function for aten::linear for the case where bias (mat2) is None.
static torch::jit::CompilationUnit decompose_funcs(R"SCRIPT(
def linear(self: Tensor, mat1: Tensor, mat2: Tensor):
return torch.matmul(self, mat1.t())
)SCRIPT");

// Iterate through nodes and search for aten::linear nodes where bias is not a Tensor (includes bias=None case)
auto block = graph->block();
for (auto it = block->nodes().begin(); it != block->nodes().end(); it++) {
auto n = *it;
if (n->kind().toQualString() == std::string("aten::linear")) {
auto input_values = n->inputs();
// input_values[2] is the bias. If none, replace it with the decomposed linear graph.
if (input_values[2]->type()->isSubtypeOf(c10::TensorType::get())) {
continue;
} else {
torch::jit::WithInsertPoint guard(*it);
std::shared_ptr<torch::jit::Graph> d_graph = decompose_funcs.get_function("linear").graph();
torch::jit::Value* new_output = insertGraph(*it->owningGraph(), *d_graph, it->inputs()).at(0);
new_output->setType(it->output()->type());
it->output()->replaceAllUsesWith(new_output);
it.destroyCurrent();
}
}
}
}

void LinearToAddMM(std::shared_ptr<torch::jit::Graph>& graph) {
// TensorRT implicitly adds a flatten layer infront of FC layers if necessary
std::string flatten_linear_pattern = R"IR(
graph(%input, %weight, %bias):
%res = aten::linear(%input, %weight, %bias)
return (%res))IR";
std::string flatten_linear_bias_none_pattern = R"IR(
graph(%input, %weight):
%bias: Tensor? = prim::Constant()
%res = aten::linear(%input, %weight, %bias)
return (%res))IR";

std::string fused_linear = R"IR(
graph(%input, %weight_t, %bias):
Expand All @@ -27,20 +59,13 @@ void LinearToAddMM(std::shared_ptr<torch::jit::Graph>& graph) {
%b_f: Tensor = trt::const(%bias)
%out: Tensor = aten::add(%b_f, %mm, %1)
return (%out))IR";
std::string fused_linear_bias_none = R"IR(
graph(%input, %weight_t):
%weight = aten::t(%weight_t)
%mm: Tensor = aten::matmul(%input, %weight)
return (%mm))IR";

// First find and replace aten::linear nodes with non-tensor bias values.
replaceLinearWithBiasNonePattern(graph);

torch::jit::SubgraphRewriter flatten_linear_to_linear;
flatten_linear_to_linear.RegisterRewritePattern(flatten_linear_pattern, fused_linear);
flatten_linear_to_linear.runOnGraph(graph);

torch::jit::SubgraphRewriter flatten_linear_bias_none_to_linear;
flatten_linear_bias_none_to_linear.RegisterRewritePattern(flatten_linear_bias_none_pattern, fused_linear_bias_none);
flatten_linear_bias_none_to_linear.runOnGraph(graph);
LOG_GRAPH("Post linear to addmm: " << *graph);
}

} // namespace passes
Expand Down

0 comments on commit 930d582

Please sign in to comment.