Skip to content

Commit

Permalink
Merge pull request #25 from BrainLesion/20-feature-request-enhanced-a…
Browse files Browse the repository at this point in the history
…utomatic-log-file-naming

20 feature request enhanced automatic log file naming
  • Loading branch information
neuronflow authored Feb 1, 2024
2 parents 694d7c8 + 7eed04b commit e552d57
Show file tree
Hide file tree
Showing 11 changed files with 135 additions and 489 deletions.
10 changes: 10 additions & 0 deletions brainles_aurora/inferer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from .constants import (
DataMode,
InferenceMode,
ModelSelection,
Output,
MODALITIES,
IMGS_TO_MODE_DICT,
)
from .dataclasses import BaseConfig, AuroraInfererConfig
from .inferer import AuroraInferer, AuroraGPUInferer
5 changes: 1 addition & 4 deletions brainles_aurora/inferer/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,17 @@
from typing import Tuple


from brainles_aurora.inferer.constants import DataMode, ModelSelection
from brainles_aurora.inferer import DataMode, ModelSelection


@dataclass
class BaseConfig:
"""Base configuration for the Aurora model inferer.
Attributes:
output_mode (DataMode, optional): Output mode for the inference results. Defaults to DataMode.NIFTI_FILE.
log_level (int | str, optional): Logging level. Defaults to logging.INFO.
"""

output_mode: DataMode = DataMode.NIFTI_FILE
log_level: int | str = logging.INFO


Expand All @@ -24,7 +22,6 @@ class AuroraInfererConfig(BaseConfig):
"""Configuration for the Aurora model inferer.
Attributes:
output_mode (DataMode, optional): Output mode for the inference results. Defaults to DataMode.NIFTI_FILE.
log_level (int | str, optional): Logging level. Defaults to logging.INFO.
tta (bool, optional): Whether to apply test-time augmentations. Defaults to True.
sliding_window_batch_size (int, optional): Batch size for sliding window inference. Defaults to 1.
Expand Down
92 changes: 51 additions & 41 deletions brainles_aurora/inferer/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,20 @@
from torch.utils.data import DataLoader
import uuid

from brainles_aurora.inferer.constants import (
from brainles_aurora.inferer import (
IMGS_TO_MODE_DICT,
DataMode,
InferenceMode,
Output,
AuroraInfererConfig,
BaseConfig,
)
from brainles_aurora.utils import (
turbo_path,
DualStdErrOutput,
download_model_weights,
remove_path_suffixes,
)
from brainles_aurora.aux import turbo_path, DualStdErrOutput
from brainles_aurora.inferer.dataclasses import AuroraInfererConfig, BaseConfig
from brainles_aurora.download import download_model_weights


class AbstractInferer(ABC):
Expand Down Expand Up @@ -221,12 +226,19 @@ def _get_data_loader(self) -> torch.utils.data.DataLoader:
"""
# init transforms
transforms = [
LoadImageD(keys=["images"])
if self.input_mode == DataMode.NIFTI_FILE
else None,
EnsureChannelFirstd(keys="images")
if len(self._get_not_none_files()) == 1
else None,
(
LoadImageD(keys=["images"])
if self.input_mode == DataMode.NIFTI_FILE
else None
),
(
EnsureChannelFirstd(keys="images")
if (
len(self._get_not_none_files()) == 1
and self.input_mode == DataMode.NIFTI_FILE
)
else None
),
Lambdad(["images"], np.nan_to_num),
ScaleIntensityRangePercentilesd(
keys="images",
Expand Down Expand Up @@ -415,11 +427,11 @@ def _post_process(
Output.METASTASIS_NETWORK: enhancing_out,
}

def _sliding_window_inference(self) -> None | Dict[str, np.ndarray]:
def _sliding_window_inference(self) -> Dict[str, np.ndarray]:
"""Perform sliding window inference using monai.inferers.SlidingWindowInferer.
Returns:
None | Dict[str, np.ndarray]: Post-processed data if output_mode is NUMPY, otherwise the data is saved as a niftis and None is returned.
Dict[str, np.ndarray]: Post-processed data
"""
inferer = SlidingWindowInferer(
roi_size=self.config.crop_size, # = patch_size
Expand All @@ -434,7 +446,7 @@ def _sliding_window_inference(self) -> None | Dict[str, np.ndarray]:
with torch.no_grad():
self.model.eval()
self.model = self.model.to(self.device)
# loop through batches, only 1 batch!
# currently always only 1 batch! TODO: potentialy add support to pass multiple image tuples at once?
for data in self.data_loader:
inputs = data["images"].to(self.device)

Expand All @@ -445,14 +457,18 @@ def _sliding_window_inference(self) -> None | Dict[str, np.ndarray]:
outputs, data, inferer
)

self.log.info("Post-processing data")
postprocessed_data = self._post_process(
onehot_model_outputs_CHWD=outputs,
)
if self.config.output_mode == DataMode.NUMPY:
return postprocessed_data
else:

# save data to fie if paths are provided
if any(self.output_file_mapping.values()):
self.log.info("Saving post-processed data as NIFTI files")
self._save_as_nifti(postproc_data=postprocessed_data)
return None

self.log.info("Returning post-processed data as Dict of Numpy arrays")
return postprocessed_data

def _configure_device(self) -> torch.device:
"""Configure the device for inference.
Expand Down Expand Up @@ -494,22 +510,22 @@ def infer(
log_file (str | Path | None, optional): _description_. Defaults to None.
Returns:
Dict[str, np.ndarray] | None: Post-processed data if output_mode is NUMPY, otherwise the data is saved as a niftis and None is returned.
Dict[str, np.ndarray]: Post-processed data.
"""
# setup logger for inference run
if not log_file:
log_file = (
Path(segmentation_file).with_suffix(".log")
if segmentation_file
else os.path.abspath(f"./{self.__class__.__name__}.log")
if log_file:
self.log = self._setup_logger(log_file=log_file)
else:
# if no log file is provided: set logfile to segmentation filename if provided, else inferer class name
self.log = self._setup_logger(
log_file=(
remove_path_suffixes(segmentation_file).with_suffix(".log")
if segmentation_file
else os.path.abspath(f"./{self.__class__.__name__}.log")
),
)
self.log = self._setup_logger(
log_file=log_file,
)

self.log.info(f"Running inference on {self.device}")

# check inputs and get mode , == prev mode => run inference, else load new model
# check inputs and get mode , if mode == prev mode => run inference, else load new model
prev_mode = self.inference_mode
self.validated_images = self._validate_images(t1=t1, t1c=t1c, t2=t2, fla=fla)
self.inference_mode = self._determine_inference_mode(
Expand All @@ -528,20 +544,14 @@ def infer(
self.data_loader = self._get_data_loader()

# setup output file paths
if self.config.output_mode == DataMode.NIFTI_FILE:
# TODO add error handling to ensure file extensions present
if not segmentation_file:
default_segmentation_path = os.path.abspath("./segmentation.nii.gz")
self.log.warning(
f"No segmentation file name provided, using default path: {default_segmentation_path}"
)
self.output_file_mapping = {
Output.SEGMENTATION: segmentation_file or default_segmentation_path,
Output.WHOLE_NETWORK: whole_tumor_unbinarized_floats_file,
Output.METASTASIS_NETWORK: metastasis_unbinarized_floats_file,
}
self.output_file_mapping = {
Output.SEGMENTATION: segmentation_file,
Output.WHOLE_NETWORK: whole_tumor_unbinarized_floats_file,
Output.METASTASIS_NETWORK: metastasis_unbinarized_floats_file,
}

########
self.log.info(f"Running inference on device := {self.device}")
out = self._sliding_window_inference()
self.log.info(f"Finished inference {os.linesep}")
return out
Expand Down
2 changes: 2 additions & 0 deletions brainles_aurora/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .utils import turbo_path, remove_path_suffixes, DualStdErrOutput
from .download import download_model_weights
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# copied from https://github.com/Nordgaren/Github-Folder-Downloader/blob/master/gitdl.py
import os
import shutil as sh

import requests
from github import ContentFile, Github, Repository
Expand Down
34 changes: 29 additions & 5 deletions brainles_aurora/aux.py → brainles_aurora/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,50 @@
from typing import IO


def turbo_path(the_path):
turbo_path = Path(
def turbo_path(path: str | Path) -> Path:
"""Make path absolute and normed
Args:
path (str | Path): input path
Returns:
Path: absolute and normed path
"""
return Path(
os.path.normpath(
os.path.abspath(
the_path,
path,
)
)
)
return turbo_path


def remove_path_suffixes(path: Path | str) -> Path:
"""Remove all suffixes from a path
Args:
path (Path | str): path to remove suffixes from
Returns:
Path: path without suffixes
"""
path_stem = Path(path)
while path_stem.suffix:
path_stem = path_stem.with_suffix("")
return path_stem


class DualStdErrOutput:
"""Class to write to stderr and a file at the same time"""

def __init__(self, stderr: IO, file_handler_stream: IO = None):
self.stderr = stderr
self.file_handler_stream = file_handler_stream

def set_file_handler_stream(self, file_handler_stream: IO):
self.file_handler_stream = file_handler_stream

def write(self, text):
def write(self, text: str):
self.stderr.write(text)
if self.file_handler_stream:
self.file_handler_stream.write(text)
Expand Down
Loading

0 comments on commit e552d57

Please sign in to comment.