-
Notifications
You must be signed in to change notification settings - Fork 228
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
1104f1b
commit 1917ece
Showing
3 changed files
with
283 additions
and
0 deletions.
There are no files selected for viewing
31 changes: 31 additions & 0 deletions
31
examples/post_training_quantization/torch/fx/resnet18/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# Quantization-Aware Training: An Example for Resnet18 in PyTorch | ||
|
||
This example demonstrates how to use Post-Training Quantization API from Neural Network Compression Framework (NNCF) to quantize and train PyTorch models on the example of Resnet18 quantization aware training, pretrained on Tiny ImageNet-200 dataset. | ||
|
||
The example includes the following steps: | ||
|
||
- Loading the Tiny ImageNet-200 dataset (~237 Mb) and the Resnet18 PyTorch model pretrained on this dataset. | ||
- Quantizing the model using NNCF Post-Training Quantization algorithm. | ||
- Fine tuning quantized model for one epoch to improve quantized model metrics. | ||
- Output of the following characteristics of the quantized model: | ||
- Accuracy drop of the quantized model (INT8) over the pre-trained model (FP32) | ||
- Compression rate of the quantized model file size relative to the pre-trained model file size | ||
- Performance speed up of the quantized model (INT8) | ||
|
||
## Install requirements | ||
|
||
At this point it is assumed that you have already installed NNCF. You can find information on installation NNCF [here](https://github.com/openvinotoolkit/nncf#user-content-installation). | ||
|
||
To work with the example you should install the corresponding Python package dependencies: | ||
|
||
```bash | ||
pip install -r requirements.txt | ||
``` | ||
|
||
## Run Example | ||
|
||
It's pretty simple. The example does not require additional preparation. It will do the preparation itself, such as loading the dataset and model, etc. | ||
|
||
```bash | ||
python main.py | ||
``` |
248 changes: 248 additions & 0 deletions
248
examples/post_training_quantization/torch/fx/resnet18/main.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,248 @@ | ||
# Copyright (c) 2024 Intel Corporation | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
import warnings | ||
from pathlib import Path | ||
from time import time | ||
from typing import Tuple | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.parallel | ||
import torch.optim | ||
import torch.utils.data | ||
import torch.utils.data.distributed | ||
import torchvision.datasets as datasets | ||
import torchvision.models as models | ||
import torchvision.transforms as transforms | ||
from fastdownload import FastDownload | ||
from torch._export import capture_pre_autograd_graph | ||
from torch.jit import TracerWarning | ||
|
||
import nncf | ||
import nncf.torch | ||
from nncf.common.logging.track_progress import track | ||
from nncf.common.utils.helpers import create_table | ||
from nncf.torch.dynamic_graph.patch_pytorch import unpatch_torch_operators | ||
|
||
unpatch_torch_operators() | ||
|
||
warnings.filterwarnings("ignore", category=TracerWarning) | ||
warnings.filterwarnings("ignore", category=UserWarning) | ||
|
||
|
||
IMAGE_SIZE = 64 | ||
BATCH_SIZE = 128 | ||
|
||
|
||
ROOT = Path(__file__).parent.resolve() | ||
BEST_CKPT_NAME = "resnet18_int8_best.pt" | ||
CHECKPOINT_URL = ( | ||
"https://storage.openvinotoolkit.org/repositories/nncf/openvino_notebook_ckpts/302_resnet18_fp32_v1.pth" | ||
) | ||
DATASET_URL = "http://cs231n.stanford.edu/tiny-imagenet-200.zip" | ||
DATASET_PATH = "~/.cache/nncf/datasets" | ||
|
||
|
||
def download_dataset() -> Path: | ||
downloader = FastDownload(base=DATASET_PATH, archive="downloaded", data="extracted") | ||
return downloader.get(DATASET_URL) | ||
|
||
|
||
def load_checkpoint(model: torch.nn.Module) -> torch.nn.Module: | ||
checkpoint = torch.hub.load_state_dict_from_url(CHECKPOINT_URL, map_location=torch.device("cpu"), progress=False) | ||
model.load_state_dict(checkpoint["state_dict"]) | ||
return model, checkpoint["acc1"] | ||
|
||
|
||
def get_resnet18_model(device: torch.device) -> torch.nn.Module: | ||
num_classes = 200 # 200 is for Tiny ImageNet, default is 1000 for ImageNet | ||
model = models.resnet18(weights=None) | ||
# Update the last FC layer for Tiny ImageNet number of classes. | ||
model.fc = nn.Linear(in_features=512, out_features=num_classes, bias=True) | ||
model.to(device) | ||
return model | ||
|
||
|
||
def measure_latency(model, example_inputs, num_iters=2000): | ||
with torch.no_grad(): | ||
model(example_inputs) | ||
total_time = 0 | ||
for _ in range(num_iters): | ||
start_time = time() | ||
model(example_inputs) | ||
total_time += time() - start_time | ||
average_time = (total_time / num_iters) * 1000 | ||
return average_time | ||
|
||
|
||
def validate(val_loader: torch.utils.data.DataLoader, model: torch.nn.Module, device: torch.device) -> float: | ||
top1_sum = 0.0 | ||
|
||
with torch.no_grad(): | ||
for images, target in track(val_loader, total=len(val_loader), description="Validation:"): | ||
images = images.to(device) | ||
target = target.to(device) | ||
|
||
# Compute output. | ||
output = model(images) | ||
|
||
# Measure accuracy and record loss. | ||
[acc1] = accuracy(output, target, topk=(1,)) | ||
top1_sum += acc1.item() | ||
|
||
num_samples = len(val_loader) | ||
top1_avg = top1_sum / num_samples | ||
return top1_avg | ||
|
||
|
||
def accuracy(output: torch.Tensor, target: torch.tensor, topk: Tuple[int, ...] = (1,)): | ||
with torch.no_grad(): | ||
maxk = max(topk) | ||
batch_size = target.size(0) | ||
|
||
_, pred = output.topk(maxk, 1, True, True) | ||
pred = pred.t() | ||
correct = pred.eq(target.view(1, -1).expand_as(pred)) | ||
|
||
res = [] | ||
for k in topk: | ||
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) | ||
res.append(correct_k.mul_(100.0 / batch_size)) | ||
return res | ||
|
||
|
||
def create_data_loaders(): | ||
dataset_path = download_dataset() | ||
|
||
prepare_tiny_imagenet_200(dataset_path) | ||
print(f"Successfully downloaded and prepared dataset at: {dataset_path}") | ||
|
||
val_dir = dataset_path / "val" | ||
|
||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | ||
|
||
val_dataset = datasets.ImageFolder( | ||
val_dir, | ||
transforms.Compose( | ||
[ | ||
transforms.Resize(IMAGE_SIZE), | ||
transforms.ToTensor(), | ||
normalize, | ||
] | ||
), | ||
) | ||
|
||
val_loader = torch.utils.data.DataLoader( | ||
val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True | ||
) | ||
|
||
# Creating separate dataloader with batch size = 1 | ||
# as dataloaders with batches > 1 are not supported yet. | ||
calibration_dataset = torch.utils.data.DataLoader( | ||
val_dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=True | ||
) | ||
|
||
return val_loader, calibration_dataset | ||
|
||
|
||
def prepare_tiny_imagenet_200(dataset_dir: Path): | ||
# Format validation set the same way as train set is formatted. | ||
val_data_dir = dataset_dir / "val" | ||
val_images_dir = val_data_dir / "images" | ||
if not val_images_dir.exists(): | ||
return | ||
|
||
val_annotations_file = val_data_dir / "val_annotations.txt" | ||
with open(val_annotations_file, "r") as f: | ||
val_annotation_data = map(lambda line: line.split("\t")[:2], f.readlines()) | ||
for image_filename, image_label in val_annotation_data: | ||
from_image_filepath = val_images_dir / image_filename | ||
to_image_dir = val_data_dir / image_label | ||
if not to_image_dir.exists(): | ||
to_image_dir.mkdir() | ||
to_image_filepath = to_image_dir / image_filename | ||
from_image_filepath.rename(to_image_filepath) | ||
val_annotations_file.unlink() | ||
val_images_dir.rmdir() | ||
|
||
|
||
def get_model_size(ir_path: str, m_type: str = "Mb") -> float: | ||
xml_size = os.path.getsize(ir_path) | ||
bin_size = os.path.getsize(os.path.splitext(ir_path)[0] + ".bin") | ||
for t in ["bytes", "Kb", "Mb"]: | ||
if m_type == t: | ||
break | ||
xml_size /= 1024 | ||
bin_size /= 1024 | ||
model_size = xml_size + bin_size | ||
return model_size | ||
|
||
|
||
def main(): | ||
torch.manual_seed(0) | ||
device = torch.device("cpu") | ||
print(f"Using {device} device") | ||
|
||
############################################################################### | ||
# Step 1: Prepare model and dataset | ||
print(os.linesep + "[Step 1] Prepare model and dataset") | ||
|
||
model = get_resnet18_model(device) | ||
model, acc1_fp32 = load_checkpoint(model) | ||
|
||
print(f"Accuracy@1 of original FP32 model: {acc1_fp32}") | ||
|
||
val_loader, calibration_dataset = create_data_loaders() | ||
|
||
def transform_fn(data_item): | ||
return data_item[0].to(device) | ||
|
||
quantization_dataset = nncf.Dataset(calibration_dataset, transform_fn) | ||
|
||
############################################################################### | ||
# Step 2: Quantize model | ||
print(os.linesep + "[Step 2] Quantize model") | ||
|
||
input_shape = (1, 3, IMAGE_SIZE, IMAGE_SIZE) | ||
example_input = torch.ones(*input_shape).cpu() | ||
|
||
fx_model = capture_pre_autograd_graph(model.eval(), args=(example_input,)) | ||
quantized_fx_model = nncf.quantize(fx_model, quantization_dataset) | ||
quantized_fx_model = torch.compile(quantized_fx_model, backend="openvino") | ||
|
||
acc1_int8 = validate(val_loader, quantized_fx_model, device) | ||
print(f"Accuracy@1 of INT8 model: {acc1_int8:.3f}") | ||
|
||
############################################################################### | ||
# Step 5: Run benchmarks | ||
print(os.linesep + "[Step 5] Run benchmarks") | ||
print("Run benchmark for FP32 model ...") | ||
fp32_latency = measure_latency(model, example_inputs=example_input) | ||
|
||
print("Run benchmark for INT8 model ...") | ||
int8_latency = measure_latency(quantized_fx_model, example_inputs=example_input) | ||
|
||
############################################################################### | ||
# Step 6: Summary | ||
print(os.linesep + "[Step 6] Summary") | ||
tabular_data = [ | ||
["Accuracy@1", acc1_fp32, acc1_int8, f"Diff: {acc1_fp32 - acc1_int8:.3f}"], | ||
["Performance, seconds", fp32_latency, int8_latency, f"Speedup x{fp32_latency / int8_latency:.3f}"], | ||
] | ||
print(create_table(["", "FP32", "INT8", "Summary"], tabular_data)) | ||
|
||
return acc1_fp32, acc1_int8, fp32_latency, int8_latency | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
4 changes: 4 additions & 0 deletions
4
examples/post_training_quantization/torch/fx/resnet18/requirements.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
fastdownload==0.0.7 | ||
openvino==2024.3 | ||
torch==2.4.0 | ||
torchvision==0.19.0 |