Skip to content

Commit

Permalink
Upgrade to transformers release V4.23.1 (huggingface#62)
Browse files Browse the repository at this point in the history
* Update trainer and model flows to accommodate sparseml

Disable FP16 on QAT start (huggingface#12)

* Override LRScheduler when using LRModifiers

* Disable FP16 on QAT start

* keep wrapped scaler object for training after disabling

Using QATMatMul in DistilBERT model class (huggingface#41)

Removed double quantization of output of context layer. (huggingface#45)

Fix DataParallel validation forward signatures (huggingface#47)

* Fix: DataParallel validation forward signatures

* Update: generalize forward_fn selection

Best model after epoch (huggingface#46)

fix sclaer check for non fp16 mode in trainer (huggingface#38)

Mobilebert QAT (huggingface#55)

* Remove duplicate quantization of vocabulary.

enable a QATWrapper for non-parameterized matmuls in BERT self attention (huggingface#9)

* Utils and auxillary changes

update Zoo stub loading for SparseZoo 1.1 refactor (huggingface#54)

add flag to signal NM integration is active (huggingface#32)

Add recipe_name to file names

* Fix errors introduced in manual cherry-pick upgrade

Co-authored-by: Benjamin Fineran <bfineran@users.noreply.github.com>
  • Loading branch information
KSGulin and bfineran authored Oct 18, 2022
1 parent 53c407d commit 46f78c1
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 18 deletions.
47 changes: 43 additions & 4 deletions src/transformers/hf_argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,16 @@
from inspect import isclass
from pathlib import Path
from typing import Any, Dict, Iterable, NewType, Optional, Tuple, Union, get_type_hints

import os
import yaml

from sparsezoo import Model

from .utils.logging import get_logger


logger = get_logger(__name__)


DataClass = NewType("DataClass", Any)
DataClassType = NewType("DataClassType", Any)
Expand Down Expand Up @@ -229,12 +236,17 @@ def parse_args_into_dataclasses(
# additional namespace.
outputs.append(namespace)
if return_remaining_strings:
return (*outputs, remaining_args)
return tuple(
*[_download_dataclass_zoo_stub_files(output) for output in outputs],
remaining_args,
)
else:
if remaining_args:
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {remaining_args}")

return (*outputs,)
return tuple(
[_download_dataclass_zoo_stub_files(output) for output in outputs]
)

def parse_dict(self, args: Dict[str, Any], allow_extra_keys: bool = False) -> Tuple[DataClass, ...]:
"""
Expand Down Expand Up @@ -262,7 +274,9 @@ def parse_dict(self, args: Dict[str, Any], allow_extra_keys: bool = False) -> Tu
outputs.append(obj)
if not allow_extra_keys and unused_keys:
raise ValueError(f"Some keys are not used by the HfArgumentParser: {sorted(unused_keys)}")
return tuple(outputs)
return tuple(
[_download_dataclass_zoo_stub_files(output) for output in outputs]
)

def parse_json_file(self, json_file: str, allow_extra_keys: bool = False) -> Tuple[DataClass, ...]:
"""
Expand Down Expand Up @@ -305,3 +319,28 @@ def parse_yaml_file(self, yaml_file: str, allow_extra_keys: bool = False) -> Tup
"""
outputs = self.parse_dict(yaml.safe_load(Path(yaml_file).read_text()), allow_extra_keys=allow_extra_keys)
return tuple(outputs)

def _download_dataclass_zoo_stub_files(data_class: DataClass):
for name, val in data_class.__dict__.items():
if not isinstance(val, str) or "recipe" in name or not val.startswith("zoo:"):
continue

logger.info(f"Downloading framework files for SparseZoo stub: {val}")

zoo_model = Model(val)
framework_file_paths = [file.path for file in zoo_model.training.default.files]
assert framework_file_paths, "Unable to download any framework files for SparseZoo stub {val}"
framework_file_names = [os.path.basename(path) for path in framework_file_paths]
if "pytorch_model.bin" not in framework_file_names or ("config.json" not in framework_file_names):
raise RuntimeError(
"Unable to find 'pytorch_model.bin' and 'config.json' in framework "
f"files downloaded from {val}. Found {framework_file_names}. Check "
"if the given stub is for a transformers repo model"
)
framework_dir_path = Path(framework_file_paths[0]).parent.absolute()

logger.info(f"Overwriting argument {name} to downloaded {framework_dir_path}")

data_class.__dict__[name] = str(framework_dir_path)

return data_class
42 changes: 39 additions & 3 deletions src/transformers/models/distilbert/modeling_distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,38 @@ def _create_sinusoidal_embeddings(n_pos: int, dim: int, out: torch.Tensor):
out.detach_()


class QATAttentionScores(nn.Module):
def __init__(self):
super().__init__()

# behaves like normal torch.matmul unless a SparseML QuantizationModifier
# is initialized
self.wrap_qat = True
self.qat_wrapper_kwargs = {
"num_inputs": 2,
"input_qconfigs": ["asymmetric", "symmetric"],
}

def forward(self, a: torch.Tensor, b: torch.Tensor):
return torch.matmul(a, b)

class QATContextLayer(nn.Module):
def __init__(self):
super().__init__()

# behaves like normal torch.matmul unless a SparseML QuantizationModifier
# is initialized
self.wrap_qat = True
self.qat_wrapper_kwargs = {
"num_inputs": 2,
"num_outputs": 0,
"input_qconfigs": ["asymmetric", "symmetric"],
}

def forward(self, a: torch.Tensor, b: torch.Tensor):
return torch.matmul(a, b)


class Embeddings(nn.Module):
def __init__(self, config: PretrainedConfig):
super().__init__()
Expand Down Expand Up @@ -150,6 +182,11 @@ def __init__(self, config: PretrainedConfig):

self.pruned_heads: Set[int] = set()

# non-parameterized matmuls will behave as normal torch.matmul ops unless
# Quantization-Aware-Training is invoked
self.attention_scores_matmul = QATAttentionScores()
self.context_layer_matmul = QATContextLayer()

def prune_heads(self, heads: List[int]):
attention_head_size = self.dim // self.n_heads
if len(heads) == 0:
Expand Down Expand Up @@ -207,7 +244,7 @@ def unshape(x: torch.Tensor) -> torch.Tensor:
v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head)

q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head)
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length)
scores = self.attention_scores_matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length)
mask = (mask == 0).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length)
scores = scores.masked_fill(
mask, torch.tensor(torch.finfo(scores.dtype).min)
Expand All @@ -220,7 +257,7 @@ def unshape(x: torch.Tensor) -> torch.Tensor:
if head_mask is not None:
weights = weights * head_mask

context = torch.matmul(weights, v) # (bs, n_heads, q_length, dim_per_head)
context = self.context_layer_matmul(weights, v) # (bs, n_heads, q_length, dim_per_head)
context = unshape(context) # (bs, q_length, dim)
context = self.out_lin(context) # (bs, q_length, dim)

Expand Down Expand Up @@ -645,7 +682,6 @@ def forward(
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

dlbrt_output = self.distilbert(
input_ids=input_ids,
attention_mask=attention_mask,
Expand Down
19 changes: 18 additions & 1 deletion src/transformers/models/mobilebert/modeling_mobilebert.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,23 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:

NORM2FN = {"layer_norm": nn.LayerNorm, "no_norm": NoNorm}

class QATEmbeddingTransformation(nn.Module):
def __init__(self, embedded_input_size, hidden_size):
super().__init__()

# Behaves like normal Linear module unless a SparseML QuantizationModifier
# is initialized.
# When initialized, does not quantize inputs.
# Only weights are quantized (inputs come quantized from embeddings)
self.linear = nn.Linear(embedded_input_size, hidden_size)
self.wrap_qat = True
self.qat_wrapper_kwargs = {
"num_inputs": 0,
"num_outputs": 1,
}

def forward(self, x: torch.Tensor):
return self.linear(x)

class MobileBertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""
Expand All @@ -186,7 +203,7 @@ def __init__(self, config):

embed_dim_multiplier = 3 if self.trigram_input else 1
embedded_input_size = self.embedding_size * embed_dim_multiplier
self.embedding_transformation = nn.Linear(embedded_input_size, config.hidden_size)
self.embedding_transformation = QATEmbeddingTransformation(embedded_input_size, config.hidden_size)

self.LayerNorm = NORM2FN[config.normalization_type](config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
Expand Down
38 changes: 29 additions & 9 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1687,6 +1687,10 @@ def _inner_training_loop(
_ = list(train_dataloader.sampler)

for epoch in range(epochs_trained, num_train_epochs):
if self.use_cuda_amp and hasattr(self, "qat_active") and callable(self.qat_active) and self.qat_active(epoch):
logger.info("entering QAT phase, disabling FP16 training")
self.scaler._enabled = False

if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
train_dataloader.sampler.set_epoch(epoch)
elif hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDatasetShard):
Expand Down Expand Up @@ -2167,7 +2171,12 @@ def _save_checkpoint(self, model, trial, metrics=None):
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))

# Determine the new best metric / best model checkpoint
if metrics is not None and self.args.metric_for_best_model is not None:
if (
metrics is not None
and self.args.metric_for_best_model is not None
and self.args.best_model_after_epoch is not None
and self.state.epoch > self.args.best_model_after_epoch
):
metric_to_check = self.args.metric_for_best_model
if not metric_to_check.startswith("eval_"):
metric_to_check = f"eval_{metric_to_check}"
Expand Down Expand Up @@ -2421,14 +2430,14 @@ def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[s

return inputs

def compute_loss_context_manager(self):
def compute_loss_context_manager(self, enabled):
"""
A helper wrapper to group together context managers.
"""
return ContextManagers(
[
self.torchdynamo_smart_context_manager(),
self.autocast_smart_context_manager(),
self.autocast_smart_context_manager(enabled=enabled),
]
)

Expand All @@ -2438,7 +2447,7 @@ def torchdynamo_smart_context_manager(self):
"""
return self.ctx_manager_torchdynamo

def autocast_smart_context_manager(self):
def autocast_smart_context_manager(self, enabled):
"""
A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired
arguments, depending on the situation.
Expand All @@ -2448,10 +2457,10 @@ def autocast_smart_context_manager(self):
ctx_manager = (
torch.cpu.amp.autocast(dtype=self.amp_dtype)
if self.use_cpu_amp
else torch.cuda.amp.autocast(dtype=self.amp_dtype)
else torch.cuda.amp.autocast(dtype=self.amp_dtype, enabled=enabled)
)
else:
ctx_manager = torch.cuda.amp.autocast()
ctx_manager = torch.cuda.amp.autocast(enabled=enabled)
else:
ctx_manager = contextlib.nullcontext() if sys.version_info >= (3, 7) else contextlib.suppress()

Expand Down Expand Up @@ -2482,7 +2491,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
return loss_mb.reduce_mean().detach().to(self.args.device)

with self.compute_loss_context_manager():
with self.compute_loss_context_manager(enabled=hasattr(self, "scaler") and self.scaler.is_enabled()):
loss = self.compute_loss(model, inputs)

if self.args.n_gpu > 1:
Expand Down Expand Up @@ -2939,7 +2948,14 @@ def evaluation_loop(

observed_num_examples = 0
# Main evaluation loop
module_forward_fn = model.module.forward if isinstance(model, nn.DataParallel) else model.forward
for step, inputs in enumerate(dataloader):
inputs = {
k: inputs[k]
for k in inputs
if k in list(inspect.signature(module_forward_fn).parameters.keys())
}

# Update the observed num examples
observed_batch_size = find_batch_size(inputs)
if observed_batch_size is not None:
Expand Down Expand Up @@ -3191,7 +3207,9 @@ def prediction_step(
logits = smp_nested_concat(logits_mb)
else:
if has_labels:
with self.compute_loss_context_manager():
with self.compute_loss_context_manager(
enabled=hasattr(self, "scaler") and self.scaler.is_enabled()
):
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
loss = loss.mean().detach()

Expand All @@ -3201,7 +3219,9 @@ def prediction_step(
logits = outputs[1:]
else:
loss = None
with self.compute_loss_context_manager():
with self.compute_loss_context_manager(
enabled=hasattr(self, "scaler") and self.scaler.is_enabled()
):
outputs = model(**inputs)
if isinstance(outputs, dict):
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/trainer_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def prediction_step(
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_new_tokens"] + 1)

with torch.no_grad():
with self.compute_loss_context_manager():
with self.compute_loss_context_manager(enabled=hasattr(self, "scaler") and self.scaler.is_enabled()):
outputs = model(**inputs)
if has_labels:
if self.label_smoother is not None:
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,6 +1016,9 @@ class _LazyModule(ModuleType):
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""

# flag to signal NM integration is active
NM_INTEGRATED = True

# Very heavily inspired by optuna.integration._IntegrationModule
# https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
def __init__(self, name, module_file, import_structure, module_spec=None, extra_objects=None):
Expand Down

0 comments on commit 46f78c1

Please sign in to comment.