From 46f78c11496c5814758d48794322d3b4288a5cd4 Mon Sep 17 00:00:00 2001 From: Konstantin Gulin <66528950+KSGulin@users.noreply.github.com> Date: Tue, 18 Oct 2022 16:22:42 +0200 Subject: [PATCH] Upgrade to transformers release V4.23.1 (#62) * Update trainer and model flows to accommodate sparseml Disable FP16 on QAT start (#12) * Override LRScheduler when using LRModifiers * Disable FP16 on QAT start * keep wrapped scaler object for training after disabling Using QATMatMul in DistilBERT model class (#41) Removed double quantization of output of context layer. (#45) Fix DataParallel validation forward signatures (#47) * Fix: DataParallel validation forward signatures * Update: generalize forward_fn selection Best model after epoch (#46) fix sclaer check for non fp16 mode in trainer (#38) Mobilebert QAT (#55) * Remove duplicate quantization of vocabulary. enable a QATWrapper for non-parameterized matmuls in BERT self attention (#9) * Utils and auxillary changes update Zoo stub loading for SparseZoo 1.1 refactor (#54) add flag to signal NM integration is active (#32) Add recipe_name to file names * Fix errors introduced in manual cherry-pick upgrade Co-authored-by: Benjamin Fineran --- src/transformers/hf_argparser.py | 47 +++++++++++++++++-- .../models/distilbert/modeling_distilbert.py | 42 +++++++++++++++-- .../models/mobilebert/modeling_mobilebert.py | 19 +++++++- src/transformers/trainer.py | 38 +++++++++++---- src/transformers/trainer_seq2seq.py | 2 +- src/transformers/utils/import_utils.py | 3 ++ 6 files changed, 133 insertions(+), 18 deletions(-) diff --git a/src/transformers/hf_argparser.py b/src/transformers/hf_argparser.py index 06a10ff5a0554b..fdae8a690f25d0 100644 --- a/src/transformers/hf_argparser.py +++ b/src/transformers/hf_argparser.py @@ -21,9 +21,16 @@ from inspect import isclass from pathlib import Path from typing import Any, Dict, Iterable, NewType, Optional, Tuple, Union, get_type_hints - +import os import yaml +from sparsezoo import Model + +from .utils.logging import get_logger + + +logger = get_logger(__name__) + DataClass = NewType("DataClass", Any) DataClassType = NewType("DataClassType", Any) @@ -229,12 +236,17 @@ def parse_args_into_dataclasses( # additional namespace. outputs.append(namespace) if return_remaining_strings: - return (*outputs, remaining_args) + return tuple( + *[_download_dataclass_zoo_stub_files(output) for output in outputs], + remaining_args, + ) else: if remaining_args: raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {remaining_args}") - return (*outputs,) + return tuple( + [_download_dataclass_zoo_stub_files(output) for output in outputs] + ) def parse_dict(self, args: Dict[str, Any], allow_extra_keys: bool = False) -> Tuple[DataClass, ...]: """ @@ -262,7 +274,9 @@ def parse_dict(self, args: Dict[str, Any], allow_extra_keys: bool = False) -> Tu outputs.append(obj) if not allow_extra_keys and unused_keys: raise ValueError(f"Some keys are not used by the HfArgumentParser: {sorted(unused_keys)}") - return tuple(outputs) + return tuple( + [_download_dataclass_zoo_stub_files(output) for output in outputs] + ) def parse_json_file(self, json_file: str, allow_extra_keys: bool = False) -> Tuple[DataClass, ...]: """ @@ -305,3 +319,28 @@ def parse_yaml_file(self, yaml_file: str, allow_extra_keys: bool = False) -> Tup """ outputs = self.parse_dict(yaml.safe_load(Path(yaml_file).read_text()), allow_extra_keys=allow_extra_keys) return tuple(outputs) + +def _download_dataclass_zoo_stub_files(data_class: DataClass): + for name, val in data_class.__dict__.items(): + if not isinstance(val, str) or "recipe" in name or not val.startswith("zoo:"): + continue + + logger.info(f"Downloading framework files for SparseZoo stub: {val}") + + zoo_model = Model(val) + framework_file_paths = [file.path for file in zoo_model.training.default.files] + assert framework_file_paths, "Unable to download any framework files for SparseZoo stub {val}" + framework_file_names = [os.path.basename(path) for path in framework_file_paths] + if "pytorch_model.bin" not in framework_file_names or ("config.json" not in framework_file_names): + raise RuntimeError( + "Unable to find 'pytorch_model.bin' and 'config.json' in framework " + f"files downloaded from {val}. Found {framework_file_names}. Check " + "if the given stub is for a transformers repo model" + ) + framework_dir_path = Path(framework_file_paths[0]).parent.absolute() + + logger.info(f"Overwriting argument {name} to downloaded {framework_dir_path}") + + data_class.__dict__[name] = str(framework_dir_path) + + return data_class diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index a2713128e901a3..dd31960cefdc97 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -89,6 +89,38 @@ def _create_sinusoidal_embeddings(n_pos: int, dim: int, out: torch.Tensor): out.detach_() +class QATAttentionScores(nn.Module): + def __init__(self): + super().__init__() + + # behaves like normal torch.matmul unless a SparseML QuantizationModifier + # is initialized + self.wrap_qat = True + self.qat_wrapper_kwargs = { + "num_inputs": 2, + "input_qconfigs": ["asymmetric", "symmetric"], + } + + def forward(self, a: torch.Tensor, b: torch.Tensor): + return torch.matmul(a, b) + +class QATContextLayer(nn.Module): + def __init__(self): + super().__init__() + + # behaves like normal torch.matmul unless a SparseML QuantizationModifier + # is initialized + self.wrap_qat = True + self.qat_wrapper_kwargs = { + "num_inputs": 2, + "num_outputs": 0, + "input_qconfigs": ["asymmetric", "symmetric"], + } + + def forward(self, a: torch.Tensor, b: torch.Tensor): + return torch.matmul(a, b) + + class Embeddings(nn.Module): def __init__(self, config: PretrainedConfig): super().__init__() @@ -150,6 +182,11 @@ def __init__(self, config: PretrainedConfig): self.pruned_heads: Set[int] = set() + # non-parameterized matmuls will behave as normal torch.matmul ops unless + # Quantization-Aware-Training is invoked + self.attention_scores_matmul = QATAttentionScores() + self.context_layer_matmul = QATContextLayer() + def prune_heads(self, heads: List[int]): attention_head_size = self.dim // self.n_heads if len(heads) == 0: @@ -207,7 +244,7 @@ def unshape(x: torch.Tensor) -> torch.Tensor: v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head) q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head) - scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length) + scores = self.attention_scores_matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length) mask = (mask == 0).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length) scores = scores.masked_fill( mask, torch.tensor(torch.finfo(scores.dtype).min) @@ -220,7 +257,7 @@ def unshape(x: torch.Tensor) -> torch.Tensor: if head_mask is not None: weights = weights * head_mask - context = torch.matmul(weights, v) # (bs, n_heads, q_length, dim_per_head) + context = self.context_layer_matmul(weights, v) # (bs, n_heads, q_length, dim_per_head) context = unshape(context) # (bs, q_length, dim) context = self.out_lin(context) # (bs, q_length, dim) @@ -645,7 +682,6 @@ def forward( loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict - dlbrt_output = self.distilbert( input_ids=input_ids, attention_mask=attention_mask, diff --git a/src/transformers/models/mobilebert/modeling_mobilebert.py b/src/transformers/models/mobilebert/modeling_mobilebert.py index 6bc306a6e05eb3..af7ceeee02a903 100644 --- a/src/transformers/models/mobilebert/modeling_mobilebert.py +++ b/src/transformers/models/mobilebert/modeling_mobilebert.py @@ -170,6 +170,23 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: NORM2FN = {"layer_norm": nn.LayerNorm, "no_norm": NoNorm} +class QATEmbeddingTransformation(nn.Module): + def __init__(self, embedded_input_size, hidden_size): + super().__init__() + + # Behaves like normal Linear module unless a SparseML QuantizationModifier + # is initialized. + # When initialized, does not quantize inputs. + # Only weights are quantized (inputs come quantized from embeddings) + self.linear = nn.Linear(embedded_input_size, hidden_size) + self.wrap_qat = True + self.qat_wrapper_kwargs = { + "num_inputs": 0, + "num_outputs": 1, + } + + def forward(self, x: torch.Tensor): + return self.linear(x) class MobileBertEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings.""" @@ -186,7 +203,7 @@ def __init__(self, config): embed_dim_multiplier = 3 if self.trigram_input else 1 embedded_input_size = self.embedding_size * embed_dim_multiplier - self.embedding_transformation = nn.Linear(embedded_input_size, config.hidden_size) + self.embedding_transformation = QATEmbeddingTransformation(embedded_input_size, config.hidden_size) self.LayerNorm = NORM2FN[config.normalization_type](config.hidden_size) self.dropout = nn.Dropout(config.hidden_dropout_prob) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 214e7a9789d2c7..abb967a2e216dd 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1687,6 +1687,10 @@ def _inner_training_loop( _ = list(train_dataloader.sampler) for epoch in range(epochs_trained, num_train_epochs): + if self.use_cuda_amp and hasattr(self, "qat_active") and callable(self.qat_active) and self.qat_active(epoch): + logger.info("entering QAT phase, disabling FP16 training") + self.scaler._enabled = False + if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): train_dataloader.sampler.set_epoch(epoch) elif hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDatasetShard): @@ -2167,7 +2171,12 @@ def _save_checkpoint(self, model, trial, metrics=None): torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) # Determine the new best metric / best model checkpoint - if metrics is not None and self.args.metric_for_best_model is not None: + if ( + metrics is not None + and self.args.metric_for_best_model is not None + and self.args.best_model_after_epoch is not None + and self.state.epoch > self.args.best_model_after_epoch + ): metric_to_check = self.args.metric_for_best_model if not metric_to_check.startswith("eval_"): metric_to_check = f"eval_{metric_to_check}" @@ -2421,14 +2430,14 @@ def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[s return inputs - def compute_loss_context_manager(self): + def compute_loss_context_manager(self, enabled): """ A helper wrapper to group together context managers. """ return ContextManagers( [ self.torchdynamo_smart_context_manager(), - self.autocast_smart_context_manager(), + self.autocast_smart_context_manager(enabled=enabled), ] ) @@ -2438,7 +2447,7 @@ def torchdynamo_smart_context_manager(self): """ return self.ctx_manager_torchdynamo - def autocast_smart_context_manager(self): + def autocast_smart_context_manager(self, enabled): """ A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired arguments, depending on the situation. @@ -2448,10 +2457,10 @@ def autocast_smart_context_manager(self): ctx_manager = ( torch.cpu.amp.autocast(dtype=self.amp_dtype) if self.use_cpu_amp - else torch.cuda.amp.autocast(dtype=self.amp_dtype) + else torch.cuda.amp.autocast(dtype=self.amp_dtype, enabled=enabled) ) else: - ctx_manager = torch.cuda.amp.autocast() + ctx_manager = torch.cuda.amp.autocast(enabled=enabled) else: ctx_manager = contextlib.nullcontext() if sys.version_info >= (3, 7) else contextlib.suppress() @@ -2482,7 +2491,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps) return loss_mb.reduce_mean().detach().to(self.args.device) - with self.compute_loss_context_manager(): + with self.compute_loss_context_manager(enabled=hasattr(self, "scaler") and self.scaler.is_enabled()): loss = self.compute_loss(model, inputs) if self.args.n_gpu > 1: @@ -2939,7 +2948,14 @@ def evaluation_loop( observed_num_examples = 0 # Main evaluation loop + module_forward_fn = model.module.forward if isinstance(model, nn.DataParallel) else model.forward for step, inputs in enumerate(dataloader): + inputs = { + k: inputs[k] + for k in inputs + if k in list(inspect.signature(module_forward_fn).parameters.keys()) + } + # Update the observed num examples observed_batch_size = find_batch_size(inputs) if observed_batch_size is not None: @@ -3191,7 +3207,9 @@ def prediction_step( logits = smp_nested_concat(logits_mb) else: if has_labels: - with self.compute_loss_context_manager(): + with self.compute_loss_context_manager( + enabled=hasattr(self, "scaler") and self.scaler.is_enabled() + ): loss, outputs = self.compute_loss(model, inputs, return_outputs=True) loss = loss.mean().detach() @@ -3201,7 +3219,9 @@ def prediction_step( logits = outputs[1:] else: loss = None - with self.compute_loss_context_manager(): + with self.compute_loss_context_manager( + enabled=hasattr(self, "scaler") and self.scaler.is_enabled() + ): outputs = model(**inputs) if isinstance(outputs, dict): logits = tuple(v for k, v in outputs.items() if k not in ignore_keys) diff --git a/src/transformers/trainer_seq2seq.py b/src/transformers/trainer_seq2seq.py index 7689998c051b8f..39bbc10d4faad7 100644 --- a/src/transformers/trainer_seq2seq.py +++ b/src/transformers/trainer_seq2seq.py @@ -208,7 +208,7 @@ def prediction_step( generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_new_tokens"] + 1) with torch.no_grad(): - with self.compute_loss_context_manager(): + with self.compute_loss_context_manager(enabled=hasattr(self, "scaler") and self.scaler.is_enabled()): outputs = model(**inputs) if has_labels: if self.label_smoother is not None: diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 140cfc78f4edd1..c73d58fb5db767 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -1016,6 +1016,9 @@ class _LazyModule(ModuleType): Module class that surfaces all objects but only performs associated imports when the objects are requested. """ + # flag to signal NM integration is active + NM_INTEGRATED = True + # Very heavily inspired by optuna.integration._IntegrationModule # https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py def __init__(self, name, module_file, import_structure, module_spec=None, extra_objects=None):