1
1
#include " partitioning.h"
2
+ #include < queue>
2
3
#include " core/conversion/evaluators/eval_util.h"
3
4
#include " core/lowering/passes/passes.h"
4
5
#include " core/util/prelude.h"
5
6
#include " torch/csrc/jit/api/module.h"
6
- #include " torch/csrc/jit/ir/constants .h"
7
+ #include " torch/csrc/jit/passes/constant_pooling .h"
7
8
8
9
namespace trtorch {
9
10
namespace core {
10
11
namespace partitioning {
11
12
13
+ inline bool isTensorOrTensorList (torch::jit::Value* val) {
14
+ return val->type ()->isSubtypeOf (torch::jit::TensorType::get ()) ||
15
+ val->type ()->isSubtypeOf (torch::jit::ListType::ofTensors ());
16
+ }
17
+
18
+ struct usage_info {
19
+ int produce_id = -1 ;
20
+ std::vector<int > torch_use_id;
21
+ std::vector<int > tensorrt_use_id;
22
+ };
23
+
12
24
torch::jit::Value* getOrAddInputForValue (
13
25
torch::jit::Value* old_value,
14
26
std::shared_ptr<torch::jit::Graph>& graph,
@@ -39,6 +51,7 @@ torch::jit::Node* cloneNode(
39
51
auto * block = graph->block ();
40
52
auto env = [&](torch::jit::Value* v) { return getOrAddInputForValue (v, graph, old_to_new); };
41
53
54
+ // create node for current graph by using the metadata in node and input Values in env
42
55
auto new_node = block->appendNode (graph->createClone (node, env));
43
56
for (size_t i = 0 ; i < node->outputs ().size (); ++i) {
44
57
auto oo = node->outputs ()[i];
@@ -68,7 +81,6 @@ void registerSegmentInOutIValues(
68
81
// create a module to run the graph
69
82
auto g = seg_block.g ();
70
83
auto copy_g = g->copy ();
71
- // LOG_INFO(*copy_g << "(copy graph)\n");
72
84
73
85
// create tuple for multiple outputs
74
86
if (seg_block.raw_outputs ().size () > 1 ) {
@@ -110,7 +122,10 @@ void registerSegmentInOutIValues(
110
122
111
123
// run segments to get outputs for later segments input shape, and other arguments such as Int
112
124
std::vector<torch::jit::IValue> jit_results;
125
+ printf (" before forward\n " );
113
126
torch::jit::IValue jit_results_ivalues = cur_mod.forward (jit_inputs_ivalues);
127
+ printf (" after forward\n " );
128
+
114
129
if (jit_results_ivalues.isTuple ()) {
115
130
auto results = jit_results_ivalues.toTuple ()->elements ();
116
131
for (auto r : results) {
@@ -149,13 +164,10 @@ std::vector<torch::jit::IValue> generateRandomInputs(std::vector<conversion::Inp
149
164
return random_inputs;
150
165
}
151
166
152
- void registerSegmentsInputsOutputs (
153
- std::vector<SegmentedBlock>& segmented_blocks,
154
- std::shared_ptr<torch::jit::Graph> g) {
167
+ void registerSegmentsOutputs (std::vector<SegmentedBlock>& segmented_blocks, std::shared_ptr<torch::jit::Graph> g) {
155
168
// find the corresponding raw values in original global graph for this segmented block's inputs/outputs
156
169
std::set<torch::jit::Value*> input_values;
157
170
for (auto & seg_block : segmented_blocks) {
158
- seg_block.registerInputs ();
159
171
for (auto & input : seg_block.raw_inputs ()) {
160
172
input_values.insert (input);
161
173
}
@@ -165,51 +177,124 @@ void registerSegmentsInputsOutputs(
165
177
input_values.insert (graph_output);
166
178
}
167
179
168
- // should be careful here because some in-place operations don't return any values
180
+ // should be careful here because some in-place operations don't return any values, there is no output for this kind
181
+ // of segment identify the output for each mini-graph by checking if any value in this graph is used later we
182
+ // shouldn't register nonTensor output for TensorRT segments
169
183
for (auto & seg_block : segmented_blocks) {
170
184
for (auto & mini_graph_input : input_values) {
171
185
if (std::find (seg_block.raw_inputs ().begin (), seg_block.raw_inputs ().end (), mini_graph_input) ==
172
186
seg_block.raw_inputs ().end () &&
173
- seg_block.contain_raw_input (mini_graph_input)) {
187
+ seg_block.contain_raw_value (mini_graph_input)) {
188
+ if (!isTensorOrTensorList (mini_graph_input) && seg_block.target () == SegmentedBlock::kTensorRT )
189
+ continue ;
174
190
seg_block.registerOutput (mini_graph_input);
175
191
}
176
192
}
193
+ // if no output, then register the last node's output as current graph's output
177
194
if (seg_block.raw_outputs ().empty ()) {
178
- seg_block.registerOutput (seg_block.raw_inputs ()[0 ]);
195
+ // for Torch segments, register input as output
196
+ if (seg_block.target () == SegmentedBlock::kTorch ) {
197
+ seg_block.registerOutput (seg_block.raw_inputs ()[0 ]);
198
+ } else {
199
+ // for TensorRT segments, register last nonInput Tensor outputs
200
+ for (int i = seg_block.raw_nodes ().size () - 1 ; i >= 0 ; --i) {
201
+ for (auto node_output : seg_block.raw_nodes ()[i]->outputs ()) {
202
+ if (isTensorOrTensorList (node_output))
203
+ seg_block.registerOutput (node_output);
204
+ }
205
+ if (!seg_block.raw_outputs ().empty ())
206
+ break ;
207
+ }
208
+ }
179
209
}
180
210
}
211
+ // erase segments which still have no output
212
+ segmented_blocks.erase (
213
+ std::remove_if (
214
+ segmented_blocks.begin (),
215
+ segmented_blocks.end (),
216
+ [](SegmentedBlock& seg_block) { return seg_block.raw_outputs ().empty (); }),
217
+ segmented_blocks.end ());
181
218
182
219
return ;
183
220
}
184
221
185
- void eraseNonTensorInputsOutputs (
186
- SegmentedBlock& seg_block,
187
- std::unordered_map<torch::jit::Value*, torch::jit::IValue>& ivalues_maps) {
188
- if (seg_block.target () == SegmentedBlock::kTorch )
189
- return ;
190
- auto mini_graph = seg_block.g ();
191
-
192
- for (int i = seg_block.raw_inputs ().size () - 1 ; i >= 0 ; --i) {
193
- // erase this input and prepend a prim::Constant if it's not Tensor
194
- if (!seg_block.raw_inputs ()[i]->type ()->isSubtypeOf (torch::jit::TensorType::get ()) &&
195
- !seg_block.raw_inputs ()[i]->type ()->isSubtypeOf (c10::ListType::ofTensors ())) {
196
- auto new_val = torch::jit::insertConstant (*mini_graph, ivalues_maps[seg_block.raw_inputs ()[i]]);
197
- seg_block.inputs ()[i]->replaceAllUsesWith (new_val);
198
- seg_block.eraseInput (i);
222
+ std::vector<torch::jit::Node*> getDependencyNodes (std::vector<torch::jit::Value*>& vals) {
223
+ // using bfs to get the DAG dependency nodes for input value
224
+ std::queue<torch::jit::Value*, std::deque<torch::jit::Value*>> q (
225
+ std::deque<torch::jit::Value*>(vals.begin (), vals.end ()));
226
+ std::unordered_set<torch::jit::Node*> visited;
227
+ std::vector<torch::jit::Node*> stk;
228
+ while (!q.empty ()) {
229
+ auto cur_val = q.front ();
230
+ q.pop ();
231
+ auto node = cur_val->node ();
232
+ if (node->kind () != torch::jit::prim::Constant && !visited.count (node)) {
233
+ stk.push_back (node);
234
+ for (auto input : node->inputs ()) {
235
+ if (!isTensorOrTensorList (input)) {
236
+ q.push (input);
237
+ }
238
+ }
199
239
}
200
240
}
241
+ std::reverse (stk.begin (), stk.end ());
242
+ return stk;
243
+ }
201
244
202
- for (int i = seg_block.raw_outputs ().size () - 1 ; i >= 0 ; --i) {
203
- if (!seg_block.raw_outputs ()[i]->type ()->isSubtypeOf (torch::jit::TensorType::get ()) &&
204
- !seg_block.raw_outputs ()[i]->type ()->isSubtypeOf (c10::ListType::ofTensors ())) {
205
- seg_block.eraseOutput (i);
245
+ SegmentedBlock injectNodesForNonTensorInputs (SegmentedBlock& seg_block) {
246
+ // reconstruct segmented_block if this block requires nonTensor input
247
+ std::vector<torch::jit::Value*> nontensor_inputs;
248
+ for (auto input : seg_block.raw_inputs ()) {
249
+ if (!isTensorOrTensorList (input)) {
250
+ nontensor_inputs.push_back (input);
206
251
}
207
252
}
253
+ std::vector<torch::jit::Node*> new_block_nodes = getDependencyNodes (nontensor_inputs);
254
+ new_block_nodes.insert (new_block_nodes.end (), seg_block.raw_nodes ().begin (), seg_block.raw_nodes ().end ());
255
+ return SegmentedBlock (seg_block.target (), new_block_nodes);
256
+ }
208
257
209
- // not sure to delete this block or just fallback to pytorch
210
- if (seg_block.raw_outputs ().empty ()) {
211
- seg_block.update_target (SegmentedBlock::kTorch );
258
+ void resolveNonTensorInputs (std::vector<SegmentedBlock>& segmented_blocks, std::shared_ptr<torch::jit::Graph> g) {
259
+ // for NonTensor inputs in TensorRT segments, count the usages on Torch segments and TensorRT segments
260
+ std::unordered_map<torch::jit::Value*, usage_info> usage_counts;
261
+ for (int i = segmented_blocks.size () - 1 ; i >= 0 ; --i) {
262
+ for (auto input : segmented_blocks[i].raw_inputs ()) {
263
+ if (!isTensorOrTensorList (input)) {
264
+ segmented_blocks[i].target () == SegmentedBlock::kTorch ? usage_counts[input].torch_use_id .push_back (i)
265
+ : usage_counts[input].tensorrt_use_id .push_back (i);
266
+ }
267
+ }
268
+ for (auto & use : usage_counts) {
269
+ if (segmented_blocks[i].contain_raw_value (use.first )) {
270
+ use.second .produce_id = i;
271
+ }
272
+ }
212
273
}
274
+ std::unordered_set<int > updated_segments;
275
+ for (auto & use : usage_counts) {
276
+ auto use_info = use.second ;
277
+ // if the segment that produce this nonTensor value is kTensorRT but consumed in kTorch, inject nodes in the first
278
+ // kTorch segments
279
+ if (segmented_blocks[use_info.produce_id ].target () == SegmentedBlock::kTensorRT && !use_info.torch_use_id .empty ()) {
280
+ int first_torch_id = use_info.torch_use_id .front ();
281
+ if (!updated_segments.count (first_torch_id)) {
282
+ auto new_torch_block = injectNodesForNonTensorInputs (segmented_blocks[first_torch_id]);
283
+ segmented_blocks[first_torch_id] = new_torch_block;
284
+ updated_segments.insert (first_torch_id);
285
+ }
286
+ } else {
287
+ // KTensorRT segments always need to inject nodes for the nonTensor inputs
288
+ for (int i : use_info.tensorrt_use_id ) {
289
+ if (!updated_segments.count (i)) {
290
+ auto new_seg_block = injectNodesForNonTensorInputs (segmented_blocks[i]);
291
+ segmented_blocks[i] = new_seg_block;
292
+ updated_segments.insert (i);
293
+ }
294
+ }
295
+ }
296
+ }
297
+ return ;
213
298
}
214
299
215
300
void construct_segments (
@@ -231,20 +316,18 @@ void construct_segments(
231
316
}
232
317
}
233
318
234
- std::vector<SegmentedBlock> segment_graph (
319
+ void segment_graph (
235
320
std::shared_ptr<torch::jit::Graph> g,
236
- std::vector< conversion::InputRange>& input_ranges ,
237
- const conversion::TorchFallback& fallback_info ) {
321
+ const conversion::TorchFallback& fallback_info ,
322
+ std::vector<SegmentedBlock>& segmented_blocks ) {
238
323
auto min_block_size = fallback_info.min_block_size ;
239
324
std::unordered_set<std::string> forced_fallback_operators (
240
325
fallback_info.forced_fallback_operators .begin (), fallback_info.forced_fallback_operators .end ());
241
- std::vector<SegmentedBlock> segmented_blocks;
242
326
243
327
auto nodes = g->block ()->nodes ();
244
328
245
329
// segment the nodes
246
330
std::vector<torch::jit::Node*> tensorrt_nodes, pytorch_nodes;
247
-
248
331
for (const auto n : nodes) {
249
332
if (n->kind () == torch::jit::prim::Constant)
250
333
continue ;
@@ -261,22 +344,33 @@ std::vector<SegmentedBlock> segment_graph(
261
344
if (!pytorch_nodes.empty ()) {
262
345
segmented_blocks.emplace_back (SegmentedBlock::kTorch , pytorch_nodes);
263
346
}
347
+ }
348
+
349
+ std::vector<SegmentedBlock> Partition (
350
+ std::shared_ptr<torch::jit::Graph> g,
351
+ std::vector<conversion::InputRange>& input_ranges,
352
+ const conversion::TorchFallback& fallback_info) {
353
+ // segment lowering global graph into blocks
354
+ std::vector<SegmentedBlock> segmented_blocks;
355
+ segment_graph (g, fallback_info, segmented_blocks);
264
356
265
- // register input/output torch::jit::Value for segmetned graphs
266
- registerSegmentsInputsOutputs (segmented_blocks, g);
357
+ // resolve nonTensor inputs/outputs
358
+ resolveNonTensorInputs (segmented_blocks, g);
359
+
360
+ // register input/output torch::jit::Value for segmented graphs
361
+ registerSegmentsOutputs (segmented_blocks, g);
267
362
268
363
// store the mapping from lowering graph torch::jit::Value => torch::jit::IValue that we get by running segments
269
364
std::unordered_map<torch::jit::Value*, torch::jit::IValue> ivalues_maps;
270
-
271
365
std::vector<torch::jit::IValue> random_inputs = generateRandomInputs (input_ranges);
272
366
for (size_t i = 0 ; i < g->inputs ().size (); ++i) {
273
367
ivalues_maps[g->inputs ()[i]] = random_inputs[i];
274
368
}
275
369
276
- // register every segment's input shape, and it's running output Ivalues
370
+ // register every segment's input shape, and it's running output IValues
277
371
for (auto & seg_block : segmented_blocks) {
372
+ torch::jit::ConstantPooling (seg_block.g ());
278
373
registerSegmentInOutIValues (seg_block, ivalues_maps);
279
- eraseNonTensorInputsOutputs (seg_block, ivalues_maps);
280
374
}
281
375
282
376
return segmented_blocks;
0 commit comments