diff --git a/brainles_aurora/inferer/constants.py b/brainles_aurora/inferer/constants.py index b43ab68..7baae07 100644 --- a/brainles_aurora/inferer/constants.py +++ b/brainles_aurora/inferer/constants.py @@ -80,3 +80,7 @@ class Device(str, Enum): """Use GPU (CUDA)""" AUTO = "auto" """Attempt to use GPU, fallback to CPU.""" + + +WEIGHTS_DIR = "weights" +"""Directory name to store model weights.""" diff --git a/brainles_aurora/inferer/model.py b/brainles_aurora/inferer/model.py index 43b8bf5..71b69a0 100644 --- a/brainles_aurora/inferer/model.py +++ b/brainles_aurora/inferer/model.py @@ -8,7 +8,7 @@ import numpy as np import torch from brainles_aurora.inferer.config import AuroraInfererConfig -from brainles_aurora.inferer.constants import InferenceMode, Output +from brainles_aurora.inferer.constants import InferenceMode, Output, WEIGHTS_DIR from brainles_aurora.inferer.data import DataHandler from brainles_aurora.utils import download_model_weights from monai.inferers import SlidingWindowInferer @@ -41,7 +41,7 @@ def __init__( self.inference_mode = None # download weights if not present self.lib_path: str = Path(os.path.dirname(os.path.abspath(__file__))) - self.model_weights_folder = self.lib_path.parent / "model_weights" + self.model_weights_folder = self.lib_path.parent / WEIGHTS_DIR if not self.model_weights_folder.exists(): download_model_weights(target_folder=str(self.model_weights_folder))