@@ -217,9 +217,84 @@ static void fuseListAndListUnpack(Block* b) {
217
217
}
218
218
}
219
219
220
+ static void decomposeLinear (Block* b) {
221
+ std::vector<Node*> linear_nodes;
222
+ for (auto it = b->nodes ().begin (), end = b->nodes ().end (); it != end; ++it) {
223
+ for (auto * child_block : it->blocks ()) {
224
+ decomposeLinear (child_block);
225
+ }
226
+ if (it->kind () == aten::linear) {
227
+ linear_nodes.push_back (*it);
228
+ }
229
+ }
230
+ for (Node* node : linear_nodes) {
231
+ auto g = b->owningGraph ();
232
+
233
+ if (node->inputs ()[2 ]->mustBeNone ()) {
234
+ auto t_weight_n =
235
+ g->create (aten::t, {node->inputs ()[1 ]}, 1 )->insertBefore (node);
236
+ auto matmul_n =
237
+ g->create (aten::matmul, {node->inputs ()[0 ], t_weight_n->output ()}, 1 )
238
+ ->insertBefore (node);
239
+ node->output ()->replaceAllUsesWith (matmul_n->output ());
240
+ node->destroy ();
241
+ } else {
242
+ auto dim_n =
243
+ g->create (aten::dim, {node->inputs ()[0 ]}, 1 )->insertBefore (node);
244
+ auto const_2 = g->insertConstant (IValue (2 ));
245
+ const_2->node ()->moveBefore (node);
246
+ auto eq_n = g->create (aten::eq, {dim_n->output (), const_2}, 1 )
247
+ ->insertBefore (node);
248
+
249
+ auto if_n = g->create (prim::If, {eq_n->output ()}, 1 )->insertBefore (node);
250
+
251
+ auto true_block = if_n->addBlock ();
252
+ auto false_block = if_n->addBlock ();
253
+
254
+ {
255
+ WithInsertPoint guard (true_block->return_node ());
256
+ auto const_1 = g->insertConstant (IValue (1.0 ));
257
+ auto t_weight_n = g->create (aten::t, {node->inputs ()[1 ]}, 1 )
258
+ ->insertBefore (true_block->return_node ());
259
+ auto addmm_n = g->create (
260
+ aten::addmm,
261
+ {node->inputs ()[2 ],
262
+ node->inputs ()[0 ],
263
+ t_weight_n->output (),
264
+ const_1,
265
+ const_1},
266
+ 1 )
267
+ ->insertBefore (true_block->return_node ());
268
+ true_block->registerOutput (addmm_n->output ());
269
+ }
270
+
271
+ {
272
+ WithInsertPoint guard (false_block->return_node ());
273
+ auto const_1 = g->insertConstant (IValue (1.0 ));
274
+ auto t_weight_n = g->create (aten::t, {node->inputs ()[1 ]}, 1 )
275
+ ->insertBefore (false_block->return_node ());
276
+ auto matmul_n =
277
+ g->create (
278
+ aten::matmul, {node->inputs ()[0 ], t_weight_n->output ()}, 1 )
279
+ ->insertBefore (false_block->return_node ());
280
+ auto add_n =
281
+ g->create (
282
+ aten::add, {matmul_n->output (), node->inputs ()[2 ], const_1}, 1 )
283
+ ->insertBefore (false_block->return_node ());
284
+ false_block->registerOutput (add_n->output ());
285
+ }
286
+ node->output ()->replaceAllUsesWith (if_n->output ());
287
+ node->destroy ();
288
+ }
289
+ }
290
+ }
291
+
220
292
} // namespace
221
293
222
294
void PreprocessForONNX (std::shared_ptr<Graph>& graph) {
295
+ GRAPH_DEBUG (" priot to decompose linear" , graph);
296
+ decomposeLinear (graph->block ());
297
+ GRAPH_DEBUG (" after decompose linear" , graph);
223
298
FuseWithListUnpack (graph->block ());
224
299
ReplaceAddWithConcat (graph->block ());
225
300
fuseListAndListUnpack (graph->block ());
0 commit comments