Skip to content

Commit 4e32eff

Browse files
committed
feat: insert nodes by dependencies for nonTensor inputs/outputs
Signed-off-by: Bo Wang <wangbo1995ee@163.com>
1 parent 54e407e commit 4e32eff

File tree

2 files changed

+148
-46
lines changed

2 files changed

+148
-46
lines changed

core/partitioning/partitioning.cpp

+134-40
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,26 @@
11
#include "partitioning.h"
2+
#include <queue>
23
#include "core/conversion/evaluators/eval_util.h"
34
#include "core/lowering/passes/passes.h"
45
#include "core/util/prelude.h"
56
#include "torch/csrc/jit/api/module.h"
6-
#include "torch/csrc/jit/ir/constants.h"
7+
#include "torch/csrc/jit/passes/constant_pooling.h"
78

89
namespace trtorch {
910
namespace core {
1011
namespace partitioning {
1112

13+
inline bool isTensorOrTensorList(torch::jit::Value* val) {
14+
return val->type()->isSubtypeOf(torch::jit::TensorType::get()) ||
15+
val->type()->isSubtypeOf(torch::jit::ListType::ofTensors());
16+
}
17+
18+
struct usage_info {
19+
int produce_id = -1;
20+
std::vector<int> torch_use_id;
21+
std::vector<int> tensorrt_use_id;
22+
};
23+
1224
torch::jit::Value* getOrAddInputForValue(
1325
torch::jit::Value* old_value,
1426
std::shared_ptr<torch::jit::Graph>& graph,
@@ -39,6 +51,7 @@ torch::jit::Node* cloneNode(
3951
auto* block = graph->block();
4052
auto env = [&](torch::jit::Value* v) { return getOrAddInputForValue(v, graph, old_to_new); };
4153

54+
// create node for current graph by using the metadata in node and input Values in env
4255
auto new_node = block->appendNode(graph->createClone(node, env));
4356
for (size_t i = 0; i < node->outputs().size(); ++i) {
4457
auto oo = node->outputs()[i];
@@ -68,7 +81,6 @@ void registerSegmentInOutIValues(
6881
// create a module to run the graph
6982
auto g = seg_block.g();
7083
auto copy_g = g->copy();
71-
// LOG_INFO(*copy_g << "(copy graph)\n");
7284

7385
// create tuple for multiple outputs
7486
if (seg_block.raw_outputs().size() > 1) {
@@ -110,7 +122,10 @@ void registerSegmentInOutIValues(
110122

111123
// run segments to get outputs for later segments input shape, and other arguments such as Int
112124
std::vector<torch::jit::IValue> jit_results;
125+
printf("before forward\n");
113126
torch::jit::IValue jit_results_ivalues = cur_mod.forward(jit_inputs_ivalues);
127+
printf("after forward\n");
128+
114129
if (jit_results_ivalues.isTuple()) {
115130
auto results = jit_results_ivalues.toTuple()->elements();
116131
for (auto r : results) {
@@ -149,13 +164,10 @@ std::vector<torch::jit::IValue> generateRandomInputs(std::vector<conversion::Inp
149164
return random_inputs;
150165
}
151166

152-
void registerSegmentsInputsOutputs(
153-
std::vector<SegmentedBlock>& segmented_blocks,
154-
std::shared_ptr<torch::jit::Graph> g) {
167+
void registerSegmentsOutputs(std::vector<SegmentedBlock>& segmented_blocks, std::shared_ptr<torch::jit::Graph> g) {
155168
// find the corresponding raw values in original global graph for this segmented block's inputs/outputs
156169
std::set<torch::jit::Value*> input_values;
157170
for (auto& seg_block : segmented_blocks) {
158-
seg_block.registerInputs();
159171
for (auto& input : seg_block.raw_inputs()) {
160172
input_values.insert(input);
161173
}
@@ -165,51 +177,124 @@ void registerSegmentsInputsOutputs(
165177
input_values.insert(graph_output);
166178
}
167179

168-
// should be careful here because some in-place operations don't return any values
180+
// should be careful here because some in-place operations don't return any values, there is no output for this kind
181+
// of segment identify the output for each mini-graph by checking if any value in this graph is used later we
182+
// shouldn't register nonTensor output for TensorRT segments
169183
for (auto& seg_block : segmented_blocks) {
170184
for (auto& mini_graph_input : input_values) {
171185
if (std::find(seg_block.raw_inputs().begin(), seg_block.raw_inputs().end(), mini_graph_input) ==
172186
seg_block.raw_inputs().end() &&
173-
seg_block.contain_raw_input(mini_graph_input)) {
187+
seg_block.contain_raw_value(mini_graph_input)) {
188+
if (!isTensorOrTensorList(mini_graph_input) && seg_block.target() == SegmentedBlock::kTensorRT)
189+
continue;
174190
seg_block.registerOutput(mini_graph_input);
175191
}
176192
}
193+
// if no output, then register the last node's output as current graph's output
177194
if (seg_block.raw_outputs().empty()) {
178-
seg_block.registerOutput(seg_block.raw_inputs()[0]);
195+
// for Torch segments, register input as output
196+
if (seg_block.target() == SegmentedBlock::kTorch) {
197+
seg_block.registerOutput(seg_block.raw_inputs()[0]);
198+
} else {
199+
// for TensorRT segments, register last nonInput Tensor outputs
200+
for (int i = seg_block.raw_nodes().size() - 1; i >= 0; --i) {
201+
for (auto node_output : seg_block.raw_nodes()[i]->outputs()) {
202+
if (isTensorOrTensorList(node_output))
203+
seg_block.registerOutput(node_output);
204+
}
205+
if (!seg_block.raw_outputs().empty())
206+
break;
207+
}
208+
}
179209
}
180210
}
211+
// erase segments which still have no output
212+
segmented_blocks.erase(
213+
std::remove_if(
214+
segmented_blocks.begin(),
215+
segmented_blocks.end(),
216+
[](SegmentedBlock& seg_block) { return seg_block.raw_outputs().empty(); }),
217+
segmented_blocks.end());
181218

182219
return;
183220
}
184221

185-
void eraseNonTensorInputsOutputs(
186-
SegmentedBlock& seg_block,
187-
std::unordered_map<torch::jit::Value*, torch::jit::IValue>& ivalues_maps) {
188-
if (seg_block.target() == SegmentedBlock::kTorch)
189-
return;
190-
auto mini_graph = seg_block.g();
191-
192-
for (int i = seg_block.raw_inputs().size() - 1; i >= 0; --i) {
193-
// erase this input and prepend a prim::Constant if it's not Tensor
194-
if (!seg_block.raw_inputs()[i]->type()->isSubtypeOf(torch::jit::TensorType::get()) &&
195-
!seg_block.raw_inputs()[i]->type()->isSubtypeOf(c10::ListType::ofTensors())) {
196-
auto new_val = torch::jit::insertConstant(*mini_graph, ivalues_maps[seg_block.raw_inputs()[i]]);
197-
seg_block.inputs()[i]->replaceAllUsesWith(new_val);
198-
seg_block.eraseInput(i);
222+
std::vector<torch::jit::Node*> getDependencyNodes(std::vector<torch::jit::Value*>& vals) {
223+
// using bfs to get the DAG dependency nodes for input value
224+
std::queue<torch::jit::Value*, std::deque<torch::jit::Value*>> q(
225+
std::deque<torch::jit::Value*>(vals.begin(), vals.end()));
226+
std::unordered_set<torch::jit::Node*> visited;
227+
std::vector<torch::jit::Node*> stk;
228+
while (!q.empty()) {
229+
auto cur_val = q.front();
230+
q.pop();
231+
auto node = cur_val->node();
232+
if (node->kind() != torch::jit::prim::Constant && !visited.count(node)) {
233+
stk.push_back(node);
234+
for (auto input : node->inputs()) {
235+
if (!isTensorOrTensorList(input)) {
236+
q.push(input);
237+
}
238+
}
199239
}
200240
}
241+
std::reverse(stk.begin(), stk.end());
242+
return stk;
243+
}
201244

202-
for (int i = seg_block.raw_outputs().size() - 1; i >= 0; --i) {
203-
if (!seg_block.raw_outputs()[i]->type()->isSubtypeOf(torch::jit::TensorType::get()) &&
204-
!seg_block.raw_outputs()[i]->type()->isSubtypeOf(c10::ListType::ofTensors())) {
205-
seg_block.eraseOutput(i);
245+
SegmentedBlock injectNodesForNonTensorInputs(SegmentedBlock& seg_block) {
246+
// reconstruct segmented_block if this block requires nonTensor input
247+
std::vector<torch::jit::Value*> nontensor_inputs;
248+
for (auto input : seg_block.raw_inputs()) {
249+
if (!isTensorOrTensorList(input)) {
250+
nontensor_inputs.push_back(input);
206251
}
207252
}
253+
std::vector<torch::jit::Node*> new_block_nodes = getDependencyNodes(nontensor_inputs);
254+
new_block_nodes.insert(new_block_nodes.end(), seg_block.raw_nodes().begin(), seg_block.raw_nodes().end());
255+
return SegmentedBlock(seg_block.target(), new_block_nodes);
256+
}
208257

209-
// not sure to delete this block or just fallback to pytorch
210-
if (seg_block.raw_outputs().empty()) {
211-
seg_block.update_target(SegmentedBlock::kTorch);
258+
void resolveNonTensorInputs(std::vector<SegmentedBlock>& segmented_blocks, std::shared_ptr<torch::jit::Graph> g) {
259+
// for NonTensor inputs in TensorRT segments, count the usages on Torch segments and TensorRT segments
260+
std::unordered_map<torch::jit::Value*, usage_info> usage_counts;
261+
for (int i = segmented_blocks.size() - 1; i >= 0; --i) {
262+
for (auto input : segmented_blocks[i].raw_inputs()) {
263+
if (!isTensorOrTensorList(input)) {
264+
segmented_blocks[i].target() == SegmentedBlock::kTorch ? usage_counts[input].torch_use_id.push_back(i)
265+
: usage_counts[input].tensorrt_use_id.push_back(i);
266+
}
267+
}
268+
for (auto& use : usage_counts) {
269+
if (segmented_blocks[i].contain_raw_value(use.first)) {
270+
use.second.produce_id = i;
271+
}
272+
}
212273
}
274+
std::unordered_set<int> updated_segments;
275+
for (auto& use : usage_counts) {
276+
auto use_info = use.second;
277+
// if the segment that produce this nonTensor value is kTensorRT but consumed in kTorch, inject nodes in the first
278+
// kTorch segments
279+
if (segmented_blocks[use_info.produce_id].target() == SegmentedBlock::kTensorRT && !use_info.torch_use_id.empty()) {
280+
int first_torch_id = use_info.torch_use_id.front();
281+
if (!updated_segments.count(first_torch_id)) {
282+
auto new_torch_block = injectNodesForNonTensorInputs(segmented_blocks[first_torch_id]);
283+
segmented_blocks[first_torch_id] = new_torch_block;
284+
updated_segments.insert(first_torch_id);
285+
}
286+
} else {
287+
// KTensorRT segments always need to inject nodes for the nonTensor inputs
288+
for (int i : use_info.tensorrt_use_id) {
289+
if (!updated_segments.count(i)) {
290+
auto new_seg_block = injectNodesForNonTensorInputs(segmented_blocks[i]);
291+
segmented_blocks[i] = new_seg_block;
292+
updated_segments.insert(i);
293+
}
294+
}
295+
}
296+
}
297+
return;
213298
}
214299

215300
void construct_segments(
@@ -231,20 +316,18 @@ void construct_segments(
231316
}
232317
}
233318

234-
std::vector<SegmentedBlock> segment_graph(
319+
void segment_graph(
235320
std::shared_ptr<torch::jit::Graph> g,
236-
std::vector<conversion::InputRange>& input_ranges,
237-
const conversion::TorchFallback& fallback_info) {
321+
const conversion::TorchFallback& fallback_info,
322+
std::vector<SegmentedBlock>& segmented_blocks) {
238323
auto min_block_size = fallback_info.min_block_size;
239324
std::unordered_set<std::string> forced_fallback_operators(
240325
fallback_info.forced_fallback_operators.begin(), fallback_info.forced_fallback_operators.end());
241-
std::vector<SegmentedBlock> segmented_blocks;
242326

243327
auto nodes = g->block()->nodes();
244328

245329
// segment the nodes
246330
std::vector<torch::jit::Node*> tensorrt_nodes, pytorch_nodes;
247-
248331
for (const auto n : nodes) {
249332
if (n->kind() == torch::jit::prim::Constant)
250333
continue;
@@ -261,22 +344,33 @@ std::vector<SegmentedBlock> segment_graph(
261344
if (!pytorch_nodes.empty()) {
262345
segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);
263346
}
347+
}
348+
349+
std::vector<SegmentedBlock> Partition(
350+
std::shared_ptr<torch::jit::Graph> g,
351+
std::vector<conversion::InputRange>& input_ranges,
352+
const conversion::TorchFallback& fallback_info) {
353+
// segment lowering global graph into blocks
354+
std::vector<SegmentedBlock> segmented_blocks;
355+
segment_graph(g, fallback_info, segmented_blocks);
264356

265-
// register input/output torch::jit::Value for segmetned graphs
266-
registerSegmentsInputsOutputs(segmented_blocks, g);
357+
// resolve nonTensor inputs/outputs
358+
resolveNonTensorInputs(segmented_blocks, g);
359+
360+
// register input/output torch::jit::Value for segmented graphs
361+
registerSegmentsOutputs(segmented_blocks, g);
267362

268363
// store the mapping from lowering graph torch::jit::Value => torch::jit::IValue that we get by running segments
269364
std::unordered_map<torch::jit::Value*, torch::jit::IValue> ivalues_maps;
270-
271365
std::vector<torch::jit::IValue> random_inputs = generateRandomInputs(input_ranges);
272366
for (size_t i = 0; i < g->inputs().size(); ++i) {
273367
ivalues_maps[g->inputs()[i]] = random_inputs[i];
274368
}
275369

276-
// register every segment's input shape, and it's running output Ivalues
370+
// register every segment's input shape, and it's running output IValues
277371
for (auto& seg_block : segmented_blocks) {
372+
torch::jit::ConstantPooling(seg_block.g());
278373
registerSegmentInOutIValues(seg_block, ivalues_maps);
279-
eraseNonTensorInputsOutputs(seg_block, ivalues_maps);
280374
}
281375

282376
return segmented_blocks;

core/partitioning/partitioning.h

+14-6
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,13 @@ struct SegmentedBlock {
3030

3131
SegmentedBlock(SegmentedBlockTarget blk_target) : target_(blk_target), g_(std::make_shared<torch::jit::Graph>()) {}
3232

33-
SegmentedBlock(SegmentedBlockTarget blk_target, const std::vector<torch::jit::Node*>& nodes)
33+
SegmentedBlock(SegmentedBlockTarget blk_target, std::vector<torch::jit::Node*>& nodes)
3434
: target_(blk_target), g_(std::make_shared<torch::jit::Graph>()) {
3535
for (auto& node : nodes) {
36+
nodes_.push_back(node);
3637
appendNode(node);
3738
}
39+
registerInputs();
3840
}
3941

4042
SegmentedBlock(SegmentedBlockTarget blk_target, std::shared_ptr<torch::jit::Graph> g) : target_(blk_target), g_(g) {}
@@ -53,9 +55,9 @@ struct SegmentedBlock {
5355
}
5456
}
5557

56-
void registerOutput(torch::jit::Value* raw_input) {
57-
outputs_.push_back(raw_input);
58-
g_->registerOutput(old_to_new_[raw_input]);
58+
void registerOutput(torch::jit::Value* raw_output) {
59+
outputs_.push_back(raw_output);
60+
g_->registerOutput(old_to_new_[raw_output]);
5961
}
6062

6163
torch::jit::Block* block() {
@@ -88,7 +90,11 @@ struct SegmentedBlock {
8890
return outputs_;
8991
}
9092

91-
bool contain_raw_input(torch::jit::Value* input) {
93+
const std::vector<torch::jit::Node*>& raw_nodes() const {
94+
return nodes_;
95+
}
96+
97+
bool contain_raw_value(torch::jit::Value* input) {
9298
return old_to_new_.count(input);
9399
}
94100

@@ -121,15 +127,17 @@ struct SegmentedBlock {
121127
std::vector<nvinfer1::Dims> in_shape_;
122128
std::vector<torch::jit::Value*> inputs_;
123129
std::vector<torch::jit::Value*> outputs_;
130+
std::vector<torch::jit::Node*> nodes_;
124131
std::shared_ptr<torch::jit::Graph> g_;
125132
std::string trt_engine;
126133
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_;
127134
};
128135

129-
std::vector<SegmentedBlock> segment_graph(
136+
std::vector<SegmentedBlock> Partition(
130137
std::shared_ptr<torch::jit::Graph> g,
131138
std::vector<conversion::InputRange>& input_ranges,
132139
const conversion::TorchFallback& fallback_info);
140+
133141
} // namespace partitioning
134142
} // namespace core
135143
} // namespace trtorch

0 commit comments

Comments
 (0)