From 1457ac6a17394e20d182db76867308b4695545d3 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 11 Aug 2022 15:30:08 +0200 Subject: [PATCH 1/2] Support audio classification architectures for labels generation, as well as provides a flag to print warnings or not --- src/transformers/utils/fx.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index 2198928eadb325..149fc231ff4d88 100644 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -19,6 +19,7 @@ import inspect import math import operator +import os import random import warnings from typing import Any, Callable, Dict, List, Optional, Type, Union @@ -53,6 +54,7 @@ logger = logging.get_logger(__name__) +_IS_IN_DEBUG_MODE = os.environ.get("FX_DEBUG_MODE", False) in ["true", "True", "1"] def _generate_supported_model_class_names( @@ -678,7 +680,12 @@ def _generate_dummy_input( if input_name in ["labels", "start_positions", "end_positions"]: batch_size = shape[0] - if model_class_name in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES): + if model_class_name in [ + *get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES), + *get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES), + *get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES), + *get_values(MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES), + ]: inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device) elif model_class_name in [ *get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES), @@ -710,11 +717,6 @@ def _generate_dummy_input( ) inputs_dict["labels"] = torch.zeros(*labels_shape, dtype=labels_dtype, device=device) - elif model_class_name in [ - *get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES), - *get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES), - ]: - inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device) elif model_class_name in [ *get_values(MODEL_FOR_PRETRAINING_MAPPING_NAMES), *get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES), @@ -725,7 +727,9 @@ def _generate_dummy_input( ]: inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device) else: - raise NotImplementedError(f"{model_class_name} not supported yet.") + raise NotImplementedError( + f"Generating the dummy input named {input_name} for {model_class_name} is not supported yet." + ) elif "pixel_values" in input_name: batch_size = shape[0] image_size = getattr(model.config, "image_size", None) @@ -846,7 +850,8 @@ def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, pr raise ValueError("Don't support composite output yet") rv.install_metadata(meta_out) except Exception as e: - warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}") + if _IS_IN_DEBUG_MODE: + warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}") return rv From 2a07ed02d7c230f5b590b99dea30a7a357f212b7 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 11 Aug 2022 16:21:46 +0200 Subject: [PATCH 2/2] Use ENV_VARS_TRUE_VALUES --- src/transformers/utils/fx.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index 149fc231ff4d88..990f278b0d5066 100644 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -49,12 +49,12 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, MODEL_MAPPING_NAMES, ) -from ..utils import TORCH_FX_REQUIRED_VERSION, is_torch_fx_available +from ..utils import ENV_VARS_TRUE_VALUES, TORCH_FX_REQUIRED_VERSION, is_torch_fx_available from ..utils.versions import importlib_metadata logger = logging.get_logger(__name__) -_IS_IN_DEBUG_MODE = os.environ.get("FX_DEBUG_MODE", False) in ["true", "True", "1"] +_IS_IN_DEBUG_MODE = os.environ.get("FX_DEBUG_MODE", "").upper() in ENV_VARS_TRUE_VALUES def _generate_supported_model_class_names(