-
Notifications
You must be signed in to change notification settings - Fork 202
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Created TensorFormat enum #191
Open
sevarac
wants to merge
1
commit into
tensorflow:master
Choose a base branch
from
sevarac:master
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
30 changes: 30 additions & 0 deletions
30
ndarray/src/main/java/org/tensorflow/ndarray/buffer/layout/TensorFormat.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
package org.tensorflow.ndarray.buffer.layout; | ||
|
||
/** | ||
* Specifies the data format in tensor. | ||
* | ||
* With the default format "NHWC", the data is stored in the order of: | ||
* [batch, height, width, channels]. | ||
* | ||
* Alternatively, the format could be "NCHW", the data storage order of: | ||
* [batch, channels, height, width]. | ||
* | ||
* Additional format NCHW_VECT_C is specified by | ||
* https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnTensorFormat_ | ||
* although not sure if it is used or needed in tf | ||
* | ||
* Even More formats are specified in https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#WXWZ-tensor-descriptor | ||
* CHWN 4d tensor description | ||
* NCDHW 5d tensor description | ||
* NDHWC | ||
* CDHWN | ||
* | ||
* https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnTensorFormat_t | ||
* | ||
// https://github.com/tensorflow/java/blob/f85623ed366d903cfddb75177725dc276f359b15/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/MaxPoolGradGrad.java | ||
|
||
*/ | ||
public enum TensorFormat { | ||
NCHW, | ||
NHWC; | ||
} |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @sevarac , as we discussed briefly during our last session, this enum should probably be moved at the
tensorflow-core-api
ortensorflow-framework
level. It would be helpful to know which one if you can provide some quick examples of usage. I think @JimClarke5 had some in mind too.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, in general everywhere where public Options dataFormat(String dataFormat) is used
there should be now public Options dataFormat(TensorFormat dataFormat)
which includes a bunch of classes mainly layers https://github.com/tensorflow/java/search?q=dataFormat
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is used in
losses.Losses
forlosses.CategoricalCrossentropy
andmetrics.CategoricalCrossentropy
.Once this PR is merged, I'll change the logic in losses and metrics.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could be done in the C++ op generator, by looking at any argument called
dataFormat
. Though I don't think this naming convention is enforced to the kernel developers, which might lead to mistakes. But there will be possible workarounds so I'm fine if you want to give a try making that change as wellThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @sevarac , so any chance that you can move this enum to a different location, as I've suggested before?
If I need to pick one, I'll suggest
tensorflow-framework
overtensorflow-core-api
, what about underorg.tensorflow.framework.utils
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also please don't forget to add the header notice in your file (like this one, for example).