Skip to content

Commit

Permalink
Comment out unused code and finish inference
Browse files Browse the repository at this point in the history
  • Loading branch information
rudolphpienaar committed Apr 22, 2024
1 parent 6e4c369 commit 3425b52
Showing 1 changed file with 111 additions and 110 deletions.
221 changes: 111 additions & 110 deletions spleenseg/core/neuralnet.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,21 @@
#!/usr/bin/env python

# from collections.abc import Iterable
# from pathlib import Path
from argparse import Namespace
from collections.abc import Callable
from pathlib import Path

# from dataclasses import dataclass, field

# import os, sys
# from monai.transforms import transform
from monai.transforms.compose import Compose

# from monai.utils import first, set_determinism

# from monai.transforms import (
# AsDiscrete,
# AsDiscreted,
# EnsureChannelFirstd,
# Compose,
# CropForegroundd,
# LoadImaged,
# Orientationd,
# RandCropByPosNegLabeld,
# RandAffined,
# SaveImaged,
# ScaleIntensityRanged,
# Spacingd,
# Invertd,
# )
# from monai.handlers.utils import from_engine
# from monai.networks.nets.unet import UNet

# from monai.transforms import LoadImage
# from monai.networks.layers.factories import Norm
# from monai.metrics.meandice import DiceMetric
# from monai.losses.dice import DiceLoss
from monai.inferers.utils import sliding_window_inference
from monai.data.dataset import CacheDataset
from monai.data.dataloader import DataLoader

# from monai.data.dataset import Dataset
from monai.data.utils import decollate_batch
from monai.data.meta_tensor import MetaTensor
from monai.handlers.utils import from_engine

# from monai.apps.utils import download_and_extract
import torch

# import matplotlib.pyplot as plt
# import tempfile
# import shutil
# import glob
# import pudb
from typing import Any, Sequence
import numpy as np

Expand All @@ -60,65 +24,6 @@
from spleenseg.plotting import plotting


# @dataclass
# class LoaderCache:
# loader: DataLoader
# cache: CacheDataset
#
#
# @dataclass
# class TrainingParams:
# max_epochs: int = 600
# val_interval = 2
# best_metric: float = -1.0
# best_metric_epoch = -1
# modelPth: Path = Path("")
# modelONNX: Path = Path("")
# determinismSeed: int = 0
#
# def __init__(self, options: Namespace):
# self.options = options
# if options is not None:
# self.max_epochs = self.options.maxEpochs
# self.modelPth = Path(options.outputdir) / "model.pth"
# self.modelONNX = Path(options.outputdir) / "model.onnx"
# self.determinismSeed = self.options.determinismSeed
# set_determinism(seed=self.determinismSeed)
#
#
# @dataclass
# class TrainingLog:
# loss_per_epoch: list[float] = field(default_factory=list)
# metric_per_epoch: list[float] = field(default_factory=list)
#
#
# @dataclass
# class ModelParams:
# optimizer: torch.optim.Adam
# device: torch.device = torch.device("cpu")
# model: UNet = UNet(
# spatial_dims=3,
# in_channels=1,
# out_channels=2,
# channels=(16, 32, 64, 128, 256),
# strides=(2, 2, 2, 2),
# num_res_units=2,
# norm=Norm.BATCH,
# )
# fn_loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = DiceLoss(
# to_onehot_y=True, softmax=True
# )
# dice_metric: DiceMetric = DiceMetric(include_background=False, reduction="mean")
#
# def __init__(self, options: Namespace):
# self.options = options
# if options is not None:
# self.device = torch.device(self.options.device)
# torch.manual_seed(42)
# self.model = self.model.to(self.device)
# self.optimizer = torch.optim.Adam(self.model.parameters(), 1e-4)


def tensor_desc(
T: torch.Tensor | tuple[torch.Tensor, ...] | dict[Any, torch.Tensor], **kwargs
) -> torch.Tensor:
Expand All @@ -133,16 +38,18 @@ def tensor_desc(
for k, v in kwargs.items():
if k.lower() == "desc":
strAs = v
T = torch.as_tensor(T)
Tt: torch.Tensor = torch.as_tensor(T)
match strAs:
case "meanstd":
tensor = torch.Tensor([T.mean().item(), T.std().item()])
tensor = torch.Tensor([Tt.mean().item(), Tt.std().item()])
case "l1l2":
tensor = torch.Tensor([T.abs().sum().item(), T.pow(2).sum().sqrt().item()])
tensor = torch.Tensor(
[Tt.abs().sum().item(), Tt.pow(2).sum().sqrt().item()]
)
case "minmax":
tensor = torch.Tensor([T.min().item(), T.max().item()])
tensor = torch.Tensor([Tt.min().item(), Tt.max().item()])
case "simplified":
tensor = T.mean(dim=(1, 2), keepdim=True)
tensor = Tt.mean(dim=(1, 2), keepdim=True)
return tensor


Expand All @@ -161,9 +68,13 @@ def __init__(self, options: Namespace):
self.f_outputPost: Compose
self.f_labelPost: Compose

self.trainingFileSet: list[dict[str, str]]
self.validationFileSet: list[dict[str, str]]
self.testingFileSet: list[dict[str, str]]

self.trainingSpace: data.LoaderCache
self.validationSpace: data.LoaderCache
self.novelSpace: data.LoaderCache
self.testingSpace: data.LoaderCache

def loaderCache_create(
self, fileList: list[dict[str, str]], transforms: Compose, batch_size: int = 2
Expand Down Expand Up @@ -199,6 +110,36 @@ def loaderCache_create(
loaderCache: data.LoaderCache = data.LoaderCache(cache=ds, loader=loader)
return loaderCache

def trainingTransformsAndSpace_setup(self) -> bool:
setupOK: bool = True
trainingTransforms: Compose
validationTransforms: Compose
trainingTransforms, validationTransforms = (
transforms.trainingAndValidation_transformsSetup()
)
if not transforms.transforms_check(
Path(self.network.options.outputdir),
self.validationFileSet,
validationTransforms,
):
return False

self.trainingSpace = self.loaderCache_create(
self.trainingFileSet, trainingTransforms
)
self.validationSpace = self.loaderCache_create(
self.validationFileSet, validationTransforms, 1
)
return setupOK

def testingTransformsAndSpace_setup(self) -> bool:
testingTransforms: Compose
testingTransforms = transforms.inferenceUse_transforms()
self.testingSpace = self.loaderCache_create(
self.testingFileSet, testingTransforms, 1
)
return True

def tensor_assign(
self,
to: str,
Expand Down Expand Up @@ -249,6 +190,9 @@ def train_overSampleSpace_retLoss(self, trainingSpace: data.LoaderCache) -> floa
trainingInstance["image"].to(self.network.device),
trainingInstance["label"].to(self.network.device),
)
if sample == 1:
print(f"training image shape: {self.input.shape}")
print(f"training label shape: {self.target.shape}")
sample_loss = self.evalAndCorrect()
total_loss += sample_loss
if (
Expand All @@ -272,15 +216,17 @@ def train(
[transforms.f_AsDiscreteArgMax()]
)
self.f_labelPost = transforms.transforms_build([transforms.f_AsDiscrete()])
self.trainingEpoch: int = 0
self.trainingEpoch = 0
epoch_loss: float = 0.0
if trainingSpace:
self.trainingSpace = trainingSpace
if validationSpace:
self.validationSpace = validationSpace
for self.trainingEpoch in range(self.trainingParams.max_epochs):
print("-" * 10)
print(f"epoch {self.trainingEpoch+1:03} / {self.trainingParams.max_epochs}")
print(
f"epoch {self.trainingEpoch+1:03} / {self.trainingParams.max_epochs:03}"
)
self.network.model.train()
epoch_loss = self.train_overSampleSpace_retLoss(self.trainingSpace)
print(f"epoch {self.trainingEpoch+1:03}, average loss: {epoch_loss:.4f}")
Expand All @@ -291,8 +237,8 @@ def train(
print("-" * 10)
print(
"Training complete: "
"best metric: {self.trainingLog.best_metric:.4f} "
"at epoch: {self.trainingLog.best_metric_epoch}"
f"best metric: {self.trainingLog.best_metric:.4f} "
f"at epoch: {self.trainingLog.best_metric_epoch}"
)

def inference_metricsProcess(self) -> float:
Expand Down Expand Up @@ -388,6 +334,9 @@ def slidingWindowInference_do(
input, roi_size, sw_batch_size, self.network.model
)
)
if index == 1:
print(f"inference input shape: {input.shape}")
print(f"inference output shape: {outputRaw.shape}")
if f_callback is not None:
metric = f_callback(sample, inferSpace, index, outputRaw)
return metric
Expand All @@ -399,12 +348,12 @@ def plot_bestModel(
index: int,
result: torch.Tensor,
) -> float:
print(f"Saving best model applied to validation sample {index}")
print(f"Plotting output of best model applied to validation sample {index}")
plotting.plot_bestModelOnValidate(
sample,
result,
str(index),
self.trainingParams.outputDir / f"bestModel-val-{index}.png",
self.trainingParams.outputDir / "validation" / f"bestModel-val-{index}.png",
)
return 0.0

Expand Down Expand Up @@ -441,17 +390,69 @@ def diceMetric_onValidationSpacing(
print(f"metric on original image spacing: {metric}")
return metric

def bestModel_evaluateImageSpacings(self, validationTransforms: Compose):
def bestModel_evaluateImageSpacings(self):
self.network.model.load_state_dict(
torch.load(str(self.trainingParams.modelPth))
)
validationOnOrigTransforms: Compose = (
transforms.validation_transformsOnOriginal()
)
self.validationSpace = self.loaderCache_create(
self.validationFileSet, validationOnOrigTransforms, 1
)
self.f_outputPost = transforms.transforms_build(
[
transforms.f_Invertd(validationTransforms),
transforms.f_Invertd(validationOnOrigTransforms),
transforms.f_predAsDiscreted(),
transforms.f_labelAsDiscreted(),
]
)
self.slidingWindowInference_do(
self.validationSpace, self.diceMetric_onValidationSpacing
)

def inference_post(
self,
sample: dict[str, MetaTensor | torch.Tensor],
space: data.LoaderCache,
index: int,
result: torch.Tensor,
) -> float:
sample["pred"] = result
sample = [self.f_outputPost(i) for i in decollate_batch(sample)]
prediction = from_engine(["pred"])(sample)
fi = transforms.f_LoadImage()
input = fi(prediction[0].meta["filename_or_obj"])
Ti = torch.as_tensor(input)
plotting.plot_infer(
Ti,
prediction,
f"{index}",
Path(
Path(self.network.options.outputdir)
/ "inference"
/ f"infer-{index}.png"
),
)

return 0.0

def infer_usingModel(self, model: Path):
# Check if model exists...
self.network.model.load_state_dict(torch.load(str(model)))
testingTransforms: Compose
testingTransforms = transforms.inferenceUse_transforms()
self.testingSpace = self.loaderCache_create(
self.testingFileSet, testingTransforms, 1
)
self.f_outputPost = transforms.transforms_build(
[
transforms.f_Invertd(testingTransforms),
transforms.f_predAsDiscreted(),
transforms.f_SaveImaged(
Path(self.network.options.outputdir) / "inference"
),
]
)

self.slidingWindowInference_do(self.testingSpace, self.inference_post)

0 comments on commit 3425b52

Please sign in to comment.