diff --git a/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp b/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp index 0080e1a74f..1324d6d555 100644 --- a/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp +++ b/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp @@ -12,6 +12,17 @@ namespace jit { namespace { +struct WorkBlock : public std::pair { + using pair::pair; + + Node* begin() { + return this->first; + } + Node* end() { + return this->second; + } +}; + class SubgraphSlicer { public: SubgraphSlicer( @@ -37,14 +48,20 @@ class SubgraphSlicer { // c = f(a, b) // e = f(d) <- iter still here // d = f(c) <- this was node moved on the other side. - bool any_changed = true; - while (any_changed) { - any_changed = false; - AliasDb aliasDb(graph_); - for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) { - bool changed; - std::tie(it, changed) = scanNode(*it, aliasDb); - any_changed |= changed; + + // see [workblocks] + auto workblocks = buildWorkBlocks(); + for (auto& workblock : workblocks) { + bool any_changed = true; + while (any_changed) { + AliasDb aliasDb(graph_); + any_changed = false; + for (auto it = workblock.end()->reverseIterator(); + it != workblock.begin()->reverseIterator();) { + bool changed; + std::tie(it, changed) = scanNode(*it, aliasDb); + any_changed |= changed; + } } } @@ -77,6 +94,46 @@ class SubgraphSlicer { } private: + std::vector buildWorkBlocks() { + // [workblocks] + // the IR has many nodes which can never be reordered around, such as a + // prim::Bailout. if a node N is surrounded by two nodes which cannot be + // reordered, A and B, then a differentiable subgraph that is created from N + // can only contain nodes from (A, B) The nodes from A to B represent one + // work block for the subgraph slicer to work on. By creating these up + // front, we avoid retraversing the whole graph block any time scanNode + // returns, and we can also avoid attempting to create differentiable + // subgraphs in work blocks that do not contain a # of differentiable nodes + // >= minSubgraphSize_ + + Node* end_bound_node = block_->return_node(); + Node* curr = end_bound_node->prev(); + + std::vector worklist; + size_t differentiable_nodes = 0; + + while (curr != block_->param_node()) { + differentiable_nodes += shouldConsiderForMerge(curr); + + // cannot reorder around side effectful nodes + if (curr->hasSideEffects()) { + // not enough differentiable nodes to create a differentiable subgraph + if (differentiable_nodes >= minSubgraphSize_) { + worklist.emplace_back(curr, end_bound_node); + } + differentiable_nodes = 0; + end_bound_node = curr; + } + curr = curr->prev(); + } + + if (differentiable_nodes >= minSubgraphSize_) { + worklist.emplace_back(curr, end_bound_node); + } + + return worklist; + } + // Inline this node's group subgraph into the outer graph if it's smaller // than the specified minimum size. //