Skip to content

Commit

Permalink
Update pydoc for the issue of using Keras BatchNorm in TFF.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 402693657
  • Loading branch information
xiaoyux11 authored and tensorflow-copybara committed Oct 13, 2021
1 parent 53676bc commit 9c3e888
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions tensorflow_federated/python/learning/keras_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ def from_keras_model(
guaranteed to exist through the functional or Sequential API but are
not necessarily present for subclassed models.
Note: This function raises a UserWarning if the `tf.keras.Model` contains a
BatchNormalization layer, as the batch mean and variance will be treated as
non-trainable variables and won't be updated during the training (see
b/186845846 for more information). Consider using Group Normalization instead.
Args:
keras_model: A `tf.keras.Model` object that is not compiled.
loss: A single `tf.keras.losses.Loss` or a list of losses-per-output. If a
Expand Down

0 comments on commit 9c3e888

Please sign in to comment.