-
Notifications
You must be signed in to change notification settings - Fork 3
AutoBroadcast class with broadcast_add first usage #57
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, I think it's looking great! There are a few sanity checks/reference changes I'd like, but over all I think it's pretty much there.
src/ngraph/ngraph_autobroadcast.cc
Outdated
SetShapesAndAxes(); | ||
|
||
// if auto broadcast is possible | ||
if (broadcastshape_.size()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a sanity check that broadcastshape_ is equal to the output node shape?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure I understand. this is the job of the unit tests, in my opinion. but, I may be missing your point. if I want to check that broadcast shape is equal to output shape don't I need to have type propagation inside the class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me rephrase: Mxnet has inferred what the output of this broadcast should be. We're doing a second inference to identify how to expand the axes. Should we validate that mxnet's inferred shape and autobroadcast's inferred shape are the same?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not end up doing this. Was wondering... do we actually need to check for this error or will it be caught by some other piece of code. Let me know.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this will actually show an error, but I'm paranoid ;) If it fails for some obscure test case, it will through an error somewhere, but that error might be hard to debug. It's okay for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if I purposely mess up broadcastshape by adding '2' as a leading dimension after exiting the loop I get:
terminate called after throwing an instance of 'std::invalid_argument'
what(): Error with node Node(Broadcast_2): Broadcast arg, shape, and axes are incompatible
Aborted (core dumped)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@adstraw @mbrookhart , I second adding an assertion. What @adstraw is describing is the current implementation. Imagine, Alice comes in and makes some changes to SetShapesAndAxes
so this error isn't produced Error with node Node(Broadcast_2):
and yet the shapes are indeed different (that the promise SetShapesAndAxes
is required to fulfill). In debug builds we would like to fail as early as possible to the source of a problem and asserting that the shapes indeed match is a way to confirm that SetShapesAndAxes
fulfills its promise.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am leaving this out, for now. If we find that we crash and burn at some later date we can add this check. As it stands, AutoBroadcast produces no errors which simplifies things. It simply 1) broadcasts if possible or 2) leaves things as-is and punts error handling down the road.
EXPECT_EQ(getShapeFromParam(ab.rhs()), s1345); | ||
} | ||
|
||
} // namespace ngraph_bridge |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These tests look great, thank you!
auto rhsShape = TShape_to_NShape(node->inputs[1]->shape); | ||
|
||
AutoBroadcast ab(lhsNode, lhsShape, rhsNode, rhsShape); | ||
return ab.lhs() + ab.rhs(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still confused by the object that's simply a constructor and two getters, but that's a longer term discussion and doesn't block this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No test for this function? I'm worried that this will fail since ab.lhs() and ab.rhs() are const references to shared pointers only defined in the AutoBroadcast...they will be destroyed when ab is when this function returns, then the returned ngraph node will have nullptrs in it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
simply forgot the test for broadcast_add, will do that now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still confused by the object that's simply a constructor and two getters, but that's a longer term discussion and doesn't block this.
It's actually a really nice API. It's impossible to misuse 😄
node.ptr, broadcastshape_, node.axes); | ||
} | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I lean towards minimal comments, I find they clutter the code and make it harder for me to read. I find this over-commented, but that may just be me, and if everyone else wants more comments, that's okay.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I deleted a few comments, hope it's better
src/ngraph/ngraph_autobroadcast.h
Outdated
AutoBroadcast(const NgraphNodePtr &lhsNode, const ngraph::Shape &lhsShape, | ||
const NgraphNodePtr &rhsNode, const ngraph::Shape &rhsShape); | ||
const NgraphNodePtr &lhs() { return lhs_.ptr; } | ||
const NgraphNodePtr &rhs() { return rhs_.ptr; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should return a copy of the pointer, not a reference. I'm worried that the shared pointer will go out of scope and delete the node between construction and getting called in the graph executor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you!
src/ngraph/ngraph_autobroadcast.cc
Outdated
SetShapesAndAxes(); | ||
|
||
// if auto broadcast is possible | ||
if (broadcastshape_.size()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this will actually show an error, but I'm paranoid ;) If it fails for some obscure test case, it will through an error somewhere, but that error might be hard to debug. It's okay for now.
lhs_.reshape.insert(lhs_.reshape.begin(), lhsDim); | ||
rhs_.reshape.insert(rhs_.reshape.begin(), rhsDim); | ||
|
||
} else if (rhsDim == 1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rhsDim == 1
and lhsDim == 1
cases seem to duplicate the same logic. Would it make sense to add a small helper and call it like in this snippet below?
if (rhsDim == 1)
{
collectBroadcastAndReshapeAxes(lhsDim, rhsDim, lhs_, rhs_);
}
else (lhsDim == 1)
{
collectBroadcastAndReshapeAxes(rhsDim, lhsDim, rhs_, lhs_);
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
saw this after my last patch. I agree - it could clean up the code. let's see if I get more feedback and I can address.
// basic reshape and broadcast test | ||
// rhs reshape to 2,3,4 then | ||
// rhs broadcast to 2,3,4,5 | ||
TEST(NGRAPH_AUTOBROADCAST, RESHAPE_1X_BROADCAST) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In future, we could probably add a few more cases:
- scalar -> vector
- scalar -> matrix
- vector -> matrix
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
data1 = op_map[in1]; | ||
data2 = op_map[in2]; | ||
}; | ||
}; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nitpick: no new line
ngraph::Shape reshape; | ||
// axes (0-based) to broadcast by ngraph::op::Broadcast | ||
ngraph::AxisSet axes; | ||
} lhs_, rhs_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
an off-topic question? is this the convention we are using for members (members' names end w/ underscores)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My opinion: It seems to be the mxnet guideline therefore it is our guideline.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is standard practice so I think that we should follow that. It is something I've put into the Coding guideline document. I would say that it's pretty conventional at this point in C++.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in my previous gig the standard was m_camelCaseVariable. I don't really care, just so long as we have a plan.
for basic broadcast 2D and 3D cases also to handle edge input cases (empty, zero dimension)
// a zero dimension is invalid | ||
// so we should not hit this case "in the wild" | ||
// make explicit: no action taken on shapes with zero dimensions | ||
if (std::find(lhs_.shape.begin(), lhs_.shape.end(), 0) != lhs_.shape.end()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we turn these into assertions? If it's indeed invalid input and we never expect to see such shapes that's what we should assert. OTOH, if graphs w/ such shapes are valid but they don't make sense to broadcast we should throw, instead. What do you guys think @adstraw @mbrookhart
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point. I don't see any other case in src/ngraph where we are using assert. @mbrookhart is there any reason not to use assert in this case?
// mxnet scalars are pre-broadcast to requisite shape | ||
// so we should not hit this case "in the wild" | ||
// make explicit: no action taken on empty shape(s) | ||
if (lhs_.shape.size() == 0 || rhs_.shape.size() == 0) return false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
again, maybe this should be an assertion? my arguments are in the above comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm pretty late on this code review. I just added my comments on this but do not want to be the one to hold up the check-in.
if (node.shape != node.reshape) { | ||
// tell reshape to examine input dimensions in order | ||
ngraph::AxisVector order(node.shape.size()); | ||
std::iota(order.begin(), order.end(), 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using STL algorithms!! Nice!!! 👍 💯
const ngraph::Shape &lhsShape, | ||
const NgraphNodePtr &rhsNode, | ||
const ngraph::Shape &rhsShape) { | ||
lhs_.ptr = lhsNode; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These should be initialized using the initializer list. It's faster. I'm also not a big fan of having so much logic in the constructor, but I understand there is precedence of that in our code already. Constructors should at most be initializing the basic components for having an object instantiated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point. code is now merged. I can fix in a future patch.
ngraph::Shape reshape; | ||
// axes (0-based) to broadcast by ngraph::op::Broadcast | ||
ngraph::AxisSet axes; | ||
} lhs_, rhs_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is standard practice so I think that we should follow that. It is something I've put into the Coding guideline document. I would say that it's pretty conventional at this point in C++.
// e.g. when adding (2,3) tensor A to (2,1) tensor B | ||
// first Reshape tensor B to (2) | ||
// then Broadcast tensor B to (2,3) | ||
void ReshapeAndBroadcast(Node &node); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No const
on the parameter???
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in this case we are actually modifying the node within the function so we want a non-const reference here.
// a zero dimension is invalid | ||
// so we should not hit this case "in the wild" | ||
// make explicit: no action taken on shapes with zero dimensions | ||
if (std::find(lhs_.shape.begin(), lhs_.shape.end(), 0) != lhs_.shape.end()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a fan of the STL algorithm usage 🥇
No description provided.