diff --git a/.azure-pipelines/scripts/codeScan/pydocstyle/scan_path.txt b/.azure-pipelines/scripts/codeScan/pydocstyle/scan_path.txt index 317a74f7d2f..ed2c4ccafca 100644 --- a/.azure-pipelines/scripts/codeScan/pydocstyle/scan_path.txt +++ b/.azure-pipelines/scripts/codeScan/pydocstyle/scan_path.txt @@ -25,4 +25,5 @@ /neural-compressor/neural_compressor/torch/algorithms/static_quant /neural-compressor/neural_compressor/torch/algorithms/weight_only /neural-compressor/neural_compressor/torch/export +/neural-compressor/neural_compressor/torch/quantization /neural-compressor/neural_compressor/torch/utils diff --git a/.azure-pipelines/scripts/ut/3x/coverage.3x_pt b/.azure-pipelines/scripts/ut/3x/coverage.3x_pt index 2902c0c8f9c..dd4991f5fa7 100644 --- a/.azure-pipelines/scripts/ut/3x/coverage.3x_pt +++ b/.azure-pipelines/scripts/ut/3x/coverage.3x_pt @@ -6,7 +6,7 @@ include = */neural_compressor/common/* */neural_compressor/torch/* omit = - */neural_compressor/torch/algorithms/habana_fp8/* + */neural_compressor/torch/algorithms/fp8_quant/* */neural_compressor/torch/amp/* exclude_lines = pragma: no cover diff --git a/.azure-pipelines/scripts/ut/3x/coverage.3x_pt_fp8 b/.azure-pipelines/scripts/ut/3x/coverage.3x_pt_fp8 index f1bf27d8da3..9b12b354d83 100644 --- a/.azure-pipelines/scripts/ut/3x/coverage.3x_pt_fp8 +++ b/.azure-pipelines/scripts/ut/3x/coverage.3x_pt_fp8 @@ -3,8 +3,7 @@ branch = True [report] include = - */neural_compressor/torch/algorithms/habana_fp8/* - */neural_compressor/torch/amp/* + */neural_compressor/torch/algorithms/fp8_quant/* exclude_lines = pragma: no cover raise NotImplementedError diff --git a/.azure-pipelines/scripts/ut/3x/run_3x_pt.sh b/.azure-pipelines/scripts/ut/3x/run_3x_pt.sh index ce36d3d8bc3..fba15ce6c4e 100644 --- a/.azure-pipelines/scripts/ut/3x/run_3x_pt.sh +++ b/.azure-pipelines/scripts/ut/3x/run_3x_pt.sh @@ -15,8 +15,8 @@ export COVERAGE_RCFILE=/neural-compressor/.azure-pipelines/scripts/ut/3x/coverag inc_path=$(python -c 'import neural_compressor; print(neural_compressor.__path__[0])') cd /neural-compressor/test/3x || exit 1 rm -rf tensorflow -rm -rf onnxrt rm -rf torch/algorithms/fp8_quant +rm -rf torch/quantization/fp8_quant LOG_DIR=/neural-compressor/log_dir mkdir -p ${LOG_DIR} diff --git a/.azure-pipelines/scripts/ut/3x/run_3x_pt_fp8.sh b/.azure-pipelines/scripts/ut/3x/run_3x_pt_fp8.sh index d2aef0c3045..753dd8ac440 100644 --- a/.azure-pipelines/scripts/ut/3x/run_3x_pt_fp8.sh +++ b/.azure-pipelines/scripts/ut/3x/run_3x_pt_fp8.sh @@ -5,11 +5,13 @@ echo "${test_case}" # install requirements echo "set up UT env..." +export LD_LIBRARY_PATH=/usr/local/lib/:$LD_LIBRARY_PATH sed -i '/^intel_extension_for_pytorch/d' /neural-compressor/test/3x/torch/requirements.txt pip install -r /neural-compressor/test/3x/torch/requirements.txt pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.16.0 pip install pytest-cov pip install pytest-html +pip install pytest-html-merger pip list export COVERAGE_RCFILE=/neural-compressor/.azure-pipelines/scripts/ut/3x/coverage.3x_pt_fp8 @@ -19,8 +21,13 @@ cd /neural-compressor/test/3x || exit 1 LOG_DIR=/neural-compressor/log_dir mkdir -p ${LOG_DIR} ut_log_name=${LOG_DIR}/ut_3x_pt_fp8.log -pytest --cov="${inc_path}" -vs --disable-warnings --html=report.html --self-contained-html torch/algorithms/fp8_quant 2>&1 | tee -a ${ut_log_name} +pytest --cov="${inc_path}" -vs --disable-warnings --html=report_1.html --self-contained-html torch/quantization/weight_only/test_load.py 2>&1 | tee -a ${ut_log_name} +pytest --cov="${inc_path}" -vs --disable-warnings --html=report_2.html --self-contained-html torch/quantization/weight_only/test_rtn.py 2>&1 | tee -a ${ut_log_name} +# pytest --cov="${inc_path}" -vs --disable-warnings --html=report_3.html --self-contained-html torch/quantization/weight_only/test_autoround.py 2>&1 | tee -a ${ut_log_name} +pytest --cov="${inc_path}" -vs --disable-warnings --html=report_4.html --self-contained-html torch/quantization/fp8_quant 2>&1 | tee -a ${ut_log_name} +mkdir -p report && mv *.html report +pytest_html_merger -i ./report -o ./report.html cp report.html ${LOG_DIR}/ if [ $(grep -c '== FAILURES ==' ${ut_log_name}) != 0 ] || [ $(grep -c '== ERRORS ==' ${ut_log_name}) != 0 ] || [ $(grep -c ' passed' ${ut_log_name}) == 0 ]; then diff --git a/.azure-pipelines/template/docker-template.yml b/.azure-pipelines/template/docker-template.yml index 34c30734791..51103c39e21 100644 --- a/.azure-pipelines/template/docker-template.yml +++ b/.azure-pipelines/template/docker-template.yml @@ -74,7 +74,7 @@ steps: - ${{ if eq(parameters.imageSource, 'pull') }}: - script: | - docker pull vault.habana.ai/gaudi-docker/1.16.1/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest + docker pull vault.habana.ai/gaudi-docker/1.17.0/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest displayName: "Pull habana docker image" - script: | @@ -95,7 +95,7 @@ steps: else docker run -dit --disable-content-trust --privileged --name=${{ parameters.containerName }} --shm-size="2g" \ --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host \ - -v ${BUILD_SOURCESDIRECTORY}:/neural-compressor vault.habana.ai/gaudi-docker/1.16.1/ubuntu22.04/habanalabs/pytorch-installer-2.2.2:latest + -v ${BUILD_SOURCESDIRECTORY}:/neural-compressor vault.habana.ai/gaudi-docker/1.17.0/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest fi echo "Show the container list after docker run ... " docker ps -a diff --git a/.azure-pipelines/ut-3x-pt-fp8.yml b/.azure-pipelines/ut-3x-pt-fp8.yml index 490db6db3be..e8a992b6e65 100644 --- a/.azure-pipelines/ut-3x-pt-fp8.yml +++ b/.azure-pipelines/ut-3x-pt-fp8.yml @@ -10,6 +10,12 @@ pr: include: - .azure-pipelines/scripts/ut/3x/run_3x_pt_fp8.sh - .azure-pipelines/ut-3x-pt-fp8.yml + - neural_compressor/common + - neural_compressor/torch + - test/3x/torch/algorithms/fp8_quant + - test/3x/torch/quantization/fp8_quant + - setup.py + - requirements_pt.txt pool: GAUDI diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5e77a67f9f1..d93d64aba33 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -128,7 +128,8 @@ repos: examples/.*(txt|patch)| examples/onnxrt/nlp/huggingface_model/text_generation/llama/quantization/ptq_static/prompt.json| examples/notebook/dynas/ResNet50_Quantiation_Search_Supernet_NAS.ipynb| - examples/notebook/dynas/Transformer_LT_Supernet_NAS.ipynb + examples/notebook/dynas/Transformer_LT_Supernet_NAS.ipynb| + neural_compressor/torch/algorithms/fp8_quant/internal/diffusion_evaluation/SR_evaluation/imagenet1000_clsidx_to_labels.txt )$ - repo: https://github.com/astral-sh/ruff-pre-commit diff --git a/README.md b/README.md index d7f02d5aa02..349a45a9aa3 100644 --- a/README.md +++ b/README.md @@ -71,66 +71,50 @@ pip install "neural-compressor>=2.3" "transformers>=4.34.0" torch torchvision ``` After successfully installing these packages, try your first quantization program. -### Weight-Only Quantization (LLMs) -Following example code demonstrates Weight-Only Quantization on LLMs, it supports Intel CPU, Intel Gaudi2 AI Accelerator, Nvidia GPU, best device will be selected automatically. +### [FP8 Quantization](./examples/3.x_api/pytorch/cv/fp8_quant/) +Following example code demonstrates FP8 Quantization, it is supported by Intel Gaudi2 AI Accelerator. To try on Intel Gaudi2, docker image with Gaudi Software Stack is recommended, please refer to following script for environment setup. More details can be found in [Gaudi Guide](https://docs.habana.ai/en/latest/Installation_Guide/Bare_Metal_Fresh_OS.html#launch-docker-image-that-was-built). ```bash # Run a container with an interactive shell -docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.14.0/ubuntu22.04/habanalabs/pytorch-installer-2.1.1:latest - -# Install the optimum-habana -pip install --upgrade-strategy eager optimum[habana] - -# Install INC/auto_round -pip install neural-compressor auto_round +docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.17.0/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest ``` Run the example: ```python -from transformers import AutoModel, AutoTokenizer - -from neural_compressor.config import PostTrainingQuantConfig -from neural_compressor.quantization import fit -from neural_compressor.adaptor.torch_utils.auto_round import get_dataloader - -model_name = "EleutherAI/gpt-neo-125m" -float_model = AutoModel.from_pretrained(model_name) -tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) -dataloader = get_dataloader(tokenizer, seqlen=2048) - -woq_conf = PostTrainingQuantConfig( - approach="weight_only", - op_type_dict={ - ".*": { # match all ops - "weight": { - "dtype": "int", - "bits": 4, - "algorithm": "AUTOROUND", - }, - } - }, +from neural_compressor.torch.quantization import ( + FP8Config, + prepare, + convert, ) -quantized_model = fit(model=float_model, conf=woq_conf, calib_dataloader=dataloader) +import torchvision.models as models + +model = models.resnet18() +qconfig = FP8Config(fp8_config="E4M3") +model = prepare(model, qconfig) +# customer defined calibration +calib_func(model) +model = convert(model) ``` -**Note:** -To try INT4 model inference, please directly use [Intel Extension for Transformers](https://github.com/intel/intel-extension-for-transformers), which leverages Intel Neural Compressor for model quantization. +### Weight-Only Large Language Model Loading (LLMs) -### Static Quantization (Non-LLMs) +Following example code demonstrates weight-only large language model loading on Intel Gaudi2 AI Accelerator. ```python -from torchvision import models +from neural_compressor.torch.quantization import load + +model_name = "TheBloke/Llama-2-7B-GPTQ" +model = load( + model_name_or_path=model_name, + format="huggingface", + device="hpu", + torch_dtype=torch.bfloat16, +) +``` -from neural_compressor.config import PostTrainingQuantConfig -from neural_compressor.data import DataLoader, Datasets -from neural_compressor.quantization import fit +**Note:** -float_model = models.resnet18() -dataset = Datasets("pytorch")["dummy"](shape=(1, 3, 224, 224)) -calib_dataloader = DataLoader(framework="pytorch", dataset=dataset) -static_quant_conf = PostTrainingQuantConfig() -quantized_model = fit(model=float_model, conf=static_quant_conf, calib_dataloader=calib_dataloader) -``` +Intel Neural Compressor will convert the model format from auto-gptq to hpu format on the first load and save hpu_model.safetensors to the local cache directory for the next load. So it may take a while to load for the first time. ## Documentation @@ -157,12 +141,13 @@ quantized_model = fit(model=float_model, conf=static_quant_conf, calib_dataloade Overview - Static Quantization Dynamic Quantization + Static Quantization Smooth Quantization - Weight-Only Quantization + Weight-Only Quantization + FP8 Quantization MX Quantization Mixed Precision diff --git a/docs/3x/PT_FP8Quant.md b/docs/3x/PT_FP8Quant.md new file mode 100644 index 00000000000..a0ed3352e8e --- /dev/null +++ b/docs/3x/PT_FP8Quant.md @@ -0,0 +1,113 @@ +FP8 Quantization +======= + +1. [Introduction](#introduction) +2. [Supported Parameters](#supported-parameters) +3. [Get Start with FP8 Quantization](#get-start-with-fp8-quantization) +4. [Examples](#examples) + +## Introduction + +Float point 8 (FP8) is a promising data type for low precision quantization which provides a data distribution that is completely different from INT8 and it's shown as below. + +
+ +
+ +Intel Gaudi2, also known as HPU, provides this data type capability for low precision quantization, which includes `E4M3` and `E5M2`. For more information about these two data type, please refer to [link](https://arxiv.org/abs/2209.05433). + +Intel Neural Compressor provides general quantization APIs to leverage HPU FP8 capability. with simple with lower memory usage and lower compute cost, 8 bit model + +## Supported Parameters + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
AttributeDescriptionValues
fp8_configThe target data type of FP8 quantization.E4M3 (default) - As Fig. 2
E5M2 - As Fig. 1.
hp_dtypeThe high precision data type of non-FP8 operators.bf16 (default) - torch.bfloat16
fp16 - torch.float16.
fp32 - torch.float32.
observerThe observer to measure the statistics.maxabs (default), saves all tensors to files.
allowlistList of nn.Module names or types to quantize. When setting an empty list, all the supported modules will be quantized by default. See Supported Modules. Not setting the list at all is not recommended as it will set the allowlist to these modules only: torch.nn.Linear, torch.nn.Conv2d, and BMM.Default = {'names': [], 'types': FP8_WHITE_LIST}
blocklistList of nn.Module names or types not to quantize. Defaults to empty list, so you may omit it from the config file.Default = {'names': [], 'types': ()}
modeThe mode, measure or quantize, to run HQT with.MEASURE - Measure statistics of all modules and emit the results to dump_stats_path.
QUANTIZE - Quantize and run the model according to the provided measurements.
AUTO (default) - Select from [MEASURE, QUANTIZE] automatically.
dump_stats_pathThe path to save and load the measurements. The path is created up until the level before last "/". The string after the last / will be used as prefix to all the measurement files that will be created.Default = "./hqt_output/measure"
scale_methodThe method for calculating the scale from the measurement.- without_scale - Convert to/from FP8 without scaling.
- unit_scale - Always use scale of 1.
- maxabs_hw (default) - Scale is calculated to stretch/compress the maxabs measurement to the full-scale of FP8 and then aligned to the corresponding HW accelerated scale.
- maxabs_pow2 - Scale is calculated to stretch/compress the maxabs measurement to the full-scale of FP8 and then rounded to the power of 2.
- maxabs_hw_opt_weight - Scale of model params (weights) is chosen as the scale that provides minimal mean-square-error between quantized and non-quantized weights, from all possible HW accelerated scales. Scale of activations is calculated the same as maxabs_hw.
- act_maxabs_pow2_weights_pcs_opt_pow2 - Scale of model params (weights) is calculated per-channel of the params tensor. The scale per-channel is calculated the same as maxabs_hw_opt_weight. Scale of activations is calculated the same as maxabs_pow2.
- act_maxabs_hw_weights_pcs_maxabs_pow2 - Scale of model params (weights) is calculated per-channel of the params tensor. The scale per-channel is calculated the same as maxabs_pow2. Scale of activations is calculated the same as maxabs_hw.
measure_excludeIf this attribute is not defined, the default is OUTPUT. Since most models do not require measuring output tensors, you can exclude it to speed up the measurement process.NONE - All tensors are measured.
OUTPUT (default) - Excludes measurement of output tensors.
+ +## Get Start with FP8 Quantization + +### Demo Usage + +```python +from neural_compressor.torch.quantization import ( + FP8Config, + prepare, + convert, +) +import torchvision.models as models + +model = models.resnet18() +qconfig = FP8Config(fp8_config="E4M3") +model = prepare(model, qconfig) +# customer defined calibration +calib_func(model) +model = convert(model) +``` + +## Examples + +| Task | Example | +|----------------------|---------| +| Computer Vision (CV) | [Link](../../examples/3.x_api/pytorch/cv/fp8_quant/) | +| Large Language Model (LLM) | [Link](https://github.com/HabanaAI/optimum-habana-fork/tree/habana-main/examples/text-generation#running-with-fp8) | + +> Note: For LLM, Optimum-habana provides higher performance based on modified modeling files, so here the Link of LLM goes to Optimum-habana, which utilize Intel Neural Compressor for FP8 quantization internally. diff --git a/examples/.config/model_params_pytorch_3x.json b/examples/.config/model_params_pytorch_3x.json index bdb8a532561..c3ae3f6b5be 100644 --- a/examples/.config/model_params_pytorch_3x.json +++ b/examples/.config/model_params_pytorch_3x.json @@ -140,6 +140,13 @@ "main_script": "main.py", "batch_size": 1 }, + "resnet18_fp8_static":{ + "model_src_dir": "cv/fp8_quant", + "dataset_location": "/tf_dataset/pytorch/ImageNet/raw", + "input_model": "", + "main_script": "main.py", + "batch_size": 1 + }, "opt_125m_pt2e_static":{ "model_src_dir": "nlp/huggingface_models/language-modeling/quantization/static_quant/pt2e", "dataset_location": "", diff --git a/examples/3.x_api/pytorch/cv/fp8_quant/README.md b/examples/3.x_api/pytorch/cv/fp8_quant/README.md new file mode 100644 index 00000000000..72b8eb282b5 --- /dev/null +++ b/examples/3.x_api/pytorch/cv/fp8_quant/README.md @@ -0,0 +1,28 @@ +# ImageNet FP8 Quantization + +This implements FP8 quantization of popular model architectures, such as ResNet on the ImageNet dataset, which is supported by Intel Gaudi2 AI Accelerator. + +## Requirements + +To try on Intel Gaudi2, docker image with Gaudi Software Stack is recommended, please refer to following script for environment setup. More details can be found in [Gaudi Guide](https://docs.habana.ai/en/latest/Installation_Guide/Bare_Metal_Fresh_OS.html#launch-docker-image-that-was-built). +```bash +# Run a container with an interactive shell +docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.17.0/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest +``` + +- Install requirements +- `pip install -r requirements.txt` +- Download the ImageNet dataset from http://www.image-net.org/ + - Then, move and extract the training and validation images to labeled subfolders, using [the following shell script](extract_ILSVRC.sh) + +## Quantizaiton + +To quant a model and validate accaracy, run `main.py` with the desired model architecture and the path to the ImageNet dataset: + +```bash +python main.py --pretrained -t -a resnet50 -b 30 /path/to/imagenet +``` +or +```bash +bash run_quant.sh --input_model=resnet50 --dataset_location=/path/to/imagenet +``` diff --git a/examples/3.x_api/pytorch/cv/fp8_quant/extract_ILSVRC.sh b/examples/3.x_api/pytorch/cv/fp8_quant/extract_ILSVRC.sh new file mode 100644 index 00000000000..3ec05e8f328 --- /dev/null +++ b/examples/3.x_api/pytorch/cv/fp8_quant/extract_ILSVRC.sh @@ -0,0 +1,80 @@ +#!/bin/bash +# +# script to extract ImageNet dataset +# ILSVRC2012_img_train.tar (about 138 GB) +# ILSVRC2012_img_val.tar (about 6.3 GB) +# make sure ILSVRC2012_img_train.tar & ILSVRC2012_img_val.tar in your current directory +# +# Adapted from: +# https://github.com/facebook/fb.resnet.torch/blob/master/INSTALL.md +# https://gist.github.com/BIGBALLON/8a71d225eff18d88e469e6ea9b39cef4 +# +# imagenet/train/ +# ├── n01440764 +# │ ├── n01440764_10026.JPEG +# │ ├── n01440764_10027.JPEG +# │ ├── ...... +# ├── ...... +# imagenet/val/ +# ├── n01440764 +# │ ├── ILSVRC2012_val_00000293.JPEG +# │ ├── ILSVRC2012_val_00002138.JPEG +# │ ├── ...... +# ├── ...... +# +# +# Make imagnet directory +# +mkdir imagenet +# +# Extract the training data: +# +# Create train directory; move .tar file; change directory +mkdir imagenet/train && mv ILSVRC2012_img_train.tar imagenet/train/ && cd imagenet/train +# Extract training set; remove compressed file +tar -xvf ILSVRC2012_img_train.tar && rm -f ILSVRC2012_img_train.tar +# +# At this stage imagenet/train will contain 1000 compressed .tar files, one for each category +# +# For each .tar file: +# 1. create directory with same name as .tar file +# 2. extract and copy contents of .tar file into directory +# 3. remove .tar file +find . -name "*.tar" | while read NAME ; do mkdir -p "${NAME%.tar}"; tar -xvf "${NAME}" -C "${NAME%.tar}"; rm -f "${NAME}"; done +# +# This results in a training directory like so: +# +# imagenet/train/ +# ├── n01440764 +# │ ├── n01440764_10026.JPEG +# │ ├── n01440764_10027.JPEG +# │ ├── ...... +# ├── ...... +# +# Change back to original directory +cd ../.. +# +# Extract the validation data and move images to subfolders: +# +# Create validation directory; move .tar file; change directory; extract validation .tar; remove compressed file +mkdir imagenet/val && mv ILSVRC2012_img_val.tar imagenet/val/ && cd imagenet/val && tar -xvf ILSVRC2012_img_val.tar && rm -f ILSVRC2012_img_val.tar +# get script from soumith and run; this script creates all class directories and moves images into corresponding directories +wget -qO- https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh | bash +# +# This results in a validation directory like so: +# +# imagenet/val/ +# ├── n01440764 +# │ ├── ILSVRC2012_val_00000293.JPEG +# │ ├── ILSVRC2012_val_00002138.JPEG +# │ ├── ...... +# ├── ...... +# +# +# Check total files after extract +# +# $ find train/ -name "*.JPEG" | wc -l +# 1281167 +# $ find val/ -name "*.JPEG" | wc -l +# 50000 +# \ No newline at end of file diff --git a/examples/3.x_api/pytorch/cv/fp8_quant/main.py b/examples/3.x_api/pytorch/cv/fp8_quant/main.py new file mode 100644 index 00000000000..dfa7515343c --- /dev/null +++ b/examples/3.x_api/pytorch/cv/fp8_quant/main.py @@ -0,0 +1,391 @@ +import argparse +import os +import random +import shutil +import time +import warnings +import sys + +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.distributed as dist +import torch.optim +import torch.multiprocessing as mp +import torch.utils.data +import torch.utils.data.distributed +import torchvision.transforms as transforms +import torchvision.datasets as datasets +import torchvision.models as models +from neural_compressor.torch.quantization import ( + FP8Config, + prepare, + convert, +) +import habana_frameworks.torch.core as htcore + + +model_names = models.list_models(module=models) + +parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') +parser.add_argument('data', metavar='DIR', + help='path to dataset') +parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', + choices=model_names, + help='model architecture: ' + + ' | '.join(model_names) + + ' (default: resnet18)') +parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', + help='number of data loading workers (default: 4)') +parser.add_argument('--epochs', default=90, type=int, metavar='N', + help='number of total epochs to run') +parser.add_argument('--start-epoch', default=0, type=int, metavar='N', + help='manual epoch number (useful on restarts)') +parser.add_argument('-b', '--batch-size', default=256, type=int, + metavar='N', + help='mini-batch size (default: 256), this is the total ' + 'batch size of all GPUs on the current node when ' + 'using Data Parallel or Distributed Data Parallel') +parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, + metavar='LR', help='initial learning rate', dest='lr') +parser.add_argument('--momentum', default=0.9, type=float, metavar='M', + help='momentum') +parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, + metavar='W', help='weight decay (default: 1e-4)', + dest='weight_decay') +parser.add_argument('-p', '--print-freq', default=10, type=int, + metavar='N', help='print frequency (default: 10)') +parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') +parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', + help='evaluate model on validation set') +parser.add_argument('-t', '--tune', dest='tune', action='store_true', + help='tune best int8 model on calibration dataset') +parser.add_argument('--pretrained', dest='pretrained', action='store_true', + help='use pre-trained model') +parser.add_argument('--world-size', default=-1, type=int, + help='number of nodes for distributed training') +parser.add_argument('--rank', default=-1, type=int, + help='node rank for distributed training') +parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, + help='url used to set up distributed training') +parser.add_argument('--dist-backend', default='nccl', type=str, + help='distributed backend') +parser.add_argument('--seed', default=None, type=int, + help='seed for initializing training. ') +parser.add_argument('--gpu', default=None, type=int, + help='GPU id to use.') +parser.add_argument('--ppn', default=1, type=int, + help='number of processes on each node of distributed training') +parser.add_argument('--multiprocessing-distributed', action='store_true', + help='Use multi-processing distributed training to launch ' + 'N processes per node, which has N GPUs. This is the ' + 'fastest way to use PyTorch for either single node or ' + 'multi node data parallel training') +parser.add_argument("--calib_iters", default=10, type=int, + help="For calibration only.") +parser.add_argument('-i', "--iter", default=0, type=int, + help='For accuracy measurement only.') +parser.add_argument('-w', "--warmup_iter", default=5, type=int, + help='For benchmark measurement only.') +parser.add_argument('--performance', dest='performance', action='store_true', + help='run benchmark') +parser.add_argument('-r', "--accuracy", dest='accuracy', action='store_true', + help='For accuracy measurement only.') +parser.add_argument("--tuned_checkpoint", default='./saved_results', type=str, metavar='PATH', + help='path to checkpoint tuned by Neural Compressor (default: ./)') +parser.add_argument('--int8', dest='int8', action='store_true', + help='run benchmark') +parser.add_argument('--device', default='hpu', type=str, + help='use hpu device for fp8 quantization') + +best_acc1 = 0 + + +def main(): + args = parser.parse_args() + + if 'mobilenet' in args.arch: + import torchvision.models.quantization as models + else: + import torchvision.models as models + + if args.seed is not None: + random.seed(args.seed) + torch.manual_seed(args.seed) + + if args.pretrained: + print("=> using pre-trained model '{}'".format(args.arch)) + model = models.__dict__[args.arch](pretrained=True) + else: + print("=> creating model '{}'".format(args.arch)) + model = models.__dict__[args.arch]() + + # define loss function (criterion) and optimizer + criterion = nn.CrossEntropyLoss() + + optimizer = torch.optim.SGD(model.parameters(), args.lr, + momentum=args.momentum, + weight_decay=args.weight_decay) + + # optionally resume from a checkpoint + if args.resume: + if os.path.isfile(args.resume): + print("=> loading checkpoint '{}'".format(args.resume)) + checkpoint = torch.load(args.resume) + args.start_epoch = checkpoint['epoch'] + best_acc1 = checkpoint['best_acc1'] + if args.gpu is not None: + # best_acc1 may be from a checkpoint from a different GPU + best_acc1 = best_acc1.to(args.gpu) + model.load_state_dict(checkpoint['state_dict']) + optimizer.load_state_dict(checkpoint['optimizer']) + print("=> loaded checkpoint '{}' (epoch {})" + .format(args.resume, checkpoint['epoch'])) + else: + print("=> no checkpoint found at '{}'".format(args.resume)) + + # Data loading code + traindir = os.path.join(args.data, 'train') + valdir = os.path.join(args.data, 'val') + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + + train_dataset = datasets.ImageFolder( + traindir, + transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ])) + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.batch_size, shuffle=True, + num_workers=args.workers, pin_memory=True, sampler=None) + + val_dataset = datasets.ImageFolder(valdir, transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ])) + + val_loader = torch.utils.data.DataLoader( + val_dataset, + batch_size=args.batch_size, shuffle=False, + num_workers=args.workers, pin_memory=True) + + if args.evaluate: + validate(val_loader, model, criterion, args) + return + + def eval_func(model): + accu = validate(val_loader, model, criterion, args) + return float(accu) + + if args.tune: + qconfig = FP8Config(fp8_config="E4M3") + model = prepare(model, qconfig) + + # Calibrate + # model is moved to HPU device automatically after preparing + with torch.no_grad(): + for i, (images, target) in enumerate(train_loader): + print("Calibrating batch:", i) + if i == args.calib_iters: + break + images = images.to(args.device) + model(images) + htcore.mark_step() + + model = convert(model) + eval_func(model) + # The saving and loading of fp8 quantization are planned in the next release. + + if args.performance or args.accuracy: + model.eval() + if args.int8: + from neural_compressor.utils.pytorch import load + new_model = load(os.path.abspath(os.path.expanduser(args.tuned_checkpoint)), + model, + dataloader=val_loader) + else: + new_model = model + if args.performance: + from neural_compressor.config import BenchmarkConfig + from neural_compressor import benchmark + b_conf = BenchmarkConfig(warmup=5, + iteration=args.iter, + cores_per_instance=4, + num_of_instance=1) + benchmark.fit(new_model, b_conf, b_dataloader=val_loader) + if args.accuracy: + validate(val_loader, new_model, criterion, args) + return + + +def train(train_loader, model, criterion, optimizer, epoch, args): + batch_time = AverageMeter('Time', ':6.3f') + data_time = AverageMeter('Data', ':6.3f') + losses = AverageMeter('Loss', ':.4e') + top1 = AverageMeter('Acc@1', ':6.2f') + top5 = AverageMeter('Acc@5', ':6.2f') + progress = ProgressMeter(len(train_loader), batch_time, data_time, losses, top1, + top5, prefix="Epoch: [{}]".format(epoch)) + + # switch to train mode + model.train() + + end = time.time() + for i, (input, target) in enumerate(train_loader): + # measure data loading time + data_time.update(time.time() - end) + + if args.gpu is not None: + input = input.cuda(args.gpu, non_blocking=True) + target = target.cuda(args.gpu, non_blocking=True) + + # compute output + output = model(input) + loss = criterion(output, target) + + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), input.size(0)) + top1.update(acc1[0], input.size(0)) + top5.update(acc5[0], input.size(0)) + + # compute gradient and do SGD step + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + progress.print(i) + + +def validate(val_loader, model, criterion, args): + batch_time = AverageMeter('Time', ':6.3f') + losses = AverageMeter('Loss', ':.4e') + top1 = AverageMeter('Acc@1', ':6.2f') + top5 = AverageMeter('Acc@5', ':6.2f') + progress = ProgressMeter(len(val_loader), batch_time, losses, top1, top5, + prefix='Test: ') + + # switch to evaluate mode + model.eval() + + with torch.no_grad(): + for i, (input, target) in enumerate(val_loader): + if i >= args.warmup_iter: + start = time.time() + input = input.to(args.device) + target = target.to(args.device) + if args.gpu is not None: + input = input.cuda(args.gpu, non_blocking=True) + target = target.cuda(args.gpu, non_blocking=True) + + # compute output + output = model(input) + loss = criterion(output, target) + + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), input.size(0)) + top1.update(acc1[0], input.size(0)) + top5.update(acc5[0], input.size(0)) + + # measure elapsed time + if i >= args.warmup_iter: + batch_time.update(time.time() - start) + + if i % args.print_freq == 0: + progress.print(i) + + if args.iter > 0 and i >= (args.warmup_iter + args.iter - 1): + break + + print('Batch size = %d' % args.batch_size) + print('Accuracy: {top1:.5f} Accuracy@5 {top5:.5f}' + .format(top1=(top1.avg / 100), top5=(top5.avg / 100))) + + return top1.avg/100 + + +def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): + torch.save(state, filename) + if is_best: + shutil.copyfile(filename, 'model_best.pth.tar') + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self, name, fmt=':f'): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) + + +class ProgressMeter(object): + def __init__(self, num_batches, *meters, prefix=""): + self.batch_fmtstr = self._get_batch_fmtstr(num_batches) + self.meters = meters + self.prefix = prefix + + def print(self, batch): + entries = [self.prefix + self.batch_fmtstr.format(batch)] + entries += [str(meter) for meter in self.meters] + print('\t'.join(entries)) + + def _get_batch_fmtstr(self, num_batches): + num_digits = len(str(num_batches // 1)) + fmt = '{:' + str(num_digits) + 'd}' + return '[' + fmt + '/' + fmt.format(num_batches) + ']' + + +def adjust_learning_rate(optimizer, epoch, args): + """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" + lr = args.lr * (0.1 ** (epoch // 30)) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +if __name__ == '__main__': + main() diff --git a/examples/3.x_api/pytorch/cv/fp8_quant/requirements.txt b/examples/3.x_api/pytorch/cv/fp8_quant/requirements.txt new file mode 100644 index 00000000000..ebd3df6ae7a --- /dev/null +++ b/examples/3.x_api/pytorch/cv/fp8_quant/requirements.txt @@ -0,0 +1,3 @@ +torch +torchvision +neural-compressor \ No newline at end of file diff --git a/examples/3.x_api/pytorch/cv/fp8_quant/run_quant.sh b/examples/3.x_api/pytorch/cv/fp8_quant/run_quant.sh new file mode 100644 index 00000000000..4d0047cf2d1 --- /dev/null +++ b/examples/3.x_api/pytorch/cv/fp8_quant/run_quant.sh @@ -0,0 +1,53 @@ +#!/bin/bash +set -x + +function main { + + init_params "$@" + run_tuning + +} + +# init params +function init_params { + output_model=saved_results + for var in "$@" + do + case $var in + --topology=*) + topology=$(echo $var |cut -f2 -d=) + ;; + --dataset_location=*) + dataset_location=$(echo $var |cut -f2 -d=) + ;; + --input_model=*) + input_model=$(echo $var |cut -f2 -d=) + ;; + --output_model=*) + output_model=$(echo $var |cut -f2 -d=) + ;; + *) + echo "Error: No such parameter: ${var}" + exit 1 + ;; + esac + done + +} + +# run_tuning +function run_tuning { + if [ "${topology}" = "resnet18_fp8_static" ]; then + input_model="resnet18" + output_dir="saved_results" + fi + python main.py \ + --pretrained \ + -t \ + -a ${input_model} \ + -b 30 \ + --tuned_checkpoint ${output_model} \ + ${dataset_location} +} + +main "$@" diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/README.md b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/README.md deleted file mode 100644 index eb39321b173..00000000000 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/README.md +++ /dev/null @@ -1,24 +0,0 @@ -# Run - -## Run FP32 model -``` python -python run_llm.py --model [model_name_or_path] --to_graph [--performance]|[--accuracy --tasks lambada_openai --batch_size 8]|[--generate --max_new_tokens 10] -``` - -## Run BF16/FP16 model -``` python -python run_llm.py --model [model_name_or_path] --approach cast --precision [bf16|fp16] --to_graph [--performance]|[--accuracy --tasks lambada_openai --batch_size 8]|[--generate --max_new_tokens 10] -``` - -## Run FP8 model -``` python -python run_llm.py --model [model_name_or_path] --approach [dynamic|static|cast] --precision [fp8_e4m3|fp8_e5m2] --to_graph [--performance]|[--accuracy --tasks lambada_openai --batch_size 8]|[--generate --max_new_tokens 10] -``` - -# Multi-card Inference -With deepspeed we can leverage multi-cards inference with a prefix in command, below it's a demonstration of 4 card inference. - -```python -deepspeed --num_gpus=4 run_llm.py --model [model_name_or_path] --approach [dynamic|static|cast] --precision [fp8_e4m3|fp8_e5m2] --to_graph [--performance]|[--accuracy --tasks lambada_openai --batch_size 8]|[--generate --max_new_tokens 10] -``` -deepspeed --num_gpus=4 run_llm.py --model facebook/opt-125m --approach static --precision fp8_e4m3 --to_graph --accuracy --tasks lambada_openai --batch_size 8 \ No newline at end of file diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/configuration_chatglm.py b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/configuration_chatglm.py deleted file mode 100644 index 35600185f5a..00000000000 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/configuration_chatglm.py +++ /dev/null @@ -1,61 +0,0 @@ -from transformers import PretrainedConfig - - -class ChatGLMConfig(PretrainedConfig): - model_type = "chatglm" - def __init__( - self, - num_layers=28, - padded_vocab_size=65024, - hidden_size=4096, - ffn_hidden_size=13696, - kv_channels=128, - num_attention_heads=32, - seq_length=2048, - hidden_dropout=0.0, - classifier_dropout=None, - attention_dropout=0.0, - layernorm_epsilon=1e-5, - rmsnorm=True, - apply_residual_connection_post_layernorm=False, - post_layer_norm=True, - add_bias_linear=False, - add_qkv_bias=False, - bias_dropout_fusion=True, - multi_query_attention=False, - multi_query_group_num=1, - apply_query_key_layer_scaling=True, - attention_softmax_in_fp32=True, - fp32_residual_connection=False, - quantization_bit=0, - pre_seq_len=None, - prefix_projection=False, - **kwargs - ): - self.num_layers = num_layers - self.vocab_size = padded_vocab_size - self.padded_vocab_size = padded_vocab_size - self.hidden_size = hidden_size - self.ffn_hidden_size = ffn_hidden_size - self.kv_channels = kv_channels - self.num_attention_heads = num_attention_heads - self.seq_length = seq_length - self.hidden_dropout = hidden_dropout - self.classifier_dropout = classifier_dropout - self.attention_dropout = attention_dropout - self.layernorm_epsilon = layernorm_epsilon - self.rmsnorm = rmsnorm - self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm - self.post_layer_norm = post_layer_norm - self.add_bias_linear = add_bias_linear - self.add_qkv_bias = add_qkv_bias - self.bias_dropout_fusion = bias_dropout_fusion - self.multi_query_attention = multi_query_attention - self.multi_query_group_num = multi_query_group_num - self.apply_query_key_layer_scaling = apply_query_key_layer_scaling - self.attention_softmax_in_fp32 = attention_softmax_in_fp32 - self.fp32_residual_connection = fp32_residual_connection - self.quantization_bit = quantization_bit - self.pre_seq_len = pre_seq_len - self.prefix_projection = prefix_projection - super().__init__(**kwargs) \ No newline at end of file diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/modeling_chatglm.py b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/modeling_chatglm.py deleted file mode 100644 index be1cd520af5..00000000000 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/modeling_chatglm.py +++ /dev/null @@ -1,1294 +0,0 @@ -""" PyTorch ChatGLM model. """ - -import math -import copy -import warnings -import re -import sys - -import torch -import torch.utils.checkpoint -import torch.nn.functional as F -from torch import nn -from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss -from torch.nn.utils import skip_init -from typing import Optional, Tuple, Union, List, Callable, Dict, Any -from copy import deepcopy - -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, - SequenceClassifierOutputWithPast, -) -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import logging -from transformers.generation.logits_process import LogitsProcessor -from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput - -from .configuration_chatglm import ChatGLMConfig - -# flags required to enable jit fusion kernels - -if sys.platform != 'darwin': - torch._C._jit_set_profiling_mode(False) - torch._C._jit_set_profiling_executor(False) - torch._C._jit_override_can_fuse_on_cpu(True) - torch._C._jit_override_can_fuse_on_gpu(True) - -logger = logging.get_logger(__name__) - -_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM" -_CONFIG_FOR_DOC = "ChatGLMConfig" - -CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "THUDM/chatglm3-6b", - # See all ChatGLM models at https://huggingface.co/models?filter=chatglm -] - - -def default_init(cls, *args, **kwargs): - return cls(*args, **kwargs) - - -class InvalidScoreLogitsProcessor(LogitsProcessor): - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - if torch.isnan(scores).any() or torch.isinf(scores).any(): - scores.zero_() - scores[..., 5] = 5e4 - return scores - - -class PrefixEncoder(torch.nn.Module): - """ - The torch.nn model to encode the prefix - Input shape: (batch-size, prefix-length) - Output shape: (batch-size, prefix-length, 2*layers*hidden) - """ - - def __init__(self, config: ChatGLMConfig): - super().__init__() - self.prefix_projection = config.prefix_projection - if self.prefix_projection: - # Use a two-layer MLP to encode the prefix - kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2 - self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size) - self.trans = torch.nn.Sequential( - torch.nn.Linear(kv_size, config.hidden_size), - torch.nn.Tanh(), - torch.nn.Linear(config.hidden_size, kv_size) - ) - else: - self.embedding = torch.nn.Embedding(config.pre_seq_len, - config.num_layers * config.kv_channels * config.multi_query_group_num * 2) - - def forward(self, prefix: torch.Tensor): - if self.prefix_projection: - prefix_tokens = self.embedding(prefix) - past_key_values = self.trans(prefix_tokens) - else: - past_key_values = self.embedding(prefix) - return past_key_values - - -def split_tensor_along_last_dim( - tensor: torch.Tensor, - num_partitions: int, - contiguous_split_chunks: bool = False, -) -> List[torch.Tensor]: - """Split a tensor along its last dimension. - - Arguments: - tensor: input tensor. - num_partitions: number of partitions to split the tensor - contiguous_split_chunks: If True, make each chunk contiguous - in memory. - - Returns: - A list of Tensors - """ - # Get the size and dimension. - last_dim = tensor.dim() - 1 - last_dim_size = tensor.size()[last_dim] // num_partitions - # Split. - tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) - # Note: torch.split does not create contiguous tensors by default. - if contiguous_split_chunks: - return tuple(chunk.contiguous() for chunk in tensor_list) - - return tensor_list - - -class RotaryEmbedding(nn.Module): - def __init__(self, dim, original_impl=False, device=None, dtype=None): - super().__init__() - inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim)) - self.register_buffer("inv_freq", inv_freq) - self.dim = dim - self.original_impl = original_impl - - def forward_impl( - self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000 - ): - """Enhanced Transformer with Rotary Position Embedding. - - Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ - transformers/rope/__init__.py. MIT License: - https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. - """ - # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ - theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem)) - - # Create position indexes `[0, 1, ..., seq_len - 1]` - seq_idx = torch.arange(seq_len, dtype=torch.float, device=device) - - # Calculate the product of position index and $\theta_i$ - idx_theta = torch.outer(seq_idx, theta).float() - - cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) - - # this is to mimic the behaviour of complex32, else we will get different results - if dtype in (torch.float16, torch.bfloat16, torch.int8): - cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half() - return cache - - def forward(self, max_seq_len, offset=0): - return self.forward_impl( - max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device - ) - -### INC change ### -# @torch.jit.script - -def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: - # x: [sq, b, np, hn] - sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3) - rot_dim = rope_cache.shape[-2] * 2 - x, x_pass = x[..., :rot_dim], x[..., rot_dim:] - # truncate to support variable sizes - rope_cache = rope_cache[:sq] - xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2) - rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2) - x_out2 = torch.stack( - [ - xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], - xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], - ], - -1, - ) - x_out2 = x_out2.flatten(3) - return torch.cat((x_out2, x_pass), dim=-1) - - -class RMSNorm(torch.nn.Module): - def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs): - super().__init__() - self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype)) - self.eps = eps - - def forward(self, hidden_states: torch.Tensor): - input_dtype = hidden_states.dtype - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.eps) - - return (self.weight * hidden_states).to(input_dtype) - - -class CoreAttention(torch.nn.Module): - def __init__(self, config: ChatGLMConfig, layer_number): - super(CoreAttention, self).__init__() - - self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling - self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 - if self.apply_query_key_layer_scaling: - self.attention_softmax_in_fp32 = True - self.layer_number = max(1, layer_number) - - projection_size = config.kv_channels * config.num_attention_heads - - # Per attention head and per partition values. - self.hidden_size_per_partition = projection_size - self.hidden_size_per_attention_head = projection_size // config.num_attention_heads - self.num_attention_heads_per_partition = config.num_attention_heads - - coeff = None - self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) - if self.apply_query_key_layer_scaling: - coeff = self.layer_number - self.norm_factor *= coeff - self.coeff = coeff - - self.attention_dropout = torch.nn.Dropout(config.attention_dropout) - - def forward(self, query_layer, key_layer, value_layer, attention_mask): - pytorch_major_version = int(torch.__version__.split('.')[0]) - if pytorch_major_version >= 2: - query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] - if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: - context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, - is_causal=True) - else: - if attention_mask is not None: - attention_mask = ~attention_mask - context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, - attention_mask) - context_layer = context_layer.permute(2, 0, 1, 3) - new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) - context_layer = context_layer.reshape(*new_context_layer_shape) - else: - # Raw attention scores - - # [b, np, sq, sk] - output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0)) - - # [sq, b, np, hn] -> [sq, b * np, hn] - query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) - # [sk, b, np, hn] -> [sk, b * np, hn] - key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) - - # preallocting input tensor: [b * np, sq, sk] - matmul_input_buffer = torch.empty( - output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype, - device=query_layer.device - ) - - # Raw attention scores. [b * np, sq, sk] - matmul_result = torch.baddbmm( - matmul_input_buffer, - query_layer.transpose(0, 1), # [b * np, sq, hn] - key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] - beta=0.0, - alpha=(1.0 / self.norm_factor), - ) - - # change view to [b, np, sq, sk] - attention_scores = matmul_result.view(*output_size) - - # =========================== - # Attention probs and dropout - # =========================== - - # attention scores and attention mask [b, np, sq, sk] - if self.attention_softmax_in_fp32: - attention_scores = attention_scores.float() - if self.coeff is not None: - attention_scores = attention_scores * self.coeff - if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]: - attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3], - device=attention_scores.device, dtype=torch.bool) - attention_mask.tril_() - attention_mask = ~attention_mask - if attention_mask is not None: - attention_scores = attention_scores.masked_fill(attention_mask, float("-inf")) - attention_probs = F.softmax(attention_scores, dim=-1) - attention_probs = attention_probs.type_as(value_layer) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.attention_dropout(attention_probs) - # ========================= - # Context layer. [sq, b, hp] - # ========================= - - # value_layer -> context layer. - # [sk, b, np, hn] --> [b, np, sq, hn] - - # context layer shape: [b, np, sq, hn] - output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) - # change view [sk, b * np, hn] - value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) - # change view [b * np, sq, sk] - attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) - # matmul: [b * np, sq, hn] - context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) - # change view [b, np, sq, hn] - context_layer = context_layer.view(*output_size) - # [b, np, sq, hn] --> [sq, b, np, hn] - context_layer = context_layer.permute(2, 0, 1, 3).contiguous() - # [sq, b, np, hn] --> [sq, b, hp] - new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) - context_layer = context_layer.view(*new_context_layer_shape) - - return context_layer - - -class SelfAttention(torch.nn.Module): - """Parallel self-attention layer abstract class. - - Self-attention layer takes input with size [s, b, h] - and returns output of the same size. - """ - - def __init__(self, config: ChatGLMConfig, layer_number, device=None): - super(SelfAttention, self).__init__() - self.layer_number = max(1, layer_number) - - self.projection_size = config.kv_channels * config.num_attention_heads - - # Per attention head and per partition values. - self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads - self.num_attention_heads_per_partition = config.num_attention_heads - - self.multi_query_attention = config.multi_query_attention - self.qkv_hidden_size = 3 * self.projection_size - if self.multi_query_attention: - self.num_multi_query_groups_per_partition = config.multi_query_group_num - self.qkv_hidden_size = ( - self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num - ) - self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size, - bias=config.add_bias_linear or config.add_qkv_bias, - device=device, **_config_to_kwargs(config) - ) - - self.core_attention = CoreAttention(config, self.layer_number) - - # Output. - self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear, - device=device, **_config_to_kwargs(config) - ) - - def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None): - if self.multi_query_attention: - num_attention_heads = self.num_multi_query_groups_per_partition - else: - num_attention_heads = self.num_attention_heads_per_partition - return torch.empty( - inference_max_sequence_len, - batch_size, - num_attention_heads, - self.hidden_size_per_attention_head, - dtype=dtype, - device=device, - ) - - def forward( - self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True - ): - # hidden_states: [sq, b, h] - - # ================================================= - # Pre-allocate memory for key-values for inference. - # ================================================= - # ===================== - # Query, Key, and Value - # ===================== - - # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] - mixed_x_layer = self.query_key_value(hidden_states) - - if self.multi_query_attention: - (query_layer, key_layer, value_layer) = mixed_x_layer.split( - [ - self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, - self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, - self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, - ], - dim=-1, - ) - query_layer = query_layer.view( - query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) - ) - key_layer = key_layer.view( - key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) - ) - value_layer = value_layer.view( - value_layer.size()[:-1] - + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) - ) - else: - new_tensor_shape = mixed_x_layer.size()[:-1] + \ - (self.num_attention_heads_per_partition, - 3 * self.hidden_size_per_attention_head) - mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) - - # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] - (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) - - # apply relative positional encoding (rotary embedding) - if rotary_pos_emb is not None: - query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb) - key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb) - - # adjust key and value for inference - if kv_cache is not None: - cache_k, cache_v = kv_cache - key_layer = torch.cat((cache_k, key_layer), dim=0) - value_layer = torch.cat((cache_v, value_layer), dim=0) - if use_cache: - kv_cache = (key_layer, value_layer) - else: - kv_cache = None - - if self.multi_query_attention: - key_layer = key_layer.unsqueeze(-2) - key_layer = key_layer.expand( - -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1 - ) - key_layer = key_layer.contiguous().view( - key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) - ) - value_layer = value_layer.unsqueeze(-2) - value_layer = value_layer.expand( - -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1 - ) - value_layer = value_layer.contiguous().view( - value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) - ) - - # ================================== - # core attention computation - # ================================== - - context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) - - # ================= - # Output. [sq, b, h] - # ================= - - output = self.dense(context_layer) - - return output, kv_cache - - -def _config_to_kwargs(args): - common_kwargs = { - "dtype": args.torch_dtype, - } - return common_kwargs - - -class MLP(torch.nn.Module): - """MLP. - - MLP will take the input with h hidden state, project it to 4*h - hidden dimension, perform nonlinear transformation, and project the - state back into h hidden dimension. - """ - - def __init__(self, config: ChatGLMConfig, device=None): - super(MLP, self).__init__() - - self.add_bias = config.add_bias_linear - - # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf - self.dense_h_to_4h = nn.Linear( - config.hidden_size, - config.ffn_hidden_size * 2, - bias=self.add_bias, - device=device, - **_config_to_kwargs(config) - ) - - def swiglu(x): - x = torch.chunk(x, 2, dim=-1) - return F.silu(x[0]) * x[1] - - self.activation_func = swiglu - - # Project back to h. - self.dense_4h_to_h = nn.Linear( - config.ffn_hidden_size, - config.hidden_size, - bias=self.add_bias, - device=device, - **_config_to_kwargs(config) - ) - - def forward(self, hidden_states): - # [s, b, 4hp] - intermediate_parallel = self.dense_h_to_4h(hidden_states) - intermediate_parallel = self.activation_func(intermediate_parallel) - # [s, b, h] - output = self.dense_4h_to_h(intermediate_parallel) - return output - - -class GLMBlock(torch.nn.Module): - """A single transformer layer. - - Transformer layer takes input with size [s, b, h] and returns an - output of the same size. - """ - - def __init__(self, config: ChatGLMConfig, layer_number, device=None): - super(GLMBlock, self).__init__() - self.layer_number = layer_number - - self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm - - self.fp32_residual_connection = config.fp32_residual_connection - - LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm - # Layernorm on the input data. - self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, - dtype=config.torch_dtype) - - # Self attention. - self.self_attention = SelfAttention(config, layer_number, device=device) - self.hidden_dropout = config.hidden_dropout - - # Layernorm on the attention output - self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, - dtype=config.torch_dtype) - - # MLP - self.mlp = MLP(config, device=device) - - def forward( - self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True, - ): - # hidden_states: [s, b, h] - - # Layer norm at the beginning of the transformer layer. - layernorm_output = self.input_layernorm(hidden_states) - # Self attention. - attention_output, kv_cache = self.self_attention( - layernorm_output, - attention_mask, - rotary_pos_emb, - kv_cache=kv_cache, - use_cache=use_cache - ) - - # Residual connection. - if self.apply_residual_connection_post_layernorm: - residual = layernorm_output - else: - residual = hidden_states - - layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training) - layernorm_input = residual + layernorm_input - - # Layer norm post the self attention. - layernorm_output = self.post_attention_layernorm(layernorm_input) - - # MLP. - mlp_output = self.mlp(layernorm_output) - - # Second residual connection. - if self.apply_residual_connection_post_layernorm: - residual = layernorm_output - else: - residual = layernorm_input - - output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training) - output = residual + output - - return output, kv_cache - - -class GLMTransformer(torch.nn.Module): - """Transformer class.""" - - def __init__(self, config: ChatGLMConfig, device=None): - super(GLMTransformer, self).__init__() - - self.fp32_residual_connection = config.fp32_residual_connection - self.post_layer_norm = config.post_layer_norm - - # Number of layers. - self.num_layers = config.num_layers - - # Transformer layers. - def build_layer(layer_number): - return GLMBlock(config, layer_number, device=device) - - self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)]) - - if self.post_layer_norm: - LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm - # Final layer norm before output. - self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, - dtype=config.torch_dtype) - - self.gradient_checkpointing = False - - def _get_layer(self, layer_number): - return self.layers[layer_number] - - def forward( - self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None, - use_cache: Optional[bool] = True, - output_hidden_states: Optional[bool] = False, - ): - if not kv_caches: - kv_caches = [None for _ in range(self.num_layers)] - presents = () if use_cache else None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - all_self_attentions = None - all_hidden_states = () if output_hidden_states else None - for index in range(self.num_layers): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer = self._get_layer(index) - if self.gradient_checkpointing and self.training: - layer_ret = torch.utils.checkpoint.checkpoint( - layer, - hidden_states, - attention_mask, - rotary_pos_emb, - kv_caches[index], - use_cache - ) - else: - layer_ret = layer( - hidden_states, - attention_mask, - rotary_pos_emb, - kv_cache=kv_caches[index], - use_cache=use_cache - ) - hidden_states, kv_cache = layer_ret - if use_cache: - presents = presents + (kv_cache,) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - # Final layer norm. - if self.post_layer_norm: - hidden_states = self.final_layernorm(hidden_states) - - return hidden_states, presents, all_hidden_states, all_self_attentions - - -class ChatGLMPreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and - a simple interface for downloading and loading pretrained models. - """ - - is_parallelizable = False - supports_gradient_checkpointing = True - config_class = ChatGLMConfig - base_model_prefix = "transformer" - _no_split_modules = ["GLMBlock"] - - def _init_weights(self, module: nn.Module): - """Initialize the weights.""" - return - - def get_masks(self, input_ids, past_key_values, padding_mask=None): - batch_size, seq_length = input_ids.shape - full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device) - full_attention_mask.tril_() - past_length = 0 - if past_key_values: - past_length = past_key_values[0][0].shape[0] - if past_length: - full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length, - device=input_ids.device), full_attention_mask), dim=-1) - if padding_mask is not None: - full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1) - if not past_length and padding_mask is not None: - full_attention_mask -= padding_mask.unsqueeze(-1) - 1 - full_attention_mask = (full_attention_mask < 0.5).bool() - full_attention_mask.unsqueeze_(1) - return full_attention_mask - - def get_position_ids(self, input_ids, device): - batch_size, seq_length = input_ids.shape - position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) - return position_ids - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, GLMTransformer): - module.gradient_checkpointing = value - - -class Embedding(torch.nn.Module): - """Language model embeddings.""" - - def __init__(self, config: ChatGLMConfig, device=None): - super(Embedding, self).__init__() - - self.hidden_size = config.hidden_size - # Word embeddings (parallel). - self.word_embeddings = nn.Embedding( - config.padded_vocab_size, - self.hidden_size, - dtype=config.torch_dtype, - device=device - ) - self.fp32_residual_connection = config.fp32_residual_connection - - def forward(self, input_ids): - # Embeddings. - words_embeddings = self.word_embeddings(input_ids) - embeddings = words_embeddings - # Data format change to avoid explicit transposes : [b s h] --> [s b h]. - embeddings = embeddings.transpose(0, 1).contiguous() - # If the input flag for fp32 residual connection is set, convert for float. - if self.fp32_residual_connection: - embeddings = embeddings.float() - return embeddings - - -class ChatGLMModel(ChatGLMPreTrainedModel): - def __init__(self, config: ChatGLMConfig, device=None, empty_init=True): - super().__init__(config) - if empty_init: - init_method = skip_init - else: - init_method = default_init - init_kwargs = {} - if device is not None: - init_kwargs["device"] = device - self.embedding = init_method(Embedding, config, **init_kwargs) - self.num_layers = config.num_layers - self.multi_query_group_num = config.multi_query_group_num - self.kv_channels = config.kv_channels - - # Rotary positional embeddings - self.seq_length = config.seq_length - rotary_dim = ( - config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels - ) - - self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device, - dtype=config.torch_dtype) - self.encoder = init_method(GLMTransformer, config, **init_kwargs) - self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False, - dtype=config.torch_dtype, **init_kwargs) - self.pre_seq_len = config.pre_seq_len - self.prefix_projection = config.prefix_projection - if self.pre_seq_len is not None: - for param in self.parameters(): - param.requires_grad = False - self.prefix_tokens = torch.arange(self.pre_seq_len).long() - self.prefix_encoder = PrefixEncoder(config) - self.dropout = torch.nn.Dropout(0.1) - - def get_input_embeddings(self): - return self.embedding.word_embeddings - - def get_prompt(self, batch_size, device, dtype=torch.half): - prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device) - past_key_values = self.prefix_encoder(prefix_tokens).type(dtype) - past_key_values = past_key_values.view( - batch_size, - self.pre_seq_len, - self.num_layers * 2, - self.multi_query_group_num, - self.kv_channels - ) - # seq_len, b, nh, hidden_size - past_key_values = self.dropout(past_key_values) - past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) - return past_key_values - - def forward( - self, - input_ids, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.BoolTensor] = None, - full_attention_mask: Optional[torch.BoolTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - batch_size, seq_length = input_ids.shape - - if inputs_embeds is None: - inputs_embeds = self.embedding(input_ids) - - if self.pre_seq_len is not None: - if past_key_values is None: - past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device, - dtype=inputs_embeds.dtype) - if attention_mask is not None: - attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)), - attention_mask], dim=-1) - - if full_attention_mask is None: - if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): - full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) - - # Rotary positional embeddings - rotary_pos_emb = self.rotary_pos_emb(self.seq_length) - if position_ids is not None: - rotary_pos_emb = rotary_pos_emb[position_ids] - else: - rotary_pos_emb = rotary_pos_emb[None, :seq_length] - rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() - - # Run encoder. - hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( - inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb, - kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states - ) - - if not return_dict: - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=presents, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) - - def quantize(self, weight_bit_width: int): - from .quantization import quantize - quantize(self.encoder, weight_bit_width) - return self - - -class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): - def __init__(self, config: ChatGLMConfig, empty_init=True, device=None): - super().__init__(config) - - self.max_sequence_length = config.max_length - self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device) - self.config = config - self.quantized = False - - if self.config.quantization_bit: - self.quantize(self.config.quantization_bit, empty_init=True) - - def _update_model_kwargs_for_generation( - self, - outputs: ModelOutput, - model_kwargs: Dict[str, Any], - is_encoder_decoder: bool = False, - standardize_cache_format: bool = False, - ) -> Dict[str, Any]: - # update past_key_values - model_kwargs["past_key_values"] = self._extract_past_from_model_output( - outputs, standardize_cache_format=standardize_cache_format - ) - - # update attention mask - if "attention_mask" in model_kwargs: - attention_mask = model_kwargs["attention_mask"] - model_kwargs["attention_mask"] = torch.cat( - [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 - ) - - # update position ids - if "position_ids" in model_kwargs: - position_ids = model_kwargs["position_ids"] - new_position_id = position_ids[..., -1:].clone() - new_position_id += 1 - model_kwargs["position_ids"] = torch.cat( - [position_ids, new_position_id], dim=-1 - ) - - model_kwargs["is_first_forward"] = False - return model_kwargs - - def prepare_inputs_for_generation( - self, - input_ids: torch.LongTensor, - past_key_values: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - is_first_forward: bool = True, - **kwargs - ) -> dict: - # only last token for input_ids if past is not None - if position_ids is None: - position_ids = self.get_position_ids(input_ids, device=input_ids.device) - if not is_first_forward: - if past_key_values is not None: - position_ids = position_ids[..., -1:] - input_ids = input_ids[:, -1:] - return { - "input_ids": input_ids, - "past_key_values": past_key_values, - "position_ids": position_ids, - "attention_mask": attention_mask, - "return_last_logit": True, - "use_cache": use_cache - } - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - return_last_logit: Optional[bool] = False, - ): - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.transformer( - input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = transformer_outputs[0] - if return_last_logit: - hidden_states = hidden_states[-1:] - lm_logits = self.transformer.output_layer(hidden_states) - lm_logits = lm_logits.transpose(0, 1).contiguous() - - loss = None - if labels is not None: - lm_logits = lm_logits.to(torch.float32) - - # Shift so that tokens < n predict n - shift_logits = lm_logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss(ignore_index=-100) - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - - lm_logits = lm_logits.to(hidden_states.dtype) - loss = loss.to(hidden_states.dtype) - - if not return_dict: - output = (lm_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=lm_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) - - @staticmethod - def _reorder_cache( - past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor - ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: - """ - This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - - Output shares the same memory storage as `past`. - """ - return tuple( - ( - layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)), - layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)), - ) - for layer_past in past - ) - - def process_response(self, output, history): - content = "" - history = deepcopy(history) - for response in output.split("<|assistant|>"): - metadata, content = response.split("\n", maxsplit=1) - if not metadata.strip(): - content = content.strip() - history.append({"role": "assistant", "metadata": metadata, "content": content}) - content = content.replace("[[训练时间]]", "2023年") - else: - history.append({"role": "assistant", "metadata": metadata, "content": content}) - if history[0]["role"] == "system" and "tools" in history[0]: - content = "\n".join(content.split("\n")[1:-1]) - def tool_call(**kwargs): - return kwargs - parameters = eval(content) - content = {"name": metadata.strip(), "parameters": parameters} - else: - content = {"name": metadata.strip(), "content": content} - return content, history - - @torch.inference_mode() - def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user", - max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, - **kwargs): - if history is None: - history = [] - if logits_processor is None: - logits_processor = LogitsProcessorList() - logits_processor.append(InvalidScoreLogitsProcessor()) - gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p, - "temperature": temperature, "logits_processor": logits_processor, **kwargs} - inputs = tokenizer.build_chat_input(query, history=history, role=role) - inputs = inputs.to(self.device) - eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"), - tokenizer.get_command("<|observation|>")] - outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id) - outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1] - response = tokenizer.decode(outputs) - history.append({"role": role, "content": query}) - response, history = self.process_response(response, history) - return response, history - - @torch.inference_mode() - def stream_chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user", - past_key_values=None,max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8, - logits_processor=None, return_past_key_values=False, **kwargs): - if history is None: - history = [] - if logits_processor is None: - logits_processor = LogitsProcessorList() - logits_processor.append(InvalidScoreLogitsProcessor()) - eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"), - tokenizer.get_command("<|observation|>")] - gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p, - "temperature": temperature, "logits_processor": logits_processor, **kwargs} - if past_key_values is None: - inputs = tokenizer.build_chat_input(query, history=history, role=role) - else: - inputs = tokenizer.build_chat_input(query, role=role) - inputs = inputs.to(self.device) - if past_key_values is not None: - past_length = past_key_values[0][0].shape[0] - if self.transformer.pre_seq_len is not None: - past_length -= self.transformer.pre_seq_len - inputs.position_ids += past_length - attention_mask = inputs.attention_mask - attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1) - inputs['attention_mask'] = attention_mask - history.append({"role": role, "content": query}) - for outputs in self.stream_generate(**inputs, past_key_values=past_key_values, - eos_token_id=eos_token_id, return_past_key_values=return_past_key_values, - **gen_kwargs): - if return_past_key_values: - outputs, past_key_values = outputs - outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1] - response = tokenizer.decode(outputs) - if response and response[-1] != "�": - response, new_history = self.process_response(response, history) - if return_past_key_values: - yield response, new_history, past_key_values - else: - yield response, new_history - - @torch.inference_mode() - def stream_generate( - self, - input_ids, - generation_config: Optional[GenerationConfig] = None, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, - return_past_key_values=False, - **kwargs, - ): - batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] - - if generation_config is None: - generation_config = self.generation_config - generation_config = copy.deepcopy(generation_config) - model_kwargs = generation_config.update(**kwargs) - model_kwargs["use_cache"] = generation_config.use_cache - bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id - - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None - - has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None - if has_default_max_length and generation_config.max_new_tokens is None: - warnings.warn( - f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " - "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" - " recommend using `max_new_tokens` to control the maximum length of the generation.", - UserWarning, - ) - elif generation_config.max_new_tokens is not None: - generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length - if not has_default_max_length: - logger.warn( - f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" - f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " - "Please refer to the documentation for more information. " - "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", - UserWarning, - ) - - if input_ids_seq_length >= generation_config.max_length: - input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" - logger.warning( - f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" - f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" - " increasing `max_new_tokens`." - ) - - # 2. Set generation parameters if not already defined - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - - logits_processor = self._get_logits_processor( - generation_config=generation_config, - input_ids_seq_length=input_ids_seq_length, - encoder_input_ids=input_ids, - prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, - logits_processor=logits_processor, - ) - - stopping_criteria = self._get_stopping_criteria( - generation_config=generation_config, stopping_criteria=stopping_criteria - ) - logits_warper = self._get_logits_warper(generation_config) - - unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) - scores = None - while True: - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - # forward pass to get next token - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=False, - output_hidden_states=False, - ) - - next_token_logits = outputs.logits[:, -1, :] - - # pre-process distribution - next_token_scores = logits_processor(input_ids, next_token_logits) - next_token_scores = logits_warper(input_ids, next_token_scores) - - # sample - probs = nn.functional.softmax(next_token_scores, dim=-1) - if generation_config.do_sample: - next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) - else: - next_tokens = torch.argmax(probs, dim=-1) - # update generated ids, model inputs, and length for next step - input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - model_kwargs = self._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder - ) - unfinished_sequences = unfinished_sequences.mul( - next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) - ) - if return_past_key_values: - yield input_ids, outputs.past_key_values - else: - yield input_ids - # stop when each sentence is finished, or if we exceed the maximum length - if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): - break - - def quantize(self, bits: int, empty_init=False, device=None, **kwargs): - if bits == 0: - return - - from .quantization import quantize - - if self.quantized: - logger.info("Already quantized.") - return self - - self.quantized = True - - self.config.quantization_bit = bits - - self.transformer.encoder = quantize(self.transformer.encoder, bits, empty_init=empty_init, device=device, - **kwargs) - return self - - -class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel): - def __init__(self, config: ChatGLMConfig, empty_init=True, device=None): - super().__init__(config) - - self.num_labels = config.num_labels - self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device) - - self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half) - if config.classifier_dropout is not None: - self.dropout = nn.Dropout(config.classifier_dropout) - else: - self.dropout = None - self.config = config - - if self.config.quantization_bit: - self.quantize(self.config.quantization_bit, empty_init=True) - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - full_attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, - inputs_embeds: Optional[torch.LongTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.transformer( - input_ids=input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - full_attention_mask=full_attention_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = transformer_outputs[0] - pooled_hidden_states = hidden_states[-1] - if self.dropout is not None: - pooled_hidden_states = self.dropout(pooled_hidden_states) - logits = self.classifier_head(pooled_hidden_states) - - loss = None - if labels is not None: - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(logits.squeeze().float(), labels.squeeze()) - else: - loss = loss_fct(logits.float(), labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels).float(), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(logits.float(), labels.view(-1, self.num_labels)) - - if not return_dict: - output = (logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/modeling_llama.py b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/modeling_llama.py deleted file mode 100644 index 4cd1b6e18e8..00000000000 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/modeling_llama.py +++ /dev/null @@ -1,1263 +0,0 @@ -# coding=utf-8 -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" PyTorch LLaMA model.""" -import math -import warnings -from typing import List, Optional, Tuple, Union - -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss - -from transformers.activations import ACT2FN -from transformers.modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast -from transformers.modeling_utils import PreTrainedModel -from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS -from transformers.utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - logging, - replace_return_docstrings, -) -from transformers.utils.import_utils import is_torch_fx_available -from transformers.models.llama.configuration_llama import LlamaConfig - - -if is_flash_attn_2_available(): - from flash_attn import flash_attn_func, flash_attn_varlen_func - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa -### INC code ### -from neural_compressor.torch.quantization.modules import Matmul, BatchMatmul, Autocast - -# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. -# It means that the function will not be traced through and simply appear as a node in the graph. -if is_torch_fx_available(): - _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) - - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "LlamaConfig" - - -def _get_unpad_data(attention_mask): - seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) - return ( - indices, - cu_seqlens, - max_seqlen_in_batch, - ) - - -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - warnings.warn( - "Calling `transformers.models.llama.modeling_llama._prepare_4d_attention_mask` is deprecated and will be removed in v4.37. Use `transformers.modeling_attn_mask_utils.AttentionMaskConverter._prepare_4d_attention_mask" - ) - return AttentionMaskConverter._prepare_4d_attention_mask(mask=mask, dtype=dtype, tgt_len=tgt_len) - - -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): - warnings.warn( - "Calling `transformers.models.llama.modeling_llama._make_causal_mask` is deprecated and will be removed in v4.37. Use `transformers.models.llama.modeling_llama.AttentionMaskConverter._make_causal_mask" - ) - return AttentionMaskConverter._make_causal_mask( - input_ids_shape=input_ids_shape, dtype=dtype, device=device, past_key_values_length=past_key_values_length - ) - - -class LlamaRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - -ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm) - - -class LlamaRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), - ) - - -class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - t = t / self.scaling_factor - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - -class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - - if seq_len > self.max_position_embeddings: - base = self.base * ( - (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) - ) ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -class LlamaMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - if self.config.pretraining_tp > 1: - slice = self.intermediate_size // self.config.pretraining_tp - gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) - up_proj_slices = self.up_proj.weight.split(slice, dim=0) - down_proj_slices = self.down_proj.weight.split(slice, dim=1) - - gate_proj = torch.cat( - [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 - ) - up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) - - intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) - down_proj = [ - F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) - ] - down_proj = sum(down_proj) - else: - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - return down_proj - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -class LlamaAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: LlamaConfig): - super().__init__() - self.config = config - self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) - self._init_rope() - ### INC code ### - self.matmul1 = Matmul() - self.matmul2 = Matmul() - self.cast1 = Autocast() - self.cast2 = Autocast() - - def _init_rope(self): - if self.config.rope_scaling is None: - self.rotary_emb = LlamaRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) - else: - scaling_type = self.config.rope_scaling["type"] - scaling_factor = self.config.rope_scaling["factor"] - if scaling_type == "linear": - self.rotary_emb = LlamaLinearScalingRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - ) - elif scaling_type == "dynamic": - self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - ) - else: - raise ValueError(f"Unknown RoPE scaling type {scaling_type}") - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - - bsz, q_len, _ = hidden_states.size() - - if self.config.pretraining_tp > 1: - key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 - ) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] - query_states = torch.cat(query_states, dim=-1) - - key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] - key_states = torch.cat(key_states, dim=-1) - - value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] - value_states = torch.cat(value_states, dim=-1) - - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - ### INC code ### - key_states = self.cast1(key_states) - value_states = self.cast2(value_states) - # import habana_frameworks.torch.core as htcore - # htcore.mark_step() - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - ### INC code ### - attn_weights = self.matmul1(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - - ### INC code ### - attn_output = self.matmul2(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - if self.config.pretraining_tp > 1: - attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) - attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) - else: - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class LlamaFlashAttention2(LlamaAttention): - """ - Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # LlamaFlashAttention2 attention does not support output_attentions - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - - # overwrite attention_mask with padding_mask - attention_mask = kwargs.pop("padding_mask") - - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = 0.0 if not self.training else self.attention_dropout - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - # Handle the case where the model is quantized - if hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = self._flash_attention_forward( - query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate - ) - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - def _flash_attention_forward( - self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None - ): - """ - Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token - first unpad the input, then computes the attention scores and pad the final attention scores. - - Args: - query_states (`torch.Tensor`): - Input query states to be passed to Flash Attention API - key_states (`torch.Tensor`): - Input key states to be passed to Flash Attention API - value_states (`torch.Tensor`): - Input value states to be passed to Flash Attention API - attention_mask (`torch.Tensor`): - The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the - position of padding tokens and 1 for the position of non-padding tokens. - dropout (`int`, *optional*): - Attention dropout - softmax_scale (`float`, *optional*): - The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) - """ - # Contains at least one padding token in the sequence - if attention_mask is not None: - batch_size = query_states.shape[0] - query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( - query_states, key_states, value_states, attention_mask, query_length - ) - - cu_seqlens_q, cu_seqlens_k = cu_seq_lens - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=self.is_causal, - ) - - attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) - else: - attn_output = flash_attn_func( - query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal - ) - - return attn_output - - def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) - batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape - - key_layer = index_first_axis( - key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k - ) - value_layer = index_first_axis( - value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k - ) - if query_length == kv_seq_len: - query_layer = index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k - ) - cu_seqlens_q = cu_seqlens_k - max_seqlen_in_batch_q = max_seqlen_in_batch_k - indices_q = indices_k - elif query_length == 1: - max_seqlen_in_batch_q = 1 - cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, device=query_layer.device - ) # There is a memcpy here, that is very bad. - indices_q = cu_seqlens_q[:-1] - query_layer = query_layer.squeeze(1) - else: - # The -q_len: slice assumes left padding. - attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) - - return ( - query_layer, - key_layer, - value_layer, - indices_q, - (cu_seqlens_q, cu_seqlens_k), - (max_seqlen_in_batch_q, max_seqlen_in_batch_k), - ) - - -class LlamaDecoderLayer(nn.Module): - def __init__(self, config: LlamaConfig): - super().__init__() - self.hidden_size = config.hidden_size - self.self_attn = ( - LlamaAttention(config=config) - if not getattr(config, "_flash_attn_2_enabled", False) - else LlamaFlashAttention2(config=config) - ) - self.mlp = LlamaMLP(config) - self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - **kwargs, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -LLAMA_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`LlamaConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", - LLAMA_START_DOCSTRING, -) -class LlamaPreTrainedModel(PreTrainedModel): - config_class = LlamaConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["LlamaDecoderLayer"] - _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - -LLAMA_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, decoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't - have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` - of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. -""" - - -@add_start_docstrings( - "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", - LLAMA_START_DOCSTRING, -) -class LlamaModel(LlamaPreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] - - Args: - config: LlamaConfig - """ - - def __init__(self, config: LlamaConfig): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) - self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape[:2] - elif inputs_embeds is not None: - batch_size, seq_length = inputs_embeds.shape[:2] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - past_key_values_length = 0 - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0) - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if getattr(self.config, "_flash_attn_2_enabled", False): - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) - - # embed positions - hidden_states = inputs_embeds - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - for idx, decoder_layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = past_key_values[idx] if past_key_values is not None else None - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - position_ids, - past_key_value, - output_attentions, - use_cache, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class LlamaForCausalLM(LlamaPreTrainedModel): - _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config): - super().__init__(config) - self.model = LlamaModel(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM - - >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - if self.config.pretraining_tp > 1: - lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) - logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] - logits = torch.cat(logits, dim=-1) - else: - logits = self.lm_head(hidden_states) - logits = logits.float() - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs - ): - if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] - - # Some generation methods already pass only the last input ID - if input_ids.shape[1] > past_length: - remove_prefix_length = past_length - else: - # Default to old behavior: keep only final ID - remove_prefix_length = input_ids.shape[1] - 1 - - input_ids = input_ids[:, remove_prefix_length:] - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - } - ) - return model_inputs - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - return reordered_past - - -@add_start_docstrings( - """ - The LLaMa Model transformer with a sequence classification head on top (linear layer). - - [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models - (e.g. GPT-2) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """, - LLAMA_START_DOCSTRING, -) -class LlamaForSequenceClassification(LlamaPreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = LlamaModel(config) - self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, SequenceClassifierOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - transformer_outputs = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = transformer_outputs[0] - logits = self.score(hidden_states) - - if input_ids is not None: - batch_size = input_ids.shape[0] - else: - batch_size = inputs_embeds.shape[0] - - if self.config.pad_token_id is None and batch_size != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - if self.config.pad_token_id is None: - sequence_lengths = -1 - else: - if input_ids is not None: - sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to( - logits.device - ) - else: - sequence_lengths = -1 - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] - - loss = None - if labels is not None: - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(pooled_logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(pooled_logits, labels) - if not return_dict: - output = (pooled_logits,) + transformer_outputs[1:] - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/tokenization_baichuan.py b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/tokenization_baichuan.py deleted file mode 100644 index 5b7054d3227..00000000000 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/models/tokenization_baichuan.py +++ /dev/null @@ -1,255 +0,0 @@ -# Copyright 2023 Baichuan Inc. All Rights Reserved. - -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from shutil import copyfile -from typing import Any, Dict, List, Optional, Tuple - -import sentencepiece as spm - -from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer -from transformers.utils import logging - - -logger = logging.get_logger(__name__) - -VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"} - -PRETRAINED_VOCAB_FILES_MAP = { - "vocab_file": {}, - "tokenizer_file": {}, -} -PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {} - - -class BaichuanTokenizer(PreTrainedTokenizer): - """ - Construct a Baichuan tokenizer. Based on byte-level Byte-Pair-Encoding. - - Args: - vocab_file (`str`): - Path to the vocabulary file. - """ - - vocab_files_names = VOCAB_FILES_NAMES - pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP - max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES - model_input_names = ["input_ids", "attention_mask"] - - def __init__( - self, - vocab_file, - unk_token="", - bos_token="", - eos_token="", - pad_token=None, - sp_model_kwargs: Optional[Dict[str, Any]] = None, - add_bos_token=True, - add_eos_token=False, - clean_up_tokenization_spaces=False, - **kwargs, - ): - self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs - bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token - eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token - unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token - pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token - ### INC code ### - self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) - self.sp_model.Load(vocab_file) - - super().__init__( - bos_token=bos_token, - eos_token=eos_token, - unk_token=unk_token, - pad_token=pad_token, - add_bos_token=add_bos_token, - add_eos_token=add_eos_token, - sp_model_kwargs=self.sp_model_kwargs, - clean_up_tokenization_spaces=clean_up_tokenization_spaces, - **kwargs, - ) - self.vocab_file = vocab_file - self.add_bos_token = add_bos_token - self.add_eos_token = add_eos_token - #self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) - #self.sp_model.Load(vocab_file) - - def __getstate__(self): - state = self.__dict__.copy() - state["sp_model"] = None - return state - - def __setstate__(self, d): - self.__dict__ = d - self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) - self.sp_model.Load(self.vocab_file) - - @property - def vocab_size(self): - """Returns vocab size""" - return self.sp_model.get_piece_size() - - def get_vocab(self): - """Returns vocab as a dict""" - vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} - vocab.update(self.added_tokens_encoder) - return vocab - - def _tokenize(self, text): - """Returns a tokenized string.""" - return self.sp_model.encode(text, out_type=str) - - def _convert_token_to_id(self, token): - """Converts a token (str) in an id using the vocab.""" - return self.sp_model.piece_to_id(token) - - def _convert_id_to_token(self, index): - """Converts an index (integer) in a token (str) using the vocab.""" - token = self.sp_model.IdToPiece(index) - return token - - def convert_tokens_to_string(self, tokens): - """Converts a sequence of tokens (string) in a single string.""" - current_sub_tokens = [] - out_string = "" - prev_is_special = False - for i, token in enumerate(tokens): - # make sure that special tokens are not decoded using sentencepiece model - if token in self.all_special_tokens: - if not prev_is_special and i != 0: - out_string += " " - out_string += self.sp_model.decode(current_sub_tokens) + token - prev_is_special = True - current_sub_tokens = [] - else: - current_sub_tokens.append(token) - prev_is_special = False - out_string += self.sp_model.decode(current_sub_tokens) - return out_string - - def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]: - """ - Save the vocabulary and special tokens file to a directory. - - Args: - save_directory (`str`): - The directory in which to save the vocabulary. - - Returns: - `Tuple(str)`: Paths to the files saved. - """ - if not os.path.isdir(save_directory): - logger.error(f"Vocabulary path ({save_directory}) should be a directory") - return - out_vocab_file = os.path.join( - save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] - ) - - if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): - copyfile(self.vocab_file, out_vocab_file) - elif not os.path.isfile(self.vocab_file): - with open(out_vocab_file, "wb") as fi: - content_spiece_model = self.sp_model.serialized_model_proto() - fi.write(content_spiece_model) - - return (out_vocab_file,) - - def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): - bos_token_id = [self.bos_token_id] if self.add_bos_token else [] - eos_token_id = [self.eos_token_id] if self.add_eos_token else [] - - output = bos_token_id + token_ids_0 + eos_token_id - - if token_ids_1 is not None: - output = output + bos_token_id + token_ids_1 + eos_token_id - - return output - - def get_special_tokens_mask( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False - ) -> List[int]: - """ - Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding - special tokens using the tokenizer `prepare_for_model` method. - - Args: - token_ids_0 (`List[int]`): - List of IDs. - token_ids_1 (`List[int]`, *optional*): - Optional second list of IDs for sequence pairs. - already_has_special_tokens (`bool`, *optional*, defaults to `False`): - Whether or not the token list is already formatted with special tokens for the model. - - Returns: - `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. - """ - if already_has_special_tokens: - return super().get_special_tokens_mask( - token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True - ) - - bos_token_id = [1] if self.add_bos_token else [] - eos_token_id = [1] if self.add_eos_token else [] - - if token_ids_1 is None: - return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id - return ( - bos_token_id - + ([0] * len(token_ids_0)) - + eos_token_id - + bos_token_id - + ([0] * len(token_ids_1)) - + eos_token_id - ) - - def create_token_type_ids_from_sequences( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None - ) -> List[int]: - """ - Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT - sequence pair mask has the following format: - - ``` - 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 - | first sequence | second sequence | - ``` - - if token_ids_1 is None, only returns the first portion of the mask (0s). - - Args: - token_ids_0 (`List[int]`): - List of ids. - token_ids_1 (`List[int]`, *optional*): - Optional second list of IDs for sequence pairs. - - Returns: - `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). - """ - bos_token_id = [self.bos_token_id] if self.add_bos_token else [] - eos_token_id = [self.eos_token_id] if self.add_eos_token else [] - - output = [0] * len(bos_token_id + token_ids_0 + eos_token_id) - - if token_ids_1 is not None: - output += [1] * len(bos_token_id + token_ids_1 + eos_token_id) - - return output diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/requirement.txt b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/requirement.txt deleted file mode 100644 index d3655acd742..00000000000 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/requirement.txt +++ /dev/null @@ -1,7 +0,0 @@ -transformers -datasets -accelerate -SentencePiece -lm_eval==0.3.0 -openpyxl -einops diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/run_llm.py b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/run_llm.py deleted file mode 100644 index e77ef2c6a33..00000000000 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/run_llm.py +++ /dev/null @@ -1,222 +0,0 @@ -import os -os.environ["EXPERIMENTAL_WEIGHT_SHARING"] = "False" - -### USE_GAUDI2_SCALE requires PT_USE_FP8_AMAX for torch.mm/bmm, or got failure -# os.environ["USE_GAUDI2_SCALE"] = "True" -# os.environ["PT_USE_FP8_AMAX"] = "True" - -### graphs will dump to .graph_dumps folder -# os.environ["GRAPH_VISUALIZATION"] = "True" -# import shutil -# shutil.rmtree(".graph_dumps", ignore_errors=True) - -import argparse -import time -import json -import re -import torch -import habana_frameworks.torch.hpex -import torch.nn.functional as F -import deepspeed -import transformers -from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig -import habana_frameworks.torch.core as htcore - -from utils import show_msg, eval_func, init_empty_model, init_model, init_tokenizer - - -torch.set_grad_enabled(False) -htcore.hpu_set_env() -torch.device('hpu') - - -parser = argparse.ArgumentParser() -parser.add_argument( - "--model", nargs="?", default="facebook/opt-125m" -) -parser.add_argument( - "--trust_remote_code", default=True, - help="Transformers parameter: use the external repo") -parser.add_argument( - "--revision", default=None, - help="Transformers parameter: set the model hub commit number") -parser.add_argument("--dataset", nargs="?", default="NeelNanda/pile-10k", const="NeelNanda/pile-10k") -parser.add_argument("--output_dir", nargs="?", default="./saved_results") -parser.add_argument("--to_graph", action="store_true") -parser.add_argument("--approach", type=str, default=None, - help="Select from ['dynamic', 'static' 'cast']") -parser.add_argument("--precision", type=str, default='fp32', - help="Select from ['fp8_e4m3', 'fp8_e5m2', 'bf16', 'fp16', 'fp32'], \ - ['bf16', 'fp16'] only work with cast approach") -parser.add_argument("--autotune", action="store_true") -parser.add_argument("--accuracy", action="store_true") -parser.add_argument("--performance", action="store_true") -parser.add_argument("--generate", action="store_true") -parser.add_argument("--skip_fp8_mm", action="store_true") -parser.add_argument("--dump_to_excel", action="store_true") -parser.add_argument("--save", action="store_true") -parser.add_argument("--load", action="store_true") -parser.add_argument("--batch_size", default=1, type=int, - help="For accuracy measurement only.") -parser.add_argument("--pad_max_length", default=512, type=int, - help="Pad input ids to max length.") -parser.add_argument("--calib_iters", default=100, type=int, - help="calibration iters.") -parser.add_argument("--tasks", nargs='+', default=["lambada_openai"], \ - type=str, choices=["hellaswag", "lambada_openai", "piqa", "winogrande", "copa", - "rte", "openbookqa", "lambada_standard", "wikitext"], - help="tasks list for accuracy validation") -parser.add_argument("--limit", default=None, type=int, - help="the sample num of evaluation.") -parser.add_argument("--max_new_tokens", default=100, type=int, - help="calibration iters.") -parser.add_argument('--buckets', type=int, nargs='+', \ - help="Input length buckets to use with static_shapes", default=[256, 512]) -parser.add_argument("--local_rank", - type=int, - default=-1, - help="local_rank for distributed training on gpus") -parser.add_argument("--skip_lm_head", action="store_true") -args = parser.parse_args() - - -world_size = int(os.getenv('WORLD_SIZE', '1')) -local_rank = int(os.getenv('LOCAL_RANK', '-1')) - - -if args.load: - user_model = init_empty_model(args.model) -else: - user_model = init_model(args) -user_model.eval() - - -tokenizer = init_tokenizer(args) - - -### dynamic & static quantization ### -if args.approach in ["dynamic", "static"] and not args.load: - print("device:", next(user_model.parameters()).device) - from neural_compressor.torch.quantization import ( - quantize, autotune, FP8Config, get_default_fp8_config, TuningConfig, get_default_fp8_config_set - ) - dtype = args.precision - if args.approach == "dynamic": - from neural_compressor.torch.algorithms.habana_fp8 import quantize_dynamic - user_model = quantize_dynamic(user_model, dtype, inplace=True) - elif args.approach == "static": - qconfig = FP8Config(w_dtype=dtype, act_dtype=dtype, approach="static") - if args.skip_lm_head: - fp32_config = FP8Config(w_dtype="fp32", act_dtype="fp32") - qconfig.set_local("lm_head", fp32_config) - # dataset - from datasets import load_dataset - calib_dataset = load_dataset(args.dataset, split="train").select(range(100)) - calib_dataset = calib_dataset.shuffle(seed=42) - calib_data = [] - for examples in calib_dataset: - calib_data.append( - tokenizer( - examples["text"], - return_tensors="pt", - max_length=64, - padding="max_length", - truncation=True - ) - ) - - def calib_func(model): - for i, calib_input in enumerate(calib_data): - if i >= args.calib_iters: - break - model( - input_ids=calib_input["input_ids"].to('hpu'), - attention_mask=calib_input["attention_mask"].to('hpu'), - ) - - user_model = quantize(user_model, qconfig, calib_func, inplace=True) - # saving - print(user_model) - if args.save and local_rank in [-1, 0]: - user_model.save("saved_results") - - -if args.load: - from neural_compressor.torch.quantization import load - user_model = load("saved_results", user_model) - - -if args.approach in ["dynamic", "static"] or args.load: - # It enables weights constant folding - from habana_frameworks.torch.core.quantization import _check_params_as_const, _mark_params_as_const - _mark_params_as_const(user_model) # can reduce memory allocated and speed up - _check_params_as_const(user_model) - - - -# If torch.matmul and torch.bmm are not replaced by INC module, -# Below codes can make torch.matmul and torch.bmm run on fp8 by injection. -if not args.skip_fp8_mm and args.precision in ['fp8_e4m3', 'fp8_e5m2']: - def replace_torch_mm_bmm(): - from neural_compressor.torch.amp.fp8.functions import fp8_matmul - torch.matmul = fp8_matmul - torch.bmm = fp8_matmul - - replace_torch_mm_bmm() - - -# inference optimization -if args.to_graph: - import habana_frameworks.torch.hpu.graphs as htgraphs - user_model = htgraphs.wrap_in_hpu_graph(user_model) - - -# dump message of HPU after quantization or reloading -show_msg() - - -### generation, performance and accuracy validation ### -if args.generate: - input_prompt = "Here is my prompt" - print("Prompt sentence:", input_prompt) - generation_config = { - "min_new_tokens": args.max_new_tokens, "max_new_tokens": args.max_new_tokens, - # "do_sample": False, "temperature": 0.9, "num_beams": 4, - } - input_tokens = tokenizer(input_prompt, return_tensors="pt").to('hpu') - eval_start = time.perf_counter() - if args.approach == "cast": - from neural_compressor.torch.amp import autocast - if args.precision == "fp8_e4m3": - dtype = torch.float8_e4m3fn - elif args.precision == "fp8_e5m2": - dtype = torch.float8_e5m2 - elif args.precision == "fp16": - dtype = torch.float16 - elif args.precision == "bf16": - dtype = torch.bfloat16 - with autocast('hpu', dtype=dtype): - outputs = user_model.generate(**input_tokens, **generation_config) - else: - outputs = user_model.generate(**input_tokens, **generation_config) - - output_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True) - eval_end = time.perf_counter() - print("Generated sentence:", output_sentence) - print("Duration:", eval_end - eval_start) - - -if args.performance: - eval_start = time.perf_counter() - input_prompt = "Intel is a company which" - input_tokens = torch.ones((1, 128), dtype=torch.long).to('hpu') - generation_config = {"min_new_tokens": 100, "max_new_tokens": 100} - outputs = user_model.generate(input_tokens, **generation_config) - print("Duration of generating 100 tokens :", time.perf_counter() - eval_start) - - -if args.accuracy: - eval_func(user_model, tokenizer=tokenizer, args=args) - -# dump final message of HPU -show_msg() diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/utils.py b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/utils.py deleted file mode 100644 index 843287cddfa..00000000000 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/habana_fp8/utils.py +++ /dev/null @@ -1,255 +0,0 @@ -import os -import re -import torch -from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer - - -world_size = int(os.getenv('WORLD_SIZE', '1')) -local_rank = int(os.getenv('LOCAL_RANK', '-1')) - - -def init_model(args): - import deepspeed - model_dtype = torch.float32 - if re.search("llama", args.model.lower()) or re.search("bloom", args.model.lower()): - if world_size > 1: - config = AutoConfig.from_pretrained(args.model) - model_dtype = torch.bfloat16 # RuntimeErrorCastToFp8V2 input must be of float or bfloat16 dtype - deepspeed.init_distributed(dist_backend="hccl") - with deepspeed.OnDevice(dtype=model_dtype, device="meta"): - user_model = AutoModelForCausalLM.from_config(config, torch_dtype=model_dtype) - import tempfile - checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w") - from optimum.habana.checkpoint_utils import write_checkpoints_json # in optimum-habana - write_checkpoints_json( - args.model, - local_rank, - checkpoints_json, - token=None, - ) - else: - user_model = AutoModelForCausalLM.from_pretrained( - args.model, - device_map='hpu', - torch_dtype=model_dtype, - ) - elif re.search("chatglm", args.model.lower()): - from models.modeling_chatglm import ChatGLMForConditionalGeneration - user_model = ChatGLMForConditionalGeneration.from_pretrained( - args.model, - revision=args.revision, - device_map='hpu', - torch_dtype=model_dtype, - ) - # print(user_model.transformer.output_layer.weight.dtype) # always fp16 - user_model.float() # static fp8 need float32 for graph compiler - else: - user_model = AutoModelForCausalLM.from_pretrained( - args.model, - trust_remote_code=args.trust_remote_code, - revision=args.revision, - device_map='hpu', - torch_dtype=model_dtype, - ) - # load weight for multi-cards - if world_size > 1: - if re.search("llama", args.model.lower()) or re.search("bloom", args.model.lower()): - ds_inference_kwargs = {"dtype": model_dtype} - ds_inference_kwargs["tensor_parallel"] = {"tp_size": world_size} - ds_inference_kwargs["enable_cuda_graph"] = False - from transformers.models.llama.modeling_llama import LlamaDecoderLayer - ds_inference_kwargs["injection_policy"] = {LlamaDecoderLayer: ("self_attn.o_proj", "mlp.down_proj")} - ds_inference_kwargs["checkpoint"] = checkpoints_json.name - ds_model = deepspeed.init_inference(user_model, **ds_inference_kwargs) - else: - ds_model = deepspeed.init_inference(user_model, - mp_size=world_size, - replace_with_kernel_inject=False) - user_model = ds_model.module - return user_model - - -def init_empty_model(model_name): - from accelerate import init_empty_weights - model_dtype = torch.float32 - config = AutoConfig.from_pretrained(model_name) - with init_empty_weights(): - model = AutoModelForCausalLM.from_config(config, torch_dtype=model_dtype) - return model - - -def init_tokenizer(args): - # tokenizer - if re.search("baichuan", args.model.lower()): - from models.tokenization_baichuan import BaichuanTokenizer - tokenizer = BaichuanTokenizer.from_pretrained( - args.model, - trust_remote_code=args.trust_remote_code - ) - else: - tokenizer = AutoTokenizer.from_pretrained( - args.model, - trust_remote_code=args.trust_remote_code - ) - tokenizer.pad_token = tokenizer.eos_token - return tokenizer - - -def show_msg(): - import numpy as np - import glob - from habana_frameworks.torch.hpu import memory_stats - print("Number of HPU graphs:", len(glob.glob(".graph_dumps/*PreGraph*"))) - mem_stats = memory_stats() - mem_dict = { - "memory_allocated (GB)": np.round(mem_stats["InUse"] / 1024**3, 2), - "max_memory_allocated (GB)": np.round(mem_stats["MaxInUse"] / 1024**3, 2), - "total_memory_available (GB)": np.round(mem_stats["Limit"] / 1024**3, 2), - } - for k, v in mem_dict.items(): - print("{:35} = {} GB".format(k[:-5].replace("_", " ").capitalize(), v)) - - -def itrex_bootstrap_stderr(f, xs, iters): - from lm_eval.metrics import _bootstrap_internal, sample_stddev - res = [] - chunk_size = min(1000, iters) - it = _bootstrap_internal(f, chunk_size) - for i in range(iters // chunk_size): - bootstrap = it((i, xs)) - res.extend(bootstrap) - return sample_stddev(res) - - -def save_to_excel(dict): - import pandas as pd - df_new = pd.DataFrame(dict) - try: - df_existing = pd.read_excel('output.xlsx') - except FileNotFoundError: - df_existing = pd.DataFrame() - df_combined = pd.concat([df_existing, df_new], axis=0, ignore_index=True) - df_combined.to_excel('output.xlsx', index=False, engine='openpyxl', header=True) - - -def eval_func(user_model, tokenizer, args): - import os - import re - import time - import json - import torch - import habana_frameworks.torch.hpex - import torch.nn.functional as F - import lm_eval - import lm_eval.tasks - import lm_eval.evaluator - - # to avoid out-of-memory caused by Popen for large language models. - lm_eval.metrics.bootstrap_stderr = itrex_bootstrap_stderr - - class HabanaModelAdapter(lm_eval.base.BaseLM): - def __init__(self, tokenizer, model, args, options): - super().__init__() - self.tokenizer = tokenizer - self.model = model.eval() - self._batch_size = args.batch_size - self.buckets = list(sorted(args.buckets)) - self.options = options - self._device = "hpu" - torch.set_grad_enabled(False) - - @property - def eot_token_id(self): - return self.model.config.eos_token_id - - @property - def max_length(self): - return self.buckets[-1] - - @property - def max_gen_toks(self): - raise NotImplementedError() - - @property - def batch_size(self): - return self._batch_size - - @property - def device(self): - # We need to do padding ourselves, otherwise we'll end up with recompilations - # Returning 'cpu' to keep tensors on CPU in lm_eval code - return 'cpu' # 'hpu' - - def tok_encode(self, string): - if ( - re.search("chatglm3", args.model.lower()) or - re.search("llama", args.model.lower()) or - re.search("mistral", args.model.lower()) - ): - string = string.lstrip() - return self.tokenizer.encode(string, add_special_tokens=False) - - def tok_decode(self, tokens): - return self.tokenizer.decode(tokens, skip_special_tokens=True) - - def _model_generate(self, context, max_length, eos_token_id): - raise NotImplementedError() - - def find_bucket(self, length): - return [b for b in self.buckets if b >= length][0] - - def _model_call(self, inputs): - seq_length = inputs.shape[-1] - padding_length = 0 - bucket_length = self.find_bucket(seq_length) - padding_length = bucket_length - seq_length - inputs = F.pad(inputs, (0, padding_length), value=self.model.config.pad_token_id) - logits = self.model(inputs.to(self._device))["logits"].cpu() - - if padding_length > 0: - logits = logits[:, :-padding_length, :] - logits = logits.to(torch.float32) - return logits - - lm_tasks = lm_eval.tasks.get_task_dict(args.tasks) - options = None - lm = HabanaModelAdapter(tokenizer, user_model, args, options) - - eval_start = time.perf_counter() - if args.approach == "cast": - from neural_compressor.torch.amp import autocast - if args.precision == "fp8_e4m3": - dtype = torch.float8_e4m3fn - elif args.precision == "fp8_e5m2": - dtype = torch.float8_e5m2 - elif args.precision == "fp16": - dtype = torch.float16 - elif args.precision == "bf16": - dtype = torch.bfloat16 - with autocast('hpu', dtype=dtype): - results = lm_eval.evaluator.evaluate(lm, lm_tasks, limit=args.limit) - else: - results = lm_eval.evaluator.evaluate(lm, lm_tasks, limit=args.limit) - print(lm_eval.evaluator.make_table(results)) - eval_end = time.perf_counter() - print("Duration:", eval_end - eval_start) - results['args'] = vars(args) - results['duration'] = eval_end - eval_start - - # make sure that result is dumped only once during multi-cards evaluation - local_rank = int(os.getenv('LOCAL_RANK', '-1')) - if local_rank in [-1, 0]: - dumped = json.dumps(results, indent=2) - accu_dict = {} - case_name = str(args.approach) + "-" + args.precision - for task_name in args.tasks: - if task_name == "wikitext": - print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["word_perplexity"]), flush=True) - accu_dict[task_name] = [args.model, case_name, results["results"][task_name]["word_perplexity"]] - else: - print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["acc"]), flush=True) - accu_dict[task_name] = [args.model, case_name, results["results"][task_name]["acc"]] - accu_dict["duration"] = [args.model, case_name, results["duration"]] - if args.dump_to_excel: - save_to_excel(accu_dict) - return results["results"][task_name]["acc"] diff --git a/examples/helloworld/fp8_example/README.md b/examples/helloworld/fp8_example/README.md new file mode 100644 index 00000000000..b758768ef0f --- /dev/null +++ b/examples/helloworld/fp8_example/README.md @@ -0,0 +1,96 @@ +### Usage demo: + +#### two steps to get quantized model + +```diff +import torch ++ from neural_compressor.torch.quantization import FP8Config, convert, prepare, finalize_calibration +import habana_frameworks.torch.core as htcore + +class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = torch.nn.Linear(10, 5) + self.fc2 = torch.nn.Linear(5, 10) + + def forward(self, inp): + x1 = self.fc1(inp) + x2 = self.fc2(x1) + return x2 + +model = M().eval() + ++ config = FP8Config.from_json_file(args.quant_config) # args.quant_config is the path of json file + ++ if config.measure: ++ model = prepare(model, config) + ++ if config.quantize: ++ htcore.hpu_initialize() ++ model = convert(model, config) + +# user code run +with torch.no_grad(): + model.to("hpu") + output = model(torch.randn(1, 10).to("hpu")) + print(output) + ++ if config.measure: ++ finalize_calibration(model) +``` + + +Whole script and config refer to [sample_two_steps.py](./sample_two_steps.py), [maxabs_measure.json](./maxabs_measure.json) and [maxabs_quant.json](./maxabs_quant.json). + +First, measure the tensor quantization statistic: +```shell +python sample_two_steps.py --quant_config=maxabs_measure.json +``` + +Then quantize the model based on previous measurements: +```shell +python sample_two_steps.py --quant_config=maxabs_quant.json +``` + +#### one step to get quantized model + +```diff +import torch ++ from neural_compressor.torch.quantization import FP8Config, convert, prepare, finalize_calibration +import habana_frameworks.torch.core as htcore + +class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = torch.nn.Linear(10, 5) + self.fc2 = torch.nn.Linear(5, 10) + + def forward(self, inp): + x1 = self.fc1(inp) + x2 = self.fc2(x1) + return x2 + +model = M().to("hpu") + ++ config = FP8Config.from_json_file(args.quant_config) # args.quant_config is the path of json file ++ model = prepare(model, config) + +# user code run to do calibration +with torch.no_grad(): + output = model(torch.randn(1, 10).to("hpu")) + print(output) + ++ finalize_calibration(model) ++ model = convert(model) + +# user code to run benchmark for quantized model +with torch.no_grad(): + output = model(torch.randn(1, 10).to("hpu")) + print(output) +``` + +Whole script and config refer to [sample_one_step.py](./sample_one_step.py). + +```shell +python sample_one_step.py --quant_config=quant_config.json +``` diff --git a/examples/helloworld/fp8_example/maxabs_measure.json b/examples/helloworld/fp8_example/maxabs_measure.json new file mode 100644 index 00000000000..8d55f33e57a --- /dev/null +++ b/examples/helloworld/fp8_example/maxabs_measure.json @@ -0,0 +1,7 @@ +{ + "mode": "MEASURE", + "observer": "maxabs", + "allowlist": {"types": [], "names": []}, + "blocklist": {"types": [], "names": []}, + "dump_stats_path": "./hqt_output/measure" +} diff --git a/examples/helloworld/fp8_example/maxabs_quant.json b/examples/helloworld/fp8_example/maxabs_quant.json new file mode 100644 index 00000000000..d1f76f8f630 --- /dev/null +++ b/examples/helloworld/fp8_example/maxabs_quant.json @@ -0,0 +1,8 @@ +{ + "mode": "QUANTIZE", + "observer": "maxabs", + "scale_method": "maxabs_hw", + "allowlist": {"types": [], "names": []}, + "blocklist": {"types": [], "names": []}, + "dump_stats_path": "./hqt_output/measure" +} diff --git a/examples/helloworld/fp8_example/quant_config.json b/examples/helloworld/fp8_example/quant_config.json new file mode 100644 index 00000000000..c139d13bbea --- /dev/null +++ b/examples/helloworld/fp8_example/quant_config.json @@ -0,0 +1,8 @@ +{ + "mode": "AUTO", + "observer": "maxabs", + "scale_method": "maxabs_hw", + "allowlist": {"types": [], "names": []}, + "blocklist": {"types": [], "names": []}, + "dump_stats_path": "./hqt_output/measure" +} diff --git a/examples/helloworld/fp8_example/sample_one_step.py b/examples/helloworld/fp8_example/sample_one_step.py new file mode 100644 index 00000000000..54a4090a833 --- /dev/null +++ b/examples/helloworld/fp8_example/sample_one_step.py @@ -0,0 +1,56 @@ +import argparse +import torch +import habana_frameworks.torch.core as htcore +htcore.hpu_set_env() + +from neural_compressor.torch.quantization import FP8Config, convert, finalize_calibration, prepare + +torch.manual_seed(1) + + +# 1. python sample_one_step.py --quant_config=quant_config.json + + +class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = torch.nn.Linear(10, 5) + self.fc2 = torch.nn.Linear(5, 10) + + def forward(self, inp): + x1 = self.fc1(inp) + x2 = self.fc2(x1) + return x2 + + +def eval_func(model): + # user's eval func + input = torch.randn(1, 10) + model(input.to("hpu")) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Habana FP8 sample code.", formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument("--quant_config", type=str, help="json file of quantization config") + args = parser.parse_args() + + model = M().eval().to("hpu") + htcore.hpu_initialize() + + config = FP8Config.from_json_file(args.quant_config) + model = prepare(model, config) + + # for calibration + with torch.no_grad(): + # model.to("hpu") + output = model(torch.randn(1, 10).to("hpu")) + + model = convert(model) + print(model) + + # for benchmark + with torch.no_grad(): + output = model(torch.randn(1, 10).to("hpu")) + print(output) diff --git a/examples/helloworld/fp8_example/sample_two_steps.py b/examples/helloworld/fp8_example/sample_two_steps.py new file mode 100644 index 00000000000..9e17748b9b0 --- /dev/null +++ b/examples/helloworld/fp8_example/sample_two_steps.py @@ -0,0 +1,50 @@ +import argparse +import torch +import habana_frameworks.torch.core as htcore +htcore.hpu_set_env() + +from neural_compressor.torch.quantization import FP8Config, convert, finalize_calibration, prepare + +torch.manual_seed(1) + +# 1. python sample_two_steps.py --quant_config=maxabs_measure.json +# 2. python sample_two_steps.py --quant_config=maxabs_quant.json + + +class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.fc1 = torch.nn.Linear(10, 5) + self.fc2 = torch.nn.Linear(5, 10) + + def forward(self, inp): + x1 = self.fc1(inp) + x2 = self.fc2(x1) + return x2 + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Habana FP8 sample code.", formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument("--quant_config", type=str, help="json file of quantization config") + args = parser.parse_args() + + model = M().eval() + config = FP8Config.from_json_file(args.quant_config) + + if config.measure: + model = prepare(model, config) + + if config.quantize: + htcore.hpu_initialize() + model = convert(model, config) + print(model) + + with torch.no_grad(): + model.to("hpu") + output = model(torch.randn(1, 10).to("hpu")) + print(output) + + if config.measure: + finalize_calibration(model) diff --git a/neural_compressor/common/base_config.py b/neural_compressor/common/base_config.py index 09032360cc0..d54e2e6515b 100644 --- a/neural_compressor/common/base_config.py +++ b/neural_compressor/common/base_config.py @@ -379,8 +379,10 @@ def to_json_file(self, filename): Args: filename (str): The path to save the JSON file. """ - # Implementation details omitted for brevity - pass + config_dict = self.to_dict() + with open(filename, "w", encoding="utf-8") as file: + json.dump(config_dict, file, indent=4) + logger.info("Dump the config into %s.", filename) def to_json_string(self, use_diff: bool = False) -> str: """Serializes this instance to a JSON string. diff --git a/neural_compressor/torch/algorithms/fp8_quant/__init__.py b/neural_compressor/torch/algorithms/fp8_quant/__init__.py new file mode 100644 index 00000000000..bea97db811c --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from neural_compressor.torch.algorithms.fp8_quant.common import ( + update_mode, + save_calib_result, + restore_patched_module, + with_patched_module, +) +from neural_compressor.torch.algorithms.fp8_quant.prepare_quant.prepare_model import finish_measurements, prep_model +from neural_compressor.torch.algorithms.fp8_quant.fp8_quant import FP8Quantizer diff --git a/neural_compressor/torch/algorithms/habana_fp8/tensor/__init__.py b/neural_compressor/torch/algorithms/fp8_quant/_core/__init__.py similarity index 100% rename from neural_compressor/torch/algorithms/habana_fp8/tensor/__init__.py rename to neural_compressor/torch/algorithms/fp8_quant/_core/__init__.py diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/common.py b/neural_compressor/torch/algorithms/fp8_quant/_core/common.py new file mode 100644 index 00000000000..cefe46e77f0 --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/common.py @@ -0,0 +1,256 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import importlib.util +import json +import os + +import numpy as np +import torch + +from .._quant_common.helper_modules import * +from .._quant_common.quant_config import get_hqt_config +from ..utils.logger import logger + +deepspeed_exists = False +if importlib.util.find_spec("deepspeed"): # check if deepspeed is installed + deepspeed_exists = True + +UNMEASURED_MODELS = "UnmeasuredModels" + + +class ModuleInfo: + def __init__(self, type, patched_module): + self.type = type + self.patched_module = patched_module + + +class ModuleConfig: + def __init__(self, inputs=(None,), outputs=(None,), params=None): + self.inputs = inputs + self.outputs = outputs + self.params = params if params is not None else {} + + +class ModuleExtraConfig: + def __init__(self, inputs=(None,), outputs=(None,), params=None, scale=None, config_params=None): + self.inputs = inputs + self.outputs = outputs + self.params = params if params is not None else {} + self.scale = scale + self.config_params = config_params if config_params is not None else {} + + +class ModuleType: + def __init__(self, num_inputs, param_names, num_outputs, required_output): + self.num_inputs = num_inputs + self.param_names = param_names + self.num_outputs = num_outputs + self.required_output = required_output + + +mod_types = { + "linear": ModuleType(1, ["weight"], 1, False), + "matmul": ModuleType(2, [], 1, False), + "kv_cache": ModuleType(1, [], 1, False), + "softmax": ModuleType(1, [], 1, True), + "fused_sdpa": ModuleType(3, [], 2, True), +} +descale_fcn = lambda x, scale: torch.mul(x, scale) +scale_fcn = lambda x, scale: torch.div(x, scale) +cast_fcn = lambda x, dtype: x.to(dtype=dtype) +cast_to_fp8_fcn = lambda x, dtype, scale_inv=None: torch.ops.hpu.cast_to_fp8_v2(x, scale_inv, False, False, dtype)[0] +cast_from_fp8_fcn = lambda x, dtype, scale=None: torch.ops.hpu.cast_from_fp8(x, scale, dtype) + + +class ShapeList: + data = None + + +def rec_fn(x, fn): + if isinstance(x, dict): + return {k: rec_fn(x[k], fn) for k in x} + elif isinstance(x, list): + return [rec_fn(k, fn) for k in x] + elif isinstance(x, tuple): + return tuple([rec_fn(k, fn) for k in x]) + else: + return fn(x) + + +def save_json(d, fname): + with open(fname, "w") as f: + json.dump(d, f, indent=4) + + +def load_json(fname): + with open(fname, "r") as f: + d = json.load(f) + return d + + +def save_npz(d, fname): + np.savez(fname, d) + + +def load_npz(fname): + d = np.load(fname, allow_pickle=True) + return d["arr_0"].item() + + +def save_file(model, d, source_format, fname, mode): + config = get_hqt_config(model) + logger.debug("Saving %s file: %s", mode, fname) + ext = os.path.splitext(fname)[1] + target_format = file_functions[ext][0] + dc = rec_fn(d, format_functions[(source_format, target_format)]) + df = { + "GlobalRank": config.cfg["global_rank"], + "LocalRank": config.cfg["local_rank"], + "Mode": mode, + "Nodes": dc, + } + try: + file_functions[ext][1](df, fname) + except: + pass + + +# convert module config data to other format +def module_convert(m, fcn): + mt = ModuleConfig( + tuple([fcn(x) for x in m.inputs]), + ( + tuple( + [fcn(m.outputs)], + ) + if type(m.outputs) == np.ndarray + else tuple([fcn(y) for y in m.outputs]) + ), + {k: fcn(m.params[k]) for k in m.params}, + ) + return mt + + +def fix_fields(d): + if "input" in d: + d["inputs"] = d.pop("input") + if "output" in d: + d["outputs"] = d.pop("output") + return d + + +def load_file(fname, target_format, fail_on_file_not_exist): + logger.debug("Loading file: %s", fname) + ext = os.path.splitext(fname)[1] + source_format = file_functions[ext][0] + d = {} + if os.path.isfile(fname): + d = file_functions[ext][2](fname) + elif fail_on_file_not_exist: + raise FileNotFoundError(f"Failed to load file {fname}") + if "Nodes" in d: + dc = {k: ModuleConfig(**fix_fields(d["Nodes"][k])) for k in d["Nodes"]} + dc = {k: module_convert(dc[k], format_functions[(source_format, target_format)]) for k in dc} + else: + dc = {} + return dc + + +def save_scales(model, d, source_format, fname): + dc = {k: d[k].__dict__ for k in d} + save_file(model, dc, source_format, fname, "Scale") + + +def load_scales(fname, target_format): + logger.debug("Loading scales file %s", fname) + d = load_file(fname, target_format, False) + return d + + +def convert_scales_to_tensors_dict(scales_obj, scales_file_format, hp_dtype): + scales_temp = {k: scales_obj[k].__dict__ for k in scales_obj} + scales_temp = format_functions_rec((scales_file_format, torch.Tensor))(scales_temp) + scales_temp = rec_fn(scales_temp, lambda x: x.to(dtype=hp_dtype, device="hpu")) + scales = {k: ModuleConfig(**scales_temp[k]) for k in scales_temp} + return scales + + +file_functions = { + ".json": (list, save_json, load_json), + ".npz": (np.ndarray, save_npz, load_npz), +} + +format_functions = { + (torch.Tensor, torch.Tensor): lambda x: x, + (np.ndarray, np.ndarray): lambda x: x, + (list, list): lambda x: x, + (torch.Tensor, np.ndarray): lambda x: x.detach().cpu().float().numpy(), + (torch.Tensor, list): lambda x: x.detach().cpu().float().numpy().tolist(), + (np.ndarray, torch.Tensor): torch.tensor, + (np.ndarray, list): lambda x: x.tolist(), + (list, torch.Tensor): torch.tensor, + (list, np.ndarray): lambda x: np.array(x), + (list, ShapeList): lambda x: [int(s) for s in x[0]], +} + + +format_functions_rec = lambda k: functools.partial(rec_fn, fn=format_functions[k]) + +mod_default_dict = { + "Matmul": ModuleInfo("matmul", PatchedMatmul), + "Linear": ModuleInfo("linear", PatchedLinear), + "RowParallelLinear": ModuleInfo("linear", PatchedRowParallelLinear), + "ColumnParallelLinear": ModuleInfo("linear", PatchedColumnParallelLinear), + "MergedColumnParallelLinear": ModuleInfo("linear", PatchedColumnParallelLinear), + "QKVParallelLinear": ModuleInfo("linear", PatchedColumnParallelLinear), + "FalconLinear": ModuleInfo("linear", PatchedLinear), + "KVCache": ModuleInfo("kv_cache", PatchedKVCache), + "VLLMKVCache": ModuleInfo("kv_cache", PatchedVLLMKVCache), + "Conv2d": ModuleInfo("linear", PatchedConv2d), + "LoRACompatibleLinear": ModuleInfo("linear", PatchedLoRACompatibleLinear), + "LoRACompatibleConv": ModuleInfo("linear", PatchedLoRACompatibleConv), + "Softmax": ModuleInfo("softmax", PatchedSoftmax), + "ModuleFusedSDPA": ModuleInfo("fused_sdpa", PatchedModuleFusedSDPA), +} + + +if deepspeed_exists: + mod_default_dict.update( + { + "LinearLayer": ModuleInfo("linear", PatchedLinear), + "LinearAllreduce": ModuleInfo("linear", PatchedLinearAllReduce), + "ScopedLinearAllReduce": ModuleInfo("linear", PatchedLinearAllReduce), + "LmHeadLinearAllreduce": ModuleInfo("linear", PatchedLmHeadLinearAllreduce), + } + ) + + +class ModInstInfo: + def __init__(self, name, parent): + self.name = name + self.parent = parent + + +parent_child_mod_dict = {} + + +def generate_model_info(model): + def create_mod_info_recursion(parent): + for name, mod in parent.named_children(): + parent_child_mod_dict[mod] = ModInstInfo(name, parent) + create_mod_info_recursion(mod) + + create_mod_info_recursion(model) diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/fp_utils.py b/neural_compressor/torch/algorithms/fp8_quant/_core/fp_utils.py new file mode 100644 index 00000000000..67ca91b7684 --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/fp_utils.py @@ -0,0 +1,187 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import habana_frameworks.torch.core as htcore +import habana_frameworks.torch.utils.experimental as htexp +import torch + +from .common import * + +GAUDI2 = htexp.synDeviceType.synDeviceGaudi2 +GAUDI3 = htexp.synDeviceType.synDeviceGaudi3 + +EXP_WIDTH = { + torch.float32: 8, + torch.bfloat16: 8, + torch.float8_e4m3fn: 4, + torch.float8_e5m2: 5, +} + + +def get_default_exp_bias(dtype): + exp_width = EXP_WIDTH[dtype] + return 2 ** (exp_width - 1) - 1 + + +EXP_BIAS_SETS = { + (GAUDI2, torch.float8_e4m3fn): [3, 7, 11, 15], + (GAUDI2, torch.float8_e5m2): [15], + (GAUDI3, torch.float8_e4m3fn): range(0, 63), + (GAUDI3, torch.float8_e5m2): range(0, 63), +} + +MAX_RANGE = { + torch.float32: 2 ** ((2**8 - 2 - get_default_exp_bias(torch.float32))) * (2 - 2 ** -(23)), + torch.bfloat16: 2 ** ((2**8 - 2 - get_default_exp_bias(torch.bfloat16))) * (2 - 2 ** -(7)), + torch.float8_e4m3fn: 2 ** ((2**4 - 2 - get_default_exp_bias(torch.float8_e4m3fn))) * (2 - 2 ** -(8 - 1 - 4)), + torch.float8_e5m2: 2 ** ((2**5 - 2 - get_default_exp_bias(torch.float8_e5m2))) * (2 - 2 ** -(8 - 1 - 5)), +} + + +def get_fullscale(dtype, exp_bias=None): + default_exp_bias = get_default_exp_bias(dtype) + fullscale = MAX_RANGE[dtype] + exp_bias = default_exp_bias if exp_bias is None else exp_bias + fullscale = fullscale * (2 ** (default_exp_bias - exp_bias)) + return fullscale + + +def get_fullscales_by_expbias_set(dtype, expbias_set): + return [get_fullscale(dtype, exp_bias=eb) for eb in expbias_set] + + +def get_fp8_hw_alligned_scales(dtype, device): + exp_bias_set = EXP_BIAS_SETS.get((device, dtype), None) + return ( + None + if exp_bias_set is None + else [x / MAX_RANGE[dtype] for x in get_fullscales_by_expbias_set(dtype, exp_bias_set)] + ) + + +DEVICES_SCALE_FACTORS = { + htexp.synDeviceType.synDeviceGaudi2: 4, + htexp.synDeviceType.synDeviceGaudi3: 1, +} +FP8_143_SCALES = { + device: get_fp8_hw_alligned_scales(torch.float8_e4m3fn, device) for device in DEVICES_SCALE_FACTORS.keys() +} +FP8_143_SCALES_TRAITS = { + device: ( + min(FP8_143_SCALES[device]), + max(FP8_143_SCALES[device]), + DEVICES_SCALE_FACTORS[device], + ) + for device in DEVICES_SCALE_FACTORS.keys() +} + + +def calc_maxabs_scale(xmaxabs, fullscale, backoff=1): + scale = xmaxabs / (fullscale * backoff) + return scale + + +def scale_to_pow2(scale): + scale_pow2 = 2 ** torch.ceil(torch.log2(scale)) + return scale_pow2 + + +# Considering range of hw aligned scales: 2^a, 2^a+1,..., 2^b (a=2^b then s=2^b, therefore min(_, 2^b) +# if m<=2^a then s=2^a, therefore max(_, 2^a) --> 2^a <= min(max(_,2^a),2^b) <=2^b +# if s^a 0: + sd[mname]["params"] = dict() + sdl[mname]["params"] = dict() + for param_name in mcd[mname].params: + if mcd[mname].params[param_name].state is not None: + sd[mname]["params"][param_name] = mcd[mname].params[param_name].state.detach().cpu().float().numpy() + sdl[mname]["params"][param_name] = ( + mcd[mname].params[param_name].state.detach().cpu().float().numpy().tolist() + ) + return sd, sdl + + +def save_measurements(model, fname=None): + config = get_hqt_config(model).cfg + if config["mode"] in [QuantMode.MEASURE, QuantMode.SHAPE]: + if fname is None: + if ("measure_file" in config) and (config["measure_file"] is not None): + fname_base = config["measure_file"] + measure_type = "DynamicRange" + elif ("shape_file" in config) and (config["shape_file"] is not None) and (config["observer"] == "shape"): + fname_base = config["shape_file"] + measure_type = "Shape" + fname_np = fname_base + ".npz" + fname_list = fname_base + ".json" + else: + logger.warning("'fname' is not None - Measurements/Shapes will not be saved") + return + mcd = get_mod_extra_config_dict(model) + sd, sdl = measure_control_to_state_dict(mcd) + + logger.info("Dumping measurements") + save_file(model, sd, np.ndarray, fname_np, measure_type) + save_file(model, sdl, list, fname_list, measure_type) + save_json(gmod_list, fname_base + "_mod_list.json") + + +def load_measurements(model, fname): + config = get_hqt_config(model).cfg + source_fname = fname if fname is not None else config["measure_file"] + fname_np = source_fname + ".npz" + d = load_file( + fname_np, + np.ndarray, + fail_on_file_not_exist=(config["scale_method"] != ScaleMethod.UNIT_SCALE), + ) + from collections import defaultdict + + d = defaultdict(lambda: None, d) + + return d + + +def save_json(d, fname): + with open(fname, "w") as f: + json.dump(d, f, indent=4) + + +def load_json(fname): + with open(fname, "r") as f: + d = json.load(f) + return d + + +class MaxAbsObserver: + def __init__(self, name, mod, d_shape=None, params=None): + self.name = name + self.mod = mod + self.first = True + self.used = False + self.state = self.init_state_from_shape(d_shape) + + def init_state(self, x): + device = x.device + state = torch.zeros((1, 1), device=device, dtype=torch.float32) + self.shape = list(x.shape) + return state + + def init_state_from_shape(self, x_shape, device="hpu"): + state = torch.zeros((1, 1), device=device, dtype=torch.float32) + self.first = False + return state + + def update_state(self, x): + self.state.copy_(torch.maximum(torch.max(torch.abs(x)), self.state)) + + def measure(self, x): + if self.first: + self.state = self.init_state(x) + self.first = False + self.update_state(x) + self.used = True + + def is_used(self): + return self.used + + +class MaxAbsPerChannelObserver: + def __init__(self, name, mod, d_shape=None, params=None): + self.name = name + self.mod = mod + self.first = True + self.state = None + self.used = False + self.dim = params["dim"] if (params is not None) and ("dim" in params) else -1 + if d_shape is not None: + p = list(range(len(d_shape))) + self.dim = self.dim if self.dim >= 0 else len(d_shape) + self.dim + p[-1] = self.dim + p[self.dim] = len(d_shape) - 1 + self.p = p + self.state = self.init_state_from_shape(d_shape) + + def init_state(self, x): + device = x.device + Nch = x.shape[self.dim] + self.Nch = Nch + state = torch.zeros((Nch, 1), device=device, dtype=torch.float32) + self.shape = list(x.shape) + return state + + def init_state_from_shape(self, x_shape, device="hpu"): + device = device + Nch = x_shape[self.dim] + self.Nch = Nch + state = torch.zeros((Nch, 1), device=device, dtype=torch.float32) + self.first = False + return state + + def update_state(self, x): + self.state.copy_( + torch.maximum( + torch.max( + torch.abs(x.permute(self.p).reshape([-1, self.Nch])), + dim=0, + keepdim=True, + )[0].t(), + self.state, + ) + ) + + def measure(self, x): + if self.first: + self.state = self.init_state(x) + self.first = False + self.update_state(x) + self.used = True + + def is_used(self): + return self.used + + +def save_module(mod): + folder_name = os.path.join(mod.config["dump_stats_base_path"], "tensors") + os.makedirs(folder_name, exist_ok=True) + file_base_name = os.path.join(folder_name, imod_dict[mod] + "_module.pt") + torch.save(mod.state_dict(), file_base_name) + + +class SaveObserver: + def __init__(self, name, mod, d_shape=None, params=None): + self.name = name + self.mod = mod + self.first = True + self.cnt = -1 + self.folder_name = os.path.join(config["dump_stats_base_path"], "tensors") + os.makedirs(self.folder_name, exist_ok=True) + self.file_base_name = os.path.join(self.folder_name, imod_dict[mod] + "_" + name + "_iter") + self.state = self.init_state_from_shape(d_shape) + self.used = False + + def init_state(self, x): + device = x.device + state = torch.zeros((1, 1), device=device, dtype=torch.float32) + self.shape = list(x.shape) + return state + + def init_state_from_shape(self, x_shape, device="hpu"): + state = torch.zeros((1, 1), device=device, dtype=torch.float32) + self.first = False + return state + + def update_state(self, x): + self.cnt += 1 + torch.save(x, self.file_base_name + str(self.cnt) + ".pt") + + def measure(self, x): + self.update_state(x) + self.used = True + + def is_used(self): + return self.used + + +class ShapeObserver: + def __init__(self, name, mod, d_shape=None, params=None): + self.name = name + self.mod = mod + self.state = None + + def init_state(self, x): + device = x.device + Ndim = len(x.shape) + self.Ndim = Ndim + state = torch.tensor(x.shape, device=device, dtype=torch.int32).reshape((1, Ndim)) + return state + + def init_state_from_shape(self, x_shape, device="hpu"): + logger.info("ShapeObserver doesn't support init_state_from_shape") + return + + def update_state(self, x): + logger.info("ShapeObserver doesn't support update_state") + return + + def measure(self, x): + self.state = self.init_state(x) + + def is_used(self): + return self.state is not None + + +observer_types = { + "shape": ShapeObserver, + "maxabs": MaxAbsObserver, + "maxabs_per_channel": MaxAbsPerChannelObserver, + "save": SaveObserver, +} + +observer_params = { + "maxabs_per_channel": { + "linear": ModuleConfig(({"dim": -1},), ({"dim": -1},), {"weight": {"dim": 0}}), + "matmul": ModuleConfig( + ( + {"dim": -1}, + {"dim": -2}, + ), + ({"dim": -1},), + None, + ), + } +} diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/quant_dequant.py b/neural_compressor/torch/algorithms/fp8_quant/_core/quant_dequant.py new file mode 100644 index 00000000000..0f32be4b00c --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/quant_dequant.py @@ -0,0 +1,71 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import abstractmethod + +import torch.nn as nn + +from .common import * + + +class QuantDequantBase(nn.Module): + def __init__(self, lp_dtype, hp_dtype="", *args, **kwargs): + super(QuantDequantBase, self).__init__(*args, **kwargs) + self.lp_dtype = lp_dtype + self.hp_dtype = hp_dtype + + @abstractmethod + def forward(self, *args, **kwargs): + pass + + def extra_repr(self) -> str: + return f"lp_dtype={self.lp_dtype}, hp_dtype={self.hp_dtype}" + + +class QuantDequantNone(QuantDequantBase): + def __init__(self, lp_dtype, hp_dtype, *args, **kwargs): + super(QuantDequantNone, self).__init__(lp_dtype, hp_dtype, *args, **kwargs) + + def forward(self, *args, **kwargs): + return args[0] + + def extra_repr(self) -> str: + repr = super(QuantDequantNone, self).extra_repr() + return f"{repr}, doesn't quantize nor dequantize" + + +class QuantInput(QuantDequantBase): + def __init__(self, scale_inv, lp_dtype, hp_dtype, *args, **kwargs): + super(QuantInput, self).__init__(lp_dtype, hp_dtype, *args, **kwargs) + self.scale_inv = nn.Parameter(scale_inv) + + def forward(self, x): + return cast_to_fp8_fcn(x, self.lp_dtype, self.scale_inv) + + def extra_repr(self) -> str: + repr = super(QuantInput, self).extra_repr() + return f"{repr}, scale_inv dtype={self.scale_inv.dtype}" + + +class DequantOutput(QuantDequantBase): + def __init__(self, scale, lp_dtype, hp_dtype, *args, **kwargs): + super(DequantOutput, self).__init__(lp_dtype, hp_dtype, *args, **kwargs) + self.scale = nn.Parameter(scale) + + def forward(self, x): + return cast_from_fp8_fcn(x, self.hp_dtype, self.scale) + + def extra_repr(self) -> str: + repr = super(DequantOutput, self).extra_repr() + return f"{repr}, scale dtype={self.scale.dtype}" diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/quantize.py b/neural_compressor/torch/algorithms/fp8_quant/_core/quantize.py new file mode 100644 index 00000000000..efe412cc16c --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/quantize.py @@ -0,0 +1,106 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import habana_frameworks.torch.core as htcore +import torch +import torch.nn as nn + +from .._quant_common.helper_modules import PatchedUnmeasuredModule +from .._quant_common.quant_config import get_hqt_config +from ..utils.logger import logger +from .common import UNMEASURED_MODELS, generate_model_info, mod_default_dict, parent_child_mod_dict +from .measure import load_measurements +from .scale import get_config, scale_method_mapping, scaling_methods + + +def patch_module(mod, qconfig, mod_dict, patched_mod=None): + parent = parent_child_mod_dict[mod].parent + name = parent_child_mod_dict[mod].name + if patched_mod is None: + patched_mod = mod_dict[mod.__class__.__name__].patched_module(mod, qconfig) + setattr(parent, name, patched_mod) + + +def apply_hf_hook(module): + if hasattr(module, "_hf_hook"): + module._hf_hook.pre_forward(module) + module._hf_hook.detach_hook(module) + delattr(module, "_hf_hook") + if hasattr(module, "_old_forward"): + module.forward = module._old_forward + delattr(module, "_old_forward") + + +def quantize_params(mod, mod_extra_config): + for param_name in mod_extra_config.params: + quantizer = mod_extra_config.params[param_name] + param = getattr(mod, param_name) + quantized_param = quantizer(param.to("hpu")) + delattr(mod, param_name) + setattr(mod, param_name, nn.Parameter(quantized_param)) + quantized_param = getattr(mod, param_name) + quantized_param.requires_grad_(False) + htcore.mark_step() + + +def prepare_model(model, qconfig, mod_list, hp_dtype=torch.float): + config = get_hqt_config(model) + patched_modules = [] + patched_module_types = set() + with torch.no_grad(): + for name, mod in model.named_modules(): + if name in qconfig[UNMEASURED_MODELS]: + if not config.cfg["ignore_modules_wo_measures"]: + patch_module(mod, None, None, PatchedUnmeasuredModule(name)) + else: + logger.debug("Module %s was not quantized.", name) + continue + # When offloading weight to disk, need to transfer the weight from disk to cpu using hf_hook + apply_hf_hook(mod) + if name in mod_list: + mod_extra_config = qconfig[name] + quantize_params(mod, mod_extra_config) + patch_module(mod, mod_extra_config, mod_default_dict) + patched_modules.append(name) + patched_module_types.add(type(mod)) + logger.debug("Patched module types: %s", patched_module_types) + logger.debug("Patched modules: %s", patched_modules) + logger.debug("Total patched modules: %d", len(patched_modules)) + model = model.to("hpu") + htcore.mark_step() + + +def quantize(model, mod_list): + config = get_hqt_config(model) + generate_model_info(model) + hp_dtype = config.cfg["hp_dtype"] + lp_dtype = config.cfg["fp8_config"] + measurement = load_measurements(model, config.cfg["measure_file"]) + # FIXME make sure this takes unit_scale or measured scale, from Configs + scaling_method_name = scale_method_mapping[(config.cfg["scale_method"], config.cfg["observer"])] + scaling_method = scaling_methods[scaling_method_name] + params = config.cfg["scale_params"] + params["hp_dtype"] = hp_dtype + params["lp_dtype"] = lp_dtype + qconfig = get_config( + model, + measurement, + mod_default_dict, + scaling_method, + params, + config.cfg["scale_file"], + False, + mod_list, + ) + prepare_model(model, qconfig, mod_list, hp_dtype=hp_dtype) diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/scale.py b/neural_compressor/torch/algorithms/fp8_quant/_core/scale.py new file mode 100644 index 00000000000..67491b42e8e --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/scale.py @@ -0,0 +1,439 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch + +from .._quant_common.quant_config import ScaleMethod, set_hqt_config +from ..utils.logger import logger +from .common import * +from .fp_utils import * +from .quant_dequant import * +from .scale_methods import * + + +def matmul_scales_to_mod_config(mod, scales, params): + scales_inv = invert_scales(scales) + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + input_config = [QuantInput(s_inv, lp_dtype, hp_dtype) for s_inv in scales_inv.inputs] + # outputs as bf16, and descaled in gemm under PatchedMatmul, so no need to work here + output_config = [QuantDequantNone(lp_dtype, hp_dtype)] + config = ModuleConfig(input_config, output_config, {}) + return config + + +def fsdpa_scales_to_mod_config(mod, scales, params): + scales_inv = invert_scales(scales) + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + input_config = [QuantInput(s_inv, lp_dtype, hp_dtype) for s_inv in scales_inv.inputs] + output_config = [DequantOutput(scales.outputs[0], lp_dtype, hp_dtype)] + config = ModuleConfig(input_config, output_config, {}) + return config + + +def linear_scales_to_mod_config(mod, scales, params): + scales_inv = invert_scales(scales) + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + input_config = [QuantInput(scales_inv.inputs[0], lp_dtype, hp_dtype)] + # outputs as bf16, and descaled in gemm under PatchedLinear, so no need to work here + output_config = [QuantDequantNone(lp_dtype, hp_dtype)] + if isinstance(scales_inv.params["weight"], (torch.Tensor, float)): + weight_config = QuantInput(scales_inv.params["weight"], lp_dtype, hp_dtype) + elif isinstance(scales_inv.params["weight"], dict): + weight_scale_inv_out_ch = scales_inv.params["weight"][0] + weight_scale_inv_in_ch = scales_inv.params["weight"][1] + if isinstance(weight_scale_inv_out_ch, torch.Tensor): + scale_inv = torch.mul( + weight_scale_inv_in_ch.reshape([1, -1]), + weight_scale_inv_out_ch.reshape([-1, 1]), + ) + else: + # TODO SW-169781: Handle here scalar weight for PCQ + raise TypeError(f"Unknown weight scales type: {type(weight_scale_inv_out_ch)}.") + weight_config = QuantInput(scale_inv, lp_dtype, hp_dtype) + else: + logger.error("Unknown weight scales format.") + params_config = {"weight": weight_config} + if hasattr(mod, "bias") and (getattr(mod, "bias") is not None): + # In PatchedLinear the bias is added to the output of gemm. + # The output is expected to be descaled and in bf16, so we don't need to touch the bias. + bias_config = QuantDequantNone(lp_dtype, hp_dtype) + params_config.update({"bias": bias_config}) + config = ModuleConfig(input_config, output_config, params_config) + return config + + +def kv_cache_scales_to_mod_config(mod, scales, params): + # how quant/dequant will be applied on layer tensors + scales_inv = invert_scales(scales) + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + input_config = [QuantInput(scales_inv.inputs[0], lp_dtype, hp_dtype)] + output_config = [DequantOutput(scales.outputs[0], lp_dtype, hp_dtype)] + config = ModuleConfig(input_config, output_config) + return config + + +def softmax_scales_to_mod_config(mod, scales, params): + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + output_config = [DequantOutput(scales.outputs[0], lp_dtype, hp_dtype)] + return ModuleConfig(None, output_config) + + +def get_config( + model, + measurement, + mod_dict, + method, + params, + scales_file=None, + recalc_scales=False, + mod_list=None, +): + with torch.no_grad(): + top_level_config = get_hqt_config(model) + qconfig = {UNMEASURED_MODELS: []} + scales_file_format = np.ndarray # file_functions[os.path.splitext(scales_file)[1]][0] + scales_obj = ( + load_scales(scales_file + ".npz", scales_file_format) + if (scales_file is not None) and not recalc_scales + else {} + ) + scales = convert_scales_to_tensors_dict(scales_obj, scales_file_format, params["hp_dtype"]) + model_dict = dict(model.named_modules()) + for mname in mod_list: + mod = model_dict[mname] + set_hqt_config(mod, top_level_config) # set config in the module, as it consumed by the patched module + mod_type_str = mod.__class__.__name__ + layer_type = mod_dict[mod_type_str].type + if mname not in scales: + logger.debug("Calculating scales for layer %s", mname) + if mname not in measurement: + qconfig[UNMEASURED_MODELS].append(mname) + logger.debug( + "Layer '%s' has no measurements therefore it can't be quantized.", + mname, + ) + continue + layer_measure = measurement[mname] # ModuleConfig() of measurements + scales[mname] = method[layer_type][0](mod, layer_measure, params) # ModuleConfig() of scales + if scales_file is not None: + scales_obj[mname] = ModuleConfig( + **format_functions_rec((torch.Tensor, scales_file_format))(scales[mname].__dict__) + ) + + logger.debug( + "Preparing quantization functions for layer %s layer_type=%s", + mname, + layer_type, + ) + mod_config = method[layer_type][1](mod, scales[mname], params) # ModuleConfig() of QuantDequant + mod_extra_config = ModuleExtraConfig( + mod_config.inputs, + mod_config.outputs, + mod_config.params, + scales[mname], + params, + ) + qconfig[mname] = mod_extra_config + if scales_file is not None: + save_scales(model, scales_obj, scales_file_format, scales_file + ".npz") + save_scales(model, scales_obj, scales_file_format, scales_file + ".json") + return qconfig + + +scaling_methods = { + "unit_scale": { + "linear": (linear_unit_scale_scales, linear_scales_to_mod_config), + "matmul": (matmul_unit_scale_scales, matmul_scales_to_mod_config), + "softmax": (softmax_unit_scale_scales, softmax_scales_to_mod_config), + "kv_cache": (kv_cache_unit_scale_scales, kv_cache_scales_to_mod_config), + "fused_sdpa": (fsdpa_unit_scale_scales, fsdpa_scales_to_mod_config), + }, + "act_maxabs_pts_weight_maxabs_pts_pow2_hw": { + "linear": ( + linear_act_maxabs_pts_weight_maxabs_pts_pow2_hw_scales, + linear_scales_to_mod_config, + ), + "matmul": ( + matmul_act_maxabs_pts_weight_maxabs_pts_pow2_hw_scales, + matmul_scales_to_mod_config, + ), + "kv_cache": ( + kv_cache_act_maxabs_pts_weight_maxabs_pts_pow2_hw_scales, + kv_cache_scales_to_mod_config, + ), + "softmax": ( + softmax_input_unit_output_maxabs_pts_hw_scales, + softmax_scales_to_mod_config, + ), + "fused_sdpa": ( + fsdpa_act_maxabs_pts_weight_maxabs_pts_pow2_hw_scales, + fsdpa_scales_to_mod_config, + ), + }, + "act_maxabs_pts_weight_maxabs_pts_pow2": { + "linear": ( + linear_act_maxabs_pts_weight_maxabs_pts_pow2_scales, + linear_scales_to_mod_config, + ), + "matmul": ( + matmul_act_maxabs_pts_weight_maxabs_pts_pow2_scales, + matmul_scales_to_mod_config, + ), + }, + "act_maxabs_pts_pow2_hw_weights_maxabs_pcs_pow2": { + "linear": ( + linear_act_maxabs_pts_pow2_hw_weights_maxabs_pcs_pow2_scales, + linear_scales_to_mod_config, + ), + "matmul": ( + matmul_act_maxabs_pts_weight_maxabs_pts_pow2_hw_scales, + matmul_scales_to_mod_config, + ), + # kv_cache is pts as op in hw doesn't work in pcs + "kv_cache": ( + kv_cache_act_maxabs_pts_weight_maxabs_pts_pow2_hw_scales, + kv_cache_scales_to_mod_config, + ), + "fused_sdpa": ( + fsdpa_act_maxabs_pts_weight_maxabs_pts_pow2_hw_scales, + fsdpa_scales_to_mod_config, + ), + }, + "act_maxabs_pts_weight_opt_pts_pow2": { + "linear": ( + linear_act_maxabs_pts_weight_opt_pts_pow2_scales, + linear_scales_to_mod_config, + ), + "matmul": ( + matmul_act_maxabs_pts_weight_maxabs_pts_pow2_scales, + matmul_scales_to_mod_config, + ), + }, + "act_maxabs_pts_weight_opt_pts_hw": { + "linear": ( + linear_act_maxabs_pts_weight_opt_pts_hw_scales, + linear_scales_to_mod_config, + ), + "matmul": ( + matmul_act_maxabs_pts_weight_maxabs_pts_pow2_hw_scales, + matmul_scales_to_mod_config, + ), + "softmax": ( + softmax_input_unit_output_maxabs_pts_hw_scales, + softmax_scales_to_mod_config, + ), + "fused_sdpa": ( + fsdpa_act_maxabs_pts_weight_maxabs_pts_pow2_hw_scales, + fsdpa_scales_to_mod_config, + ), + }, + "act_maxabs_pts_pow2_hw_weights_opt_pcs_pow2": { + "linear": ( + linear_act_maxabs_pts_pow2_hw_weights_opt_pcs_pow2_scales, + linear_scales_to_mod_config, + ), + "matmul": ( + matmul_act_maxabs_pts_weight_maxabs_pts_pow2_hw_scales, + matmul_scales_to_mod_config, + ), + # kv_cache is pts as op in hw doesn't work in pcs + "kv_cache": ( + kv_cache_act_maxabs_pts_weight_maxabs_pts_pow2_hw_scales, + kv_cache_scales_to_mod_config, + ), + "fused_sdpa": ( + fsdpa_act_maxabs_pts_weight_maxabs_pts_pow2_hw_scales, + fsdpa_scales_to_mod_config, + ), + }, + "act_maxabs_pts_pow2_weights_maxabs_pcs_pow2": { + "linear": ( + linear_act_maxabs_pts_pow2_weights_maxabs_pcs_pow2_scales, + linear_scales_to_mod_config, + ), + "matmul": ( + matmul_act_maxabs_pts_weight_maxabs_pts_pow2_scales, + matmul_scales_to_mod_config, + ), + # kv_cache is pts as op in hw doesn't work in pcs + "kv_cache": ( + kv_cache_act_maxabs_pts_weight_maxabs_pts_pow2_hw_scales, + kv_cache_scales_to_mod_config, + ), + "fused_sdpa": ( + fsdpa_act_maxabs_pts_weight_maxabs_pts_pow2_scales, + fsdpa_scales_to_mod_config, + ), + }, + "act_maxabs_pts_pow2_weights_opt_pcs_pow2": { + "linear": ( + linear_act_maxabs_pts_pow2_weights_opt_pcs_pow2_scales, + linear_scales_to_mod_config, + ), + "matmul": ( + matmul_act_maxabs_pts_weight_maxabs_pts_pow2_scales, + matmul_scales_to_mod_config, + ), + # kv_cache is pts as op in hw doesn't work in pcs + "kv_cache": ( + kv_cache_act_maxabs_pts_pow2_weight_opt_pcs_pow2_scales, + kv_cache_scales_to_mod_config, + ), + "fused_sdpa": ( + fsdpa_act_maxabs_pts_weight_maxabs_pts_pow2_scales, + fsdpa_scales_to_mod_config, + ), + }, + "smoothquant_weights_opt_pow2": { + "linear": ( + linear_smoothquant_weights_opt_pow2_scales, + linear_scales_to_mod_config, + ), + "matmul": ( + matmul_act_maxabs_pts_weight_maxabs_pts_pow2_hw_scales, + matmul_scales_to_mod_config, + ), + }, + "smoothquant_weights_maxabs_pow2": { + "linear": ( + linear_smoothquant_weights_maxabs_pow2_scales, + linear_scales_to_mod_config, + ), + "matmul": ( + matmul_act_maxabs_pts_weight_maxabs_pts_pow2_hw_scales, + matmul_scales_to_mod_config, + ), + }, + "weaksmoothquant_weights_maxabs_pow2": { + "linear": ( + linear_weaksmoothquant_weights_maxabs_pow2_scales, + linear_scales_to_mod_config, + ), + "matmul": ( + matmul_act_maxabs_pts_weight_maxabs_pts_pow2_hw_scales, + matmul_scales_to_mod_config, + ), + }, +} + +scale_method_mapping = { + (ScaleMethod.UNIT_SCALE, "maxabs"): "unit_scale", + (ScaleMethod.UNIT_SCALE, "maxabs_per_channel"): "unit_scale", + (ScaleMethod.MAXABS_HW, "maxabs"): "act_maxabs_pts_weight_maxabs_pts_pow2_hw", + (ScaleMethod.MAXABS_POW2, "maxabs"): "act_maxabs_pts_weight_maxabs_pts_pow2", + (ScaleMethod.MAXABS_HW_OPT_WEIGHT, "maxabs"): "act_maxabs_pts_weight_opt_pts_hw", + ( + ScaleMethod.MAXABS_POW2_OPT_WEIGHT, + "maxabs", + ): "act_maxabs_pts_weight_opt_pts_pow2", + ( + ScaleMethod.ACT_MAXABS_HW_WEIGHTS_PCS_MAXABS_POW2, + "maxabs", + ): "act_maxabs_pts_pow2_hw_weights_maxabs_pcs_pow2", + ( + ScaleMethod.ACT_MAXABS_HW_WEIGHTS_PCS_MAXABS_POW2, + "maxabs_per_channel", + ): "act_maxabs_pts_pow2_hw_weights_maxabs_pcs_pow2", + ( + ScaleMethod.SMOOTHQUANT_WEIGHTS_OUTPUT_CHANNEL_MAXABS_POW2, + "maxabs_per_channel", + ): "smoothquant_weights_maxabs_pow2", + ( + ScaleMethod.WEAKSMOOTHQUANT_WEIGHTS_OUTPUT_CHANNEL_MAXABS_POW2, + "maxabs_per_channel", + ): "weaksmoothquant_weights_maxabs_pow2", + ( + ScaleMethod.ACT_MAXABS_HW_WEIGHTS_PCS_OPT_POW2, + "maxabs", + ): "act_maxabs_pts_pow2_hw_weights_opt_pcs_pow2", + ( + ScaleMethod.ACT_MAXABS_HW_WEIGHTS_PCS_OPT_POW2, + "maxabs_per_channel", + ): "act_maxabs_pts_pow2_hw_weights_opt_pcs_pow2", + ( + ScaleMethod.ACT_MAXABS_POW2_WEIGHTS_PCS_MAXABS_POW2, + "maxabs", + ): "act_maxabs_pts_pow2_weights_maxabs_pcs_pow2", + ( + ScaleMethod.ACT_MAXABS_POW2_WEIGHTS_PCS_MAXABS_POW2, + "maxabs_per_channel", + ): "act_maxabs_pts_pow2_weights_maxabs_pcs_pow2", + ( + ScaleMethod.ACT_MAXABS_POW2_WEIGHTS_PCS_OPT_POW2, + "maxabs", + ): "act_maxabs_pts_pow2_weights_opt_pcs_pow2", + (ScaleMethod.SMOOTHQUANT_OPT, "maxabs_per_channel"): "smoothquant_weights_opt_pow2", +} + +scaling_params = { + "unit_scale": {}, + "act_maxabs_pts_weight_maxabs_pts_pow2_hw": { + "input_backoff": 0.25, + "weight_backoff": 0.5, + }, + "act_maxabs_pts_weight_maxabs_pts_pow2": { + "input_backoff": 0.25, + "weight_backoff": 0.5, + }, + "act_maxabs_pts_weight_opt_pts_pow2": { + "input_backoff": 0.25, + "weight_backoff": 0.5, + "weight_scales": [2.0**s for s in range(-10, 10)], + }, + "act_maxabs_pts_weight_opt_pts_hw": { + "input_backoff": 0.25, + "weight_backoff": 0.5, + "weight_scales": [2.0**s for s in [4, 0, -4, -8]], + }, + "smoothquant_weights_maxabs_pow2": { + "input_backoff": 0.25, + "weight_backoff": 0.5, + "alpha": 0.5, + }, + "weaksmoothquant_weights_maxabs_pow2": { + "input_backoff": 0.25, + "weight_backoff": 0.5, + "alpha": 0.5, + }, + "act_maxabs_pts_pow2_hw_weights_maxabs_pcs_pow2": { + "input_backoff": 0.25, + "weight_backoff": 0.5, + }, + "act_maxabs_pts_pow2_hw_weights_opt_pcs_pow2": { + "input_backoff": 0.25, + "weight_backoff": 0.5, + "weight_scales": [2.0**s for s in range(-3, 5)], + }, + "act_maxabs_pts_pow2_weights_maxabs_pcs_pow2": { + "input_backoff": 0.25, + "weight_backoff": 0.5, + }, + "act_maxabs_pts_pow2_weights_opt_pcs_pow2": { + "input_backoff": 0.25, + "weight_backoff": 0.5, + "weight_scales": [2.0**s for s in range(-3, 5)], + }, + "smoothquant_weights_opt_pow2": { + "input_backoff": 0.25, + "weight_backoff": 0.5, + "alpha": 0.5, + "transformed_weight_scales": [2.0**s for s in range(-3, 5)], + }, +} diff --git a/neural_compressor/torch/algorithms/habana_fp8/__init__.py b/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/__init__.py similarity index 85% rename from neural_compressor/torch/algorithms/habana_fp8/__init__.py rename to neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/__init__.py index fe3a05d7d0b..23d3c7686d4 100644 --- a/neural_compressor/torch/algorithms/habana_fp8/__init__.py +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/__init__.py @@ -12,5 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .fp8_quant import quantize_dynamic, quantize, white_list -from .save_load import save, load +from .max_abs import * +from .unit_scale import * +from .smooth_quant import * diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/max_abs.py b/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/max_abs.py new file mode 100644 index 00000000000..fd295a07374 --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/max_abs.py @@ -0,0 +1,411 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from ..common import * +from ..fp_utils import * + + +def linear_act_maxabs_pts_weight_maxabs_pts_pow2_hw_scales(mod, measurement, params): + config = get_hqt_config(mod).cfg + device = torch.device("hpu") + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + fullscale = MAX_RANGE[lp_dtype] + input_backoff = params["input_backoff"] + weight_backoff = params["weight_backoff"] + input_scale = calc_maxabs_scale( + torch.tensor(measurement.inputs[0], dtype=hp_dtype, device=device).max(), + fullscale, + input_backoff, + ) + weight_scale = calc_maxabs_scale( + torch.max(torch.abs(mod.weight.detach())).to(dtype=hp_dtype, device=device), + fullscale, + weight_backoff, + ) + input_scale = scale_to_pow2_hw(input_scale, device_type=config["device_type"]) + weight_scale = scale_to_pow2_hw(weight_scale, device_type=config["device_type"]) + output_scale = input_scale * weight_scale + return ModuleConfig((input_scale,), (output_scale,), {"weight": weight_scale}) + + +def linear_act_maxabs_pts_weight_maxabs_pts_pow2_scales(mod, measurement, params): + config = get_hqt_config(mod).cfg + device = torch.device("hpu") + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + fullscale = MAX_RANGE[lp_dtype] + input_backoff = params["input_backoff"] + weight_backoff = params["weight_backoff"] + input_scale = calc_maxabs_scale( + torch.tensor(measurement.inputs[0], dtype=hp_dtype, device=device).max(), + fullscale, + input_backoff, + ) + weight_scale = calc_maxabs_scale( + torch.max(torch.abs(mod.weight.detach())).to(dtype=hp_dtype, device=device), + fullscale, + weight_backoff, + ) + input_scale = scale_to_pow2(input_scale) + weight_scale = scale_to_pow2(weight_scale) + output_scale = input_scale * weight_scale + return ModuleConfig((input_scale,), (output_scale,), {"weight": weight_scale}) + + +def matmul_act_maxabs_pts_weight_maxabs_pts_pow2_hw_scales(mod, measurement, params): + config = get_hqt_config(mod).cfg + device = torch.device("hpu") + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + fullscale = MAX_RANGE[lp_dtype] + input_backoff = params["input_backoff"] + input_scale = [ + calc_maxabs_scale( + torch.tensor(x, dtype=hp_dtype, device=device).max(), + fullscale, + input_backoff, + ) + for x in measurement.inputs + ] + input_scale = [scale_to_pow2_hw(x, device_type=config["device_type"]) for x in input_scale] + output_scale = [input_scale[0] * input_scale[1]] + return ModuleConfig(input_scale, output_scale, {}) + + +def matmul_act_maxabs_pts_weight_maxabs_pts_pow2_scales(mod, measurement, params): + config = get_hqt_config(mod).cfg + device = torch.device("hpu") + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + fullscale = MAX_RANGE[lp_dtype] + input_backoff = params["input_backoff"] + input_scale = [ + calc_maxabs_scale( + torch.tensor(x, dtype=hp_dtype, device=device).max(), + fullscale, + input_backoff, + ) + for x in measurement.inputs + ] + input_scale = [scale_to_pow2(x) for x in input_scale] + output_scale = [input_scale[0] * input_scale[1]] + return ModuleConfig(input_scale, output_scale, {}) + + +def fsdpa_act_maxabs_pts_weight_maxabs_pts_pow2_hw_scales(mod, measurement, params): + config = get_hqt_config(mod).cfg + device = torch.device("hpu") + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + fullscale = MAX_RANGE[lp_dtype] + input_backoff = params["input_backoff"] + input_scale = [ + calc_maxabs_scale( + torch.tensor(x, dtype=hp_dtype, device=device).max(), + fullscale, + input_backoff, + ) + for x in measurement.inputs + ] + # add amax scale to input scales + input_scale.append( + calc_maxabs_scale( + torch.tensor(measurement.outputs[1], dtype=hp_dtype, device=device).max(), + fullscale, + input_backoff, + ) + ) + input_scale = [scale_to_pow2_hw(x, device_type=config["device_type"]) for x in input_scale] + output_scale = calc_maxabs_scale( + torch.tensor(measurement.outputs[0], dtype=hp_dtype, device=device).max(), + fullscale, + input_backoff, + ) + output_scale = [scale_to_pow2_hw(output_scale, device_type=config["device_type"])] + return ModuleConfig(input_scale, output_scale, {}) + + +def fsdpa_act_maxabs_pts_weight_maxabs_pts_pow2_scales(mod, measurement, params): + device = torch.device("hpu") + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + fullscale = MAX_RANGE[lp_dtype] + input_backoff = params["input_backoff"] + input_scale = [ + calc_maxabs_scale( + torch.tensor(x, dtype=hp_dtype, device=device).max(), + fullscale, + input_backoff, + ) + for x in measurement.inputs + ] + # fsdpa is combined out of - BMM1(Q,K) -> Softmax -> BMM2(AMAX,V) + # during measure we receive the amax value from the cguid and apply it during quant as input + input_scale.append( + calc_maxabs_scale( + torch.tensor(measurement.outputs[1], dtype=hp_dtype, device=device).max(), + fullscale, + input_backoff, + ) + ) + input_scale = [scale_to_pow2(x) for x in input_scale] + output_scale = calc_maxabs_scale( + torch.tensor(measurement.outputs[0], dtype=hp_dtype, device=device).max(), + fullscale, + input_backoff, + ) + output_scale = [scale_to_pow2(output_scale)] + return ModuleConfig(input_scale, output_scale, {}) + + +def linear_act_maxabs_pts_weight_opt_pts_pow2_scales(mod, measurement, params): + config = get_hqt_config(mod).cfg + device = torch.device("hpu") + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + fullscale = MAX_RANGE[lp_dtype] + scales = params["weight_scales"] + input_backoff = params["input_backoff"] + weight_backoff = params["weight_backoff"] + input_scale = calc_maxabs_scale( + torch.tensor(measurement.inputs[0], dtype=hp_dtype, device=device).max(), + fullscale, + input_backoff, + ) + weight_scale = mmse_scale(mod.weight, scales, lp_dtype, hp_dtype) + input_scale = scale_to_pow2(input_scale) + weight_scale = scale_to_pow2(weight_scale) + output_scale = input_scale * weight_scale + return ModuleConfig((input_scale,), (output_scale,), {"weight": weight_scale}) + + +def linear_act_maxabs_pts_weight_opt_pts_hw_scales(mod, measurement, params): + config = get_hqt_config(mod).cfg + device = torch.device("hpu") + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + fullscale = MAX_RANGE[lp_dtype] + scales = params["weight_scales"] + input_backoff = params["input_backoff"] + weight_backoff = params["weight_backoff"] + input_scale = calc_maxabs_scale( + torch.tensor(measurement.inputs[0], dtype=hp_dtype, device=device).max(), + fullscale, + input_backoff, + ) + weight_scale = mmse_scale(mod.weight, scales, lp_dtype, hp_dtype) + input_scale = scale_to_pow2_hw(input_scale, device_type=config["device_type"]) + weight_scale = scale_to_pow2_hw(weight_scale, device_type=config["device_type"]) + output_scale = input_scale * weight_scale + return ModuleConfig((input_scale,), (output_scale,), {"weight": weight_scale}) + + +def kv_cache_act_maxabs_pts_weight_maxabs_pts_pow2_hw_scales(mod, measurement, params): + config = get_hqt_config(mod).cfg + # calc the scale per layer tensor + device = torch.device("hpu") + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + fullscale = MAX_RANGE[lp_dtype] + input_backoff = params["input_backoff"] + input_scale = calc_maxabs_scale( + torch.tensor(measurement.inputs[0], dtype=hp_dtype, device=device).max(), + fullscale, + input_backoff, + ) + input_scale_list = [scale_to_pow2_hw(input_scale, device_type=config["device_type"])] + output_scale = [input_scale_list[0]] # output scale is same as the first input (current data) since range is same + return ModuleConfig(input_scale_list, output_scale, {}) + + +def kv_cache_act_maxabs_pts_pow2_weight_opt_pcs_pow2_scales(mod, measurement, params): + # calc the scale per layer tensor + device = torch.device("hpu") + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + fullscale = MAX_RANGE[lp_dtype] + input_backoff = params["input_backoff"] + input_scale = calc_maxabs_scale( + torch.tensor(measurement.inputs[0], dtype=hp_dtype, device=device).max(), + fullscale, + input_backoff, + ) + input_scale_list = [scale_to_pow2(input_scale)] + output_scale = [input_scale_list[0]] # output scale is same as the first input (current data) since range is same + return ModuleConfig(input_scale_list, output_scale, {}) + + +def softmax_input_unit_output_maxabs_pts_hw_scales(mod, measurement, params): + config = get_hqt_config(mod).cfg + device = torch.device("hpu") + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + fullscale = MAX_RANGE[lp_dtype] + input_backoff = params["input_backoff"] + output_scale = calc_maxabs_scale( + torch.tensor(measurement.outputs[0], dtype=hp_dtype, device=device).max(), + fullscale, + input_backoff, + ) + output_scale = [scale_to_pow2_hw(output_scale, device_type=config["device_type"])] + return ModuleConfig((), output_scale, {}) + + +def linear_act_maxabs_pts_pow2_hw_weights_maxabs_pcs_pow2_scales(mod, measurement, params): + config = get_hqt_config(mod).cfg + device = torch.device("hpu") + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + fullscale = MAX_RANGE[lp_dtype] + input_backoff = params["input_backoff"] + weight_backoff = params["weight_backoff"] + input_scale = calc_maxabs_scale( + torch.tensor(measurement.inputs[0], dtype=hp_dtype, device=device).max(), + fullscale, + input_backoff, + ) + input_scale = scale_to_pow2_hw(input_scale, device_type=config["device_type"]) + weight_scale_in_ch = torch.ones([mod.weight.shape[1], 1], dtype=hp_dtype, device=device) + + weight_range_out_ch = torch.max(torch.abs(mod.weight), dim=1)[0].reshape([-1, 1]) + weight_maxabs_scale_out_ch = calc_maxabs_scale(weight_range_out_ch, fullscale, weight_backoff) + weight_maxabs_scale_out_ch = scale_to_pow2(weight_maxabs_scale_out_ch) + output_scale = weight_maxabs_scale_out_ch * input_scale + return ModuleConfig( + (input_scale.flatten(),), + (output_scale.flatten(),), + { + "weight": { + 0: weight_maxabs_scale_out_ch.flatten(), + 1: weight_scale_in_ch.flatten(), + } + }, + ) + + +def linear_act_maxabs_pts_pow2_weights_maxabs_pcs_pow2_scales(mod, measurement, params): + device = torch.device("hpu") + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + fullscale = MAX_RANGE[lp_dtype] + input_backoff = params["input_backoff"] + weight_backoff = params["weight_backoff"] + input_scale = calc_maxabs_scale( + torch.tensor(measurement.inputs[0], dtype=hp_dtype, device=device).max(), + fullscale, + input_backoff, + ) + input_scale = scale_to_pow2(input_scale) + weight_scale_in_ch = torch.ones([mod.weight.shape[1], 1], dtype=hp_dtype, device=device) + + weight_range_out_ch = torch.max(torch.abs(mod.weight), dim=1)[0].reshape([-1, 1]) + weight_maxabs_scale_out_ch = calc_maxabs_scale(weight_range_out_ch, fullscale, weight_backoff) + weight_maxabs_scale_out_ch = scale_to_pow2(weight_maxabs_scale_out_ch) + output_scale = weight_maxabs_scale_out_ch * input_scale + return ModuleConfig( + (input_scale.flatten(),), + (output_scale.flatten(),), + { + "weight": { + 0: weight_maxabs_scale_out_ch.flatten(), + 1: weight_scale_in_ch.flatten(), + } + }, + ) + + +def linear_act_maxabs_pts_pow2_hw_weights_opt_pcs_pow2_scales(mod, measurement, params): + config = get_hqt_config(mod).cfg + device = torch.device("hpu") + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + fullscale = MAX_RANGE[lp_dtype] + input_backoff = params["input_backoff"] + weight_backoff = params["weight_backoff"] + weight_scales = params["weight_scales"] + input_scale = calc_maxabs_scale( + torch.tensor(measurement.inputs[0], dtype=hp_dtype, device=device).max(), + fullscale, + input_backoff, + ) + input_scale = scale_to_pow2_hw(input_scale, device_type=config["device_type"]) + weight_scale_in_ch = torch.ones([mod.weight.shape[1], 1], dtype=hp_dtype, device=device) + + weight_range_out_ch = torch.max(torch.abs(mod.weight), dim=1)[0].reshape([-1, 1]) + weight_maxabs_scale_out_ch = calc_maxabs_scale(weight_range_out_ch, fullscale, weight_backoff) + weight_maxabs_scale_out_ch = scale_to_pow2(weight_maxabs_scale_out_ch) + weight_opt_scale_out_ch = mmse_scale_multi( + torch.transpose(mod.weight, 0, 1), + weight_maxabs_scale_out_ch.squeeze(), + weight_scales, + lp_dtype, + hp_dtype, + ).unsqueeze(1) + weight_maxabs_scale_out_ch = weight_opt_scale_out_ch + weight_maxabs_scale_out_ch = scale_to_pow2(weight_maxabs_scale_out_ch) # should be power of 2, just making sure + output_scale = weight_maxabs_scale_out_ch * input_scale + return ModuleConfig( + (input_scale.flatten(),), + (output_scale.flatten(),), + { + "weight": { + 0: weight_maxabs_scale_out_ch.flatten(), + 1: weight_scale_in_ch.flatten(), + } + }, + ) + + +def linear_act_maxabs_pts_pow2_weights_opt_pcs_pow2_scales(mod, measurement, params): + device = torch.device("hpu") + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + fullscale = MAX_RANGE[lp_dtype] + input_backoff = params["input_backoff"] + weight_backoff = params["weight_backoff"] + weight_scales = params["weight_scales"] + input_scale = calc_maxabs_scale( + torch.tensor(measurement.inputs[0], dtype=hp_dtype, device=device).max(), + fullscale, + input_backoff, + ) + input_scale = scale_to_pow2(input_scale) + weight_scale_in_ch = torch.ones([mod.weight.shape[1], 1], dtype=hp_dtype, device=device) + + weight_range_out_ch = torch.max(torch.abs(mod.weight), dim=1)[0].reshape([-1, 1]) + weight_maxabs_scale_out_ch = calc_maxabs_scale(weight_range_out_ch, fullscale, weight_backoff) + weight_maxabs_scale_out_ch = scale_to_pow2(weight_maxabs_scale_out_ch) + weight_opt_scale_out_ch = mmse_scale_multi( + torch.transpose(mod.weight, 0, 1), + weight_maxabs_scale_out_ch.squeeze(), + weight_scales, + lp_dtype, + hp_dtype, + ).unsqueeze(1) + weight_maxabs_scale_out_ch = weight_opt_scale_out_ch + weight_maxabs_scale_out_ch = scale_to_pow2(weight_maxabs_scale_out_ch) # should be power of 2, just making sure + output_scale = weight_maxabs_scale_out_ch * input_scale + return ModuleConfig( + (input_scale.flatten(),), + (output_scale.flatten(),), + { + "weight": { + 0: weight_maxabs_scale_out_ch.flatten(), + 1: weight_scale_in_ch.flatten(), + } + }, + ) diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/smooth_quant.py b/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/smooth_quant.py new file mode 100644 index 00000000000..0c3e5f8cd67 --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/smooth_quant.py @@ -0,0 +1,132 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from tqdm import tqdm + +from ..common import * +from ..fp_utils import * + + +def linear_smoothquant_weights_opt_pow2_scales(mod, measurement, params): + device = torch.device("hpu") + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + fullscale = MAX_RANGE[lp_dtype] + input_backoff = params["input_backoff"] + weight_backoff = params["weight_backoff"] + alpha = params["alpha"] + transformed_weight_scales = params["transformed_weight_scales"] + input_range = torch.tensor(measurement.inputs[0], dtype=hp_dtype, device=device) + weight_range_in_ch = torch.max(torch.abs(mod.weight), dim=0)[0].reshape([-1, 1]) + input_scale = calc_maxabs_scale(input_range, fullscale, input_backoff) + weight_scale_in_ch = calc_maxabs_scale(weight_range_in_ch, fullscale, weight_backoff) + input_scale = (input_scale**alpha) / (weight_scale_in_ch ** (1 - alpha)) + input_scale = scale_to_pow2(input_scale) + weight_scale_in_ch = 1 / input_scale + trans_weight = scale_fcn(mod.weight, weight_scale_in_ch.reshape([1, -1])) + trans_weight_range_out_ch = torch.max(torch.abs(trans_weight), dim=1)[0].reshape([-1, 1]) + trans_weight_maxabs_scale_out_ch = calc_maxabs_scale(trans_weight_range_out_ch, fullscale, weight_backoff) + trans_weight_maxabs_scale_out_ch = scale_to_pow2(trans_weight_maxabs_scale_out_ch) + trans_weight_scale_out_ch = torch.zeros(mod.weight.shape[0]) + for k in tqdm(range(trans_weight_scale_out_ch.shape[0])): + trans_weight_scale_out_ch[k] = mmse_scale( + trans_weight[k, :], + [s * trans_weight_maxabs_scale_out_ch[k] for s in transformed_weight_scales], + lp_dtype, + hp_dtype, + ) + weight_scale_out_ch = scale_to_pow2(trans_weight_scale_out_ch) + output_scale = torch.tensor(weight_scale_out_ch, dtype=hp_dtype, device=device) + return ModuleConfig( + (input_scale.flatten(),), + (output_scale.flatten(),), + {"weight": {0: weight_scale_out_ch.flatten(), 1: weight_scale_in_ch.flatten()}}, + ) + + +def linear_smoothquant_weights_maxabs_pow2_scales(mod, measurement, params): + device = torch.device("hpu") + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + fullscale = MAX_RANGE[lp_dtype] + input_backoff = params["input_backoff"] + weight_backoff = params["weight_backoff"] + alpha = params["alpha"] + input_range = torch.tensor(measurement.inputs[0], dtype=hp_dtype, device=device) + weight_range_in_ch = torch.max(torch.abs(mod.weight), dim=0)[0].reshape([-1, 1]) + input_scale = calc_maxabs_scale(input_range, 1.0, 1.0) + weight_scale_in_ch = calc_maxabs_scale(weight_range_in_ch, 1.0, 1.0) + input_scale = (input_scale**alpha) / (weight_scale_in_ch ** (1 - alpha)) + input_scale = scale_to_pow2(input_scale) + input_range_post = input_range / input_scale + input_scale_post = calc_maxabs_scale(input_range_post.max(), fullscale, input_backoff) + input_scale_post = scale_to_pow2(input_scale_post) + input_scale = input_scale * input_scale_post + weight_scale_in_ch = 1 / input_scale + trans_weight = scale_fcn(mod.weight, weight_scale_in_ch.reshape([1, -1])) + trans_weight_range_out_ch = torch.max(torch.abs(trans_weight), dim=1)[0].reshape([-1, 1]) + trans_weight_maxabs_scale_out_ch = calc_maxabs_scale(trans_weight_range_out_ch, fullscale, weight_backoff) + trans_weight_maxabs_scale_out_ch = scale_to_pow2(trans_weight_maxabs_scale_out_ch) + weight_scale_out_ch = scale_to_pow2(trans_weight_maxabs_scale_out_ch) + output_scale = torch.tensor(weight_scale_out_ch, dtype=hp_dtype, device=device) + return ModuleConfig( + (input_scale.flatten(),), + (output_scale.flatten(),), + {"weight": {0: weight_scale_out_ch.flatten(), 1: weight_scale_in_ch.flatten()}}, + ) + + +def linear_weaksmoothquant_weights_maxabs_pow2_scales(mod, measurement, params): + device = torch.device("hpu") + lp_dtype = params["lp_dtype"] + hp_dtype = params["hp_dtype"] + fullscale = MAX_RANGE[lp_dtype] + input_backoff = params["input_backoff"] + weight_backoff = params["weight_backoff"] + alpha = params["alpha"] + input_range = torch.tensor(measurement.inputs[0], dtype=hp_dtype, device=device).max().clamp(min=1e-5) + input_range_mid = input_range.max() / torch.sqrt(input_range.max() / input_range.min().clamp(min=1e-5)) + input_scale_pcs = calc_maxabs_scale(input_range.clamp(min=1e-5), input_range_mid, 1.0).clamp(min=1e-5) + weight_range_in_ch = torch.max(torch.abs(mod.weight), dim=0)[0].reshape([-1, 1]).clamp(min=1e-5) + weight_range_in_ch_mid = weight_range_in_ch.max() / torch.sqrt( + weight_range_in_ch.max() / weight_range_in_ch.min().clamp(min=1e-5) + ).clamp(min=1e-5) + weight_scale_pcs = calc_maxabs_scale(weight_range_in_ch.clamp(min=1e-5), weight_range_in_ch_mid, 1.0).clamp( + min=1e-5 + ) + + input_scale = ((input_scale_pcs**alpha) / (weight_scale_pcs ** (1 - alpha))).clamp(min=1e-5) + input_scale = scale_to_pow2(input_scale) + input_scale_post = calc_maxabs_scale((input_range / input_scale).max(), fullscale, input_backoff) + input_scale_post = scale_to_pow2(input_scale_post) + + weight_scale_in_ch = torch.ones([mod.weight.shape[1], 1], dtype=hp_dtype, device=device) * (1 / input_scale) + + trans_weight = scale_fcn(mod.weight, weight_scale_in_ch.reshape([1, -1])) + weight_range_out_ch = torch.max(torch.abs(trans_weight), dim=1)[0].reshape([-1, 1]) + + weight_maxabs_scale_out_ch = calc_maxabs_scale(weight_range_out_ch, fullscale, weight_backoff) + weight_maxabs_scale_out_ch = scale_to_pow2(weight_maxabs_scale_out_ch) + output_scale = weight_maxabs_scale_out_ch * input_scale_post + return ModuleConfig( + (input_scale.flatten() * input_scale_post,), + (output_scale.flatten(),), + { + "weight": { + 0: weight_maxabs_scale_out_ch.flatten(), + 1: weight_scale_in_ch.flatten(), + } + }, + ) diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/unit_scale.py b/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/unit_scale.py new file mode 100644 index 00000000000..4ced867fe4a --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/scale_methods/unit_scale.py @@ -0,0 +1,66 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from ..common import * +from ..fp_utils import * + + +def linear_unit_scale_scales(mod, measurement, params): + device = torch.device("hpu") + hp_dtype = params["hp_dtype"] + input_scale = torch.tensor(1.0, dtype=hp_dtype, device=device) + weight_scale = torch.tensor(1.0, dtype=hp_dtype, device=device) + output_scale = torch.tensor(1.0, dtype=hp_dtype, device=device) + return ModuleConfig((input_scale,), (output_scale,), {"weight": weight_scale}) + + +def fsdpa_unit_scale_scales(mod, measurement, params): + device = torch.device("hpu") + hp_dtype = torch.float32 # params["hp_dtype"] + q_scale = torch.tensor(1.0, dtype=hp_dtype, device=device) + k_scale = torch.tensor(1.0, dtype=hp_dtype, device=device) + v_scale = torch.tensor(1.0, dtype=hp_dtype, device=device) + softmax_scale = torch.tensor(1.0, dtype=hp_dtype, device=device) + input_scale = (q_scale, k_scale, v_scale, softmax_scale) + output_scale = (torch.tensor(1.0, dtype=hp_dtype, device=device),) + return ModuleConfig(input_scale, output_scale, {}) + + +def matmul_unit_scale_scales(mod, measurement, params): + device = torch.device("hpu") + hp_dtype = params["hp_dtype"] + input_scale = ( + torch.tensor(1.0, dtype=hp_dtype, device=device), + torch.tensor(1.0, dtype=hp_dtype, device=device), + ) + output_scale = (torch.tensor(1.0, dtype=hp_dtype, device=device),) + return ModuleConfig(input_scale, output_scale, {}) + + +def softmax_unit_scale_scales(mod, measurement, params): + device = torch.device("hpu") + hp_dtype = params["hp_dtype"] + input_scale = (torch.tensor(1.0, dtype=hp_dtype, device=device),) + output_scale = (torch.tensor(1.0, dtype=hp_dtype, device=device),) + return ModuleConfig(input_scale, output_scale) + + +def kv_cache_unit_scale_scales(mod, measurement, params): + device = torch.device("hpu") + hp_dtype = params["hp_dtype"] + input_scale = (torch.tensor(1.0, dtype=hp_dtype, device=device),) + output_scale = (torch.tensor(1.0, dtype=hp_dtype, device=device),) + return ModuleConfig(input_scale, output_scale) diff --git a/neural_compressor/torch/algorithms/fp8_quant/_core/utils.py b/neural_compressor/torch/algorithms/fp8_quant/_core/utils.py new file mode 100644 index 00000000000..30635109c2e --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/_core/utils.py @@ -0,0 +1,73 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .._quant_common.helper_modules import * +from .._quant_common.quant_config import QuantMode, get_hqt_config +from ..utils.logger import logger +from .common import mod_default_dict +from .measure import prepare_model as prepare_model_for_measure +from .quantize import quantize +from .scale import scale_method_mapping, scaling_params + + +def update_mod_dict(config): + assert ( + len(config.cfg["mod_dict"]) == 0 + ), f"Custom modules are not supported: {config.cfg['mod_dict'].keys()}. Please add it in the code." + config.cfg["mod_dict"].update({k: mod_default_dict[k].type for k in mod_default_dict}) + + +def print_init_info(config): + import importlib.metadata + + versionStr = importlib.metadata.version("neural_compressor_3x_pt") + locationStr = versionStr.find("git") + 3 + logger.info("neural_compressor_3x_pt Git revision = %s", versionStr[locationStr:]) + logger.info("neural_compressor_3x_pt Configuration = %s", config) + + +def is_substr(substr_list, target): + return any([x in target for x in substr_list]) + + +def prepare_model(model): + config = get_hqt_config(model) + update_mod_dict(config) + allowlist = set(config.cfg["mod_dict"].keys()) + blocklist = set() + for type_st in config.cfg["blocklist"]["types"]: + blocklist.add(type_st) + allowlist.difference_update(blocklist) + allowlist_tuple = tuple(allowlist) + mod_list = [] + for name, mod in model.named_modules(): + mod_type = mod.__class__.__name__ + if ( + (mod_type in allowlist_tuple) + and (is_substr(config.cfg["allowlist"]["names"], name) or len(config.cfg["allowlist"]["names"]) == 0) + and (not is_substr(config.cfg["blocklist"]["names"], name)) + ): + mod_list.append(name) + + print_init_info(config) + + logger.debug("Module list: %s", mod_list) + logger.info("Total modules : %d", len(mod_list)) + if (config.cfg["mode"] == QuantMode.MEASURE) or (config.cfg["mode"] == QuantMode.SHAPE): + return prepare_model_for_measure(model, mod_list) + elif config.cfg["mode"] == QuantMode.QUANTIZE: + scaling_method_name = scale_method_mapping[(config.cfg["scale_method"], config.cfg["observer"])] + scaling_params[scaling_method_name].update(config.cfg["scale_params"]) + config.cfg["scale_params"] = scaling_params[scaling_method_name] + return quantize(model, mod_list) diff --git a/neural_compressor/torch/amp/fp8/__init__.py b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/__init__.py similarity index 100% rename from neural_compressor/torch/amp/fp8/__init__.py rename to neural_compressor/torch/algorithms/fp8_quant/_quant_common/__init__.py diff --git a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py new file mode 100644 index 00000000000..8957096bbc4 --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/helper_modules.py @@ -0,0 +1,829 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn + +from .quant_config import QuantMode, get_hqt_config + +try: # backwards compatibility for 1.16 + from habana_frameworks.torch.hpex.kernels import fp8_fused_sdpa +except ImportError: + pass + + +class BMM(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.bmm(x, y) + + +class Matmul(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, *args, **kwargs): + return torch.matmul(*args, **kwargs) + + +class Identity(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x.clone() + + +class Softmax(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, dim=None): + return torch.softmax(x, dim) + + +def matmul_fp8( + input, + other, + out=None, + out_dtype=torch.bfloat16, + scale_input_inv=None, + scale_other_inv=None, +): + res = torch.ops.hpu.fp8_gemm_v2( + input, + False, + other, + False, + out, + out_dtype, + scale_input_inv, + scale_other_inv, + None, + False, + ) + return res + + +def measure_input(input, observer): + for i in range(len(observer)): + observer[i].measure(input[i]) + + +def measure_output(output, observer): + if observer: + for i in range(len(observer)): + observer[i].measure(output[i]) + + +def conv2d_fp8( + input, + other, + bias, + stride, + padding, + dilation, + groups, + out_dtype=torch.bfloat16, + scale_input_inv=None, + scale_other_inv=None, +): + return torch.ops.hpu.conv2d_fp8( + input=input, + weight=other, + bias=bias, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + out_dtype=out_dtype, + scale_input=scale_input_inv, + scale_weight=scale_other_inv, + ) + + +def set_attrs_from_orig_model(cls_instance, mod, mod_extra_config, *func_names): + cls_instance.__dict__.update(mod.__dict__) + config = get_hqt_config(cls_instance) + cls_instance.extra_repr_org = mod.extra_repr + cls_instance.class_name_org = mod.__class__.__name__ + cls_instance._mod_extra_config = mod_extra_config + cls_instance.quantization_mode = config.cfg["mode"] + # store original module in order to invoke its functions during measurements. + # this may be omitted of torch remove the related validation from dynamo. see SW-187731. + cls_instance.__dict__["orig_mod"] = mod + cls_instance.forward_orig = mod.forward + if func_names is not None: + for func in func_names: + setattr(cls_instance, func, getattr(mod, func)) + + +def get_current_repr(cls_instance, *member_names): + curr_repr = "" + if cls_instance.quantization_mode == QuantMode.QUANTIZE: + first_name = True + for name in member_names: + if not first_name: + curr_repr += ", " + curr_repr += f"{name} dtype={getattr(cls_instance, name).dtype}" + first_name = False + return curr_repr + + +def extra_representation(org_repr, org_name, curr_repr): + repr = f"original={org_name}," + (" " + org_repr + "," if org_repr != "" else "") + return f"{repr} {curr_repr}" + + +def _raise_lora_layer_error(layer_class): + raise RuntimeError( + f"{layer_class} quantization is not supported in case of lora_layer member is not None." + f" Can add {layer_class} to 'blocklist' field in quantization config file" + ) + + +class PatchedMatmul(nn.Module): + def __init__(self, mod, mod_extra_config, *args, **kwargs): + super().__init__() + set_attrs_from_orig_model(self, mod, mod_extra_config) + if self.quantization_mode == QuantMode.QUANTIZE: + self.quant_input_0 = self._mod_extra_config.inputs[0] + self.quant_input_1 = self._mod_extra_config.inputs[1] + self.scale_input = nn.Parameter(mod_extra_config.scale.inputs[0]) + self.scale_other = nn.Parameter(mod_extra_config.scale.inputs[1]) + elif (self.quantization_mode == QuantMode.MEASURE) or (self.quantization_mode == QuantMode.SHAPE): + self.forward = self.forward_measure + + def forward(self, input, other): + qinput = self.quant_input_0(input) + qother = self.quant_input_1(other) + output = matmul_fp8( + qinput, + qother, + out_dtype=self._mod_extra_config.config_params["hp_dtype"], + scale_input_inv=self.scale_input, + scale_other_inv=self.scale_other, + ) + return output + + def forward_measure(self, input, other): + measure_input((input, other), observer=self._mod_extra_config.inputs) + output = self.orig_mod(input, other) + measure_output((output,), self._mod_extra_config.outputs) + return output + + def extra_repr(self) -> str: + return extra_representation( + self.extra_repr_org(), + self.class_name_org, + get_current_repr(self, "scale_input", "scale_other"), + ) + + +class PatchedLinear(nn.Module): + def __init__(self, mod, mod_extra_config, *args, **kwargs): + super().__init__() + set_attrs_from_orig_model(self, mod, mod_extra_config) + if self.quantization_mode == QuantMode.QUANTIZE: + # When offloading weights to disk using device_map, the module forward is overridden. + # __dict__.update call again overrides the PatchedLinear forward with the forward that device_map planted. + # So need to set PatchedLinear forawrd to be the right forward. + self.forward = self.forward_quant + self.quant_input = self._mod_extra_config.inputs[0] + self.weight = nn.Parameter(self.weight.t().contiguous()) + self.scale_input = nn.Parameter(mod_extra_config.scale.inputs[0]) + if isinstance(mod_extra_config.scale.params["weight"], (torch.Tensor, float)): + self.scale_weight = nn.Parameter(mod_extra_config.scale.params["weight"]) + elif isinstance(mod_extra_config.scale.params["weight"], dict): + # PCQ weight is calculated with actual weight [0] and ones [1] + self.scale_weight = nn.Parameter(mod_extra_config.scale.params["weight"][0]) + elif (self.quantization_mode == QuantMode.MEASURE) or (self.quantization_mode == QuantMode.SHAPE): + self.forward = self.forward_measure + + def forward_quant(self, input): + qinput = self.quant_input(input) + y = matmul_fp8( + qinput, + self.weight, + out_dtype=self._mod_extra_config.config_params["hp_dtype"], + scale_input_inv=self.scale_input, + scale_other_inv=self.scale_weight, + ) + output = y + self.bias if (self.bias is not None) else y + return output + + def forward_measure(self, input): + measure_input((input,), observer=self._mod_extra_config.inputs) + output = self.orig_mod(input) + measure_output((output,), self._mod_extra_config.outputs) + return output + + def extra_repr(self) -> str: + return extra_representation( + self.extra_repr_org(), + self.class_name_org, + get_current_repr(self, "scale_input", "scale_weight"), + ) + + +class PatchedLinearAllReduce(nn.Module): + def __init__(self, mod, mod_extra_config, *args, **kwargs): + super().__init__() + set_attrs_from_orig_model(self, mod, mod_extra_config) + self.scoped_version = mod.__class__.__name__ == "ScopedLinearAllReduce" + if self.quantization_mode == QuantMode.QUANTIZE: + self.quant_input = self._mod_extra_config.inputs[0] + self.quant_output = self._mod_extra_config.outputs[0] + self.weight = nn.Parameter(self.weight.t().contiguous()) + self.scale_input = nn.Parameter(mod_extra_config.scale.inputs[0]) + if isinstance(mod_extra_config.scale.params["weight"], (torch.Tensor, float)): + self.scale_weight = nn.Parameter(mod_extra_config.scale.params["weight"]) + elif isinstance(mod_extra_config.scale.params["weight"], dict): + # PCQ weight is calculated with actual weight [0] and ones [1] + self.scale_weight = nn.Parameter(mod_extra_config.scale.params["weight"][0]) + elif (self.quantization_mode == QuantMode.MEASURE) or (self.quantization_mode == QuantMode.SHAPE): + self.forward = self.forward_measure + + def forward(self, input): + # pre_all_reduce + qinput = self.quant_input(input) + output = matmul_fp8( + qinput, + self.weight, + out_dtype=self._mod_extra_config.config_params["hp_dtype"], + scale_input_inv=self.scale_input, + scale_other_inv=self.scale_weight, + ) + dqoutput = self.quant_output(output) + if not self.scoped_version: + self.all_reduce(dqoutput) + dqoutput = self.post_all_reduce(dqoutput) + return dqoutput + + def forward_measure(self, input): + measure_input((input,), observer=self._mod_extra_config.inputs) + output = torch.matmul(input, self.weight.transpose(-1, -2)) + measure_output((output,), self._mod_extra_config.outputs) + # in scoped version all reduce is being called outside of the layer + if not self.scoped_version: + self.all_reduce(output) + output = self.post_all_reduce(output) + return output + + def all_reduce(self, input): + if self.mp_group is not None: + from deepspeed import comm as dist + + dist.inference_all_reduce(input, group=self.mp_group) + + def post_all_reduce(self, input): + output = input + self.bias if (self.bias is not None) else input + return output + + def extra_repr(self) -> str: + return extra_representation( + self.extra_repr_org(), + self.class_name_org, + get_current_repr(self, "scale_input", "scale_weight"), + ) + + +class PatchedRowParallelLinear(nn.Module): + def __init__(self, mod, mod_extra_config, *args, **kwargs): + super().__init__() + set_attrs_from_orig_model(self, mod, mod_extra_config, "resolve_input") + if self.quantization_mode == QuantMode.QUANTIZE: + self.quant_input = self._mod_extra_config.inputs[0] + self.quant_output = self._mod_extra_config.outputs[0] + self.weight = nn.Parameter(self.weight.t().contiguous()) + self.scale_input = nn.Parameter(mod_extra_config.scale.inputs[0]) + if isinstance(mod_extra_config.scale.params["weight"], (torch.Tensor, float)): + self.scale_weight = nn.Parameter(mod_extra_config.scale.params["weight"]) + elif isinstance(mod_extra_config.scale.params["weight"], dict): + # PCQ weight is calculated with actual weight [0] and ones [1] + self.scale_weight = nn.Parameter(mod_extra_config.scale.params["weight"][0]) + elif (self.quantization_mode == QuantMode.MEASURE) or (self.quantization_mode == QuantMode.SHAPE): + self.forward = self.forward_measure + + def forward(self, input): + resolved_input = self.resolve_input(input) + qinput = self.quant_input(resolved_input) + output = matmul_fp8( + qinput, + self.weight, + out_dtype=self._mod_extra_config.config_params["hp_dtype"], + scale_input_inv=self.scale_input, + scale_other_inv=self.scale_weight, + ) + dqoutput = self.quant_output(output) + if self.reduce_results: + dqoutput = self.collective_func(dqoutput) + return self.post_all_reduce(dqoutput) + + def forward_measure(self, input): + resolved_input = self.resolve_input(input) + measure_input((resolved_input,), observer=self._mod_extra_config.inputs) + output = torch.matmul(resolved_input, self.weight.transpose(-1, -2)) + measure_output((output,), self._mod_extra_config.outputs) + if self.reduce_results: + output = self.collective_func(output) + return self.post_all_reduce(output) + + def post_all_reduce(self, output): + assert ( + self.reduce_results or (not self.bias) or self.skip_bias_add + ), "When not reduce the results, adding bias to the results can lead to incorrect results" + if not self.skip_bias_add: + output = output + self.bias if self.bias is not None else output + output_bias = None + else: + output_bias = self.bias + return output, output_bias + + def extra_repr(self) -> str: + return extra_representation( + self.extra_repr_org(), + self.class_name_org, + get_current_repr(self, "scale_input", "scale_weight"), + ) + + +class PatchedColumnParallelLinear(nn.Module): + def __init__(self, mod, mod_extra_config, *args, **kwargs): + super().__init__() + set_attrs_from_orig_model(self, mod, mod_extra_config) + if self.quantization_mode == QuantMode.QUANTIZE: + self.quant_input = self._mod_extra_config.inputs[0] + self.quant_output = self._mod_extra_config.outputs[0] + self.weight = nn.Parameter(self.weight.t().contiguous()) + self.scale_input = nn.Parameter(mod_extra_config.scale.inputs[0]) + if isinstance(mod_extra_config.scale.params["weight"], (torch.Tensor, float)): + self.scale_weight = nn.Parameter(mod_extra_config.scale.params["weight"]) + elif isinstance(mod_extra_config.scale.params["weight"], dict): + # PCQ weight is calculated with actual weight [0] and ones [1] + self.scale_weight = nn.Parameter(mod_extra_config.scale.params["weight"][0]) + elif (self.quantization_mode == QuantMode.MEASURE) or (self.quantization_mode == QuantMode.SHAPE): + self.forward = self.forward_measure + + def forward(self, input): + qinput = self.quant_input(input) + output = matmul_fp8( + qinput, + self.weight, + out_dtype=self._mod_extra_config.config_params["hp_dtype"], + scale_input_inv=self.scale_input, + scale_other_inv=self.scale_weight, + ) + dqoutput = self.quant_output(output) + if self.gather_output: + dqoutput = self.orig_mod.collective_func(dqoutput) + return self.post_all_reduce(dqoutput) + + def forward_measure(self, input): + measure_input((input,), observer=self._mod_extra_config.inputs) + output = torch.matmul(input, self.weight.transpose(-1, -2)) + measure_output((output,), self._mod_extra_config.outputs) + if self.gather_output: + output = self.orig_mod.collective_func(output) + return self.post_all_reduce(output) + + def post_all_reduce(self, output): + if not self.skip_bias_add: + output = output + self.bias if self.bias is not None else output + output_bias = None + else: + output_bias = self.bias + return output, output_bias + + def extra_repr(self) -> str: + return extra_representation( + self.extra_repr_org(), + self.class_name_org, + get_current_repr(self, "scale_input", "scale_weight"), + ) + + +class PatchedLmHeadLinearAllreduce(nn.Module): + def __init__(self, mod, mod_extra_config, *args, **kwargs): + super().__init__() + set_attrs_from_orig_model(self, mod, mod_extra_config) + if self.quantization_mode == QuantMode.QUANTIZE: + self.quant_input = self._mod_extra_config.inputs[0] + self.quant_output = self._mod_extra_config.outputs[0] + self.weight = nn.Parameter(self.weight.t().contiguous()) + self.scale_input = nn.Parameter(mod_extra_config.scale.inputs[0]) + if isinstance(mod_extra_config.scale.params["weight"], (torch.Tensor, float)): + self.scale_weight = nn.Parameter(mod_extra_config.scale.params["weight"]) + elif isinstance(mod_extra_config.scale.params["weight"], dict): + # PCQ weight is calculated with actual weight [0] and ones [1] + self.scale_weight = nn.Parameter(mod_extra_config.scale.params["weight"][0]) + elif (self.quantization_mode == QuantMode.MEASURE) or (self.quantization_mode == QuantMode.SHAPE): + self.forward = self.forward_measure + + def forward(self, input): + assert ( + input.shape[-1] % self.world_size == 0 + ), "Please ensure that self.world_size is divisible by input.shape[-1]" + input_shard = input.shape[-1] // self.world_size + splittedInput = input[:, :, self.rank * input_shard : (self.rank + 1) * input_shard] + qinput = self.quant_input(splittedInput) + output = matmul_fp8( + qinput, + self.weight, + out_dtype=self._mod_extra_config.config_params["hp_dtype"], + scale_input_inv=self.scale_input, + scale_other_inv=self.scale_weight, + ) + dqoutput = self.quant_output(output) + + if self.mp_group is not None: + from deepspeed import comm as dist + + dist.inference_all_reduce(dqoutput, group=self.mp_group) + if self.bias is not None: + dqoutput += self.bias + return dqoutput + + def forward_measure(self, input): + assert ( + input.shape[-1] % self.world_size == 0 + ), "Please ensure that self.world_size is divisible by input.shape[-1]" + input_shard = input.shape[-1] // self.world_size + splittedInput = input[:, :, self.rank * input_shard : (self.rank + 1) * input_shard] + measure_input((splittedInput,), observer=self._mod_extra_config.inputs) + output = torch.matmul(splittedInput, self.weight.t()) + measure_output((output,), self._mod_extra_config.outputs) + + if self.mp_group is not None: + from deepspeed import comm as dist + + dist.inference_all_reduce(output, group=self.mp_group) + if self.bias is not None: + output += self.bias + return output + + def extra_repr(self) -> str: + return extra_representation( + self.extra_repr_org(), + self.class_name_org, + get_current_repr(self, "scale_input", "scale_weight"), + ) + + +class PatchedKVCache(nn.Module): + # Module to patch KVCache module from llama model + def __init__(self, mod, mod_extra_config, *args, **kwargs): + super().__init__() + set_attrs_from_orig_model(self, mod, mod_extra_config, "forward", "get_shape") + self.org_allocate = mod.allocate + self.org_update = mod.update + if self.quantization_mode == QuantMode.QUANTIZE: + mod.update = self.update + self.quant_input = self._mod_extra_config.inputs[0] + self.quant_output = self._mod_extra_config.outputs[0] + elif (self.quantization_mode == QuantMode.MEASURE) or (self.quantization_mode == QuantMode.SHAPE): + self.update = self.update_measure + mod.update = self.update_measure + + # overwrite allocate function of original module to force allocation in fp8 + def allocate(self, inp_seq_len, dtype, device, shape): + dtype = torch.float8_e4m3fn if (self.quantization_mode == QuantMode.QUANTIZE) else dtype + return self.org_allocate(inp_seq_len, dtype, device, shape) + + # overwrite update function of original module to force quant and dequant of cache input and output + def update(self, prev, cur, dim, idx, inp_seq_len): + qinput = self.quant_input(cur) + output = self.org_update(prev, qinput, dim, idx, inp_seq_len) + if output.dtype == torch.float8_e4m3fn: + return self.quant_output(output) + else: + return output + + # overwrite update function of original module to force quant and dequant of cache input and output + def update_measure(self, prev, cur, dim, idx, inp_seq_len): + measure_input((cur,), self._mod_extra_config.inputs) + output = self.org_update(prev, cur, dim, idx, inp_seq_len) + measure_output((output,), self._mod_extra_config.outputs) + return output + + +class PatchedVLLMKVCache(nn.Module): + # Module to patch VLLMKVCache module from llama model + def __init__(self, mod, mod_extra_config, *args, **kwargs): + super().__init__() + set_attrs_from_orig_model(self, mod, mod_extra_config) + if self.quantization_mode == QuantMode.QUANTIZE: + self.quant_input = self._mod_extra_config.inputs[0] + self.quant_output = self._mod_extra_config.outputs[0] + self.orig_fetch_from_cache = mod.fetch_from_cache + elif (self.quantization_mode == QuantMode.MEASURE) or (self.quantization_mode == QuantMode.SHAPE): + self.fetch_from_cache = mod.fetch_from_cache + self.forward = self.forward_measure + + def forward(self, input, cache, block_indices, block_offset): + qinput = self.quant_input(input) + output_cache = self.forward_orig(qinput, cache, block_indices, block_offset) + return self.quant_output(output_cache) + + def forward_measure(self, input, cache, block_indices, block_offset): + measure_input((input), self._mod_extra_config.inputs) + output_cache = self.forward_orig(input, cache, block_indices, block_offset) + measure_output((output_cache), self._mod_extra_config.outputs) + return output_cache + + def fetch_from_cache(self, cache, blocks, permutations): + quant_cache = self.quant_input(cache) + output_cache = self.orig_fetch_from_cache(quant_cache, blocks, permutations) + for i in range(len(output_cache)): + output_cache[i] = self.quant_output(output_cache[i]) + return output_cache + + +class PatchedConv2d(nn.Conv2d): + def __init__(self, mod, mod_extra_config, *args, **kwargs): + set_attrs_from_orig_model(self, mod, mod_extra_config) + if self.quantization_mode == QuantMode.QUANTIZE: + self.quant_input = self._mod_extra_config.inputs[0] + self.scale_input = nn.Parameter(mod_extra_config.scale.inputs[0]) + self.scale_weight = nn.Parameter(mod_extra_config.scale.params["weight"]) + elif (self.quantization_mode == QuantMode.MEASURE) or (self.quantization_mode == QuantMode.SHAPE): + self.forward = self.forward_measure + + def forward(self, input): + qinput = self.quant_input(input) + output = conv2d_fp8( + qinput, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + out_dtype=self._mod_extra_config.config_params["hp_dtype"], + scale_input_inv=self.scale_input, + scale_other_inv=self.scale_weight, + ) + return output + + def forward_measure(self, input): + measure_input((input,), observer=self._mod_extra_config.inputs) + output = self.orig_mod(input) + measure_output((output,), self._mod_extra_config.outputs) + return output + + def extra_repr(self) -> str: + return extra_representation( + self.extra_repr_org(), + self.class_name_org, + get_current_repr(self, "scale_input", "scale_weight"), + ) + + +class PatchedSoftmax(nn.Module): + def __init__(self, mod, mod_extra_config, *args, **kwargs): + super().__init__() + set_attrs_from_orig_model(self, mod, mod_extra_config) + if self.quantization_mode == QuantMode.QUANTIZE: + self.quant_output = self._mod_extra_config.outputs[0] + # input scale is 1 assuming the input to SM is descaled because we are using HW supported scales + self.scale_input = nn.Parameter(torch.Tensor([1.0])) + self.scale_output = nn.Parameter(torch.Tensor([1 / mod_extra_config.scale.outputs[0]])) + elif (self.quantization_mode == QuantMode.MEASURE) or (self.quantization_mode == QuantMode.SHAPE): + self.forward = self.forward_measure + + def forward(self, x, dim=None, invAttnHead=None): + output = torch.ops.hpu.softmax_fp8(x, dim, self.scale_input, self.scale_output, invAttnHead) + return self.quant_output(output) + + def forward_measure(self, x, dim=None, invAttnHead=None): + measure_input((x,), observer=self._mod_extra_config.inputs) + output = self.orig_mod(x, dim, invAttnHead) + measure_output((output,), self._mod_extra_config.outputs) + return output + + def extra_repr(self) -> str: + return extra_representation( + self.extra_repr_org(), + self.class_name_org, + get_current_repr(self, "scale_input", "scale_output"), + ) + + +class PatchedLoRACompatibleLinear(nn.Linear): + def __init__(self, mod, mod_extra_config, *args, **kwargs): + set_attrs_from_orig_model(self, mod, mod_extra_config) + if self.quantization_mode == QuantMode.QUANTIZE: + self.quant_input = self._mod_extra_config.inputs[0] + self.weight = nn.Parameter(self.weight.t().contiguous()) + self.scale_input = nn.Parameter(mod_extra_config.scale.inputs[0]) + self.scale_weight = nn.Parameter(mod_extra_config.scale.params["weight"]) + elif (self.quantization_mode == QuantMode.MEASURE) or (self.quantization_mode == QuantMode.SHAPE): + self.forward = self.forward_measure + + def forward(self, input, scale: float = 1.0): + qinput = self.quant_input(input) + y = matmul_fp8( + qinput, + self.weight, + out_dtype=self._mod_extra_config.config_params["hp_dtype"], + scale_input_inv=self.scale_input, + scale_other_inv=self.scale_weight, + ) + output = y + self.bias if (self.bias is not None) else y + if self.lora_layer is not None: + # TODO SW-174899 support lora layer quantization + _raise_lora_layer_error(self.class_name_org) + # output = output + (scale * self.lora_layer(input)) + return output + + def forward_measure(self, input, scale: float = 1.0): + measure_input((input,), observer=self._mod_extra_config.inputs) + output = self.orig_mod(input, scale) + measure_output((output,), self._mod_extra_config.outputs) + return output + + def extra_repr(self) -> str: + return extra_representation( + self.extra_repr_org(), + self.class_name_org, + get_current_repr(self, "scale_input", "scale_weight"), + ) + + +class PatchedLoRACompatibleConv(nn.Conv2d): + def __init__(self, mod, mod_extra_config, *args, **kwargs): + set_attrs_from_orig_model(self, mod, mod_extra_config) + if self.quantization_mode == QuantMode.QUANTIZE: + self.quant_input = self._mod_extra_config.inputs[0] + self.scale_input = nn.Parameter(mod_extra_config.scale.inputs[0]) + self.scale_weight = nn.Parameter(mod_extra_config.scale.params["weight"]) + elif (self.quantization_mode == QuantMode.MEASURE) or (self.quantization_mode == QuantMode.SHAPE): + self.forward = self.forward_measure + + def forward(self, input, scale: float = 1.0): + qinput = self.quant_input(input) + if self.lora_layer is not None: + # TODO SW-174899 support lora layer quantization + _raise_lora_layer_error(self.class_name_org) + # output = conv2d_fp8(qinput, self.weight, None, self.stride, self.padding, self.dilation, self.groups, \ + # out_dtype=self._mod_extra_config.config_params["hp_dtype"], scale_input_inv=self.scale_input, scale_other_inv=self.scale_weight) + # output = output + (scale * self.lora_layer(input)) + # output = output+torch.unsqueeze(torch.unsqueeze(self.bias,1), 1) if (self.bias is not None) else output + else: + output = conv2d_fp8( + qinput, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + out_dtype=self._mod_extra_config.config_params["hp_dtype"], + scale_input_inv=self.scale_input, + scale_other_inv=self.scale_weight, + ) + return output + + def forward_measure(self, input, scale: float = 1.0): + measure_input((input,), observer=self._mod_extra_config.inputs) + output = self.orig_mod(input, scale) + measure_output((output,), self._mod_extra_config.outputs) + return output + + def extra_repr(self) -> str: + return extra_representation( + self.extra_repr_org(), + self.class_name_org, + get_current_repr(self, "scale_input", "scale_weight"), + ) + + +class PatchedModuleFusedSDPA(nn.Module): + def __init__(self, mod, mod_extra_config, *args, **kwargs): + # fsdpa is combined out of - BMM1(Q,K) -> Softmax -> BMM2(AMAX,V) + # during measure we receive the amax value from the cguid and apply it during quant as input + super().__init__() + set_attrs_from_orig_model(self, mod, mod_extra_config) + if self.quantization_mode == QuantMode.QUANTIZE: + self.quant_q = self._mod_extra_config.inputs[0] + self.quant_k = self._mod_extra_config.inputs[1] + self.quant_v = self._mod_extra_config.inputs[2] + self.dequant_output = self._mod_extra_config.outputs[0] + self.scale_q = nn.Parameter(mod_extra_config.scale.inputs[0].type(torch.float32)) + self.scale_k = nn.Parameter(mod_extra_config.scale.inputs[1].type(torch.float32)) + self.scale_v = nn.Parameter(mod_extra_config.scale.inputs[2].type(torch.float32)) + self.descale_amax = nn.Parameter(mod_extra_config.scale.inputs[3].type(torch.float32)) + self.scale_output = nn.Parameter(1 / mod_extra_config.scale.outputs[0].type(torch.float32)) + self.scale_amax = nn.Parameter(1 / self.descale_amax) + elif (self.quantization_mode == QuantMode.MEASURE) or (self.quantization_mode == QuantMode.SHAPE): + self.forward = self.forward_measure + + def forward( + self, + q, + k, + v, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=None, + softmax_mode="None", + ): + qinput = self.quant_q(q).detach() + kinput = self.quant_k(k).detach() + vinput = self.quant_v(v).detach() + results = fp8_fused_sdpa( + qinput, + kinput, + vinput, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + # fp8_fused_sdpa in fp8 mode supports only FastSoftmax + softmax_mode="None", + d_scale_q=self.scale_q, + d_scale_k=self.scale_k, + d_scale_v=self.scale_v, + q_scale_s=self.scale_amax, + q_scale_o=self.scale_output, + d_scale_s=self.descale_amax, + is_amax_s=False, + ) + output = results[0] + d_out = self.dequant_output(output) + return d_out + + def forward_measure( + self, + q, + k, + v, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=None, + softmax_mode="fast", + ): + dq = q.detach() + dk = k.detach() + dv = v.detach() + measure_input((dq, dk, dv), observer=self._mod_extra_config.inputs) + results = fp8_fused_sdpa( + dq, + dk, + dv, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + # fp8_fused_sdpa in bf16 can use either FastSoftmax or regular + softmax_mode="fast", + is_amax_s=True, + ) + output = results[0] + amax = results[1] + measure_output((output, amax), self._mod_extra_config.outputs) + return output + + def extra_repr(self) -> str: + return extra_representation( + self.extra_repr_org(), + self.class_name_org, + get_current_repr( + self, + "scale_q", + "scale_k", + "scale_v", + "descale_amax", + "scale_amax", + "scale_output", + ), + ) + + +class PatchedUnmeasuredModule(nn.Module): + def __init__(self, name, *args, **kwargs): + super().__init__() + self.name = name + + def forward(self, *args, **kwargs): + raise Exception( + "Error - Layer '{}' was called but was not quantized because no measures were supplied.".format(self.name) + ) + + def extra_repr(self) -> str: + return f"Dummy patch of {self.name} to raise exception as there are no measurements provided." diff --git a/neural_compressor/torch/algorithms/fp8_quant/_quant_common/quant_config.py b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/quant_config.py new file mode 100644 index 00000000000..1cf343e1a22 --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/_quant_common/quant_config.py @@ -0,0 +1,274 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +import os +from dataclasses import dataclass +from enum import Enum, Flag, auto +from json.decoder import JSONDecodeError +from typing import Any, Mapping + +import habana_frameworks.torch.utils.experimental as htexp +import torch + +from ..utils.logger import logger + +local_rank = int(os.getenv("LOCAL_RANK", "-1")) +world_size = int(os.getenv("WORLD_SIZE", "-1")) + + +class QuantMode(Enum): + NONE = 0 + QUANTIZE = 1 + MEASURE = 2 + SHAPE = 3 + + +class MeasureExclude(Flag): + NONE = auto() + INPUT = auto() + OUTPUT = auto() + PARAMS = auto() + ALL = auto() + + +class ScaleMethod(Enum): + MAX = 1 + UNIT_SCALE = 2 + MAXABS_HW = 3 + MAXABS_POW2 = 4 + SMOOTHQUANT_WEIGHTS_OUTPUT_CHANNEL_MAXABS_POW2 = 5 + WEAKSMOOTHQUANT_WEIGHTS_OUTPUT_CHANNEL_MAXABS_POW2 = 6 + ACT_MAXABS_HW_WEIGHTS_PCS_MAXABS_POW2 = 7 + ACT_MAXABS_HW_WEIGHTS_PCS_OPT_POW2 = 8 + ACT_MAXABS_POW2_WEIGHTS_PCS_MAXABS_POW2 = 9 + ACT_MAXABS_POW2_WEIGHTS_PCS_OPT_POW2 = 10 + SMOOTHQUANT_OPT = 11 + MAXABS_HW_OPT_WEIGHT = 12 + MAXABS_POW2_OPT_WEIGHT = 13 + + +def get_hqt_config(mod) -> Fp8cfg: + return mod.__hqt_config__ + + +def set_hqt_config(mod, config): + mod.__hqt_config__ = config + + +@dataclass +class Fp8cfg: + cfg: Mapping[str, Any] + + def parse(custom_config: Mapping[str, str]) -> Fp8cfg: + measured_global_config = { + "dump_stats_path": "stats", + "fp8_config": torch.float8_e4m3fn, # The parameters of the chosen Quantization methed + "hp_dtype": torch.bfloat16, # The parameters of the chosen Quantization methed + "blocklist": { + "names": [], + "types": (), + }, # types and names to not be quantized + "allowlist": { + "names": [], + "types": ("torch.nn.Linear", "torch.nn.Conv2d", "BMM"), + }, # types and names to be quantized. Allowlist by names is not yet implemented + "mode": QuantMode.QUANTIZE, # Quantize or Measure + "scale_method": ScaleMethod.UNIT_SCALE, # Method to quantize with + "scale_params": {}, # scaling parameters that are different then the default ones + "observer": "maxabs", # Supported ['shape', 'maxabs', 'maxabs_per_channel', 'save'] + "mod_dict": {}, + "ignore_modules_wo_measures": False, # Determines whether to fail quantization on modules without existing measures or not to quantize them + "local_rank": local_rank if local_rank >= 0 else None, + "global_rank": None, + "world_size": world_size if world_size >= 0 else None, + "seperate_measure_files": True, # Determines whether to expect one or several measure files when using more than one gaudi + "device_type": htexp._get_device_type(), # Determines device type: Gaudi2, Gaudi3... + "measure_exclude": MeasureExclude.OUTPUT, + } + # assert measured_global_config['allowlist']['names'] == [''], "Allowlist names not yet implemented" + + # go over all user-defined keys from json, handle various cases + for keys in custom_config: + if keys == "mode": + if custom_config[keys] == "NONE": + custom_config[keys] = QuantMode.NONE + elif custom_config[keys] == "QUANTIZE": + custom_config[keys] = QuantMode.QUANTIZE + elif custom_config[keys] == "MEASURE": + custom_config[keys] = QuantMode.MEASURE + elif custom_config[keys] == "SHAPE": + custom_config[keys] = QuantMode.SHAPE + else: + raise ValueError("invalid mode in custom config. Enter Quantize or Measure") + + if keys == "measure_exclude": + if custom_config[keys] == "NONE": + custom_config[keys] = MeasureExclude.NONE + elif custom_config[keys] == "OUTPUT": + custom_config[keys] = MeasureExclude.OUTPUT + elif custom_config[keys] == "INPUT": + custom_config[keys] = MeasureExclude.INPUT + elif custom_config[keys] == "ALL": + custom_config[keys] = MeasureExclude.ALL + else: + raise ValueError("invalid measure exclude value in custom config. Enter OUTPUT or NONE") + + if keys == "fp8_config": + if custom_config[keys].lower() == "e4m3": + custom_config[keys] = torch.float8_e4m3fn + + elif custom_config[keys].lower() == "e5m2": + custom_config[keys] = torch.float8_e5m2 + else: + raise ValueError("invalid fp8_config in custom config. Enter E4M3 or E5M2") + + if keys == "hp_dtype": + if custom_config[keys].lower() == "bf16": + custom_config[keys] = torch.bfloat16 + elif custom_config[keys].lower() == "fp16": + custom_config[keys] = torch.float16 + elif custom_config[keys].lower() == "fp32": + custom_config[keys] = torch.float32 + else: + raise ValueError("invalid hp_dtype in custom config. Enter bf16, fp16 or fp32") + + if keys == "scale_method": + if custom_config[keys].lower() == "unit_scale": + custom_config[keys] = ScaleMethod.UNIT_SCALE + elif custom_config[keys].lower() == "max": + custom_config[keys] = ScaleMethod.MAX + elif custom_config[keys].lower() == "maxabs_hw": + custom_config[keys] = ScaleMethod.MAXABS_HW + elif custom_config[keys].lower() == "maxabs_pow2": + custom_config[keys] = ScaleMethod.MAXABS_POW2 + elif custom_config[keys].lower() == "maxabs_hw_opt_weight": + custom_config[keys] = ScaleMethod.MAXABS_HW_OPT_WEIGHT + elif custom_config[keys].lower() == "maxabs_pow2_opt_weight": + custom_config[keys] = ScaleMethod.MAXABS_POW2_OPT_WEIGHT + elif custom_config[keys].lower() == "smoothquant_weights_output_channel_maxabs_pow2": + custom_config[keys] = ScaleMethod.SMOOTHQUANT_WEIGHTS_OUTPUT_CHANNEL_MAXABS_POW2 + elif custom_config[keys].lower() == "weaksmoothquant_weights_output_channel_maxabs_pow2": + custom_config[keys] = ScaleMethod.WEAKSMOOTHQUANT_WEIGHTS_OUTPUT_CHANNEL_MAXABS_POW2 + elif custom_config[keys].lower() == "act_maxabs_hw_weights_pcs_maxabs_pow2": + custom_config[keys] = ScaleMethod.ACT_MAXABS_HW_WEIGHTS_PCS_MAXABS_POW2 + elif custom_config[keys].lower() == "act_maxabs_hw_weights_pcs_opt_pow2": + custom_config[keys] = ScaleMethod.ACT_MAXABS_HW_WEIGHTS_PCS_OPT_POW2 + elif custom_config[keys].lower() == "act_maxabs_pow2_weights_pcs_maxabs_pow2": + custom_config[keys] = ScaleMethod.ACT_MAXABS_POW2_WEIGHTS_PCS_MAXABS_POW2 + elif custom_config[keys].lower() == "act_maxabs_pow2_weights_pcs_opt_pow2": + custom_config[keys] = ScaleMethod.ACT_MAXABS_POW2_WEIGHTS_PCS_OPT_POW2 + elif custom_config[keys].lower() == "smoothquant_opt": + custom_config[keys] = ScaleMethod.SMOOTHQUANT_OPT + else: + raise ValueError( + f'Invalid fp8_config in custom config ({custom_config[keys]}). should be in ["max", "unit_scale", "maxabs_hw", "maxabs_pow2", "maxabs_per_channel_pow2", "smoothquant_opt"]' + ) + + if keys == "ignore_modules_wo_measures": + custom_config[keys] = custom_config[keys].lower() == "true" + + # TODO [SW-175936] - remove checking for old key names whitelist and blacklist. + if isinstance(custom_config[keys], dict): + for keys_2 in custom_config[keys]: + if keys == "whitelist": + measured_global_config["allowlist"][keys_2] = custom_config[keys][keys_2] + elif keys == "blacklist": + measured_global_config["blocklist"][keys_2] = custom_config[keys][keys_2] + else: + measured_global_config[keys][keys_2] = custom_config[keys][keys_2] + else: + if keys == "whitelist": + measured_global_config["allowlist"] = custom_config[keys] + elif keys == "blacklist": + measured_global_config["blocklist"] = custom_config[keys] + else: + measured_global_config[keys] = custom_config[keys] + + # If seperate_measure_files is True (default value), then it is assumed that there are multiple distinct measure and scale files + # and they are stored in / loaded from paths with the correct index as a suffix. Else, only one is searched for. + measured_global_config["local_rank"] = ( + local_rank if local_rank >= 0 and custom_config.get("seperate_measure_files", True) else None + ) + + base_name = measured_global_config["dump_stats_path"].split("/")[-1] + folder_name = measured_global_config["dump_stats_path"][: -(len(base_name))] + measured_global_config["dump_stats_base_path"] = folder_name + os.makedirs(folder_name, exist_ok=True) + worker_st = ( + "" + if measured_global_config["local_rank"] is None + else "_" + str(measured_global_config["local_rank"]) + "_" + str(measured_global_config["world_size"]) + ) + measured_global_config["shape_file"] = measured_global_config["dump_stats_path"] + "_hooks_shape" + worker_st + measured_global_config["scale_file"] = ( + measured_global_config["dump_stats_path"] + + "_hooks_" + + measured_global_config["observer"] + + "_" + + measured_global_config["scale_method"].name + + worker_st + ) + if (measured_global_config["mode"] == QuantMode.MEASURE) or ( + measured_global_config["mode"] == QuantMode.QUANTIZE + ): + measured_global_config["measure_file"] = ( + measured_global_config["dump_stats_path"] + "_hooks_" + measured_global_config["observer"] + worker_st + ) + # measured_global_config['dump_stats_path'] += '_hooks_.json' + + logger.debug("HQT Paths:") + logger.debug("base_name='%s'", base_name) + logger.debug("folder_name='%s'", folder_name) + logger.debug( + "measured_global_config['shape_file']='%s'", + measured_global_config["shape_file"], + ) + logger.debug( + "measured_global_config['scale_file']='%s'", + measured_global_config["scale_file"], + ) + if "measure_file" in measured_global_config.keys(): + logger.debug( + "measured_global_config['measure_file']='%s'", + measured_global_config["measure_file"], + ) + logger.debug( + "measured_global_config['dump_stats_path']='%s'", + measured_global_config["dump_stats_path"], + ) + + return Fp8cfg(cfg=measured_global_config) + + +def _read_config_from_file(config_path: str) -> Mapping[str, str]: + logger.debug("QUANT PACKAGE: using %s config", config_path) + + module_directory = os.path.dirname(os.path.abspath(__file__)) + + # if file in absolute path doesn't exist, try looking in cfg directory + if not os.path.isfile(config_path): + config_path = os.path.join(module_directory, "..", f"custom_config/{config_path}.json") + try: + logger.info("QUANT PACKAGE: Loading %s", config_path) + with open(config_path) as config_json: + config = json.load(config_json) + except FileNotFoundError as e: + raise Exception(f"Got exception: {e}. QUANT PACKAGE: Can't open {config_path}!") + except JSONDecodeError as e: + config_json.close() + raise Exception(f"Got exception: {e}. QUANT PACKAGE: Can't load {config_path} json!") + return config diff --git a/neural_compressor/torch/algorithms/fp8_quant/common.py b/neural_compressor/torch/algorithms/fp8_quant/common.py new file mode 100644 index 00000000000..163509a6048 --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/common.py @@ -0,0 +1,106 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import tempfile +from collections import namedtuple +from pathlib import Path +from typing import Union + +import torch + +from neural_compressor.torch.algorithms.fp8_quant._quant_common.quant_config import Fp8cfg +from neural_compressor.torch.algorithms.fp8_quant.prepare_quant.prepare_model import finish_measurements + + +def save_calib_result(model): + if hasattr(model, "__hqt_config__") and isinstance(model.__hqt_config__, Fp8cfg): + # TODO SW-184714 modify hqt notation to inc notation once code is ported + finish_measurements(model) + else: + raise NotImplementedError("Saving calibration results currently supported only in HPU.") + + +def update_mode(config_path, measure_step=False, quant_step=False): + with open(config_path, "r") as file: + config = json.load(file) + + if (measure_step and config.get("mode") == "MEASURE") or (quant_step and config.get("mode") == "QUANTIZE"): + return config_path + else: + if measure_step: + config["mode"] = "MEASURE" + if quant_step: + config["mode"] = "QUANTIZE" + + temp_file = tempfile.NamedTemporaryFile(suffix=".json", delete=False) + temp_file_path = temp_file.name + + with open(temp_file_path, "w") as temp_file: + json.dump(config, temp_file) + + return temp_file_path + + +def generate_model_info(model): + mod_inst_info = namedtuple("ModInstInfo", ["name", "parent"]) + parent_child_mod_dict = {} + + def create_mod_info_recursion(parent): + for name, mod in parent.named_children(): + parent_child_mod_dict[mod] = mod_inst_info(name=name, parent=parent) + create_mod_info_recursion(mod) + + create_mod_info_recursion(model) + return parent_child_mod_dict + + +def get_patched_mod_list(): + from ._core.common import mod_default_dict + + patched_mod_list = [] + for patched_mod in mod_default_dict.values(): + patched_mod_list.append(patched_mod.patched_module.__name__) + return patched_mod_list + + +def restore_patched_module(patched_model): + from neural_compressor.torch.algorithms.fp8_quant.helper_modules import helper_mods + + patched_mod_list = get_patched_mod_list() + + parent_child_mod_dict = generate_model_info(patched_model) + with torch.no_grad(): + for name, patched_mod in patched_model.named_modules(): + patched_mod_type_str = patched_mod.__class__.__name__ + if patched_mod_type_str in patched_mod_list: + parent = parent_child_mod_dict[patched_mod].parent + name = parent_child_mod_dict[patched_mod].name + class_name_org = ( + getattr(patched_mod, "class_name_org", None) or patched_mod.__class__.__name__.split("Patched")[-1] + ) + patched_mod.__dict__.pop("forward", None) + origin_mod = helper_mods[class_name_org](patched_mod) + setattr(parent, name, origin_mod) + + +def with_patched_module(model): + patched_mod_list = get_patched_mod_list() + + for name, mod in model.named_modules(): + mod_type = mod.__class__.__name__ + if mod_type in patched_mod_list: + return True + return False diff --git a/neural_compressor/torch/algorithms/fp8_quant/custom_config/custom_example.json b/neural_compressor/torch/algorithms/fp8_quant/custom_config/custom_example.json new file mode 100644 index 00000000000..26b8af220a7 --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/custom_config/custom_example.json @@ -0,0 +1,5 @@ +{ + "mode": "MEASURE", + "scale_method": "MAX", + "fp8_config": "E4M3" +} \ No newline at end of file diff --git a/neural_compressor/torch/algorithms/fp8_quant/custom_config/llama_measure.json b/neural_compressor/torch/algorithms/fp8_quant/custom_config/llama_measure.json new file mode 100644 index 00000000000..fc675067c22 --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/custom_config/llama_measure.json @@ -0,0 +1,14 @@ +{ + "mode": "MEASURE", + "observer": "maxabs", + "allowlist": { + "types": [], + "names": [] + }, + "blocklist": { + "types": [], + "names": [] + }, + "quantize_weight": false, + "dump_stats_path": "./llama_output/7b_measure" +} \ No newline at end of file diff --git a/neural_compressor/torch/algorithms/fp8_quant/custom_config/llama_quant.json b/neural_compressor/torch/algorithms/fp8_quant/custom_config/llama_quant.json new file mode 100644 index 00000000000..f341964187a --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/custom_config/llama_quant.json @@ -0,0 +1,17 @@ +{ + "mode": "QUANTIZE", + "observer": "maxabs", + "scale_method": "maxabs_hw", + "allowlist": { + "types": [], + "names": [] + }, + "blocklist": { + "types": [], + "names": [ + "lm_head" + ] + }, + "quantize_weight": false, + "dump_stats_path": "./llama_output/7b_measure" +} \ No newline at end of file diff --git a/neural_compressor/torch/algorithms/fp8_quant/custom_config/measure_config.json b/neural_compressor/torch/algorithms/fp8_quant/custom_config/measure_config.json new file mode 100755 index 00000000000..b8c4d29b781 --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/custom_config/measure_config.json @@ -0,0 +1,12 @@ +{ + "mode": "MEASURE", + "scale_method": "MAX", + "quantize_weight": true, + "dump_stats_path": "./run_outputs/fp8/stats", + "allowlist": { + "types": [ + "torch.nn.Linear", + "torch.nn.Conv2d" + ] + } +} \ No newline at end of file diff --git a/neural_compressor/torch/algorithms/fp8_quant/custom_config/quant_config.json b/neural_compressor/torch/algorithms/fp8_quant/custom_config/quant_config.json new file mode 100755 index 00000000000..286a1632257 --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/custom_config/quant_config.json @@ -0,0 +1,13 @@ +{ + "mode": "QUANTIZE", + "observer": "maxabs", + "scale_method": "maxabs_hw", + "fp8_config": "E4M3", + "allowlist": { + "types": [ + "torch.nn.Linear", + "torch.nn.Conv2d" + ] + }, + "dump_stats_path": "./run_outputs/fp8/stats" +} \ No newline at end of file diff --git a/neural_compressor/torch/algorithms/fp8_quant/fp8_quant.py b/neural_compressor/torch/algorithms/fp8_quant/fp8_quant.py new file mode 100644 index 00000000000..f160f208612 --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/fp8_quant.py @@ -0,0 +1,60 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from neural_compressor.common.utils import FP8_QUANT +from neural_compressor.torch.algorithms import Quantizer +from neural_compressor.torch.algorithms.fp8_quant import ( + finish_measurements, + prep_model, + restore_patched_module, + update_mode, + with_patched_module, +) + + +class FP8Quantizer(Quantizer): + def __init__(self, quant_config): + super().__init__(quant_config) + if isinstance(quant_config, dict): + json_file = [cfg.json_file for cfg in quant_config.values()] + assert len(json_file) > 0, "Cannot get json file from config." + self.quant_config = json_file[0] + + def prepare(self, model): + _prepare(model, self.quant_config) + return model + + def convert(self, model): + if with_patched_module(model): # if model was calibrated on hpu + finish_measurements(model) # dump the measurements into files to be loaded in _convert + # for INC flow, it calls `prepare` and then `convert` user-facing API in one run + restore_patched_module(model) + _convert(model, self.quant_config) + return model + + +def _convert(model, config_path): + # update mode to QUANTIZE + config_path = update_mode(config_path, quant_step=True) + + return prep_model(model, config_path) + + +def _prepare(model, config_path): + # update mode to MEASURE + config_path = update_mode(config_path, measure_step=True) + + return prep_model(model, config_path) diff --git a/neural_compressor/torch/algorithms/fp8_quant/helper_modules.py b/neural_compressor/torch/algorithms/fp8_quant/helper_modules.py new file mode 100644 index 00000000000..a31d4910979 --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/helper_modules.py @@ -0,0 +1,133 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +# For mapping revert patched module to origin module + +helper_mods = {} + + +def helper_mod_register(name): + def decorator(mod): + helper_mods[name] = mod + return mod + + return decorator + + +@helper_mod_register(name="Matmul") +class Matmul(torch.nn.Module): + def __init__(self, patched_mod, *args, **kwargs): + super().__init__() + self.__dict__.update(patched_mod.__dict__) + self.extra_repr = patched_mod.extra_repr_org + + +@helper_mod_register(name="Linear") +class Linear(torch.nn.Module): + def __init__(self, patched_mod, *args, **kwargs): + super().__init__() + self.__dict__.update(patched_mod.__dict__) + self.extra_repr = patched_mod.extra_repr_org + + +@helper_mod_register(name="FalconLinear") +class FalconLinear(torch.nn.Module): + def __init__(self, patched_mod, *args, **kwargs): + super().__init__() + self.__dict__.update(patched_mod.__dict__) + self.extra_repr = patched_mod.extra_repr_org + + +@helper_mod_register(name="KVCache") +class KVCache(torch.nn.Module): + def __init__(self, patched_mod, *args, **kwargs): + super().__init__() + self.allocate = patched_mod.org_allocate + self.get_shape = patched_mod.get_shape + self.forward = patched_mod.forward + self.update = patched_mod.update + + +@helper_mod_register(name="Conv2d") +class Conv2d(torch.nn.Module): + def __init__(self, patched_mod, *args, **kwargs): + super().__init__() + self.__dict__.update(patched_mod.__dict__) + self.extra_repr = patched_mod.extra_repr_org + + +@helper_mod_register(name="LoRACompatibleLinear") +class LoRACompatibleLinear(torch.nn.Module): + def __init__(self, patched_mod, *args, **kwargs): + super().__init__() + self.__dict__.update(patched_mod.__dict__) + self.extra_repr = patched_mod.extra_repr_org + + +@helper_mod_register(name="LoRACompatibleConv") +class LoRACompatibleConv(torch.nn.Module): + def __init__(self, patched_mod, *args, **kwargs): + super().__init__() + self.__dict__.update(patched_mod.__dict__) + self.extra_repr = patched_mod.extra_repr_org + + +@helper_mod_register(name="Softmax") +class Softmax(torch.nn.Module): + def __init__(self, patched_mod, *args, **kwargs): + super().__init__() + self.__dict__.update(patched_mod.__dict__) + self.extra_repr = patched_mod.extra_repr_org + + +@helper_mod_register(name="LinearLayer") +class LinearLayer(torch.nn.Module): + def __init__(self, patched_mod, *args, **kwargs): + super().__init__() + self.__dict__.update(patched_mod.__dict__) + self.extra_repr = patched_mod.extra_repr_org + + +@helper_mod_register(name="LinearAllreduce") +class LinearAllreduce(torch.nn.Module): + def __init__(self, patched_mod, *args, **kwargs): + super().__init__() + self.__dict__.update(patched_mod.__dict__) + self.extra_repr = patched_mod.extra_repr_org + + +@helper_mod_register(name="ScopedLinearAllReduce") +class ScopedLinearAllReduce(torch.nn.Module): + def __init__(self, patched_mod, *args, **kwargs): + super().__init__() + self.__dict__.update(patched_mod.__dict__) + self.extra_repr = patched_mod.extra_repr_org + + +@helper_mod_register(name="LmHeadLinearAllreduce") +class LmHeadLinearAllreduce(torch.nn.Module): + def __init__(self, patched_mod, *args, **kwargs): + super().__init__() + self.__dict__.update(patched_mod.__dict__) + self.extra_repr = patched_mod.extra_repr_org + + +@helper_mod_register(name="ModuleFusedSDPA") +class ModuleFusedSDPA(torch.nn.Module): + def __init__(self, patched_mod, *args, **kwargs): + super().__init__() + self.__dict__.update(patched_mod.__dict__) + self.extra_repr = patched_mod.extra_repr_org diff --git a/neural_compressor/torch/algorithms/fp8_quant/internal/diffusion_evaluation/README b/neural_compressor/torch/algorithms/fp8_quant/internal/diffusion_evaluation/README new file mode 100644 index 00000000000..dfd014918d4 --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/internal/diffusion_evaluation/README @@ -0,0 +1,32 @@ +How to calculate FID and clip score: + +We will use the MS-COCO database. We use this for two things: +- Generating a large amount of prompts which we can use to create diffusion images +- Once we have diffusion images, we need a "ground truth" dataset to calculate the FID. + +1) Run a python script which does the following things: + - Takes a subset of MSCOCO + - Create a CSV with prompts which can then be inserted into the diffusion model. These prompts are taken from captions of the images in the subset + - Create a new folder with the images from the subset + - The standard number of images for this evaluation is 30K or 10K + +run the following: + +python create_dataset.py /datasets/coco2014 + +Now, create the generated images from the csv file + +IMPORTANT!! - the script that does the actual evaluation (explained below) expects to get an image where the prompt is the title of the image. For example, if the prompt is "a monster playing the guitar" then the name of the file that is created using diffusion should be "/a monster playing the guitar.png" (or jpg or whatever) + +IMPORTANT!! #2 - from my experience, stable diffusion inference returns an error for prompts with the character '/' in them. There are very few, around one in a thousand. My recommendation, if you want to evaluate N images, create a subset of the size N+30 and delete prompts with '/' in them. After creating the CSV I just deleted these prompts manually (takes 10 seconds to do). +(Perhaps automating this should be a future commit). + +2) Now, run the evaluation script. This does the following: +- Calculates the CLIP score – takes the CLIP embedding of each generated image and the embedding of the caption that created it (in this case each image and its file name). Then, calculates the cosine distance between them. +- Calculates the FID - takes the real and generated images, and calculates according to the FID distance metric. +- insert the number of images to evaluate with - could be the number of images in the subset created above or less + +To do this, run: + +python evaluator.py --device hpu --real_images_path /datasets/coco2014/val2014 --diff_images_path --num_of_images + diff --git a/neural_compressor/torch/algorithms/fp8_quant/internal/diffusion_evaluation/SR_evaluation/README.md b/neural_compressor/torch/algorithms/fp8_quant/internal/diffusion_evaluation/SR_evaluation/README.md new file mode 100644 index 00000000000..f837681413e --- /dev/null +++ b/neural_compressor/torch/algorithms/fp8_quant/internal/diffusion_evaluation/SR_evaluation/README.md @@ -0,0 +1,37 @@ +How to calculate PSNR and SSIM for Super Resolution +We will use the Imagenet validation dataset. + +The evaluation is done by the following steps: +1) We take the Imagenet validation set which has 50,000 images (We can also take a subset) +2) Crop these Images to be 256*256 (center cropped), and save these images as the "ground truth" dataset. The name of +the saved image is its label. +3) Downsample the images to be 64*64 (using bicubic interpolation) and then restore them using Super Resolution. +4) Calculate PSNR and SSIM between each ground truth image and restored image, and print the mean. + +Steps 1,2 and 4 are included here, while step 3 (downsampling and restoring) should be done separately, using the +desired Super Resolution method. Keep in mind that this script assumes that the images are stored in a specific format, +(detailed later). Later, the restored images path should be given as an input to step 4. + +You can skip step 1+2 and use the images at /datasets/imagenet/val_cropped_labeled +You can also run a python script which does the following to the imagenet validation dataset: + - Crops images to 256*256 (this can also be changed using the argument --resize, 256*256 is the default) + - Saves the images with the convention /