Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Created TensorFormat enum #191

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package org.tensorflow.ndarray.buffer.layout;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @sevarac , as we discussed briefly during our last session, this enum should probably be moved at the tensorflow-core-api or tensorflow-framework level. It would be helpful to know which one if you can provide some quick examples of usage. I think @JimClarke5 had some in mind too.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, in general everywhere where public Options dataFormat(String dataFormat) is used
there should be now public Options dataFormat(TensorFormat dataFormat)
which includes a bunch of classes mainly layers https://github.com/tensorflow/java/search?q=dataFormat

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is used in losses.Losses for losses.CategoricalCrossentropy and metrics.CategoricalCrossentropy.

  public static final int CHANNELS_LAST = -1;
  public static final int CHANNELS_FIRST = 1;

Once this PR is merged, I'll change the logic in losses and metrics.

Copy link
Collaborator

@karllessard karllessard Feb 16, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be done in the C++ op generator, by looking at any argument called dataFormat. Though I don't think this naming convention is enforced to the kernel developers, which might lead to mistakes. But there will be possible workarounds so I'm fine if you want to give a try making that change as well

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @sevarac , so any chance that you can move this enum to a different location, as I've suggested before?

If I need to pick one, I'll suggesttensorflow-framework over tensorflow-core-api, what about under org.tensorflow.framework.utils?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also please don't forget to add the header notice in your file (like this one, for example).


/**
* Specifies the data format in tensor.
*
* With the default format "NHWC", the data is stored in the order of:
* [batch, height, width, channels].
*
* Alternatively, the format could be "NCHW", the data storage order of:
* [batch, channels, height, width].
*
* Additional format NCHW_VECT_C is specified by
* https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnTensorFormat_
* although not sure if it is used or needed in tf
*
* Even More formats are specified in https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#WXWZ-tensor-descriptor
* CHWN 4d tensor description
* NCDHW 5d tensor description
* NDHWC
* CDHWN
*
* https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnTensorFormat_t
*
// https://github.com/tensorflow/java/blob/f85623ed366d903cfddb75177725dc276f359b15/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPoolGradGrad.java

*/
public enum TensorFormat {
NCHW,
NHWC;
}