Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: add use_gpu for cifar finetuning #882

Merged
merged 5 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"source": [
"import warnings\n",
"\n",
"import concrete.compiler\n",
"import torch\n",
"from cifar_utils import (\n",
" fhe_compatibility,\n",
Expand Down Expand Up @@ -67,6 +68,34 @@
"print(f\"Device Type: {device}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Concrete ML also supports a CUDA-enabled backend. To set it up, follow the instructions in the official [guide](../../../docs/guides/using_gpu.md) for installing the GPU-enabled Concrete compiler."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Is GPU enabled: False\n",
"Is GPU available: False\n"
]
}
],
"source": [
"compilation_device = \"cuda\" if concrete.compiler.check_gpu_available() else \"cpu\"\n",
"\n",
"print(f\"Is GPU enabled: {concrete.compiler.check_gpu_enabled()}\")\n",
"print(f\"Is GPU available: {concrete.compiler.check_gpu_available()}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -206,7 +235,7 @@
"\n",
"data_calibration, _ = next(iter(train_loader_c10))\n",
"\n",
"qmodel_c10 = fhe_compatibility(quant_vgg_c10, data_calibration)\n",
"qmodel_c10 = fhe_compatibility(quant_vgg_c10, data_calibration, device=compilation_device)\n",
"\n",
"print(\n",
" f\"Maximum bit-width in the circuit: {qmodel_c10.fhe_circuit.graph.maximum_integer_bit_width()}\"\n",
Expand Down Expand Up @@ -394,7 +423,7 @@
"\n",
"data_calibration, _ = next(iter(train_loader_c100))\n",
"\n",
"qmodel_c100 = fhe_compatibility(quant_vgg_c100, data_calibration)\n",
"qmodel_c100 = fhe_compatibility(quant_vgg_c100, data_calibration, device=compilation_device)\n",
"\n",
"print(\n",
" f\"Maximum bit-width in the circuit: {qmodel_c100.fhe_circuit.graph.maximum_integer_bit_width()}\"\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"import warnings\n",
"from typing import Callable, List, Tuple\n",
"\n",
"import concrete.compiler\n",
"import matplotlib.pyplot as plt\n",
"import torch\n",
"from cifar_utils import fhe_simulation_inference, get_dataloader, torch_inference\n",
Expand Down Expand Up @@ -62,6 +63,25 @@
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Concrete ML also supports a CUDA-enabled backend. To set it up, follow the instructions in the official [guide](../../../docs/guides/using_gpu.md) for installing the GPU-enabled Concrete compiler."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"compilation_device = \"cuda\" if concrete.compiler.check_gpu_available() else \"cpu\"\n",
"\n",
"print(f\"Is GPU enabled: {concrete.compiler.check_gpu_enabled()}\")\n",
"print(f\"Is GPU available: {concrete.compiler.check_gpu_available()}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -91,6 +111,7 @@
" model.to(\"cpu\"),\n",
" torch_inputset=X_train,\n",
" rounding_threshold_bits=max_bitwidth,\n",
" device=compilation_device,\n",
" )\n",
"\n",
" acc_fhe_s = fhe_simulation_inference(qmodel, test_loader, True)\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"metadata": {},
"outputs": [],
"source": [
"import concrete.compiler\n",
"import torch\n",
"from cifar_utils import (\n",
" fhe_compatibility,\n",
Expand Down Expand Up @@ -93,6 +94,25 @@
"print(f\"Device Type: {device}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Concrete ML also supports a CUDA-enabled backend. To set it up, follow the instructions in the official [guide](../../../docs/guides/using_gpu.md) for installing the GPU-enabled Concrete compiler."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"compilation_device = \"cuda\" if concrete.compiler.check_gpu_available() else \"cpu\"\n",
"\n",
"print(f\"Is GPU enabled: {concrete.compiler.check_gpu_enabled()}\")\n",
"print(f\"Is GPU available: {concrete.compiler.check_gpu_available()}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -280,7 +300,7 @@
"\n",
"data_calibration, _ = next(iter(train_loader_c100))\n",
"\n",
"qmodel = fhe_compatibility(quant_vgg, data_calibration)\n",
"qmodel = fhe_compatibility(quant_vgg, data_calibration, device=compilation_device)\n",
"\n",
"print(\n",
" f\"With {param_c100['dataset_name']}, the maximum bit-width in the circuit = \"\n",
Expand Down Expand Up @@ -544,7 +564,7 @@
"# Check the FHE-compatibility.\n",
"data, _ = next(iter(train_loader_c10))\n",
"\n",
"qmodel = fhe_compatibility(quant_vgg, data)\n",
"qmodel = fhe_compatibility(quant_vgg, data, device=compilation_device)\n",
"\n",
"print(\n",
" f\"With {param_c10['dataset_name']}, the circuit has a maximum bit-width of \"\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"from itertools import chain\n",
"from time import time\n",
"\n",
"import concrete.compiler\n",
"import matplotlib.pylab as plt\n",
"import numpy\n",
"import torch\n",
Expand Down Expand Up @@ -78,6 +79,25 @@
"print(f\"Device Type: {device}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Concrete ML also supports a CUDA-enabled backend. To set it up, follow the instructions in the official [guide](../../../docs/guides/using_gpu.md) for installing the GPU-enabled Concrete compiler."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"compilation_device = \"cuda\" if concrete.compiler.check_gpu_available() else \"cpu\"\n",
"\n",
"print(f\"Is GPU enabled: {concrete.compiler.check_gpu_enabled()}\")\n",
"print(f\"Is GPU available: {concrete.compiler.check_gpu_available()}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -245,7 +265,10 @@
"\n",
" start_time = time()\n",
" qmodel = compile_brevitas_qat_model(\n",
" torch_model=quant_model, torch_inputset=X_calib, p_error=p_error\n",
" torch_model=quant_model,\n",
" torch_inputset=X_calib,\n",
" p_error=p_error,\n",
" device=compilation_device,\n",
" )\n",
" compilation_time.append((time() - start_time) / 60.0)\n",
"\n",
Expand Down Expand Up @@ -353,7 +376,10 @@
"\n",
"# Compile the model with the optimal `p_error`\n",
"qmodel = compile_brevitas_qat_model(\n",
" torch_model=quant_model, torch_inputset=X_calib, p_error=largest_p_error\n",
" torch_model=quant_model,\n",
" torch_inputset=X_calib,\n",
" p_error=largest_p_error,\n",
" device=compilation_device,\n",
")\n",
"\n",
"# Key Generation\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import warnings
from collections import OrderedDict
from pathlib import Path
from time import time
from typing import Callable, Dict, Optional, Tuple

import matplotlib.pyplot as plt
Expand All @@ -14,7 +13,6 @@
from brevitas import config
from concrete.fhe.compilation import Configuration
from models import Fp32VGG11
from sklearn.metrics import top_k_accuracy_score
from torch.utils.data.dataloader import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid
Expand Down Expand Up @@ -441,12 +439,13 @@ def torch_inference(
return np.mean(np.vstack(correct), dtype="float64")


def fhe_compatibility(model: Callable, data: DataLoader) -> Callable:
def fhe_compatibility(model: Callable, data: DataLoader, device: str) -> Callable:
"""Test if the model is FHE-compatible.

Args:
model (Callable): The Brevitas model.
data (DataLoader): The data loader.
device (str): Specifies the device to run during the compilation, either 'cpu' or 'gpu'.

Returns:
Callable: Quantized model.
Expand All @@ -458,6 +457,7 @@ def fhe_compatibility(model: Callable, data: DataLoader) -> Callable:
torch_inputset=data,
show_mlir=False,
output_onnx_file="test.onnx",
device=device,
)

return qmodel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path

import torch
from concrete.compiler import check_gpu_available
from concrete.fhe import Exactness
from concrete.fhe.compilation.configuration import Configuration
from models import cnv_2w2a
Expand All @@ -22,6 +23,8 @@
# observe a decrease in torch's top1 accuracy when using MPS devices
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3953
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
COMPILATION_DEVICE = "cuda" if check_gpu_available() else "cpu"

NUM_SAMPLES = int(os.environ.get("NUM_SAMPLES", 1))
P_ERROR = float(os.environ.get("P_ERROR", 0.01))

Expand Down Expand Up @@ -93,6 +96,7 @@ def wrapper(*args, **kwargs):
configuration=configuration,
rounding_threshold_bits={"method": Exactness.APPROXIMATE, "n_bits": 6},
p_error=P_ERROR,
device=COMPILATION_DEVICE,
)
assert isinstance(quantized_numpy_module, QuantizedModule)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
from pathlib import Path

import concrete.compiler
import numpy as np
import torch
from concrete.fhe import Configuration
Expand Down Expand Up @@ -74,8 +75,14 @@ def main(args):
# observe a decrease in torch's top1 accuracy when using MPS devices
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3953
device = "cuda" if torch.cuda.is_available() else "cpu"
compilation_device = "cuda" if concrete.compiler.check_gpu_available() else "cpu"

print("Device in use:", device)
print("Torch device in use:", device)
print(
"To leverage the CUDA backend, follow the GPU setup guide to install the Concrete ML compiler."
)
print("GPU Enabled:", concrete.compiler.check_gpu_enabled())
print("GPU Available:", concrete.compiler.check_gpu_available())

# Find relative path to this file
dir_path = Path(__file__).parent.absolute()
Expand Down Expand Up @@ -123,6 +130,7 @@ def main(args):
if rounding_threshold_bits is not None
else None
),
device=compilation_device,
)

# Print max bit-width in the circuit
Expand Down
16 changes: 8 additions & 8 deletions use_case_examples/resnet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,14 @@ GPU machine: 8xH100 GPU machine

Summary of the accuracy evaluation on ImageNet (100 images):

| w&a bits | p_error | Accuracy | Top-5 Accuracy | Runtime* | Device |
| -------- | ------- | -------- | -------------- | --------------- | ------ |
| fp32 | - | 67% | 87% | - | - |
| 6/6 | 0.05 | 55% | 78% | 56 min | GPU |
| 6/6 | 0.05 | 55% | 78% | 1 h 31 min | CPU |
| 7/7 | 0.05 | **66%** | **87%** | **2 h 12 min** | CPU |

*Runtime reported to run the inference on a single image
| w&a bits | p_error | Accuracy | Top-5 Accuracy | Runtime\* | Device |
kcelia marked this conversation as resolved.
Show resolved Hide resolved
| -------- | ------- | -------- | -------------- | -------------- | ------ |
| fp32 | - | 67% | 87% | - | - |
| 6/6 | 0.05 | 55% | 78% | 56 min | GPU |
| 6/6 | 0.05 | 55% | 78% | 1 h 31 min | CPU |
| 7/7 | 0.05 | **66%** | **87%** | **2 h 12 min** | CPU |

\*Runtime reported to run the inference on a single image

6/6 `n_bits` configuration: {"model_inputs": 8, "op_inputs": 6, "op_weights": 6, "model_outputs": 9}

Expand Down
Loading
Loading