-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixes lower train metrics when using Keras Masking (SequenceMaskRandom, SequenceMaskLast) #983
Conversation
Documentation preview |
rerun tests |
02e35d6
to
0cc59fe
Compare
…and is causing a lower than real accuracy in model.fit() when using preds._keras_mask
0cc59fe
to
08ff219
Compare
The decorator was added to fix that dataloader issue. There is unsuitability with list columns in the dataloader and adding the decorator fixed it. Do we still have issues with metrics if we use both |
@gabrielspmoreira I tested the PR and now I am getting more consistent results between model.fit() and model.evaluate(). |
rerun tests |
Hi Edward. I remember you have added some @tf.function decorator to deal with list features. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on the discussion offline, it sounds like the CI failure is unrelated to tf.function
. Please ignore my previous comment.
…m, SequenceMaskLast) (#983) * Removed @tf.function from train_compute_metrics, as it is not needed and is causing a lower than real accuracy in model.fit() when using preds._keras_mask * Turning if condition into tf.cond to remove tf.function decorator * Making the should_compute_train_metrics_for_batch variable True by default
Fixes #961
Goals ⚽
This PR fix an issue that caused metrics obtained with
model.fit()
being much lower than the ones obtained withmodel.evaluate()
when Keras Masking is used.This bug was observed when comparing training and evaluation metrics of a Transformer example (as described in #961 ), which makes usage of Keras Masking (SequenceMaskRandom, SequenceMaskLast) to select items of the sequence for training / eval.
Implementation Details 🚧
@tf.function
decorator we had inmodel.train_compute_metrics()
. After replacing a condition inside that function bytf.cond()
, it was possible to remove the@tf.function
decorator and fix the error when using Keras Masking (i.e., settingpredictions._keras_mask
).Testing Details 🔍
test_train_metrics_steps
, to double check that the logic insidemodel.train_compute_metrics()
that skips steps for computing metrics continue to working in eager and graph mode.