Skip to content

Commit

Permalink
Merge pull request #45 from BrainLesion/bug/legacy-weights-used
Browse files Browse the repository at this point in the history
Fix issue that leads to using old model weights
  • Loading branch information
MarcelRosier authored Mar 10, 2024
2 parents c446b61 + 48f8151 commit 1de48b7
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
4 changes: 4 additions & 0 deletions brainles_aurora/inferer/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,7 @@ class Device(str, Enum):
"""Use GPU (CUDA)"""
AUTO = "auto"
"""Attempt to use GPU, fallback to CPU."""


WEIGHTS_DIR = "weights"
"""Directory name to store model weights."""
4 changes: 2 additions & 2 deletions brainles_aurora/inferer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
import torch
from brainles_aurora.inferer.config import AuroraInfererConfig
from brainles_aurora.inferer.constants import InferenceMode, Output
from brainles_aurora.inferer.constants import InferenceMode, Output, WEIGHTS_DIR
from brainles_aurora.inferer.data import DataHandler
from brainles_aurora.utils import download_model_weights
from monai.inferers import SlidingWindowInferer
Expand Down Expand Up @@ -41,7 +41,7 @@ def __init__(
self.inference_mode = None
# download weights if not present
self.lib_path: str = Path(os.path.dirname(os.path.abspath(__file__)))
self.model_weights_folder = self.lib_path.parent / "model_weights"
self.model_weights_folder = self.lib_path.parent / WEIGHTS_DIR
if not self.model_weights_folder.exists():
download_model_weights(target_folder=str(self.model_weights_folder))

Expand Down

0 comments on commit 1de48b7

Please sign in to comment.