Skip to content

Commit

Permalink
Improve const-correctness of JIT.
Browse files Browse the repository at this point in the history
This started off as a minor fix based on Adam's question, "why is printing
a graph not const" and snowballed into a giant yak shaving exercise.

- The Graph and Node APIs now uniformly enforce deep constness; e.g., if you
  get a const Node* or const Graph*, it is not possible to get a non-const
  Node*/Graph* somewhere else in the graph (even though the member variables
  of these are non-const.  Hooray for private access specifier.)

- A big pile of functions got const versions, most notably the printing
  functions, and functions for accessing inputs().

- REALLY IMPORTANT, BC-BREAKING CHANGE: inputs() now returns a COPY of the
  inputs, rather than a reference to the underlying.  I was forced to do this
  because there is no way to portably turn a std::vector<Node*> into a
  std::vector<const Node*>, which is necessary to provide a const-correct
  version of inputs() that enforces deep const-correctness.  I then justified
  this choice to myself with the observation that outputs() returned a
  copy (by necessity), so this makes the API more uniform.

  But making this change uncovered two very subtle bugs:

    1. If you change functions from returning a reference to returning a copy,
       the idiom node->inputs().begin() is no longer valid, because the memory
       the iterator points to immediately becomes invalid.  THIS SUCKS.
       Honestly, we should add a lint rule rejecting calling begin()/end() on
       temporaries because this is very dangerous.  To excise this pattern from
       the codebase, I added begin() and end() methods to Graph, so that we got
       rid of the graph->nodes().begin() idiom, which happens to be sound,
       despite not returning a reference, because graph_node_list is a
       non-owning reference.

    2. pybind11 doesn't handle std::vector<Node*> cast out of the box.
       Fortunately, I found a simple fix in the GitHub issues tracker
       that involved adding an extra type converter.  And yes, this
       does mean that outputs() in Python never worked correctly.

- New const_graph_node_list, which is a graph_node_list that gives you const
  Node*

There are some more miscellaneous improvements:

- Applied CR comment fixes on export.cpp; using replaceInput, and renaming
  variables for clarity.

- assertValidInput helper method added, and applied to replaceInput

- Use an explicit function to print THPObjectPtr, otherwise we get
  the wrong overload.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
  • Loading branch information
ezyang committed Nov 1, 2017
1 parent 20f569e commit 79b068a
Show file tree
Hide file tree
Showing 13 changed files with 173 additions and 78 deletions.
4 changes: 4 additions & 0 deletions test/expect/TestJit.test_dropout.expect
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
graph(%1 : Double(2, 2)) {
%3 : Double(2, 2), %4 : Handle = ^Dropout(0.6, True, False)(%1), uses = [[%0.i0], []];
return (%3);
}
5 changes: 5 additions & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,11 @@ def test_batchnorm(self):
trace, _ = torch.jit.trace(nn.BatchNorm2d(2), x)
self.assertExpected(str(trace))

def test_dropout(self):
x = Variable(torch.randn(2, 2).fill_(1.0), requires_grad=True)
trace, _ = torch.jit.trace(nn.Dropout(0.6), x)
self.assertExpected(str(trace))

@unittest.skip("unrecognized NodeKind: SpatialBN")
def test_batchnorm_run_twice(self):
@torch.jit.compile(nderivs=0)
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/autograd/functions/jit_closure.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ struct StageClosure {

// All Eval nodes take context edges as an input, and we need to register
// all such places
auto & inputs = value->inputs();
auto inputs = value->inputs();
JIT_ASSERT(inputs.size() > 0);
auto handle_input = inputs[inputs.size() - 1];
JIT_ASSERT(handle_input->type()->kind() == TypeKind::HandleType);
Expand Down
27 changes: 14 additions & 13 deletions torch/csrc/jit/export.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ bool fusibleExpandTo(at::IntList from, at::IntList to) {
// local information. This optimization is not useful for PyTorch as 'expand'
// is free.
void fuseBroadcast(const std::shared_ptr<Graph>& graph) {
for (auto it = graph->nodes().begin(); it != graph->nodes().end(); ++it) {
for (auto it = graph->begin(); it != graph->end(); ++it) {
auto* n = *it;

// Can't fuse into nodes that don't support broadcasting
Expand All @@ -271,28 +271,29 @@ void fuseBroadcast(const std::shared_ptr<Graph>& graph) {
if (n->hasAttribute(kbroadcast) && n->i(kbroadcast)) continue;
JIT_ASSERT(!n->hasAttribute(kaxis));

auto* rhs = n->inputs().at(n->inputs().size() - 1);
auto input_index = n->inputs().size() - 1;
auto* expanded_rhs = n->inputs().at(input_index);

// The rhs input isn't actually an expand, so no fusion available
if (rhs->kind() != kExpand) continue;
// The expanded_rhs input isn't actually an expand, so no fusion available
if (expanded_rhs->kind() != kExpand) continue;

auto* new_rhs = rhs->input();
auto* unexpanded_rhs = expanded_rhs->input();

// We need to know what the type pre-expand is. We should basically
// always have this information (because expands are only ever traced,
// not generated from symbolic), but if for some reason we don't
// have it, we need to skip.
if (!new_rhs->hasType()) continue;
if (!unexpanded_rhs->hasType()) continue;

// Not all broadcasts are supported by ONNX broadcast.
if (!fusibleExpandTo(new_rhs->type()->expect<TensorType>()->sizes(), // from
rhs->type()->expect<TensorType>()->sizes()) // to
if (!fusibleExpandTo(unexpanded_rhs->type()->expect<TensorType>()->sizes(), // from
expanded_rhs->type()->expect<TensorType>()->sizes()) // to
) continue;

n->replaceInput(n->inputs().size() - 1, new_rhs);
n->i_(kbroadcast,1);
if (rhs->uses().size() == 0) {
rhs->destroy();
n->replaceInput(input_index, unexpanded_rhs);
n->i_(kbroadcast, 1);
if (expanded_rhs->uses().size() == 0) {
expanded_rhs->destroy();
}
}
}
Expand All @@ -302,7 +303,7 @@ void standardizeGraph(const std::shared_ptr<Graph>& graph) {
// TODO: move this out of here...
fuseBroadcast(graph);

for (auto it = graph->nodes().begin(); it != graph->nodes().end(); ++it) {
for (auto it = graph->begin(); it != graph->end(); ++it) {
// Macro'ed so we get a marginally better line number on failed export
#define FAIL_EXPORT(name) \
throw std::runtime_error(std::string("ONNX export failed: ") + name + "\n\nGraph we tried to export:\n" + graph->toString());
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/graph_node_list.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ struct generic_graph_node_list_iterator;

struct Node;
using graph_node_list = generic_graph_node_list<Node>;
using const_graph_node_list = generic_graph_node_list<const Node>;
using graph_node_list_iterator = generic_graph_node_list_iterator<Node>;
using const_graph_node_list_iterator = generic_graph_node_list_iterator<const Node>;

Expand Down
55 changes: 28 additions & 27 deletions torch/csrc/jit/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,12 @@ std::string getPythonName(const PyObject* obj, bool is_legacy) {
return THPUtils_unpackString(name.get());
}
}
void printNodeRef(std::ostream & out, Node * n) {
void printNodeRef(std::ostream & out, const Node * n) {
out << "%" << n->uniqueName();
}

std::ostream& operator<<(std::ostream & out, const node_list & nodes) {
template <typename T>
std::ostream& operator<<(std::ostream & out, const std::vector<T> & nodes) {
size_t i = 0;
for(auto n : nodes) {
if(i++ > 0)
Expand All @@ -48,8 +49,9 @@ std::ostream& operator<<(std::ostream & out, const node_list & nodes) {
return out;
}

static std::ostream& operator<<(std::ostream & out, THPObjectPtr& obj) {
auto pyobj = py::handle(obj.get());
std::ostream& printPyObject(std::ostream & out, const THPObjectPtr& obj) {
AutoGIL gil;
auto pyobj = py::handle(const_cast<PyObject*>(obj.get()));
if (py::isinstance<py::tuple>(pyobj)) {
// This special-case for printing tuples handles a problem where
// str((2L, 3L)) outputs "(2L, 3L)" in Python 2 but "(2, 3)"
Expand Down Expand Up @@ -82,20 +84,19 @@ static std::ostream& operator<<(std::ostream & out, THPObjectPtr& obj) {
out << ")";
return out;
} else {
THPObjectPtr str { PyObject_Str(obj.get()) };
return out << THPUtils_unpackString(str.get());
return out << THPUtils_unpackString(py::str(pyobj).ptr());
}
}

std::string PythonOp::name() {
std::string PythonOp::name() const {
return getPythonName(pyobj.get(),is_legacy);
}

std::string CppOp::name() {
std::string CppOp::name() const {
return fn->name();
}

static void emitUses(std::ostream & out, Node * n) {
static void emitUses(std::ostream & out, const Node * n) {
size_t i = 0;
for(auto u : n->uses()) {
if(i++ > 0)
Expand All @@ -105,13 +106,13 @@ static void emitUses(std::ostream & out, Node * n) {
}
}

struct node_list_with_types {
const node_list& nodes;
struct const_node_list_with_types {
const std::vector<const Node*>& nodes;
bool use_newlines;
node_list_with_types(const node_list& nodes, bool use_newlines = false)
const_node_list_with_types(const std::vector<const Node*>& nodes, bool use_newlines = false)
: nodes(nodes), use_newlines(use_newlines) {}
};
std::ostream& operator<<(std::ostream & out, node_list_with_types l) {
std::ostream& operator<<(std::ostream & out, const_node_list_with_types l) {
size_t i = 0;
size_t prev_stage = 0;
for(auto n : l.nodes) {
Expand Down Expand Up @@ -147,7 +148,7 @@ void printPrimList(std::ostream & out, const std::vector<T> & items) {
}
out << "]";
}
void printAttributes(std::ostream & out, Node * n) {
void printAttributes(std::ostream & out, const Node * n) {
out << "[";
auto names = n->attributeNames();
int i = 0;
Expand Down Expand Up @@ -214,18 +215,18 @@ void printAttributes(std::ostream & out, Node * n) {
out << "]";
}

std::ostream& printNode(std::ostream & out, Node * n, std::vector<Node*> * groups) {
node_list outputs = n->outputs();
out << node_list_with_types(outputs);
std::ostream& printNode(std::ostream & out, const Node * n, std::vector<const Node*> * groups) {
auto outputs = n->outputs();
out << const_node_list_with_types(outputs);
out << " = ";
IR_IFM(n,PythonOp)
IR_IFM_CONST(n,PythonOp)
out << "^" << value->name();
out << "(";
int i = 0;
for (auto& scalar : value->scalar_args) {
if (i++ > 0)
out << ", ";
out << scalar;
printPyObject(out, scalar);
}
out << ")";
IR_ELSEIF(FusionGroup)
Expand All @@ -235,7 +236,7 @@ std::ostream& printNode(std::ostream & out, Node * n, std::vector<Node*> * group
} else {
out << "fusion_group[" << *n->g(kSubgraph) << "]";
}
IR_ELSEIFM(CppOp)
IR_ELSEIFM_CONST(CppOp)
out << "CppOp[" << value->name() << "]";
IR_ELSE()
out << symbolToString(n->kind());
Expand All @@ -260,13 +261,13 @@ std::ostream& printNode(std::ostream & out, Node * n, std::vector<Node*> * group
return out;
}

std::ostream& operator<<(std::ostream & out, Node & n) {
std::ostream& operator<<(std::ostream & out, const Node & n) {
return printNode(out, &n, nullptr);
}

std::ostream& operator<<(std::ostream & out, Graph & g) {
out << "graph(" << node_list_with_types(g.inputs(), true) << ") {\n";
std::vector<Node*> groups;
std::ostream& operator<<(std::ostream & out, const Graph & g) {
out << "graph(" << const_node_list_with_types(g.inputs(), true) << ") {\n";
std::vector<const Node*> groups;
size_t prev_stage = 0;
for(auto n : g.nodes()) {
if(n->kind() != kSelect) { //improve readibility by printing selects inline
Expand Down Expand Up @@ -490,14 +491,14 @@ void Graph::lint() const {
JIT_ASSERT(std::includes(ALL_OF(sum_set), ALL_OF(all_nodes_set)));

// graph->stage() should be equal to max(node.stage for node in graph->nodes())
if (nodes().begin() == nodes().end()) {
if (begin() == end()) {
JIT_ASSERT(stage() == 0);
} else {
JIT_ASSERT(stage() == nodes().rbegin()->stage());
JIT_ASSERT(stage() == rbegin()->stage());
}
}

void Graph::dump() {
void Graph::dump() const {
std::cout << *this << "\n";
}

Expand Down
Loading

0 comments on commit 79b068a

Please sign in to comment.