Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: add use_gpu for cifar finetuning #882

Merged
merged 5 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
" plot_dataset,\n",
" torch_inference,\n",
")\n",
"from concrete.compiler import check_gpu_available\n",
"from models import QuantVGG11\n",
"from torchvision import datasets\n",
"\n",
Expand Down Expand Up @@ -62,7 +63,8 @@
"bit = 5\n",
"seed = 42\n",
"\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"use_gpu_if_available = False\n",
"device = \"cuda\" if use_gpu_if_available and check_gpu_available() else \"cpu\"\n",
"\n",
"print(f\"Device Type: {device}\")"
]
Expand Down Expand Up @@ -206,7 +208,7 @@
"\n",
"data_calibration, _ = next(iter(train_loader_c10))\n",
"\n",
"qmodel_c10 = fhe_compatibility(quant_vgg_c10, data_calibration)\n",
"qmodel_c10 = fhe_compatibility(quant_vgg_c10, data_calibration, device=device)\n",
"\n",
"print(\n",
" f\"Maximum bit-width in the circuit: {qmodel_c10.fhe_circuit.graph.maximum_integer_bit_width()}\"\n",
Expand Down Expand Up @@ -394,7 +396,7 @@
"\n",
"data_calibration, _ = next(iter(train_loader_c100))\n",
"\n",
"qmodel_c100 = fhe_compatibility(quant_vgg_c100, data_calibration)\n",
"qmodel_c100 = fhe_compatibility(quant_vgg_c100, data_calibration, device=device)\n",
"\n",
"print(\n",
" f\"Maximum bit-width in the circuit: {qmodel_c100.fhe_circuit.graph.maximum_integer_bit_width()}\"\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
"import matplotlib.pyplot as plt\n",
"import torch\n",
"from cifar_utils import fhe_simulation_inference, get_dataloader, torch_inference\n",
"from concrete.compiler import check_gpu_available\n",
"from concrete.fhe.compilation import Configuration\n",
"from models import QuantVGG11\n",
"from torch.utils.data.dataloader import DataLoader\n",
"from torchvision import datasets\n",
Expand All @@ -59,7 +61,8 @@
"seed = 42\n",
"rounding_thresholds_bits = [8, 7, 6, 5, 3]\n",
"\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\""
"use_gpu_if_available = False\n",
"device = \"cuda\" if use_gpu_if_available and check_gpu_available() else \"cpu\""
]
},
{
Expand Down Expand Up @@ -91,6 +94,7 @@
" model.to(\"cpu\"),\n",
" torch_inputset=X_train,\n",
" rounding_threshold_bits=max_bitwidth,\n",
" configuration=Configuration(use_gpu=(device == \"cuda\")),\n",
" )\n",
"\n",
" acc_fhe_s = fhe_simulation_inference(qmodel, test_loader, True)\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
" torch_inference,\n",
" train,\n",
")\n",
"from concrete.compiler import check_gpu_available\n",
"\n",
"# As we follow the same methodology for quantization aware training for CIFAR-10 and CIFAR-100.\n",
"# Let's import some generic functions.\n",
Expand Down Expand Up @@ -88,7 +89,8 @@
"source": [
"bit = 5\n",
"\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"use_gpu_if_available = False\n",
"device = \"cuda\" if use_gpu_if_available and check_gpu_available() else \"cpu\"\n",
"\n",
"print(f\"Device Type: {device}\")"
]
Expand Down Expand Up @@ -280,7 +282,7 @@
"\n",
"data_calibration, _ = next(iter(train_loader_c100))\n",
"\n",
"qmodel = fhe_compatibility(quant_vgg, data_calibration)\n",
"qmodel = fhe_compatibility(quant_vgg, data_calibration, device=device)\n",
"\n",
"print(\n",
" f\"With {param_c100['dataset_name']}, the maximum bit-width in the circuit = \"\n",
Expand Down Expand Up @@ -544,7 +546,7 @@
"# Check the FHE-compatibility.\n",
"data, _ = next(iter(train_loader_c10))\n",
"\n",
"qmodel = fhe_compatibility(quant_vgg, data)\n",
"qmodel = fhe_compatibility(quant_vgg, data, device=device)\n",
"\n",
"print(\n",
" f\"With {param_c10['dataset_name']}, the circuit has a maximum bit-width of \"\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"\n",
"import torch\n",
"from cifar_utils import get_dataloader, plot_dataset, plot_history, torch_inference, train\n",
"from concrete.compiler import check_gpu_available\n",
"from models import Fp32VGG11\n",
"\n",
"warnings.filterwarnings(\"ignore\")"
Expand Down Expand Up @@ -64,7 +65,8 @@
"source": [
"dataset_name = \"CIFAR_100\"\n",
"\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"use_gpu_if_available = False\n",
"device = \"cuda\" if use_gpu_if_available and check_gpu_available() else \"cpu\"\n",
"\n",
"param_c10 = {\n",
" \"output_size\": 10,\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
"import numpy\n",
"import torch\n",
"from cifar_utils import get_dataloader, mapping_keys, plot_dataset, torch_inference, train\n",
"from concrete.compiler import check_gpu_available\n",
"from concrete.fhe.compilation import Configuration\n",
"from sklearn.metrics import top_k_accuracy_score\n",
"\n",
"from concrete.ml.pytest.torch_models import QNNFashionMNIST\n",
Expand Down Expand Up @@ -73,7 +75,8 @@
" \"seed\": 42,\n",
"}\n",
"\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"use_gpu_if_available = False\n",
"device = \"cuda\" if use_gpu_if_available and check_gpu_available() else \"cpu\"\n",
"\n",
"print(f\"Device Type: {device}\")"
]
Expand Down Expand Up @@ -245,7 +248,10 @@
"\n",
" start_time = time()\n",
" qmodel = compile_brevitas_qat_model(\n",
" torch_model=quant_model, torch_inputset=X_calib, p_error=p_error\n",
" torch_model=quant_model,\n",
" torch_inputset=X_calib,\n",
" p_error=p_error,\n",
" configuration=Configuration(use_gpu=(device == \"cuda\")),\n",
kcelia marked this conversation as resolved.
Show resolved Hide resolved
" )\n",
" compilation_time.append((time() - start_time) / 60.0)\n",
"\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import warnings
from collections import OrderedDict
from pathlib import Path
from time import time
from typing import Callable, Dict, Optional, Tuple

import matplotlib.pyplot as plt
Expand All @@ -14,7 +13,6 @@
from brevitas import config
from concrete.fhe.compilation import Configuration
from models import Fp32VGG11
from sklearn.metrics import top_k_accuracy_score
from torch.utils.data.dataloader import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid
Expand Down Expand Up @@ -441,12 +439,13 @@ def torch_inference(
return np.mean(np.vstack(correct), dtype="float64")


def fhe_compatibility(model: Callable, data: DataLoader) -> Callable:
def fhe_compatibility(model: Callable, data: DataLoader, device: str) -> Callable:
"""Test if the model is FHE-compatible.

Args:
model (Callable): The Brevitas model.
data (DataLoader): The data loader.
device (str): Specifies the device to run on, either 'cpu' or 'gpu'.

Returns:
Callable: Quantized model.
Expand All @@ -458,6 +457,7 @@ def fhe_compatibility(model: Callable, data: DataLoader) -> Callable:
torch_inputset=data,
show_mlir=False,
output_onnx_file="test.onnx",
configuration=Configuration(use_gpu=(device == "cuda")),
kcelia marked this conversation as resolved.
Show resolved Hide resolved
)

return qmodel
Expand Down
16 changes: 8 additions & 8 deletions use_case_examples/resnet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,14 @@ GPU machine: 8xH100 GPU machine

Summary of the accuracy evaluation on ImageNet (100 images):

| w&a bits | p_error | Accuracy | Top-5 Accuracy | Runtime* | Device |
| -------- | ------- | -------- | -------------- | --------------- | ------ |
| fp32 | - | 67% | 87% | - | - |
| 6/6 | 0.05 | 55% | 78% | 56 min | GPU |
| 6/6 | 0.05 | 55% | 78% | 1 h 31 min | CPU |
| 7/7 | 0.05 | **66%** | **87%** | **2 h 12 min** | CPU |

*Runtime reported to run the inference on a single image
| w&a bits | p_error | Accuracy | Top-5 Accuracy | Runtime\* | Device |
kcelia marked this conversation as resolved.
Show resolved Hide resolved
| -------- | ------- | -------- | -------------- | -------------- | ------ |
| fp32 | - | 67% | 87% | - | - |
| 6/6 | 0.05 | 55% | 78% | 56 min | GPU |
| 6/6 | 0.05 | 55% | 78% | 1 h 31 min | CPU |
| 7/7 | 0.05 | **66%** | **87%** | **2 h 12 min** | CPU |

\*Runtime reported to run the inference on a single image

6/6 `n_bits` configuration: {"model_inputs": 8, "op_inputs": 6, "op_weights": 6, "model_outputs": 9}

Expand Down
Loading