Skip to content
This repository has been archived by the owner on Jan 3, 2023. It is now read-only.

Commit

Permalink
refactor for op::Result (#165)
Browse files Browse the repository at this point in the history
* refactor for op::Result

* fix optimize graph
  • Loading branch information
Matthew Brookhart authored Mar 7, 2018
1 parent 1fe35c6 commit 11fae70
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/ngraph/ngraph_sgcompiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ void OptimizeGraph(std::shared_ptr<Graph> sub_graph,
// if we're in CPU, combine the graphs
ngraph::NodeVector dYdXs;
for (size_t i = 0; i < bf->get_output_size(); ++i) {
dYdXs.push_back(bf->get_output_op(i));
dYdXs.push_back(bf->get_output_op(i)->get_input_op(0));
}
ngraph::NodeVector combined_outputs{f->get_output_op(0)};
ngraph::NodeVector combined_outputs{f->get_output_op(0)->get_input_op(0)};
combined_outputs.insert(combined_outputs.end(), dYdXs.begin(), dYdXs.end());

std::vector<std::shared_ptr<ngraph::op::Parameter>> combined_parameters =
Expand Down Expand Up @@ -171,7 +171,7 @@ std::shared_ptr<ngraph::Function> SGCompiler::MakeForwardFunction(
std::shared_ptr<ngraph::Function> SGCompiler::MakeBackwardFunction(
std::shared_ptr<Graph> sub_graph, std::shared_ptr<ngraph::Function> f) {
// Get the output
auto Y = f->get_output_op(0);
auto Y = f->get_output_op(0)->get_input_op(0);

// Create the Adjoint
auto C = std::make_shared<ngraph::op::Parameter>(Y->get_element_type(),
Expand Down

0 comments on commit 11fae70

Please sign in to comment.