Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TF version compatibility fixes #23663

Merged
merged 5 commits into from
May 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 13 additions & 18 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save
from .generation import GenerationConfig, TFGenerationMixin
from .tf_utils import shape_list
from .tf_utils import expand_1d, load_attributes_from_hdf5_group, save_attributes_to_hdf5_group, shape_list
from .utils import (
DUMMY_INPUTS,
SAFE_WEIGHTS_INDEX_NAME,
Expand All @@ -65,16 +65,15 @@
from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files


if parse(tf.__version__) >= parse("2.11.0"):
if parse(tf.__version__).minor >= 13:
from keras import backend as K
from keras.__internal__ import KerasTensor
elif parse(tf.__version__).minor >= 11:
from keras import backend as K
from keras.engine import data_adapter
from keras.engine.keras_tensor import KerasTensor
from keras.saving.legacy import hdf5_format
else:
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.engine import data_adapter
from tensorflow.python.keras.engine.keras_tensor import KerasTensor
from tensorflow.python.keras.saving import hdf5_format


if is_safetensors_available():
Expand Down Expand Up @@ -797,9 +796,7 @@ def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatch
try:
with h5py.File(resolved_archive_file, "r") as sharded_checkpoint_file:
# Retrieve the name of each layer from the H5 file
saved_h5_model_layers_name = set(
hdf5_format.load_attributes_from_hdf5_group(sharded_checkpoint_file, "layer_names")
)
saved_h5_model_layers_name = set(load_attributes_from_hdf5_group(sharded_checkpoint_file, "layer_names"))
weight_value_tuples = []

# Compute missing and unexpected sub layers
Expand Down Expand Up @@ -898,9 +895,7 @@ def load_tf_weights_from_h5(model, resolved_archive_file, ignore_mismatched_size
# Read the H5 file
with h5py.File(resolved_archive_file, "r") as sharded_checkpoint_file:
# Retrieve the name of each layer from the H5 file
saved_h5_model_layers_name = set(
hdf5_format.load_attributes_from_hdf5_group(sharded_checkpoint_file, "layer_names")
)
saved_h5_model_layers_name = set(load_attributes_from_hdf5_group(sharded_checkpoint_file, "layer_names"))

# Find the missing layers from the high level list of layers
missing_layers = list({layer.name for layer in model.layers} - saved_h5_model_layers_name)
Expand All @@ -924,7 +919,7 @@ def load_tf_weights_from_h5(model, resolved_archive_file, ignore_mismatched_size

# Create a dict from the H5 saved model that looks like {"weight_name": weight_value}
# And a set with only the names
for weight_name in hdf5_format.load_attributes_from_hdf5_group(h5_layer_object, "weight_names"):
for weight_name in load_attributes_from_hdf5_group(h5_layer_object, "weight_names"):
# TF names always start with the model name so we ignore it
name = "/".join(weight_name.split("/")[1:])

Expand Down Expand Up @@ -1528,8 +1523,8 @@ def train_step(self, data):
output_to_label = {val: key for key, val in label_to_output.items()}
if not self._using_dummy_loss and parse(tf.__version__) < parse("2.11.0"):
# Newer TF train steps leave this out
data = data_adapter.expand_1d(data)
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
data = expand_1d(data)
x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Has this been here for all the TF versions we support?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, since at least 2.4!

# If the inputs are mutable dictionaries, make a shallow copy of them because we will modify
# them during input/label pre-processing. This avoids surprising the user by wrecking their data.
# In addition, modifying mutable Python inputs makes XLA compilation impossible.
Expand Down Expand Up @@ -1635,8 +1630,8 @@ def test_step(self, data):
output_to_label = {val: key for key, val in label_to_output.items()}
if not self._using_dummy_loss and parse(tf.__version__) < parse("2.11.0"):
# Newer versions leave this out
data = data_adapter.expand_1d(data)
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
data = expand_1d(data)
x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)
# If the inputs are mutable dictionaries, make a shallow copy of them because we will modify
# them during input/label pre-processing. This avoids surprising the user by wrecking their data.
# In addition, modifying mutable Python inputs makes XLA compilation impossible.
Expand Down Expand Up @@ -2402,7 +2397,7 @@ def save_pretrained(
)
param_dset[:] = layer.numpy()
layers.append(layer_name.encode("utf8"))
hdf5_format.save_attributes_to_hdf5_group(shard_file, "layer_names", layers)
save_attributes_to_hdf5_group(shard_file, "layer_names", layers)

if push_to_hub:
self._upload_modified_files(
Expand Down
7 changes: 3 additions & 4 deletions src/transformers/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import collections
import csv
import importlib
import inspect
import json
import os
import pickle
Expand All @@ -36,7 +35,7 @@
from ..modelcard import ModelCard
from ..models.auto.configuration_auto import AutoConfig
from ..tokenization_utils import PreTrainedTokenizer
from ..utils import ModelOutput, add_end_docstrings, is_tf_available, is_torch_available, logging
from ..utils import ModelOutput, add_end_docstrings, infer_framework, is_tf_available, is_torch_available, logging


GenericTensor = Union[List["GenericTensor"], "torch.Tensor", "tf.Tensor"]
Expand Down Expand Up @@ -278,7 +277,7 @@ def infer_framework_load_model(
if isinstance(model, str):
raise ValueError(f"Could not load model {model} with any of the following classes: {class_tuple}.")

framework = "tf" if "keras.engine.training.Model" in str(inspect.getmro(model.__class__)) else "pt"
framework = infer_framework(model.__class__)
return framework, model


Expand Down Expand Up @@ -351,7 +350,7 @@ def get_framework(model, revision: Optional[str] = None):
except OSError:
model = TFAutoModel.from_pretrained(model, revision=revision)

framework = "tf" if "keras.engine.training.Model" in str(inspect.getmro(model.__class__)) else "pt"
framework = infer_framework(model.__class__)
return framework


Expand Down
87 changes: 87 additions & 0 deletions src/transformers/tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,90 @@ def check_embeddings_within_bounds(tensor: tf.Tensor, embed_dim: int, tensor_nam
f"layer's input dimension ({embed_dim}). The likely cause is some problem at tokenization time."
),
)


def save_attributes_to_hdf5_group(group, name, data):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two functions are also moving targets I was trying to import from Keras - I've just given up and copied them into the transformers codebase.

"""Saves attributes (data) of the specified name into the HDF5 group.

This method deals with an inherent problem of HDF5 file which is not able to store data larger than
HDF5_OBJECT_HEADER_LIMIT bytes.

Args:
group: A pointer to a HDF5 group.
name: A name of the attributes to save.
data: Attributes data to store.

Raises:
RuntimeError: If any single attribute is too large to be saved.

Copied from Keras to Transformers to avoid versioning issues.
"""
HDF5_OBJECT_HEADER_LIMIT = 64512
# Check that no item in `data` is larger than `HDF5_OBJECT_HEADER_LIMIT`
# because in that case even chunking the array would not make the saving
# possible.
bad_attributes = [x for x in data if len(x) > HDF5_OBJECT_HEADER_LIMIT]

# Expecting this to never be true.
if bad_attributes:
raise RuntimeError(
"The following attributes cannot be saved to HDF5 file because "
f"they are larger than {HDF5_OBJECT_HEADER_LIMIT} "
f"bytes: {bad_attributes}"
)

data_npy = np.asarray(data)

num_chunks = 1
chunked_data = np.array_split(data_npy, num_chunks)

# This will never loop forever thanks to the test above.
while any(x.nbytes > HDF5_OBJECT_HEADER_LIMIT for x in chunked_data):
num_chunks += 1
chunked_data = np.array_split(data_npy, num_chunks)

if num_chunks > 1:
for chunk_id, chunk_data in enumerate(chunked_data):
group.attrs["%s%d" % (name, chunk_id)] = chunk_data
else:
group.attrs[name] = data


def load_attributes_from_hdf5_group(group, name):
"""Loads attributes of the specified name from the HDF5 group.

This method deals with an inherent problem of HDF5 file which is not able to store data larger than
HDF5_OBJECT_HEADER_LIMIT bytes.

Args:
group: A pointer to a HDF5 group.
name: A name of the attributes to load.

Returns:
data: Attributes data.

Copied from Keras to Transformers to avoid versioning issues.
"""
if name in group.attrs:
data = [n.decode("utf8") if hasattr(n, "decode") else n for n in group.attrs[name]]
else:
data = []
chunk_id = 0
while "%s%d" % (name, chunk_id) in group.attrs:
data.extend(
[n.decode("utf8") if hasattr(n, "decode") else n for n in group.attrs["%s%d" % (name, chunk_id)]]
)
chunk_id += 1
return data


def expand_1d(data):
"""Expands 1-dimensional `Tensor`s into 2-dimensional `Tensor`s.
Copied from Keras to here to avoid versioning issues."""

def _expand_single_1d_tensor(t):
if isinstance(t, tf.Tensor) and t.shape.rank == 1:
return tf.expand_dims(t, axis=-1)
return t

return tf.nest.map_structure(_expand_single_1d_tensor, data)
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
expand_dims,
find_labels,
flatten_dict,
infer_framework,
is_jax_tensor,
is_numpy_array,
is_tensor,
Expand Down
32 changes: 24 additions & 8 deletions src/transformers/utils/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,11 +398,10 @@ def can_return_loss(model_class):
Args:
model_class (`type`): The class of the model.
"""
base_classes = str(inspect.getmro(model_class))

if "keras.engine.training.Model" in base_classes:
framework = infer_framework(model_class)
if framework == "tf":
signature = inspect.signature(model_class.call) # TensorFlow models
elif "torch.nn.modules.module.Module" in base_classes:
elif framework == "pt":
signature = inspect.signature(model_class.forward) # PyTorch models
else:
signature = inspect.signature(model_class.__call__) # Flax models
Expand All @@ -422,11 +421,10 @@ def find_labels(model_class):
model_class (`type`): The class of the model.
"""
model_name = model_class.__name__
base_classes = str(inspect.getmro(model_class))

if "keras.engine.training.Model" in base_classes:
framework = infer_framework(model_class)
if framework == "tf":
signature = inspect.signature(model_class.call) # TensorFlow models
elif "torch.nn.modules.module.Module" in base_classes:
elif framework == "pt":
signature = inspect.signature(model_class.forward) # PyTorch models
else:
signature = inspect.signature(model_class.__call__) # Flax models
Expand Down Expand Up @@ -565,3 +563,21 @@ def add_model_info_to_auto_map(auto_map, repo_id):
auto_map[key] = f"{repo_id}--{value}"

return auto_map


def infer_framework(model_class):
"""
Infers the framework of a given model without using isinstance(), because we cannot guarantee that the relevant
classes are imported or available.
"""
for base_class in inspect.getmro(model_class):
module = base_class.__module__
name = base_class.__name__
if module.startswith("tensorflow") or module.startswith("keras") or name == "TFPreTrainedModel":
return "tf"
elif module.startswith("torch") or name == "PreTrainedModel":
return "pt"
elif module.startswith("flax") or module.startswith("jax") or name == "FlaxPreTrainedModel":
return "flax"
else:
raise TypeError(f"Could not infer framework from class {model_class}.")