From e4934c867135dea9adb6d30ec1c80d0a6bf4e98f Mon Sep 17 00:00:00 2001 From: kcelia Date: Wed, 25 Sep 2024 14:26:11 +0200 Subject: [PATCH] chore: update --- .../CifarQuantizationAwareTraining.ipynb | 1 - .../cifar_brevitas_training/evaluate_torch_cml.py | 11 +++++++++-- use_case_examples/resnet/run_resnet18_fhe.py | 2 +- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/use_case_examples/cifar/cifar_brevitas_finetuning/CifarQuantizationAwareTraining.ipynb b/use_case_examples/cifar/cifar_brevitas_finetuning/CifarQuantizationAwareTraining.ipynb index 12321a7dc..42702b6c7 100644 --- a/use_case_examples/cifar/cifar_brevitas_finetuning/CifarQuantizationAwareTraining.ipynb +++ b/use_case_examples/cifar/cifar_brevitas_finetuning/CifarQuantizationAwareTraining.ipynb @@ -88,7 +88,6 @@ ], "source": [ "bit = 5\n", - "seed = 42\n", "\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "\n", diff --git a/use_case_examples/cifar/cifar_brevitas_training/evaluate_torch_cml.py b/use_case_examples/cifar/cifar_brevitas_training/evaluate_torch_cml.py index 5838eba00..04c7dce52 100644 --- a/use_case_examples/cifar/cifar_brevitas_training/evaluate_torch_cml.py +++ b/use_case_examples/cifar/cifar_brevitas_training/evaluate_torch_cml.py @@ -1,6 +1,7 @@ import argparse from pathlib import Path +import concrete.compiler import numpy as np import torch from concrete.fhe import Configuration @@ -74,8 +75,14 @@ def main(args): # observe a decrease in torch's top1 accuracy when using MPS devices # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3953 device = "cuda" if torch.cuda.is_available() else "cpu" + compilation_device = "cuda" if concrete.compiler.check_gpu_available() else "cpu" - print("Device in use:", device) + print("Torch device in use:", device) + print( + "To leverage the CUDA backend, follow the GPU setup guide to install the Concrete ML compiler." + ) + print("GPU Enabled:", concrete.compiler.check_gpu_enabled()) + print("GPU Available:", concrete.compiler.check_gpu_available()) # Find relative path to this file dir_path = Path(__file__).parent.absolute() @@ -123,7 +130,7 @@ def main(args): if rounding_threshold_bits is not None else None ), - device=COMPILATION_DEVICE, + device=compilation_device, ) # Print max bit-width in the circuit diff --git a/use_case_examples/resnet/run_resnet18_fhe.py b/use_case_examples/resnet/run_resnet18_fhe.py index f04a6d5cd..2e7a3a74d 100644 --- a/use_case_examples/resnet/run_resnet18_fhe.py +++ b/use_case_examples/resnet/run_resnet18_fhe.py @@ -276,7 +276,7 @@ def main(): "--export_statistics", action="store_true", help="Export the circuit statistics." ) parser.add_argument( - "--use_gpu", type=bool, action="store_true", help="Use the available GPU at FHE runtime." + "--use_gpu", action="store_true", help="Use the available GPU at FHE runtime." ) parser.add_argument( "--run_experiment",