Skip to content

Commit

Permalink
Merge pull request #38 from BrainLesion/19-feature-request-host-model…
Browse files Browse the repository at this point in the history
…-weights-on-zenodo

19 feature request host model weights on zenodo
  • Loading branch information
neuronflow authored Mar 4, 2024
2 parents a33b69f + d9fd5df commit 4a6613f
Show file tree
Hide file tree
Showing 21 changed files with 58 additions and 53 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,4 @@ dmypy.json
.vscode
poetry.lock
.DS_Store
brainles_aurora/model_weights/*
1 change: 1 addition & 0 deletions brainles_aurora/inferer/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,4 +221,5 @@ def infer(
self.data_handler.save_as_nifti(
postproc_data=out, output_file_mapping=output_file_mapping
)
logger.info(f"{' Finished inference run ':=^80}")
return out
5 changes: 2 additions & 3 deletions brainles_aurora/inferer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(
self.lib_path: str = Path(os.path.dirname(os.path.abspath(__file__)))
self.model_weights_folder = self.lib_path.parent / "model_weights"
if not self.model_weights_folder.exists():
download_model_weights(target_folder=str(self.lib_path.parent))
download_model_weights(target_folder=str(self.model_weights_folder))

def load_model(
self, inference_mode: InferenceMode, num_input_modalities: int
Expand Down Expand Up @@ -85,8 +85,7 @@ def _load_model(self, num_input_modalities: int) -> torch.nn.Module:
# load weights
weights_path = os.path.join(
self.model_weights_folder,
self.inference_mode,
f"{self.config.model_selection}.tar",
f"{self.inference_mode}_{self.config.model_selection}.tar",
)
if not os.path.exists(weights_path):
raise NotImplementedError(
Expand Down
Loading

0 comments on commit 4a6613f

Please sign in to comment.