Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NNVM] Bug fix Prevent fusing convolution with injective op #1608

Merged
merged 4 commits into from
Aug 17, 2018

Conversation

masahi
Copy link
Member

@masahi masahi commented Aug 16, 2018

Continuing from #1603.

Modify operator fuser to include following logics.

  • During the initial partition step, if injective op is followed by broadcast op, mark broadcast op's op pattern to be injective
  • In the last grouping step, check for input nodes op patterns. If one of them is kOutEWiseFusable and the other is kInjective, ignore the kOutEWiseFusable node and fuse only kInjective node.

A test case is included. This test fails without this PR.
I confirmed that this change does not affect compling resnet, mobilenet and vgg.

@tqchen please check if the logic is good.

@tqchen
Copy link
Member

tqchen commented Aug 16, 2018

@merrymercy can you also review this?

bool parent_injective = false;
for (const auto& e : inode.inputs) {
TOpPattern pt = pattern_vec[e.node_id];
if (pt == kOutEWiseFusable) parent_out_ewise = true;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style, multiline if always enclose with {}

@tqchen tqchen merged commit 6cd5a8f into apache:master Aug 17, 2018
@tqchen
Copy link
Member

tqchen commented Aug 17, 2018

Thanks @masahi , this is now merged

@kazum
Copy link
Contributor

kazum commented Aug 19, 2018

The below script fails after this commit. This kind of graph pattern can be found, for example, in NASNet.

import nnvm
from nnvm import symbol as sym

ch = 3
size = 8
data = sym.Variable(name="data")
concat = sym.concatenate(data, data)
conv = sym.conv2d(data=concat, kernel_size=(1,1), channels=ch*2)
net = sym.elemwise_add(concat, conv)
nnvm.compiler.build(net, "llvm", {"data": (1, ch, size, size)})

The result is as follows.

Traceback (most recent call last):
  File "tmp.py", line 10, in <module>
    nnvm.compiler.build(net, "llvm", {"data": (1, ch, size, size)})
  File "/home/kazutaka/git/tvm/nnvm/python/nnvm/compiler/build_module.py", line 304, in build
    graph = graph.apply("GraphCompile")
  File "/home/kazutaka/git/tvm/nnvm/python/nnvm/graph.py", line 234, in apply
    check_call(_LIB.NNGraphApplyPasses(self.handle, npass, cpass, ctypes.byref(ghandle)))
  File "/home/kazutaka/git/tvm/nnvm/python/nnvm/_base.py", line 75, in check_call
    raise NNVMError(py_str(_LIB.NNGetLastError()))
nnvm._base.NNVMError: [05:30:54] /home/kazutaka/git/tvm/nnvm/include/nnvm/././op.h:530: Check failed: op != nullptr

@masahi, can you check this out?

@masahi
Copy link
Member Author

masahi commented Aug 19, 2018

@kazum Thanks! I sent a PR to fix this #1622.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants