Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The axes of the output of function gconv2d are inconsistent with the axes of its input #2

Open
yangyu12 opened this issue Mar 15, 2017 · 2 comments

Comments

@yangyu12
Copy link

Hello @tscohen

I'm trying to use the tensorflow API in your GrouPy lib. And I faced some problem. Then I find in GrouPy/groupy/gconv/tensorflow_gconv/splitgconv2d.py that the axes of returned tensor are (batch, out channels, height, width). And I notice that the input axes are (batch, height, width, in channels).

However, in your tensorflow sample code, you simply feed the the output of the previous conv layer into the next conv layer without any reshape. Does it make sense ?

Thanks a lot!

@tscohen
Copy link
Owner

tscohen commented Mar 15, 2017

Thanks for catching this! As you can see, the tensorflow version has not been battle tested like the Chainer version has been (I just ran the unit tests in check_gconv2d.py).

Looking at the gconv2d constructor, there is a data_format='NHWC' default parameter and a check:

if data_format != 'NHWC':
        raise NotImplemented('Currently only NHWC data_format is supported. Got:' + str(data_format))

But I don't remember why we can't have NCHW. The filter transformation operation should not be affected by the data_format because for both NHWC and NCHW, the shape of the filter is the same. Tf.nn.gconv2d should also support NCHW, though perhaps this wasn't supported previously.

Could you try removing the check and running with data_format='NCHW'? To make sure nothing silently broke, it is probably a good idea to test the equivariance of the layer using something like the code in check_gconv2d.py. You can also do this for your whole network.

@yangyu12
Copy link
Author

Thanks for your reply!
I still have some confusion.
I've tried to just use this function to construct a network. I guess if I directly use this function to construct each layer, then the true data_formats of each layer output are :
NHWC(input) -> NCHW -> NHWC -> NCHW -> ...
I'm not sure if it is right.

Do you mean the above usage actually works? Maybe I need to revise your paper to get more familiar with the principle :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants