Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update main with v4.31 #90

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions .github/workflows/build-nightly.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
name: build-nightly
run-name: ${{ github.workflow }} is to create nightly wheel file for pypi
on:
push:
branches:
- 'main'
schedule:
- cron: '0 0 * * *'
workflow_dispatch:
jobs:
build-nightly:
runs-on: ubuntu-22.04
permissions:
id-token: write
contents: read
steps:
- uses: aws-actions/configure-aws-credentials@v2
with:
role-to-assume: ${{ secrets.AWS_WEBIDENTITY_FOR_GITHUB_ACTIONS }}
aws-region: us-east-1
- uses: actions/checkout@v3
- run: |
pwd
sudo apt-get install python3-pip
pip3 --version
sudo pip3 install virtualenv
virtualenv venv
source venv/bin/activate
pip install -e .
make -B build
deactivate
ls dist/
aws s3 cp dist/*nightly*.whl s3://nm-github-actions/${{ github.event.repository.name }}/
todaytime=`date +%Y%m%d`
date '+%Y%m%d-%k:%M:%S' | tee log_${GITHUB_REF_NAME}_nightly_${todaytime}_${GITHUB_SHA:0:7}
aws s3 cp log_${GITHUB_REF_NAME}_nightly_${todaytime}_${GITHUB_SHA:0:7} s3://nm-github-actions/${{ github.event.repository.name }}/
oldDate=`date --date='-2 month' +%Y%m%d`
oldWhl=`(aws s3 ls s3://nm-github-actions/${{ github.event.repository.name }}/ | grep nightly | grep "${oldDate}") || echo "notfound"`
if [[ "${oldWhl}" != 'notfound' ]]; then
for oldwhl in $(echo "${oldWhl}" | awk '{print $4}')
do
echo "Remove old build ${oldwhl}"
aws s3 rm s3://nm-github-actions/${{ github.event.repository.name }}/${oldwhl}
done
fi
35 changes: 35 additions & 0 deletions .github/workflows/build-release.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
name: build-release
run-name: ${{ github.workflow }} is to create release wheel file for pypi
on:
push:
branches:
- 'release/[0-9]+.[0-9]+'
workflow_dispatch:
jobs:
build-release:
runs-on: ubuntu-22.04
permissions:
id-token: write
contents: read
steps:
- uses: aws-actions/configure-aws-credentials@v2
with:
role-to-assume: ${{ secrets.AWS_WEBIDENTITY_FOR_GITHUB_ACTIONS }}
aws-region: us-east-1
- uses: actions/checkout@v3
- run: |
pwd
sudo apt-get install python3-pip
pip3 --version
sudo pip3 install virtualenv
virtualenv venv
source venv/bin/activate
pip install -e .
sed -i 's/is_release = False/is_release = True/g' src/${{ github.event.repository.name }}/version.py
make -B build
deactivate
ls dist/
aws s3 cp dist/*.whl s3://nm-github-actions/${{ github.event.repository.name }}/
todaytime=`date +%Y%m%d`
date '+%Y%m%d-%k:%M:%S' | tee log_${GITHUB_REF_NAME/\//-}_release_${todaytime}_${GITHUB_SHA:0:7}
aws s3 cp log_${GITHUB_REF_NAME/\//-}_release_${todaytime}_${GITHUB_SHA:0:7} s3://nm-github-actions/${{ github.event.repository.name }}/
6 changes: 5 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
.PHONY: deps_table_update modified_only_fixup extra_style_checks quality style fixup fix-copies test test-examples
.PHONY: deps_table_update modified_only_fixup extra_style_checks quality style fixup fix-copies test test-examples build

# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
export PYTHONPATH = src
Expand Down Expand Up @@ -118,3 +118,7 @@ build-release:
python setup.py bdist_wheel
python setup.py sdist
python utils/check_build.py

# neuralmagic: creates wheel file
build:
python3 setup.py sdist bdist_wheel
11 changes: 8 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,17 +425,22 @@ def run(self):
deps["tqdm"], # progress bars in model download and training scripts
]

# default variable to be overwritten by the version.py file
version = "unknown"
# load and overwrite version and release info from version.py
exec(open(os.path.join("src", "transformers", "version.py")).read())

setup(
name="transformers",
version="4.31.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
name="nm-transformers" if is_release else "nm-transformers-nightly",
version=version, # major.minor.patch to match NM repos, fourth entry is either transformers base version or nightly date
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
author_email="transformers@huggingface.co",
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
keywords="NLP vision speech deep learning transformer pytorch tensorflow jax BERT GPT-2 Wav2Vec2 ViT",
license="Apache 2.0 License",
url="https://github.com/huggingface/transformers",
url="https://github.com/neuralmagic/transformers",
package_dir={"": "src"},
packages=find_packages("src"),
include_package_data=True,
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
# in the namespace without actually importing anything (and especially none of the backends).

__version__ = "4.31.0"
from .version import *

from typing import TYPE_CHECKING

Expand Down
46 changes: 43 additions & 3 deletions src/transformers/hf_argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,16 @@
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Literal, NewType, Optional, Tuple, Union, get_type_hints

import os
import yaml

from sparsezoo import Model

from .utils.logging import get_logger


logger = get_logger(__name__)


DataClass = NewType("DataClass", Any)
DataClassType = NewType("DataClassType", Any)
Expand Down Expand Up @@ -341,12 +349,17 @@ def parse_args_into_dataclasses(
# additional namespace.
outputs.append(namespace)
if return_remaining_strings:
return (*outputs, remaining_args)
return tuple(
*[_download_dataclass_zoo_stub_files(output) for output in outputs],
remaining_args,
)
else:
if remaining_args:
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {remaining_args}")

return (*outputs,)
return tuple(
[_download_dataclass_zoo_stub_files(output) for output in outputs]
)

def parse_dict(self, args: Dict[str, Any], allow_extra_keys: bool = False) -> Tuple[DataClass, ...]:
"""
Expand Down Expand Up @@ -374,7 +387,9 @@ def parse_dict(self, args: Dict[str, Any], allow_extra_keys: bool = False) -> Tu
outputs.append(obj)
if not allow_extra_keys and unused_keys:
raise ValueError(f"Some keys are not used by the HfArgumentParser: {sorted(unused_keys)}")
return tuple(outputs)
return tuple(
[_download_dataclass_zoo_stub_files(output) for output in outputs]
)

def parse_json_file(self, json_file: str, allow_extra_keys: bool = False) -> Tuple[DataClass, ...]:
"""
Expand Down Expand Up @@ -417,3 +432,28 @@ def parse_yaml_file(self, yaml_file: str, allow_extra_keys: bool = False) -> Tup
"""
outputs = self.parse_dict(yaml.safe_load(Path(yaml_file).read_text()), allow_extra_keys=allow_extra_keys)
return tuple(outputs)

def _download_dataclass_zoo_stub_files(data_class: DataClass):
for name, val in data_class.__dict__.items():
if not isinstance(val, str) or "recipe" in name or not val.startswith("zoo:"):
continue

logger.info(f"Downloading framework files for SparseZoo stub: {val}")

zoo_model = Model(val)
framework_file_paths = [file.path for file in zoo_model.training.default.files]
assert framework_file_paths, "Unable to download any framework files for SparseZoo stub {val}"
framework_file_names = [os.path.basename(path) for path in framework_file_paths]
if "pytorch_model.bin" not in framework_file_names or ("config.json" not in framework_file_names):
raise RuntimeError(
"Unable to find 'pytorch_model.bin' and 'config.json' in framework "
f"files downloaded from {val}. Found {framework_file_names}. Check "
"if the given stub is for a transformers repo model"
)
framework_dir_path = Path(framework_file_paths[0]).parent.absolute()

logger.info(f"Overwriting argument {name} to downloaded {framework_dir_path}")

data_class.__dict__[name] = str(framework_dir_path)

return data_class
25 changes: 23 additions & 2 deletions src/transformers/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,22 @@ def forward(
return embeddings


class QATMatMul(nn.Module):
def __init__(self):
super().__init__()

# behaves like normal torch.matmul unless a SparseML QuantizationModifier
# is initialized
self.wrap_qat = True
self.qat_wrapper_kwargs = {
"num_inputs": 2,
"input_qconfigs": ["asymmetric", "symmetric"],
}

def forward(self, a: torch.Tensor, b: torch.Tensor):
return torch.matmul(a, b)


class BertSelfAttention(nn.Module):
def __init__(self, config, position_embedding_type=None):
super().__init__()
Expand All @@ -258,6 +274,11 @@ def __init__(self, config, position_embedding_type=None):
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)

# non-parameterized matmuls will behave as normal torch.matmul ops unless
# Quantization-Aware-Training is invoked
self.attention_scores_matmul = QATMatMul()
self.context_layer_matmul = QATMatMul()

self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = position_embedding_type or getattr(
config, "position_embedding_type", "absolute"
Expand Down Expand Up @@ -322,7 +343,7 @@ def forward(
past_key_value = (key_layer, value_layer)

# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = self.attention_scores_matmul(query_layer, key_layer.transpose(-1, -2))

if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
Expand Down Expand Up @@ -362,7 +383,7 @@ def forward(
if head_mask is not None:
attention_probs = attention_probs * head_mask

context_layer = torch.matmul(attention_probs, value_layer)
context_layer = self.context_layer_matmul(attention_probs, value_layer)

context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
Expand Down
42 changes: 39 additions & 3 deletions src/transformers/models/distilbert/modeling_distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,38 @@ def _create_sinusoidal_embeddings(n_pos: int, dim: int, out: torch.Tensor):
out.detach_()


class QATAttentionScores(nn.Module):
def __init__(self):
super().__init__()

# behaves like normal torch.matmul unless a SparseML QuantizationModifier
# is initialized
self.wrap_qat = True
self.qat_wrapper_kwargs = {
"num_inputs": 2,
"input_qconfigs": ["asymmetric", "symmetric"],
}

def forward(self, a: torch.Tensor, b: torch.Tensor):
return torch.matmul(a, b)

class QATContextLayer(nn.Module):
def __init__(self):
super().__init__()

# behaves like normal torch.matmul unless a SparseML QuantizationModifier
# is initialized
self.wrap_qat = True
self.qat_wrapper_kwargs = {
"num_inputs": 2,
"num_outputs": 0,
"input_qconfigs": ["asymmetric", "symmetric"],
}

def forward(self, a: torch.Tensor, b: torch.Tensor):
return torch.matmul(a, b)


class Embeddings(nn.Module):
def __init__(self, config: PretrainedConfig):
super().__init__()
Expand Down Expand Up @@ -159,6 +191,11 @@ def __init__(self, config: PretrainedConfig):
self.pruned_heads: Set[int] = set()
self.attention_head_size = self.dim // self.n_heads

# non-parameterized matmuls will behave as normal torch.matmul ops unless
# Quantization-Aware-Training is invoked
self.attention_scores_matmul = QATAttentionScores()
self.context_layer_matmul = QATContextLayer()

def prune_heads(self, heads: List[int]):
if len(heads) == 0:
return
Expand Down Expand Up @@ -217,7 +254,7 @@ def unshape(x: torch.Tensor) -> torch.Tensor:
v = shape(self.v_lin(value)) # (bs, n_heads, k_length, dim_per_head)

q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head)
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length)
scores = self.attention_scores_matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length)
mask = (mask == 0).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length)
scores = scores.masked_fill(
mask, torch.tensor(torch.finfo(scores.dtype).min)
Expand All @@ -230,7 +267,7 @@ def unshape(x: torch.Tensor) -> torch.Tensor:
if head_mask is not None:
weights = weights * head_mask

context = torch.matmul(weights, v) # (bs, n_heads, q_length, dim_per_head)
context = self.context_layer_matmul(weights, v) # (bs, n_heads, q_length, dim_per_head)
context = unshape(context) # (bs, q_length, dim)
context = self.out_lin(context) # (bs, q_length, dim)

Expand Down Expand Up @@ -687,7 +724,6 @@ def forward(
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

dlbrt_output = self.distilbert(
input_ids=input_ids,
attention_mask=attention_mask,
Expand Down
19 changes: 18 additions & 1 deletion src/transformers/models/mobilebert/modeling_mobilebert.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,23 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:

NORM2FN = {"layer_norm": nn.LayerNorm, "no_norm": NoNorm}

class QATEmbeddingTransformation(nn.Module):
def __init__(self, embedded_input_size, hidden_size):
super().__init__()

# Behaves like normal Linear module unless a SparseML QuantizationModifier
# is initialized.
# When initialized, does not quantize inputs.
# Only weights are quantized (inputs come quantized from embeddings)
self.linear = nn.Linear(embedded_input_size, hidden_size)
self.wrap_qat = True
self.qat_wrapper_kwargs = {
"num_inputs": 0,
"num_outputs": 1,
}

def forward(self, x: torch.Tensor):
return self.linear(x)

class MobileBertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""
Expand All @@ -185,7 +202,7 @@ def __init__(self, config):

embed_dim_multiplier = 3 if self.trigram_input else 1
embedded_input_size = self.embedding_size * embed_dim_multiplier
self.embedding_transformation = nn.Linear(embedded_input_size, config.hidden_size)
self.embedding_transformation = QATEmbeddingTransformation(embedded_input_size, config.hidden_size)

self.LayerNorm = NORM2FN[config.normalization_type](config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
Expand Down
Loading
Loading