-
Notifications
You must be signed in to change notification settings - Fork 210
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Losses #129
Add Losses #129
Conversation
Sync with master tensorflow on upstream
* @param tf The TensorFlow Ops | ||
* @param labels the labels | ||
* @param predictions the predictions | ||
* @param <T> the data type of the result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perhaps "of the predictions and result"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
* | ||
* @param <T> the data type of the Tuple entries. | ||
*/ | ||
public class Tuple<T extends TNumber> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Tuple
class name is uncomfortably vanilla for me. Perhaps LossTuple
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This object will also be used in Metrics as many metrics are built using loss classes or Losses
methods. I have changed it to LossTuple
.
|
||
public class LossesImpl { | ||
|
||
/** |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Javadocs in this file are still partly in markdown.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I thought I caught them all, I will fix.
* @param tf the TensorFlow Ops | ||
* @param predictions Predicted values, a <code>Operand</code> of arbitrary dimensions. | ||
* @param labels Optional label <code>Operand</code> whose dimensions match <code>prediction</code>. | ||
* @return Tuple of <code>prediction</code>, <code>label</code> and <code>sampleWeight</code>. Each of them possibly has the last |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For this method, the returned sampleWeight
is always null.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is not always the case when we do Metrics.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm just thinking our documentation for this method might take into account that the returned sampleWeight
is always null
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now I see what you are talking about. I added a comment in the @return that sampleWeight
will be null
for this particular method signature.
* | ||
* @param tf the TensorFlow Ops | ||
* @param predictions Predicted values, a <code>Operand</code> of arbitrary dimensions. | ||
* @param labels Optional label <code>Operand</code> whose dimensions match <code>prediction |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is "match" the right way to describe the precondition relationship between predictions
and labels
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is definitely not the same Shape
. I was thinking of compatible, but that has specific meaning in Shape.isCompatibleWIth
. The description is saying the ranks must be equal or differ by one. I am not sure of one word that describes that. match
was the word used in the Python version of this method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, here's a suggestion:
- We could decide what we want the convention to be, in terms of squeeze-or-expand plus maybe broadcasting.
- Write this up carefully in the class javadoc for either
Loss
orLosses
. - Mention that documentation in the class javadoc for every other loss class.
- Also mention it in
Loss#call
. - And be silent about it in the individual methods of
Losses
andLossesImpl
.
Perhaps?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That said, it just occurred to me that we have another gap, and that filling that gap might help this issue.
We don't specify the behavior of these methods when labels
and predictions
don't have a permitted shape relationship. Nor do we make sure our behavior is consistent in that case.
Perhaps we should
- spell out that there's an
IllegalArgumentException
for that in the statically-known-dimensions case, - rename
squeezeOrExpandDimensions
into something likevalidateAndAdjustLossDimensions
, - have that method throw
IllegalArgumentException
when appropriate, - and then link to a fuller explanation in the documentation of the
IllegalArgumentException
?
Although I have never been in the habit of subclassing IllegalArgumentException
, I see Oracle does that sometimes. That could be an alternative way of pointing people to the fuller explanation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
match
must mean that the shapes of the input operands are capable of being molded into the relationships defined for the result of this method. Again LossesImpl
is intended to be marked as module private (JDK 11) and only should be accessible from the losses or metrics package. It is not intended to be a general use API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably note in the javadoc for the class that this is an internal implementation class and subject to change (and being locked off under the module system).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added this comment for the LossesImpl
class
/**
* These are helper methods for Losses and will be module private when
* Java modularity is applied to TensorFlow Java.
* These methods should not be used outside of the Loss package.
*/
|
||
if (labels != null) { | ||
Shape labelsShape = labels.asOutput().shape(); | ||
long labelRank = labelsShape.numDimensions(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency, labelsRank
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
/** | ||
* Creates a Binary Crossentropy Loss using {@link Class#getSimpleName()} as the loss name, {@link | ||
* #FROM_LOGITS_DEFAULT} for fromLogits, {@link #LABEL_SMOOTHING_DEFAULT} for labelSmoothing and a | ||
* Loss Reduction of {@link * Reduction#AUTO} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Extraneous *
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Deleted
/** | ||
* Creates a categorical cross entropy Loss using {@link Class#getSimpleName()} as the loss name, | ||
* {@link #FROM_LOGITS_DEFAULT} for fromLogits, {@link #LABEL_SMOOTHING_DEFAULT} for | ||
* labelSmoothing, a Loss Reduction of {@link * Reduction#AUTO}, and an axis of {@link |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Extraneous *
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed all Extraneous @link *
* Creates a Loss using a Loss Reduction of {@link Reduction#AUTO} | ||
* | ||
* @param tf the TensorFlow Ops | ||
* @param name the name of this Loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
. . . , or null to use {@link Class#getSimpleName()}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why would someone want to pass null
, when there are other CTORs that handle that condition?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For APIs that will get enough use to be worth some polish, I tend toward carefully documenting edge cases. I don't know whether we want to invest in that now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's worth documenting it in case users build their own losses.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, added this to name
param, if null the name will be {@link Class#getSimpleName()}.
* Creates a Loss | ||
* | ||
* @param tf the TensorFlow Ops | ||
* @param name the name of this loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
. . . , or null to use {@link Class#getSimpleName()}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, added this to all name
param, if null the name will be {@link Class#getSimpleName()}.
* | ||
* @param labels the truth values or labels | ||
* @param predictions the predictions | ||
* @param <T> The data type of the labels, predictions and loss. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, there's a separate <U>
for the labels.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
// Compute cross entropy from probabilities. | ||
Operand<T> cce = | ||
tf.reduceSum( | ||
tf.math.mul(tLabels, tf.math.log(predictions)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although in this internal case of this method, we do broadcast. I'll stop commenting on this issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can return to the "squeezeOrExpandDimensions
followed by broadcasting" topic when I work on #130 .
Resolved.
*/ | ||
public static <T extends TNumber, U extends TNumber> Operand<T> meanAbsoluteError( | ||
Ops tf, Operand<U> labels, Operand<T> predictions) { | ||
Operand<T> tLabels = tf.dtypes.cast(labels, predictions.asOutput().dataType()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to avoid this cast in the case where labels
already has the same data type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess the question is what is the overhead of casting onto oneself vs the overhead of checking? I would hope that tf.dtypes.cast
already handles this, but I could be mistaken.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code for checking could be something like this:
@SuppressWarnings("unchecked")
private static <T extends TNumber, U extends TNumber> Operand<T> castIfNecessary(
Operand<U> value, DataType<T> requiredType) {
return (value.asOutput().dataType() == requiredType)
? (Operand<T>) value
: tf.dtypes.cast(value, requiredType);
}
So the overhead of checking would be the function call plus value.asOutput().dataType() == requiredType
.
Looking at the code for tf.dtypes.cast
, unless we think a cast is almost always needed, it would be cheaper to do the check to sometimes avoid it.
public <U extends TType, T extends TType> Cast<U> cast(Operand<T> x, DataType<U> DstT,
Cast.Options... options) {
return Cast.create(scope, x, DstT, options);
}
@Endpoint(describeByClass = true)
public static <U extends TType, T extends TType> Cast<U> create(Scope scope, Operand<T> x, DataType<U> DstT, Options... options) {
OperationBuilder opBuilder = scope.env().opBuilder("Cast", scope.makeOpName("Cast"));
opBuilder.addInput(x.asOutput());
opBuilder = scope.applyControlDependencies(opBuilder);
opBuilder.setAttr("DstT", DstT);
if (options != null) {
for (Options opts : options) {
if (opts.Truncate != null) {
opBuilder.setAttr("Truncate", opts.Truncate);
}
}
}
return new Cast<U>(opBuilder.build());
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In graph construction mode the overhead is probably irrelevant because it's only called once during construction. In eager mode it could be faster as it could sidestep a JNI call in each step, but I suspect we've got other issues to get speed in eager mode.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like castIfNecessary
as a general util method. It would be used almost everywhere, so it would be a huge change.
Perhaps create a new PR for castIfNecessary
, then once that is merged we can start retrofitting all packages under framework.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In graph construction mode, an unnecessary call to cast
creates an unnecessary graph operation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shrug it'll be a no-op most of the time and compiled away if we get XLA working. Given the relative size of the computation around it I suspect it won't be an issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also vote for a explicit check in the code to avoid adding an extra operation to the graph when it is not required
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I will add a helper class in org.tensorflow.framework.utils
, then retrofit the Loss
classes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a comment on @deansher proposed method here, the datatypes for <U>
and <T>
should not be restricted to TNumber
because it is valid to cast to/from TNumber
and TBool
.
public static <T extends TNumber, U extends TNumber> Operand<T> meanAbsolutePercentageError( | ||
Ops tf, Operand<U> labels, Operand<T> predictions) { | ||
DataType<T> dataType = predictions.asOutput().dataType(); | ||
Operand<T> tLabels = tf.dtypes.cast(labels, predictions.asOutput().dataType()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can just use dataType
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
* @param <T> the data type of the Operands | ||
* @return the binary crossentropy loss. | ||
*/ | ||
private static <T extends TNumber> Operand<T> binaryCrossentropy( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tripped over this private method having the usual naming of a loss method, since I didn't notice that it was private and so expected it to follow the conventions of public loss methods, such as invoking squeezeOrExpandDimensions
. Also (if I'm navigating accurately through unfamiliar territory), this method doesn't compute a binaryCrossentropy
since it depends on its caller to compute the mean at the end.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method does the grunt work for the binaryCrossentropy after the operands have had their shapes and types manipulated and after smoothing the labels. Perhaps a new name would remove some of the confusion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I wonder if we want to call it something like binaryCrossentropyHelper
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, Changed
import org.tensorflow.types.family.TNumber; | ||
|
||
/** | ||
* Computes the categorical hinge loss between labels and predictions. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we follow the Python in documenting that labels are expected to be 0
or 1
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, The Python CategporicalHinge
class does not mention that at all, but it is mentioned in the categorical_hinge
method.
I have added an entry to the class JavaDoc and to the Losses.categoricalHinge
method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually the values can be [-1, 0, 1]
. [0,1]
is converted to [-1,1
]. I have added a value check to make sure the values are wholly contained in the allowed values set [-1, 0, 1]
. This will either throw TFInvalidArgumentException
if run in Graph mode via a control dependency, and throw IllegalArgumentException
if created in Eager mode with the call
method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool -- Resolved.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does it do if there are all three of [-1 0 1]
present? That's probably an invalid input, does it throw?
public static <T extends TNumber, U extends TNumber> Operand<T> meanSquaredLogarithmicError( | ||
Ops tf, Operand<U> labels, Operand<T> predictions) { | ||
DataType<T> dataType = predictions.asOutput().dataType(); | ||
Operand<T> tLabels = tf.dtypes.cast(labels, predictions.asOutput().dataType()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could just use dataType
. I'll stop mentioning this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed, hopefully I have fixed them all.
} | ||
|
||
/** | ||
* Calculates the mean squared logarithmic percentage error between labels and predictions. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think "percentage" is extraneous here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
* | ||
* @param tf the TensorFlow Ops | ||
* @param predictions Predicted values, a <code>Operand</code> of arbitrary dimensions. | ||
* @param labels Optional label <code>Operand</code> whose dimensions match <code>prediction |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That said, it just occurred to me that we have another gap, and that filling that gap might help this issue.
We don't specify the behavior of these methods when labels
and predictions
don't have a permitted shape relationship. Nor do we make sure our behavior is consistent in that case.
Perhaps we should
- spell out that there's an
IllegalArgumentException
for that in the statically-known-dimensions case, - rename
squeezeOrExpandDimensions
into something likevalidateAndAdjustLossDimensions
, - have that method throw
IllegalArgumentException
when appropriate, - and then link to a fuller explanation in the documentation of the
IllegalArgumentException
?
Although I have never been in the habit of subclassing IllegalArgumentException
, I see Oracle does that sometimes. That could be an alternative way of pointing people to the fuller explanation.
// Use dynamic rank. | ||
|
||
// TODO Operand<TInt32> rankDiff = tf.math.sub(tf.rank(predictions), tf.rank(labels)); | ||
if (predictionsRank == Shape.UNKNOWN_SIZE && Shape.isCompatible(predictionsShape.size(-1), 1)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the rank is unknown, then the size of the last dimension is guaranteed to be unknown, so isCompatible
is guaranteed true. (But there may be some idiomatic reason for writing it this way, of which I am blissfully unaware.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct, it should have been or
not and
.
if (labels != null) { | ||
Shape labelsShape = labels.asOutput().shape(); | ||
long labelRank = labelsShape.numDimensions(); | ||
if (labelRank != Shape.UNKNOWN_SIZE && predictionsRank != Shape.UNKNOWN_SIZE) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm pretty sure this logic is wrong. Perhaps either
- document preconditions of
removeSqueezableDimensions
and check exactly those, - or (my leaning) just invoke
removeSqueezableDimensions
and make it however smart it needs to be.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic is checking to see if both objects ranks are known (not Shape.unknown()
). If both ranks are known, then it checks to see if the shapes are already in the right relationship or not. If not in the right relationship, then call removeSqueezableDimensions
. It is basically an optimization to avoid doing the work in removeSqueezableDimensions
if it does not need to be done.
*/ | ||
private static <T extends TNumber> int[] allAxis(Operand<T> op) { | ||
int rank = op.asOutput().shape().numDimensions(); | ||
int[] axes = new int[rank]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rank
could be -1
at this point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
* @param <T> the type of Operand | ||
* @return the integer array representing all the axes of the operand. | ||
*/ | ||
private static <T extends TNumber> int[] allAxis(Operand<T> op) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
allAxes
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed name to allAxes
} | ||
Shape weightsShape = sampleWeight.asOutput().shape(); | ||
long weightsRank = weightsShape.numDimensions(); | ||
if (weightsRank == 0) { // scalar |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What should happen if weightsRank
is UNKNOWN
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It falls through and executes the last part of the method after the // Use dynamic rank.
comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:-) Oh yeah, that.
|
||
if (predictionsRank != Shape.UNKNOWN_SIZE && weightsRank != Shape.UNKNOWN_SIZE) { | ||
|
||
if (weightsRank - predictionsRank == 1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we are working with the original predictionsRank
, when we wanted to be working with the new rank.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This matches the original Python code, but when you think about it, the predictions
rank would never change from UNKNOWN to KNOWN and vice versa in a static context.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking perhaps predictions
changed rank through our squeezing it to match labels
earlier in the method. But I think there's a more pernicious problem. Here's an elided version of some of the method's code. Notice that we may squeeze predictions
, but we only store the result in tuple
. If we then also work with sampleWeight
, we neither reference the squeezed version of predictions
nor return it.
if (labels != null) {
. . .
if (predictionsRank - labelRank != 1 || predictionsShape.size(-1) == 1) {
tuple = removeSqueezableDimensions(tf, labels, predictions);
}
} else { // use dynamic rank
tuple = removeSqueezableDimensions(tf, labels, predictions);
}
}
. . .
if (predictionsRank != Shape.UNKNOWN_SIZE && weightsRank != Shape.UNKNOWN_SIZE) {
if (weightsRank - predictionsRank == 1) {
sampleWeight = tf.squeeze(sampleWeight);
. . .
}
return new Tuple<>(labels, predictions, sampleWeight);
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we should probably fetch the labels
and predictions
from tuple
first. I'll fix it.
* Each of them possibly has the last dimension squeezed, <code>sampleWeight</code> could be | ||
* extended by one dimension. If <code>sampleWeight</code> is null, (prediction, label) is | ||
* returned. | ||
*/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method has a myriad of complex cases, so I think it deserves its own direct unit test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could not find direct test cases for this method in Python. It's defined in tensorflow/tensorflow/python/ops/losses/utils.py
. You want to take a stab at it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:-) Totally. I want to do some work on #92 first, so I'll open an issue for myself.
…ictions and weights are returned in LossTuple
Hi @JimClarke5 , what would be the best order for reviewing your PRs? You have this one, #123 and #106 that are still opened. |
I would do activations #123 first. |
@karllessard I have started working on Metrics which depends on Loss, and thusly this PR. I plan to do Metrics in two PRs, the first PR will focus on Metrics that depend on |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In addition to the specific comments, I think it might be a good idea to add checks that the values are in the expected range (i.e. if it's expecting probabilities then it should check that they are in the range 0-1). Otherwise it's a right pain to track that down. Not sure if it will add too much overhead, but the loss computation tends to be much cheaper than the forward or backward passes, so hopefully it'll be fine.
tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java
Outdated
Show resolved
Hide resolved
* | ||
* @param tf the TensorFlow Ops | ||
* @param fromLogits Whether to interpret predictions as a tensor of logit values | ||
* @param labelSmoothing Float in [0, 1]. When 0, no smoothing occurs. When > 0, we compute the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does labelSmoothing = 1.0
mean the true label distribution is set to 1/n
? I'm not sure what "squeezing the values towards 0.5" means, because it would only be 0.5 in a binary problem.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually this is the comment for BinaryCrossentropy. It should be:
Float in <code>[0, 1]</code>. When <code>> 0</code>, label values are smoothed, meaning the
confidence on label values are relaxed. e.g. <code>label_smoothing=0.2<code> means that we will use a
value of </code>0.1<code> for label </code>0<code> and </code>0.9<code> for label </code>1<code>
I'll fix it.
tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalHinge.java
Show resolved
Hide resolved
tensorflow-framework/src/main/java/org/tensorflow/framework/losses/LogCosh.java
Outdated
Show resolved
Hide resolved
* Creates a Loss using a Loss Reduction of {@link Reduction#AUTO} | ||
* | ||
* @param tf the TensorFlow Ops | ||
* @param name the name of this Loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's worth documenting it in case users build their own losses.
tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java
Show resolved
Hide resolved
tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java
Outdated
Show resolved
Hide resolved
...w-framework/src/main/java/org/tensorflow/framework/losses/SparseCategoricalCrossentropy.java
Outdated
Show resolved
Hide resolved
* | ||
* @param tf the TensorFlow Ops | ||
* @param predictions Predicted values, a <code>Operand</code> of arbitrary dimensions. | ||
* @param labels Optional label <code>Operand</code> whose dimensions match <code>prediction |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably note in the javadoc for the class that this is an internal implementation class and subject to change (and being locked off under the module system).
...low-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java
Outdated
Show resolved
Hide resolved
@Craigacp we could set a range check, but it would have to be a control dependency e.g. using |
I tested the range check, but found out that |
I have added 2 methods to FIrst, control dependancies do not work in Eager mode. To handle this, I throw an One question, should these utilities be stored in a common location like |
Add in rangeCheck and valueCheck Misc fixes based on review
tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java
Show resolved
Hide resolved
* These are helper methods for Losses and will be module private when Java modularity is applied to | ||
* TensorFlow Java. These methods should not be used outside of the Loss package. | ||
*/ | ||
public class LossesImpl { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For me, a *Impl
should be the implementation of an interface, this one looks more like a LossesHelper
with all its static methods (and the class should probably be final).
I did not went through the whole thing but it looks like these helpers could also be moved directly to Loss
as protected methods?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The split really comes for module visibility. Losses
should be publicly accessible, while LossesImpl
should be module private. Some LossesImpl
methods may be used by metrics
. Whether we call it LossesImpl
of LossesHelper
is a matter of preference. The current methods in LossesImpl
should not be restricted to Loss classes as metrics
classes may also make use of them, therefore protected
is not the right semantic, .
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It feels uncomfortable to me that we plan to use the LossesImpl
methods from other parts of our framework while restricting them from public use. When a system's built-ins rely on privileged capabilities that aren't available to 3rd-party code, I think it is commonly a big problem for the system's extensibility. In this case, I do see room to argue that these methods aren't "capabilities", but are just "implementation" which can safely be hidden. But given that it is important to us to reuse them for our own metrics, I lean toward thinking of them as capabilities that we should expose.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a tight symmetry between Losses and Metrics as many (but not all) metrics rely on the methods in Losses.
Don't think other packages will have this close of a relationship.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a potential use case justifying exposing these to the public? Seeing as they are utilities needed to implement Losses/Metrics.
Agree with a rename to LossesHelper or LossesUtility to differentiate from interface implementation, however.
* predictions is scaled by the corresponding value of sample_weight. (Note on dN-1: all loss | ||
* functions reduce by 1 dimension, usually axis=-1.) | ||
* @param <T> The data type of the predictions, sampleWeights and loss. | ||
* @param <U> The data type of the labels. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Personally, I'd lean toward using some of our own single-letter conventions for situations that are common in our own code, including
L
as the labels type.
This may be hard to follow consistently once several letters have been used e.g. 'L' might be needed for something other than label type. Seems a tad more confusing than the standard type names
* These are helper methods for Losses and will be module private when Java modularity is applied to | ||
* TensorFlow Java. These methods should not be used outside of the Loss package. | ||
*/ | ||
public class LossesImpl { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a potential use case justifying exposing these to the public? Seeing as they are utilities needed to implement Losses/Metrics.
Agree with a rename to LossesHelper or LossesUtility to differentiate from interface implementation, however.
import org.tensorflow.types.family.TNumber; | ||
|
||
import java.util.Collections; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Class should have Javadoc description, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have changed the class name to LossesHelper
.
I don't understand your comment on Class JavaDoc. This is what I have in my copy.
/**
* These are helper methods for Losses and Metrics and will be module private when Java modularity is applied to
* TensorFlow Java. These methods should not be used outside of the losses and metrics packages.
*/
The basic comment was put in a while a ago, and I just updated it to mention metrics
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rendering issue I think. Looks good, thanks.
// * @param <T> The data type of the predictions, sampleWeights and loss. | ||
// * @param <U> The data type of the labels. | ||
// * @return the loss | ||
// * |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we remove this commented-out documentation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
That's just a transient network error:
|
@saudet , @JimClarke5 , it might also be related that we are having trouble these days to build our artifacts as you can see here, the linux GPU build runs out of space and that prevent the last step to occur (i.e. the bulk deploy that normalize all snapshots timestamp which might explain why a few artifacts disappeared lately: #142). I’ve retried many times but without success. Samuel, any idea how to solve this again? |
It just looks like GitHub Actions is below its guaranteed 14 GB of disk space:
We'll have to wait until they fix that, again, I guess? |
* @return the value cast to the required data type. | ||
*/ | ||
@SuppressWarnings("unchecked") | ||
public static <T extends TType, U extends TType> Operand<T> cast( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should open an issue to track inserting these cast checks into the optimizers for uniformity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could do it in the #106 Learning Rate PR if that works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, let's not hold anything up for it, it's just something to clean up later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few more documentation things.
* @param fromLogits Whether to interpret predictions as a tensor of logit values | ||
* @param labelSmoothing A number in the range, [0, 1]. When 0, no smoothing occurs. When > 0, | ||
* compute the loss between the predicted labels and a smoothed version of the true labels, | ||
* where the smoothing squeezes the labels towards 0.5. Larger values of label_smoothing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
label_smoothing
-> labelSmoothing
, here and elsewhere in this file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
* @param fromLogits Whether to interpret predictions as a tensor of logit values | ||
* @param labelSmoothing Float in [0, 1]. When 0, no smoothing occurs. When > 0, we compute the | ||
* loss between the predicted labels and a smoothed version of the true labels, where the | ||
* smoothing squeezes the labels towards 0.5. Larger values of label_smoothing correspond to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one's still got the doc from BinaryCrossEntropy wrt label_smoothing
. And it's snake_case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
* @param tf the TensorFlow Ops | ||
* @param fromLogits Whether to interpret predictions as a tensor of logit values | ||
* @param labelSmoothing Float in [0, 1]. When 0, no smoothing occurs. When > 0, we compute the | ||
* loss between the predicted labels and a smoothed version of the true labels, where the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incorrect doc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
* @param fromLogits Whether to interpret predictions as a tensor of logit values | ||
* @param labelSmoothing Float in [0, 1]. When 0, no smoothing occurs. When > 0, we compute the | ||
* loss between the predicted labels and a smoothed version of the true labels, where the | ||
* smoothing squeezes the labels towards 0.5. Larger values of label_smoothing correspond to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incorrect doc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
* | ||
* <p>Note that it is a number between -1 and 1. When it is a negative number between -1 and 0, 0 | ||
* indicates orthogonality and values closer to -1 indicate greater similarity. The values closer | ||
* to 1 indicate greater dissimilarity. This makes it usable as a loss function in a setting where |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This javadoc is better, but I think it should mention that this function is inverted from the regular cosine similarity, as that's 1 when the values are most similar and -1 when they point in opposite directions. It makes sense that it is inverted because then you can minimise it sensibly, but it is confusing if you're just browsing through.
* @param <T> the data type of the labels | ||
* @return the smoothed binary labels | ||
*/ | ||
private static <T extends TNumber> Operand<T> smoothLabelsBinaryX( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this would be better called smoothBinaryLabels
as it's not specific to the binary cross entropy as far as I can tell. But it's a private method so it's not too much of an issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
* @param <T> the data type of the labels | ||
* @return the smoothed categorical labels | ||
*/ | ||
private static <T extends TNumber> Operand<T> smoothLabelsCatX( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar comment to above, but smoothCategoricalLabels
. Also I think the doc should explicitly state that it's smoothing the labels towards 1/n
where n
is the number of classes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK
@Craigacp I have modified the JavaDoc for
|
smoothLabelsCatX to smoothCategoricalLabels. Added clarification oin JavaDoc for cosineSimilarity to describe the difference between the mathematical definition for cosine similarity and the loss definition.
7eefbb7
tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java
Outdated
Show resolved
Hide resolved
tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java
Outdated
Show resolved
Hide resolved
fix typo error in JavaDoc comment
@Craigacp , it looks like @JimClarke5 pushed all commits for the last changes you've requested, I let you validate and dismiss your review if that's the case, thanks @JimClarke5 , some of your unit test files don't have a valid source header, can you please add one? Also, I've noticed that now you are authoring your work to Oracle, that is ok but I just want to validate with you if it was intentional, thanks |
I have fixed the copyright issues. The attribution to Oracle was a mistake, as I copied the copyright from Optimizers to the Loss classes. I have replaced them with the |
@JimClarke5 yes, I'm required to put the Oracle copyright header on substantive external open source contributions that are part of my job. There's an internal review process for all the things that I write that are longer than a line or two. @karllessard I've resolved my two comments, I think this is good to be merged now. I can approve it if you want. |
All right, this one is merged now, thanks for great contribution again @JimClarke5 ! |
::party::
…On Tue, Nov 17, 2020 at 7:09 AM Karl Lessard ***@***.***> wrote:
All right, this one is merged now, thanks for great contribution again
@JimClarke5 <https://github.com/JimClarke5> !
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#129 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AABZ7X44NRW56MHAI4DLBCDSQJRZBANCNFSM4SJLM76A>
.
|
This PR adds losses to framework.
All the loss sub-classes inherit from
Loss
.The
Losses
class, has methods that can be called directly to get raw loss values. These are utilized by theLoss
subclasses before applying aReduction
to the loss. TheLosses
class will also be used by some of theMetric
classes when that feature is submitted.The
impl
package has some helper methods and classes utilized by theloss
classes, and are not expected to be exposed outside the frameworkmodule
, when we do modules.