@@ -68,6 +68,7 @@ std::vector<torch::jit::Node*> getDependencyNodes(std::vector<torch::jit::Value*
6868 q.pop ();
6969 auto node = cur_val->node ();
7070 if (node->kind () != torch::jit::prim::Constant && !visited.count (node)) {
71+ visited.insert (node);
7172 stk.push_back (node);
7273 for (auto input : node->inputs ()) {
7374 if (!isTensorOrTensorList (input)) {
@@ -89,14 +90,14 @@ std::vector<torch::jit::Node*> getOutputNodes(
8990 std::unordered_set<torch::jit::Node*> visited;
9091 q.push (value);
9192
92- // top-down order traveling
93+ // top-down order traversing
9394 while (!q.empty ()) {
9495 auto cur_val = q.front ();
9596 q.pop ();
9697 for (auto use : cur_val->uses ()) {
9798 auto node = use.user ;
9899 // use node must be in seg_block_nodes
99- if (seg_block_nodes.count (node) != 0 && !visited.count (node)) {
100+ if (seg_block_nodes.count (node) && !visited.count (node)) {
100101 stk.push_back (node);
101102 visited.insert (node);
102103 // travel its' all outputs
@@ -109,10 +110,41 @@ std::vector<torch::jit::Node*> getOutputNodes(
109110 }
110111 }
111112
112- // top-down order and we don't need reverse it
113+ // top-down order and we don't need to reverse it
113114 return stk;
114115}
115116
117+ void getDirtyNodes (
118+ std::unordered_set<torch::jit::Node*>& dirty_nodes,
119+ const std::unordered_set<torch::jit::Node*>& seg_block_nodes) {
120+ std::queue<torch::jit::Node*> q;
121+ for (auto & node : dirty_nodes) {
122+ q.push (node);
123+ }
124+ dirty_nodes.clear ();
125+
126+ while (!q.empty ()) {
127+ auto cur_node = q.front ();
128+ q.pop ();
129+ if (!dirty_nodes.count (cur_node) && seg_block_nodes.count (cur_node)) {
130+ dirty_nodes.insert (cur_node);
131+ for (auto input : cur_node->inputs ()) {
132+ if (!isTensorOrTensorList (input)) {
133+ q.push (input->node ());
134+ }
135+ }
136+ for (auto output : cur_node->outputs ()) {
137+ if (!isTensorOrTensorList (output)) {
138+ for (auto use : output->uses ()) {
139+ auto node = use.user ;
140+ q.push (node);
141+ }
142+ }
143+ }
144+ }
145+ }
146+ }
147+
116148std::pair<std::unordered_map<torch::jit::Value*, SegmentedBlock>, SegmentedBlock> segmentBlocksWithTensorListInputs (
117149 SegmentedBlock& seg_block,
118150 const std::unordered_map<torch::jit::Value*, SegmentedBlock>& tensorlist_inputs) {
@@ -163,25 +195,29 @@ PartitionedGraph segmentBlocksWithNonTensorInputs(SegmentedBlock& seg_block) {
163195 } else {
164196 // if current block is kTensorRT but the dependency nodes contain unsupported node, then we have to segment again
165197 std::unordered_set<torch::jit::Value*> nontensor_inputs_set (nontensor_inputs.begin (), nontensor_inputs.end ());
166- std::vector<torch::jit::Node*> tensorrt_nodes, pytorch_nodes (dependency_nodes. begin (), dependency_nodes. end ()) ;
198+ std::vector<torch::jit::Node*> tensorrt_nodes, pytorch_nodes;
167199
168- bool prev_non_tensor_outputs = false ;
200+ // take all nodes with non_tensor_inputs as initial dirty nodes (nodes that should be in PyTorch block), then we use
201+ // dfs/bfs to find all dirty nodes that consume non_tensor values produced by dirty nodes or produces non_tensor
202+ // values consumed by dirty nodes
203+ std::unordered_set<torch::jit::Node*> dirty_nodes;
204+ const std::unordered_set<torch::jit::Node*> seg_block_nodes (
205+ seg_block.raw_nodes ().begin (), seg_block.raw_nodes ().end ());
206+
207+ for (auto n : seg_block.raw_nodes ()) {
208+ if (containTargetInputs (n, nontensor_inputs_set)) {
209+ dirty_nodes.insert (n);
210+ }
211+ }
212+ getDirtyNodes (dirty_nodes, seg_block_nodes);
169213 for (auto n : seg_block.raw_nodes ()) {
170- // Check if the node has non-tensor inputs or if it consumes non-tensor outputs of previous node.
171- // In these cases, these nodes are placed into a new Pytorch SegmentedBlock. Else, they form a new TensorRT
172- // SegmentedBlock.
173- if (containTargetInputs (n, nontensor_inputs_set) || prev_non_tensor_outputs) {
174- // If tensorrt_nodes is not empty, the previous nodes were all tensorrt_nodes. Construct a
175- // TensorRT segmented_block and clear the tensorrt_nodes list to be later used for new TRT segments.
214+ if (dirty_nodes.count (n)) {
176215 if (!tensorrt_nodes.empty ()) {
177216 new_seg_blocks.emplace_back (new_seg_blocks.size (), SegmentedBlock::kTensorRT , tensorrt_nodes);
178217 tensorrt_nodes.clear ();
179218 }
180219 pytorch_nodes.push_back (n);
181- prev_non_tensor_outputs = containNonTensorOutputs (n);
182220 } else {
183- // If pytorch_nodes is not empty, the previous nodes were all pytorch_nodes. Construct a
184- // Pytorch segmented_block and clear the pytorch_nodes list to be later used for new Pytorch segments.
185221 if (!pytorch_nodes.empty ()) {
186222 new_seg_blocks.emplace_back (new_seg_blocks.size (), SegmentedBlock::kTorch , pytorch_nodes);
187223 pytorch_nodes.clear ();
@@ -190,7 +226,7 @@ PartitionedGraph segmentBlocksWithNonTensorInputs(SegmentedBlock& seg_block) {
190226 }
191227 }
192228
193- // Form the last segmented_block with the left over nodes in tensorrt_nodes or pytorch_nodes correspondingly.
229+ // Form the last segmented_block with the leftover nodes in tensorrt_nodes or pytorch_nodes correspondingly.
194230 if (!tensorrt_nodes.empty ()) {
195231 new_seg_blocks.emplace_back (new_seg_blocks.size (), SegmentedBlock::kTensorRT , tensorrt_nodes);
196232 } else {
0 commit comments