1- #include " torch/csrc/jit/passes/subgraph_rewrite.h"
1+
2+ #include < torch/csrc/jit/runtime/operator.h>
3+ #include " torch/csrc/jit/ir/alias_analysis.h"
4+ #include " torch/csrc/jit/jit_log.h"
5+ #include " torch/csrc/jit/passes/constant_propagation.h"
6+ #include " torch/csrc/jit/passes/dead_code_elimination.h"
7+ #include " torch/csrc/jit/passes/guard_elimination.h"
8+ #include " torch/csrc/jit/passes/peephole.h"
9+ #include " torch/csrc/jit/runtime/graph_executor.h"
210
311#include " core/util/prelude.h"
12+ #include " torch/csrc/jit/passes/subgraph_rewrite.h"
413
514namespace trtorch {
615namespace core {
716namespace lowering {
817namespace passes {
918
19+ void replaceLinearWithBiasNonePattern (std::shared_ptr<torch::jit::Graph> graph) {
20+ // Define the decomposition function for aten::linear for the case where bias (mat2) is None.
21+ static torch::jit::CompilationUnit decompose_funcs (R"SCRIPT(
22+ def linear(self: Tensor, mat1: Tensor, mat2: Tensor):
23+ return torch.matmul(self, mat1.t())
24+ )SCRIPT" );
25+
26+ // Iterate through nodes and search for aten::linear nodes where bias is not a Tensor (includes bias=None case)
27+ auto block = graph->block ();
28+ for (auto it = block->nodes ().begin (); it != block->nodes ().end (); it++) {
29+ auto n = *it;
30+ if (n->kind ().toQualString () == std::string (" aten::linear" )) {
31+ auto input_values = n->inputs ();
32+ // input_values[2] is the bias. If none, replace it with the decomposed linear graph.
33+ if (input_values[2 ]->type ()->isSubtypeOf (c10::TensorType::get ())) {
34+ continue ;
35+ } else {
36+ torch::jit::WithInsertPoint guard (*it);
37+ std::shared_ptr<torch::jit::Graph> d_graph = decompose_funcs.get_function (" linear" ).graph ();
38+ torch::jit::Value* new_output = insertGraph (*it->owningGraph (), *d_graph, it->inputs ()).at (0 );
39+ new_output->setType (it->output ()->type ());
40+ it->output ()->replaceAllUsesWith (new_output);
41+ it.destroyCurrent ();
42+ }
43+ }
44+ }
45+ }
46+
1047void LinearToAddMM (std::shared_ptr<torch::jit::Graph>& graph) {
1148 // TensorRT implicitly adds a flatten layer infront of FC layers if necessary
1249 std::string flatten_linear_pattern = R"IR(
1350 graph(%input, %weight, %bias):
1451 %res = aten::linear(%input, %weight, %bias)
1552 return (%res))IR" ;
16- std::string flatten_linear_bias_none_pattern = R"IR(
17- graph(%input, %weight):
18- %bias: Tensor? = prim::Constant()
19- %res = aten::linear(%input, %weight, %bias)
20- return (%res))IR" ;
2153
2254 std::string fused_linear = R"IR(
2355 graph(%input, %weight_t, %bias):
@@ -27,20 +59,13 @@ void LinearToAddMM(std::shared_ptr<torch::jit::Graph>& graph) {
2759 %b_f: Tensor = trt::const(%bias)
2860 %out: Tensor = aten::add(%b_f, %mm, %1)
2961 return (%out))IR" ;
30- std::string fused_linear_bias_none = R"IR(
31- graph(%input, %weight_t):
32- %weight = aten::t(%weight_t)
33- %mm: Tensor = aten::matmul(%input, %weight)
34- return (%mm))IR" ;
62+
63+ // First find and replace aten::linear nodes with non-tensor bias values.
64+ replaceLinearWithBiasNonePattern (graph);
3565
3666 torch::jit::SubgraphRewriter flatten_linear_to_linear;
3767 flatten_linear_to_linear.RegisterRewritePattern (flatten_linear_pattern, fused_linear);
3868 flatten_linear_to_linear.runOnGraph (graph);
39-
40- torch::jit::SubgraphRewriter flatten_linear_bias_none_to_linear;
41- flatten_linear_bias_none_to_linear.RegisterRewritePattern (flatten_linear_bias_none_pattern, fused_linear_bias_none);
42- flatten_linear_bias_none_to_linear.runOnGraph (graph);
43- LOG_GRAPH (" Post linear to addmm: " << *graph);
4469}
4570
4671} // namespace passes
0 commit comments