From e66224c05a91e259e8b500e065ebfe4db928670d Mon Sep 17 00:00:00 2001 From: matt Date: Thu, 16 Dec 2021 12:43:44 +0000 Subject: [PATCH 01/17] Working on splitting out labels --- src/transformers/keras_callbacks.py | 120 +++++++++++++++++++++------- 1 file changed, 93 insertions(+), 27 deletions(-) diff --git a/src/transformers/keras_callbacks.py b/src/transformers/keras_callbacks.py index ff1b938cec07e3..1bc9efc55373f8 100644 --- a/src/transformers/keras_callbacks.py +++ b/src/transformers/keras_callbacks.py @@ -2,8 +2,10 @@ import os from pathlib import Path from time import sleep -from typing import Optional, Union +from typing import Optional, Union, Callable +import numpy as np +import tensorflow as tf from tensorflow.keras.callbacks import Callback from huggingface_hub import Repository @@ -16,6 +18,67 @@ logger = logging.getLogger(__name__) +class KerasMetricCallback(Callback): + """ + Callback to prompt metrics at the end of every epoch. + + Args: + metric_fn: Metric function provided by the user. + val_dataset: Validation data to be used to evaluate the model at + the end of the epoch. + metric_name: Name of the metric calculated in metric_fn. + batch_size: Batch size. + labels: Labels. + """ + + def __init__(self, metric_fn: Callable, + val_dataset: Union[tf.data.Dataset, np.ndarray, tf.Tensor, tuple, dict], + metric_name: Optional[str], + label_names: Optional[str], + batch_size: Optional[int] = None): + super().__init__() + self.metric_fn = metric_fn + self.batch_size = batch_size + if not isinstance(val_dataset, tf.data.Dataset): + if batch_size is None: + raise ValueError("When passing data to KerasMetricCallback that is not a pre-batched tf.data.Dataset " + "the batch_size argument must be set.") + # Wrap a tf.data.Dataset around it + val_dataset = tf.data.Dataset.from_tensor_slices(val_dataset).batch(batch_size, drop_remainder=False) + self.val_dataset = val_dataset + self.metric_name = metric_name + self.label_names = "labels" if label_names is None else label_names + + def on_epoch_end(self, epoch, logs=None): + + prediction_list = [] + label_list = [] + + for batch in self.val_dataset: + + if isinstance(batch, tuple): + batch, labels = batch + labels = np.asarray(labels) + + elif isinstance(batch, dict): + labels = np.asarray(batch["labels"]) + + predictions = self.model.predict(batch) + predictions_dict = dict(predictions) + + for prediction in predictions_dict["logits"]: + prediction_list.append(predictions) + + for label in labels: + label_list.append(label) + + metric_value = self.metric_fn(predictions=np.asarray(prediction_list), labels=np.asarray(label_list)) + + if metric_name is not None: + print(f"{self.metric_name} for epoch {epoch} is {metric_value}") + else: + print(f"At epoch {epoch}: {metric_value}") + class PushToHubCallback(Callback): def __init__( self, @@ -29,32 +92,35 @@ def __init__( **model_card_args ): """ - output_dir (:obj:`str`): - The output directory where the model predictions and checkpoints will be written and synced with the - repository on the Hub. - save_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"epoch"`): - The checkpoint save strategy to adopt during training. Possible values are: - - * :obj:`"no"`: No save is done during training. - * :obj:`"epoch"`: Save is done at the end of each epoch. - * :obj:`"steps"`: Save is done every :obj:`save_steps` - save_steps (:obj:`int`, `optional`): - The number of steps between saves when using the "steps" save_strategy. - tokenizer (:obj:`PreTrainedTokenizerBase`, `optional`): - The tokenizer used by the model. If supplied, will be uploaded to the repo alongside the weights. - hub_model_id (:obj:`str`, `optional`): - The name of the repository to keep in sync with the local `output_dir`. It can be a simple model ID in - which case the model will be pushed in your namespace. Otherwise it should be the whole repository name, - for instance :obj:`"user_name/model"`, which allows you to push to an organization you are a member of with - :obj:`"organization_name/model"`. - - Will default to to the name of :obj:`output_dir`. - hub_token (:obj:`str`, `optional`): - The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with - :obj:`huggingface-cli login`. - checkpoint (:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether to save full training checkpoints (including epoch and optimizer state) to allow training to be - resumed. Only usable when `save_strategy` is `epoch`. + Callback for pushing the model to the Hub after training. + + Args: + output_dir (:obj:`str`): + The output directory where the model predictions and checkpoints will be written and synced with the + repository on the Hub. + save_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"epoch"`): + The checkpoint save strategy to adopt during training. Possible values are: + + * :obj:`"no"`: No save is done during training. + * :obj:`"epoch"`: Save is done at the end of each epoch. + * :obj:`"steps"`: Save is done every :obj:`save_steps` + save_steps (:obj:`int`, `optional`): + The number of steps between saves when using the "steps" save_strategy. + tokenizer (:obj:`PreTrainedTokenizerBase`, `optional`): + The tokenizer used by the model. If supplied, will be uploaded to the repo alongside the weights. + hub_model_id (:obj:`str`, `optional`): + The name of the repository to keep in sync with the local `output_dir`. It can be a simple model ID in + which case the model will be pushed in your namespace. Otherwise it should be the whole repository name, + for instance :obj:`"user_name/model"`, which allows you to push to an organization you are a member of with + :obj:`"organization_name/model"`. + + Will default to to the name of :obj:`output_dir`. + hub_token (:obj:`str`, `optional`): + The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with + :obj:`huggingface-cli login`. + checkpoint (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to save full training checkpoints (including epoch and optimizer state) to allow training to be + resumed. Only usable when `save_strategy` is `epoch`. """ super().__init__() if checkpoint and save_strategy != "epoch": From e1c8adbe2ef933fe96729b0875da579e44d27519 Mon Sep 17 00:00:00 2001 From: matt Date: Tue, 21 Dec 2021 17:11:59 +0000 Subject: [PATCH 02/17] First working version --- src/transformers/keras_callbacks.py | 110 ++++++++++++++++++---------- 1 file changed, 73 insertions(+), 37 deletions(-) diff --git a/src/transformers/keras_callbacks.py b/src/transformers/keras_callbacks.py index 1bc9efc55373f8..4229ffabbb86d9 100644 --- a/src/transformers/keras_callbacks.py +++ b/src/transformers/keras_callbacks.py @@ -2,9 +2,9 @@ import os from pathlib import Path from time import sleep -from typing import Optional, Union, Callable -import numpy as np +from typing import Callable, List, Optional, Union +import numpy as np import tensorflow as tf from tensorflow.keras.callbacks import Callback @@ -26,58 +26,94 @@ class KerasMetricCallback(Callback): metric_fn: Metric function provided by the user. val_dataset: Validation data to be used to evaluate the model at the end of the epoch. - metric_name: Name of the metric calculated in metric_fn. batch_size: Batch size. labels: Labels. """ - def __init__(self, metric_fn: Callable, - val_dataset: Union[tf.data.Dataset, np.ndarray, tf.Tensor, tuple, dict], - metric_name: Optional[str], - label_names: Optional[str], - batch_size: Optional[int] = None): + def __init__( + self, + metric_fn: Callable, + val_dataset: Union[tf.data.Dataset, np.ndarray, tf.Tensor, tuple, dict], + output_cols: Optional[List[str]] = None, + label_cols: Optional[List[str]] = None, + batch_size: Optional[int] = None, + predict_with_generate: Optional[bool] = False, + ): super().__init__() self.metric_fn = metric_fn self.batch_size = batch_size if not isinstance(val_dataset, tf.data.Dataset): if batch_size is None: - raise ValueError("When passing data to KerasMetricCallback that is not a pre-batched tf.data.Dataset " - "the batch_size argument must be set.") + raise ValueError( + "When passing data to KerasMetricCallback that is not a pre-batched tf.data.Dataset " + "the batch_size argument must be set." + ) # Wrap a tf.data.Dataset around it val_dataset = tf.data.Dataset.from_tensor_slices(val_dataset).batch(batch_size, drop_remainder=False) self.val_dataset = val_dataset - self.metric_name = metric_name - self.label_names = "labels" if label_names is None else label_names + self.predict_with_generate = predict_with_generate + self.output_cols = output_cols - def on_epoch_end(self, epoch, logs=None): + # This next block attempts to parse out which elements of the dataset should be appended to the labels list + # that is passed to the metric_fn + if isinstance(val_dataset.element_spec, tuple) and len(val_dataset.element_spec) == 2: + input_spec, label_spec = val_dataset.element_spec + else: + input_spec = val_dataset.element_spec + label_spec = None + if label_cols is not None: + for label in label_cols: + if label not in input_spec: + raise ValueError(f"Label {label} is in label_cols but could not be found in the dataset inputs!") + self.label_cols = label_cols + self.use_keras_label = False + elif label_spec is not None: + # If the dataset inputs are split into a 2-tuple of inputs and labels, + # assume the second element is the labels + self.label_cols = None + self.use_keras_label = True + elif "labels" in input_spec: + self.label_cols = ["labels"] + self.use_keras_label = False + logging.warning("No label_cols specified for KerasMetricCallback, assuming you want the 'labels' key.") + else: + raise ValueError("Could not autodetect label_cols for KerasMetricCallback, please specify them!") + def on_epoch_end(self, epoch, logs=None): prediction_list = [] label_list = [] + # The whole predict/generate loop is handled inside this method for batch in self.val_dataset: - if isinstance(batch, tuple): batch, labels = batch - labels = np.asarray(labels) - - elif isinstance(batch, dict): - labels = np.asarray(batch["labels"]) - - predictions = self.model.predict(batch) - predictions_dict = dict(predictions) - - for prediction in predictions_dict["logits"]: + else: + labels = None + if self.predict_with_generate: + predictions = self.model.generate(batch) + else: + predictions = self.model.predict(batch) + predictions = dict(predictions) + if self.output_cols is not None: + prediction_list.append({key: predictions[key] for key in self.output_cols}) + else: prediction_list.append(predictions) + if self.use_keras_label: + label_list.append(labels) + else: + label_list.append({key: batch[key] for key in self.label_cols}) + + metric_output = self.metric_fn(prediction_list, label_list) + if not isinstance(metric_output, dict): + raise TypeError( + f"metric_fn should return a dict mapping metric names to values but instead returned {metric_output}" + ) + # This is the critical bit - Keras passes a dict containing the loss and standard metric values for this epoch + # in the logs argument. Ordinarily, this is so the callback can read them, but in this case we write a bunch of + # new keys in there, which will then get read by the History callback and treated like any other metric value. + # I promise that I have it in writing from Chollet that this is okay. + logs.update(metric_output) - for label in labels: - label_list.append(label) - - metric_value = self.metric_fn(predictions=np.asarray(prediction_list), labels=np.asarray(label_list)) - - if metric_name is not None: - print(f"{self.metric_name} for epoch {epoch} is {metric_value}") - else: - print(f"At epoch {epoch}: {metric_value}") class PushToHubCallback(Callback): def __init__( @@ -110,14 +146,14 @@ def __init__( The tokenizer used by the model. If supplied, will be uploaded to the repo alongside the weights. hub_model_id (:obj:`str`, `optional`): The name of the repository to keep in sync with the local `output_dir`. It can be a simple model ID in - which case the model will be pushed in your namespace. Otherwise it should be the whole repository name, - for instance :obj:`"user_name/model"`, which allows you to push to an organization you are a member of with - :obj:`"organization_name/model"`. + which case the model will be pushed in your namespace. Otherwise it should be the whole repository + name, for instance :obj:`"user_name/model"`, which allows you to push to an organization you are a + member of with :obj:`"organization_name/model"`. Will default to to the name of :obj:`output_dir`. hub_token (:obj:`str`, `optional`): - The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with - :obj:`huggingface-cli login`. + The token to use to push the model to the Hub. Will default to the token in the cache folder obtained + with :obj:`huggingface-cli login`. checkpoint (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether to save full training checkpoints (including epoch and optimizer state) to allow training to be resumed. Only usable when `save_strategy` is `epoch`. From d1157057e3b91a133ffc4e7a78b4275b9934e9db Mon Sep 17 00:00:00 2001 From: matt Date: Tue, 21 Dec 2021 19:09:26 +0000 Subject: [PATCH 03/17] Fixed concatenation of outputs and labels --- src/transformers/keras_callbacks.py | 41 ++++++++++++++++++++++++----- 1 file changed, 35 insertions(+), 6 deletions(-) diff --git a/src/transformers/keras_callbacks.py b/src/transformers/keras_callbacks.py index 4229ffabbb86d9..6fc584f4e6a96d 100644 --- a/src/transformers/keras_callbacks.py +++ b/src/transformers/keras_callbacks.py @@ -79,6 +79,26 @@ def __init__( else: raise ValueError("Could not autodetect label_cols for KerasMetricCallback, please specify them!") + @staticmethod + def _concatenate_batches(batches): + # Flattens Numpy array batches into a list of single samples, where each sample is still np.ndarray + return [sample for batch in batches for sample in batch] + + def _postprocess_predictions_or_labels(self, inputs): + if isinstance(inputs[0], dict): + outputs = dict() + for key in inputs[0].keys(): + outputs[key] = self._concatenate_batches(batch[key] for batch in inputs) + elif isinstance(inputs[0], list) or isinstance(inputs[0], tuple): + outputs = [] + for input_list in zip(*inputs): + outputs.append(self._concatenate_batches(input_list)) + elif isinstance(inputs[0], np.ndarray): + outputs = self._concatenate_batches(inputs) + else: + raise TypeError(f"Couldn't handle batch of type {type(inputs[0])}!") + return outputs + def on_epoch_end(self, epoch, logs=None): prediction_list = [] label_list = [] @@ -95,13 +115,22 @@ def on_epoch_end(self, epoch, logs=None): predictions = self.model.predict(batch) predictions = dict(predictions) if self.output_cols is not None: - prediction_list.append({key: predictions[key] for key in self.output_cols}) + predictions = {key: predictions[key] for key in self.output_cols} + prediction_list.append(predictions) + if not self.use_keras_label: + labels = {key: batch[key].numpy() for key in self.label_cols} + elif isinstance(labels, dict): + labels = {key: array.numpy() for key, array in labels.items()} + elif isinstance(labels, list) or isinstance(labels, tuple): + labels = [array.numpy() for array in labels] + elif isinstance(labels, tf.Tensor): + labels = labels.numpy() else: - prediction_list.append(predictions) - if self.use_keras_label: - label_list.append(labels) - else: - label_list.append({key: batch[key] for key in self.label_cols}) + raise TypeError(f"Confused by labels of type {type(labels)}") + label_list.append(labels) + + prediction_list = self._postprocess_predictions_or_labels(prediction_list) + label_list = self._postprocess_predictions_or_labels(label_list) metric_output = self.metric_fn(prediction_list, label_list) if not isinstance(metric_output, dict): From db0350cdc8742f7b4e177cf247caa25a33423d7e Mon Sep 17 00:00:00 2001 From: matt Date: Wed, 22 Dec 2021 15:59:42 +0000 Subject: [PATCH 04/17] val_dataset -> eval_dataset --- src/transformers/keras_callbacks.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/transformers/keras_callbacks.py b/src/transformers/keras_callbacks.py index 6fc584f4e6a96d..ea89e3d6c7aca5 100644 --- a/src/transformers/keras_callbacks.py +++ b/src/transformers/keras_callbacks.py @@ -24,7 +24,7 @@ class KerasMetricCallback(Callback): Args: metric_fn: Metric function provided by the user. - val_dataset: Validation data to be used to evaluate the model at + eval_dataset: Validation data to be used to evaluate the model at the end of the epoch. batch_size: Batch size. labels: Labels. @@ -33,7 +33,7 @@ class KerasMetricCallback(Callback): def __init__( self, metric_fn: Callable, - val_dataset: Union[tf.data.Dataset, np.ndarray, tf.Tensor, tuple, dict], + eval_dataset: Union[tf.data.Dataset, np.ndarray, tf.Tensor, tuple, dict], output_cols: Optional[List[str]] = None, label_cols: Optional[List[str]] = None, batch_size: Optional[int] = None, @@ -42,24 +42,24 @@ def __init__( super().__init__() self.metric_fn = metric_fn self.batch_size = batch_size - if not isinstance(val_dataset, tf.data.Dataset): + if not isinstance(eval_dataset, tf.data.Dataset): if batch_size is None: raise ValueError( "When passing data to KerasMetricCallback that is not a pre-batched tf.data.Dataset " "the batch_size argument must be set." ) # Wrap a tf.data.Dataset around it - val_dataset = tf.data.Dataset.from_tensor_slices(val_dataset).batch(batch_size, drop_remainder=False) - self.val_dataset = val_dataset + eval_dataset = tf.data.Dataset.from_tensor_slices(eval_dataset).batch(batch_size, drop_remainder=False) + self.eval_dataset = eval_dataset self.predict_with_generate = predict_with_generate self.output_cols = output_cols # This next block attempts to parse out which elements of the dataset should be appended to the labels list # that is passed to the metric_fn - if isinstance(val_dataset.element_spec, tuple) and len(val_dataset.element_spec) == 2: - input_spec, label_spec = val_dataset.element_spec + if isinstance(eval_dataset.element_spec, tuple) and len(eval_dataset.element_spec) == 2: + input_spec, label_spec = eval_dataset.element_spec else: - input_spec = val_dataset.element_spec + input_spec = eval_dataset.element_spec label_spec = None if label_cols is not None: for label in label_cols: @@ -104,7 +104,7 @@ def on_epoch_end(self, epoch, logs=None): label_list = [] # The whole predict/generate loop is handled inside this method - for batch in self.val_dataset: + for batch in self.eval_dataset: if isinstance(batch, tuple): batch, labels = batch else: From 29560f37853fe3d79c8d5f0c3fe545107a03de47 Mon Sep 17 00:00:00 2001 From: matt Date: Wed, 22 Dec 2021 16:03:36 +0000 Subject: [PATCH 05/17] Only pass input arrays in tokenizer.model_input_names --- src/transformers/keras_callbacks.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/transformers/keras_callbacks.py b/src/transformers/keras_callbacks.py index ea89e3d6c7aca5..241b827945e8bd 100644 --- a/src/transformers/keras_callbacks.py +++ b/src/transformers/keras_callbacks.py @@ -33,6 +33,7 @@ class KerasMetricCallback(Callback): def __init__( self, metric_fn: Callable, + tokenizer: PreTrainedTokenizerBase, eval_dataset: Union[tf.data.Dataset, np.ndarray, tf.Tensor, tuple, dict], output_cols: Optional[List[str]] = None, label_cols: Optional[List[str]] = None, @@ -53,6 +54,7 @@ def __init__( self.eval_dataset = eval_dataset self.predict_with_generate = predict_with_generate self.output_cols = output_cols + self.model_input_names = tokenizer.model_input_names # This next block attempts to parse out which elements of the dataset should be appended to the labels list # that is passed to the metric_fn @@ -109,6 +111,9 @@ def on_epoch_end(self, epoch, logs=None): batch, labels = batch else: labels = None + if isinstance(batch, dict): + batch = {key: array for key, array in batch.items() + if key in self.model_input_names} if self.predict_with_generate: predictions = self.model.generate(batch) else: From a77fee15a0ac4d2ad546ef9948888f4fadbead7a Mon Sep 17 00:00:00 2001 From: matt Date: Wed, 22 Dec 2021 16:03:56 +0000 Subject: [PATCH 06/17] Only pass input arrays in tokenizer.model_input_names --- src/transformers/keras_callbacks.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/transformers/keras_callbacks.py b/src/transformers/keras_callbacks.py index 241b827945e8bd..6cbe4f3d0485fb 100644 --- a/src/transformers/keras_callbacks.py +++ b/src/transformers/keras_callbacks.py @@ -112,8 +112,7 @@ def on_epoch_end(self, epoch, logs=None): else: labels = None if isinstance(batch, dict): - batch = {key: array for key, array in batch.items() - if key in self.model_input_names} + batch = {key: array for key, array in batch.items() if key in self.model_input_names} if self.predict_with_generate: predictions = self.model.generate(batch) else: From bd0875819893df8ed8634d9794e8f0cbcbf1c823 Mon Sep 17 00:00:00 2001 From: matt Date: Wed, 22 Dec 2021 17:04:29 +0000 Subject: [PATCH 07/17] Only remove unexpected keys when predict_with_generate is True --- src/transformers/keras_callbacks.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/transformers/keras_callbacks.py b/src/transformers/keras_callbacks.py index 6cbe4f3d0485fb..fb66043b8f6f4a 100644 --- a/src/transformers/keras_callbacks.py +++ b/src/transformers/keras_callbacks.py @@ -23,8 +23,9 @@ class KerasMetricCallback(Callback): Callback to prompt metrics at the end of every epoch. Args: - metric_fn: Metric function provided by the user. - eval_dataset: Validation data to be used to evaluate the model at + metric_fn (`Callable`): + Metric function provided by the user. + eval_dataset (`tf.data.Dataset` or `dict` or `tuple` ) : Validation data to be used to evaluate the model at the end of the epoch. batch_size: Batch size. labels: Labels. @@ -111,9 +112,10 @@ def on_epoch_end(self, epoch, logs=None): batch, labels = batch else: labels = None - if isinstance(batch, dict): - batch = {key: array for key, array in batch.items() if key in self.model_input_names} if self.predict_with_generate: + if isinstance(batch, dict): + # generate() gets stressed out by any unexpected keys + batch = {key: array for key, array in batch.items() if key in self.model_input_names} predictions = self.model.generate(batch) else: predictions = self.model.predict(batch) From 700b4b591ee0287dc769b41765de96e32d25f7c8 Mon Sep 17 00:00:00 2001 From: matt Date: Wed, 22 Dec 2021 18:32:12 +0000 Subject: [PATCH 08/17] Adding proper docstring --- src/transformers/keras_callbacks.py | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/src/transformers/keras_callbacks.py b/src/transformers/keras_callbacks.py index fb66043b8f6f4a..4d1ca6647bbe9f 100644 --- a/src/transformers/keras_callbacks.py +++ b/src/transformers/keras_callbacks.py @@ -24,18 +24,31 @@ class KerasMetricCallback(Callback): Args: metric_fn (`Callable`): - Metric function provided by the user. - eval_dataset (`tf.data.Dataset` or `dict` or `tuple` ) : Validation data to be used to evaluate the model at - the end of the epoch. - batch_size: Batch size. - labels: Labels. + Metric function provided by the user. It will be called with two arguments - `predictions` and `labels`. + These contain the model's outputs and matching labels from the dataset. It should return a dict mapping + metric names to numerical values. + eval_dataset (`tf.data.Dataset` or `dict` or `tuple` or `np.ndarray` or `tf.Tensor`): + Validation data to be used to generate predictions for the `metric_fn`. + tokenizer ([`PretrainedTokenizerBase`], *optional*): + Tokenizer used to validate column names to be passed to the generate() function. Required only if + predict_with_generate is True. + output_cols: (`List[str], *optional*): + A list of columns to be retained from the model output as the predictions. Defaults to all. + label_cols: ('`List[str]`, *optional*'): + A list of columns to be retained from the input dataset as the labels. Will be autodetected if this is not + supplied. + batch_size (`int`, *optional*): + Batch size. Only used when the data is not a pre-batched `tf.data.Dataset`. + predict_with_generate: (`bool`, *optional*, defaults to *False*): + Whether we should use `model.generate()` to get outputs for the model. + """ def __init__( self, metric_fn: Callable, - tokenizer: PreTrainedTokenizerBase, eval_dataset: Union[tf.data.Dataset, np.ndarray, tf.Tensor, tuple, dict], + tokenizer: Optional[PreTrainedTokenizerBase] = None, output_cols: Optional[List[str]] = None, label_cols: Optional[List[str]] = None, batch_size: Optional[int] = None, @@ -44,6 +57,8 @@ def __init__( super().__init__() self.metric_fn = metric_fn self.batch_size = batch_size + if predict_with_generate and tokenizer is None: + raise ValueError("A tokenizer is required when using predict_with_generate!") if not isinstance(eval_dataset, tf.data.Dataset): if batch_size is None: raise ValueError( From 9d175b60b041403c97aa444b5777f1e48e3f5520 Mon Sep 17 00:00:00 2001 From: matt Date: Wed, 22 Dec 2021 18:49:05 +0000 Subject: [PATCH 09/17] Adding example to docstring --- src/transformers/keras_callbacks.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/src/transformers/keras_callbacks.py b/src/transformers/keras_callbacks.py index 4d1ca6647bbe9f..d7e85b0dd1d9bf 100644 --- a/src/transformers/keras_callbacks.py +++ b/src/transformers/keras_callbacks.py @@ -20,7 +20,22 @@ class KerasMetricCallback(Callback): """ - Callback to prompt metrics at the end of every epoch. + Callback to compute metrics at the end of every epoch. Unlike normal Keras metrics, these do not need to be + compilable by TF. It is particularly useful for common NLP metrics like BLEU and ROUGE that require string + operations or generation loops that cannot be compiled. Predictions (or generations) will be computed on the + `eval_dataset` before being passed to the `metric_fn` in `np.ndarray` format. The `metric_fn` should compute + metrics and return a dict mapping metric names to metric values. + + A simple example of a suitable metric_fn that computes accuracy on a SequenceClassification model: + + ```py + def accuracy_fn(predictions, labels): + class_predictions = np.argmax(predictions['logits'], axis=-1) + correct = np.sum(class_predictions == labels) + return {"accuracy": correct / len(labels)} + ``` + + In practice, of course, functions this simple should usually be implemented as a straightforward Keras metric! Args: metric_fn (`Callable`): @@ -70,7 +85,10 @@ def __init__( self.eval_dataset = eval_dataset self.predict_with_generate = predict_with_generate self.output_cols = output_cols - self.model_input_names = tokenizer.model_input_names + if tokenizer is not None: + self.model_input_names = tokenizer.model_input_names + else: + self.model_input_names = None # This next block attempts to parse out which elements of the dataset should be appended to the labels list # that is passed to the metric_fn From 00fbff683e406ca4db6592c667050ac470316f6c Mon Sep 17 00:00:00 2001 From: matt Date: Wed, 22 Dec 2021 19:10:29 +0000 Subject: [PATCH 10/17] Add a proper ROUGE metric example --- src/transformers/keras_callbacks.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/src/transformers/keras_callbacks.py b/src/transformers/keras_callbacks.py index 815e2f5d67cade..5a211c365ffbf4 100644 --- a/src/transformers/keras_callbacks.py +++ b/src/transformers/keras_callbacks.py @@ -26,16 +26,26 @@ class KerasMetricCallback(Callback): `eval_dataset` before being passed to the `metric_fn` in `np.ndarray` format. The `metric_fn` should compute metrics and return a dict mapping metric names to metric values. - A simple example of a suitable metric_fn that computes accuracy on a SequenceClassification model: + An example of a suitable metric_fn that computes ROUGE scores for a summarization model: ```py - def accuracy_fn(predictions, labels): - class_predictions = np.argmax(predictions['logits'], axis=-1) - correct = np.sum(class_predictions == labels) - return {"accuracy": correct / len(labels)} + from datasets import load_metric + rouge_metric = load_metric("rouge") + + def rouge_fn(predictions, labels): + # Note that this example skips some post-processing for readability and simplicity, + # and may not be directly applicable to any particular use-case + decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)) + decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) + result = rouge_metric.compute(predictions=decoded_predictions, references=decoded_labels) + return {key: value.mid.fmeasure * 100 for key, value in result.items()} ``` - In practice, of course, functions this simple should usually be implemented as a straightforward Keras metric! + The above function will return a dict containing values which will be logged like any other Keras metric: + + ```py + {'rouge1': 37.4199, 'rouge2': 13.9768, 'rougeL': 34.361, 'rougeLsum': 35.0781 + ``` Args: metric_fn (`Callable`): @@ -44,6 +54,8 @@ def accuracy_fn(predictions, labels): metric names to numerical values. eval_dataset (`tf.data.Dataset` or `dict` or `tuple` or `np.ndarray` or `tf.Tensor`): Validation data to be used to generate predictions for the `metric_fn`. + metric_fn_kwargs (`dict`, *optional*): + Additional keyword arguments to be passed to the metric_fn. tokenizer ([`PretrainedTokenizerBase`], *optional*): Tokenizer used to validate column names to be passed to the generate() function. Required only if predict_with_generate is True. @@ -64,6 +76,7 @@ def __init__( metric_fn: Callable, eval_dataset: Union[tf.data.Dataset, np.ndarray, tf.Tensor, tuple, dict], tokenizer: Optional[PreTrainedTokenizerBase] = None, + metric_fn_kwargs: Optional[dict] = None, output_cols: Optional[List[str]] = None, label_cols: Optional[List[str]] = None, batch_size: Optional[int] = None, @@ -85,6 +98,7 @@ def __init__( self.eval_dataset = eval_dataset self.predict_with_generate = predict_with_generate self.output_cols = output_cols + self.metric_fn_kwargs = metric_fn_kwargs or dict() if tokenizer is not None: self.model_input_names = tokenizer.model_input_names else: @@ -171,7 +185,7 @@ def on_epoch_end(self, epoch, logs=None): prediction_list = self._postprocess_predictions_or_labels(prediction_list) label_list = self._postprocess_predictions_or_labels(label_list) - metric_output = self.metric_fn(prediction_list, label_list) + metric_output = self.metric_fn(prediction_list, label_list, **self.metric_fn_kwargs) if not isinstance(metric_output, dict): raise TypeError( f"metric_fn should return a dict mapping metric names to values but instead returned {metric_output}" From e5c59b06252fff0a731e74b82500e3ec62ef3062 Mon Sep 17 00:00:00 2001 From: matt Date: Wed, 22 Dec 2021 19:12:02 +0000 Subject: [PATCH 11/17] Add a proper ROUGE metric example --- src/transformers/keras_callbacks.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/keras_callbacks.py b/src/transformers/keras_callbacks.py index 5a211c365ffbf4..b6a1e333ff7674 100644 --- a/src/transformers/keras_callbacks.py +++ b/src/transformers/keras_callbacks.py @@ -26,15 +26,15 @@ class KerasMetricCallback(Callback): `eval_dataset` before being passed to the `metric_fn` in `np.ndarray` format. The `metric_fn` should compute metrics and return a dict mapping metric names to metric values. - An example of a suitable metric_fn that computes ROUGE scores for a summarization model: + We provide an example of a suitable metric_fn that computes ROUGE scores for a summarization model below. + Note that this example skips some post-processing for readability and simplicity, and should probably + not be used as-is! ```py from datasets import load_metric rouge_metric = load_metric("rouge") def rouge_fn(predictions, labels): - # Note that this example skips some post-processing for readability and simplicity, - # and may not be directly applicable to any particular use-case decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)) decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) result = rouge_metric.compute(predictions=decoded_predictions, references=decoded_labels) From 9016bc932e56c0c23bd976af752298e30ef16ddb Mon Sep 17 00:00:00 2001 From: matt Date: Wed, 22 Dec 2021 19:21:48 +0000 Subject: [PATCH 12/17] Add version checking --- src/transformers/keras_callbacks.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/transformers/keras_callbacks.py b/src/transformers/keras_callbacks.py index b6a1e333ff7674..71e334c3534938 100644 --- a/src/transformers/keras_callbacks.py +++ b/src/transformers/keras_callbacks.py @@ -6,6 +6,7 @@ import numpy as np import tensorflow as tf +from packaging.version import parse from tensorflow.keras.callbacks import Callback from huggingface_hub import Repository @@ -128,6 +129,8 @@ def __init__( logging.warning("No label_cols specified for KerasMetricCallback, assuming you want the 'labels' key.") else: raise ValueError("Could not autodetect label_cols for KerasMetricCallback, please specify them!") + if parse(tf.__version__).minor < parse("2.7"): + logging.warning("TF versions less than 2.7 may encounter issues with KerasMetricCallback!") @staticmethod def _concatenate_batches(batches): From 2b44bd566a1ab86493455d43c6ff6be6b63f3481 Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 22 Dec 2021 19:22:22 +0000 Subject: [PATCH 13/17] Update src/transformers/keras_callbacks.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/keras_callbacks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/keras_callbacks.py b/src/transformers/keras_callbacks.py index 71e334c3534938..dcb57b1a489328 100644 --- a/src/transformers/keras_callbacks.py +++ b/src/transformers/keras_callbacks.py @@ -60,7 +60,7 @@ def rouge_fn(predictions, labels): tokenizer ([`PretrainedTokenizerBase`], *optional*): Tokenizer used to validate column names to be passed to the generate() function. Required only if predict_with_generate is True. - output_cols: (`List[str], *optional*): + output_cols (`List[str], *optional*): A list of columns to be retained from the model output as the predictions. Defaults to all. label_cols: ('`List[str]`, *optional*'): A list of columns to be retained from the input dataset as the labels. Will be autodetected if this is not From fb89390ffe3f8f2746e034227233348a72f2c0bc Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 22 Dec 2021 19:22:27 +0000 Subject: [PATCH 14/17] Update src/transformers/keras_callbacks.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/keras_callbacks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/keras_callbacks.py b/src/transformers/keras_callbacks.py index dcb57b1a489328..a9aaa9fa61a0e0 100644 --- a/src/transformers/keras_callbacks.py +++ b/src/transformers/keras_callbacks.py @@ -62,7 +62,7 @@ def rouge_fn(predictions, labels): predict_with_generate is True. output_cols (`List[str], *optional*): A list of columns to be retained from the model output as the predictions. Defaults to all. - label_cols: ('`List[str]`, *optional*'): + label_cols ('`List[str]`, *optional*'): A list of columns to be retained from the input dataset as the labels. Will be autodetected if this is not supplied. batch_size (`int`, *optional*): From 340d3bf5dba90bdfa0a36433c6f22358752c6483 Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 22 Dec 2021 19:22:30 +0000 Subject: [PATCH 15/17] Update src/transformers/keras_callbacks.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/keras_callbacks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/keras_callbacks.py b/src/transformers/keras_callbacks.py index a9aaa9fa61a0e0..cf2bf75e2fa0a8 100644 --- a/src/transformers/keras_callbacks.py +++ b/src/transformers/keras_callbacks.py @@ -67,7 +67,7 @@ def rouge_fn(predictions, labels): supplied. batch_size (`int`, *optional*): Batch size. Only used when the data is not a pre-batched `tf.data.Dataset`. - predict_with_generate: (`bool`, *optional*, defaults to *False*): + predict_with_generate: (`bool`, *optional*, defaults to `False`): Whether we should use `model.generate()` to get outputs for the model. """ From d9df77bb6ae71dd523eddf2920af42cc23e95ab6 Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 22 Dec 2021 19:22:34 +0000 Subject: [PATCH 16/17] Update src/transformers/keras_callbacks.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/keras_callbacks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/keras_callbacks.py b/src/transformers/keras_callbacks.py index cf2bf75e2fa0a8..1fdf63366e5702 100644 --- a/src/transformers/keras_callbacks.py +++ b/src/transformers/keras_callbacks.py @@ -44,7 +44,7 @@ def rouge_fn(predictions, labels): The above function will return a dict containing values which will be logged like any other Keras metric: - ```py + ``` {'rouge1': 37.4199, 'rouge2': 13.9768, 'rougeL': 34.361, 'rougeLsum': 35.0781 ``` From 9947001e7cf302cbb3406bae87366d80121dc026 Mon Sep 17 00:00:00 2001 From: matt Date: Wed, 22 Dec 2021 19:26:41 +0000 Subject: [PATCH 17/17] Remove requirement for tokenizer with predict_with_generate --- src/transformers/keras_callbacks.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/transformers/keras_callbacks.py b/src/transformers/keras_callbacks.py index 1fdf63366e5702..de0b3b8bfd5fb3 100644 --- a/src/transformers/keras_callbacks.py +++ b/src/transformers/keras_callbacks.py @@ -58,8 +58,7 @@ def rouge_fn(predictions, labels): metric_fn_kwargs (`dict`, *optional*): Additional keyword arguments to be passed to the metric_fn. tokenizer ([`PretrainedTokenizerBase`], *optional*): - Tokenizer used to validate column names to be passed to the generate() function. Required only if - predict_with_generate is True. + Tokenizer used to validate column names to be passed to the generate() function. output_cols (`List[str], *optional*): A list of columns to be retained from the model output as the predictions. Defaults to all. label_cols ('`List[str]`, *optional*'): @@ -86,8 +85,6 @@ def __init__( super().__init__() self.metric_fn = metric_fn self.batch_size = batch_size - if predict_with_generate and tokenizer is None: - raise ValueError("A tokenizer is required when using predict_with_generate!") if not isinstance(eval_dataset, tf.data.Dataset): if batch_size is None: raise ValueError( @@ -103,7 +100,7 @@ def __init__( if tokenizer is not None: self.model_input_names = tokenizer.model_input_names else: - self.model_input_names = None + self.model_input_names = ["input_ids"] # This next block attempts to parse out which elements of the dataset should be appended to the labels list # that is passed to the metric_fn