-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add tf_keras imports to prepare for Keras 3 #28588
Changes from all commits
f0ddbd8
0de7d71
728b7a1
316ea58
dbe1e05
9201f68
3d86ddf
f0fd9bf
f4ba81e
cc6aaa0
a2d325c
ea92cd3
166b093
d7bb7b6
4224206
df5b486
193cbca
383bc79
b5e59ca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,7 +15,20 @@ | |
import math | ||
|
||
import tensorflow as tf | ||
from packaging import version | ||
from packaging.version import parse | ||
|
||
|
||
try: | ||
import tf_keras as keras | ||
except (ModuleNotFoundError, ImportError): | ||
import keras | ||
|
||
if parse(keras.__version__).major > 2: | ||
raise ValueError( | ||
"Your currently installed version of Keras is Keras 3, but this is not yet supported in " | ||
"Transformers. Please install the backwards-compatible tf-keras package with " | ||
"`pip install tf-keras`." | ||
) | ||
Comment on lines
+21
to
+31
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm slightly concerned that raising this exception at the top-level of a module is going to cause lots of issues. Users, who might not even be using tensorflow with transformers, but will have it in their environment will start having exception raised if they do e.g. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I was worried about this too! I think that specific problem shouldn't happen, because if TF isn't available then all the TF objects will be dummies and TF-specific files shouldn't be executed. I'm still a bit unsure about the import from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can just install the different possibilities: no TF, old keras, new keras and check to see if the exceptions / warnings are egregious. I suspect the issue will not be users that don't have tensorflow, but for people that have it installed in their environment but aren't necessarily using it with transformers. As TF is always a pain to find the set of compatible libraries, I can imagine a lot of people complaining. Could we have the Or perhaps, create something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did some testing with this! I set up an environment that should trigger the exception (TF 2.15 + Keras 3, no tf_keras). Initializing TF objects caused the exception to be thrown. However, I was able to initialize torch models and run them fine. I think our lazy loading protects us here - the exception should only appear when the user explicitly initializes TF objects. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (Also in general I think the library maintainers segregated TF stuff pretty hard in the past to make sure that it didn't spam the console or allocate GPU memory) |
||
|
||
|
||
def _gelu(x): | ||
|
@@ -99,12 +112,12 @@ def glu(x, axis=-1): | |
return a * tf.math.sigmoid(b) | ||
|
||
|
||
if version.parse(tf.version.VERSION) >= version.parse("2.4"): | ||
if parse(tf.version.VERSION) >= parse("2.4"): | ||
|
||
def approximate_gelu_wrap(x): | ||
return tf.keras.activations.gelu(x, approximate=True) | ||
return keras.activations.gelu(x, approximate=True) | ||
|
||
gelu = tf.keras.activations.gelu | ||
gelu = keras.activations.gelu | ||
gelu_new = approximate_gelu_wrap | ||
else: | ||
gelu = _gelu | ||
|
@@ -119,11 +132,11 @@ def approximate_gelu_wrap(x): | |
"glu": glu, | ||
"mish": mish, | ||
"quick_gelu": quick_gelu, | ||
"relu": tf.keras.activations.relu, | ||
"sigmoid": tf.keras.activations.sigmoid, | ||
"silu": tf.keras.activations.swish, | ||
"swish": tf.keras.activations.swish, | ||
"tanh": tf.keras.activations.tanh, | ||
"relu": keras.activations.relu, | ||
"sigmoid": keras.activations.sigmoid, | ||
"silu": keras.activations.swish, | ||
"swish": keras.activations.swish, | ||
"tanh": keras.activations.tanh, | ||
} | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure I understand the rational for when the version compatible import is just handled e.g. in
src/transformers/keras_callbacks.py
and when the version is checked e.g. hereThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've rewritten things to always import from
modeling_tf_utils
. The only exceptions, where I copy-pasted it instead, are here inactivations_tf
, because that would create a circular import, and in the example files, which are designed for users to read and modify so I want to make it clear exactly whichkeras
they're getting.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect :)