Skip to content

Commit

Permalink
[JIT] optimize autodiff subgraph slicing (#41437)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch/pytorch#41437

[copied from commented code]
the IR has many nodes which can never be reordered around, such as a
prim::Bailout. if a node N is surrounded by two nodes which cannot be
reordered, A and B, then a differentiable subgraph that is created from N
can only contain nodes from [A, B] The nodes from A to B represent one
work block for the subgraph slicer to work on. By creating these up
front, we avoid retraversing the whole graph block any time scanNode
returns, and we can also avoid attempting to create differentiable
subgraphs in work blocks that do not contain a minimum number of differentiable nodes

This improved compilation time of e of densenet (the model with the slowest compilation time we're tracking) from 56s  -> 28s, and for mobilenet from 8s -> 6s.

Test Plan: Imported from OSS

Reviewed By: Krovatkin, ZolotukhinM

Differential Revision: D22600607

Pulled By: eellison

fbshipit-source-id: e5ab6ed87bf6820b4e22c86eabafd9d17bf7cedc
  • Loading branch information
Elias Ellison authored and facebook-github-bot committed Jul 23, 2020
1 parent da3ff5e commit 25b6e2e
Showing 1 changed file with 65 additions and 8 deletions.
73 changes: 65 additions & 8 deletions torch/csrc/jit/passes/create_autodiff_subgraphs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,17 @@ namespace jit {

namespace {

struct WorkBlock : public std::pair<Node*, Node*> {
using pair::pair;

Node* begin() {
return this->first;
}
Node* end() {
return this->second;
}
};

class SubgraphSlicer {
public:
SubgraphSlicer(
Expand All @@ -37,14 +48,20 @@ class SubgraphSlicer {
// c = f(a, b)
// e = f(d) <- iter still here
// d = f(c) <- this was node moved on the other side.
bool any_changed = true;
while (any_changed) {
any_changed = false;
AliasDb aliasDb(graph_);
for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) {
bool changed;
std::tie(it, changed) = scanNode(*it, aliasDb);
any_changed |= changed;

// see [workblocks]
auto workblocks = buildWorkBlocks();
for (auto& workblock : workblocks) {
bool any_changed = true;
while (any_changed) {
AliasDb aliasDb(graph_);
any_changed = false;
for (auto it = workblock.end()->reverseIterator();
it != workblock.begin()->reverseIterator();) {
bool changed;
std::tie(it, changed) = scanNode(*it, aliasDb);
any_changed |= changed;
}
}
}

Expand Down Expand Up @@ -77,6 +94,46 @@ class SubgraphSlicer {
}

private:
std::vector<WorkBlock> buildWorkBlocks() {
// [workblocks]
// the IR has many nodes which can never be reordered around, such as a
// prim::Bailout. if a node N is surrounded by two nodes which cannot be
// reordered, A and B, then a differentiable subgraph that is created from N
// can only contain nodes from (A, B) The nodes from A to B represent one
// work block for the subgraph slicer to work on. By creating these up
// front, we avoid retraversing the whole graph block any time scanNode
// returns, and we can also avoid attempting to create differentiable
// subgraphs in work blocks that do not contain a # of differentiable nodes
// >= minSubgraphSize_

Node* end_bound_node = block_->return_node();
Node* curr = end_bound_node->prev();

std::vector<WorkBlock> worklist;
size_t differentiable_nodes = 0;

while (curr != block_->param_node()) {
differentiable_nodes += shouldConsiderForMerge(curr);

// cannot reorder around side effectful nodes
if (curr->hasSideEffects()) {
// not enough differentiable nodes to create a differentiable subgraph
if (differentiable_nodes >= minSubgraphSize_) {
worklist.emplace_back(curr, end_bound_node);
}
differentiable_nodes = 0;
end_bound_node = curr;
}
curr = curr->prev();
}

if (differentiable_nodes >= minSubgraphSize_) {
worklist.emplace_back(curr, end_bound_node);
}

return worklist;
}

// Inline this node's group subgraph into the outer graph if it's smaller
// than the specified minimum size.
//
Expand Down

0 comments on commit 25b6e2e

Please sign in to comment.