Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tf_keras imports to prepare for Keras 3 #28588

Merged
merged 19 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/source/ja/troubleshooting.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ TensorFlowの[model.save](https://www.tensorflow.org/tutorials/keras/save_and_lo

```py
>>> from transformers import TFPreTrainedModel
>>> from tensorflow import keras

>>> model.save_weights("some_folder/tf_model.h5")
>>> model = TFPreTrainedModel.from_pretrained("some_folder")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
set_seed,
)
from transformers.keras_callbacks import KerasMetricCallback
from transformers.modeling_tf_utils import keras
from transformers.trainer_utils import get_last_checkpoint, is_main_process
from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
Expand Down Expand Up @@ -363,7 +364,7 @@ def main():

def _train_transforms(image):
img_size = image_size
image = tf.keras.utils.img_to_array(image)
image = keras.utils.img_to_array(image)
image = random_resized_crop(image, size=img_size)
image = tf.image.random_flip_left_right(image)
image /= 255.0
Expand All @@ -372,7 +373,7 @@ def _train_transforms(image):
return image

def _val_transforms(image):
image = tf.keras.utils.img_to_array(image)
image = keras.utils.img_to_array(image)
image = tf.image.resize(image, size=image_size)
# image = np.array(image) # FIXME - use tf.image function
image = center_crop(image, size=image_size)
Expand Down
16 changes: 15 additions & 1 deletion examples/tensorflow/language-modeling-tpu/run_mlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import re

import tensorflow as tf
from packaging.version import parse

from transformers import (
AutoConfig,
Expand All @@ -33,6 +34,19 @@
)


try:
import tf_keras as keras
except (ModuleNotFoundError, ImportError):
import keras

if parse(keras.__version__).major > 2:
raise ValueError(
"Your currently installed version of Keras is Keras 3, but this is not yet supported in "
"Transformers. Please install the backwards-compatible tf-keras package with "
"`pip install tf-keras`."
)


logger = logging.getLogger(__name__)

AUTO = tf.data.AUTOTUNE
Expand Down Expand Up @@ -209,7 +223,7 @@ def main(args):
strategy = tf.distribute.OneDeviceStrategy(device="/gpu:0")

if args.bfloat16:
tf.keras.mixed_precision.set_global_policy("mixed_bfloat16")
keras.mixed_precision.set_global_policy("mixed_bfloat16")

tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
config = AutoConfig.from_pretrained(args.pretrained_model_config)
Expand Down
16 changes: 15 additions & 1 deletion examples/tensorflow/question-answering/run_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import evaluate
import tensorflow as tf
from datasets import load_dataset
from packaging.version import parse
from utils_qa import postprocess_qa_predictions

import transformers
Expand All @@ -48,6 +49,19 @@
from transformers.utils import CONFIG_NAME, TF2_WEIGHTS_NAME, check_min_version, send_example_telemetry


try:
import tf_keras as keras
except (ModuleNotFoundError, ImportError):
import keras

if parse(keras.__version__).major > 2:
raise ValueError(
"Your currently installed version of Keras is Keras 3, but this is not yet supported in "
"Transformers. Please install the backwards-compatible tf-keras package with "
"`pip install tf-keras`."
)


# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.38.0.dev0")

Expand Down Expand Up @@ -233,7 +247,7 @@ def __post_init__(self):


# region Helper classes
class SavePretrainedCallback(tf.keras.callbacks.Callback):
class SavePretrainedCallback(keras.callbacks.Callback):
# Hugging Face models have a save_pretrained() method that saves both the weights and the necessary
# metadata to allow them to be loaded as a pretrained model in future. This is a simple Keras callback
# that saves the model with this method after each epoch.
Expand Down
16 changes: 15 additions & 1 deletion examples/tensorflow/test_tensorflow_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,20 @@
from unittest.mock import patch

import tensorflow as tf
from packaging.version import parse


try:
import tf_keras as keras
except (ModuleNotFoundError, ImportError):
import keras

if parse(keras.__version__).major > 2:
raise ValueError(
"Your currently installed version of Keras is Keras 3, but this is not yet supported in "
"Transformers. Please install the backwards-compatible tf-keras package with "
"`pip install tf-keras`."
)

from transformers.testing_utils import TestCasePlus, get_gpu_count, slow

Expand Down Expand Up @@ -115,7 +129,7 @@ def test_run_text_classification(self):
with patch.object(sys, "argv", testargs):
run_text_classification.main()
# Reset the mixed precision policy so we don't break other tests
tf.keras.mixed_precision.set_global_policy("float32")
keras.mixed_precision.set_global_policy("float32")
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_accuracy"], 0.75)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import numpy as np
from datasets import load_dataset
from packaging.version import parse

from transformers import (
AutoConfig,
Expand All @@ -46,11 +47,24 @@
import tensorflow as tf # noqa: E402


try:
import tf_keras as keras
except (ModuleNotFoundError, ImportError):
import keras

if parse(keras.__version__).major > 2:
raise ValueError(
"Your currently installed version of Keras is Keras 3, but this is not yet supported in "
"Transformers. Please install the backwards-compatible tf-keras package with "
"`pip install tf-keras`."
)


logger = logging.getLogger(__name__)


# region Helper classes
class SavePretrainedCallback(tf.keras.callbacks.Callback):
class SavePretrainedCallback(keras.callbacks.Callback):
# Hugging Face models have a save_pretrained() method that saves both the weights and the necessary
# metadata to allow them to be loaded as a pretrained model in future. This is a simple Keras callback
# that saves the model with this method after each epoch.
Expand Down
31 changes: 22 additions & 9 deletions src/transformers/activations_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,20 @@
import math

import tensorflow as tf
from packaging import version
from packaging.version import parse


try:
import tf_keras as keras
except (ModuleNotFoundError, ImportError):
import keras

if parse(keras.__version__).major > 2:
raise ValueError(
"Your currently installed version of Keras is Keras 3, but this is not yet supported in "
"Transformers. Please install the backwards-compatible tf-keras package with "
"`pip install tf-keras`."
)
Comment on lines +21 to +31
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand the rational for when the version compatible import is just handled e.g. in src/transformers/keras_callbacks.py and when the version is checked e.g. here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've rewritten things to always import from modeling_tf_utils. The only exceptions, where I copy-pasted it instead, are here in activations_tf, because that would create a circular import, and in the example files, which are designed for users to read and modify so I want to make it clear exactly which keras they're getting.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect :)

Comment on lines +21 to +31
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm slightly concerned that raising this exception at the top-level of a module is going to cause lots of issues. Users, who might not even be using tensorflow with transformers, but will have it in their environment will start having exception raised if they do e.g. from transformers import * (or something less stupid which ends up loading this module).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I was worried about this too! I think that specific problem shouldn't happen, because if TF isn't available then all the TF objects will be dummies and TF-specific files shouldn't be executed.

I'm still a bit unsure about the import from modeling_tf_utils, though - it feels dubious to me, but it was the only way I could think of to avoid pasting boilerplate everywhere. Maybe for peace of mind it'd be better to just do that, though, which would resolve this issue (and most of the other comments here)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can just install the different possibilities: no TF, old keras, new keras and check to see if the exceptions / warnings are egregious. I suspect the issue will not be users that don't have tensorflow, but for people that have it installed in their environment but aren't necessarily using it with transformers.

As TF is always a pain to find the set of compatible libraries, I can imagine a lot of people complaining.

Could we have the if parse logic put into a function and then add that as validation when people instantiate objects or use functions?

Or perhaps, create something like require_keras_lt_2 we can use to decorate objects?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did some testing with this! I set up an environment that should trigger the exception (TF 2.15 + Keras 3, no tf_keras). Initializing TF objects caused the exception to be thrown. However, I was able to initialize torch models and run them fine. I think our lazy loading protects us here - the exception should only appear when the user explicitly initializes TF objects.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Also in general I think the library maintainers segregated TF stuff pretty hard in the past to make sure that it didn't spam the console or allocate GPU memory)



def _gelu(x):
Expand Down Expand Up @@ -99,12 +112,12 @@ def glu(x, axis=-1):
return a * tf.math.sigmoid(b)


if version.parse(tf.version.VERSION) >= version.parse("2.4"):
if parse(tf.version.VERSION) >= parse("2.4"):

def approximate_gelu_wrap(x):
return tf.keras.activations.gelu(x, approximate=True)
return keras.activations.gelu(x, approximate=True)

gelu = tf.keras.activations.gelu
gelu = keras.activations.gelu
gelu_new = approximate_gelu_wrap
else:
gelu = _gelu
Expand All @@ -119,11 +132,11 @@ def approximate_gelu_wrap(x):
"glu": glu,
"mish": mish,
"quick_gelu": quick_gelu,
"relu": tf.keras.activations.relu,
"sigmoid": tf.keras.activations.sigmoid,
"silu": tf.keras.activations.swish,
"swish": tf.keras.activations.swish,
"tanh": tf.keras.activations.tanh,
"relu": keras.activations.relu,
"sigmoid": keras.activations.sigmoid,
"silu": keras.activations.swish,
"swish": keras.activations.swish,
"tanh": keras.activations.tanh,
}


Expand Down
6 changes: 3 additions & 3 deletions src/transformers/keras_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@
import tensorflow as tf
from huggingface_hub import Repository, create_repo
from packaging.version import parse
from tensorflow.keras.callbacks import Callback

from . import IntervalStrategy, PreTrainedTokenizerBase
from .modelcard import TrainingSummary
from .modeling_tf_utils import keras


logger = logging.getLogger(__name__)


class KerasMetricCallback(Callback):
class KerasMetricCallback(keras.callbacks.Callback):
"""
Callback to compute metrics at the end of every epoch. Unlike normal Keras metrics, these do not need to be
compilable by TF. It is particularly useful for common NLP metrics like BLEU and ROUGE that require string
Expand Down Expand Up @@ -265,7 +265,7 @@ def generation_function(inputs, attention_mask):
logs.update(metric_output)


class PushToHubCallback(Callback):
class PushToHubCallback(keras.callbacks.Callback):
"""
Callback that will save and push the model to the Hub regularly. By default, it pushes once per epoch, but this can
be changed with the `save_strategy` argument. Pushed models can be accessed like any other model on the hub, such
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/modelcard.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,7 @@ def from_keras(

def parse_keras_history(logs):
"""
Parse the `logs` of either a `tf.keras.History` object returned by `model.fit()` or an accumulated logs `dict`
Parse the `logs` of either a `keras.History` object returned by `model.fit()` or an accumulated logs `dict`
passed to the `PushToHubCallback`. Returns lines and logs compatible with those returned by `parse_log_history`.
"""
if hasattr(logs, "history"):
Expand Down Expand Up @@ -800,14 +800,14 @@ def parse_log_history(log_history):


def extract_hyperparameters_from_keras(model):
import tensorflow as tf
from .modeling_tf_utils import keras

hyperparameters = {}
if hasattr(model, "optimizer") and model.optimizer is not None:
hyperparameters["optimizer"] = model.optimizer.get_config()
else:
hyperparameters["optimizer"] = None
hyperparameters["training_precision"] = tf.keras.mixed_precision.global_policy().name
hyperparameters["training_precision"] = keras.mixed_precision.global_policy().name

return hyperparameters

Expand Down
3 changes: 1 addition & 2 deletions src/transformers/modeling_tf_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,6 @@ def load_pytorch_state_dict_in_tf2_model(
"""Load a pytorch state_dict in a TF 2.0 model. pt_state_dict can be either an actual dict or a lazy-loading
safetensors archive created with the safe_open() function."""
import tensorflow as tf
from keras import backend as K

if tf_inputs is None:
tf_inputs = tf_model.dummy_inputs
Expand Down Expand Up @@ -360,7 +359,7 @@ def load_pytorch_state_dict_in_tf2_model(

tf_loaded_numel += tensor_size(array)

K.set_value(symbolic_weight, array)
symbolic_weight.assign(tf.cast(array, symbolic_weight.dtype))
amyeroberts marked this conversation as resolved.
Show resolved Hide resolved
del array # Immediately free memory to keep peak usage as low as possible
all_pytorch_weights.discard(name)

Expand Down
Loading
Loading