Skip to content

Commit 3cdaa07

Browse files
violetch24bmyrcha
authored andcommitted
add notebook example for pytorch (#1300)
Signed-off-by: Cheng, Zixuan <zixuan.cheng@intel.com> Signed-off-by: bmyrcha <bartosz.myrcha@intel.com>
1 parent de099ab commit 3cdaa07

File tree

3 files changed

+424
-0
lines changed

3 files changed

+424
-0
lines changed

examples/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ Intel® Neural Compressor validated examples with multiple compression technique
88

99
* [Quick Get Started Notebook of Intel® Neural Compressor for Tensorflow](/examples/notebook/tensorflow/resnet/resnet_quantization.ipynb)
1010

11+
* [Quick Get Started Notebook of Intel® Neural Compressor for Pytorch](/examples/notebook/pytorch/Quick_Started_Notebook_of_INC_for_Pytorch.ipynb)
12+
1113
# Helloworld Examples
1214

1315
* [tf_example1](/examples/helloworld/tf_example1): quantize with built-in dataloader and metric.
Lines changed: 327 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,327 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Quick Get Started Notebook of Intel® Neural Compressor for Pytorch\n",
8+
"\n",
9+
"\n",
10+
"This notebook is designed to provide an easy-to-follow guide for getting started with the [Intel® Neural Compressor](https://github.com/intel/neural-compressor) (INC) library for [pytorch](https://github.com/pytorch/pytorch) framework.\n",
11+
"\n",
12+
"In the following sections, we are going to use a DistilBert model fine-tuned on MRPC as an example to show how to apply post-training quantization on [transformers](https://github.com/huggingface/transformers) models using the INC library.\n",
13+
"\n",
14+
"\n",
15+
"The main objectives of this notebook are:\n",
16+
"\n",
17+
"1. Prerequisite: Prepare necessary environment, model and dataset.\n",
18+
"2. Quantization with INC: Walk through the step-by-step process of applying post-training quantization.\n",
19+
"3. Benchmark with INC: Evaluate the performance of the FP32 and INT8 models.\n",
20+
"\n",
21+
"\n",
22+
"## 1. Prerequisite\n",
23+
"\n",
24+
"### 1.1 Environment\n",
25+
"\n",
26+
"If you have Jupyter Notebook, you may directly run this notebook. We will use pip to install or upgrade [neural-compressor](https://github.com/intel/neural-compressor), [pytorch](https://github.com/pytorch/pytorch) and other required packages.\n",
27+
"\n",
28+
"Otherwise, you can setup a new environment. First, we install [Anaconda](https://www.anaconda.com/distribution/). Then open an Anaconda prompt window and run the following commands:\n",
29+
"\n",
30+
"```shell\n",
31+
"conda create -n inc_notebook python==3.8\n",
32+
"conda activate inc_notebook\n",
33+
"pip install jupyter\n",
34+
"jupyter notebook\n",
35+
"```\n",
36+
"The last command will launch Jupyter Notebook and we can open this notebook in browser to continue.\n",
37+
"\n",
38+
"Then, let's install necessary packages."
39+
]
40+
},
41+
{
42+
"cell_type": "code",
43+
"execution_count": null,
44+
"metadata": {},
45+
"outputs": [],
46+
"source": [
47+
"# install neural-compressor from source\n",
48+
"!git clone https://github.com/intel/neural-compressor.git\n",
49+
"%cd ./neural-compressor\n",
50+
"!pip install -r requirements.txt\n",
51+
"!python setup.py install\n",
52+
"%cd ..\n",
53+
"\n",
54+
"# or install stable basic version from pypi\n",
55+
"!pip install neural-compressor"
56+
]
57+
},
58+
{
59+
"cell_type": "code",
60+
"execution_count": null,
61+
"metadata": {},
62+
"outputs": [],
63+
"source": [
64+
"# install other packages used in this notebook.\n",
65+
"!pip install torch>=1.9.0 transformers>=4.16.0 accelerate sympy numpy sentencepiece!=0.1.92 protobuf<=3.20.3 datasets>=1.1.3 scipy scikit-learn Keras-Preprocessing"
66+
]
67+
},
68+
{
69+
"cell_type": "markdown",
70+
"metadata": {},
71+
"source": [
72+
"### 1.2 Load Dataset\n",
73+
"\n",
74+
"The General Language Understanding Evaluation (GLUE) benchmark is a group of nine classification tasks on sentences or pairs of sentences which are:\n",
75+
"\n",
76+
"- [CoLA](https://nyu-mll.github.io/CoLA/) (Corpus of Linguistic Acceptability) Determine if a sentence is grammatically correct or not.\n",
77+
"- [MNLI](https://arxiv.org/abs/1704.05426) (Multi-Genre Natural Language Inference) Determine if a sentence entails, contradicts or is unrelated to a given hypothesis. This dataset has two versions, one with the validation and test set coming from the same distribution, another called mismatched where the validation and test use out-of-domain data.\n",
78+
"- [MRPC](https://www.microsoft.com/en-us/download/details.aspx?id=52398) (Microsoft Research Paraphrase Corpus) Determine if two sentences are paraphrases from one another or not.\n",
79+
"- [QNLI](https://rajpurkar.github.io/SQuAD-explorer/) (Question-answering Natural Language Inference) Determine if the answer to a question is in the second sentence or not. This dataset is built from the SQuAD dataset.\n",
80+
"- [QQP](https://data.quora.com/First-Quora-Dataset-Release-Question-Pairs) (Quora Question Pairs2) Determine if two questions are semantically equivalent or not.\n",
81+
"- [RTE](https://aclweb.org/aclwiki/Recognizing_Textual_Entailment) (Recognizing Textual Entailment) Determine if a sentence entails a given hypothesis or not.\n",
82+
"- [SST-2](https://nlp.stanford.edu/sentiment/index.html) (Stanford Sentiment Treebank) Determine if the sentence has a positive or negative sentiment.\n",
83+
"- [STS-B](http://ixa2.si.ehu.es/stswiki/index.php/STSbenchmark) (Semantic Textual Similarity Benchmark) Determine the similarity of two sentences with a score from 1 to 5.\n",
84+
"- [WNLI](https://cs.nyu.edu/faculty/davise/papers/WinogradSchemas/WS.html) (Winograd Natural Language Inference) Determine if a sentence with an anonymous pronoun and a sentence with this pronoun replaced are entailed or not. This dataset is built from the Winograd Schema Challenge dataset.\n",
85+
"\n",
86+
"Here, we use MRPC task. We download and load the required dataset from hub."
87+
]
88+
},
89+
{
90+
"cell_type": "code",
91+
"execution_count": null,
92+
"metadata": {},
93+
"outputs": [],
94+
"source": [
95+
"import datasets\n",
96+
"import numpy as np\n",
97+
"import transformers\n",
98+
"from datasets import load_dataset, load_metric\n",
99+
"from transformers import (\n",
100+
" AutoConfig,\n",
101+
" AutoModelForSequenceClassification,\n",
102+
" AutoTokenizer,\n",
103+
" EvalPrediction,\n",
104+
" Trainer,\n",
105+
")"
106+
]
107+
},
108+
{
109+
"cell_type": "code",
110+
"execution_count": null,
111+
"metadata": {},
112+
"outputs": [],
113+
"source": [
114+
"task_name = 'mrpc'\n",
115+
"raw_datasets = load_dataset(\"glue\", task_name)\n",
116+
"label_list = raw_datasets[\"train\"].features[\"label\"].names\n",
117+
"num_labels = len(label_list)"
118+
]
119+
},
120+
{
121+
"cell_type": "markdown",
122+
"metadata": {},
123+
"source": [
124+
"### 1.3 Prepare Model\n",
125+
"Download the pretrained model [textattack/distilbert-base-uncased-MRPC](https://huggingface.co/textattack/distilbert-base-uncased-MRPC) to a pytorch model."
126+
]
127+
},
128+
{
129+
"cell_type": "code",
130+
"execution_count": null,
131+
"metadata": {},
132+
"outputs": [],
133+
"source": [
134+
"model_name = 'textattack/distilbert-base-uncased-MRPC'\n",
135+
"\n",
136+
"config = AutoConfig.from_pretrained(\n",
137+
" model_name,\n",
138+
" num_labels=num_labels,\n",
139+
" finetuning_task=task_name,\n",
140+
" use_auth_token=None,\n",
141+
")\n",
142+
"\n",
143+
"tokenizer = AutoTokenizer.from_pretrained(\n",
144+
" model_name,\n",
145+
" use_auth_token=None,\n",
146+
")\n",
147+
"\n",
148+
"model = AutoModelForSequenceClassification.from_pretrained(\n",
149+
" model_name,\n",
150+
" from_tf=False,\n",
151+
" config=config,\n",
152+
" use_auth_token=None,\n",
153+
")"
154+
]
155+
},
156+
{
157+
"cell_type": "markdown",
158+
"metadata": {},
159+
"source": [
160+
"### 1.4 Dataset Preprocessing\n",
161+
"We need to preprocess the raw dataset."
162+
]
163+
},
164+
{
165+
"cell_type": "code",
166+
"execution_count": null,
167+
"metadata": {},
168+
"outputs": [],
169+
"source": [
170+
"sentence1_key, sentence2_key = (\"sentence1\", \"sentence2\")\n",
171+
"padding = \"max_length\"\n",
172+
"label_to_id = None\n",
173+
"max_seq_length = 128\n",
174+
"\n",
175+
"def preprocess_function(examples):\n",
176+
" args = (\n",
177+
" (examples[sentence1_key], examples[sentence2_key])\n",
178+
" )\n",
179+
" result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True)\n",
180+
" return result\n",
181+
"\n",
182+
"raw_datasets = raw_datasets.map(preprocess_function, batched=True)"
183+
]
184+
},
185+
{
186+
"cell_type": "markdown",
187+
"metadata": {},
188+
"source": [
189+
"## 2. Quantization with Intel® Neural Compressor"
190+
]
191+
},
192+
{
193+
"cell_type": "markdown",
194+
"metadata": {},
195+
"source": [
196+
"### 2.1 Define metric, evaluate function, and dataloader\n",
197+
"\n",
198+
"In this part, we define a GLUE metirc and use it to generate an evaluate function for INC.\n",
199+
"\n",
200+
"Refer to doc [metric.md](https://github.com/intel/neural-compressor/blob/master/docs/source/metric.md#build-custom-metric-with-python-api) for how to build your own metric.\n",
201+
"Refer to doc [dataset.md](https://github.com/intel/neural-compressor/blob/master/docs/source/dataset.md#user-specific-dataset) and [dataloader.md](https://github.com/intel/neural-compressor/blob/master/docs/source/dataloader.md#build-custom-dataloader-with-python-apiapi) for how to build your own dataset and dataloader."
202+
]
203+
},
204+
{
205+
"cell_type": "code",
206+
"execution_count": null,
207+
"metadata": {},
208+
"outputs": [],
209+
"source": [
210+
"eval_dataset = raw_datasets[\"validation\"]\n",
211+
"metric = load_metric(\"glue\", task_name)\n",
212+
"data_collator = None\n",
213+
"\n",
214+
"def compute_metrics(p: EvalPrediction):\n",
215+
" preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions\n",
216+
" preds = np.argmax(preds, axis=1)\n",
217+
" result = metric.compute(predictions=preds, references=p.label_ids)\n",
218+
" if len(result) > 1:\n",
219+
" result[\"combined_score\"] = np.mean(list(result.values())).item()\n",
220+
" return result\n",
221+
"\n",
222+
"# Initialize our Trainer\n",
223+
"trainer = Trainer(\n",
224+
" model=model,\n",
225+
" train_dataset=None,\n",
226+
" eval_dataset=eval_dataset,\n",
227+
" compute_metrics=compute_metrics,\n",
228+
" tokenizer=tokenizer,\n",
229+
" data_collator=data_collator,\n",
230+
")\n",
231+
"\n",
232+
"eval_dataloader = trainer.get_eval_dataloader()\n",
233+
"\n",
234+
"# for transformers 4.31.0: accelerate dataloader\n",
235+
"# please use the code below to avoid error \n",
236+
"if eval_dataloader.batch_size is None:\n",
237+
" def _build_inc_dataloader(dataloader):\n",
238+
" class INCDataLoader:\n",
239+
" __iter__ = dataloader.__iter__\n",
240+
" def __init__(self) -> None:\n",
241+
" self.dataloader = dataloader\n",
242+
" self.batch_size = dataloader.total_batch_size\n",
243+
" return INCDataLoader()\n",
244+
" eval_dataloader = _build_inc_dataloader(eval_dataloader)\n",
245+
"batch_size = eval_dataloader.batch_size\n",
246+
"\n",
247+
"def take_eval_steps(model, trainer, save_metrics=False):\n",
248+
" trainer.model = model\n",
249+
" metrics = trainer.evaluate()\n",
250+
" bert_task_acc_keys = ['eval_f1', 'eval_accuracy', 'eval_matthews_correlation',\n",
251+
" 'eval_pearson', 'eval_mcc', 'eval_spearmanr']\n",
252+
" for key in bert_task_acc_keys:\n",
253+
" if key in metrics.keys():\n",
254+
" throughput = metrics.get(\"eval_samples_per_second\")\n",
255+
" print('Batch size = %d' % batch_size)\n",
256+
" print(\"Finally Eval {} Accuracy: {}\".format(key, metrics[key]))\n",
257+
" print(\"Latency: %.3f ms\" % (1000 / throughput))\n",
258+
" print(\"Throughput: {} samples/sec\".format(throughput))\n",
259+
" return metrics[key]\n",
260+
" assert False, \"No metric returned, Please check inference metric!\"\n",
261+
"\n",
262+
"def eval_func(model):\n",
263+
" return take_eval_steps(model, trainer)"
264+
]
265+
},
266+
{
267+
"cell_type": "markdown",
268+
"metadata": {},
269+
"source": [
270+
"### 2.2 Run Quantization\n",
271+
"\n",
272+
"So far, we can finally start to quantize the model. \n",
273+
"\n",
274+
"To start, we need to set the configuration for post-training quantization using `PostTrainingQuantConfig` class. Once the configuration is set, we can proceed to the next step by calling the `quantization.fit()` function. This function performs the quantization process on the model and will return the best quantized model."
275+
]
276+
},
277+
{
278+
"cell_type": "code",
279+
"execution_count": null,
280+
"metadata": {},
281+
"outputs": [],
282+
"source": [
283+
"from neural_compressor.quantization import fit\n",
284+
"from neural_compressor.config import PostTrainingQuantConfig, TuningCriterion\n",
285+
"tuning_criterion = TuningCriterion(max_trials=600)\n",
286+
"conf = PostTrainingQuantConfig(approach=\"static\", tuning_criterion=tuning_criterion)\n",
287+
"q_model = fit(model, conf=conf, calib_dataloader=eval_dataloader, eval_func=eval_func)"
288+
]
289+
},
290+
{
291+
"cell_type": "markdown",
292+
"metadata": {},
293+
"source": [
294+
"## 3. Benchmark with Intel® Neural Compressor\n",
295+
"\n",
296+
"INC provides a benchmark feature to measure the model performance with the objective settings."
297+
]
298+
},
299+
{
300+
"cell_type": "code",
301+
"execution_count": null,
302+
"metadata": {},
303+
"outputs": [],
304+
"source": [
305+
"# fp32 benchmark\n",
306+
"!python benchmark.py --input_model ./pytorch_model.bin 2>&1|tee fp32_benchmark.log\n",
307+
"\n",
308+
"# int8 benchmark\n",
309+
"!python benchmark.py --input_model ./saved_results/best_model.pt 2>&1|tee int8_benchmark.log\n"
310+
]
311+
}
312+
],
313+
"metadata": {
314+
"kernelspec": {
315+
"display_name": "Python 3",
316+
"language": "python",
317+
"name": "python3"
318+
},
319+
"language_info": {
320+
"name": "python",
321+
"version": "3.9.12"
322+
},
323+
"orig_nbformat": 4
324+
},
325+
"nbformat": 4,
326+
"nbformat_minor": 2
327+
}

0 commit comments

Comments
 (0)