From ec3134c8d87862e1044867ae279e8cf82fcddc1c Mon Sep 17 00:00:00 2001
From: shota_mizusaki <nrxg129@gmail.com>
Date: Wed, 22 May 2024 00:48:19 +0900
Subject: [PATCH 1/4] first commit

---
 build/lib/pytorch3dunet/__init__.py           |   1 +
 build/lib/pytorch3dunet/__version__.py        |   1 +
 build/lib/pytorch3dunet/augment/__init__.py   |   0
 build/lib/pytorch3dunet/augment/transforms.py | 761 ++++++++++++++++++
 build/lib/pytorch3dunet/datasets/__init__.py  |   0
 build/lib/pytorch3dunet/datasets/dsb.py       | 108 +++
 build/lib/pytorch3dunet/datasets/hdf5.py      | 293 +++++++
 build/lib/pytorch3dunet/datasets/utils.py     | 361 +++++++++
 build/lib/pytorch3dunet/predict.py            |  59 ++
 build/lib/pytorch3dunet/train.py              |  35 +
 build/lib/pytorch3dunet/unet3d/__init__.py    |   0
 .../pytorch3dunet/unet3d/buildingblocks.py    | 545 +++++++++++++
 build/lib/pytorch3dunet/unet3d/config.py      |  79 ++
 build/lib/pytorch3dunet/unet3d/losses.py      | 345 ++++++++
 build/lib/pytorch3dunet/unet3d/metrics.py     | 445 ++++++++++
 build/lib/pytorch3dunet/unet3d/model.py       | 249 ++++++
 build/lib/pytorch3dunet/unet3d/predictor.py   | 281 +++++++
 build/lib/pytorch3dunet/unet3d/se.py          | 113 +++
 build/lib/pytorch3dunet/unet3d/seg_metrics.py | 123 +++
 build/lib/pytorch3dunet/unet3d/trainer.py     | 404 ++++++++++
 build/lib/pytorch3dunet/unet3d/utils.py       | 366 +++++++++
 21 files changed, 4569 insertions(+)
 create mode 100644 build/lib/pytorch3dunet/__init__.py
 create mode 100644 build/lib/pytorch3dunet/__version__.py
 create mode 100644 build/lib/pytorch3dunet/augment/__init__.py
 create mode 100644 build/lib/pytorch3dunet/augment/transforms.py
 create mode 100644 build/lib/pytorch3dunet/datasets/__init__.py
 create mode 100644 build/lib/pytorch3dunet/datasets/dsb.py
 create mode 100644 build/lib/pytorch3dunet/datasets/hdf5.py
 create mode 100644 build/lib/pytorch3dunet/datasets/utils.py
 create mode 100644 build/lib/pytorch3dunet/predict.py
 create mode 100644 build/lib/pytorch3dunet/train.py
 create mode 100644 build/lib/pytorch3dunet/unet3d/__init__.py
 create mode 100644 build/lib/pytorch3dunet/unet3d/buildingblocks.py
 create mode 100644 build/lib/pytorch3dunet/unet3d/config.py
 create mode 100644 build/lib/pytorch3dunet/unet3d/losses.py
 create mode 100644 build/lib/pytorch3dunet/unet3d/metrics.py
 create mode 100644 build/lib/pytorch3dunet/unet3d/model.py
 create mode 100644 build/lib/pytorch3dunet/unet3d/predictor.py
 create mode 100644 build/lib/pytorch3dunet/unet3d/se.py
 create mode 100644 build/lib/pytorch3dunet/unet3d/seg_metrics.py
 create mode 100644 build/lib/pytorch3dunet/unet3d/trainer.py
 create mode 100644 build/lib/pytorch3dunet/unet3d/utils.py

diff --git a/build/lib/pytorch3dunet/__init__.py b/build/lib/pytorch3dunet/__init__.py
new file mode 100644
index 00000000..9226fe7e
--- /dev/null
+++ b/build/lib/pytorch3dunet/__init__.py
@@ -0,0 +1 @@
+from .__version__ import __version__
diff --git a/build/lib/pytorch3dunet/__version__.py b/build/lib/pytorch3dunet/__version__.py
new file mode 100644
index 00000000..655be529
--- /dev/null
+++ b/build/lib/pytorch3dunet/__version__.py
@@ -0,0 +1 @@
+__version__ = '1.8.7'
diff --git a/build/lib/pytorch3dunet/augment/__init__.py b/build/lib/pytorch3dunet/augment/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/build/lib/pytorch3dunet/augment/transforms.py b/build/lib/pytorch3dunet/augment/transforms.py
new file mode 100644
index 00000000..527d596b
--- /dev/null
+++ b/build/lib/pytorch3dunet/augment/transforms.py
@@ -0,0 +1,761 @@
+import importlib
+import random
+
+import numpy as np
+import torch
+from scipy.ndimage import rotate, map_coordinates, gaussian_filter, convolve
+from skimage import measure
+from skimage.filters import gaussian
+from skimage.segmentation import find_boundaries
+
+# WARN: use fixed random state for reproducibility; if you want to randomize on each run seed with `time.time()` e.g.
+GLOBAL_RANDOM_STATE = np.random.RandomState(47)
+
+
+class Compose(object):
+    def __init__(self, transforms):
+        self.transforms = transforms
+
+    def __call__(self, m):
+        for t in self.transforms:
+            m = t(m)
+        return m
+
+
+class RandomFlip:
+    """
+    Randomly flips the image across the given axes. Image can be either 3D (DxHxW) or 4D (CxDxHxW).
+
+    When creating make sure that the provided RandomStates are consistent between raw and labeled datasets,
+    otherwise the models won't converge.
+    """
+
+    def __init__(self, random_state, axis_prob=0.5, **kwargs):
+        assert random_state is not None, 'RandomState cannot be None'
+        self.random_state = random_state
+        self.axes = (0, 1, 2)
+        self.axis_prob = axis_prob
+
+    def __call__(self, m):
+        assert m.ndim in [3, 4], 'Supports only 3D (DxHxW) or 4D (CxDxHxW) images'
+
+        for axis in self.axes:
+            if self.random_state.uniform() > self.axis_prob:
+                if m.ndim == 3:
+                    m = np.flip(m, axis)
+                else:
+                    channels = [np.flip(m[c], axis) for c in range(m.shape[0])]
+                    m = np.stack(channels, axis=0)
+
+        return m
+
+
+class RandomRotate90:
+    """
+    Rotate an array by 90 degrees around a randomly chosen plane. Image can be either 3D (DxHxW) or 4D (CxDxHxW).
+
+    When creating make sure that the provided RandomStates are consistent between raw and labeled datasets,
+    otherwise the models won't converge.
+
+    IMPORTANT: assumes DHW axis order (that's why rotation is performed across (1,2) axis)
+    """
+
+    def __init__(self, random_state, **kwargs):
+        self.random_state = random_state
+        # always rotate around z-axis
+        self.axis = (1, 2)
+
+    def __call__(self, m):
+        assert m.ndim in [3, 4], 'Supports only 3D (DxHxW) or 4D (CxDxHxW) images'
+
+        # pick number of rotations at random
+        k = self.random_state.randint(0, 4)
+        # rotate k times around a given plane
+        if m.ndim == 3:
+            m = np.rot90(m, k, self.axis)
+        else:
+            channels = [np.rot90(m[c], k, self.axis) for c in range(m.shape[0])]
+            m = np.stack(channels, axis=0)
+
+        return m
+
+
+class RandomRotate:
+    """
+    Rotate an array by a random degrees from taken from (-angle_spectrum, angle_spectrum) interval.
+    Rotation axis is picked at random from the list of provided axes.
+    """
+
+    def __init__(self, random_state, angle_spectrum=30, axes=None, mode='reflect', order=0, **kwargs):
+        if axes is None:
+            axes = [(1, 0), (2, 1), (2, 0)]
+        else:
+            assert isinstance(axes, list) and len(axes) > 0
+
+        self.random_state = random_state
+        self.angle_spectrum = angle_spectrum
+        self.axes = axes
+        self.mode = mode
+        self.order = order
+
+    def __call__(self, m):
+        axis = self.axes[self.random_state.randint(len(self.axes))]
+        angle = self.random_state.randint(-self.angle_spectrum, self.angle_spectrum)
+
+        if m.ndim == 3:
+            m = rotate(m, angle, axes=axis, reshape=False, order=self.order, mode=self.mode, cval=-1)
+        else:
+            channels = [rotate(m[c], angle, axes=axis, reshape=False, order=self.order, mode=self.mode, cval=-1) for c
+                        in range(m.shape[0])]
+            m = np.stack(channels, axis=0)
+
+        return m
+
+
+class RandomContrast:
+    """
+    Adjust contrast by scaling each voxel to `mean + alpha * (v - mean)`.
+    """
+
+    def __init__(self, random_state, alpha=(0.5, 1.5), mean=0.0, execution_probability=0.1, **kwargs):
+        self.random_state = random_state
+        assert len(alpha) == 2
+        self.alpha = alpha
+        self.mean = mean
+        self.execution_probability = execution_probability
+
+    def __call__(self, m):
+        if self.random_state.uniform() < self.execution_probability:
+            alpha = self.random_state.uniform(self.alpha[0], self.alpha[1])
+            result = self.mean + alpha * (m - self.mean)
+            return np.clip(result, -1, 1)
+
+        return m
+
+
+# it's relatively slow, i.e. ~1s per patch of size 64x200x200, so use multiple workers in the DataLoader
+# remember to use spline_order=0 when transforming the labels
+class ElasticDeformation:
+    """
+    Apply elasitc deformations of 3D patches on a per-voxel mesh. Assumes ZYX axis order (or CZYX if the data is 4D).
+    Based on: https://github.com/fcalvet/image_tools/blob/master/image_augmentation.py#L62
+    """
+
+    def __init__(self, random_state, spline_order, alpha=2000, sigma=50, execution_probability=0.1, apply_3d=True,
+                 **kwargs):
+        """
+        :param spline_order: the order of spline interpolation (use 0 for labeled images)
+        :param alpha: scaling factor for deformations
+        :param sigma: smoothing factor for Gaussian filter
+        :param execution_probability: probability of executing this transform
+        :param apply_3d: if True apply deformations in each axis
+        """
+        self.random_state = random_state
+        self.spline_order = spline_order
+        self.alpha = alpha
+        self.sigma = sigma
+        self.execution_probability = execution_probability
+        self.apply_3d = apply_3d
+
+    def __call__(self, m):
+        if self.random_state.uniform() < self.execution_probability:
+            assert m.ndim in [3, 4]
+
+            if m.ndim == 3:
+                volume_shape = m.shape
+            else:
+                volume_shape = m[0].shape
+
+            if self.apply_3d:
+                dz = gaussian_filter(self.random_state.randn(*volume_shape), self.sigma, mode="reflect") * self.alpha
+            else:
+                dz = np.zeros_like(m)
+
+            dy, dx = [
+                gaussian_filter(
+                    self.random_state.randn(*volume_shape),
+                    self.sigma, mode="reflect"
+                ) * self.alpha for _ in range(2)
+            ]
+
+            z_dim, y_dim, x_dim = volume_shape
+            z, y, x = np.meshgrid(np.arange(z_dim), np.arange(y_dim), np.arange(x_dim), indexing='ij')
+            indices = z + dz, y + dy, x + dx
+
+            if m.ndim == 3:
+                return map_coordinates(m, indices, order=self.spline_order, mode='reflect')
+            else:
+                channels = [map_coordinates(c, indices, order=self.spline_order, mode='reflect') for c in m]
+                return np.stack(channels, axis=0)
+
+        return m
+
+
+class CropToFixed:
+    def __init__(self, random_state, size=(256, 256), centered=False, **kwargs):
+        self.random_state = random_state
+        self.crop_y, self.crop_x = size
+        self.centered = centered
+
+    def __call__(self, m):
+        def _padding(pad_total):
+            half_total = pad_total // 2
+            return (half_total, pad_total - half_total)
+
+        def _rand_range_and_pad(crop_size, max_size):
+            """
+            Returns a tuple:
+                max_value (int) for the corner dimension. The corner dimension is chosen as `self.random_state(max_value)`
+                pad (int): padding in both directions; if crop_size is lt max_size the pad is 0
+            """
+            if crop_size < max_size:
+                return max_size - crop_size, (0, 0)
+            else:
+                return 1, _padding(crop_size - max_size)
+
+        def _start_and_pad(crop_size, max_size):
+            if crop_size < max_size:
+                return (max_size - crop_size) // 2, (0, 0)
+            else:
+                return 0, _padding(crop_size - max_size)
+
+        assert m.ndim in (3, 4)
+        if m.ndim == 3:
+            _, y, x = m.shape
+        else:
+            _, _, y, x = m.shape
+
+        if not self.centered:
+            y_range, y_pad = _rand_range_and_pad(self.crop_y, y)
+            x_range, x_pad = _rand_range_and_pad(self.crop_x, x)
+
+            y_start = self.random_state.randint(y_range)
+            x_start = self.random_state.randint(x_range)
+
+        else:
+            y_start, y_pad = _start_and_pad(self.crop_y, y)
+            x_start, x_pad = _start_and_pad(self.crop_x, x)
+
+        if m.ndim == 3:
+            result = m[:, y_start:y_start + self.crop_y, x_start:x_start + self.crop_x]
+            return np.pad(result, pad_width=((0, 0), y_pad, x_pad), mode='reflect')
+        else:
+            channels = []
+            for c in range(m.shape[0]):
+                result = m[c][:, y_start:y_start + self.crop_y, x_start:x_start + self.crop_x]
+                channels.append(np.pad(result, pad_width=((0, 0), y_pad, x_pad), mode='reflect'))
+            return np.stack(channels, axis=0)
+
+
+class AbstractLabelToBoundary:
+    AXES_TRANSPOSE = [
+        (0, 1, 2),  # X
+        (0, 2, 1),  # Y
+        (2, 0, 1)  # Z
+    ]
+
+    def __init__(self, ignore_index=None, aggregate_affinities=False, append_label=False, **kwargs):
+        """
+        :param ignore_index: label to be ignored in the output, i.e. after computing the boundary the label ignore_index
+            will be restored where is was in the patch originally
+        :param aggregate_affinities: aggregate affinities with the same offset across Z,Y,X axes
+        :param append_label: if True append the orignal ground truth labels to the last channel
+        :param blur: Gaussian blur the boundaries
+        :param sigma: standard deviation for Gaussian kernel
+        """
+        self.ignore_index = ignore_index
+        self.aggregate_affinities = aggregate_affinities
+        self.append_label = append_label
+
+    def __call__(self, m):
+        """
+        Extract boundaries from a given 3D label tensor.
+        :param m: input 3D tensor
+        :return: binary mask, with 1-label corresponding to the boundary and 0-label corresponding to the background
+        """
+        assert m.ndim == 3
+
+        kernels = self.get_kernels()
+        boundary_arr = [np.where(np.abs(convolve(m, kernel)) > 0, 1, 0) for kernel in kernels]
+        channels = np.stack(boundary_arr)
+        results = []
+        if self.aggregate_affinities:
+            assert len(kernels) % 3 == 0, "Number of kernels must be divided by 3 (one kernel per offset per Z,Y,X axes"
+            # aggregate affinities with the same offset
+            for i in range(0, len(kernels), 3):
+                # merge across X,Y,Z axes (logical OR)
+                xyz_aggregated_affinities = np.logical_or.reduce(channels[i:i + 3, ...]).astype(np.int32)
+                # recover ignore index
+                xyz_aggregated_affinities = _recover_ignore_index(xyz_aggregated_affinities, m, self.ignore_index)
+                results.append(xyz_aggregated_affinities)
+        else:
+            results = [_recover_ignore_index(channels[i], m, self.ignore_index) for i in range(channels.shape[0])]
+
+        if self.append_label:
+            # append original input data
+            results.append(m)
+
+        # stack across channel dim
+        return np.stack(results, axis=0)
+
+    @staticmethod
+    def create_kernel(axis, offset):
+        # create conv kernel
+        k_size = offset + 1
+        k = np.zeros((1, 1, k_size), dtype=np.int32)
+        k[0, 0, 0] = 1
+        k[0, 0, offset] = -1
+        return np.transpose(k, axis)
+
+    def get_kernels(self):
+        raise NotImplementedError
+
+
+class StandardLabelToBoundary:
+    def __init__(self, ignore_index=None, append_label=False, mode='thick', foreground=False,
+                 **kwargs):
+        self.ignore_index = ignore_index
+        self.append_label = append_label
+        self.mode = mode
+        self.foreground = foreground
+
+    def __call__(self, m):
+        assert m.ndim == 3
+
+        boundaries = find_boundaries(m, connectivity=2, mode=self.mode)
+        boundaries = boundaries.astype('int32')
+
+        results = []
+        if self.foreground:
+            foreground = (m > 0).astype('uint8')
+            results.append(_recover_ignore_index(foreground, m, self.ignore_index))
+
+        results.append(_recover_ignore_index(boundaries, m, self.ignore_index))
+
+        if self.append_label:
+            # append original input data
+            results.append(m)
+
+        return np.stack(results, axis=0)
+
+
+class BlobsToMask:
+    """
+    Returns binary mask from labeled image, i.e. every label greater than 0 is treated as foreground.
+
+    """
+
+    def __init__(self, append_label=False, boundary=False, cross_entropy=False, **kwargs):
+        self.cross_entropy = cross_entropy
+        self.boundary = boundary
+        self.append_label = append_label
+
+    def __call__(self, m):
+        assert m.ndim == 3
+
+        # get the segmentation mask
+        mask = (m > 0).astype('uint8')
+        results = [mask]
+
+        if self.boundary:
+            outer = find_boundaries(m, connectivity=2, mode='outer')
+            if self.cross_entropy:
+                # boundary is class 2
+                mask[outer > 0] = 2
+                results = [mask]
+            else:
+                results.append(outer)
+
+        if self.append_label:
+            results.append(m)
+
+        return np.stack(results, axis=0)
+
+
+class RandomLabelToAffinities(AbstractLabelToBoundary):
+    """
+    Converts a given volumetric label array to binary mask corresponding to borders between labels.
+    One specify the max_offset (thickness) of the border. Then the offset is picked at random every time you call
+    the transformer (offset is picked form the range 1:max_offset) for each axis and the boundary computed.
+    One may use this scheme in order to make the network more robust against various thickness of borders in the ground
+    truth  (think of it as a boundary denoising scheme).
+    """
+
+    def __init__(self, random_state, max_offset=10, ignore_index=None, append_label=False, z_offset_scale=2, **kwargs):
+        super().__init__(ignore_index=ignore_index, append_label=append_label, aggregate_affinities=False)
+        self.random_state = random_state
+        self.offsets = tuple(range(1, max_offset + 1))
+        self.z_offset_scale = z_offset_scale
+
+    def get_kernels(self):
+        rand_offset = self.random_state.choice(self.offsets)
+        axis_ind = self.random_state.randint(3)
+        # scale down z-affinities due to anisotropy
+        if axis_ind == 2:
+            rand_offset = max(1, rand_offset // self.z_offset_scale)
+
+        rand_axis = self.AXES_TRANSPOSE[axis_ind]
+        # return a single kernel
+        return [self.create_kernel(rand_axis, rand_offset)]
+
+
+class LabelToAffinities(AbstractLabelToBoundary):
+    """
+    Converts a given volumetric label array to binary mask corresponding to borders between labels (which can be seen
+    as an affinity graph: https://arxiv.org/pdf/1706.00120.pdf)
+    One specify the offsets (thickness) of the border. The boundary will be computed via the convolution operator.
+    """
+
+    def __init__(self, offsets, ignore_index=None, append_label=False, aggregate_affinities=False, z_offsets=None,
+                 **kwargs):
+        super().__init__(ignore_index=ignore_index, append_label=append_label,
+                         aggregate_affinities=aggregate_affinities)
+
+        assert isinstance(offsets, list) or isinstance(offsets, tuple), 'offsets must be a list or a tuple'
+        assert all(a > 0 for a in offsets), "'offsets must be positive"
+        assert len(set(offsets)) == len(offsets), "'offsets' must be unique"
+        if z_offsets is not None:
+            assert len(offsets) == len(z_offsets), 'z_offsets length must be the same as the length of offsets'
+        else:
+            # if z_offsets is None just use the offsets for z-affinities
+            z_offsets = list(offsets)
+        self.z_offsets = z_offsets
+
+        self.kernels = []
+        # create kernel for every axis-offset pair
+        for xy_offset, z_offset in zip(offsets, z_offsets):
+            for axis_ind, axis in enumerate(self.AXES_TRANSPOSE):
+                final_offset = xy_offset
+                if axis_ind == 2:
+                    final_offset = z_offset
+                # create kernels for a given offset in every direction
+                self.kernels.append(self.create_kernel(axis, final_offset))
+
+    def get_kernels(self):
+        return self.kernels
+
+
+class LabelToZAffinities(AbstractLabelToBoundary):
+    """
+    Converts a given volumetric label array to binary mask corresponding to borders between labels (which can be seen
+    as an affinity graph: https://arxiv.org/pdf/1706.00120.pdf)
+    One specify the offsets (thickness) of the border. The boundary will be computed via the convolution operator.
+    """
+
+    def __init__(self, offsets, ignore_index=None, append_label=False, **kwargs):
+        super().__init__(ignore_index=ignore_index, append_label=append_label)
+
+        assert isinstance(offsets, list) or isinstance(offsets, tuple), 'offsets must be a list or a tuple'
+        assert all(a > 0 for a in offsets), "'offsets must be positive"
+        assert len(set(offsets)) == len(offsets), "'offsets' must be unique"
+
+        self.kernels = []
+        z_axis = self.AXES_TRANSPOSE[2]
+        # create kernels
+        for z_offset in offsets:
+            self.kernels.append(self.create_kernel(z_axis, z_offset))
+
+    def get_kernels(self):
+        return self.kernels
+
+
+class LabelToBoundaryAndAffinities:
+    """
+    Combines the StandardLabelToBoundary and LabelToAffinities in the hope
+    that that training the network to predict both would improve the main task: boundary prediction.
+    """
+
+    def __init__(self, xy_offsets, z_offsets, append_label=False, blur=False, sigma=1, ignore_index=None, mode='thick',
+                 foreground=False, **kwargs):
+        # blur only StandardLabelToBoundary results; we don't want to blur the affinities
+        self.l2b = StandardLabelToBoundary(blur=blur, sigma=sigma, ignore_index=ignore_index, mode=mode,
+                                           foreground=foreground)
+        self.l2a = LabelToAffinities(offsets=xy_offsets, z_offsets=z_offsets, append_label=append_label,
+                                     ignore_index=ignore_index)
+
+    def __call__(self, m):
+        boundary = self.l2b(m)
+        affinities = self.l2a(m)
+        return np.concatenate((boundary, affinities), axis=0)
+
+
+class LabelToMaskAndAffinities:
+    def __init__(self, xy_offsets, z_offsets, append_label=False, background=0, ignore_index=None, **kwargs):
+        self.background = background
+        self.l2a = LabelToAffinities(offsets=xy_offsets, z_offsets=z_offsets, append_label=append_label,
+                                     ignore_index=ignore_index)
+
+    def __call__(self, m):
+        mask = m > self.background
+        mask = np.expand_dims(mask.astype(np.uint8), axis=0)
+        affinities = self.l2a(m)
+        return np.concatenate((mask, affinities), axis=0)
+
+
+class Standardize:
+    """
+    Apply Z-score normalization to a given input tensor, i.e. re-scaling the values to be 0-mean and 1-std.
+    """
+
+    def __init__(self, eps=1e-10, mean=None, std=None, channelwise=False, **kwargs):
+        if mean is not None or std is not None:
+            assert mean is not None and std is not None
+        self.mean = mean
+        self.std = std
+        self.eps = eps
+        self.channelwise = channelwise
+
+    def __call__(self, m):
+        if self.mean is not None:
+            mean, std = self.mean, self.std
+        else:
+            if self.channelwise:
+                # normalize per-channel
+                axes = list(range(m.ndim))
+                # average across channels
+                axes = tuple(axes[1:])
+                mean = np.mean(m, axis=axes, keepdims=True)
+                std = np.std(m, axis=axes, keepdims=True)
+            else:
+                mean = np.mean(m)
+                std = np.std(m)
+
+        return (m - mean) / np.clip(std, a_min=self.eps, a_max=None)
+
+
+class PercentileNormalizer:
+    def __init__(self, pmin=1, pmax=99.6, channelwise=False, eps=1e-10, **kwargs):
+        self.eps = eps
+        self.pmin = pmin
+        self.pmax = pmax
+        self.channelwise = channelwise
+
+    def __call__(self, m):
+        if self.channelwise:
+            axes = list(range(m.ndim))
+            # average across channels
+            axes = tuple(axes[1:])
+            pmin = np.percentile(m, self.pmin, axis=axes, keepdims=True)
+            pmax = np.percentile(m, self.pmax, axis=axes, keepdims=True)
+        else:
+            pmin = np.percentile(m, self.pmin)
+            pmax = np.percentile(m, self.pmax)
+
+        return (m - pmin) / (pmax - pmin + self.eps)
+
+
+class Normalize:
+    """
+    Apply simple min-max scaling to a given input tensor, i.e. shrinks the range of the data
+    in a fixed range of [-1, 1] or in case of norm01==True to [0, 1]. In addition, data can be
+    clipped by specifying min_value/max_value either globally using single values or via a
+    list/tuple channelwise if enabled.
+    """
+
+    def __init__(self, min_value=None, max_value=None, norm01=False, channelwise=False,
+                 eps=1e-10, **kwargs):
+        if min_value is not None and max_value is not None:
+            assert max_value > min_value
+        self.min_value = min_value
+        self.max_value = max_value
+        self.norm01 = norm01
+        self.channelwise = channelwise
+        self.eps = eps
+
+    def __call__(self, m):
+        if self.channelwise:
+            # get min/max channelwise
+            axes = list(range(m.ndim))
+            axes = tuple(axes[1:])
+            if self.min_value is None or 'None' in self.min_value:
+                min_value = np.min(m, axis=axes, keepdims=True)
+
+            if self.max_value is None or 'None' in self.max_value:
+                max_value = np.max(m, axis=axes, keepdims=True)
+
+            # check if non None in self.min_value/self.max_value
+            # if present and if so copy value to min_value
+            if self.min_value is not None:
+                for i,v in enumerate(self.min_value):
+                    if v != 'None':
+                        min_value[i] = v
+
+            if self.max_value is not None:
+                for i,v in enumerate(self.max_value):
+                    if v != 'None':
+                        max_value[i] = v
+        else:
+            if self.min_value is None:
+                min_value = np.min(m)
+            else:
+                min_value = self.min_value
+
+            if self.max_value is None:
+                max_value = np.max(m)
+            else:
+                max_value = self.max_value
+
+        # calculate norm_0_1 with min_value / max_value with the same dimension
+        # in case of channelwise application
+        norm_0_1 = (m - min_value) / (max_value - min_value + self.eps)
+
+        if self.norm01 is True:
+          return np.clip(norm_0_1, 0, 1)
+        else:
+          return np.clip(2 * norm_0_1 - 1, -1, 1)
+
+
+class AdditiveGaussianNoise:
+    def __init__(self, random_state, scale=(0.0, 1.0), execution_probability=0.1, **kwargs):
+        self.execution_probability = execution_probability
+        self.random_state = random_state
+        self.scale = scale
+
+    def __call__(self, m):
+        if self.random_state.uniform() < self.execution_probability:
+            std = self.random_state.uniform(self.scale[0], self.scale[1])
+            gaussian_noise = self.random_state.normal(0, std, size=m.shape)
+            return m + gaussian_noise
+        return m
+
+
+class AdditivePoissonNoise:
+    def __init__(self, random_state, lam=(0.0, 1.0), execution_probability=0.1, **kwargs):
+        self.execution_probability = execution_probability
+        self.random_state = random_state
+        self.lam = lam
+
+    def __call__(self, m):
+        if self.random_state.uniform() < self.execution_probability:
+            lam = self.random_state.uniform(self.lam[0], self.lam[1])
+            poisson_noise = self.random_state.poisson(lam, size=m.shape)
+            return m + poisson_noise
+        return m
+
+
+class ToTensor:
+    """
+    Converts a given input numpy.ndarray into torch.Tensor.
+
+    Args:
+        expand_dims (bool): if True, adds a channel dimension to the input data
+        dtype (np.dtype): the desired output data type
+    """
+
+    def __init__(self, expand_dims, dtype=np.float32, **kwargs):
+        self.expand_dims = expand_dims
+        self.dtype = dtype
+
+    def __call__(self, m):
+        assert m.ndim in [3, 4], 'Supports only 3D (DxHxW) or 4D (CxDxHxW) images'
+        # add channel dimension
+        if self.expand_dims and m.ndim == 3:
+            m = np.expand_dims(m, axis=0)
+
+        return torch.from_numpy(m.astype(dtype=self.dtype))
+
+
+class Relabel:
+    """
+    Relabel a numpy array of labels into a consecutive numbers, e.g.
+    [10, 10, 0, 6, 6] -> [2, 2, 0, 1, 1]. Useful when one has an instance segmentation volume
+    at hand and would like to create a one-hot-encoding for it. Without a consecutive labeling the task would be harder.
+    """
+
+    def __init__(self, append_original=False, run_cc=True, ignore_label=None, **kwargs):
+        self.append_original = append_original
+        self.ignore_label = ignore_label
+        self.run_cc = run_cc
+
+        if ignore_label is not None:
+            assert append_original, "ignore_label present, so append_original must be true, so that one can localize the ignore region"
+
+    def __call__(self, m):
+        orig = m
+        if self.run_cc:
+            # assign 0 to the ignore region
+            m = measure.label(m, background=self.ignore_label)
+
+        _, unique_labels = np.unique(m, return_inverse=True)
+        result = unique_labels.reshape(m.shape)
+        if self.append_original:
+            result = np.stack([result, orig])
+        return result
+
+
+class Identity:
+    def __init__(self, **kwargs):
+        pass
+
+    def __call__(self, m):
+        return m
+
+
+class RgbToLabel:
+    def __call__(self, img):
+        img = np.array(img)
+        assert img.ndim == 3 and img.shape[2] == 3
+        result = img[..., 0] * 65536 + img[..., 1] * 256 + img[..., 2]
+        return result
+
+
+class LabelToTensor:
+    def __call__(self, m):
+        m = np.array(m)
+        return torch.from_numpy(m.astype(dtype='int64'))
+
+
+class GaussianBlur3D:
+    def __init__(self, sigma=[.1, 2.], execution_probability=0.5, **kwargs):
+        self.sigma = sigma
+        self.execution_probability = execution_probability
+
+    def __call__(self, x):
+        if random.random() < self.execution_probability:
+            sigma = random.uniform(self.sigma[0], self.sigma[1])
+            x = gaussian(x, sigma=sigma)
+            return x
+        return x
+
+
+class Transformer:
+    def __init__(self, phase_config, base_config):
+        self.phase_config = phase_config
+        self.config_base = base_config
+        self.seed = GLOBAL_RANDOM_STATE.randint(10000000)
+
+    def raw_transform(self):
+        return self._create_transform('raw')
+
+    def label_transform(self):
+        return self._create_transform('label')
+
+    def weight_transform(self):
+        return self._create_transform('weight')
+
+    @staticmethod
+    def _transformer_class(class_name):
+        m = importlib.import_module('pytorch3dunet.augment.transforms')
+        clazz = getattr(m, class_name)
+        return clazz
+
+    def _create_transform(self, name):
+        assert name in self.phase_config, f'Could not find {name} transform'
+        return Compose([
+            self._create_augmentation(c) for c in self.phase_config[name]
+        ])
+
+    def _create_augmentation(self, c):
+        config = dict(self.config_base)
+        config.update(c)
+        config['random_state'] = np.random.RandomState(self.seed)
+        aug_class = self._transformer_class(config['name'])
+        return aug_class(**config)
+
+
+def _recover_ignore_index(input, orig, ignore_index):
+    if ignore_index is not None:
+        mask = orig == ignore_index
+        input[mask] = ignore_index
+
+    return input
diff --git a/build/lib/pytorch3dunet/datasets/__init__.py b/build/lib/pytorch3dunet/datasets/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/build/lib/pytorch3dunet/datasets/dsb.py b/build/lib/pytorch3dunet/datasets/dsb.py
new file mode 100644
index 00000000..5d0cde86
--- /dev/null
+++ b/build/lib/pytorch3dunet/datasets/dsb.py
@@ -0,0 +1,108 @@
+import collections
+import os
+
+import imageio
+import numpy as np
+import torch
+
+from pytorch3dunet.augment import transforms
+from pytorch3dunet.datasets.utils import ConfigDataset, calculate_stats
+from pytorch3dunet.unet3d.utils import get_logger
+
+logger = get_logger('DSB2018Dataset')
+
+
+def dsb_prediction_collate(batch):
+    """
+    Forms a mini-batch of (images, paths) during test time for the DSB-like datasets.
+    """
+    error_msg = "batch must contain tensors or str; found {}"
+    if isinstance(batch[0], torch.Tensor):
+        return torch.stack(batch, 0)
+    elif isinstance(batch[0], str):
+        return list(batch)
+    elif isinstance(batch[0], collections.Sequence):
+        # transpose tuples, i.e. [[1, 2], ['a', 'b']] to be [[1, 'a'], [2, 'b']]
+        transposed = zip(*batch)
+        return [dsb_prediction_collate(samples) for samples in transposed]
+
+    raise TypeError((error_msg.format(type(batch[0]))))
+
+
+class DSB2018Dataset(ConfigDataset):
+    def __init__(self, root_dir, phase, transformer_config, expand_dims=True):
+        assert os.path.isdir(root_dir), f'{root_dir} is not a directory'
+        assert phase in ['train', 'val', 'test']
+
+        self.phase = phase
+
+        # load raw images
+        images_dir = os.path.join(root_dir, 'images')
+        assert os.path.isdir(images_dir)
+        self.images, self.paths = self._load_files(images_dir, expand_dims)
+        self.file_path = images_dir
+
+        stats = calculate_stats(self.images, True)
+
+        transformer = transforms.Transformer(transformer_config, stats)
+
+        # load raw images transformer
+        self.raw_transform = transformer.raw_transform()
+
+        if phase != 'test':
+            # load labeled images
+            masks_dir = os.path.join(root_dir, 'masks')
+            assert os.path.isdir(masks_dir)
+            self.masks, _ = self._load_files(masks_dir, expand_dims)
+            assert len(self.images) == len(self.masks)
+            # load label images transformer
+            self.masks_transform = transformer.label_transform()
+        else:
+            self.masks = None
+            self.masks_transform = None
+
+    def __getitem__(self, idx):
+        if idx >= len(self):
+            raise StopIteration
+
+        img = self.images[idx]
+        if self.phase != 'test':
+            mask = self.masks[idx]
+            return self.raw_transform(img), self.masks_transform(mask)
+        else:
+            return self.raw_transform(img), self.paths[idx]
+
+    def __len__(self):
+        return len(self.images)
+
+    @classmethod
+    def prediction_collate(cls, batch):
+        return dsb_prediction_collate(batch)
+
+    @classmethod
+    def create_datasets(cls, dataset_config, phase):
+        phase_config = dataset_config[phase]
+        # load data augmentation configuration
+        transformer_config = phase_config['transformer']
+        # load files to process
+        file_paths = phase_config['file_paths']
+        expand_dims = dataset_config.get('expand_dims', True)
+        return [cls(file_paths[0], phase, transformer_config, expand_dims)]
+
+    @staticmethod
+    def _load_files(dir, expand_dims):
+        files_data = []
+        paths = []
+        for file in os.listdir(dir):
+            path = os.path.join(dir, file)
+            img = np.asarray(imageio.imread(path))
+            if expand_dims:
+                dims = img.ndim
+                img = np.expand_dims(img, axis=0)
+                if dims == 3:
+                    img = np.transpose(img, (3, 0, 1, 2))
+
+            files_data.append(img)
+            paths.append(path)
+
+        return files_data, paths
diff --git a/build/lib/pytorch3dunet/datasets/hdf5.py b/build/lib/pytorch3dunet/datasets/hdf5.py
new file mode 100644
index 00000000..040adb85
--- /dev/null
+++ b/build/lib/pytorch3dunet/datasets/hdf5.py
@@ -0,0 +1,293 @@
+import glob
+import os
+from abc import abstractmethod
+from itertools import chain
+
+import h5py
+
+import pytorch3dunet.augment.transforms as transforms
+from pytorch3dunet.datasets.utils import get_slice_builder, ConfigDataset, calculate_stats, mirror_pad
+from pytorch3dunet.unet3d.utils import get_logger
+
+logger = get_logger('HDF5Dataset')
+
+
+def _create_padded_indexes(indexes, halo_shape):
+    return tuple(slice(index.start, index.stop + 2 * halo) for index, halo in zip(indexes, halo_shape))
+
+
+def traverse_h5_paths(file_paths):
+    assert isinstance(file_paths, list)
+    results = []
+    for file_path in file_paths:
+        if os.path.isdir(file_path):
+            # if file path is a directory take all H5 files in that directory
+            iters = [glob.glob(os.path.join(file_path, ext)) for ext in ['*.h5', '*.hdf', '*.hdf5', '*.hd5']]
+            for fp in chain(*iters):
+                results.append(fp)
+        else:
+            results.append(file_path)
+    return results
+
+
+class AbstractHDF5Dataset(ConfigDataset):
+    """
+    Implementation of torch.utils.data.Dataset backed by the HDF5 files, which iterates over the raw and label datasets
+    patch by patch with a given stride.
+
+    Args:
+        file_path (str): path to H5 file containing raw data as well as labels and per pixel weights (optional)
+        phase (str): 'train' for training, 'val' for validation, 'test' for testing
+        slice_builder_config (dict): configuration of the SliceBuilder
+        transformer_config (dict): data augmentation configuration
+        raw_internal_path (str or list): H5 internal path to the raw dataset
+        label_internal_path (str or list): H5 internal path to the label dataset
+        weight_internal_path (str or list): H5 internal path to the per pixel weights (optional)
+        global_normalization (bool): if True, the mean and std of the raw data will be calculated over the whole dataset
+    """
+
+    def __init__(self, file_path, phase, slice_builder_config, transformer_config, raw_internal_path='raw',
+                 label_internal_path='label', weight_internal_path=None, global_normalization=True):
+        assert phase in ['train', 'val', 'test']
+
+        self.phase = phase
+        self.file_path = file_path
+        self.raw_internal_path = raw_internal_path
+        self.label_internal_path = label_internal_path
+        self.weight_internal_path = weight_internal_path
+
+        self.halo_shape = slice_builder_config.get('halo_shape', [0, 0, 0])
+
+        if global_normalization:
+            logger.info('Calculating mean and std of the raw data...')
+            with h5py.File(file_path, 'r') as f:
+                raw = f[raw_internal_path][:]
+                stats = calculate_stats(raw)
+        else:
+            stats = calculate_stats(None, True)
+
+        self.transformer = transforms.Transformer(transformer_config, stats)
+        self.raw_transform = self.transformer.raw_transform()
+
+        if phase != 'test':
+            # create label/weight transform only in train/val phase
+            self.label_transform = self.transformer.label_transform()
+
+            if weight_internal_path is not None:
+                self.weight_transform = self.transformer.weight_transform()
+            else:
+                self.weight_transform = None
+
+            self._check_volume_sizes()
+        else:
+            # 'test' phase used only for predictions so ignore the label dataset
+            self.label = None
+            self.weight_map = None
+
+            # compare patch and stride configuration
+            patch_shape = slice_builder_config.get('patch_shape')
+            stride_shape = slice_builder_config.get('stride_shape')
+            if sum(self.halo_shape) != 0 and patch_shape != stride_shape:
+                logger.warning(f'Found non-zero halo shape {self.halo_shape}. '
+                               f'In this case: patch shape and stride shape should be equal for optimal prediction '
+                               f'performance, but found patch_shape: {patch_shape} and stride_shape: {stride_shape}!')
+
+        with h5py.File(file_path, 'r') as f:
+            raw = f[raw_internal_path]
+            label = f[label_internal_path] if phase != 'test' else None
+            weight_map = f[weight_internal_path] if weight_internal_path is not None else None
+            # build slice indices for raw and label data sets
+            slice_builder = get_slice_builder(raw, label, weight_map, slice_builder_config)
+            self.raw_slices = slice_builder.raw_slices
+            self.label_slices = slice_builder.label_slices
+            self.weight_slices = slice_builder.weight_slices
+
+        self.patch_count = len(self.raw_slices)
+        logger.info(f'Number of patches: {self.patch_count}')
+
+    @abstractmethod
+    def get_raw_patch(self, idx):
+        raise NotImplementedError
+
+    @abstractmethod
+    def get_label_patch(self, idx):
+        raise NotImplementedError
+
+    @abstractmethod
+    def get_weight_patch(self, idx):
+        raise NotImplementedError
+
+    @abstractmethod
+    def get_raw_padded_patch(self, idx):
+        raise NotImplementedError
+
+    def volume_shape(self):
+        with h5py.File(self.file_path, 'r') as f:
+            raw = f[self.raw_internal_path]
+            if raw.ndim == 3:
+                return raw.shape
+            else:
+                return raw.shape[1:]
+
+    def __getitem__(self, idx):
+        if idx >= len(self):
+            raise StopIteration
+
+        raw_idx = self.raw_slices[idx]
+
+        if self.phase == 'test':
+            if len(raw_idx) == 4:
+                # discard the channel dimension in the slices: predictor requires only the spatial dimensions of the volume
+                raw_idx = raw_idx[1:]  # Remove the first element if raw_idx has 4 elements
+                raw_idx_padded = (slice(None),) + _create_padded_indexes(raw_idx, self.halo_shape)
+            else:
+                raw_idx_padded = _create_padded_indexes(raw_idx, self.halo_shape)
+
+            raw_patch_transformed = self.raw_transform(self.get_raw_padded_patch(raw_idx_padded))
+            return raw_patch_transformed, raw_idx
+        else:
+            raw_patch_transformed = self.raw_transform(self.get_raw_patch(raw_idx))
+
+            # get the slice for a given index 'idx'
+            label_idx = self.label_slices[idx]
+            label_patch_transformed = self.label_transform(self.get_label_patch(label_idx))
+            if self.weight_internal_path is not None:
+                weight_idx = self.weight_slices[idx]
+                weight_patch_transformed = self.weight_transform(self.get_weight_patch(weight_idx))
+                return raw_patch_transformed, label_patch_transformed, weight_patch_transformed
+            # return the transformed raw and label patches
+            return raw_patch_transformed, label_patch_transformed
+
+    def __len__(self):
+        return self.patch_count
+
+    def _check_volume_sizes(self):
+        def _volume_shape(volume):
+            if volume.ndim == 3:
+                return volume.shape
+            return volume.shape[1:]
+
+        with h5py.File(self.file_path, 'r') as f:
+            raw = f[self.raw_internal_path]
+            label = f[self.label_internal_path]
+            assert raw.ndim in [3, 4], 'Raw dataset must be 3D (DxHxW) or 4D (CxDxHxW)'
+            assert label.ndim in [3, 4], 'Label dataset must be 3D (DxHxW) or 4D (CxDxHxW)'
+            assert _volume_shape(raw) == _volume_shape(label), 'Raw and labels have to be of the same size'
+            if self.weight_internal_path is not None:
+                weight_map = f[self.weight_internal_path]
+                assert weight_map.ndim in [3, 4], 'Weight map dataset must be 3D (DxHxW) or 4D (CxDxHxW)'
+                assert _volume_shape(raw) == _volume_shape(weight_map), 'Raw and weight map have to be of the same size'
+
+    @classmethod
+    def create_datasets(cls, dataset_config, phase):
+        phase_config = dataset_config[phase]
+
+        # load data augmentation configuration
+        transformer_config = phase_config['transformer']
+        # load slice builder config
+        slice_builder_config = phase_config['slice_builder']
+        # load files to process
+        file_paths = phase_config['file_paths']
+        # file_paths may contain both files and directories; if the file_path is a directory all H5 files inside
+        # are going to be included in the final file_paths
+        file_paths = traverse_h5_paths(file_paths)
+
+        datasets = []
+        for file_path in file_paths:
+            try:
+                logger.info(f'Loading {phase} set from: {file_path}...')
+                dataset = cls(file_path=file_path,
+                              phase=phase,
+                              slice_builder_config=slice_builder_config,
+                              transformer_config=transformer_config,
+                              raw_internal_path=dataset_config.get('raw_internal_path', 'raw'),
+                              label_internal_path=dataset_config.get('label_internal_path', 'label'),
+                              weight_internal_path=dataset_config.get('weight_internal_path', None),
+                              global_normalization=dataset_config.get('global_normalization', None))
+                datasets.append(dataset)
+            except Exception:
+                logger.error(f'Skipping {phase} set: {file_path}', exc_info=True)
+        return datasets
+
+
+class StandardHDF5Dataset(AbstractHDF5Dataset):
+    """
+    Implementation of the HDF5 dataset which loads the data from the H5 files into the memory.
+    Fast but might consume a lot of memory.
+    """
+
+    def __init__(self, file_path, phase, slice_builder_config, transformer_config,
+                 raw_internal_path='raw', label_internal_path='label', weight_internal_path=None,
+                 global_normalization=True):
+        super().__init__(file_path=file_path, phase=phase, slice_builder_config=slice_builder_config,
+                         transformer_config=transformer_config, raw_internal_path=raw_internal_path,
+                         label_internal_path=label_internal_path, weight_internal_path=weight_internal_path,
+                         global_normalization=global_normalization)
+        self._raw = None
+        self._raw_padded = None
+        self._label = None
+        self._weight_map = None
+
+    def get_raw_patch(self, idx):
+        if self._raw is None:
+            with h5py.File(self.file_path, 'r') as f:
+                assert self.raw_internal_path in f, f'Dataset {self.raw_internal_path} not found in {self.file_path}'
+                self._raw = f[self.raw_internal_path][:]
+        return self._raw[idx]
+
+    def get_label_patch(self, idx):
+        if self._label is None:
+            with h5py.File(self.file_path, 'r') as f:
+                assert self.label_internal_path in f, f'Dataset {self.label_internal_path} not found in {self.file_path}'
+                self._label = f[self.label_internal_path][:]
+        return self._label[idx]
+
+    def get_weight_patch(self, idx):
+        if self._weight_map is None:
+            with h5py.File(self.file_path, 'r') as f:
+                assert self.weight_internal_path in f, f'Dataset {self.weight_internal_path} not found in {self.file_path}'
+                self._weight_map = f[self.weight_internal_path][:]
+        return self._weight_map[idx]
+
+    def get_raw_padded_patch(self, idx):
+        if self._raw_padded is None:
+            with h5py.File(self.file_path, 'r') as f:
+                assert self.raw_internal_path in f, f'Dataset {self.raw_internal_path} not found in {self.file_path}'
+                self._raw_padded = mirror_pad(f[self.raw_internal_path][:], self.halo_shape)
+        return self._raw_padded[idx]
+
+
+class LazyHDF5Dataset(AbstractHDF5Dataset):
+    """Implementation of the HDF5 dataset which loads the data lazily. It's slower, but has a low memory footprint."""
+
+    def __init__(self, file_path, phase, slice_builder_config, transformer_config,
+                 raw_internal_path='raw', label_internal_path='label', weight_internal_path=None,
+                 global_normalization=False):
+        super().__init__(file_path=file_path, phase=phase, slice_builder_config=slice_builder_config,
+                         transformer_config=transformer_config, raw_internal_path=raw_internal_path,
+                         label_internal_path=label_internal_path, weight_internal_path=weight_internal_path,
+                         global_normalization=global_normalization)
+
+        logger.info("Using LazyHDF5Dataset")
+
+    def get_raw_patch(self, idx):
+        with h5py.File(self.file_path, 'r') as f:
+            return f[self.raw_internal_path][idx]
+
+    def get_label_patch(self, idx):
+        with h5py.File(self.file_path, 'r') as f:
+            return f[self.label_internal_path][idx]
+
+    def get_weight_patch(self, idx):
+        with h5py.File(self.file_path, 'r') as f:
+            return f[self.weight_internal_path][idx]
+
+    def get_raw_padded_patch(self, idx):
+        with h5py.File(self.file_path, 'r+') as f:
+            if 'raw_padded' in f:
+                return f['raw_padded'][idx]
+
+            raw = f[self.raw_internal_path][:]
+            raw_padded = mirror_pad(raw, self.halo_shape)
+            f.create_dataset('raw_padded', data=raw_padded, compression='gzip')
+            return raw_padded[idx]
diff --git a/build/lib/pytorch3dunet/datasets/utils.py b/build/lib/pytorch3dunet/datasets/utils.py
new file mode 100644
index 00000000..1ffeefe4
--- /dev/null
+++ b/build/lib/pytorch3dunet/datasets/utils.py
@@ -0,0 +1,361 @@
+import collections
+from typing import Any
+
+import numpy as np
+import torch
+from torch.utils.data import DataLoader, ConcatDataset, Dataset
+
+from pytorch3dunet.unet3d.utils import get_logger, get_class
+
+logger = get_logger('Dataset')
+
+
+class ConfigDataset(Dataset):
+    def __getitem__(self, index):
+        raise NotImplementedError
+
+    def __len__(self):
+        raise NotImplementedError
+
+    @classmethod
+    def create_datasets(cls, dataset_config, phase):
+        """
+        Factory method for creating a list of datasets based on the provided config.
+
+        Args:
+            dataset_config (dict): dataset configuration
+            phase (str): one of ['train', 'val', 'test']
+
+        Returns:
+            list of `Dataset` instances
+        """
+        raise NotImplementedError
+
+    @classmethod
+    def prediction_collate(cls, batch):
+        """Default collate_fn. Override in child class for non-standard datasets."""
+        return default_prediction_collate(batch)
+
+
+class SliceBuilder:
+    """
+    Builds the position of the patches in a given raw/label/weight ndarray based on the patch and stride shape.
+
+    Args:
+        raw_dataset (ndarray): raw data
+        label_dataset (ndarray): ground truth labels
+        weight_dataset (ndarray): weights for the labels
+        patch_shape (tuple): the shape of the patch DxHxW
+        stride_shape (tuple): the shape of the stride DxHxW
+        kwargs: additional metadata
+    """
+
+    def __init__(self, raw_dataset, label_dataset, weight_dataset, patch_shape, stride_shape, **kwargs):
+        patch_shape = tuple(patch_shape)
+        stride_shape = tuple(stride_shape)
+        skip_shape_check = kwargs.get('skip_shape_check', False)
+        if not skip_shape_check:
+            self._check_patch_shape(patch_shape)
+
+        self._raw_slices = self._build_slices(raw_dataset, patch_shape, stride_shape)
+        if label_dataset is None:
+            self._label_slices = None
+        else:
+            # take the first element in the label_dataset to build slices
+            self._label_slices = self._build_slices(label_dataset, patch_shape, stride_shape)
+            assert len(self._raw_slices) == len(self._label_slices)
+        if weight_dataset is None:
+            self._weight_slices = None
+        else:
+            self._weight_slices = self._build_slices(weight_dataset, patch_shape, stride_shape)
+            assert len(self.raw_slices) == len(self._weight_slices)
+
+    @property
+    def raw_slices(self):
+        return self._raw_slices
+
+    @property
+    def label_slices(self):
+        return self._label_slices
+
+    @property
+    def weight_slices(self):
+        return self._weight_slices
+
+    @staticmethod
+    def _build_slices(dataset, patch_shape, stride_shape):
+        """Iterates over a given n-dim dataset patch-by-patch with a given stride
+        and builds an array of slice positions.
+
+        Returns:
+            list of slices, i.e.
+            [(slice, slice, slice, slice), ...] if len(shape) == 4
+            [(slice, slice, slice), ...] if len(shape) == 3
+        """
+        slices = []
+        if dataset.ndim == 4:
+            in_channels, i_z, i_y, i_x = dataset.shape
+        else:
+            i_z, i_y, i_x = dataset.shape
+
+        k_z, k_y, k_x = patch_shape
+        s_z, s_y, s_x = stride_shape
+        z_steps = SliceBuilder._gen_indices(i_z, k_z, s_z)
+        for z in z_steps:
+            y_steps = SliceBuilder._gen_indices(i_y, k_y, s_y)
+            for y in y_steps:
+                x_steps = SliceBuilder._gen_indices(i_x, k_x, s_x)
+                for x in x_steps:
+                    slice_idx = (
+                        slice(z, z + k_z),
+                        slice(y, y + k_y),
+                        slice(x, x + k_x),
+                    )
+                    if dataset.ndim == 4:
+                        slice_idx = (slice(0, in_channels),) + slice_idx
+                    slices.append(slice_idx)
+        return slices
+
+    @staticmethod
+    def _gen_indices(i, k, s):
+        assert i >= k, 'Sample size has to be bigger than the patch size'
+        for j in range(0, i - k + 1, s):
+            yield j
+        if j + k < i:
+            yield i - k
+
+    @staticmethod
+    def _check_patch_shape(patch_shape):
+        assert len(patch_shape) == 3, 'patch_shape must be a 3D tuple'
+        assert patch_shape[1] >= 64 and patch_shape[2] >= 64, 'Height and Width must be greater or equal 64'
+
+
+class FilterSliceBuilder(SliceBuilder):
+    """
+    Filter patches containing more than `1 - threshold` of ignore_index label
+    """
+
+    def __init__(self, raw_dataset, label_dataset, weight_dataset, patch_shape, stride_shape, ignore_index=None,
+                 threshold=0.6, slack_acceptance=0.01, **kwargs):
+        super().__init__(raw_dataset, label_dataset, weight_dataset, patch_shape, stride_shape, **kwargs)
+        if label_dataset is None:
+            return
+
+        rand_state = np.random.RandomState(47)
+
+        def ignore_predicate(raw_label_idx):
+            label_idx = raw_label_idx[1]
+            patch = label_dataset[label_idx]
+            if ignore_index is not None:
+                patch = np.copy(patch)
+                patch[patch == ignore_index] = 0
+            non_ignore_counts = np.count_nonzero(patch != 0)
+            non_ignore_counts = non_ignore_counts / patch.size
+            return non_ignore_counts > threshold or rand_state.rand() < slack_acceptance
+
+        zipped_slices = zip(self.raw_slices, self.label_slices)
+        # ignore slices containing too much ignore_index
+        logger.info(f'Filtering slices...')
+        filtered_slices = list(filter(ignore_predicate, zipped_slices))
+        # unzip and save slices
+        raw_slices, label_slices = zip(*filtered_slices)
+        self._raw_slices = list(raw_slices)
+        self._label_slices = list(label_slices)
+
+
+def _loader_classes(class_name):
+    modules = [
+        'pytorch3dunet.datasets.hdf5',
+        'pytorch3dunet.datasets.dsb',
+        'pytorch3dunet.datasets.utils'
+    ]
+    return get_class(class_name, modules)
+
+
+def get_slice_builder(raws, labels, weight_maps, config):
+    assert 'name' in config
+    logger.info(f"Slice builder config: {config}")
+    slice_builder_cls = _loader_classes(config['name'])
+    return slice_builder_cls(raws, labels, weight_maps, **config)
+
+
+def get_train_loaders(config):
+    """
+    Returns dictionary containing the training and validation loaders (torch.utils.data.DataLoader).
+
+    :param config: a top level configuration object containing the 'loaders' key
+    :return: dict {
+        'train': <train_loader>
+        'val': <val_loader>
+    }
+    """
+    assert 'loaders' in config, 'Could not find data loaders configuration'
+    loaders_config = config['loaders']
+
+    logger.info('Creating training and validation set loaders...')
+
+    # get dataset class
+    dataset_cls_str = loaders_config.get('dataset', None)
+    if dataset_cls_str is None:
+        dataset_cls_str = 'StandardHDF5Dataset'
+        logger.warning(f"Cannot find dataset class in the config. Using default '{dataset_cls_str}'.")
+    dataset_class = _loader_classes(dataset_cls_str)
+
+    assert set(loaders_config['train']['file_paths']).isdisjoint(loaders_config['val']['file_paths']), \
+        "Train and validation 'file_paths' overlap. One cannot use validation data for training!"
+
+    train_datasets = dataset_class.create_datasets(loaders_config, phase='train')
+
+    val_datasets = dataset_class.create_datasets(loaders_config, phase='val')
+
+    num_workers = loaders_config.get('num_workers', 1)
+    logger.info(f'Number of workers for train/val dataloader: {num_workers}')
+    batch_size = loaders_config.get('batch_size', 1)
+    if torch.cuda.device_count() > 1 and not config['device'] == 'cpu':
+        logger.info(
+            f'{torch.cuda.device_count()} GPUs available. Using batch_size = {torch.cuda.device_count()} * {batch_size}')
+        batch_size = batch_size * torch.cuda.device_count()
+
+    logger.info(f'Batch size for train/val loader: {batch_size}')
+    # when training with volumetric data use batch_size of 1 due to GPU memory constraints
+    return {
+        'train': DataLoader(ConcatDataset(train_datasets), batch_size=batch_size, shuffle=True, pin_memory=True,
+                            num_workers=num_workers),
+        # don't shuffle during validation: useful when showing how predictions for a given batch get better over time
+        'val': DataLoader(ConcatDataset(val_datasets), batch_size=batch_size, shuffle=False, pin_memory=True,
+                          num_workers=num_workers)
+    }
+
+
+def get_test_loaders(config):
+    """
+    Returns test DataLoader.
+
+    :return: generator of DataLoader objects
+    """
+
+    assert 'loaders' in config, 'Could not find data loaders configuration'
+    loaders_config = config['loaders']
+
+    logger.info('Creating test set loaders...')
+
+    # get dataset class
+    dataset_cls_str = loaders_config.get('dataset', None)
+    if dataset_cls_str is None:
+        dataset_cls_str = 'StandardHDF5Dataset'
+        logger.warning(f"Cannot find dataset class in the config. Using default '{dataset_cls_str}'.")
+    dataset_class = _loader_classes(dataset_cls_str)
+
+    test_datasets = dataset_class.create_datasets(loaders_config, phase='test')
+
+    num_workers = loaders_config.get('num_workers', 1)
+    logger.info(f'Number of workers for the dataloader: {num_workers}')
+
+    batch_size = loaders_config.get('batch_size', 1)
+    if torch.cuda.device_count() > 1 and not config['device'] == 'cpu':
+        logger.info(
+            f'{torch.cuda.device_count()} GPUs available. Using batch_size = {torch.cuda.device_count()} * {batch_size}')
+        batch_size = batch_size * torch.cuda.device_count()
+
+    logger.info(f'Batch size for dataloader: {batch_size}')
+
+    # use generator in order to create data loaders lazily one by one
+    for test_dataset in test_datasets:
+        logger.info(f'Loading test set from: {test_dataset.file_path}...')
+        if hasattr(test_dataset, 'prediction_collate'):
+            collate_fn = test_dataset.prediction_collate
+        else:
+            collate_fn = default_prediction_collate
+
+        yield DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=True,
+                         collate_fn=collate_fn)
+
+
+def default_prediction_collate(batch):
+    """
+    Default collate_fn to form a mini-batch of Tensor(s) for HDF5 based datasets
+    """
+    error_msg = "batch must contain tensors or slice; found {}"
+    if isinstance(batch[0], torch.Tensor):
+        return torch.stack(batch, 0)
+    elif isinstance(batch[0], tuple) and isinstance(batch[0][0], slice):
+        return batch
+    elif isinstance(batch[0], collections.abc.Sequence):
+        transposed = zip(*batch)
+        return [default_prediction_collate(samples) for samples in transposed]
+
+    raise TypeError((error_msg.format(type(batch[0]))))
+
+
+def calculate_stats(img: np.array, skip: bool = False) -> dict[str, Any]:
+    """
+    Calculates the minimum percentile, maximum percentile, mean, and standard deviation of the image.
+
+    Args:
+        img: The input image array.
+        skip: if True, skip the calculation and return None for all values.
+
+    Returns:
+        tuple[float, float, float, float]: The minimum percentile, maximum percentile, mean, and std dev
+    """
+    if not skip:
+        pmin, pmax, mean, std = np.percentile(img, 1), np.percentile(img, 99.6), np.mean(img), np.std(img)
+    else:
+        pmin, pmax, mean, std = None, None, None, None
+
+    return {
+        'pmin': pmin,
+        'pmax': pmax,
+        'mean': mean,
+        'std': std
+    }
+
+
+def mirror_pad(image, padding_shape):
+    """
+    Pad the image with a mirror reflection of itself.
+
+    This function is used on data in its original shape before it is split into patches.
+
+    Args:
+        image (np.ndarray): The input image array to be padded.
+        padding_shape (tuple of int): Specifies the amount of padding for each dimension, should be YX or ZYX.
+
+    Returns:
+        np.ndarray: The mirror-padded image.
+
+    Raises:
+        ValueError: If any element of padding_shape is negative.
+    """
+    assert len(padding_shape) == 3, "Padding shape must be specified for each dimension: ZYX"
+
+    if any(p < 0 for p in padding_shape):
+        raise ValueError("padding_shape must be non-negative")
+
+    if all(p == 0 for p in padding_shape):
+        return image
+
+    pad_width = [(p, p) for p in padding_shape]
+
+    if image.ndim == 4:
+        pad_width = [(0, 0)] + pad_width
+    return np.pad(image, pad_width, mode='reflect')
+
+
+def remove_padding(m, padding_shape):
+    """
+    Removes padding from the margins of a multi-dimensional array.
+
+    Args:
+        m (np.ndarray): The input array to be unpadded.
+        padding_shape (tuple of int, optional): The amount of padding to remove from each dimension.
+            Assumes the tuple length matches the array dimensions.
+
+    Returns:
+        np.ndarray: The unpadded array.
+    """
+    if padding_shape is None:
+        return m
+
+    # Correctly construct slice objects for each dimension in padding_shape and apply them to m.
+    return m[(..., *(slice(p, -p or None) for p in padding_shape))]
diff --git a/build/lib/pytorch3dunet/predict.py b/build/lib/pytorch3dunet/predict.py
new file mode 100644
index 00000000..cc54fcf7
--- /dev/null
+++ b/build/lib/pytorch3dunet/predict.py
@@ -0,0 +1,59 @@
+import importlib
+import os
+
+import torch
+import torch.nn as nn
+
+from pytorch3dunet.datasets.utils import get_test_loaders
+from pytorch3dunet.unet3d import utils
+from pytorch3dunet.unet3d.config import load_config
+from pytorch3dunet.unet3d.model import get_model
+
+logger = utils.get_logger('UNet3DPredict')
+
+
+def get_predictor(model, config):
+    output_dir = config['loaders'].get('output_dir', None)
+    # override output_dir if provided in the 'predictor' section of the config
+    output_dir = config.get('predictor', {}).get('output_dir', output_dir)
+    if output_dir is not None:
+        os.makedirs(output_dir, exist_ok=True)
+
+    predictor_config = config.get('predictor', {})
+    class_name = predictor_config.get('name', 'StandardPredictor')
+
+    m = importlib.import_module('pytorch3dunet.unet3d.predictor')
+    predictor_class = getattr(m, class_name)
+    out_channels = config['model'].get('out_channels')
+    return predictor_class(model, output_dir, out_channels, **predictor_config)
+
+
+def main():
+    # Load configuration
+    config, _ = load_config()
+
+    # Create the model
+    model = get_model(config['model'])
+
+    # Load model state
+    model_path = config['model_path']
+    logger.info(f'Loading model from {model_path}...')
+    utils.load_checkpoint(model_path, model)
+    # use DataParallel if more than 1 GPU available
+
+    if torch.cuda.device_count() > 1 and not config['device'] == 'cpu':
+        model = nn.DataParallel(model)
+        logger.info(f'Using {torch.cuda.device_count()} GPUs for prediction')
+    if torch.cuda.is_available() and not config['device'] == 'cpu':
+        model = model.cuda()
+
+    # create predictor instance
+    predictor = get_predictor(model, config)
+
+    for test_loader in get_test_loaders(config):
+        # run the model prediction on the test_loader and save the results in the output_dir
+        predictor(test_loader)
+
+
+if __name__ == '__main__':
+    main()
diff --git a/build/lib/pytorch3dunet/train.py b/build/lib/pytorch3dunet/train.py
new file mode 100644
index 00000000..eceaf719
--- /dev/null
+++ b/build/lib/pytorch3dunet/train.py
@@ -0,0 +1,35 @@
+import random
+
+import torch
+
+from pytorch3dunet.unet3d.config import load_config, copy_config
+from pytorch3dunet.unet3d.trainer import create_trainer
+from pytorch3dunet.unet3d.utils import get_logger
+
+logger = get_logger('TrainingSetup')
+
+
+def main():
+    # Load and log experiment configuration
+    config, config_path = load_config()
+    logger.info(config)
+
+    manual_seed = config.get('manual_seed', None)
+    if manual_seed is not None:
+        logger.info(f'Seed the RNG for all devices with {manual_seed}')
+        logger.warning('Using CuDNN deterministic setting. This may slow down the training!')
+        random.seed(manual_seed)
+        torch.manual_seed(manual_seed)
+        # see https://pytorch.org/docs/stable/notes/randomness.html
+        torch.backends.cudnn.deterministic = True
+
+    # Create trainer
+    trainer = create_trainer(config)
+    # Copy config file
+    copy_config(config, config_path)
+    # Start training
+    trainer.fit()
+
+
+if __name__ == '__main__':
+    main()
diff --git a/build/lib/pytorch3dunet/unet3d/__init__.py b/build/lib/pytorch3dunet/unet3d/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/build/lib/pytorch3dunet/unet3d/buildingblocks.py b/build/lib/pytorch3dunet/unet3d/buildingblocks.py
new file mode 100644
index 00000000..25679c24
--- /dev/null
+++ b/build/lib/pytorch3dunet/unet3d/buildingblocks.py
@@ -0,0 +1,545 @@
+from functools import partial
+
+import torch
+from torch import nn as nn
+from torch.nn import functional as F
+
+from pytorch3dunet.unet3d.se import ChannelSELayer3D, ChannelSpatialSELayer3D, SpatialSELayer3D
+
+
+def create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding,
+                dropout_prob, is3d):
+    """
+    Create a list of modules with together constitute a single conv layer with non-linearity
+    and optional batchnorm/groupnorm.
+
+    Args:
+        in_channels (int): number of input channels
+        out_channels (int): number of output channels
+        kernel_size(int or tuple): size of the convolving kernel
+        order (string): order of things, e.g.
+            'cr' -> conv + ReLU
+            'gcr' -> groupnorm + conv + ReLU
+            'cl' -> conv + LeakyReLU
+            'ce' -> conv + ELU
+            'bcr' -> batchnorm + conv + ReLU
+            'cbrd' -> conv + batchnorm + ReLU + dropout
+            'cbrD' -> conv + batchnorm + ReLU + dropout2d
+        num_groups (int): number of groups for the GroupNorm
+        padding (int or tuple): add zero-padding added to all three sides of the input
+        dropout_prob (float): dropout probability
+        is3d (bool): is3d (bool): if True use Conv3d, otherwise use Conv2d
+    Return:
+        list of tuple (name, module)
+    """
+    assert 'c' in order, "Conv layer MUST be present"
+    assert order[0] not in 'rle', 'Non-linearity cannot be the first operation in the layer'
+
+    modules = []
+    for i, char in enumerate(order):
+        if char == 'r':
+            modules.append(('ReLU', nn.ReLU(inplace=True)))
+        elif char == 'l':
+            modules.append(('LeakyReLU', nn.LeakyReLU(inplace=True)))
+        elif char == 'e':
+            modules.append(('ELU', nn.ELU(inplace=True)))
+        elif char == 'c':
+            # add learnable bias only in the absence of batchnorm/groupnorm
+            bias = not ('g' in order or 'b' in order)
+            if is3d:
+                conv = nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding, bias=bias)
+            else:
+                conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=bias)
+
+            modules.append(('conv', conv))
+        elif char == 'g':
+            is_before_conv = i < order.index('c')
+            if is_before_conv:
+                num_channels = in_channels
+            else:
+                num_channels = out_channels
+
+            # use only one group if the given number of groups is greater than the number of channels
+            if num_channels < num_groups:
+                num_groups = 1
+
+            assert num_channels % num_groups == 0, f'Expected number of channels in input to be divisible by num_groups. num_channels={num_channels}, num_groups={num_groups}'
+            modules.append(('groupnorm', nn.GroupNorm(num_groups=num_groups, num_channels=num_channels)))
+        elif char == 'b':
+            is_before_conv = i < order.index('c')
+            if is3d:
+                bn = nn.BatchNorm3d
+            else:
+                bn = nn.BatchNorm2d
+
+            if is_before_conv:
+                modules.append(('batchnorm', bn(in_channels)))
+            else:
+                modules.append(('batchnorm', bn(out_channels)))
+        elif char == 'd':
+            modules.append(('dropout', nn.Dropout(p=dropout_prob)))
+        elif char == 'D':
+            modules.append(('dropout2d', nn.Dropout2d(p=dropout_prob)))
+        else:
+            raise ValueError(f"Unsupported layer type '{char}'. MUST be one of ['b', 'g', 'r', 'l', 'e', 'c', 'd', 'D']")
+
+    return modules
+
+
+class SingleConv(nn.Sequential):
+    """
+    Basic convolutional module consisting of a Conv3d, non-linearity and optional batchnorm/groupnorm. The order
+    of operations can be specified via the `order` parameter
+
+    Args:
+        in_channels (int): number of input channels
+        out_channels (int): number of output channels
+        kernel_size (int or tuple): size of the convolving kernel
+        order (string): determines the order of layers, e.g.
+            'cr' -> conv + ReLU
+            'crg' -> conv + ReLU + groupnorm
+            'cl' -> conv + LeakyReLU
+            'ce' -> conv + ELU
+        num_groups (int): number of groups for the GroupNorm
+        padding (int or tuple): add zero-padding
+        dropout_prob (float): dropout probability, default 0.1
+        is3d (bool): if True use Conv3d, otherwise use Conv2d
+    """
+
+    def __init__(self, in_channels, out_channels, kernel_size=3, order='gcr', num_groups=8,
+                 padding=1, dropout_prob=0.1, is3d=True):
+        super(SingleConv, self).__init__()
+
+        for name, module in create_conv(in_channels, out_channels, kernel_size, order,
+                                        num_groups, padding, dropout_prob, is3d):
+            self.add_module(name, module)
+
+
+class DoubleConv(nn.Sequential):
+    """
+    A module consisting of two consecutive convolution layers (e.g. BatchNorm3d+ReLU+Conv3d).
+    We use (Conv3d+ReLU+GroupNorm3d) by default.
+    This can be changed however by providing the 'order' argument, e.g. in order
+    to change to Conv3d+BatchNorm3d+ELU use order='cbe'.
+    Use padded convolutions to make sure that the output (H_out, W_out) is the same
+    as (H_in, W_in), so that you don't have to crop in the decoder path.
+
+    Args:
+        in_channels (int): number of input channels
+        out_channels (int): number of output channels
+        encoder (bool): if True we're in the encoder path, otherwise we're in the decoder
+        kernel_size (int or tuple): size of the convolving kernel
+        order (string): determines the order of layers, e.g.
+            'cr' -> conv + ReLU
+            'crg' -> conv + ReLU + groupnorm
+            'cl' -> conv + LeakyReLU
+            'ce' -> conv + ELU
+        num_groups (int): number of groups for the GroupNorm
+        padding (int or tuple): add zero-padding added to all three sides of the input
+        upscale (int): number of the convolution to upscale in encoder if DoubleConv, default: 2
+        dropout_prob (float or tuple): dropout probability for each convolution, default 0.1
+        is3d (bool): if True use Conv3d instead of Conv2d layers
+    """
+
+    def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order='gcr',
+                 num_groups=8, padding=1, upscale=2, dropout_prob=0.1, is3d=True):
+        super(DoubleConv, self).__init__()
+        if encoder:
+            # we're in the encoder path
+            conv1_in_channels = in_channels
+            if upscale == 1:
+                conv1_out_channels = out_channels
+            else:
+                conv1_out_channels = out_channels // 2
+            if conv1_out_channels < in_channels:
+                conv1_out_channels = in_channels
+            conv2_in_channels, conv2_out_channels = conv1_out_channels, out_channels
+        else:
+            # we're in the decoder path, decrease the number of channels in the 1st convolution
+            conv1_in_channels, conv1_out_channels = in_channels, out_channels
+            conv2_in_channels, conv2_out_channels = out_channels, out_channels
+
+        # check if dropout_prob is a tuple and if so
+        # split it for different dropout probabilities for each convolution.
+        if isinstance(dropout_prob, list) or isinstance(dropout_prob, tuple):
+            dropout_prob1 = dropout_prob[0]
+            dropout_prob2 = dropout_prob[1]
+        else:
+            dropout_prob1 = dropout_prob2 = dropout_prob
+
+        # conv1
+        self.add_module('SingleConv1',
+                        SingleConv(conv1_in_channels, conv1_out_channels, kernel_size, order, num_groups,
+                                   padding=padding, dropout_prob=dropout_prob1, is3d=is3d))
+        # conv2
+        self.add_module('SingleConv2',
+                        SingleConv(conv2_in_channels, conv2_out_channels, kernel_size, order, num_groups,
+                                   padding=padding, dropout_prob=dropout_prob2, is3d=is3d))
+
+
+class ResNetBlock(nn.Module):
+    """
+    Residual block that can be used instead of standard DoubleConv in the Encoder module.
+    Motivated by: https://arxiv.org/pdf/1706.00120.pdf
+
+    Notice we use ELU instead of ReLU (order='cge') and put non-linearity after the groupnorm.
+    """
+
+    def __init__(self, in_channels, out_channels, kernel_size=3, order='cge', num_groups=8, is3d=True, **kwargs):
+        super(ResNetBlock, self).__init__()
+
+        if in_channels != out_channels:
+            # conv1x1 for increasing the number of channels
+            if is3d:
+                self.conv1 = nn.Conv3d(in_channels, out_channels, 1)
+            else:
+                self.conv1 = nn.Conv2d(in_channels, out_channels, 1)
+        else:
+            self.conv1 = nn.Identity()
+
+        # residual block
+        self.conv2 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups,
+                                is3d=is3d)
+        # remove non-linearity from the 3rd convolution since it's going to be applied after adding the residual
+        n_order = order
+        for c in 'rel':
+            n_order = n_order.replace(c, '')
+        self.conv3 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=n_order,
+                                num_groups=num_groups, is3d=is3d)
+
+        # create non-linearity separately
+        if 'l' in order:
+            self.non_linearity = nn.LeakyReLU(negative_slope=0.1, inplace=True)
+        elif 'e' in order:
+            self.non_linearity = nn.ELU(inplace=True)
+        else:
+            self.non_linearity = nn.ReLU(inplace=True)
+
+    def forward(self, x):
+        # apply first convolution to bring the number of channels to out_channels
+        residual = self.conv1(x)
+
+        # residual block
+        out = self.conv2(residual)
+        out = self.conv3(out)
+
+        out += residual
+        out = self.non_linearity(out)
+
+        return out
+
+
+class ResNetBlockSE(ResNetBlock):
+    def __init__(self, in_channels, out_channels, kernel_size=3, order='cge', num_groups=8, se_module='scse', **kwargs):
+        super(ResNetBlockSE, self).__init__(
+            in_channels, out_channels, kernel_size=kernel_size, order=order,
+            num_groups=num_groups, **kwargs)
+        assert se_module in ['scse', 'cse', 'sse']
+        if se_module == 'scse':
+            self.se_module = ChannelSpatialSELayer3D(num_channels=out_channels, reduction_ratio=1)
+        elif se_module == 'cse':
+            self.se_module = ChannelSELayer3D(num_channels=out_channels, reduction_ratio=1)
+        elif se_module == 'sse':
+            self.se_module = SpatialSELayer3D(num_channels=out_channels)
+
+    def forward(self, x):
+        out = super().forward(x)
+        out = self.se_module(out)
+        return out
+
+
+class Encoder(nn.Module):
+    """
+    A single module from the encoder path consisting of the optional max
+    pooling layer (one may specify the MaxPool kernel_size to be different
+    from the standard (2,2,2), e.g. if the volumetric data is anisotropic
+    (make sure to use complementary scale_factor in the decoder path) followed by
+    a basic module (DoubleConv or ResNetBlock).
+
+    Args:
+        in_channels (int): number of input channels
+        out_channels (int): number of output channels
+        conv_kernel_size (int or tuple): size of the convolving kernel
+        apply_pooling (bool): if True use MaxPool3d before DoubleConv
+        pool_kernel_size (int or tuple): the size of the window
+        pool_type (str): pooling layer: 'max' or 'avg'
+        basic_module(nn.Module): either ResNetBlock or DoubleConv
+        conv_layer_order (string): determines the order of layers
+            in `DoubleConv` module. See `DoubleConv` for more info.
+        num_groups (int): number of groups for the GroupNorm
+        padding (int or tuple): add zero-padding added to all three sides of the input
+        upscale (int): number of the convolution to upscale in encoder if DoubleConv, default: 2
+        dropout_prob (float or tuple): dropout probability, default 0.1
+        is3d (bool): use 3d or 2d convolutions/pooling operation
+    """
+
+    def __init__(self, in_channels, out_channels, conv_kernel_size=3, apply_pooling=True,
+                 pool_kernel_size=2, pool_type='max', basic_module=DoubleConv, conv_layer_order='gcr',
+                 num_groups=8, padding=1, upscale=2, dropout_prob=0.1, is3d=True):
+        super(Encoder, self).__init__()
+        assert pool_type in ['max', 'avg']
+        if apply_pooling:
+            if pool_type == 'max':
+                if is3d:
+                    self.pooling = nn.MaxPool3d(kernel_size=pool_kernel_size)
+                else:
+                    self.pooling = nn.MaxPool2d(kernel_size=pool_kernel_size)
+            else:
+                if is3d:
+                    self.pooling = nn.AvgPool3d(kernel_size=pool_kernel_size)
+                else:
+                    self.pooling = nn.AvgPool2d(kernel_size=pool_kernel_size)
+        else:
+            self.pooling = None
+
+        self.basic_module = basic_module(in_channels, out_channels,
+                                         encoder=True,
+                                         kernel_size=conv_kernel_size,
+                                         order=conv_layer_order,
+                                         num_groups=num_groups,
+                                         padding=padding,
+                                         upscale=upscale,
+                                         dropout_prob=dropout_prob,
+                                         is3d=is3d)
+
+    def forward(self, x):
+        if self.pooling is not None:
+            x = self.pooling(x)
+        x = self.basic_module(x)
+        return x
+
+
+class Decoder(nn.Module):
+    """
+    A single module for decoder path consisting of the upsampling layer
+    (either learned ConvTranspose3d or nearest neighbor interpolation)
+    followed by a basic module (DoubleConv or ResNetBlock).
+
+    Args:
+        in_channels (int): number of input channels
+        out_channels (int): number of output channels
+        conv_kernel_size (int or tuple): size of the convolving kernel
+        scale_factor (int or tuple): used as the multiplier for the image H/W/D in
+            case of nn.Upsample or as stride in case of ConvTranspose3d, must reverse the MaxPool3d operation
+            from the corresponding encoder
+        basic_module(nn.Module): either ResNetBlock or DoubleConv
+        conv_layer_order (string): determines the order of layers
+            in `DoubleConv` module. See `DoubleConv` for more info.
+        num_groups (int): number of groups for the GroupNorm
+        padding (int or tuple): add zero-padding added to all three sides of the input
+        upsample (str): algorithm used for upsampling:
+            InterpolateUpsampling:   'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area'
+            TransposeConvUpsampling: 'deconv'
+            No upsampling:           None
+            Default: 'default' (chooses automatically)
+        dropout_prob (float or tuple): dropout probability, default 0.1
+    """
+
+    def __init__(self, in_channels, out_channels, conv_kernel_size=3, scale_factor=2, basic_module=DoubleConv,
+                 conv_layer_order='gcr', num_groups=8, padding=1, upsample='default',
+                 dropout_prob=0.1, is3d=True):
+        super(Decoder, self).__init__()
+
+        # perform concat joining per default
+        concat = True
+
+        # don't adapt channels after join operation
+        adapt_channels = False
+
+        if upsample is not None and upsample != 'none':
+            if upsample == 'default':
+                if basic_module == DoubleConv:
+                    upsample = 'nearest'  # use nearest neighbor interpolation for upsampling
+                    concat = True  # use concat joining
+                    adapt_channels = False  # don't adapt channels
+                elif basic_module == ResNetBlock or basic_module == ResNetBlockSE:
+                    upsample = 'deconv'  # use deconvolution upsampling
+                    concat = False  # use summation joining
+                    adapt_channels = True  # adapt channels after joining
+
+            # perform deconvolution upsampling if mode is deconv
+            if upsample == 'deconv':
+                self.upsampling = TransposeConvUpsampling(in_channels=in_channels, out_channels=out_channels,
+                                                          kernel_size=conv_kernel_size, scale_factor=scale_factor,
+                                                          is3d=is3d)
+            else:
+                self.upsampling = InterpolateUpsampling(mode=upsample)
+        else:
+            # no upsampling
+            self.upsampling = NoUpsampling()
+            # concat joining
+            self.joining = partial(self._joining, concat=True)
+
+        # perform joining operation
+        self.joining = partial(self._joining, concat=concat)
+
+        # adapt the number of in_channels for the ResNetBlock
+        if adapt_channels is True:
+            in_channels = out_channels
+
+        self.basic_module = basic_module(in_channels, out_channels,
+                                         encoder=False,
+                                         kernel_size=conv_kernel_size,
+                                         order=conv_layer_order,
+                                         num_groups=num_groups,
+                                         padding=padding,
+                                         dropout_prob=dropout_prob,
+                                         is3d=is3d)
+
+    def forward(self, encoder_features, x):
+        x = self.upsampling(encoder_features=encoder_features, x=x)
+        x = self.joining(encoder_features, x)
+        x = self.basic_module(x)
+        return x
+
+    @staticmethod
+    def _joining(encoder_features, x, concat):
+        if concat:
+            return torch.cat((encoder_features, x), dim=1)
+        else:
+            return encoder_features + x
+
+
+def create_encoders(in_channels, f_maps, basic_module, conv_kernel_size, conv_padding,
+                    conv_upscale, dropout_prob,
+                    layer_order, num_groups, pool_kernel_size, is3d):
+    # create encoder path consisting of Encoder modules. Depth of the encoder is equal to `len(f_maps)`
+    encoders = []
+    for i, out_feature_num in enumerate(f_maps):
+        if i == 0:
+            # apply conv_coord only in the first encoder if any
+            encoder = Encoder(in_channels, out_feature_num,
+                              apply_pooling=False,  # skip pooling in the firs encoder
+                              basic_module=basic_module,
+                              conv_layer_order=layer_order,
+                              conv_kernel_size=conv_kernel_size,
+                              num_groups=num_groups,
+                              padding=conv_padding,
+                              upscale=conv_upscale,
+                              dropout_prob=dropout_prob,
+                              is3d=is3d)
+        else:
+            encoder = Encoder(f_maps[i - 1], out_feature_num,
+                              basic_module=basic_module,
+                              conv_layer_order=layer_order,
+                              conv_kernel_size=conv_kernel_size,
+                              num_groups=num_groups,
+                              pool_kernel_size=pool_kernel_size,
+                              padding=conv_padding,
+                              upscale=conv_upscale,
+                              dropout_prob=dropout_prob,
+                              is3d=is3d)
+
+        encoders.append(encoder)
+
+    return nn.ModuleList(encoders)
+
+
+def create_decoders(f_maps, basic_module, conv_kernel_size, conv_padding, layer_order,
+                    num_groups, upsample, dropout_prob, is3d):
+    # create decoder path consisting of the Decoder modules. The length of the decoder list is equal to `len(f_maps) - 1`
+    decoders = []
+    reversed_f_maps = list(reversed(f_maps))
+    for i in range(len(reversed_f_maps) - 1):
+        if basic_module == DoubleConv and upsample != 'deconv':
+            in_feature_num = reversed_f_maps[i] + reversed_f_maps[i + 1]
+        else:
+            in_feature_num = reversed_f_maps[i]
+
+        out_feature_num = reversed_f_maps[i + 1]
+
+        decoder = Decoder(in_feature_num, out_feature_num,
+                          basic_module=basic_module,
+                          conv_layer_order=layer_order,
+                          conv_kernel_size=conv_kernel_size,
+                          num_groups=num_groups,
+                          padding=conv_padding,
+                          upsample=upsample,
+                          dropout_prob=dropout_prob,
+                          is3d=is3d)
+        decoders.append(decoder)
+    return nn.ModuleList(decoders)
+
+
+class AbstractUpsampling(nn.Module):
+    """
+    Abstract class for upsampling. A given implementation should upsample a given 5D input tensor using either
+    interpolation or learned transposed convolution.
+    """
+
+    def __init__(self, upsample):
+        super(AbstractUpsampling, self).__init__()
+        self.upsample = upsample
+
+    def forward(self, encoder_features, x):
+        # get the spatial dimensions of the output given the encoder_features
+        output_size = encoder_features.size()[2:]
+        # upsample the input and return
+        return self.upsample(x, output_size)
+
+
+class InterpolateUpsampling(AbstractUpsampling):
+    """
+    Args:
+        mode (str): algorithm used for upsampling:
+            'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area'. Default: 'nearest'
+            used only if transposed_conv is False
+    """
+
+    def __init__(self, mode='nearest'):
+        upsample = partial(self._interpolate, mode=mode)
+        super().__init__(upsample)
+
+    @staticmethod
+    def _interpolate(x, size, mode):
+        return F.interpolate(x, size=size, mode=mode)
+
+
+class TransposeConvUpsampling(AbstractUpsampling):
+    """
+    Args:
+        in_channels (int): number of input channels for transposed conv
+            used only if transposed_conv is True
+        out_channels (int): number of output channels for transpose conv
+            used only if transposed_conv is True
+        kernel_size (int or tuple): size of the convolving kernel
+            used only if transposed_conv is True
+        scale_factor (int or tuple): stride of the convolution
+            used only if transposed_conv is True
+        is3d (bool): if True use ConvTranspose3d, otherwise use ConvTranspose2d
+    """
+
+    class Upsample(nn.Module):
+        """
+        Workaround the 'ValueError: requested an output size...' in the `_output_padding` method in
+        transposed convolution. It performs transposed conv followed by the interpolation to the correct size if necessary.
+        """
+
+        def __init__(self, conv_transposed, is3d):
+            super().__init__()
+            self.conv_transposed = conv_transposed
+            self.is3d = is3d
+
+        def forward(self, x, size):
+            x = self.conv_transposed(x)
+            return F.interpolate(x, size=size)
+
+    def __init__(self, in_channels, out_channels, kernel_size=3, scale_factor=2, is3d=True):
+        # make sure that the output size reverses the MaxPool3d from the corresponding encoder
+        if is3d is True:
+            conv_transposed = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=kernel_size,
+                                                 stride=scale_factor, padding=1, bias=False)
+        else:
+            conv_transposed = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size,
+                                                 stride=scale_factor, padding=1, bias=False)
+        upsample = self.Upsample(conv_transposed, is3d)
+        super().__init__(upsample)
+
+
+class NoUpsampling(AbstractUpsampling):
+    def __init__(self):
+        super().__init__(self._no_upsampling)
+
+    @staticmethod
+    def _no_upsampling(x, size):
+        return x
diff --git a/build/lib/pytorch3dunet/unet3d/config.py b/build/lib/pytorch3dunet/unet3d/config.py
new file mode 100644
index 00000000..bb011632
--- /dev/null
+++ b/build/lib/pytorch3dunet/unet3d/config.py
@@ -0,0 +1,79 @@
+import argparse
+import os
+import shutil
+
+import torch
+import yaml
+
+from pytorch3dunet.unet3d import utils
+
+logger = utils.get_logger('ConfigLoader')
+
+
+def _override_config(args, config):
+    """Overrides config params with the ones given in command line."""
+
+    args_dict = vars(args)
+    # remove the first argument which is the config file path
+    args_dict.pop('config')
+
+    for key, value in args_dict.items():
+        if value is None:
+            continue
+        c = config
+        for k in key.split('.'):
+            if k not in c:
+                raise ValueError(f'Invalid config key: {key}')
+            if isinstance(c[k], dict):
+                c = c[k]
+            else:
+                c[k] = value
+
+
+def load_config():
+    parser = argparse.ArgumentParser(description='UNet3D')
+    parser.add_argument('--config', type=str, help='Path to the YAML config file', required=True)
+    # add additional command line arguments for the prediction that override the ones in the config file
+    parser.add_argument('--model_path', type=str, required=False)
+    parser.add_argument('--loaders.output_dir', type=str, required=False)
+    parser.add_argument('--loaders.test.file_paths', type=str, nargs="+", required=False)
+    parser.add_argument('--loaders.test.slice_builder.patch_shape', type=int, nargs="+", required=False)
+    parser.add_argument('--loaders.test.slice_builder.stride_shape', type=int, nargs="+", required=False)
+
+    args = parser.parse_args()
+    config_path = args.config
+    config = yaml.safe_load(open(config_path, 'r'))
+    _override_config(args, config)
+
+    device = config.get('device', None)
+    if device == 'cpu':
+        logger.warning('CPU mode forced in config, this will likely result in slow training/prediction')
+        config['device'] = 'cpu'
+        return config
+
+    if torch.cuda.is_available():
+        config['device'] = 'cuda'
+    else:
+        logger.warning('CUDA not available, using CPU')
+        config['device'] = 'cpu'
+    return config, config_path
+
+
+def copy_config(config, config_path):
+    """Copies the config file to the checkpoint folder."""
+
+    def _get_last_subfolder_path(path):
+        subfolders = [f.path for f in os.scandir(path) if f.is_dir()]
+        return max(subfolders, default=None)
+
+    checkpoint_dir = os.path.join(
+        config['trainer'].pop('checkpoint_dir'), 'logs')
+    last_run_dir = _get_last_subfolder_path(checkpoint_dir)
+    config_file_name = os.path.basename(config_path)
+
+    if last_run_dir:
+        shutil.copy2(config_path, os.path.join(last_run_dir, config_file_name))
+
+
+def _load_config_yaml(config_file):
+    return yaml.safe_load(open(config_file, 'r'))
diff --git a/build/lib/pytorch3dunet/unet3d/losses.py b/build/lib/pytorch3dunet/unet3d/losses.py
new file mode 100644
index 00000000..6a53966f
--- /dev/null
+++ b/build/lib/pytorch3dunet/unet3d/losses.py
@@ -0,0 +1,345 @@
+import torch
+import torch.nn.functional as F
+from torch import nn as nn
+from torch.nn import MSELoss, SmoothL1Loss, L1Loss
+
+
+def compute_per_channel_dice(input, target, epsilon=1e-6, weight=None):
+    """
+    Computes DiceCoefficient as defined in https://arxiv.org/abs/1606.04797 given  a multi channel input and target.
+    Assumes the input is a normalized probability, e.g. a result of Sigmoid or Softmax function.
+
+    Args:
+         input (torch.Tensor): NxCxSpatial input tensor
+         target (torch.Tensor): NxCxSpatial target tensor
+         epsilon (float): prevents division by zero
+         weight (torch.Tensor): Cx1 tensor of weight per channel/class
+    """
+
+    # input and target shapes must match
+    assert input.size() == target.size(), "'input' and 'target' must have the same shape"
+
+    input = flatten(input)
+    target = flatten(target)
+    target = target.float()
+
+    # compute per channel Dice Coefficient
+    intersect = (input * target).sum(-1)
+    if weight is not None:
+        intersect = weight * intersect
+
+    # here we can use standard dice (input + target).sum(-1) or extension (see V-Net) (input^2 + target^2).sum(-1)
+    denominator = (input * input).sum(-1) + (target * target).sum(-1)
+    return 2 * (intersect / denominator.clamp(min=epsilon))
+
+
+class _MaskingLossWrapper(nn.Module):
+    """
+    Loss wrapper which prevents the gradient of the loss to be computed where target is equal to `ignore_index`.
+    """
+
+    def __init__(self, loss, ignore_index):
+        super(_MaskingLossWrapper, self).__init__()
+        assert ignore_index is not None, 'ignore_index cannot be None'
+        self.loss = loss
+        self.ignore_index = ignore_index
+
+    def forward(self, input, target):
+        mask = target.clone().ne_(self.ignore_index)
+        mask.requires_grad = False
+
+        # mask out input/target so that the gradient is zero where on the mask
+        input = input * mask
+        target = target * mask
+
+        # forward masked input and target to the loss
+        return self.loss(input, target)
+
+
+class SkipLastTargetChannelWrapper(nn.Module):
+    """
+    Loss wrapper which removes additional target channel
+    """
+
+    def __init__(self, loss, squeeze_channel=False):
+        super(SkipLastTargetChannelWrapper, self).__init__()
+        self.loss = loss
+        self.squeeze_channel = squeeze_channel
+
+    def forward(self, input, target, weight=None):
+        assert target.size(1) > 1, 'Target tensor has a singleton channel dimension, cannot remove channel'
+
+        # skips last target channel if needed
+        target = target[:, :-1, ...]
+
+        if self.squeeze_channel:
+            # squeeze channel dimension
+            target = torch.squeeze(target, dim=1)
+        if weight is not None:
+            return self.loss(input, target, weight)
+        return self.loss(input, target)
+
+
+class _AbstractDiceLoss(nn.Module):
+    """
+    Base class for different implementations of Dice loss.
+    """
+
+    def __init__(self, weight=None, normalization='sigmoid'):
+        super(_AbstractDiceLoss, self).__init__()
+        self.register_buffer('weight', weight)
+        # The output from the network during training is assumed to be un-normalized probabilities and we would
+        # like to normalize the logits. Since Dice (or soft Dice in this case) is usually used for binary data,
+        # normalizing the channels with Sigmoid is the default choice even for multi-class segmentation problems.
+        # However if one would like to apply Softmax in order to get the proper probability distribution from the
+        # output, just specify `normalization=Softmax`
+        assert normalization in ['sigmoid', 'softmax', 'none']
+        if normalization == 'sigmoid':
+            self.normalization = nn.Sigmoid()
+        elif normalization == 'softmax':
+            self.normalization = nn.Softmax(dim=1)
+        else:
+            self.normalization = lambda x: x
+
+    def dice(self, input, target, weight):
+        # actual Dice score computation; to be implemented by the subclass
+        raise NotImplementedError
+
+    def forward(self, input, target):
+        # get probabilities from logits
+        input = self.normalization(input)
+
+        # compute per channel Dice coefficient
+        per_channel_dice = self.dice(input, target, weight=self.weight)
+
+        # average Dice score across all channels/classes
+        return 1. - torch.mean(per_channel_dice)
+
+
+class DiceLoss(_AbstractDiceLoss):
+    """Computes Dice Loss according to https://arxiv.org/abs/1606.04797.
+    For multi-class segmentation `weight` parameter can be used to assign different weights per class.
+    The input to the loss function is assumed to be a logit and will be normalized by the Sigmoid function.
+    """
+
+    def __init__(self, weight=None, normalization='sigmoid'):
+        super().__init__(weight, normalization)
+
+    def dice(self, input, target, weight):
+        return compute_per_channel_dice(input, target, weight=self.weight)
+
+
+class GeneralizedDiceLoss(_AbstractDiceLoss):
+    """Computes Generalized Dice Loss (GDL) as described in https://arxiv.org/pdf/1707.03237.pdf.
+    """
+
+    def __init__(self, normalization='sigmoid', epsilon=1e-6):
+        super().__init__(weight=None, normalization=normalization)
+        self.epsilon = epsilon
+
+    def dice(self, input, target, weight):
+        assert input.size() == target.size(), "'input' and 'target' must have the same shape"
+
+        input = flatten(input)
+        target = flatten(target)
+        target = target.float()
+
+        if input.size(0) == 1:
+            # for GDL to make sense we need at least 2 channels (see https://arxiv.org/pdf/1707.03237.pdf)
+            # put foreground and background voxels in separate channels
+            input = torch.cat((input, 1 - input), dim=0)
+            target = torch.cat((target, 1 - target), dim=0)
+
+        # GDL weighting: the contribution of each label is corrected by the inverse of its volume
+        w_l = target.sum(-1)
+        w_l = 1 / (w_l * w_l).clamp(min=self.epsilon)
+        w_l.requires_grad = False
+
+        intersect = (input * target).sum(-1)
+        intersect = intersect * w_l
+
+        denominator = (input + target).sum(-1)
+        denominator = (denominator * w_l).clamp(min=self.epsilon)
+
+        return 2 * (intersect.sum() / denominator.sum())
+
+
+class BCEDiceLoss(nn.Module):
+    """Linear combination of BCE and Dice losses"""
+
+    def __init__(self, alpha, beta):
+        super(BCEDiceLoss, self).__init__()
+        self.alpha = alpha
+        self.bce = nn.BCEWithLogitsLoss()
+        self.beta = beta
+        self.dice = DiceLoss()
+
+    def forward(self, input, target):
+        return self.alpha * self.bce(input, target) + self.beta * self.dice(input, target)
+
+
+class WeightedCrossEntropyLoss(nn.Module):
+    """WeightedCrossEntropyLoss (WCE) as described in https://arxiv.org/pdf/1707.03237.pdf
+    """
+
+    def __init__(self, ignore_index=-1):
+        super(WeightedCrossEntropyLoss, self).__init__()
+        self.ignore_index = ignore_index
+
+    def forward(self, input, target):
+        weight = self._class_weights(input)
+        return F.cross_entropy(input, target, weight=weight, ignore_index=self.ignore_index)
+
+    @staticmethod
+    def _class_weights(input):
+        # normalize the input first
+        input = F.softmax(input, dim=1)
+        flattened = flatten(input)
+        nominator = (1. - flattened).sum(-1)
+        denominator = flattened.sum(-1)
+        class_weights = nominator / denominator
+        return class_weights.detach()
+
+
+class PixelWiseCrossEntropyLoss(nn.Module):
+    def __init__(self, ignore_index=None):
+        super(PixelWiseCrossEntropyLoss, self).__init__()
+        self.ignore_index = ignore_index
+        self.log_softmax = nn.LogSoftmax(dim=1)
+
+    def forward(self, input, target, weights):
+        assert target.size() == weights.size()
+        # normalize the input
+        log_probabilities = self.log_softmax(input)
+        # standard CrossEntropyLoss requires the target to be (NxDxHxW), so we need to expand it to (NxCxDxHxW)
+        if self.ignore_index is not None:
+            mask = target == self.ignore_index
+            target[mask] = 0
+        else:
+            mask = torch.zeros_like(target)
+        # add channel dimension and invert the mask
+        mask = 1 - mask.unsqueeze(1)
+        # convert target to one-hot encoding
+        target = F.one_hot(target.long())
+        if target.ndim == 5:
+            # permute target to (NxCxDxHxW)
+            target = target.permute(0, 4, 1, 2, 3).contiguous()
+        else:
+            target = target.permute(0, 3, 1, 2).contiguous()
+        # apply the mask on the target
+        target = target * mask
+        # add channel dimension to the weights
+        weights = weights.unsqueeze(1)
+        # compute the losses
+        result = -weights * target * log_probabilities
+        return result.mean()
+
+
+class WeightedSmoothL1Loss(nn.SmoothL1Loss):
+    def __init__(self, threshold, initial_weight, apply_below_threshold=True):
+        super().__init__(reduction="none")
+        self.threshold = threshold
+        self.apply_below_threshold = apply_below_threshold
+        self.weight = initial_weight
+
+    def forward(self, input, target):
+        l1 = super().forward(input, target)
+
+        if self.apply_below_threshold:
+            mask = target < self.threshold
+        else:
+            mask = target >= self.threshold
+
+        l1[mask] = l1[mask] * self.weight
+
+        return l1.mean()
+
+
+def flatten(tensor):
+    """Flattens a given tensor such that the channel axis is first.
+    The shapes are transformed as follows:
+       (N, C, D, H, W) -> (C, N * D * H * W)
+    """
+    # number of channels
+    C = tensor.size(1)
+    # new axis order
+    axis_order = (1, 0) + tuple(range(2, tensor.dim()))
+    # Transpose: (N, C, D, H, W) -> (C, N, D, H, W)
+    transposed = tensor.permute(axis_order)
+    # Flatten: (C, N, D, H, W) -> (C, N * D * H * W)
+    return transposed.contiguous().view(C, -1)
+
+
+def get_loss_criterion(config):
+    """
+    Returns the loss function based on provided configuration
+    :param config: (dict) a top level configuration object containing the 'loss' key
+    :return: an instance of the loss function
+    """
+    assert 'loss' in config, 'Could not find loss function configuration'
+    loss_config = config['loss']
+    name = loss_config.pop('name')
+
+    ignore_index = loss_config.pop('ignore_index', None)
+    skip_last_target = loss_config.pop('skip_last_target', False)
+    weight = loss_config.pop('weight', None)
+
+    if weight is not None:
+        weight = torch.tensor(weight)
+
+    pos_weight = loss_config.pop('pos_weight', None)
+    if pos_weight is not None:
+        pos_weight = torch.tensor(pos_weight)
+
+    loss = _create_loss(name, loss_config, weight, ignore_index, pos_weight)
+
+    if not (ignore_index is None or name in ['CrossEntropyLoss', 'WeightedCrossEntropyLoss']):
+        # use MaskingLossWrapper only for non-cross-entropy losses, since CE losses allow specifying 'ignore_index' directly
+        loss = _MaskingLossWrapper(loss, ignore_index)
+
+    if skip_last_target:
+        loss = SkipLastTargetChannelWrapper(loss, loss_config.get('squeeze_channel', False))
+
+    if torch.cuda.is_available():
+        loss = loss.cuda()
+
+    return loss
+
+
+#######################################################################################################################
+
+def _create_loss(name, loss_config, weight, ignore_index, pos_weight):
+    if name == 'BCEWithLogitsLoss':
+        return nn.BCEWithLogitsLoss(pos_weight=pos_weight)
+    elif name == 'BCEDiceLoss':
+        alpha = loss_config.get('alpha', 1.)
+        beta = loss_config.get('beta', 1.)
+        return BCEDiceLoss(alpha, beta)
+    elif name == 'CrossEntropyLoss':
+        if ignore_index is None:
+            ignore_index = -100  # use the default 'ignore_index' as defined in the CrossEntropyLoss
+        return nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index)
+    elif name == 'WeightedCrossEntropyLoss':
+        if ignore_index is None:
+            ignore_index = -100  # use the default 'ignore_index' as defined in the CrossEntropyLoss
+        return WeightedCrossEntropyLoss(ignore_index=ignore_index)
+    elif name == 'PixelWiseCrossEntropyLoss':
+        return PixelWiseCrossEntropyLoss(ignore_index=ignore_index)
+    elif name == 'GeneralizedDiceLoss':
+        normalization = loss_config.get('normalization', 'sigmoid')
+        return GeneralizedDiceLoss(normalization=normalization)
+    elif name == 'DiceLoss':
+        normalization = loss_config.get('normalization', 'sigmoid')
+        return DiceLoss(weight=weight, normalization=normalization)
+    elif name == 'MSELoss':
+        return MSELoss()
+    elif name == 'SmoothL1Loss':
+        return SmoothL1Loss()
+    elif name == 'L1Loss':
+        return L1Loss()
+    elif name == 'WeightedSmoothL1Loss':
+        return WeightedSmoothL1Loss(threshold=loss_config['threshold'],
+                                    initial_weight=loss_config['initial_weight'],
+                                    apply_below_threshold=loss_config.get('apply_below_threshold', True))
+    else:
+        raise RuntimeError(f"Unsupported loss function: '{name}'")
diff --git a/build/lib/pytorch3dunet/unet3d/metrics.py b/build/lib/pytorch3dunet/unet3d/metrics.py
new file mode 100644
index 00000000..2b60b4b7
--- /dev/null
+++ b/build/lib/pytorch3dunet/unet3d/metrics.py
@@ -0,0 +1,445 @@
+import importlib
+
+import numpy as np
+import torch
+from skimage import measure
+from skimage.metrics import adapted_rand_error, peak_signal_noise_ratio, mean_squared_error
+
+from pytorch3dunet.unet3d.losses import compute_per_channel_dice
+from pytorch3dunet.unet3d.seg_metrics import AveragePrecision, Accuracy
+from pytorch3dunet.unet3d.utils import get_logger, expand_as_one_hot, convert_to_numpy
+
+logger = get_logger('EvalMetric')
+
+
+class DiceCoefficient:
+    """Computes Dice Coefficient.
+    Generalized to multiple channels by computing per-channel Dice Score
+    (as described in https://arxiv.org/pdf/1707.03237.pdf) and then simply taking the average.
+    Input is expected to be probabilities instead of logits.
+    This metric is mostly useful when channels contain the same semantic class (e.g. affinities computed with different offsets).
+    DO NOT USE this metric when training with DiceLoss, otherwise the results will be biased towards the loss.
+    """
+
+    def __init__(self, epsilon=1e-6, **kwargs):
+        self.epsilon = epsilon
+
+    def __call__(self, input, target):
+        # Average across channels in order to get the final score
+        return torch.mean(compute_per_channel_dice(input, target, epsilon=self.epsilon))
+
+
+class MeanIoU:
+    """
+    Computes IoU for each class separately and then averages over all classes.
+    """
+
+    def __init__(self, skip_channels=(), ignore_index=None, **kwargs):
+        """
+        :param skip_channels: list/tuple of channels to be ignored from the IoU computation
+        :param ignore_index: id of the label to be ignored from IoU computation
+        """
+        self.ignore_index = ignore_index
+        self.skip_channels = skip_channels
+
+    def __call__(self, input, target):
+        """
+        :param input: 5D probability maps torch float tensor (NxCxDxHxW)
+        :param target: 4D or 5D ground truth torch tensor. 4D (NxDxHxW) tensor will be expanded to 5D as one-hot
+        :return: intersection over union averaged over all channels
+        """
+        assert input.dim() == 5
+
+        n_classes = input.size()[1]
+
+        if target.dim() == 4:
+            target = expand_as_one_hot(target, C=n_classes, ignore_index=self.ignore_index)
+
+        assert input.size() == target.size()
+
+        per_batch_iou = []
+        for _input, _target in zip(input, target):
+            binary_prediction = self._binarize_predictions(_input, n_classes)
+
+            if self.ignore_index is not None:
+                # zero out ignore_index
+                mask = _target == self.ignore_index
+                binary_prediction[mask] = 0
+                _target[mask] = 0
+
+            # convert to uint8 just in case
+            binary_prediction = binary_prediction.byte()
+            _target = _target.byte()
+
+            per_channel_iou = []
+            for c in range(n_classes):
+                if c in self.skip_channels:
+                    continue
+
+                per_channel_iou.append(self._jaccard_index(binary_prediction[c], _target[c]))
+
+            assert per_channel_iou, "All channels were ignored from the computation"
+            mean_iou = torch.mean(torch.tensor(per_channel_iou))
+            per_batch_iou.append(mean_iou)
+
+        return torch.mean(torch.tensor(per_batch_iou))
+
+    def _binarize_predictions(self, input, n_classes):
+        """
+        Puts 1 for the class/channel with the highest probability and 0 in other channels. Returns byte tensor of the
+        same size as the input tensor.
+        """
+        if n_classes == 1:
+            # for single channel input just threshold the probability map
+            result = input > 0.5
+            return result.long()
+
+        _, max_index = torch.max(input, dim=0, keepdim=True)
+        return torch.zeros_like(input, dtype=torch.uint8).scatter_(0, max_index, 1)
+
+    def _jaccard_index(self, prediction, target):
+        """
+        Computes IoU for a given target and prediction tensors
+        """
+        return torch.sum(prediction & target).float() / torch.clamp(torch.sum(prediction | target).float(), min=1e-8)
+
+
+class AdaptedRandError:
+    """
+    A functor which computes an Adapted Rand error as defined by the SNEMI3D contest
+    (http://brainiac2.mit.edu/SNEMI3D/evaluation).
+
+    This is a generic implementation which takes the input, converts it to the segmentation image (see `input_to_segm()`)
+    and then computes the ARand between the segmentation and the ground truth target. Depending on one's use case
+    it's enough to extend this class and implement the `input_to_segm` method.
+
+    Args:
+        use_last_target (bool): if true, use the last channel from the target to compute the ARand, otherwise the first.
+    """
+
+    def __init__(self, use_last_target=False, ignore_index=None, **kwargs):
+        self.use_last_target = use_last_target
+        self.ignore_index = ignore_index
+
+    def __call__(self, input, target):
+        """
+        Compute ARand Error for each input, target pair in the batch and return the mean value.
+
+        Args:
+            input (torch.tensor):  5D (NCDHW) output from the network
+            target (torch.tensor): 5D (NCDHW) ground truth segmentation
+
+        Returns:
+            average ARand Error across the batch
+        """
+
+        # converts input and target to numpy arrays
+        input, target = convert_to_numpy(input, target)
+        if self.use_last_target:
+            target = target[:, -1, ...]  # 4D
+        else:
+            # use 1st target channel
+            target = target[:, 0, ...]  # 4D
+
+        # ensure target is of integer type
+        target = target.astype(np.int32)
+
+        if self.ignore_index is not None:
+            target[target == self.ignore_index] = 0
+
+        per_batch_arand = []
+        for _input, _target in zip(input, target):
+            if np.all(_target == _target.flat[0]):  # skip ARand eval if there is only one label in the patch due to zero-division
+                logger.info('Skipping ARandError computation: only 1 label present in the ground truth')
+                per_batch_arand.append(0.)
+                continue
+
+            # convert _input to segmentation CDHW
+            segm = self.input_to_segm(_input)
+            assert segm.ndim == 4
+
+            # compute per channel arand and return the minimum value
+            per_channel_arand = [adapted_rand_error(_target, channel_segm)[0] for channel_segm in segm]
+            per_batch_arand.append(np.min(per_channel_arand))
+
+        # return mean arand error
+        mean_arand = torch.mean(torch.tensor(per_batch_arand))
+        logger.info(f'ARand: {mean_arand.item()}')
+        return mean_arand
+
+    def input_to_segm(self, input):
+        """
+        Converts input tensor (output from the network) to the segmentation image. E.g. if the input is the boundary
+        pmaps then one option would be to threshold it and run connected components in order to return the segmentation.
+
+        :param input: 4D tensor (CDHW)
+        :return: segmentation volume either 4D (segmentation per channel)
+        """
+        # by deafult assume that input is a segmentation volume itself
+        return input
+
+
+class BoundaryAdaptedRandError(AdaptedRandError):
+    """
+    Compute ARand between the input boundary map and target segmentation.
+    Boundary map is thresholded, and connected components is run to get the predicted segmentation
+    """
+
+    def __init__(self, thresholds=None, use_last_target=True, ignore_index=None, input_channel=None, invert_pmaps=True,
+                 save_plots=False, plots_dir='.', **kwargs):
+        super().__init__(use_last_target=use_last_target, ignore_index=ignore_index, save_plots=save_plots,
+                         plots_dir=plots_dir, **kwargs)
+
+        if thresholds is None:
+            thresholds = [0.3, 0.4, 0.5, 0.6]
+        assert isinstance(thresholds, list)
+        self.thresholds = thresholds
+        self.input_channel = input_channel
+        self.invert_pmaps = invert_pmaps
+
+    def input_to_segm(self, input):
+        if self.input_channel is not None:
+            input = np.expand_dims(input[self.input_channel], axis=0)
+
+        segs = []
+        for predictions in input:
+            for th in self.thresholds:
+                # threshold probability maps
+                predictions = predictions > th
+
+                if self.invert_pmaps:
+                    # for connected component analysis we need to treat boundary signal as background
+                    # assign 0-label to boundary mask
+                    predictions = np.logical_not(predictions)
+
+                predictions = predictions.astype(np.uint8)
+                # run connected components on the predicted mask; consider only 1-connectivity
+                seg = measure.label(predictions, background=0, connectivity=1)
+                segs.append(seg)
+
+        return np.stack(segs)
+
+
+class GenericAdaptedRandError(AdaptedRandError):
+    def __init__(self, input_channels, thresholds=None, use_last_target=True, ignore_index=None, invert_channels=None,
+                 **kwargs):
+
+        super().__init__(use_last_target=use_last_target, ignore_index=ignore_index, **kwargs)
+        assert isinstance(input_channels, list) or isinstance(input_channels, tuple)
+        self.input_channels = input_channels
+        if thresholds is None:
+            thresholds = [0.3, 0.4, 0.5, 0.6]
+        assert isinstance(thresholds, list)
+        self.thresholds = thresholds
+        if invert_channels is None:
+            invert_channels = []
+        self.invert_channels = invert_channels
+
+    def input_to_segm(self, input):
+        # pick only the channels specified in the input_channels
+        results = []
+        for i in self.input_channels:
+            c = input[i]
+            # invert channel if necessary
+            if i in self.invert_channels:
+                c = 1 - c
+            results.append(c)
+
+        input = np.stack(results)
+
+        segs = []
+        for predictions in input:
+            for th in self.thresholds:
+                # run connected components on the predicted mask; consider only 1-connectivity
+                seg = measure.label((predictions > th).astype(np.uint8), background=0, connectivity=1)
+                segs.append(seg)
+
+        return np.stack(segs)
+
+
+class GenericAveragePrecision:
+    def __init__(self, min_instance_size=None, use_last_target=False, metric='ap', **kwargs):
+        self.min_instance_size = min_instance_size
+        self.use_last_target = use_last_target
+        assert metric in ['ap', 'acc']
+        if metric == 'ap':
+            # use AveragePrecision
+            self.metric = AveragePrecision()
+        else:
+            # use Accuracy at 0.5 IoU
+            self.metric = Accuracy(iou_threshold=0.5)
+
+    def __call__(self, input, target):
+        if target.dim() == 5:
+            if self.use_last_target:
+                target = target[:, -1, ...]  # 4D
+            else:
+                # use 1st target channel
+                target = target[:, 0, ...]  # 4D
+
+        input1 = input2 = input
+        multi_head = isinstance(input, tuple)
+        if multi_head:
+            input1, input2 = input
+
+        input1, input2, target = convert_to_numpy(input1, input2, target)
+
+        batch_aps = []
+        i_batch = 0
+        # iterate over the batch
+        for inp1, inp2, tar in zip(input1, input2, target):
+            if multi_head:
+                inp = (inp1, inp2)
+            else:
+                inp = inp1
+
+            segs = self.input_to_seg(inp, tar)  # expects 4D
+            assert segs.ndim == 4
+            # convert target to seg
+            tar = self.target_to_seg(tar)
+
+            # filter small instances if necessary
+            tar = self._filter_instances(tar)
+
+            # compute average precision per channel
+            segs_aps = [self.metric(self._filter_instances(seg), tar) for seg in segs]
+
+            logger.info(f'Batch: {i_batch}. Max Average Precision for channel: {np.argmax(segs_aps)}')
+            # save max AP
+            batch_aps.append(np.max(segs_aps))
+            i_batch += 1
+
+        return torch.tensor(batch_aps).mean()
+
+    def _filter_instances(self, input):
+        """
+        Filters instances smaller than 'min_instance_size' by overriding them with 0-index
+        :param input: input instance segmentation
+        """
+        if self.min_instance_size is not None:
+            labels, counts = np.unique(input, return_counts=True)
+            for label, count in zip(labels, counts):
+                if count < self.min_instance_size:
+                    input[input == label] = 0
+        return input
+
+    def input_to_seg(self, input, target=None):
+        raise NotImplementedError
+
+    def target_to_seg(self, target):
+        return target
+
+
+class BlobsAveragePrecision(GenericAveragePrecision):
+    """
+    Computes Average Precision given foreground prediction and ground truth instance segmentation.
+    """
+
+    def __init__(self, thresholds=None, metric='ap', min_instance_size=None, input_channel=0, **kwargs):
+        super().__init__(min_instance_size=min_instance_size, use_last_target=True, metric=metric)
+        if thresholds is None:
+            thresholds = [0.4, 0.5, 0.6, 0.7, 0.8]
+        assert isinstance(thresholds, list)
+        self.thresholds = thresholds
+        self.input_channel = input_channel
+
+    def input_to_seg(self, input, target=None):
+        input = input[self.input_channel]
+        segs = []
+        for th in self.thresholds:
+            # threshold and run connected components
+            mask = (input > th).astype(np.uint8)
+            seg = measure.label(mask, background=0, connectivity=1)
+            segs.append(seg)
+        return np.stack(segs)
+
+
+class BlobsBoundaryAveragePrecision(GenericAveragePrecision):
+    """
+    Computes Average Precision given foreground prediction, boundary prediction and ground truth instance segmentation.
+    Segmentation mask is computed as (P_mask - P_boundary) > th followed by a connected component
+    """
+
+    def __init__(self, thresholds=None, metric='ap', min_instance_size=None, **kwargs):
+        super().__init__(min_instance_size=min_instance_size, use_last_target=True, metric=metric)
+        if thresholds is None:
+            thresholds = [0.3, 0.4, 0.5, 0.6, 0.7]
+        assert isinstance(thresholds, list)
+        self.thresholds = thresholds
+
+    def input_to_seg(self, input, target=None):
+        # input = P_mask - P_boundary
+        input = input[0] - input[1]
+        segs = []
+        for th in self.thresholds:
+            # threshold and run connected components
+            mask = (input > th).astype(np.uint8)
+            seg = measure.label(mask, background=0, connectivity=1)
+            segs.append(seg)
+        return np.stack(segs)
+
+
+class BoundaryAveragePrecision(GenericAveragePrecision):
+    """
+    Computes Average Precision given boundary prediction and ground truth instance segmentation.
+    """
+
+    def __init__(self, thresholds=None, min_instance_size=None, input_channel=0, **kwargs):
+        super().__init__(min_instance_size=min_instance_size, use_last_target=True)
+        if thresholds is None:
+            thresholds = [0.3, 0.4, 0.5, 0.6]
+        assert isinstance(thresholds, list)
+        self.thresholds = thresholds
+        self.input_channel = input_channel
+
+    def input_to_seg(self, input, target=None):
+        input = input[self.input_channel]
+        segs = []
+        for th in self.thresholds:
+            seg = measure.label(np.logical_not(input > th).astype(np.uint8), background=0, connectivity=1)
+            segs.append(seg)
+        return np.stack(segs)
+
+
+class PSNR:
+    """
+    Computes Peak Signal to Noise Ratio. Use e.g. as an eval metric for denoising task
+    """
+
+    def __init__(self, **kwargs):
+        pass
+
+    def __call__(self, input, target):
+        input, target = convert_to_numpy(input, target)
+        return peak_signal_noise_ratio(target, input)
+
+
+class MSE:
+    """
+    Computes MSE between input and target
+    """
+
+    def __init__(self, **kwargs):
+        pass
+
+    def __call__(self, input, target):
+        input, target = convert_to_numpy(input, target)
+        return mean_squared_error(input, target)
+
+
+def get_evaluation_metric(config):
+    """
+    Returns the evaluation metric function based on provided configuration
+    :param config: (dict) a top level configuration object containing the 'eval_metric' key
+    :return: an instance of the evaluation metric
+    """
+
+    def _metric_class(class_name):
+        m = importlib.import_module('pytorch3dunet.unet3d.metrics')
+        clazz = getattr(m, class_name)
+        return clazz
+
+    assert 'eval_metric' in config, 'Could not find evaluation metric configuration'
+    metric_config = config['eval_metric']
+    metric_class = _metric_class(metric_config['name'])
+    return metric_class(**metric_config)
diff --git a/build/lib/pytorch3dunet/unet3d/model.py b/build/lib/pytorch3dunet/unet3d/model.py
new file mode 100644
index 00000000..e4de49a7
--- /dev/null
+++ b/build/lib/pytorch3dunet/unet3d/model.py
@@ -0,0 +1,249 @@
+import torch.nn as nn
+
+from pytorch3dunet.unet3d.buildingblocks import DoubleConv, ResNetBlock, ResNetBlockSE, \
+    create_decoders, create_encoders
+from pytorch3dunet.unet3d.utils import get_class, number_of_features_per_level
+
+
+class AbstractUNet(nn.Module):
+    """
+    Base class for standard and residual UNet.
+
+    Args:
+        in_channels (int): number of input channels
+        out_channels (int): number of output segmentation masks;
+            Note that the of out_channels might correspond to either
+            different semantic classes or to different binary segmentation mask.
+            It's up to the user of the class to interpret the out_channels and
+            use the proper loss criterion during training (i.e. CrossEntropyLoss (multi-class)
+            or BCEWithLogitsLoss (two-class) respectively)
+        f_maps (int, tuple): number of feature maps at each level of the encoder; if it's an integer the number
+            of feature maps is given by the geometric progression: f_maps ^ k, k=1,2,3,4
+        final_sigmoid (bool): if True apply element-wise nn.Sigmoid after the final 1x1 convolution,
+            otherwise apply nn.Softmax. In effect only if `self.training == False`, i.e. during validation/testing
+        basic_module: basic model for the encoder/decoder (DoubleConv, ResNetBlock, ....)
+        layer_order (string): determines the order of layers in `SingleConv` module.
+            E.g. 'crg' stands for GroupNorm3d+Conv3d+ReLU. See `SingleConv` for more info
+        num_groups (int): number of groups for the GroupNorm
+        num_levels (int): number of levels in the encoder/decoder path (applied only if f_maps is an int)
+            default: 4
+        is_segmentation (bool): if True and the model is in eval mode, Sigmoid/Softmax normalization is applied
+            after the final convolution; if False (regression problem) the normalization layer is skipped
+        conv_kernel_size (int or tuple): size of the convolving kernel in the basic_module
+        pool_kernel_size (int or tuple): the size of the window
+        conv_padding (int or tuple): add zero-padding added to all three sides of the input
+        conv_upscale (int): number of the convolution to upscale in encoder if DoubleConv, default: 2
+        upsample (str): algorithm used for decoder upsampling:
+            InterpolateUpsampling:   'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area'
+            TransposeConvUpsampling: 'deconv'
+            No upsampling:           None
+            Default: 'default' (chooses automatically)
+        dropout_prob (float or tuple): dropout probability, default: 0.1
+        is3d (bool): if True the model is 3D, otherwise 2D, default: True
+    """
+
+    def __init__(self, in_channels, out_channels, final_sigmoid, basic_module, f_maps=64, layer_order='gcr',
+                 num_groups=8, num_levels=4, is_segmentation=True, conv_kernel_size=3, pool_kernel_size=2,
+                 conv_padding=1, conv_upscale=2, upsample='default', dropout_prob=0.1, is3d=True):
+        super(AbstractUNet, self).__init__()
+
+        if isinstance(f_maps, int):
+            f_maps = number_of_features_per_level(f_maps, num_levels=num_levels)
+
+        assert isinstance(f_maps, list) or isinstance(f_maps, tuple)
+        assert len(f_maps) > 1, "Required at least 2 levels in the U-Net"
+        if 'g' in layer_order:
+            assert num_groups is not None, "num_groups must be specified if GroupNorm is used"
+
+        # create encoder path
+        self.encoders = create_encoders(in_channels, f_maps, basic_module, conv_kernel_size,
+                                        conv_padding, conv_upscale, dropout_prob,
+                                        layer_order, num_groups, pool_kernel_size, is3d)
+
+        # create decoder path
+        self.decoders = create_decoders(f_maps, basic_module, conv_kernel_size, conv_padding,
+                                        layer_order, num_groups, upsample, dropout_prob,
+                                        is3d)
+
+        # in the last layer a 1×1 convolution reduces the number of output channels to the number of labels
+        if is3d:
+            self.final_conv = nn.Conv3d(f_maps[0], out_channels, 1)
+        else:
+            self.final_conv = nn.Conv2d(f_maps[0], out_channels, 1)
+
+        if is_segmentation:
+            # semantic segmentation problem
+            if final_sigmoid:
+                self.final_activation = nn.Sigmoid()
+            else:
+                self.final_activation = nn.Softmax(dim=1)
+        else:
+            # regression problem
+            self.final_activation = None
+
+    def forward(self, x):
+        # encoder part
+        encoders_features = []
+        for encoder in self.encoders:
+            x = encoder(x)
+            # reverse the encoder outputs to be aligned with the decoder
+            encoders_features.insert(0, x)
+
+        # remove the last encoder's output from the list
+        # !!remember: it's the 1st in the list
+        encoders_features = encoders_features[1:]
+
+        # decoder part
+        for decoder, encoder_features in zip(self.decoders, encoders_features):
+            # pass the output from the corresponding encoder and the output
+            # of the previous decoder
+            x = decoder(encoder_features, x)
+
+        x = self.final_conv(x)
+
+        # apply final_activation (i.e. Sigmoid or Softmax) only during prediction.
+        # During training the network outputs logits
+        if not self.training and self.final_activation is not None:
+            x = self.final_activation(x)
+
+        return x
+
+
+class UNet3D(AbstractUNet):
+    """
+    3DUnet model from
+    `"3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation"
+        <https://arxiv.org/pdf/1606.06650.pdf>`.
+
+    Uses `DoubleConv` as a basic_module and nearest neighbor upsampling in the decoder
+    """
+
+    def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr',
+                 num_groups=8, num_levels=4, is_segmentation=True, conv_padding=1,
+                 conv_upscale=2, upsample='default', dropout_prob=0.1, **kwargs):
+        super(UNet3D, self).__init__(in_channels=in_channels,
+                                     out_channels=out_channels,
+                                     final_sigmoid=final_sigmoid,
+                                     basic_module=DoubleConv,
+                                     f_maps=f_maps,
+                                     layer_order=layer_order,
+                                     num_groups=num_groups,
+                                     num_levels=num_levels,
+                                     is_segmentation=is_segmentation,
+                                     conv_padding=conv_padding,
+                                     conv_upscale=conv_upscale,
+                                     upsample=upsample,
+                                     dropout_prob=dropout_prob,
+                                     is3d=True)
+
+
+class ResidualUNet3D(AbstractUNet):
+    """
+    Residual 3DUnet model implementation based on https://arxiv.org/pdf/1706.00120.pdf.
+    Uses ResNetBlock as a basic building block, summation joining instead
+    of concatenation joining and transposed convolutions for upsampling (watch out for block artifacts).
+    Since the model effectively becomes a residual net, in theory it allows for deeper UNet.
+    """
+
+    def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr',
+                 num_groups=8, num_levels=5, is_segmentation=True, conv_padding=1,
+                 conv_upscale=2, upsample='default', dropout_prob=0.1, **kwargs):
+        super(ResidualUNet3D, self).__init__(in_channels=in_channels,
+                                             out_channels=out_channels,
+                                             final_sigmoid=final_sigmoid,
+                                             basic_module=ResNetBlock,
+                                             f_maps=f_maps,
+                                             layer_order=layer_order,
+                                             num_groups=num_groups,
+                                             num_levels=num_levels,
+                                             is_segmentation=is_segmentation,
+                                             conv_padding=conv_padding,
+                                             conv_upscale=conv_upscale,
+                                             upsample=upsample,
+                                             dropout_prob=dropout_prob,
+                                             is3d=True)
+
+
+class ResidualUNetSE3D(AbstractUNet):
+    """_summary_
+    Residual 3DUnet model implementation with squeeze and excitation based on 
+    https://arxiv.org/pdf/1706.00120.pdf.
+    Uses ResNetBlockSE as a basic building block, summation joining instead
+    of concatenation joining and transposed convolutions for upsampling (watch
+    out for block artifacts). Since the model effectively becomes a residual
+    net, in theory it allows for deeper UNet.
+    """
+
+    def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr',
+                 num_groups=8, num_levels=5, is_segmentation=True, conv_padding=1,
+                 conv_upscale=2, upsample='default', dropout_prob=0.1, **kwargs):
+        super(ResidualUNetSE3D, self).__init__(in_channels=in_channels,
+                                               out_channels=out_channels,
+                                               final_sigmoid=final_sigmoid,
+                                               basic_module=ResNetBlockSE,
+                                               f_maps=f_maps,
+                                               layer_order=layer_order,
+                                               num_groups=num_groups,
+                                               num_levels=num_levels,
+                                               is_segmentation=is_segmentation,
+                                               conv_padding=conv_padding,
+                                               conv_upscale=conv_upscale,
+                                               upsample=upsample,
+                                               dropout_prob=dropout_prob,
+                                               is3d=True)
+
+
+class UNet2D(AbstractUNet):
+    """
+    2DUnet model from
+    `"U-Net: Convolutional Networks for Biomedical Image Segmentation" <https://arxiv.org/abs/1505.04597>`
+    """
+
+    def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr',
+                 num_groups=8, num_levels=4, is_segmentation=True, conv_padding=1,
+                 conv_upscale=2, upsample='default', dropout_prob=0.1, **kwargs):
+        super(UNet2D, self).__init__(in_channels=in_channels,
+                                     out_channels=out_channels,
+                                     final_sigmoid=final_sigmoid,
+                                     basic_module=DoubleConv,
+                                     f_maps=f_maps,
+                                     layer_order=layer_order,
+                                     num_groups=num_groups,
+                                     num_levels=num_levels,
+                                     is_segmentation=is_segmentation,
+                                     conv_padding=conv_padding,
+                                     conv_upscale=conv_upscale,
+                                     upsample=upsample,
+                                     dropout_prob=dropout_prob,
+                                     is3d=False)
+
+
+class ResidualUNet2D(AbstractUNet):
+    """
+    Residual 2DUnet model implementation based on https://arxiv.org/pdf/1706.00120.pdf.
+    """
+
+    def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr',
+                 num_groups=8, num_levels=5, is_segmentation=True, conv_padding=1,
+                 conv_upscale=2, upsample='default', dropout_prob=0.1, **kwargs):
+        super(ResidualUNet2D, self).__init__(in_channels=in_channels,
+                                             out_channels=out_channels,
+                                             final_sigmoid=final_sigmoid,
+                                             basic_module=ResNetBlock,
+                                             f_maps=f_maps,
+                                             layer_order=layer_order,
+                                             num_groups=num_groups,
+                                             num_levels=num_levels,
+                                             is_segmentation=is_segmentation,
+                                             conv_padding=conv_padding,
+                                             conv_upscale=conv_upscale,
+                                             upsample=upsample,
+                                             dropout_prob=dropout_prob,
+                                             is3d=False)
+
+
+def get_model(model_config):
+    model_class = get_class(model_config['name'], modules=[
+        'pytorch3dunet.unet3d.model'
+    ])
+    return model_class(**model_config)
diff --git a/build/lib/pytorch3dunet/unet3d/predictor.py b/build/lib/pytorch3dunet/unet3d/predictor.py
new file mode 100644
index 00000000..c9b4f6eb
--- /dev/null
+++ b/build/lib/pytorch3dunet/unet3d/predictor.py
@@ -0,0 +1,281 @@
+import os
+import time
+from concurrent import futures
+from pathlib import Path
+
+import h5py
+import numpy as np
+import torch
+from skimage import measure
+from torch import nn
+from tqdm import tqdm
+
+from pytorch3dunet.datasets.hdf5 import AbstractHDF5Dataset
+from pytorch3dunet.datasets.utils import SliceBuilder, remove_padding
+from pytorch3dunet.unet3d.model import UNet2D
+from pytorch3dunet.unet3d.utils import get_logger
+
+logger = get_logger('UNetPredictor')
+
+
+def _get_output_file(dataset, suffix='_predictions', output_dir=None):
+    input_dir, file_name = os.path.split(dataset.file_path)
+    if output_dir is None:
+        output_dir = input_dir
+    output_filename = os.path.splitext(file_name)[0] + suffix + '.h5'
+    return Path(output_dir) / output_filename
+
+
+def _is_2d_model(model):
+    if isinstance(model, nn.DataParallel):
+        model = model.module
+    return isinstance(model, UNet2D)
+
+
+class _AbstractPredictor:
+    def __init__(self,
+                 model: nn.Module,
+                 output_dir: str,
+                 out_channels: int,
+                 output_dataset: str = 'predictions',
+                 save_segmentation: bool = False,
+                 prediction_channel: int = None,
+                 **kwargs):
+        """
+        Base class for predictors.
+        Args:
+            model: segmentation model
+            output_dir: directory where the predictions will be saved
+            out_channels: number of output channels of the model
+            output_dataset: name of the dataset in the H5 file where the predictions will be saved
+            save_segmentation: if true the segmentation will be saved instead of the probability maps
+            prediction_channel: save only the specified channel from the network output
+        """
+        self.model = model
+        self.output_dir = output_dir
+        self.out_channels = out_channels
+        self.output_dataset = output_dataset
+        self.save_segmentation = save_segmentation
+        self.prediction_channel = prediction_channel
+
+    def __call__(self, test_loader):
+        raise NotImplementedError
+
+
+class StandardPredictor(_AbstractPredictor):
+    """
+    Applies the model on the given dataset and saves the result as H5 file.
+    Predictions from the network are kept in memory. If the results from the network don't fit in into RAM
+    use `LazyPredictor` instead.
+
+    The output dataset names inside the H5 is given by `output_dataset` config argument.
+    """
+
+    def __init__(self,
+                 model: nn.Module,
+                 output_dir: str,
+                 out_channels: int,
+                 output_dataset: str = 'predictions',
+                 save_segmentation: bool = False,
+                 prediction_channel: int = None,
+                 **kwargs):
+        super().__init__(model, output_dir, out_channels, output_dataset, save_segmentation, prediction_channel,
+                         **kwargs)
+
+    def __call__(self, test_loader):
+        assert isinstance(test_loader.dataset, AbstractHDF5Dataset)
+        logger.info(f"Processing '{test_loader.dataset.file_path}'...")
+        start = time.perf_counter()
+
+        logger.info(f'Running inference on {len(test_loader)} batches')
+        # dimensionality of the output predictions
+        volume_shape = test_loader.dataset.volume_shape()
+        if self.prediction_channel is not None:
+            # single channel prediction map
+            prediction_maps_shape = (1,) + volume_shape
+        else:
+            prediction_maps_shape = (self.out_channels,) + volume_shape
+
+        # create destination H5 file
+        output_file = _get_output_file(dataset=test_loader.dataset, output_dir=self.output_dir)
+        with h5py.File(output_file, 'w') as h5_output_file:
+            # allocate prediction and normalization arrays
+            logger.info('Allocating prediction and normalization arrays...')
+            prediction_map, normalization_mask = self._allocate_prediction_maps(prediction_maps_shape, h5_output_file)
+
+            # determine halo used for padding
+            patch_halo = test_loader.dataset.halo_shape
+
+            # Sets the module in evaluation mode explicitly
+            # It is necessary for batchnorm/dropout layers if present as well as final Sigmoid/Softmax to be applied
+            self.model.eval()
+            # Run predictions on the entire input dataset
+            with torch.no_grad():
+                for input, indices in tqdm(test_loader):
+                    # send batch to gpu
+                    if torch.cuda.is_available():
+                        input = input.pin_memory().cuda(non_blocking=True)
+
+                    if _is_2d_model(self.model):
+                        # remove the singleton z-dimension from the input
+                        input = torch.squeeze(input, dim=-3)
+                        # forward pass
+                        prediction = self.model(input)
+                        # add the singleton z-dimension to the output
+                        prediction = torch.unsqueeze(prediction, dim=-3)
+                    else:
+                        # forward pass
+                        prediction = self.model(input)
+
+                    # unpad the predicted patch
+                    prediction = remove_padding(prediction, patch_halo)
+                    # convert to numpy array
+                    prediction = prediction.cpu().numpy()
+                    # for each batch sample
+                    for pred, index in zip(prediction, indices):
+                        # save patch index: (C,D,H,W)
+                        if self.prediction_channel is None:
+                            channel_slice = slice(0, self.out_channels)
+                        else:
+                            # use only the specified channel
+                            channel_slice = slice(0, 1)
+                            pred = np.expand_dims(pred[self.prediction_channel], axis=0)
+
+                        # add channel dimension to the index
+                        index = (channel_slice,) + tuple(index)
+                        # accumulate probabilities into the output prediction array
+                        prediction_map[index] += pred
+                        # count voxel visits for normalization
+                        normalization_mask[index] += 1
+
+            logger.info(f'Finished inference in {time.perf_counter() - start:.2f} seconds')
+            # save results
+            output_type = 'segmentation' if self.save_segmentation else 'probability maps'
+            logger.info(f'Saving {output_type} to: {output_file}')
+            self._save_results(prediction_map, normalization_mask, h5_output_file, test_loader.dataset)
+
+    def _allocate_prediction_maps(self, output_shape, output_file):
+        # initialize the output prediction arrays
+        prediction_map = np.zeros(output_shape, dtype='float32')
+        # initialize normalization mask in order to average out probabilities of overlapping patches
+        normalization_mask = np.zeros(output_shape, dtype='uint8')
+        return prediction_map, normalization_mask
+
+    def _save_results(self, prediction_map, normalization_mask, output_file, dataset):
+        result = prediction_map / normalization_mask
+        if self.save_segmentation:
+            result = np.argmax(result, axis=0).astype('uint16')
+        output_file.create_dataset(self.output_dataset, data=result, compression="gzip")
+
+
+class LazyPredictor(StandardPredictor):
+    """
+        Applies the model on the given dataset and saves the result in the `output_file` in the H5 format.
+        Predicted patches are directly saved into the H5 and they won't be stored in memory. Since this predictor
+        is slower than the `StandardPredictor` it should only be used when the predicted volume does not fit into RAM.
+        """
+
+    def __init__(self,
+                 model: nn.Module,
+                 output_dir: str,
+                 out_channels: int,
+                 output_dataset: str = 'predictions',
+                 save_segmentation: bool = False,
+                 prediction_channel: int = None,
+                 **kwargs):
+        super().__init__(model, output_dir, out_channels, output_dataset, save_segmentation, prediction_channel,
+                         **kwargs)
+
+    def _allocate_prediction_maps(self, output_shape, output_file):
+        # allocate datasets for probability maps
+        prediction_map = output_file.create_dataset(self.output_dataset,
+                                                    shape=output_shape,
+                                                    dtype='float32',
+                                                    chunks=True,
+                                                    compression='gzip')
+        # allocate datasets for normalization masks
+        normalization_mask = output_file.create_dataset('normalization',
+                                                        shape=output_shape,
+                                                        dtype='uint8',
+                                                        chunks=True,
+                                                        compression='gzip')
+        return prediction_map, normalization_mask
+
+    def _save_results(self, prediction_map, normalization_mask, output_file, dataset):
+        z, y, x = prediction_map.shape[1:]
+        # take slices which are 1/27 of the original volume
+        patch_shape = (z // 3, y // 3, x // 3)
+        if self.save_segmentation:
+            output_file.create_dataset('segmentation', shape=(z, y, x), dtype='uint16', chunks=True, compression='gzip')
+
+        for index in SliceBuilder._build_slices(prediction_map, patch_shape=patch_shape, stride_shape=patch_shape):
+            logger.info(f'Normalizing slice: {index}')
+            prediction_map[index] /= normalization_mask[index]
+            # make sure to reset the slice that has been visited already in order to avoid 'double' normalization
+            # when the patches overlap with each other
+            normalization_mask[index] = 1
+            # save segmentation
+            if self.save_segmentation:
+                output_file['segmentation'][index[1:]] = np.argmax(prediction_map[index], axis=0).astype('uint16')
+
+        del output_file['normalization']
+        if self.save_segmentation:
+            del output_file[self.output_dataset]
+
+
+class DSB2018Predictor(_AbstractPredictor):
+    def __init__(self, model, output_dir, config, save_segmentation=True, pmaps_thershold=0.5, **kwargs):
+        super().__init__(model, output_dir, config, **kwargs)
+        self.pmaps_threshold = pmaps_thershold
+        self.save_segmentation = save_segmentation
+
+    def _slice_from_pad(self, pad):
+        if pad == 0:
+            return slice(None, None)
+        else:
+            return slice(pad, -pad)
+
+    def __call__(self, test_loader):
+        # Sets the module in evaluation mode explicitly
+        self.model.eval()
+        # initial process pool for saving results to disk
+        executor = futures.ProcessPoolExecutor(max_workers=32)
+        # Run predictions on the entire input dataset
+        with torch.no_grad():
+            for img, path in test_loader:
+                # send batch to gpu
+                if torch.cuda.is_available():
+                    img = img.cuda(non_blocking=True)
+                # forward pass
+                pred = self.model(img)
+
+                executor.submit(
+                    dsb_save_batch,
+                    self.output_dir,
+                    path
+                )
+
+        print('Waiting for all predictions to be saved to disk...')
+        executor.shutdown(wait=True)
+
+
+def dsb_save_batch(output_dir, path, pred, save_segmentation=True, pmaps_thershold=0.5):
+    def _pmaps_to_seg(pred):
+        mask = (pred > pmaps_thershold)
+        return measure.label(mask).astype('uint16')
+
+    # convert to numpy array
+    for single_pred, single_path in zip(pred, path):
+        logger.info(f'Processing {single_path}')
+        single_pred = single_pred.squeeze()
+
+        # save to h5 file
+        out_file = os.path.splitext(single_path)[0] + '_predictions.h5'
+        if output_dir is not None:
+            out_file = os.path.join(output_dir, os.path.split(out_file)[1])
+
+        with h5py.File(out_file, 'w') as f:
+            # logger.info(f'Saving output to {out_file}')
+            f.create_dataset('predictions', data=single_pred, compression='gzip')
+            if save_segmentation:
+                f.create_dataset('segmentation', data=_pmaps_to_seg(single_pred), compression='gzip')
diff --git a/build/lib/pytorch3dunet/unet3d/se.py b/build/lib/pytorch3dunet/unet3d/se.py
new file mode 100644
index 00000000..23fac3d7
--- /dev/null
+++ b/build/lib/pytorch3dunet/unet3d/se.py
@@ -0,0 +1,113 @@
+"""
+3D Squeeze and Excitation Modules
+*****************************
+3D Extensions of the following 2D squeeze and excitation blocks:
+    1. `Channel Squeeze and Excitation <https://arxiv.org/abs/1709.01507>`_
+    2. `Spatial Squeeze and Excitation <https://arxiv.org/abs/1803.02579>`_
+    3. `Channel and Spatial Squeeze and Excitation <https://arxiv.org/abs/1803.02579>`_
+New Project & Excite block, designed specifically for 3D inputs
+    'quote'
+    Coded by -- Anne-Marie Rickmann (https://github.com/arickm)
+"""
+
+import torch
+from torch import nn as nn
+from torch.nn import functional as F
+
+
+class ChannelSELayer3D(nn.Module):
+    """
+    3D extension of Squeeze-and-Excitation (SE) block described in:
+        *Hu et al., Squeeze-and-Excitation Networks, arXiv:1709.01507*
+        *Zhu et al., AnatomyNet, arXiv:arXiv:1808.05238*
+    """
+
+    def __init__(self, num_channels, reduction_ratio=2):
+        """
+        Args:
+            num_channels (int): No of input channels
+            reduction_ratio (int): By how much should the num_channels should be reduced
+        """
+        super(ChannelSELayer3D, self).__init__()
+        self.avg_pool = nn.AdaptiveAvgPool3d(1)
+        num_channels_reduced = num_channels // reduction_ratio
+        self.reduction_ratio = reduction_ratio
+        self.fc1 = nn.Linear(num_channels, num_channels_reduced, bias=True)
+        self.fc2 = nn.Linear(num_channels_reduced, num_channels, bias=True)
+        self.relu = nn.ReLU()
+        self.sigmoid = nn.Sigmoid()
+
+    def forward(self, x):
+        batch_size, num_channels, D, H, W = x.size()
+        # Average along each channel
+        squeeze_tensor = self.avg_pool(x)
+
+        # channel excitation
+        fc_out_1 = self.relu(self.fc1(squeeze_tensor.view(batch_size, num_channels)))
+        fc_out_2 = self.sigmoid(self.fc2(fc_out_1))
+
+        output_tensor = torch.mul(x, fc_out_2.view(batch_size, num_channels, 1, 1, 1))
+
+        return output_tensor
+
+
+class SpatialSELayer3D(nn.Module):
+    """
+    3D extension of SE block -- squeezing spatially and exciting channel-wise described in:
+        *Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks, MICCAI 2018*
+    """
+
+    def __init__(self, num_channels):
+        """
+        Args:
+            num_channels (int): No of input channels
+        """
+        super(SpatialSELayer3D, self).__init__()
+        self.conv = nn.Conv3d(num_channels, 1, 1)
+        self.sigmoid = nn.Sigmoid()
+
+    def forward(self, x, weights=None):
+        """
+        Args:
+            weights (torch.Tensor): weights for few shot learning
+            x: X, shape = (batch_size, num_channels, D, H, W)
+
+        Returns:
+            (torch.Tensor): output_tensor
+        """
+        # channel squeeze
+        batch_size, channel, D, H, W = x.size()
+
+        if weights:
+            weights = weights.view(1, channel, 1, 1)
+            out = F.conv2d(x, weights)
+        else:
+            out = self.conv(x)
+
+        squeeze_tensor = self.sigmoid(out)
+
+        # spatial excitation
+        output_tensor = torch.mul(x, squeeze_tensor.view(batch_size, 1, D, H, W))
+
+        return output_tensor
+
+
+class ChannelSpatialSELayer3D(nn.Module):
+    """
+       3D extension of concurrent spatial and channel squeeze & excitation:
+           *Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks, arXiv:1803.02579*
+       """
+
+    def __init__(self, num_channels, reduction_ratio=2):
+        """
+        Args:
+            num_channels (int): No of input channels
+            reduction_ratio (int): By how much should the num_channels should be reduced
+        """
+        super(ChannelSpatialSELayer3D, self).__init__()
+        self.cSE = ChannelSELayer3D(num_channels, reduction_ratio)
+        self.sSE = SpatialSELayer3D(num_channels)
+
+    def forward(self, input_tensor):
+        output_tensor = torch.max(self.cSE(input_tensor), self.sSE(input_tensor))
+        return output_tensor
diff --git a/build/lib/pytorch3dunet/unet3d/seg_metrics.py b/build/lib/pytorch3dunet/unet3d/seg_metrics.py
new file mode 100644
index 00000000..e713ea23
--- /dev/null
+++ b/build/lib/pytorch3dunet/unet3d/seg_metrics.py
@@ -0,0 +1,123 @@
+import numpy as np
+from skimage.metrics import contingency_table
+
+
+def precision(tp, fp, fn):
+    return tp / (tp + fp) if tp > 0 else 0
+
+
+def recall(tp, fp, fn):
+    return tp / (tp + fn) if tp > 0 else 0
+
+
+def accuracy(tp, fp, fn):
+    return tp / (tp + fp + fn) if tp > 0 else 0
+
+
+def f1(tp, fp, fn):
+    return (2 * tp) / (2 * tp + fp + fn) if tp > 0 else 0
+
+
+def _relabel(input):
+    _, unique_labels = np.unique(input, return_inverse=True)
+    return unique_labels.reshape(input.shape)
+
+
+def _iou_matrix(gt, seg):
+    # relabel gt and seg for smaller memory footprint of contingency table
+    gt = _relabel(gt)
+    seg = _relabel(seg)
+
+    # get number of overlapping pixels between GT and SEG
+    n_inter = contingency_table(gt, seg).A
+
+    # number of pixels for GT instances
+    n_gt = n_inter.sum(axis=1, keepdims=True)
+    # number of pixels for SEG instances
+    n_seg = n_inter.sum(axis=0, keepdims=True)
+
+    # number of pixels in the union between GT and SEG instances
+    n_union = n_gt + n_seg - n_inter
+
+    iou_matrix = n_inter / n_union
+    # make sure that the values are within [0,1] range
+    assert 0 <= np.min(iou_matrix) <= np.max(iou_matrix) <= 1
+
+    return iou_matrix
+
+
+class SegmentationMetrics:
+    """
+    Computes precision, recall, accuracy, f1 score for a given ground truth and predicted segmentation.
+    Contingency table for a given ground truth and predicted segmentation is computed eagerly upon construction
+    of the instance of `SegmentationMetrics`.
+
+    Args:
+        gt (ndarray): ground truth segmentation
+        seg (ndarray): predicted segmentation
+    """
+
+    def __init__(self, gt, seg):
+        self.iou_matrix = _iou_matrix(gt, seg)
+
+    def metrics(self, iou_threshold):
+        """
+        Computes precision, recall, accuracy, f1 score at a given IoU threshold
+        """
+        # ignore background
+        iou_matrix = self.iou_matrix[1:, 1:]
+        detection_matrix = (iou_matrix > iou_threshold).astype(np.uint8)
+        n_gt, n_seg = detection_matrix.shape
+
+        # if the iou_matrix is empty or all values are 0
+        trivial = min(n_gt, n_seg) == 0 or np.all(detection_matrix == 0)
+        if trivial:
+            tp = fp = fn = 0
+        else:
+            # count non-zero rows to get the number of TP
+            tp = np.count_nonzero(detection_matrix.sum(axis=1))
+            # count zero rows to get the number of FN
+            fn = n_gt - tp
+            # count zero columns to get the number of FP
+            fp = n_seg - np.count_nonzero(detection_matrix.sum(axis=0))
+
+        return {
+            'precision': precision(tp, fp, fn),
+            'recall': recall(tp, fp, fn),
+            'accuracy': accuracy(tp, fp, fn),
+            'f1': f1(tp, fp, fn)
+        }
+
+
+class Accuracy:
+    """
+    Computes accuracy between ground truth and predicted segmentation a a given threshold value.
+    Defined as: AC = TP / (TP + FP + FN).
+    Kaggle DSB2018 calls it Precision, see:
+    https://www.kaggle.com/stkbailey/step-by-step-explanation-of-scoring-metric.
+    """
+
+    def __init__(self, iou_threshold):
+        self.iou_threshold = iou_threshold
+
+    def __call__(self, input_seg, gt_seg):
+        metrics = SegmentationMetrics(gt_seg, input_seg).metrics(self.iou_threshold)
+        return metrics['accuracy']
+
+
+class AveragePrecision:
+    """
+    Average precision taken for the IoU range (0.5, 0.95) with a step of 0.05 as defined in:
+    https://www.kaggle.com/stkbailey/step-by-step-explanation-of-scoring-metric
+    """
+
+    def __init__(self):
+        self.iou_range = np.linspace(0.50, 0.95, 10)
+
+    def __call__(self, input_seg, gt_seg):
+        # compute contingency_table
+        sm = SegmentationMetrics(gt_seg, input_seg)
+        # compute accuracy for each threshold
+        acc = [sm.metrics(iou)['accuracy'] for iou in self.iou_range]
+        # return the average
+        return np.mean(acc)
diff --git a/build/lib/pytorch3dunet/unet3d/trainer.py b/build/lib/pytorch3dunet/unet3d/trainer.py
new file mode 100644
index 00000000..4b59d568
--- /dev/null
+++ b/build/lib/pytorch3dunet/unet3d/trainer.py
@@ -0,0 +1,404 @@
+import os
+import torch
+import torch.nn as nn
+from torch.optim.lr_scheduler import ReduceLROnPlateau
+from torch.utils.tensorboard import SummaryWriter
+from datetime import datetime
+
+from pytorch3dunet.datasets.utils import get_train_loaders
+from pytorch3dunet.unet3d.losses import get_loss_criterion
+from pytorch3dunet.unet3d.metrics import get_evaluation_metric
+from pytorch3dunet.unet3d.model import get_model, UNet2D
+from pytorch3dunet.unet3d.utils import get_logger, get_tensorboard_formatter, create_optimizer, \
+    create_lr_scheduler, get_number_of_learnable_parameters
+from . import utils
+
+logger = get_logger('UNetTrainer')
+
+
+def create_trainer(config):
+    # Create the model
+    model = get_model(config['model'])
+
+    if torch.cuda.device_count() > 1 and not config['device'] == 'cpu':
+        model = nn.DataParallel(model)
+        logger.info(f'Using {torch.cuda.device_count()} GPUs for prediction')
+    if torch.cuda.is_available() and not config['device'] == 'cpu':
+        model = model.cuda()
+
+    # Log the number of learnable parameters
+    logger.info(f'Number of learnable params {get_number_of_learnable_parameters(model)}')
+
+    # Create loss criterion
+    loss_criterion = get_loss_criterion(config)
+    # Create evaluation metric
+    eval_criterion = get_evaluation_metric(config)
+
+    # Create data loaders
+    loaders = get_train_loaders(config)
+
+    # Create the optimizer
+    optimizer = create_optimizer(config['optimizer'], model)
+
+    # Create learning rate adjustment strategy
+    lr_scheduler = create_lr_scheduler(config.get('lr_scheduler', None), optimizer)
+
+    trainer_config = config['trainer']
+    # Create tensorboard formatter
+    tensorboard_formatter = get_tensorboard_formatter(trainer_config.pop('tensorboard_formatter', None))
+    # Create trainer
+    resume = trainer_config.pop('resume', None)
+    pre_trained = trainer_config.pop('pre_trained', None)
+
+    return UNetTrainer(model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, loss_criterion=loss_criterion,
+                       eval_criterion=eval_criterion, loaders=loaders, tensorboard_formatter=tensorboard_formatter,
+                       resume=resume, pre_trained=pre_trained, **trainer_config)
+
+
+class UNetTrainer:
+    """UNet trainer.
+
+    Args:
+        model (Unet3D): UNet 3D model to be trained
+        optimizer (nn.optim.Optimizer): optimizer used for training
+        lr_scheduler (torch.optim.lr_scheduler._LRScheduler): learning rate scheduler
+            WARN: bear in mind that lr_scheduler.step() is invoked after every validation step
+            (i.e. validate_after_iters) not after every epoch. So e.g. if one uses StepLR with step_size=30
+            the learning rate will be adjusted after every 30 * validate_after_iters iterations.
+        loss_criterion (callable): loss function
+        eval_criterion (callable): used to compute training/validation metric (such as Dice, IoU, AP or Rand score)
+            saving the best checkpoint is based on the result of this function on the validation set
+        loaders (dict): 'train' and 'val' loaders
+        checkpoint_dir (string): dir for saving checkpoints and tensorboard logs
+        max_num_epochs (int): maximum number of epochs
+        max_num_iterations (int): maximum number of iterations
+        validate_after_iters (int): validate after that many iterations
+        log_after_iters (int): number of iterations before logging to tensorboard
+        validate_iters (int): number of validation iterations, if None validate
+            on the whole validation set
+        eval_score_higher_is_better (bool): if True higher eval scores are considered better
+        best_eval_score (float): best validation score so far (higher better)
+        num_iterations (int): useful when loading the model from the checkpoint
+        num_epoch (int): useful when loading the model from the checkpoint
+        tensorboard_formatter (callable): converts a given batch of input/output/target image to a series of images
+            that can be displayed in tensorboard
+        skip_train_validation (bool): if True eval_criterion is not evaluated on the training set (used mostly when
+            evaluation is expensive)
+    """
+
+    def __init__(self, model, optimizer, lr_scheduler, loss_criterion, eval_criterion, loaders, checkpoint_dir,
+                 max_num_epochs, max_num_iterations, validate_after_iters=200, log_after_iters=100, validate_iters=None,
+                 num_iterations=1, num_epoch=0, eval_score_higher_is_better=True, tensorboard_formatter=None,
+                 skip_train_validation=False, resume=None, pre_trained=None, **kwargs):
+
+        self.model = model
+        self.optimizer = optimizer
+        self.scheduler = lr_scheduler
+        self.loss_criterion = loss_criterion
+        self.eval_criterion = eval_criterion
+        self.loaders = loaders
+        self.checkpoint_dir = checkpoint_dir
+        self.max_num_epochs = max_num_epochs
+        self.max_num_iterations = max_num_iterations
+        self.validate_after_iters = validate_after_iters
+        self.log_after_iters = log_after_iters
+        self.validate_iters = validate_iters
+        self.eval_score_higher_is_better = eval_score_higher_is_better
+
+        logger.info(model)
+        logger.info(f'eval_score_higher_is_better: {eval_score_higher_is_better}')
+
+        # initialize the best_eval_score
+        if eval_score_higher_is_better:
+            self.best_eval_score = float('-inf')
+        else:
+            self.best_eval_score = float('+inf')
+
+        self.writer = SummaryWriter(
+            log_dir=os.path.join(
+                checkpoint_dir, 'logs', 
+                datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
+                )
+            )
+
+        assert tensorboard_formatter is not None, 'TensorboardFormatter must be provided'
+        self.tensorboard_formatter = tensorboard_formatter
+
+        self.num_iterations = num_iterations
+        self.num_epochs = num_epoch
+        self.skip_train_validation = skip_train_validation
+
+        if resume is not None:
+            logger.info(f"Loading checkpoint '{resume}'...")
+            state = utils.load_checkpoint(resume, self.model, self.optimizer)
+            logger.info(
+                f"Checkpoint loaded from '{resume}'. Epoch: {state['num_epochs']}.  Iteration: {state['num_iterations']}. "
+                f"Best val score: {state['best_eval_score']}."
+            )
+            self.best_eval_score = state['best_eval_score']
+            self.num_iterations = state['num_iterations']
+            self.num_epochs = state['num_epochs']
+            self.checkpoint_dir = os.path.split(resume)[0]
+        elif pre_trained is not None:
+            logger.info(f"Logging pre-trained model from '{pre_trained}'...")
+            utils.load_checkpoint(pre_trained, self.model, None)
+            if 'checkpoint_dir' not in kwargs:
+                self.checkpoint_dir = os.path.split(pre_trained)[0]
+
+    def fit(self):
+        for _ in range(self.num_epochs, self.max_num_epochs):
+            # train for one epoch
+            should_terminate = self.train()
+
+            if should_terminate:
+                logger.info('Stopping criterion is satisfied. Finishing training')
+                return
+
+            self.num_epochs += 1
+        logger.info(f"Reached maximum number of epochs: {self.max_num_epochs}. Finishing training...")
+
+    def train(self):
+        """Trains the model for 1 epoch.
+
+        Returns:
+            True if the training should be terminated immediately, False otherwise
+        """
+        train_losses = utils.RunningAverage()
+        train_eval_scores = utils.RunningAverage()
+
+        # sets the model in training mode
+        self.model.train()
+
+        for t in self.loaders['train']:
+            logger.info(f'Training iteration [{self.num_iterations}/{self.max_num_iterations}]. '
+                        f'Epoch [{self.num_epochs}/{self.max_num_epochs - 1}]')
+
+            input, target, weight = self._split_training_batch(t)
+
+            output, loss = self._forward_pass(input, target, weight)
+
+            train_losses.update(loss.item(), self._batch_size(input))
+
+            # compute gradients and update parameters
+            self.optimizer.zero_grad()
+            loss.backward()
+            self.optimizer.step()
+
+            if self.num_iterations % self.validate_after_iters == 0:
+                # set the model in eval mode
+                self.model.eval()
+                # evaluate on validation set
+                eval_score = self.validate()
+                # set the model back to training mode
+                self.model.train()
+
+                # adjust learning rate if necessary
+                if isinstance(self.scheduler, ReduceLROnPlateau):
+                    self.scheduler.step(eval_score)
+                elif self.scheduler is not None:
+                    self.scheduler.step()
+
+                # log current learning rate in tensorboard
+                self._log_lr()
+                # remember best validation metric
+                is_best = self._is_best_eval_score(eval_score)
+
+                # save checkpoint
+                self._save_checkpoint(is_best)
+
+            if self.num_iterations % self.log_after_iters == 0:
+                # compute eval criterion
+                if not self.skip_train_validation:
+                    # apply final activation before calculating eval score
+                    if isinstance(self.model, nn.DataParallel):
+                        final_activation = self.model.module.final_activation
+                    else:
+                        final_activation = self.model.final_activation
+
+                    if final_activation is not None:
+                        act_output = final_activation(output)
+                    else:
+                        act_output = output
+                    eval_score = self.eval_criterion(act_output, target)
+                    train_eval_scores.update(eval_score.item(), self._batch_size(input))
+
+                # log stats, params and images
+                logger.info(
+                    f'Training stats. Loss: {train_losses.avg}. Evaluation score: {train_eval_scores.avg}')
+                self._log_stats('train', train_losses.avg, train_eval_scores.avg)
+                # self._log_params()
+                self._log_images(input, target, output, 'train_')
+
+            if self.should_stop():
+                return True
+
+            self.num_iterations += 1
+
+        return False
+
+    def should_stop(self):
+        """
+        Training will terminate if maximum number of iterations is exceeded or the learning rate drops below
+        some predefined threshold (1e-6 in our case)
+        """
+        if self.max_num_iterations < self.num_iterations:
+            logger.info(f'Maximum number of iterations {self.max_num_iterations} exceeded.')
+            return True
+
+        min_lr = 1e-6
+        lr = self.optimizer.param_groups[0]['lr']
+        if lr < min_lr:
+            logger.info(f'Learning rate below the minimum {min_lr}.')
+            return True
+
+        return False
+
+    def validate(self):
+        logger.info('Validating...')
+
+        val_losses = utils.RunningAverage()
+        val_scores = utils.RunningAverage()
+
+        with torch.no_grad():
+            for i, t in enumerate(self.loaders['val']):
+                logger.info(f'Validation iteration {i}')
+
+                input, target, weight = self._split_training_batch(t)
+
+                output, loss = self._forward_pass(input, target, weight)
+                val_losses.update(loss.item(), self._batch_size(input))
+
+                if i % 100 == 0:
+                    self._log_images(input, target, output, 'val_')
+
+                eval_score = self.eval_criterion(output, target)
+                val_scores.update(eval_score.item(), self._batch_size(input))
+
+                if self.validate_iters is not None and self.validate_iters <= i:
+                    # stop validation
+                    break
+
+            self._log_stats('val', val_losses.avg, val_scores.avg)
+            logger.info(f'Validation finished. Loss: {val_losses.avg}. Evaluation score: {val_scores.avg}')
+            return val_scores.avg
+
+    def _split_training_batch(self, t):
+        def _move_to_gpu(input):
+            if isinstance(input, tuple) or isinstance(input, list):
+                return tuple([_move_to_gpu(x) for x in input])
+            else:
+                if torch.cuda.is_available():
+                    input = input.cuda(non_blocking=True)
+                return input
+
+        t = _move_to_gpu(t)
+        weight = None
+        if len(t) == 2:
+            input, target = t
+        else:
+            input, target, weight = t
+        return input, target, weight
+
+    def _forward_pass(self, input, target, weight=None):
+        if isinstance(self.model, UNet2D):
+            # remove the singleton z-dimension from the input
+            input = torch.squeeze(input, dim=-3)
+            # forward pass
+            output = self.model(input)
+            # add the singleton z-dimension to the output
+            output = torch.unsqueeze(output, dim=-3)
+        else:
+            # forward pass
+            output = self.model(input)
+
+        # compute the loss
+        if weight is None:
+            loss = self.loss_criterion(output, target)
+        else:
+            loss = self.loss_criterion(output, target, weight)
+
+        return output, loss
+
+    def _is_best_eval_score(self, eval_score):
+        if self.eval_score_higher_is_better:
+            is_best = eval_score > self.best_eval_score
+        else:
+            is_best = eval_score < self.best_eval_score
+
+        if is_best:
+            logger.info(f'Saving new best evaluation metric: {eval_score}')
+            self.best_eval_score = eval_score
+
+        return is_best
+
+    def _save_checkpoint(self, is_best):
+        # remove `module` prefix from layer names when using `nn.DataParallel`
+        # see: https://discuss.pytorch.org/t/solved-keyerror-unexpected-key-module-encoder-embedding-weight-in-state-dict/1686/20
+        if isinstance(self.model, nn.DataParallel):
+            state_dict = self.model.module.state_dict()
+        else:
+            state_dict = self.model.state_dict()
+
+        last_file_path = os.path.join(self.checkpoint_dir, 'last_checkpoint.pytorch')
+        logger.info(f"Saving checkpoint to '{last_file_path}'")
+
+        utils.save_checkpoint({
+            'num_epochs': self.num_epochs + 1,
+            'num_iterations': self.num_iterations,
+            'model_state_dict': state_dict,
+            'best_eval_score': self.best_eval_score,
+            'optimizer_state_dict': self.optimizer.state_dict(),
+        }, is_best, checkpoint_dir=self.checkpoint_dir)
+
+    def _log_lr(self):
+        lr = self.optimizer.param_groups[0]['lr']
+        self.writer.add_scalar('learning_rate', lr, self.num_iterations)
+
+    def _log_stats(self, phase, loss_avg, eval_score_avg):
+        tag_value = {
+            f'{phase}_loss_avg': loss_avg,
+            f'{phase}_eval_score_avg': eval_score_avg
+        }
+
+        for tag, value in tag_value.items():
+            self.writer.add_scalar(tag, value, self.num_iterations)
+
+    def _log_params(self):
+        logger.info('Logging model parameters and gradients')
+        for name, value in self.model.named_parameters():
+            self.writer.add_histogram(name, value.data.cpu().numpy(), self.num_iterations)
+            self.writer.add_histogram(name + '/grad', value.grad.data.cpu().numpy(), self.num_iterations)
+
+    def _log_images(self, input, target, prediction, prefix=''):
+
+        if isinstance(self.model, nn.DataParallel):
+            net = self.model.module
+        else:
+            net = self.model
+
+        if net.final_activation is not None:
+            prediction = net.final_activation(prediction)
+
+        inputs_map = {
+            'inputs': input,
+            'targets': target,
+            'predictions': prediction
+        }
+        img_sources = {}
+        for name, batch in inputs_map.items():
+            if isinstance(batch, list) or isinstance(batch, tuple):
+                for i, b in enumerate(batch):
+                    img_sources[f'{name}{i}'] = b.data.cpu().numpy()
+            else:
+                img_sources[name] = batch.data.cpu().numpy()
+
+        for name, batch in img_sources.items():
+            for tag, image in self.tensorboard_formatter(name, batch):
+                self.writer.add_image(prefix + tag, image, self.num_iterations)
+
+    @staticmethod
+    def _batch_size(input):
+        if isinstance(input, list) or isinstance(input, tuple):
+            return input[0].size(0)
+        else:
+            return input.size(0)
diff --git a/build/lib/pytorch3dunet/unet3d/utils.py b/build/lib/pytorch3dunet/unet3d/utils.py
new file mode 100644
index 00000000..01d5559c
--- /dev/null
+++ b/build/lib/pytorch3dunet/unet3d/utils.py
@@ -0,0 +1,366 @@
+import importlib
+import logging
+import os
+import shutil
+import sys
+
+import h5py
+import numpy as np
+import torch
+from torch import optim
+
+
+def save_checkpoint(state, is_best, checkpoint_dir):
+    """Saves model and training parameters at '{checkpoint_dir}/last_checkpoint.pytorch'.
+    If is_best==True saves '{checkpoint_dir}/best_checkpoint.pytorch' as well.
+
+    Args:
+        state (dict): contains model's state_dict, optimizer's state_dict, epoch
+            and best evaluation metric value so far
+        is_best (bool): if True state contains the best model seen so far
+        checkpoint_dir (string): directory where the checkpoint are to be saved
+    """
+
+    if not os.path.exists(checkpoint_dir):
+        os.mkdir(checkpoint_dir)
+
+    last_file_path = os.path.join(checkpoint_dir, 'last_checkpoint.pytorch')
+    torch.save(state, last_file_path)
+    if is_best:
+        best_file_path = os.path.join(checkpoint_dir, 'best_checkpoint.pytorch')
+        shutil.copyfile(last_file_path, best_file_path)
+
+
+def load_checkpoint(checkpoint_path, model, optimizer=None,
+                    model_key='model_state_dict', optimizer_key='optimizer_state_dict'):
+    """Loads model and training parameters from a given checkpoint_path
+    If optimizer is provided, loads optimizer's state_dict of as well.
+
+    Args:
+        checkpoint_path (string): path to the checkpoint to be loaded
+        model (torch.nn.Module): model into which the parameters are to be copied
+        optimizer (torch.optim.Optimizer) optional: optimizer instance into
+            which the parameters are to be copied
+
+    Returns:
+        state
+    """
+    if not os.path.exists(checkpoint_path):
+        raise IOError(f"Checkpoint '{checkpoint_path}' does not exist")
+
+    state = torch.load(checkpoint_path, map_location='cpu')
+    model.load_state_dict(state[model_key])
+
+    if optimizer is not None:
+        optimizer.load_state_dict(state[optimizer_key])
+
+    return state
+
+
+def save_network_output(output_path, output, logger=None):
+    if logger is not None:
+        logger.info(f'Saving network output to: {output_path}...')
+    output = output.detach().cpu()[0]
+    with h5py.File(output_path, 'w') as f:
+        f.create_dataset('predictions', data=output, compression='gzip')
+
+
+loggers = {}
+
+
+def get_logger(name, level=logging.INFO):
+    global loggers
+    if loggers.get(name) is not None:
+        return loggers[name]
+    else:
+        logger = logging.getLogger(name)
+        logger.setLevel(level)
+        # Logging to console
+        stream_handler = logging.StreamHandler(sys.stdout)
+        formatter = logging.Formatter(
+            '%(asctime)s [%(threadName)s] %(levelname)s %(name)s - %(message)s')
+        stream_handler.setFormatter(formatter)
+        logger.addHandler(stream_handler)
+
+        loggers[name] = logger
+
+        return logger
+
+
+def get_number_of_learnable_parameters(model):
+    return sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+
+class RunningAverage:
+    """Computes and stores the average
+    """
+
+    def __init__(self):
+        self.count = 0
+        self.sum = 0
+        self.avg = 0
+
+    def update(self, value, n=1):
+        self.count += n
+        self.sum += value * n
+        self.avg = self.sum / self.count
+
+
+def number_of_features_per_level(init_channel_number, num_levels):
+    return [init_channel_number * 2 ** k for k in range(num_levels)]
+
+
+class _TensorboardFormatter:
+    """
+    Tensorboard formatters converts a given batch of images (be it input/output to the network or the target segmentation
+    image) to a series of images that can be displayed in tensorboard. This is the parent class for all tensorboard
+    formatters which ensures that returned images are in the 'CHW' format.
+    """
+
+    def __init__(self, **kwargs):
+        pass
+
+    def __call__(self, name, batch):
+        """
+        Transform a batch to a series of tuples of the form (tag, img), where `tag` corresponds to the image tag
+        and `img` is the image itself.
+
+        Args:
+             name (str): one of 'inputs'/'targets'/'predictions'
+             batch (torch.tensor): 4D or 5D torch tensor
+        """
+
+        def _check_img(tag_img):
+            tag, img = tag_img
+
+            assert img.ndim == 2 or img.ndim == 3, 'Only 2D (HW) and 3D (CHW) images are accepted for display'
+
+            if img.ndim == 2:
+                img = np.expand_dims(img, axis=0)
+            else:
+                C = img.shape[0]
+                assert C == 1 or C == 3, 'Only (1, H, W) or (3, H, W) images are supported'
+
+            return tag, img
+
+        tagged_images = self.process_batch(name, batch)
+
+        return list(map(_check_img, tagged_images))
+
+    def process_batch(self, name, batch):
+        raise NotImplementedError
+
+
+class DefaultTensorboardFormatter(_TensorboardFormatter):
+    def __init__(self, skip_last_target=False, **kwargs):
+        super().__init__(**kwargs)
+        self.skip_last_target = skip_last_target
+
+    def process_batch(self, name, batch):
+        if name == 'targets' and self.skip_last_target:
+            batch = batch[:, :-1, ...]
+
+        tag_template = '{}/batch_{}/channel_{}/slice_{}'
+
+        tagged_images = []
+
+        if batch.ndim == 5:
+            # NCDHW
+            slice_idx = batch.shape[2] // 2  # get the middle slice
+            for batch_idx in range(batch.shape[0]):
+                for channel_idx in range(batch.shape[1]):
+                    tag = tag_template.format(name, batch_idx, channel_idx, slice_idx)
+                    img = batch[batch_idx, channel_idx, slice_idx, ...]
+                    tagged_images.append((tag, self._normalize_img(img)))
+        else:
+            # batch has no channel dim: NDHW
+            slice_idx = batch.shape[1] // 2  # get the middle slice
+            for batch_idx in range(batch.shape[0]):
+                tag = tag_template.format(name, batch_idx, 0, slice_idx)
+                img = batch[batch_idx, slice_idx, ...]
+                tagged_images.append((tag, self._normalize_img(img)))
+
+        return tagged_images
+
+    @staticmethod
+    def _normalize_img(img):
+        return np.nan_to_num((img - np.min(img)) / np.ptp(img))
+
+
+def _find_masks(batch, min_size=10):
+    """Center the z-slice in the 'middle' of a given instance, given a batch of instances
+
+    Args:
+        batch (ndarray): 5d numpy tensor (NCDHW)
+    """
+    result = []
+    for b in batch:
+        assert b.shape[0] == 1
+        patch = b[0]
+        z_sum = patch.sum(axis=(1, 2))
+        coords = np.where(z_sum > min_size)[0]
+        if len(coords) > 0:
+            ind = coords[len(coords) // 2]
+            result.append(b[:, ind:ind + 1, ...])
+        else:
+            ind = b.shape[1] // 2
+            result.append(b[:, ind:ind + 1, ...])
+
+    return np.stack(result, axis=0)
+
+
+def get_tensorboard_formatter(formatter_config):
+    if formatter_config is None:
+        return DefaultTensorboardFormatter()
+
+    class_name = formatter_config['name']
+    m = importlib.import_module('pytorch3dunet.unet3d.utils')
+    clazz = getattr(m, class_name)
+    return clazz(**formatter_config)
+
+
+def expand_as_one_hot(input, C, ignore_index=None):
+    """
+    Converts NxSPATIAL label image to NxCxSPATIAL, where each label gets converted to its corresponding one-hot vector.
+    It is assumed that the batch dimension is present.
+    Args:
+        input (torch.Tensor): 3D/4D input image
+        C (int): number of channels/labels
+        ignore_index (int): ignore index to be kept during the expansion
+    Returns:
+        4D/5D output torch.Tensor (NxCxSPATIAL)
+    """
+    assert input.dim() == 4
+
+    # expand the input tensor to Nx1xSPATIAL before scattering
+    input = input.unsqueeze(1)
+    # create output tensor shape (NxCxSPATIAL)
+    shape = list(input.size())
+    shape[1] = C
+
+    if ignore_index is not None:
+        # create ignore_index mask for the result
+        mask = input.expand(shape) == ignore_index
+        # clone the src tensor and zero out ignore_index in the input
+        input = input.clone()
+        input[input == ignore_index] = 0
+        # scatter to get the one-hot tensor
+        result = torch.zeros(shape).to(input.device).scatter_(1, input, 1)
+        # bring back the ignore_index in the result
+        result[mask] = ignore_index
+        return result
+    else:
+        # scatter to get the one-hot tensor
+        return torch.zeros(shape).to(input.device).scatter_(1, input, 1)
+
+
+def convert_to_numpy(*inputs):
+    """
+    Coverts input tensors to numpy ndarrays
+
+    Args:
+        inputs (iteable of torch.Tensor): torch tensor
+
+    Returns:
+        tuple of ndarrays
+    """
+
+    def _to_numpy(i):
+        assert isinstance(i, torch.Tensor), "Expected input to be torch.Tensor"
+        return i.detach().cpu().numpy()
+
+    return (_to_numpy(i) for i in inputs)
+
+
+def create_optimizer(optimizer_config, model):
+    optim_name = optimizer_config.get('name', 'Adam')
+    # common optimizer settings
+    learning_rate = optimizer_config.get('learning_rate', 1e-3)
+    weight_decay = optimizer_config.get('weight_decay', 0)
+
+    # grab optimizer specific settings and init
+    # optimizer
+    if optim_name == 'Adadelta':
+        rho = optimizer_config.get('rho', 0.9)
+        optimizer = optim.Adadelta(model.parameters(), lr=learning_rate, rho=rho,
+                                   weight_decay=weight_decay)
+    elif optim_name == 'Adagrad':
+        lr_decay = optimizer_config.get('lr_decay', 0)
+        optimizer = optim.Adagrad(model.parameters(), lr=learning_rate, lr_decay=lr_decay,
+                                  weight_decay=weight_decay)
+    elif optim_name == 'AdamW':
+        betas = tuple(optimizer_config.get('betas', (0.9, 0.999)))
+        optimizer = optim.AdamW(model.parameters(), lr=learning_rate, betas=betas,
+                                weight_decay=weight_decay)
+    elif optim_name == 'SparseAdam':
+        betas = tuple(optimizer_config.get('betas', (0.9, 0.999)))
+        optimizer = optim.SparseAdam(model.parameters(), lr=learning_rate, betas=betas)
+    elif optim_name == 'Adamax':
+        betas = tuple(optimizer_config.get('betas', (0.9, 0.999)))
+        optimizer = optim.Adamax(model.parameters(), lr=learning_rate, betas=betas,
+                                 weight_decay=weight_decay)
+    elif optim_name == 'ASGD':
+        lambd = optimizer_config.get('lambd', 0.0001)
+        alpha = optimizer_config.get('alpha', 0.75)
+        t0 = optimizer_config.get('t0', 1e6)
+        optimizer = optim.Adamax(model.parameters(), lr=learning_rate, lambd=lambd,
+                                 alpha=alpha, t0=t0, weight_decay=weight_decay)
+    elif optim_name == 'LBFGS':
+        max_iter = optimizer_config.get('max_iter', 20)
+        max_eval = optimizer_config.get('max_eval', None)
+        tolerance_grad = optimizer_config.get('tolerance_grad', 1e-7)
+        tolerance_change = optimizer_config.get('tolerance_change', 1e-9)
+        history_size = optimizer_config.get('history_size', 100)
+        optimizer = optim.LBFGS(model.parameters(), lr=learning_rate, max_iter=max_iter,
+                                max_eval=max_eval, tolerance_grad=tolerance_grad,
+                                tolerance_change=tolerance_change, history_size=history_size)
+    elif optim_name == 'NAdam':
+        betas = tuple(optimizer_config.get('betas', (0.9, 0.999)))
+        momentum_decay = optimizer_config.get('momentum_decay', 4e-3)
+        optimizer = optim.NAdam(model.parameters(), lr=learning_rate, betas=betas,
+                                momentum_decay=momentum_decay,
+                                weight_decay=weight_decay)
+    elif optim_name == 'RAdam':
+        betas = tuple(optimizer_config.get('betas', (0.9, 0.999)))
+        optimizer = optim.RAdam(model.parameters(), lr=learning_rate, betas=betas,
+                                weight_decay=weight_decay)
+    elif optim_name == 'RMSprop':
+        alpha = optimizer_config.get('alpha', 0.99)
+        optimizer = optim.RMSprop(model.parameters(), lr=learning_rate, alpha=alpha,
+                                  weight_decay=weight_decay)
+    elif optim_name == 'Rprop':
+        momentum = optimizer_config.get('momentum', 0)
+        optimizer = optim.RMSprop(model.parameters(), lr=learning_rate, weight_decay=weight_decay, momentum=momentum)
+    elif optim_name == 'SGD':
+        momentum = optimizer_config.get('momentum', 0)
+        dampening = optimizer_config.get('dampening', 0)
+        nesterov = optimizer_config.get('nesterov', False)
+        optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum,
+                              dampening=dampening, nesterov=nesterov,
+                              weight_decay=weight_decay)
+    else:  # Adam is default
+        betas = tuple(optimizer_config.get('betas', (0.9, 0.999)))
+        optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=betas,
+                               weight_decay=weight_decay)
+
+    return optimizer
+
+
+def create_lr_scheduler(lr_config, optimizer):
+    if lr_config is None:
+        return None
+    class_name = lr_config.pop('name')
+    m = importlib.import_module('torch.optim.lr_scheduler')
+    clazz = getattr(m, class_name)
+    # add optimizer to the config
+    lr_config['optimizer'] = optimizer
+    return clazz(**lr_config)
+
+
+def get_class(class_name, modules):
+    for module in modules:
+        m = importlib.import_module(module)
+        clazz = getattr(m, class_name, None)
+        if clazz is not None:
+            return clazz
+    raise RuntimeError(f'Unsupported dataset class: {class_name}')

From 0987f24d8d7c1e1ac45865a5402f2bab09f98de1 Mon Sep 17 00:00:00 2001
From: Shota Mizusaki <nrxg129@gmail.com>
Date: Fri, 12 Jul 2024 14:35:29 +0900
Subject: [PATCH 2/4] Fix: Add missing return statement

---
 pytorch3dunet/unet3d/config.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/pytorch3dunet/unet3d/config.py b/pytorch3dunet/unet3d/config.py
index bb011632..0dbecffd 100644
--- a/pytorch3dunet/unet3d/config.py
+++ b/pytorch3dunet/unet3d/config.py
@@ -49,7 +49,7 @@ def load_config():
     if device == 'cpu':
         logger.warning('CPU mode forced in config, this will likely result in slow training/prediction')
         config['device'] = 'cpu'
-        return config
+        return config, config_path
 
     if torch.cuda.is_available():
         config['device'] = 'cuda'

From 6b00041129871bff9790ea75c18e201828d2761a Mon Sep 17 00:00:00 2001
From: Shota Mizusaki <nrxg129@gmail.com>
Date: Fri, 12 Jul 2024 14:49:07 +0900
Subject: [PATCH 3/4] delete build file

---
 build/lib/pytorch3dunet/__init__.py           |   1 -
 build/lib/pytorch3dunet/__version__.py        |   1 -
 build/lib/pytorch3dunet/augment/__init__.py   |   0
 build/lib/pytorch3dunet/augment/transforms.py | 761 ------------------
 build/lib/pytorch3dunet/datasets/__init__.py  |   0
 build/lib/pytorch3dunet/datasets/dsb.py       | 108 ---
 build/lib/pytorch3dunet/datasets/hdf5.py      | 293 -------
 build/lib/pytorch3dunet/datasets/utils.py     | 361 ---------
 build/lib/pytorch3dunet/predict.py            |  59 --
 build/lib/pytorch3dunet/train.py              |  35 -
 build/lib/pytorch3dunet/unet3d/__init__.py    |   0
 .../pytorch3dunet/unet3d/buildingblocks.py    | 545 -------------
 build/lib/pytorch3dunet/unet3d/config.py      |  79 --
 build/lib/pytorch3dunet/unet3d/losses.py      | 345 --------
 build/lib/pytorch3dunet/unet3d/metrics.py     | 445 ----------
 build/lib/pytorch3dunet/unet3d/model.py       | 249 ------
 build/lib/pytorch3dunet/unet3d/predictor.py   | 281 -------
 build/lib/pytorch3dunet/unet3d/se.py          | 113 ---
 build/lib/pytorch3dunet/unet3d/seg_metrics.py | 123 ---
 build/lib/pytorch3dunet/unet3d/trainer.py     | 404 ----------
 build/lib/pytorch3dunet/unet3d/utils.py       | 366 ---------
 21 files changed, 4569 deletions(-)
 delete mode 100644 build/lib/pytorch3dunet/__init__.py
 delete mode 100644 build/lib/pytorch3dunet/__version__.py
 delete mode 100644 build/lib/pytorch3dunet/augment/__init__.py
 delete mode 100644 build/lib/pytorch3dunet/augment/transforms.py
 delete mode 100644 build/lib/pytorch3dunet/datasets/__init__.py
 delete mode 100644 build/lib/pytorch3dunet/datasets/dsb.py
 delete mode 100644 build/lib/pytorch3dunet/datasets/hdf5.py
 delete mode 100644 build/lib/pytorch3dunet/datasets/utils.py
 delete mode 100644 build/lib/pytorch3dunet/predict.py
 delete mode 100644 build/lib/pytorch3dunet/train.py
 delete mode 100644 build/lib/pytorch3dunet/unet3d/__init__.py
 delete mode 100644 build/lib/pytorch3dunet/unet3d/buildingblocks.py
 delete mode 100644 build/lib/pytorch3dunet/unet3d/config.py
 delete mode 100644 build/lib/pytorch3dunet/unet3d/losses.py
 delete mode 100644 build/lib/pytorch3dunet/unet3d/metrics.py
 delete mode 100644 build/lib/pytorch3dunet/unet3d/model.py
 delete mode 100644 build/lib/pytorch3dunet/unet3d/predictor.py
 delete mode 100644 build/lib/pytorch3dunet/unet3d/se.py
 delete mode 100644 build/lib/pytorch3dunet/unet3d/seg_metrics.py
 delete mode 100644 build/lib/pytorch3dunet/unet3d/trainer.py
 delete mode 100644 build/lib/pytorch3dunet/unet3d/utils.py

diff --git a/build/lib/pytorch3dunet/__init__.py b/build/lib/pytorch3dunet/__init__.py
deleted file mode 100644
index 9226fe7e..00000000
--- a/build/lib/pytorch3dunet/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .__version__ import __version__
diff --git a/build/lib/pytorch3dunet/__version__.py b/build/lib/pytorch3dunet/__version__.py
deleted file mode 100644
index 655be529..00000000
--- a/build/lib/pytorch3dunet/__version__.py
+++ /dev/null
@@ -1 +0,0 @@
-__version__ = '1.8.7'
diff --git a/build/lib/pytorch3dunet/augment/__init__.py b/build/lib/pytorch3dunet/augment/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/build/lib/pytorch3dunet/augment/transforms.py b/build/lib/pytorch3dunet/augment/transforms.py
deleted file mode 100644
index 527d596b..00000000
--- a/build/lib/pytorch3dunet/augment/transforms.py
+++ /dev/null
@@ -1,761 +0,0 @@
-import importlib
-import random
-
-import numpy as np
-import torch
-from scipy.ndimage import rotate, map_coordinates, gaussian_filter, convolve
-from skimage import measure
-from skimage.filters import gaussian
-from skimage.segmentation import find_boundaries
-
-# WARN: use fixed random state for reproducibility; if you want to randomize on each run seed with `time.time()` e.g.
-GLOBAL_RANDOM_STATE = np.random.RandomState(47)
-
-
-class Compose(object):
-    def __init__(self, transforms):
-        self.transforms = transforms
-
-    def __call__(self, m):
-        for t in self.transforms:
-            m = t(m)
-        return m
-
-
-class RandomFlip:
-    """
-    Randomly flips the image across the given axes. Image can be either 3D (DxHxW) or 4D (CxDxHxW).
-
-    When creating make sure that the provided RandomStates are consistent between raw and labeled datasets,
-    otherwise the models won't converge.
-    """
-
-    def __init__(self, random_state, axis_prob=0.5, **kwargs):
-        assert random_state is not None, 'RandomState cannot be None'
-        self.random_state = random_state
-        self.axes = (0, 1, 2)
-        self.axis_prob = axis_prob
-
-    def __call__(self, m):
-        assert m.ndim in [3, 4], 'Supports only 3D (DxHxW) or 4D (CxDxHxW) images'
-
-        for axis in self.axes:
-            if self.random_state.uniform() > self.axis_prob:
-                if m.ndim == 3:
-                    m = np.flip(m, axis)
-                else:
-                    channels = [np.flip(m[c], axis) for c in range(m.shape[0])]
-                    m = np.stack(channels, axis=0)
-
-        return m
-
-
-class RandomRotate90:
-    """
-    Rotate an array by 90 degrees around a randomly chosen plane. Image can be either 3D (DxHxW) or 4D (CxDxHxW).
-
-    When creating make sure that the provided RandomStates are consistent between raw and labeled datasets,
-    otherwise the models won't converge.
-
-    IMPORTANT: assumes DHW axis order (that's why rotation is performed across (1,2) axis)
-    """
-
-    def __init__(self, random_state, **kwargs):
-        self.random_state = random_state
-        # always rotate around z-axis
-        self.axis = (1, 2)
-
-    def __call__(self, m):
-        assert m.ndim in [3, 4], 'Supports only 3D (DxHxW) or 4D (CxDxHxW) images'
-
-        # pick number of rotations at random
-        k = self.random_state.randint(0, 4)
-        # rotate k times around a given plane
-        if m.ndim == 3:
-            m = np.rot90(m, k, self.axis)
-        else:
-            channels = [np.rot90(m[c], k, self.axis) for c in range(m.shape[0])]
-            m = np.stack(channels, axis=0)
-
-        return m
-
-
-class RandomRotate:
-    """
-    Rotate an array by a random degrees from taken from (-angle_spectrum, angle_spectrum) interval.
-    Rotation axis is picked at random from the list of provided axes.
-    """
-
-    def __init__(self, random_state, angle_spectrum=30, axes=None, mode='reflect', order=0, **kwargs):
-        if axes is None:
-            axes = [(1, 0), (2, 1), (2, 0)]
-        else:
-            assert isinstance(axes, list) and len(axes) > 0
-
-        self.random_state = random_state
-        self.angle_spectrum = angle_spectrum
-        self.axes = axes
-        self.mode = mode
-        self.order = order
-
-    def __call__(self, m):
-        axis = self.axes[self.random_state.randint(len(self.axes))]
-        angle = self.random_state.randint(-self.angle_spectrum, self.angle_spectrum)
-
-        if m.ndim == 3:
-            m = rotate(m, angle, axes=axis, reshape=False, order=self.order, mode=self.mode, cval=-1)
-        else:
-            channels = [rotate(m[c], angle, axes=axis, reshape=False, order=self.order, mode=self.mode, cval=-1) for c
-                        in range(m.shape[0])]
-            m = np.stack(channels, axis=0)
-
-        return m
-
-
-class RandomContrast:
-    """
-    Adjust contrast by scaling each voxel to `mean + alpha * (v - mean)`.
-    """
-
-    def __init__(self, random_state, alpha=(0.5, 1.5), mean=0.0, execution_probability=0.1, **kwargs):
-        self.random_state = random_state
-        assert len(alpha) == 2
-        self.alpha = alpha
-        self.mean = mean
-        self.execution_probability = execution_probability
-
-    def __call__(self, m):
-        if self.random_state.uniform() < self.execution_probability:
-            alpha = self.random_state.uniform(self.alpha[0], self.alpha[1])
-            result = self.mean + alpha * (m - self.mean)
-            return np.clip(result, -1, 1)
-
-        return m
-
-
-# it's relatively slow, i.e. ~1s per patch of size 64x200x200, so use multiple workers in the DataLoader
-# remember to use spline_order=0 when transforming the labels
-class ElasticDeformation:
-    """
-    Apply elasitc deformations of 3D patches on a per-voxel mesh. Assumes ZYX axis order (or CZYX if the data is 4D).
-    Based on: https://github.com/fcalvet/image_tools/blob/master/image_augmentation.py#L62
-    """
-
-    def __init__(self, random_state, spline_order, alpha=2000, sigma=50, execution_probability=0.1, apply_3d=True,
-                 **kwargs):
-        """
-        :param spline_order: the order of spline interpolation (use 0 for labeled images)
-        :param alpha: scaling factor for deformations
-        :param sigma: smoothing factor for Gaussian filter
-        :param execution_probability: probability of executing this transform
-        :param apply_3d: if True apply deformations in each axis
-        """
-        self.random_state = random_state
-        self.spline_order = spline_order
-        self.alpha = alpha
-        self.sigma = sigma
-        self.execution_probability = execution_probability
-        self.apply_3d = apply_3d
-
-    def __call__(self, m):
-        if self.random_state.uniform() < self.execution_probability:
-            assert m.ndim in [3, 4]
-
-            if m.ndim == 3:
-                volume_shape = m.shape
-            else:
-                volume_shape = m[0].shape
-
-            if self.apply_3d:
-                dz = gaussian_filter(self.random_state.randn(*volume_shape), self.sigma, mode="reflect") * self.alpha
-            else:
-                dz = np.zeros_like(m)
-
-            dy, dx = [
-                gaussian_filter(
-                    self.random_state.randn(*volume_shape),
-                    self.sigma, mode="reflect"
-                ) * self.alpha for _ in range(2)
-            ]
-
-            z_dim, y_dim, x_dim = volume_shape
-            z, y, x = np.meshgrid(np.arange(z_dim), np.arange(y_dim), np.arange(x_dim), indexing='ij')
-            indices = z + dz, y + dy, x + dx
-
-            if m.ndim == 3:
-                return map_coordinates(m, indices, order=self.spline_order, mode='reflect')
-            else:
-                channels = [map_coordinates(c, indices, order=self.spline_order, mode='reflect') for c in m]
-                return np.stack(channels, axis=0)
-
-        return m
-
-
-class CropToFixed:
-    def __init__(self, random_state, size=(256, 256), centered=False, **kwargs):
-        self.random_state = random_state
-        self.crop_y, self.crop_x = size
-        self.centered = centered
-
-    def __call__(self, m):
-        def _padding(pad_total):
-            half_total = pad_total // 2
-            return (half_total, pad_total - half_total)
-
-        def _rand_range_and_pad(crop_size, max_size):
-            """
-            Returns a tuple:
-                max_value (int) for the corner dimension. The corner dimension is chosen as `self.random_state(max_value)`
-                pad (int): padding in both directions; if crop_size is lt max_size the pad is 0
-            """
-            if crop_size < max_size:
-                return max_size - crop_size, (0, 0)
-            else:
-                return 1, _padding(crop_size - max_size)
-
-        def _start_and_pad(crop_size, max_size):
-            if crop_size < max_size:
-                return (max_size - crop_size) // 2, (0, 0)
-            else:
-                return 0, _padding(crop_size - max_size)
-
-        assert m.ndim in (3, 4)
-        if m.ndim == 3:
-            _, y, x = m.shape
-        else:
-            _, _, y, x = m.shape
-
-        if not self.centered:
-            y_range, y_pad = _rand_range_and_pad(self.crop_y, y)
-            x_range, x_pad = _rand_range_and_pad(self.crop_x, x)
-
-            y_start = self.random_state.randint(y_range)
-            x_start = self.random_state.randint(x_range)
-
-        else:
-            y_start, y_pad = _start_and_pad(self.crop_y, y)
-            x_start, x_pad = _start_and_pad(self.crop_x, x)
-
-        if m.ndim == 3:
-            result = m[:, y_start:y_start + self.crop_y, x_start:x_start + self.crop_x]
-            return np.pad(result, pad_width=((0, 0), y_pad, x_pad), mode='reflect')
-        else:
-            channels = []
-            for c in range(m.shape[0]):
-                result = m[c][:, y_start:y_start + self.crop_y, x_start:x_start + self.crop_x]
-                channels.append(np.pad(result, pad_width=((0, 0), y_pad, x_pad), mode='reflect'))
-            return np.stack(channels, axis=0)
-
-
-class AbstractLabelToBoundary:
-    AXES_TRANSPOSE = [
-        (0, 1, 2),  # X
-        (0, 2, 1),  # Y
-        (2, 0, 1)  # Z
-    ]
-
-    def __init__(self, ignore_index=None, aggregate_affinities=False, append_label=False, **kwargs):
-        """
-        :param ignore_index: label to be ignored in the output, i.e. after computing the boundary the label ignore_index
-            will be restored where is was in the patch originally
-        :param aggregate_affinities: aggregate affinities with the same offset across Z,Y,X axes
-        :param append_label: if True append the orignal ground truth labels to the last channel
-        :param blur: Gaussian blur the boundaries
-        :param sigma: standard deviation for Gaussian kernel
-        """
-        self.ignore_index = ignore_index
-        self.aggregate_affinities = aggregate_affinities
-        self.append_label = append_label
-
-    def __call__(self, m):
-        """
-        Extract boundaries from a given 3D label tensor.
-        :param m: input 3D tensor
-        :return: binary mask, with 1-label corresponding to the boundary and 0-label corresponding to the background
-        """
-        assert m.ndim == 3
-
-        kernels = self.get_kernels()
-        boundary_arr = [np.where(np.abs(convolve(m, kernel)) > 0, 1, 0) for kernel in kernels]
-        channels = np.stack(boundary_arr)
-        results = []
-        if self.aggregate_affinities:
-            assert len(kernels) % 3 == 0, "Number of kernels must be divided by 3 (one kernel per offset per Z,Y,X axes"
-            # aggregate affinities with the same offset
-            for i in range(0, len(kernels), 3):
-                # merge across X,Y,Z axes (logical OR)
-                xyz_aggregated_affinities = np.logical_or.reduce(channels[i:i + 3, ...]).astype(np.int32)
-                # recover ignore index
-                xyz_aggregated_affinities = _recover_ignore_index(xyz_aggregated_affinities, m, self.ignore_index)
-                results.append(xyz_aggregated_affinities)
-        else:
-            results = [_recover_ignore_index(channels[i], m, self.ignore_index) for i in range(channels.shape[0])]
-
-        if self.append_label:
-            # append original input data
-            results.append(m)
-
-        # stack across channel dim
-        return np.stack(results, axis=0)
-
-    @staticmethod
-    def create_kernel(axis, offset):
-        # create conv kernel
-        k_size = offset + 1
-        k = np.zeros((1, 1, k_size), dtype=np.int32)
-        k[0, 0, 0] = 1
-        k[0, 0, offset] = -1
-        return np.transpose(k, axis)
-
-    def get_kernels(self):
-        raise NotImplementedError
-
-
-class StandardLabelToBoundary:
-    def __init__(self, ignore_index=None, append_label=False, mode='thick', foreground=False,
-                 **kwargs):
-        self.ignore_index = ignore_index
-        self.append_label = append_label
-        self.mode = mode
-        self.foreground = foreground
-
-    def __call__(self, m):
-        assert m.ndim == 3
-
-        boundaries = find_boundaries(m, connectivity=2, mode=self.mode)
-        boundaries = boundaries.astype('int32')
-
-        results = []
-        if self.foreground:
-            foreground = (m > 0).astype('uint8')
-            results.append(_recover_ignore_index(foreground, m, self.ignore_index))
-
-        results.append(_recover_ignore_index(boundaries, m, self.ignore_index))
-
-        if self.append_label:
-            # append original input data
-            results.append(m)
-
-        return np.stack(results, axis=0)
-
-
-class BlobsToMask:
-    """
-    Returns binary mask from labeled image, i.e. every label greater than 0 is treated as foreground.
-
-    """
-
-    def __init__(self, append_label=False, boundary=False, cross_entropy=False, **kwargs):
-        self.cross_entropy = cross_entropy
-        self.boundary = boundary
-        self.append_label = append_label
-
-    def __call__(self, m):
-        assert m.ndim == 3
-
-        # get the segmentation mask
-        mask = (m > 0).astype('uint8')
-        results = [mask]
-
-        if self.boundary:
-            outer = find_boundaries(m, connectivity=2, mode='outer')
-            if self.cross_entropy:
-                # boundary is class 2
-                mask[outer > 0] = 2
-                results = [mask]
-            else:
-                results.append(outer)
-
-        if self.append_label:
-            results.append(m)
-
-        return np.stack(results, axis=0)
-
-
-class RandomLabelToAffinities(AbstractLabelToBoundary):
-    """
-    Converts a given volumetric label array to binary mask corresponding to borders between labels.
-    One specify the max_offset (thickness) of the border. Then the offset is picked at random every time you call
-    the transformer (offset is picked form the range 1:max_offset) for each axis and the boundary computed.
-    One may use this scheme in order to make the network more robust against various thickness of borders in the ground
-    truth  (think of it as a boundary denoising scheme).
-    """
-
-    def __init__(self, random_state, max_offset=10, ignore_index=None, append_label=False, z_offset_scale=2, **kwargs):
-        super().__init__(ignore_index=ignore_index, append_label=append_label, aggregate_affinities=False)
-        self.random_state = random_state
-        self.offsets = tuple(range(1, max_offset + 1))
-        self.z_offset_scale = z_offset_scale
-
-    def get_kernels(self):
-        rand_offset = self.random_state.choice(self.offsets)
-        axis_ind = self.random_state.randint(3)
-        # scale down z-affinities due to anisotropy
-        if axis_ind == 2:
-            rand_offset = max(1, rand_offset // self.z_offset_scale)
-
-        rand_axis = self.AXES_TRANSPOSE[axis_ind]
-        # return a single kernel
-        return [self.create_kernel(rand_axis, rand_offset)]
-
-
-class LabelToAffinities(AbstractLabelToBoundary):
-    """
-    Converts a given volumetric label array to binary mask corresponding to borders between labels (which can be seen
-    as an affinity graph: https://arxiv.org/pdf/1706.00120.pdf)
-    One specify the offsets (thickness) of the border. The boundary will be computed via the convolution operator.
-    """
-
-    def __init__(self, offsets, ignore_index=None, append_label=False, aggregate_affinities=False, z_offsets=None,
-                 **kwargs):
-        super().__init__(ignore_index=ignore_index, append_label=append_label,
-                         aggregate_affinities=aggregate_affinities)
-
-        assert isinstance(offsets, list) or isinstance(offsets, tuple), 'offsets must be a list or a tuple'
-        assert all(a > 0 for a in offsets), "'offsets must be positive"
-        assert len(set(offsets)) == len(offsets), "'offsets' must be unique"
-        if z_offsets is not None:
-            assert len(offsets) == len(z_offsets), 'z_offsets length must be the same as the length of offsets'
-        else:
-            # if z_offsets is None just use the offsets for z-affinities
-            z_offsets = list(offsets)
-        self.z_offsets = z_offsets
-
-        self.kernels = []
-        # create kernel for every axis-offset pair
-        for xy_offset, z_offset in zip(offsets, z_offsets):
-            for axis_ind, axis in enumerate(self.AXES_TRANSPOSE):
-                final_offset = xy_offset
-                if axis_ind == 2:
-                    final_offset = z_offset
-                # create kernels for a given offset in every direction
-                self.kernels.append(self.create_kernel(axis, final_offset))
-
-    def get_kernels(self):
-        return self.kernels
-
-
-class LabelToZAffinities(AbstractLabelToBoundary):
-    """
-    Converts a given volumetric label array to binary mask corresponding to borders between labels (which can be seen
-    as an affinity graph: https://arxiv.org/pdf/1706.00120.pdf)
-    One specify the offsets (thickness) of the border. The boundary will be computed via the convolution operator.
-    """
-
-    def __init__(self, offsets, ignore_index=None, append_label=False, **kwargs):
-        super().__init__(ignore_index=ignore_index, append_label=append_label)
-
-        assert isinstance(offsets, list) or isinstance(offsets, tuple), 'offsets must be a list or a tuple'
-        assert all(a > 0 for a in offsets), "'offsets must be positive"
-        assert len(set(offsets)) == len(offsets), "'offsets' must be unique"
-
-        self.kernels = []
-        z_axis = self.AXES_TRANSPOSE[2]
-        # create kernels
-        for z_offset in offsets:
-            self.kernels.append(self.create_kernel(z_axis, z_offset))
-
-    def get_kernels(self):
-        return self.kernels
-
-
-class LabelToBoundaryAndAffinities:
-    """
-    Combines the StandardLabelToBoundary and LabelToAffinities in the hope
-    that that training the network to predict both would improve the main task: boundary prediction.
-    """
-
-    def __init__(self, xy_offsets, z_offsets, append_label=False, blur=False, sigma=1, ignore_index=None, mode='thick',
-                 foreground=False, **kwargs):
-        # blur only StandardLabelToBoundary results; we don't want to blur the affinities
-        self.l2b = StandardLabelToBoundary(blur=blur, sigma=sigma, ignore_index=ignore_index, mode=mode,
-                                           foreground=foreground)
-        self.l2a = LabelToAffinities(offsets=xy_offsets, z_offsets=z_offsets, append_label=append_label,
-                                     ignore_index=ignore_index)
-
-    def __call__(self, m):
-        boundary = self.l2b(m)
-        affinities = self.l2a(m)
-        return np.concatenate((boundary, affinities), axis=0)
-
-
-class LabelToMaskAndAffinities:
-    def __init__(self, xy_offsets, z_offsets, append_label=False, background=0, ignore_index=None, **kwargs):
-        self.background = background
-        self.l2a = LabelToAffinities(offsets=xy_offsets, z_offsets=z_offsets, append_label=append_label,
-                                     ignore_index=ignore_index)
-
-    def __call__(self, m):
-        mask = m > self.background
-        mask = np.expand_dims(mask.astype(np.uint8), axis=0)
-        affinities = self.l2a(m)
-        return np.concatenate((mask, affinities), axis=0)
-
-
-class Standardize:
-    """
-    Apply Z-score normalization to a given input tensor, i.e. re-scaling the values to be 0-mean and 1-std.
-    """
-
-    def __init__(self, eps=1e-10, mean=None, std=None, channelwise=False, **kwargs):
-        if mean is not None or std is not None:
-            assert mean is not None and std is not None
-        self.mean = mean
-        self.std = std
-        self.eps = eps
-        self.channelwise = channelwise
-
-    def __call__(self, m):
-        if self.mean is not None:
-            mean, std = self.mean, self.std
-        else:
-            if self.channelwise:
-                # normalize per-channel
-                axes = list(range(m.ndim))
-                # average across channels
-                axes = tuple(axes[1:])
-                mean = np.mean(m, axis=axes, keepdims=True)
-                std = np.std(m, axis=axes, keepdims=True)
-            else:
-                mean = np.mean(m)
-                std = np.std(m)
-
-        return (m - mean) / np.clip(std, a_min=self.eps, a_max=None)
-
-
-class PercentileNormalizer:
-    def __init__(self, pmin=1, pmax=99.6, channelwise=False, eps=1e-10, **kwargs):
-        self.eps = eps
-        self.pmin = pmin
-        self.pmax = pmax
-        self.channelwise = channelwise
-
-    def __call__(self, m):
-        if self.channelwise:
-            axes = list(range(m.ndim))
-            # average across channels
-            axes = tuple(axes[1:])
-            pmin = np.percentile(m, self.pmin, axis=axes, keepdims=True)
-            pmax = np.percentile(m, self.pmax, axis=axes, keepdims=True)
-        else:
-            pmin = np.percentile(m, self.pmin)
-            pmax = np.percentile(m, self.pmax)
-
-        return (m - pmin) / (pmax - pmin + self.eps)
-
-
-class Normalize:
-    """
-    Apply simple min-max scaling to a given input tensor, i.e. shrinks the range of the data
-    in a fixed range of [-1, 1] or in case of norm01==True to [0, 1]. In addition, data can be
-    clipped by specifying min_value/max_value either globally using single values or via a
-    list/tuple channelwise if enabled.
-    """
-
-    def __init__(self, min_value=None, max_value=None, norm01=False, channelwise=False,
-                 eps=1e-10, **kwargs):
-        if min_value is not None and max_value is not None:
-            assert max_value > min_value
-        self.min_value = min_value
-        self.max_value = max_value
-        self.norm01 = norm01
-        self.channelwise = channelwise
-        self.eps = eps
-
-    def __call__(self, m):
-        if self.channelwise:
-            # get min/max channelwise
-            axes = list(range(m.ndim))
-            axes = tuple(axes[1:])
-            if self.min_value is None or 'None' in self.min_value:
-                min_value = np.min(m, axis=axes, keepdims=True)
-
-            if self.max_value is None or 'None' in self.max_value:
-                max_value = np.max(m, axis=axes, keepdims=True)
-
-            # check if non None in self.min_value/self.max_value
-            # if present and if so copy value to min_value
-            if self.min_value is not None:
-                for i,v in enumerate(self.min_value):
-                    if v != 'None':
-                        min_value[i] = v
-
-            if self.max_value is not None:
-                for i,v in enumerate(self.max_value):
-                    if v != 'None':
-                        max_value[i] = v
-        else:
-            if self.min_value is None:
-                min_value = np.min(m)
-            else:
-                min_value = self.min_value
-
-            if self.max_value is None:
-                max_value = np.max(m)
-            else:
-                max_value = self.max_value
-
-        # calculate norm_0_1 with min_value / max_value with the same dimension
-        # in case of channelwise application
-        norm_0_1 = (m - min_value) / (max_value - min_value + self.eps)
-
-        if self.norm01 is True:
-          return np.clip(norm_0_1, 0, 1)
-        else:
-          return np.clip(2 * norm_0_1 - 1, -1, 1)
-
-
-class AdditiveGaussianNoise:
-    def __init__(self, random_state, scale=(0.0, 1.0), execution_probability=0.1, **kwargs):
-        self.execution_probability = execution_probability
-        self.random_state = random_state
-        self.scale = scale
-
-    def __call__(self, m):
-        if self.random_state.uniform() < self.execution_probability:
-            std = self.random_state.uniform(self.scale[0], self.scale[1])
-            gaussian_noise = self.random_state.normal(0, std, size=m.shape)
-            return m + gaussian_noise
-        return m
-
-
-class AdditivePoissonNoise:
-    def __init__(self, random_state, lam=(0.0, 1.0), execution_probability=0.1, **kwargs):
-        self.execution_probability = execution_probability
-        self.random_state = random_state
-        self.lam = lam
-
-    def __call__(self, m):
-        if self.random_state.uniform() < self.execution_probability:
-            lam = self.random_state.uniform(self.lam[0], self.lam[1])
-            poisson_noise = self.random_state.poisson(lam, size=m.shape)
-            return m + poisson_noise
-        return m
-
-
-class ToTensor:
-    """
-    Converts a given input numpy.ndarray into torch.Tensor.
-
-    Args:
-        expand_dims (bool): if True, adds a channel dimension to the input data
-        dtype (np.dtype): the desired output data type
-    """
-
-    def __init__(self, expand_dims, dtype=np.float32, **kwargs):
-        self.expand_dims = expand_dims
-        self.dtype = dtype
-
-    def __call__(self, m):
-        assert m.ndim in [3, 4], 'Supports only 3D (DxHxW) or 4D (CxDxHxW) images'
-        # add channel dimension
-        if self.expand_dims and m.ndim == 3:
-            m = np.expand_dims(m, axis=0)
-
-        return torch.from_numpy(m.astype(dtype=self.dtype))
-
-
-class Relabel:
-    """
-    Relabel a numpy array of labels into a consecutive numbers, e.g.
-    [10, 10, 0, 6, 6] -> [2, 2, 0, 1, 1]. Useful when one has an instance segmentation volume
-    at hand and would like to create a one-hot-encoding for it. Without a consecutive labeling the task would be harder.
-    """
-
-    def __init__(self, append_original=False, run_cc=True, ignore_label=None, **kwargs):
-        self.append_original = append_original
-        self.ignore_label = ignore_label
-        self.run_cc = run_cc
-
-        if ignore_label is not None:
-            assert append_original, "ignore_label present, so append_original must be true, so that one can localize the ignore region"
-
-    def __call__(self, m):
-        orig = m
-        if self.run_cc:
-            # assign 0 to the ignore region
-            m = measure.label(m, background=self.ignore_label)
-
-        _, unique_labels = np.unique(m, return_inverse=True)
-        result = unique_labels.reshape(m.shape)
-        if self.append_original:
-            result = np.stack([result, orig])
-        return result
-
-
-class Identity:
-    def __init__(self, **kwargs):
-        pass
-
-    def __call__(self, m):
-        return m
-
-
-class RgbToLabel:
-    def __call__(self, img):
-        img = np.array(img)
-        assert img.ndim == 3 and img.shape[2] == 3
-        result = img[..., 0] * 65536 + img[..., 1] * 256 + img[..., 2]
-        return result
-
-
-class LabelToTensor:
-    def __call__(self, m):
-        m = np.array(m)
-        return torch.from_numpy(m.astype(dtype='int64'))
-
-
-class GaussianBlur3D:
-    def __init__(self, sigma=[.1, 2.], execution_probability=0.5, **kwargs):
-        self.sigma = sigma
-        self.execution_probability = execution_probability
-
-    def __call__(self, x):
-        if random.random() < self.execution_probability:
-            sigma = random.uniform(self.sigma[0], self.sigma[1])
-            x = gaussian(x, sigma=sigma)
-            return x
-        return x
-
-
-class Transformer:
-    def __init__(self, phase_config, base_config):
-        self.phase_config = phase_config
-        self.config_base = base_config
-        self.seed = GLOBAL_RANDOM_STATE.randint(10000000)
-
-    def raw_transform(self):
-        return self._create_transform('raw')
-
-    def label_transform(self):
-        return self._create_transform('label')
-
-    def weight_transform(self):
-        return self._create_transform('weight')
-
-    @staticmethod
-    def _transformer_class(class_name):
-        m = importlib.import_module('pytorch3dunet.augment.transforms')
-        clazz = getattr(m, class_name)
-        return clazz
-
-    def _create_transform(self, name):
-        assert name in self.phase_config, f'Could not find {name} transform'
-        return Compose([
-            self._create_augmentation(c) for c in self.phase_config[name]
-        ])
-
-    def _create_augmentation(self, c):
-        config = dict(self.config_base)
-        config.update(c)
-        config['random_state'] = np.random.RandomState(self.seed)
-        aug_class = self._transformer_class(config['name'])
-        return aug_class(**config)
-
-
-def _recover_ignore_index(input, orig, ignore_index):
-    if ignore_index is not None:
-        mask = orig == ignore_index
-        input[mask] = ignore_index
-
-    return input
diff --git a/build/lib/pytorch3dunet/datasets/__init__.py b/build/lib/pytorch3dunet/datasets/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/build/lib/pytorch3dunet/datasets/dsb.py b/build/lib/pytorch3dunet/datasets/dsb.py
deleted file mode 100644
index 5d0cde86..00000000
--- a/build/lib/pytorch3dunet/datasets/dsb.py
+++ /dev/null
@@ -1,108 +0,0 @@
-import collections
-import os
-
-import imageio
-import numpy as np
-import torch
-
-from pytorch3dunet.augment import transforms
-from pytorch3dunet.datasets.utils import ConfigDataset, calculate_stats
-from pytorch3dunet.unet3d.utils import get_logger
-
-logger = get_logger('DSB2018Dataset')
-
-
-def dsb_prediction_collate(batch):
-    """
-    Forms a mini-batch of (images, paths) during test time for the DSB-like datasets.
-    """
-    error_msg = "batch must contain tensors or str; found {}"
-    if isinstance(batch[0], torch.Tensor):
-        return torch.stack(batch, 0)
-    elif isinstance(batch[0], str):
-        return list(batch)
-    elif isinstance(batch[0], collections.Sequence):
-        # transpose tuples, i.e. [[1, 2], ['a', 'b']] to be [[1, 'a'], [2, 'b']]
-        transposed = zip(*batch)
-        return [dsb_prediction_collate(samples) for samples in transposed]
-
-    raise TypeError((error_msg.format(type(batch[0]))))
-
-
-class DSB2018Dataset(ConfigDataset):
-    def __init__(self, root_dir, phase, transformer_config, expand_dims=True):
-        assert os.path.isdir(root_dir), f'{root_dir} is not a directory'
-        assert phase in ['train', 'val', 'test']
-
-        self.phase = phase
-
-        # load raw images
-        images_dir = os.path.join(root_dir, 'images')
-        assert os.path.isdir(images_dir)
-        self.images, self.paths = self._load_files(images_dir, expand_dims)
-        self.file_path = images_dir
-
-        stats = calculate_stats(self.images, True)
-
-        transformer = transforms.Transformer(transformer_config, stats)
-
-        # load raw images transformer
-        self.raw_transform = transformer.raw_transform()
-
-        if phase != 'test':
-            # load labeled images
-            masks_dir = os.path.join(root_dir, 'masks')
-            assert os.path.isdir(masks_dir)
-            self.masks, _ = self._load_files(masks_dir, expand_dims)
-            assert len(self.images) == len(self.masks)
-            # load label images transformer
-            self.masks_transform = transformer.label_transform()
-        else:
-            self.masks = None
-            self.masks_transform = None
-
-    def __getitem__(self, idx):
-        if idx >= len(self):
-            raise StopIteration
-
-        img = self.images[idx]
-        if self.phase != 'test':
-            mask = self.masks[idx]
-            return self.raw_transform(img), self.masks_transform(mask)
-        else:
-            return self.raw_transform(img), self.paths[idx]
-
-    def __len__(self):
-        return len(self.images)
-
-    @classmethod
-    def prediction_collate(cls, batch):
-        return dsb_prediction_collate(batch)
-
-    @classmethod
-    def create_datasets(cls, dataset_config, phase):
-        phase_config = dataset_config[phase]
-        # load data augmentation configuration
-        transformer_config = phase_config['transformer']
-        # load files to process
-        file_paths = phase_config['file_paths']
-        expand_dims = dataset_config.get('expand_dims', True)
-        return [cls(file_paths[0], phase, transformer_config, expand_dims)]
-
-    @staticmethod
-    def _load_files(dir, expand_dims):
-        files_data = []
-        paths = []
-        for file in os.listdir(dir):
-            path = os.path.join(dir, file)
-            img = np.asarray(imageio.imread(path))
-            if expand_dims:
-                dims = img.ndim
-                img = np.expand_dims(img, axis=0)
-                if dims == 3:
-                    img = np.transpose(img, (3, 0, 1, 2))
-
-            files_data.append(img)
-            paths.append(path)
-
-        return files_data, paths
diff --git a/build/lib/pytorch3dunet/datasets/hdf5.py b/build/lib/pytorch3dunet/datasets/hdf5.py
deleted file mode 100644
index 040adb85..00000000
--- a/build/lib/pytorch3dunet/datasets/hdf5.py
+++ /dev/null
@@ -1,293 +0,0 @@
-import glob
-import os
-from abc import abstractmethod
-from itertools import chain
-
-import h5py
-
-import pytorch3dunet.augment.transforms as transforms
-from pytorch3dunet.datasets.utils import get_slice_builder, ConfigDataset, calculate_stats, mirror_pad
-from pytorch3dunet.unet3d.utils import get_logger
-
-logger = get_logger('HDF5Dataset')
-
-
-def _create_padded_indexes(indexes, halo_shape):
-    return tuple(slice(index.start, index.stop + 2 * halo) for index, halo in zip(indexes, halo_shape))
-
-
-def traverse_h5_paths(file_paths):
-    assert isinstance(file_paths, list)
-    results = []
-    for file_path in file_paths:
-        if os.path.isdir(file_path):
-            # if file path is a directory take all H5 files in that directory
-            iters = [glob.glob(os.path.join(file_path, ext)) for ext in ['*.h5', '*.hdf', '*.hdf5', '*.hd5']]
-            for fp in chain(*iters):
-                results.append(fp)
-        else:
-            results.append(file_path)
-    return results
-
-
-class AbstractHDF5Dataset(ConfigDataset):
-    """
-    Implementation of torch.utils.data.Dataset backed by the HDF5 files, which iterates over the raw and label datasets
-    patch by patch with a given stride.
-
-    Args:
-        file_path (str): path to H5 file containing raw data as well as labels and per pixel weights (optional)
-        phase (str): 'train' for training, 'val' for validation, 'test' for testing
-        slice_builder_config (dict): configuration of the SliceBuilder
-        transformer_config (dict): data augmentation configuration
-        raw_internal_path (str or list): H5 internal path to the raw dataset
-        label_internal_path (str or list): H5 internal path to the label dataset
-        weight_internal_path (str or list): H5 internal path to the per pixel weights (optional)
-        global_normalization (bool): if True, the mean and std of the raw data will be calculated over the whole dataset
-    """
-
-    def __init__(self, file_path, phase, slice_builder_config, transformer_config, raw_internal_path='raw',
-                 label_internal_path='label', weight_internal_path=None, global_normalization=True):
-        assert phase in ['train', 'val', 'test']
-
-        self.phase = phase
-        self.file_path = file_path
-        self.raw_internal_path = raw_internal_path
-        self.label_internal_path = label_internal_path
-        self.weight_internal_path = weight_internal_path
-
-        self.halo_shape = slice_builder_config.get('halo_shape', [0, 0, 0])
-
-        if global_normalization:
-            logger.info('Calculating mean and std of the raw data...')
-            with h5py.File(file_path, 'r') as f:
-                raw = f[raw_internal_path][:]
-                stats = calculate_stats(raw)
-        else:
-            stats = calculate_stats(None, True)
-
-        self.transformer = transforms.Transformer(transformer_config, stats)
-        self.raw_transform = self.transformer.raw_transform()
-
-        if phase != 'test':
-            # create label/weight transform only in train/val phase
-            self.label_transform = self.transformer.label_transform()
-
-            if weight_internal_path is not None:
-                self.weight_transform = self.transformer.weight_transform()
-            else:
-                self.weight_transform = None
-
-            self._check_volume_sizes()
-        else:
-            # 'test' phase used only for predictions so ignore the label dataset
-            self.label = None
-            self.weight_map = None
-
-            # compare patch and stride configuration
-            patch_shape = slice_builder_config.get('patch_shape')
-            stride_shape = slice_builder_config.get('stride_shape')
-            if sum(self.halo_shape) != 0 and patch_shape != stride_shape:
-                logger.warning(f'Found non-zero halo shape {self.halo_shape}. '
-                               f'In this case: patch shape and stride shape should be equal for optimal prediction '
-                               f'performance, but found patch_shape: {patch_shape} and stride_shape: {stride_shape}!')
-
-        with h5py.File(file_path, 'r') as f:
-            raw = f[raw_internal_path]
-            label = f[label_internal_path] if phase != 'test' else None
-            weight_map = f[weight_internal_path] if weight_internal_path is not None else None
-            # build slice indices for raw and label data sets
-            slice_builder = get_slice_builder(raw, label, weight_map, slice_builder_config)
-            self.raw_slices = slice_builder.raw_slices
-            self.label_slices = slice_builder.label_slices
-            self.weight_slices = slice_builder.weight_slices
-
-        self.patch_count = len(self.raw_slices)
-        logger.info(f'Number of patches: {self.patch_count}')
-
-    @abstractmethod
-    def get_raw_patch(self, idx):
-        raise NotImplementedError
-
-    @abstractmethod
-    def get_label_patch(self, idx):
-        raise NotImplementedError
-
-    @abstractmethod
-    def get_weight_patch(self, idx):
-        raise NotImplementedError
-
-    @abstractmethod
-    def get_raw_padded_patch(self, idx):
-        raise NotImplementedError
-
-    def volume_shape(self):
-        with h5py.File(self.file_path, 'r') as f:
-            raw = f[self.raw_internal_path]
-            if raw.ndim == 3:
-                return raw.shape
-            else:
-                return raw.shape[1:]
-
-    def __getitem__(self, idx):
-        if idx >= len(self):
-            raise StopIteration
-
-        raw_idx = self.raw_slices[idx]
-
-        if self.phase == 'test':
-            if len(raw_idx) == 4:
-                # discard the channel dimension in the slices: predictor requires only the spatial dimensions of the volume
-                raw_idx = raw_idx[1:]  # Remove the first element if raw_idx has 4 elements
-                raw_idx_padded = (slice(None),) + _create_padded_indexes(raw_idx, self.halo_shape)
-            else:
-                raw_idx_padded = _create_padded_indexes(raw_idx, self.halo_shape)
-
-            raw_patch_transformed = self.raw_transform(self.get_raw_padded_patch(raw_idx_padded))
-            return raw_patch_transformed, raw_idx
-        else:
-            raw_patch_transformed = self.raw_transform(self.get_raw_patch(raw_idx))
-
-            # get the slice for a given index 'idx'
-            label_idx = self.label_slices[idx]
-            label_patch_transformed = self.label_transform(self.get_label_patch(label_idx))
-            if self.weight_internal_path is not None:
-                weight_idx = self.weight_slices[idx]
-                weight_patch_transformed = self.weight_transform(self.get_weight_patch(weight_idx))
-                return raw_patch_transformed, label_patch_transformed, weight_patch_transformed
-            # return the transformed raw and label patches
-            return raw_patch_transformed, label_patch_transformed
-
-    def __len__(self):
-        return self.patch_count
-
-    def _check_volume_sizes(self):
-        def _volume_shape(volume):
-            if volume.ndim == 3:
-                return volume.shape
-            return volume.shape[1:]
-
-        with h5py.File(self.file_path, 'r') as f:
-            raw = f[self.raw_internal_path]
-            label = f[self.label_internal_path]
-            assert raw.ndim in [3, 4], 'Raw dataset must be 3D (DxHxW) or 4D (CxDxHxW)'
-            assert label.ndim in [3, 4], 'Label dataset must be 3D (DxHxW) or 4D (CxDxHxW)'
-            assert _volume_shape(raw) == _volume_shape(label), 'Raw and labels have to be of the same size'
-            if self.weight_internal_path is not None:
-                weight_map = f[self.weight_internal_path]
-                assert weight_map.ndim in [3, 4], 'Weight map dataset must be 3D (DxHxW) or 4D (CxDxHxW)'
-                assert _volume_shape(raw) == _volume_shape(weight_map), 'Raw and weight map have to be of the same size'
-
-    @classmethod
-    def create_datasets(cls, dataset_config, phase):
-        phase_config = dataset_config[phase]
-
-        # load data augmentation configuration
-        transformer_config = phase_config['transformer']
-        # load slice builder config
-        slice_builder_config = phase_config['slice_builder']
-        # load files to process
-        file_paths = phase_config['file_paths']
-        # file_paths may contain both files and directories; if the file_path is a directory all H5 files inside
-        # are going to be included in the final file_paths
-        file_paths = traverse_h5_paths(file_paths)
-
-        datasets = []
-        for file_path in file_paths:
-            try:
-                logger.info(f'Loading {phase} set from: {file_path}...')
-                dataset = cls(file_path=file_path,
-                              phase=phase,
-                              slice_builder_config=slice_builder_config,
-                              transformer_config=transformer_config,
-                              raw_internal_path=dataset_config.get('raw_internal_path', 'raw'),
-                              label_internal_path=dataset_config.get('label_internal_path', 'label'),
-                              weight_internal_path=dataset_config.get('weight_internal_path', None),
-                              global_normalization=dataset_config.get('global_normalization', None))
-                datasets.append(dataset)
-            except Exception:
-                logger.error(f'Skipping {phase} set: {file_path}', exc_info=True)
-        return datasets
-
-
-class StandardHDF5Dataset(AbstractHDF5Dataset):
-    """
-    Implementation of the HDF5 dataset which loads the data from the H5 files into the memory.
-    Fast but might consume a lot of memory.
-    """
-
-    def __init__(self, file_path, phase, slice_builder_config, transformer_config,
-                 raw_internal_path='raw', label_internal_path='label', weight_internal_path=None,
-                 global_normalization=True):
-        super().__init__(file_path=file_path, phase=phase, slice_builder_config=slice_builder_config,
-                         transformer_config=transformer_config, raw_internal_path=raw_internal_path,
-                         label_internal_path=label_internal_path, weight_internal_path=weight_internal_path,
-                         global_normalization=global_normalization)
-        self._raw = None
-        self._raw_padded = None
-        self._label = None
-        self._weight_map = None
-
-    def get_raw_patch(self, idx):
-        if self._raw is None:
-            with h5py.File(self.file_path, 'r') as f:
-                assert self.raw_internal_path in f, f'Dataset {self.raw_internal_path} not found in {self.file_path}'
-                self._raw = f[self.raw_internal_path][:]
-        return self._raw[idx]
-
-    def get_label_patch(self, idx):
-        if self._label is None:
-            with h5py.File(self.file_path, 'r') as f:
-                assert self.label_internal_path in f, f'Dataset {self.label_internal_path} not found in {self.file_path}'
-                self._label = f[self.label_internal_path][:]
-        return self._label[idx]
-
-    def get_weight_patch(self, idx):
-        if self._weight_map is None:
-            with h5py.File(self.file_path, 'r') as f:
-                assert self.weight_internal_path in f, f'Dataset {self.weight_internal_path} not found in {self.file_path}'
-                self._weight_map = f[self.weight_internal_path][:]
-        return self._weight_map[idx]
-
-    def get_raw_padded_patch(self, idx):
-        if self._raw_padded is None:
-            with h5py.File(self.file_path, 'r') as f:
-                assert self.raw_internal_path in f, f'Dataset {self.raw_internal_path} not found in {self.file_path}'
-                self._raw_padded = mirror_pad(f[self.raw_internal_path][:], self.halo_shape)
-        return self._raw_padded[idx]
-
-
-class LazyHDF5Dataset(AbstractHDF5Dataset):
-    """Implementation of the HDF5 dataset which loads the data lazily. It's slower, but has a low memory footprint."""
-
-    def __init__(self, file_path, phase, slice_builder_config, transformer_config,
-                 raw_internal_path='raw', label_internal_path='label', weight_internal_path=None,
-                 global_normalization=False):
-        super().__init__(file_path=file_path, phase=phase, slice_builder_config=slice_builder_config,
-                         transformer_config=transformer_config, raw_internal_path=raw_internal_path,
-                         label_internal_path=label_internal_path, weight_internal_path=weight_internal_path,
-                         global_normalization=global_normalization)
-
-        logger.info("Using LazyHDF5Dataset")
-
-    def get_raw_patch(self, idx):
-        with h5py.File(self.file_path, 'r') as f:
-            return f[self.raw_internal_path][idx]
-
-    def get_label_patch(self, idx):
-        with h5py.File(self.file_path, 'r') as f:
-            return f[self.label_internal_path][idx]
-
-    def get_weight_patch(self, idx):
-        with h5py.File(self.file_path, 'r') as f:
-            return f[self.weight_internal_path][idx]
-
-    def get_raw_padded_patch(self, idx):
-        with h5py.File(self.file_path, 'r+') as f:
-            if 'raw_padded' in f:
-                return f['raw_padded'][idx]
-
-            raw = f[self.raw_internal_path][:]
-            raw_padded = mirror_pad(raw, self.halo_shape)
-            f.create_dataset('raw_padded', data=raw_padded, compression='gzip')
-            return raw_padded[idx]
diff --git a/build/lib/pytorch3dunet/datasets/utils.py b/build/lib/pytorch3dunet/datasets/utils.py
deleted file mode 100644
index 1ffeefe4..00000000
--- a/build/lib/pytorch3dunet/datasets/utils.py
+++ /dev/null
@@ -1,361 +0,0 @@
-import collections
-from typing import Any
-
-import numpy as np
-import torch
-from torch.utils.data import DataLoader, ConcatDataset, Dataset
-
-from pytorch3dunet.unet3d.utils import get_logger, get_class
-
-logger = get_logger('Dataset')
-
-
-class ConfigDataset(Dataset):
-    def __getitem__(self, index):
-        raise NotImplementedError
-
-    def __len__(self):
-        raise NotImplementedError
-
-    @classmethod
-    def create_datasets(cls, dataset_config, phase):
-        """
-        Factory method for creating a list of datasets based on the provided config.
-
-        Args:
-            dataset_config (dict): dataset configuration
-            phase (str): one of ['train', 'val', 'test']
-
-        Returns:
-            list of `Dataset` instances
-        """
-        raise NotImplementedError
-
-    @classmethod
-    def prediction_collate(cls, batch):
-        """Default collate_fn. Override in child class for non-standard datasets."""
-        return default_prediction_collate(batch)
-
-
-class SliceBuilder:
-    """
-    Builds the position of the patches in a given raw/label/weight ndarray based on the patch and stride shape.
-
-    Args:
-        raw_dataset (ndarray): raw data
-        label_dataset (ndarray): ground truth labels
-        weight_dataset (ndarray): weights for the labels
-        patch_shape (tuple): the shape of the patch DxHxW
-        stride_shape (tuple): the shape of the stride DxHxW
-        kwargs: additional metadata
-    """
-
-    def __init__(self, raw_dataset, label_dataset, weight_dataset, patch_shape, stride_shape, **kwargs):
-        patch_shape = tuple(patch_shape)
-        stride_shape = tuple(stride_shape)
-        skip_shape_check = kwargs.get('skip_shape_check', False)
-        if not skip_shape_check:
-            self._check_patch_shape(patch_shape)
-
-        self._raw_slices = self._build_slices(raw_dataset, patch_shape, stride_shape)
-        if label_dataset is None:
-            self._label_slices = None
-        else:
-            # take the first element in the label_dataset to build slices
-            self._label_slices = self._build_slices(label_dataset, patch_shape, stride_shape)
-            assert len(self._raw_slices) == len(self._label_slices)
-        if weight_dataset is None:
-            self._weight_slices = None
-        else:
-            self._weight_slices = self._build_slices(weight_dataset, patch_shape, stride_shape)
-            assert len(self.raw_slices) == len(self._weight_slices)
-
-    @property
-    def raw_slices(self):
-        return self._raw_slices
-
-    @property
-    def label_slices(self):
-        return self._label_slices
-
-    @property
-    def weight_slices(self):
-        return self._weight_slices
-
-    @staticmethod
-    def _build_slices(dataset, patch_shape, stride_shape):
-        """Iterates over a given n-dim dataset patch-by-patch with a given stride
-        and builds an array of slice positions.
-
-        Returns:
-            list of slices, i.e.
-            [(slice, slice, slice, slice), ...] if len(shape) == 4
-            [(slice, slice, slice), ...] if len(shape) == 3
-        """
-        slices = []
-        if dataset.ndim == 4:
-            in_channels, i_z, i_y, i_x = dataset.shape
-        else:
-            i_z, i_y, i_x = dataset.shape
-
-        k_z, k_y, k_x = patch_shape
-        s_z, s_y, s_x = stride_shape
-        z_steps = SliceBuilder._gen_indices(i_z, k_z, s_z)
-        for z in z_steps:
-            y_steps = SliceBuilder._gen_indices(i_y, k_y, s_y)
-            for y in y_steps:
-                x_steps = SliceBuilder._gen_indices(i_x, k_x, s_x)
-                for x in x_steps:
-                    slice_idx = (
-                        slice(z, z + k_z),
-                        slice(y, y + k_y),
-                        slice(x, x + k_x),
-                    )
-                    if dataset.ndim == 4:
-                        slice_idx = (slice(0, in_channels),) + slice_idx
-                    slices.append(slice_idx)
-        return slices
-
-    @staticmethod
-    def _gen_indices(i, k, s):
-        assert i >= k, 'Sample size has to be bigger than the patch size'
-        for j in range(0, i - k + 1, s):
-            yield j
-        if j + k < i:
-            yield i - k
-
-    @staticmethod
-    def _check_patch_shape(patch_shape):
-        assert len(patch_shape) == 3, 'patch_shape must be a 3D tuple'
-        assert patch_shape[1] >= 64 and patch_shape[2] >= 64, 'Height and Width must be greater or equal 64'
-
-
-class FilterSliceBuilder(SliceBuilder):
-    """
-    Filter patches containing more than `1 - threshold` of ignore_index label
-    """
-
-    def __init__(self, raw_dataset, label_dataset, weight_dataset, patch_shape, stride_shape, ignore_index=None,
-                 threshold=0.6, slack_acceptance=0.01, **kwargs):
-        super().__init__(raw_dataset, label_dataset, weight_dataset, patch_shape, stride_shape, **kwargs)
-        if label_dataset is None:
-            return
-
-        rand_state = np.random.RandomState(47)
-
-        def ignore_predicate(raw_label_idx):
-            label_idx = raw_label_idx[1]
-            patch = label_dataset[label_idx]
-            if ignore_index is not None:
-                patch = np.copy(patch)
-                patch[patch == ignore_index] = 0
-            non_ignore_counts = np.count_nonzero(patch != 0)
-            non_ignore_counts = non_ignore_counts / patch.size
-            return non_ignore_counts > threshold or rand_state.rand() < slack_acceptance
-
-        zipped_slices = zip(self.raw_slices, self.label_slices)
-        # ignore slices containing too much ignore_index
-        logger.info(f'Filtering slices...')
-        filtered_slices = list(filter(ignore_predicate, zipped_slices))
-        # unzip and save slices
-        raw_slices, label_slices = zip(*filtered_slices)
-        self._raw_slices = list(raw_slices)
-        self._label_slices = list(label_slices)
-
-
-def _loader_classes(class_name):
-    modules = [
-        'pytorch3dunet.datasets.hdf5',
-        'pytorch3dunet.datasets.dsb',
-        'pytorch3dunet.datasets.utils'
-    ]
-    return get_class(class_name, modules)
-
-
-def get_slice_builder(raws, labels, weight_maps, config):
-    assert 'name' in config
-    logger.info(f"Slice builder config: {config}")
-    slice_builder_cls = _loader_classes(config['name'])
-    return slice_builder_cls(raws, labels, weight_maps, **config)
-
-
-def get_train_loaders(config):
-    """
-    Returns dictionary containing the training and validation loaders (torch.utils.data.DataLoader).
-
-    :param config: a top level configuration object containing the 'loaders' key
-    :return: dict {
-        'train': <train_loader>
-        'val': <val_loader>
-    }
-    """
-    assert 'loaders' in config, 'Could not find data loaders configuration'
-    loaders_config = config['loaders']
-
-    logger.info('Creating training and validation set loaders...')
-
-    # get dataset class
-    dataset_cls_str = loaders_config.get('dataset', None)
-    if dataset_cls_str is None:
-        dataset_cls_str = 'StandardHDF5Dataset'
-        logger.warning(f"Cannot find dataset class in the config. Using default '{dataset_cls_str}'.")
-    dataset_class = _loader_classes(dataset_cls_str)
-
-    assert set(loaders_config['train']['file_paths']).isdisjoint(loaders_config['val']['file_paths']), \
-        "Train and validation 'file_paths' overlap. One cannot use validation data for training!"
-
-    train_datasets = dataset_class.create_datasets(loaders_config, phase='train')
-
-    val_datasets = dataset_class.create_datasets(loaders_config, phase='val')
-
-    num_workers = loaders_config.get('num_workers', 1)
-    logger.info(f'Number of workers for train/val dataloader: {num_workers}')
-    batch_size = loaders_config.get('batch_size', 1)
-    if torch.cuda.device_count() > 1 and not config['device'] == 'cpu':
-        logger.info(
-            f'{torch.cuda.device_count()} GPUs available. Using batch_size = {torch.cuda.device_count()} * {batch_size}')
-        batch_size = batch_size * torch.cuda.device_count()
-
-    logger.info(f'Batch size for train/val loader: {batch_size}')
-    # when training with volumetric data use batch_size of 1 due to GPU memory constraints
-    return {
-        'train': DataLoader(ConcatDataset(train_datasets), batch_size=batch_size, shuffle=True, pin_memory=True,
-                            num_workers=num_workers),
-        # don't shuffle during validation: useful when showing how predictions for a given batch get better over time
-        'val': DataLoader(ConcatDataset(val_datasets), batch_size=batch_size, shuffle=False, pin_memory=True,
-                          num_workers=num_workers)
-    }
-
-
-def get_test_loaders(config):
-    """
-    Returns test DataLoader.
-
-    :return: generator of DataLoader objects
-    """
-
-    assert 'loaders' in config, 'Could not find data loaders configuration'
-    loaders_config = config['loaders']
-
-    logger.info('Creating test set loaders...')
-
-    # get dataset class
-    dataset_cls_str = loaders_config.get('dataset', None)
-    if dataset_cls_str is None:
-        dataset_cls_str = 'StandardHDF5Dataset'
-        logger.warning(f"Cannot find dataset class in the config. Using default '{dataset_cls_str}'.")
-    dataset_class = _loader_classes(dataset_cls_str)
-
-    test_datasets = dataset_class.create_datasets(loaders_config, phase='test')
-
-    num_workers = loaders_config.get('num_workers', 1)
-    logger.info(f'Number of workers for the dataloader: {num_workers}')
-
-    batch_size = loaders_config.get('batch_size', 1)
-    if torch.cuda.device_count() > 1 and not config['device'] == 'cpu':
-        logger.info(
-            f'{torch.cuda.device_count()} GPUs available. Using batch_size = {torch.cuda.device_count()} * {batch_size}')
-        batch_size = batch_size * torch.cuda.device_count()
-
-    logger.info(f'Batch size for dataloader: {batch_size}')
-
-    # use generator in order to create data loaders lazily one by one
-    for test_dataset in test_datasets:
-        logger.info(f'Loading test set from: {test_dataset.file_path}...')
-        if hasattr(test_dataset, 'prediction_collate'):
-            collate_fn = test_dataset.prediction_collate
-        else:
-            collate_fn = default_prediction_collate
-
-        yield DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=True,
-                         collate_fn=collate_fn)
-
-
-def default_prediction_collate(batch):
-    """
-    Default collate_fn to form a mini-batch of Tensor(s) for HDF5 based datasets
-    """
-    error_msg = "batch must contain tensors or slice; found {}"
-    if isinstance(batch[0], torch.Tensor):
-        return torch.stack(batch, 0)
-    elif isinstance(batch[0], tuple) and isinstance(batch[0][0], slice):
-        return batch
-    elif isinstance(batch[0], collections.abc.Sequence):
-        transposed = zip(*batch)
-        return [default_prediction_collate(samples) for samples in transposed]
-
-    raise TypeError((error_msg.format(type(batch[0]))))
-
-
-def calculate_stats(img: np.array, skip: bool = False) -> dict[str, Any]:
-    """
-    Calculates the minimum percentile, maximum percentile, mean, and standard deviation of the image.
-
-    Args:
-        img: The input image array.
-        skip: if True, skip the calculation and return None for all values.
-
-    Returns:
-        tuple[float, float, float, float]: The minimum percentile, maximum percentile, mean, and std dev
-    """
-    if not skip:
-        pmin, pmax, mean, std = np.percentile(img, 1), np.percentile(img, 99.6), np.mean(img), np.std(img)
-    else:
-        pmin, pmax, mean, std = None, None, None, None
-
-    return {
-        'pmin': pmin,
-        'pmax': pmax,
-        'mean': mean,
-        'std': std
-    }
-
-
-def mirror_pad(image, padding_shape):
-    """
-    Pad the image with a mirror reflection of itself.
-
-    This function is used on data in its original shape before it is split into patches.
-
-    Args:
-        image (np.ndarray): The input image array to be padded.
-        padding_shape (tuple of int): Specifies the amount of padding for each dimension, should be YX or ZYX.
-
-    Returns:
-        np.ndarray: The mirror-padded image.
-
-    Raises:
-        ValueError: If any element of padding_shape is negative.
-    """
-    assert len(padding_shape) == 3, "Padding shape must be specified for each dimension: ZYX"
-
-    if any(p < 0 for p in padding_shape):
-        raise ValueError("padding_shape must be non-negative")
-
-    if all(p == 0 for p in padding_shape):
-        return image
-
-    pad_width = [(p, p) for p in padding_shape]
-
-    if image.ndim == 4:
-        pad_width = [(0, 0)] + pad_width
-    return np.pad(image, pad_width, mode='reflect')
-
-
-def remove_padding(m, padding_shape):
-    """
-    Removes padding from the margins of a multi-dimensional array.
-
-    Args:
-        m (np.ndarray): The input array to be unpadded.
-        padding_shape (tuple of int, optional): The amount of padding to remove from each dimension.
-            Assumes the tuple length matches the array dimensions.
-
-    Returns:
-        np.ndarray: The unpadded array.
-    """
-    if padding_shape is None:
-        return m
-
-    # Correctly construct slice objects for each dimension in padding_shape and apply them to m.
-    return m[(..., *(slice(p, -p or None) for p in padding_shape))]
diff --git a/build/lib/pytorch3dunet/predict.py b/build/lib/pytorch3dunet/predict.py
deleted file mode 100644
index cc54fcf7..00000000
--- a/build/lib/pytorch3dunet/predict.py
+++ /dev/null
@@ -1,59 +0,0 @@
-import importlib
-import os
-
-import torch
-import torch.nn as nn
-
-from pytorch3dunet.datasets.utils import get_test_loaders
-from pytorch3dunet.unet3d import utils
-from pytorch3dunet.unet3d.config import load_config
-from pytorch3dunet.unet3d.model import get_model
-
-logger = utils.get_logger('UNet3DPredict')
-
-
-def get_predictor(model, config):
-    output_dir = config['loaders'].get('output_dir', None)
-    # override output_dir if provided in the 'predictor' section of the config
-    output_dir = config.get('predictor', {}).get('output_dir', output_dir)
-    if output_dir is not None:
-        os.makedirs(output_dir, exist_ok=True)
-
-    predictor_config = config.get('predictor', {})
-    class_name = predictor_config.get('name', 'StandardPredictor')
-
-    m = importlib.import_module('pytorch3dunet.unet3d.predictor')
-    predictor_class = getattr(m, class_name)
-    out_channels = config['model'].get('out_channels')
-    return predictor_class(model, output_dir, out_channels, **predictor_config)
-
-
-def main():
-    # Load configuration
-    config, _ = load_config()
-
-    # Create the model
-    model = get_model(config['model'])
-
-    # Load model state
-    model_path = config['model_path']
-    logger.info(f'Loading model from {model_path}...')
-    utils.load_checkpoint(model_path, model)
-    # use DataParallel if more than 1 GPU available
-
-    if torch.cuda.device_count() > 1 and not config['device'] == 'cpu':
-        model = nn.DataParallel(model)
-        logger.info(f'Using {torch.cuda.device_count()} GPUs for prediction')
-    if torch.cuda.is_available() and not config['device'] == 'cpu':
-        model = model.cuda()
-
-    # create predictor instance
-    predictor = get_predictor(model, config)
-
-    for test_loader in get_test_loaders(config):
-        # run the model prediction on the test_loader and save the results in the output_dir
-        predictor(test_loader)
-
-
-if __name__ == '__main__':
-    main()
diff --git a/build/lib/pytorch3dunet/train.py b/build/lib/pytorch3dunet/train.py
deleted file mode 100644
index eceaf719..00000000
--- a/build/lib/pytorch3dunet/train.py
+++ /dev/null
@@ -1,35 +0,0 @@
-import random
-
-import torch
-
-from pytorch3dunet.unet3d.config import load_config, copy_config
-from pytorch3dunet.unet3d.trainer import create_trainer
-from pytorch3dunet.unet3d.utils import get_logger
-
-logger = get_logger('TrainingSetup')
-
-
-def main():
-    # Load and log experiment configuration
-    config, config_path = load_config()
-    logger.info(config)
-
-    manual_seed = config.get('manual_seed', None)
-    if manual_seed is not None:
-        logger.info(f'Seed the RNG for all devices with {manual_seed}')
-        logger.warning('Using CuDNN deterministic setting. This may slow down the training!')
-        random.seed(manual_seed)
-        torch.manual_seed(manual_seed)
-        # see https://pytorch.org/docs/stable/notes/randomness.html
-        torch.backends.cudnn.deterministic = True
-
-    # Create trainer
-    trainer = create_trainer(config)
-    # Copy config file
-    copy_config(config, config_path)
-    # Start training
-    trainer.fit()
-
-
-if __name__ == '__main__':
-    main()
diff --git a/build/lib/pytorch3dunet/unet3d/__init__.py b/build/lib/pytorch3dunet/unet3d/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/build/lib/pytorch3dunet/unet3d/buildingblocks.py b/build/lib/pytorch3dunet/unet3d/buildingblocks.py
deleted file mode 100644
index 25679c24..00000000
--- a/build/lib/pytorch3dunet/unet3d/buildingblocks.py
+++ /dev/null
@@ -1,545 +0,0 @@
-from functools import partial
-
-import torch
-from torch import nn as nn
-from torch.nn import functional as F
-
-from pytorch3dunet.unet3d.se import ChannelSELayer3D, ChannelSpatialSELayer3D, SpatialSELayer3D
-
-
-def create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding,
-                dropout_prob, is3d):
-    """
-    Create a list of modules with together constitute a single conv layer with non-linearity
-    and optional batchnorm/groupnorm.
-
-    Args:
-        in_channels (int): number of input channels
-        out_channels (int): number of output channels
-        kernel_size(int or tuple): size of the convolving kernel
-        order (string): order of things, e.g.
-            'cr' -> conv + ReLU
-            'gcr' -> groupnorm + conv + ReLU
-            'cl' -> conv + LeakyReLU
-            'ce' -> conv + ELU
-            'bcr' -> batchnorm + conv + ReLU
-            'cbrd' -> conv + batchnorm + ReLU + dropout
-            'cbrD' -> conv + batchnorm + ReLU + dropout2d
-        num_groups (int): number of groups for the GroupNorm
-        padding (int or tuple): add zero-padding added to all three sides of the input
-        dropout_prob (float): dropout probability
-        is3d (bool): is3d (bool): if True use Conv3d, otherwise use Conv2d
-    Return:
-        list of tuple (name, module)
-    """
-    assert 'c' in order, "Conv layer MUST be present"
-    assert order[0] not in 'rle', 'Non-linearity cannot be the first operation in the layer'
-
-    modules = []
-    for i, char in enumerate(order):
-        if char == 'r':
-            modules.append(('ReLU', nn.ReLU(inplace=True)))
-        elif char == 'l':
-            modules.append(('LeakyReLU', nn.LeakyReLU(inplace=True)))
-        elif char == 'e':
-            modules.append(('ELU', nn.ELU(inplace=True)))
-        elif char == 'c':
-            # add learnable bias only in the absence of batchnorm/groupnorm
-            bias = not ('g' in order or 'b' in order)
-            if is3d:
-                conv = nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding, bias=bias)
-            else:
-                conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=bias)
-
-            modules.append(('conv', conv))
-        elif char == 'g':
-            is_before_conv = i < order.index('c')
-            if is_before_conv:
-                num_channels = in_channels
-            else:
-                num_channels = out_channels
-
-            # use only one group if the given number of groups is greater than the number of channels
-            if num_channels < num_groups:
-                num_groups = 1
-
-            assert num_channels % num_groups == 0, f'Expected number of channels in input to be divisible by num_groups. num_channels={num_channels}, num_groups={num_groups}'
-            modules.append(('groupnorm', nn.GroupNorm(num_groups=num_groups, num_channels=num_channels)))
-        elif char == 'b':
-            is_before_conv = i < order.index('c')
-            if is3d:
-                bn = nn.BatchNorm3d
-            else:
-                bn = nn.BatchNorm2d
-
-            if is_before_conv:
-                modules.append(('batchnorm', bn(in_channels)))
-            else:
-                modules.append(('batchnorm', bn(out_channels)))
-        elif char == 'd':
-            modules.append(('dropout', nn.Dropout(p=dropout_prob)))
-        elif char == 'D':
-            modules.append(('dropout2d', nn.Dropout2d(p=dropout_prob)))
-        else:
-            raise ValueError(f"Unsupported layer type '{char}'. MUST be one of ['b', 'g', 'r', 'l', 'e', 'c', 'd', 'D']")
-
-    return modules
-
-
-class SingleConv(nn.Sequential):
-    """
-    Basic convolutional module consisting of a Conv3d, non-linearity and optional batchnorm/groupnorm. The order
-    of operations can be specified via the `order` parameter
-
-    Args:
-        in_channels (int): number of input channels
-        out_channels (int): number of output channels
-        kernel_size (int or tuple): size of the convolving kernel
-        order (string): determines the order of layers, e.g.
-            'cr' -> conv + ReLU
-            'crg' -> conv + ReLU + groupnorm
-            'cl' -> conv + LeakyReLU
-            'ce' -> conv + ELU
-        num_groups (int): number of groups for the GroupNorm
-        padding (int or tuple): add zero-padding
-        dropout_prob (float): dropout probability, default 0.1
-        is3d (bool): if True use Conv3d, otherwise use Conv2d
-    """
-
-    def __init__(self, in_channels, out_channels, kernel_size=3, order='gcr', num_groups=8,
-                 padding=1, dropout_prob=0.1, is3d=True):
-        super(SingleConv, self).__init__()
-
-        for name, module in create_conv(in_channels, out_channels, kernel_size, order,
-                                        num_groups, padding, dropout_prob, is3d):
-            self.add_module(name, module)
-
-
-class DoubleConv(nn.Sequential):
-    """
-    A module consisting of two consecutive convolution layers (e.g. BatchNorm3d+ReLU+Conv3d).
-    We use (Conv3d+ReLU+GroupNorm3d) by default.
-    This can be changed however by providing the 'order' argument, e.g. in order
-    to change to Conv3d+BatchNorm3d+ELU use order='cbe'.
-    Use padded convolutions to make sure that the output (H_out, W_out) is the same
-    as (H_in, W_in), so that you don't have to crop in the decoder path.
-
-    Args:
-        in_channels (int): number of input channels
-        out_channels (int): number of output channels
-        encoder (bool): if True we're in the encoder path, otherwise we're in the decoder
-        kernel_size (int or tuple): size of the convolving kernel
-        order (string): determines the order of layers, e.g.
-            'cr' -> conv + ReLU
-            'crg' -> conv + ReLU + groupnorm
-            'cl' -> conv + LeakyReLU
-            'ce' -> conv + ELU
-        num_groups (int): number of groups for the GroupNorm
-        padding (int or tuple): add zero-padding added to all three sides of the input
-        upscale (int): number of the convolution to upscale in encoder if DoubleConv, default: 2
-        dropout_prob (float or tuple): dropout probability for each convolution, default 0.1
-        is3d (bool): if True use Conv3d instead of Conv2d layers
-    """
-
-    def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order='gcr',
-                 num_groups=8, padding=1, upscale=2, dropout_prob=0.1, is3d=True):
-        super(DoubleConv, self).__init__()
-        if encoder:
-            # we're in the encoder path
-            conv1_in_channels = in_channels
-            if upscale == 1:
-                conv1_out_channels = out_channels
-            else:
-                conv1_out_channels = out_channels // 2
-            if conv1_out_channels < in_channels:
-                conv1_out_channels = in_channels
-            conv2_in_channels, conv2_out_channels = conv1_out_channels, out_channels
-        else:
-            # we're in the decoder path, decrease the number of channels in the 1st convolution
-            conv1_in_channels, conv1_out_channels = in_channels, out_channels
-            conv2_in_channels, conv2_out_channels = out_channels, out_channels
-
-        # check if dropout_prob is a tuple and if so
-        # split it for different dropout probabilities for each convolution.
-        if isinstance(dropout_prob, list) or isinstance(dropout_prob, tuple):
-            dropout_prob1 = dropout_prob[0]
-            dropout_prob2 = dropout_prob[1]
-        else:
-            dropout_prob1 = dropout_prob2 = dropout_prob
-
-        # conv1
-        self.add_module('SingleConv1',
-                        SingleConv(conv1_in_channels, conv1_out_channels, kernel_size, order, num_groups,
-                                   padding=padding, dropout_prob=dropout_prob1, is3d=is3d))
-        # conv2
-        self.add_module('SingleConv2',
-                        SingleConv(conv2_in_channels, conv2_out_channels, kernel_size, order, num_groups,
-                                   padding=padding, dropout_prob=dropout_prob2, is3d=is3d))
-
-
-class ResNetBlock(nn.Module):
-    """
-    Residual block that can be used instead of standard DoubleConv in the Encoder module.
-    Motivated by: https://arxiv.org/pdf/1706.00120.pdf
-
-    Notice we use ELU instead of ReLU (order='cge') and put non-linearity after the groupnorm.
-    """
-
-    def __init__(self, in_channels, out_channels, kernel_size=3, order='cge', num_groups=8, is3d=True, **kwargs):
-        super(ResNetBlock, self).__init__()
-
-        if in_channels != out_channels:
-            # conv1x1 for increasing the number of channels
-            if is3d:
-                self.conv1 = nn.Conv3d(in_channels, out_channels, 1)
-            else:
-                self.conv1 = nn.Conv2d(in_channels, out_channels, 1)
-        else:
-            self.conv1 = nn.Identity()
-
-        # residual block
-        self.conv2 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups,
-                                is3d=is3d)
-        # remove non-linearity from the 3rd convolution since it's going to be applied after adding the residual
-        n_order = order
-        for c in 'rel':
-            n_order = n_order.replace(c, '')
-        self.conv3 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=n_order,
-                                num_groups=num_groups, is3d=is3d)
-
-        # create non-linearity separately
-        if 'l' in order:
-            self.non_linearity = nn.LeakyReLU(negative_slope=0.1, inplace=True)
-        elif 'e' in order:
-            self.non_linearity = nn.ELU(inplace=True)
-        else:
-            self.non_linearity = nn.ReLU(inplace=True)
-
-    def forward(self, x):
-        # apply first convolution to bring the number of channels to out_channels
-        residual = self.conv1(x)
-
-        # residual block
-        out = self.conv2(residual)
-        out = self.conv3(out)
-
-        out += residual
-        out = self.non_linearity(out)
-
-        return out
-
-
-class ResNetBlockSE(ResNetBlock):
-    def __init__(self, in_channels, out_channels, kernel_size=3, order='cge', num_groups=8, se_module='scse', **kwargs):
-        super(ResNetBlockSE, self).__init__(
-            in_channels, out_channels, kernel_size=kernel_size, order=order,
-            num_groups=num_groups, **kwargs)
-        assert se_module in ['scse', 'cse', 'sse']
-        if se_module == 'scse':
-            self.se_module = ChannelSpatialSELayer3D(num_channels=out_channels, reduction_ratio=1)
-        elif se_module == 'cse':
-            self.se_module = ChannelSELayer3D(num_channels=out_channels, reduction_ratio=1)
-        elif se_module == 'sse':
-            self.se_module = SpatialSELayer3D(num_channels=out_channels)
-
-    def forward(self, x):
-        out = super().forward(x)
-        out = self.se_module(out)
-        return out
-
-
-class Encoder(nn.Module):
-    """
-    A single module from the encoder path consisting of the optional max
-    pooling layer (one may specify the MaxPool kernel_size to be different
-    from the standard (2,2,2), e.g. if the volumetric data is anisotropic
-    (make sure to use complementary scale_factor in the decoder path) followed by
-    a basic module (DoubleConv or ResNetBlock).
-
-    Args:
-        in_channels (int): number of input channels
-        out_channels (int): number of output channels
-        conv_kernel_size (int or tuple): size of the convolving kernel
-        apply_pooling (bool): if True use MaxPool3d before DoubleConv
-        pool_kernel_size (int or tuple): the size of the window
-        pool_type (str): pooling layer: 'max' or 'avg'
-        basic_module(nn.Module): either ResNetBlock or DoubleConv
-        conv_layer_order (string): determines the order of layers
-            in `DoubleConv` module. See `DoubleConv` for more info.
-        num_groups (int): number of groups for the GroupNorm
-        padding (int or tuple): add zero-padding added to all three sides of the input
-        upscale (int): number of the convolution to upscale in encoder if DoubleConv, default: 2
-        dropout_prob (float or tuple): dropout probability, default 0.1
-        is3d (bool): use 3d or 2d convolutions/pooling operation
-    """
-
-    def __init__(self, in_channels, out_channels, conv_kernel_size=3, apply_pooling=True,
-                 pool_kernel_size=2, pool_type='max', basic_module=DoubleConv, conv_layer_order='gcr',
-                 num_groups=8, padding=1, upscale=2, dropout_prob=0.1, is3d=True):
-        super(Encoder, self).__init__()
-        assert pool_type in ['max', 'avg']
-        if apply_pooling:
-            if pool_type == 'max':
-                if is3d:
-                    self.pooling = nn.MaxPool3d(kernel_size=pool_kernel_size)
-                else:
-                    self.pooling = nn.MaxPool2d(kernel_size=pool_kernel_size)
-            else:
-                if is3d:
-                    self.pooling = nn.AvgPool3d(kernel_size=pool_kernel_size)
-                else:
-                    self.pooling = nn.AvgPool2d(kernel_size=pool_kernel_size)
-        else:
-            self.pooling = None
-
-        self.basic_module = basic_module(in_channels, out_channels,
-                                         encoder=True,
-                                         kernel_size=conv_kernel_size,
-                                         order=conv_layer_order,
-                                         num_groups=num_groups,
-                                         padding=padding,
-                                         upscale=upscale,
-                                         dropout_prob=dropout_prob,
-                                         is3d=is3d)
-
-    def forward(self, x):
-        if self.pooling is not None:
-            x = self.pooling(x)
-        x = self.basic_module(x)
-        return x
-
-
-class Decoder(nn.Module):
-    """
-    A single module for decoder path consisting of the upsampling layer
-    (either learned ConvTranspose3d or nearest neighbor interpolation)
-    followed by a basic module (DoubleConv or ResNetBlock).
-
-    Args:
-        in_channels (int): number of input channels
-        out_channels (int): number of output channels
-        conv_kernel_size (int or tuple): size of the convolving kernel
-        scale_factor (int or tuple): used as the multiplier for the image H/W/D in
-            case of nn.Upsample or as stride in case of ConvTranspose3d, must reverse the MaxPool3d operation
-            from the corresponding encoder
-        basic_module(nn.Module): either ResNetBlock or DoubleConv
-        conv_layer_order (string): determines the order of layers
-            in `DoubleConv` module. See `DoubleConv` for more info.
-        num_groups (int): number of groups for the GroupNorm
-        padding (int or tuple): add zero-padding added to all three sides of the input
-        upsample (str): algorithm used for upsampling:
-            InterpolateUpsampling:   'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area'
-            TransposeConvUpsampling: 'deconv'
-            No upsampling:           None
-            Default: 'default' (chooses automatically)
-        dropout_prob (float or tuple): dropout probability, default 0.1
-    """
-
-    def __init__(self, in_channels, out_channels, conv_kernel_size=3, scale_factor=2, basic_module=DoubleConv,
-                 conv_layer_order='gcr', num_groups=8, padding=1, upsample='default',
-                 dropout_prob=0.1, is3d=True):
-        super(Decoder, self).__init__()
-
-        # perform concat joining per default
-        concat = True
-
-        # don't adapt channels after join operation
-        adapt_channels = False
-
-        if upsample is not None and upsample != 'none':
-            if upsample == 'default':
-                if basic_module == DoubleConv:
-                    upsample = 'nearest'  # use nearest neighbor interpolation for upsampling
-                    concat = True  # use concat joining
-                    adapt_channels = False  # don't adapt channels
-                elif basic_module == ResNetBlock or basic_module == ResNetBlockSE:
-                    upsample = 'deconv'  # use deconvolution upsampling
-                    concat = False  # use summation joining
-                    adapt_channels = True  # adapt channels after joining
-
-            # perform deconvolution upsampling if mode is deconv
-            if upsample == 'deconv':
-                self.upsampling = TransposeConvUpsampling(in_channels=in_channels, out_channels=out_channels,
-                                                          kernel_size=conv_kernel_size, scale_factor=scale_factor,
-                                                          is3d=is3d)
-            else:
-                self.upsampling = InterpolateUpsampling(mode=upsample)
-        else:
-            # no upsampling
-            self.upsampling = NoUpsampling()
-            # concat joining
-            self.joining = partial(self._joining, concat=True)
-
-        # perform joining operation
-        self.joining = partial(self._joining, concat=concat)
-
-        # adapt the number of in_channels for the ResNetBlock
-        if adapt_channels is True:
-            in_channels = out_channels
-
-        self.basic_module = basic_module(in_channels, out_channels,
-                                         encoder=False,
-                                         kernel_size=conv_kernel_size,
-                                         order=conv_layer_order,
-                                         num_groups=num_groups,
-                                         padding=padding,
-                                         dropout_prob=dropout_prob,
-                                         is3d=is3d)
-
-    def forward(self, encoder_features, x):
-        x = self.upsampling(encoder_features=encoder_features, x=x)
-        x = self.joining(encoder_features, x)
-        x = self.basic_module(x)
-        return x
-
-    @staticmethod
-    def _joining(encoder_features, x, concat):
-        if concat:
-            return torch.cat((encoder_features, x), dim=1)
-        else:
-            return encoder_features + x
-
-
-def create_encoders(in_channels, f_maps, basic_module, conv_kernel_size, conv_padding,
-                    conv_upscale, dropout_prob,
-                    layer_order, num_groups, pool_kernel_size, is3d):
-    # create encoder path consisting of Encoder modules. Depth of the encoder is equal to `len(f_maps)`
-    encoders = []
-    for i, out_feature_num in enumerate(f_maps):
-        if i == 0:
-            # apply conv_coord only in the first encoder if any
-            encoder = Encoder(in_channels, out_feature_num,
-                              apply_pooling=False,  # skip pooling in the firs encoder
-                              basic_module=basic_module,
-                              conv_layer_order=layer_order,
-                              conv_kernel_size=conv_kernel_size,
-                              num_groups=num_groups,
-                              padding=conv_padding,
-                              upscale=conv_upscale,
-                              dropout_prob=dropout_prob,
-                              is3d=is3d)
-        else:
-            encoder = Encoder(f_maps[i - 1], out_feature_num,
-                              basic_module=basic_module,
-                              conv_layer_order=layer_order,
-                              conv_kernel_size=conv_kernel_size,
-                              num_groups=num_groups,
-                              pool_kernel_size=pool_kernel_size,
-                              padding=conv_padding,
-                              upscale=conv_upscale,
-                              dropout_prob=dropout_prob,
-                              is3d=is3d)
-
-        encoders.append(encoder)
-
-    return nn.ModuleList(encoders)
-
-
-def create_decoders(f_maps, basic_module, conv_kernel_size, conv_padding, layer_order,
-                    num_groups, upsample, dropout_prob, is3d):
-    # create decoder path consisting of the Decoder modules. The length of the decoder list is equal to `len(f_maps) - 1`
-    decoders = []
-    reversed_f_maps = list(reversed(f_maps))
-    for i in range(len(reversed_f_maps) - 1):
-        if basic_module == DoubleConv and upsample != 'deconv':
-            in_feature_num = reversed_f_maps[i] + reversed_f_maps[i + 1]
-        else:
-            in_feature_num = reversed_f_maps[i]
-
-        out_feature_num = reversed_f_maps[i + 1]
-
-        decoder = Decoder(in_feature_num, out_feature_num,
-                          basic_module=basic_module,
-                          conv_layer_order=layer_order,
-                          conv_kernel_size=conv_kernel_size,
-                          num_groups=num_groups,
-                          padding=conv_padding,
-                          upsample=upsample,
-                          dropout_prob=dropout_prob,
-                          is3d=is3d)
-        decoders.append(decoder)
-    return nn.ModuleList(decoders)
-
-
-class AbstractUpsampling(nn.Module):
-    """
-    Abstract class for upsampling. A given implementation should upsample a given 5D input tensor using either
-    interpolation or learned transposed convolution.
-    """
-
-    def __init__(self, upsample):
-        super(AbstractUpsampling, self).__init__()
-        self.upsample = upsample
-
-    def forward(self, encoder_features, x):
-        # get the spatial dimensions of the output given the encoder_features
-        output_size = encoder_features.size()[2:]
-        # upsample the input and return
-        return self.upsample(x, output_size)
-
-
-class InterpolateUpsampling(AbstractUpsampling):
-    """
-    Args:
-        mode (str): algorithm used for upsampling:
-            'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area'. Default: 'nearest'
-            used only if transposed_conv is False
-    """
-
-    def __init__(self, mode='nearest'):
-        upsample = partial(self._interpolate, mode=mode)
-        super().__init__(upsample)
-
-    @staticmethod
-    def _interpolate(x, size, mode):
-        return F.interpolate(x, size=size, mode=mode)
-
-
-class TransposeConvUpsampling(AbstractUpsampling):
-    """
-    Args:
-        in_channels (int): number of input channels for transposed conv
-            used only if transposed_conv is True
-        out_channels (int): number of output channels for transpose conv
-            used only if transposed_conv is True
-        kernel_size (int or tuple): size of the convolving kernel
-            used only if transposed_conv is True
-        scale_factor (int or tuple): stride of the convolution
-            used only if transposed_conv is True
-        is3d (bool): if True use ConvTranspose3d, otherwise use ConvTranspose2d
-    """
-
-    class Upsample(nn.Module):
-        """
-        Workaround the 'ValueError: requested an output size...' in the `_output_padding` method in
-        transposed convolution. It performs transposed conv followed by the interpolation to the correct size if necessary.
-        """
-
-        def __init__(self, conv_transposed, is3d):
-            super().__init__()
-            self.conv_transposed = conv_transposed
-            self.is3d = is3d
-
-        def forward(self, x, size):
-            x = self.conv_transposed(x)
-            return F.interpolate(x, size=size)
-
-    def __init__(self, in_channels, out_channels, kernel_size=3, scale_factor=2, is3d=True):
-        # make sure that the output size reverses the MaxPool3d from the corresponding encoder
-        if is3d is True:
-            conv_transposed = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=kernel_size,
-                                                 stride=scale_factor, padding=1, bias=False)
-        else:
-            conv_transposed = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size,
-                                                 stride=scale_factor, padding=1, bias=False)
-        upsample = self.Upsample(conv_transposed, is3d)
-        super().__init__(upsample)
-
-
-class NoUpsampling(AbstractUpsampling):
-    def __init__(self):
-        super().__init__(self._no_upsampling)
-
-    @staticmethod
-    def _no_upsampling(x, size):
-        return x
diff --git a/build/lib/pytorch3dunet/unet3d/config.py b/build/lib/pytorch3dunet/unet3d/config.py
deleted file mode 100644
index bb011632..00000000
--- a/build/lib/pytorch3dunet/unet3d/config.py
+++ /dev/null
@@ -1,79 +0,0 @@
-import argparse
-import os
-import shutil
-
-import torch
-import yaml
-
-from pytorch3dunet.unet3d import utils
-
-logger = utils.get_logger('ConfigLoader')
-
-
-def _override_config(args, config):
-    """Overrides config params with the ones given in command line."""
-
-    args_dict = vars(args)
-    # remove the first argument which is the config file path
-    args_dict.pop('config')
-
-    for key, value in args_dict.items():
-        if value is None:
-            continue
-        c = config
-        for k in key.split('.'):
-            if k not in c:
-                raise ValueError(f'Invalid config key: {key}')
-            if isinstance(c[k], dict):
-                c = c[k]
-            else:
-                c[k] = value
-
-
-def load_config():
-    parser = argparse.ArgumentParser(description='UNet3D')
-    parser.add_argument('--config', type=str, help='Path to the YAML config file', required=True)
-    # add additional command line arguments for the prediction that override the ones in the config file
-    parser.add_argument('--model_path', type=str, required=False)
-    parser.add_argument('--loaders.output_dir', type=str, required=False)
-    parser.add_argument('--loaders.test.file_paths', type=str, nargs="+", required=False)
-    parser.add_argument('--loaders.test.slice_builder.patch_shape', type=int, nargs="+", required=False)
-    parser.add_argument('--loaders.test.slice_builder.stride_shape', type=int, nargs="+", required=False)
-
-    args = parser.parse_args()
-    config_path = args.config
-    config = yaml.safe_load(open(config_path, 'r'))
-    _override_config(args, config)
-
-    device = config.get('device', None)
-    if device == 'cpu':
-        logger.warning('CPU mode forced in config, this will likely result in slow training/prediction')
-        config['device'] = 'cpu'
-        return config
-
-    if torch.cuda.is_available():
-        config['device'] = 'cuda'
-    else:
-        logger.warning('CUDA not available, using CPU')
-        config['device'] = 'cpu'
-    return config, config_path
-
-
-def copy_config(config, config_path):
-    """Copies the config file to the checkpoint folder."""
-
-    def _get_last_subfolder_path(path):
-        subfolders = [f.path for f in os.scandir(path) if f.is_dir()]
-        return max(subfolders, default=None)
-
-    checkpoint_dir = os.path.join(
-        config['trainer'].pop('checkpoint_dir'), 'logs')
-    last_run_dir = _get_last_subfolder_path(checkpoint_dir)
-    config_file_name = os.path.basename(config_path)
-
-    if last_run_dir:
-        shutil.copy2(config_path, os.path.join(last_run_dir, config_file_name))
-
-
-def _load_config_yaml(config_file):
-    return yaml.safe_load(open(config_file, 'r'))
diff --git a/build/lib/pytorch3dunet/unet3d/losses.py b/build/lib/pytorch3dunet/unet3d/losses.py
deleted file mode 100644
index 6a53966f..00000000
--- a/build/lib/pytorch3dunet/unet3d/losses.py
+++ /dev/null
@@ -1,345 +0,0 @@
-import torch
-import torch.nn.functional as F
-from torch import nn as nn
-from torch.nn import MSELoss, SmoothL1Loss, L1Loss
-
-
-def compute_per_channel_dice(input, target, epsilon=1e-6, weight=None):
-    """
-    Computes DiceCoefficient as defined in https://arxiv.org/abs/1606.04797 given  a multi channel input and target.
-    Assumes the input is a normalized probability, e.g. a result of Sigmoid or Softmax function.
-
-    Args:
-         input (torch.Tensor): NxCxSpatial input tensor
-         target (torch.Tensor): NxCxSpatial target tensor
-         epsilon (float): prevents division by zero
-         weight (torch.Tensor): Cx1 tensor of weight per channel/class
-    """
-
-    # input and target shapes must match
-    assert input.size() == target.size(), "'input' and 'target' must have the same shape"
-
-    input = flatten(input)
-    target = flatten(target)
-    target = target.float()
-
-    # compute per channel Dice Coefficient
-    intersect = (input * target).sum(-1)
-    if weight is not None:
-        intersect = weight * intersect
-
-    # here we can use standard dice (input + target).sum(-1) or extension (see V-Net) (input^2 + target^2).sum(-1)
-    denominator = (input * input).sum(-1) + (target * target).sum(-1)
-    return 2 * (intersect / denominator.clamp(min=epsilon))
-
-
-class _MaskingLossWrapper(nn.Module):
-    """
-    Loss wrapper which prevents the gradient of the loss to be computed where target is equal to `ignore_index`.
-    """
-
-    def __init__(self, loss, ignore_index):
-        super(_MaskingLossWrapper, self).__init__()
-        assert ignore_index is not None, 'ignore_index cannot be None'
-        self.loss = loss
-        self.ignore_index = ignore_index
-
-    def forward(self, input, target):
-        mask = target.clone().ne_(self.ignore_index)
-        mask.requires_grad = False
-
-        # mask out input/target so that the gradient is zero where on the mask
-        input = input * mask
-        target = target * mask
-
-        # forward masked input and target to the loss
-        return self.loss(input, target)
-
-
-class SkipLastTargetChannelWrapper(nn.Module):
-    """
-    Loss wrapper which removes additional target channel
-    """
-
-    def __init__(self, loss, squeeze_channel=False):
-        super(SkipLastTargetChannelWrapper, self).__init__()
-        self.loss = loss
-        self.squeeze_channel = squeeze_channel
-
-    def forward(self, input, target, weight=None):
-        assert target.size(1) > 1, 'Target tensor has a singleton channel dimension, cannot remove channel'
-
-        # skips last target channel if needed
-        target = target[:, :-1, ...]
-
-        if self.squeeze_channel:
-            # squeeze channel dimension
-            target = torch.squeeze(target, dim=1)
-        if weight is not None:
-            return self.loss(input, target, weight)
-        return self.loss(input, target)
-
-
-class _AbstractDiceLoss(nn.Module):
-    """
-    Base class for different implementations of Dice loss.
-    """
-
-    def __init__(self, weight=None, normalization='sigmoid'):
-        super(_AbstractDiceLoss, self).__init__()
-        self.register_buffer('weight', weight)
-        # The output from the network during training is assumed to be un-normalized probabilities and we would
-        # like to normalize the logits. Since Dice (or soft Dice in this case) is usually used for binary data,
-        # normalizing the channels with Sigmoid is the default choice even for multi-class segmentation problems.
-        # However if one would like to apply Softmax in order to get the proper probability distribution from the
-        # output, just specify `normalization=Softmax`
-        assert normalization in ['sigmoid', 'softmax', 'none']
-        if normalization == 'sigmoid':
-            self.normalization = nn.Sigmoid()
-        elif normalization == 'softmax':
-            self.normalization = nn.Softmax(dim=1)
-        else:
-            self.normalization = lambda x: x
-
-    def dice(self, input, target, weight):
-        # actual Dice score computation; to be implemented by the subclass
-        raise NotImplementedError
-
-    def forward(self, input, target):
-        # get probabilities from logits
-        input = self.normalization(input)
-
-        # compute per channel Dice coefficient
-        per_channel_dice = self.dice(input, target, weight=self.weight)
-
-        # average Dice score across all channels/classes
-        return 1. - torch.mean(per_channel_dice)
-
-
-class DiceLoss(_AbstractDiceLoss):
-    """Computes Dice Loss according to https://arxiv.org/abs/1606.04797.
-    For multi-class segmentation `weight` parameter can be used to assign different weights per class.
-    The input to the loss function is assumed to be a logit and will be normalized by the Sigmoid function.
-    """
-
-    def __init__(self, weight=None, normalization='sigmoid'):
-        super().__init__(weight, normalization)
-
-    def dice(self, input, target, weight):
-        return compute_per_channel_dice(input, target, weight=self.weight)
-
-
-class GeneralizedDiceLoss(_AbstractDiceLoss):
-    """Computes Generalized Dice Loss (GDL) as described in https://arxiv.org/pdf/1707.03237.pdf.
-    """
-
-    def __init__(self, normalization='sigmoid', epsilon=1e-6):
-        super().__init__(weight=None, normalization=normalization)
-        self.epsilon = epsilon
-
-    def dice(self, input, target, weight):
-        assert input.size() == target.size(), "'input' and 'target' must have the same shape"
-
-        input = flatten(input)
-        target = flatten(target)
-        target = target.float()
-
-        if input.size(0) == 1:
-            # for GDL to make sense we need at least 2 channels (see https://arxiv.org/pdf/1707.03237.pdf)
-            # put foreground and background voxels in separate channels
-            input = torch.cat((input, 1 - input), dim=0)
-            target = torch.cat((target, 1 - target), dim=0)
-
-        # GDL weighting: the contribution of each label is corrected by the inverse of its volume
-        w_l = target.sum(-1)
-        w_l = 1 / (w_l * w_l).clamp(min=self.epsilon)
-        w_l.requires_grad = False
-
-        intersect = (input * target).sum(-1)
-        intersect = intersect * w_l
-
-        denominator = (input + target).sum(-1)
-        denominator = (denominator * w_l).clamp(min=self.epsilon)
-
-        return 2 * (intersect.sum() / denominator.sum())
-
-
-class BCEDiceLoss(nn.Module):
-    """Linear combination of BCE and Dice losses"""
-
-    def __init__(self, alpha, beta):
-        super(BCEDiceLoss, self).__init__()
-        self.alpha = alpha
-        self.bce = nn.BCEWithLogitsLoss()
-        self.beta = beta
-        self.dice = DiceLoss()
-
-    def forward(self, input, target):
-        return self.alpha * self.bce(input, target) + self.beta * self.dice(input, target)
-
-
-class WeightedCrossEntropyLoss(nn.Module):
-    """WeightedCrossEntropyLoss (WCE) as described in https://arxiv.org/pdf/1707.03237.pdf
-    """
-
-    def __init__(self, ignore_index=-1):
-        super(WeightedCrossEntropyLoss, self).__init__()
-        self.ignore_index = ignore_index
-
-    def forward(self, input, target):
-        weight = self._class_weights(input)
-        return F.cross_entropy(input, target, weight=weight, ignore_index=self.ignore_index)
-
-    @staticmethod
-    def _class_weights(input):
-        # normalize the input first
-        input = F.softmax(input, dim=1)
-        flattened = flatten(input)
-        nominator = (1. - flattened).sum(-1)
-        denominator = flattened.sum(-1)
-        class_weights = nominator / denominator
-        return class_weights.detach()
-
-
-class PixelWiseCrossEntropyLoss(nn.Module):
-    def __init__(self, ignore_index=None):
-        super(PixelWiseCrossEntropyLoss, self).__init__()
-        self.ignore_index = ignore_index
-        self.log_softmax = nn.LogSoftmax(dim=1)
-
-    def forward(self, input, target, weights):
-        assert target.size() == weights.size()
-        # normalize the input
-        log_probabilities = self.log_softmax(input)
-        # standard CrossEntropyLoss requires the target to be (NxDxHxW), so we need to expand it to (NxCxDxHxW)
-        if self.ignore_index is not None:
-            mask = target == self.ignore_index
-            target[mask] = 0
-        else:
-            mask = torch.zeros_like(target)
-        # add channel dimension and invert the mask
-        mask = 1 - mask.unsqueeze(1)
-        # convert target to one-hot encoding
-        target = F.one_hot(target.long())
-        if target.ndim == 5:
-            # permute target to (NxCxDxHxW)
-            target = target.permute(0, 4, 1, 2, 3).contiguous()
-        else:
-            target = target.permute(0, 3, 1, 2).contiguous()
-        # apply the mask on the target
-        target = target * mask
-        # add channel dimension to the weights
-        weights = weights.unsqueeze(1)
-        # compute the losses
-        result = -weights * target * log_probabilities
-        return result.mean()
-
-
-class WeightedSmoothL1Loss(nn.SmoothL1Loss):
-    def __init__(self, threshold, initial_weight, apply_below_threshold=True):
-        super().__init__(reduction="none")
-        self.threshold = threshold
-        self.apply_below_threshold = apply_below_threshold
-        self.weight = initial_weight
-
-    def forward(self, input, target):
-        l1 = super().forward(input, target)
-
-        if self.apply_below_threshold:
-            mask = target < self.threshold
-        else:
-            mask = target >= self.threshold
-
-        l1[mask] = l1[mask] * self.weight
-
-        return l1.mean()
-
-
-def flatten(tensor):
-    """Flattens a given tensor such that the channel axis is first.
-    The shapes are transformed as follows:
-       (N, C, D, H, W) -> (C, N * D * H * W)
-    """
-    # number of channels
-    C = tensor.size(1)
-    # new axis order
-    axis_order = (1, 0) + tuple(range(2, tensor.dim()))
-    # Transpose: (N, C, D, H, W) -> (C, N, D, H, W)
-    transposed = tensor.permute(axis_order)
-    # Flatten: (C, N, D, H, W) -> (C, N * D * H * W)
-    return transposed.contiguous().view(C, -1)
-
-
-def get_loss_criterion(config):
-    """
-    Returns the loss function based on provided configuration
-    :param config: (dict) a top level configuration object containing the 'loss' key
-    :return: an instance of the loss function
-    """
-    assert 'loss' in config, 'Could not find loss function configuration'
-    loss_config = config['loss']
-    name = loss_config.pop('name')
-
-    ignore_index = loss_config.pop('ignore_index', None)
-    skip_last_target = loss_config.pop('skip_last_target', False)
-    weight = loss_config.pop('weight', None)
-
-    if weight is not None:
-        weight = torch.tensor(weight)
-
-    pos_weight = loss_config.pop('pos_weight', None)
-    if pos_weight is not None:
-        pos_weight = torch.tensor(pos_weight)
-
-    loss = _create_loss(name, loss_config, weight, ignore_index, pos_weight)
-
-    if not (ignore_index is None or name in ['CrossEntropyLoss', 'WeightedCrossEntropyLoss']):
-        # use MaskingLossWrapper only for non-cross-entropy losses, since CE losses allow specifying 'ignore_index' directly
-        loss = _MaskingLossWrapper(loss, ignore_index)
-
-    if skip_last_target:
-        loss = SkipLastTargetChannelWrapper(loss, loss_config.get('squeeze_channel', False))
-
-    if torch.cuda.is_available():
-        loss = loss.cuda()
-
-    return loss
-
-
-#######################################################################################################################
-
-def _create_loss(name, loss_config, weight, ignore_index, pos_weight):
-    if name == 'BCEWithLogitsLoss':
-        return nn.BCEWithLogitsLoss(pos_weight=pos_weight)
-    elif name == 'BCEDiceLoss':
-        alpha = loss_config.get('alpha', 1.)
-        beta = loss_config.get('beta', 1.)
-        return BCEDiceLoss(alpha, beta)
-    elif name == 'CrossEntropyLoss':
-        if ignore_index is None:
-            ignore_index = -100  # use the default 'ignore_index' as defined in the CrossEntropyLoss
-        return nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index)
-    elif name == 'WeightedCrossEntropyLoss':
-        if ignore_index is None:
-            ignore_index = -100  # use the default 'ignore_index' as defined in the CrossEntropyLoss
-        return WeightedCrossEntropyLoss(ignore_index=ignore_index)
-    elif name == 'PixelWiseCrossEntropyLoss':
-        return PixelWiseCrossEntropyLoss(ignore_index=ignore_index)
-    elif name == 'GeneralizedDiceLoss':
-        normalization = loss_config.get('normalization', 'sigmoid')
-        return GeneralizedDiceLoss(normalization=normalization)
-    elif name == 'DiceLoss':
-        normalization = loss_config.get('normalization', 'sigmoid')
-        return DiceLoss(weight=weight, normalization=normalization)
-    elif name == 'MSELoss':
-        return MSELoss()
-    elif name == 'SmoothL1Loss':
-        return SmoothL1Loss()
-    elif name == 'L1Loss':
-        return L1Loss()
-    elif name == 'WeightedSmoothL1Loss':
-        return WeightedSmoothL1Loss(threshold=loss_config['threshold'],
-                                    initial_weight=loss_config['initial_weight'],
-                                    apply_below_threshold=loss_config.get('apply_below_threshold', True))
-    else:
-        raise RuntimeError(f"Unsupported loss function: '{name}'")
diff --git a/build/lib/pytorch3dunet/unet3d/metrics.py b/build/lib/pytorch3dunet/unet3d/metrics.py
deleted file mode 100644
index 2b60b4b7..00000000
--- a/build/lib/pytorch3dunet/unet3d/metrics.py
+++ /dev/null
@@ -1,445 +0,0 @@
-import importlib
-
-import numpy as np
-import torch
-from skimage import measure
-from skimage.metrics import adapted_rand_error, peak_signal_noise_ratio, mean_squared_error
-
-from pytorch3dunet.unet3d.losses import compute_per_channel_dice
-from pytorch3dunet.unet3d.seg_metrics import AveragePrecision, Accuracy
-from pytorch3dunet.unet3d.utils import get_logger, expand_as_one_hot, convert_to_numpy
-
-logger = get_logger('EvalMetric')
-
-
-class DiceCoefficient:
-    """Computes Dice Coefficient.
-    Generalized to multiple channels by computing per-channel Dice Score
-    (as described in https://arxiv.org/pdf/1707.03237.pdf) and then simply taking the average.
-    Input is expected to be probabilities instead of logits.
-    This metric is mostly useful when channels contain the same semantic class (e.g. affinities computed with different offsets).
-    DO NOT USE this metric when training with DiceLoss, otherwise the results will be biased towards the loss.
-    """
-
-    def __init__(self, epsilon=1e-6, **kwargs):
-        self.epsilon = epsilon
-
-    def __call__(self, input, target):
-        # Average across channels in order to get the final score
-        return torch.mean(compute_per_channel_dice(input, target, epsilon=self.epsilon))
-
-
-class MeanIoU:
-    """
-    Computes IoU for each class separately and then averages over all classes.
-    """
-
-    def __init__(self, skip_channels=(), ignore_index=None, **kwargs):
-        """
-        :param skip_channels: list/tuple of channels to be ignored from the IoU computation
-        :param ignore_index: id of the label to be ignored from IoU computation
-        """
-        self.ignore_index = ignore_index
-        self.skip_channels = skip_channels
-
-    def __call__(self, input, target):
-        """
-        :param input: 5D probability maps torch float tensor (NxCxDxHxW)
-        :param target: 4D or 5D ground truth torch tensor. 4D (NxDxHxW) tensor will be expanded to 5D as one-hot
-        :return: intersection over union averaged over all channels
-        """
-        assert input.dim() == 5
-
-        n_classes = input.size()[1]
-
-        if target.dim() == 4:
-            target = expand_as_one_hot(target, C=n_classes, ignore_index=self.ignore_index)
-
-        assert input.size() == target.size()
-
-        per_batch_iou = []
-        for _input, _target in zip(input, target):
-            binary_prediction = self._binarize_predictions(_input, n_classes)
-
-            if self.ignore_index is not None:
-                # zero out ignore_index
-                mask = _target == self.ignore_index
-                binary_prediction[mask] = 0
-                _target[mask] = 0
-
-            # convert to uint8 just in case
-            binary_prediction = binary_prediction.byte()
-            _target = _target.byte()
-
-            per_channel_iou = []
-            for c in range(n_classes):
-                if c in self.skip_channels:
-                    continue
-
-                per_channel_iou.append(self._jaccard_index(binary_prediction[c], _target[c]))
-
-            assert per_channel_iou, "All channels were ignored from the computation"
-            mean_iou = torch.mean(torch.tensor(per_channel_iou))
-            per_batch_iou.append(mean_iou)
-
-        return torch.mean(torch.tensor(per_batch_iou))
-
-    def _binarize_predictions(self, input, n_classes):
-        """
-        Puts 1 for the class/channel with the highest probability and 0 in other channels. Returns byte tensor of the
-        same size as the input tensor.
-        """
-        if n_classes == 1:
-            # for single channel input just threshold the probability map
-            result = input > 0.5
-            return result.long()
-
-        _, max_index = torch.max(input, dim=0, keepdim=True)
-        return torch.zeros_like(input, dtype=torch.uint8).scatter_(0, max_index, 1)
-
-    def _jaccard_index(self, prediction, target):
-        """
-        Computes IoU for a given target and prediction tensors
-        """
-        return torch.sum(prediction & target).float() / torch.clamp(torch.sum(prediction | target).float(), min=1e-8)
-
-
-class AdaptedRandError:
-    """
-    A functor which computes an Adapted Rand error as defined by the SNEMI3D contest
-    (http://brainiac2.mit.edu/SNEMI3D/evaluation).
-
-    This is a generic implementation which takes the input, converts it to the segmentation image (see `input_to_segm()`)
-    and then computes the ARand between the segmentation and the ground truth target. Depending on one's use case
-    it's enough to extend this class and implement the `input_to_segm` method.
-
-    Args:
-        use_last_target (bool): if true, use the last channel from the target to compute the ARand, otherwise the first.
-    """
-
-    def __init__(self, use_last_target=False, ignore_index=None, **kwargs):
-        self.use_last_target = use_last_target
-        self.ignore_index = ignore_index
-
-    def __call__(self, input, target):
-        """
-        Compute ARand Error for each input, target pair in the batch and return the mean value.
-
-        Args:
-            input (torch.tensor):  5D (NCDHW) output from the network
-            target (torch.tensor): 5D (NCDHW) ground truth segmentation
-
-        Returns:
-            average ARand Error across the batch
-        """
-
-        # converts input and target to numpy arrays
-        input, target = convert_to_numpy(input, target)
-        if self.use_last_target:
-            target = target[:, -1, ...]  # 4D
-        else:
-            # use 1st target channel
-            target = target[:, 0, ...]  # 4D
-
-        # ensure target is of integer type
-        target = target.astype(np.int32)
-
-        if self.ignore_index is not None:
-            target[target == self.ignore_index] = 0
-
-        per_batch_arand = []
-        for _input, _target in zip(input, target):
-            if np.all(_target == _target.flat[0]):  # skip ARand eval if there is only one label in the patch due to zero-division
-                logger.info('Skipping ARandError computation: only 1 label present in the ground truth')
-                per_batch_arand.append(0.)
-                continue
-
-            # convert _input to segmentation CDHW
-            segm = self.input_to_segm(_input)
-            assert segm.ndim == 4
-
-            # compute per channel arand and return the minimum value
-            per_channel_arand = [adapted_rand_error(_target, channel_segm)[0] for channel_segm in segm]
-            per_batch_arand.append(np.min(per_channel_arand))
-
-        # return mean arand error
-        mean_arand = torch.mean(torch.tensor(per_batch_arand))
-        logger.info(f'ARand: {mean_arand.item()}')
-        return mean_arand
-
-    def input_to_segm(self, input):
-        """
-        Converts input tensor (output from the network) to the segmentation image. E.g. if the input is the boundary
-        pmaps then one option would be to threshold it and run connected components in order to return the segmentation.
-
-        :param input: 4D tensor (CDHW)
-        :return: segmentation volume either 4D (segmentation per channel)
-        """
-        # by deafult assume that input is a segmentation volume itself
-        return input
-
-
-class BoundaryAdaptedRandError(AdaptedRandError):
-    """
-    Compute ARand between the input boundary map and target segmentation.
-    Boundary map is thresholded, and connected components is run to get the predicted segmentation
-    """
-
-    def __init__(self, thresholds=None, use_last_target=True, ignore_index=None, input_channel=None, invert_pmaps=True,
-                 save_plots=False, plots_dir='.', **kwargs):
-        super().__init__(use_last_target=use_last_target, ignore_index=ignore_index, save_plots=save_plots,
-                         plots_dir=plots_dir, **kwargs)
-
-        if thresholds is None:
-            thresholds = [0.3, 0.4, 0.5, 0.6]
-        assert isinstance(thresholds, list)
-        self.thresholds = thresholds
-        self.input_channel = input_channel
-        self.invert_pmaps = invert_pmaps
-
-    def input_to_segm(self, input):
-        if self.input_channel is not None:
-            input = np.expand_dims(input[self.input_channel], axis=0)
-
-        segs = []
-        for predictions in input:
-            for th in self.thresholds:
-                # threshold probability maps
-                predictions = predictions > th
-
-                if self.invert_pmaps:
-                    # for connected component analysis we need to treat boundary signal as background
-                    # assign 0-label to boundary mask
-                    predictions = np.logical_not(predictions)
-
-                predictions = predictions.astype(np.uint8)
-                # run connected components on the predicted mask; consider only 1-connectivity
-                seg = measure.label(predictions, background=0, connectivity=1)
-                segs.append(seg)
-
-        return np.stack(segs)
-
-
-class GenericAdaptedRandError(AdaptedRandError):
-    def __init__(self, input_channels, thresholds=None, use_last_target=True, ignore_index=None, invert_channels=None,
-                 **kwargs):
-
-        super().__init__(use_last_target=use_last_target, ignore_index=ignore_index, **kwargs)
-        assert isinstance(input_channels, list) or isinstance(input_channels, tuple)
-        self.input_channels = input_channels
-        if thresholds is None:
-            thresholds = [0.3, 0.4, 0.5, 0.6]
-        assert isinstance(thresholds, list)
-        self.thresholds = thresholds
-        if invert_channels is None:
-            invert_channels = []
-        self.invert_channels = invert_channels
-
-    def input_to_segm(self, input):
-        # pick only the channels specified in the input_channels
-        results = []
-        for i in self.input_channels:
-            c = input[i]
-            # invert channel if necessary
-            if i in self.invert_channels:
-                c = 1 - c
-            results.append(c)
-
-        input = np.stack(results)
-
-        segs = []
-        for predictions in input:
-            for th in self.thresholds:
-                # run connected components on the predicted mask; consider only 1-connectivity
-                seg = measure.label((predictions > th).astype(np.uint8), background=0, connectivity=1)
-                segs.append(seg)
-
-        return np.stack(segs)
-
-
-class GenericAveragePrecision:
-    def __init__(self, min_instance_size=None, use_last_target=False, metric='ap', **kwargs):
-        self.min_instance_size = min_instance_size
-        self.use_last_target = use_last_target
-        assert metric in ['ap', 'acc']
-        if metric == 'ap':
-            # use AveragePrecision
-            self.metric = AveragePrecision()
-        else:
-            # use Accuracy at 0.5 IoU
-            self.metric = Accuracy(iou_threshold=0.5)
-
-    def __call__(self, input, target):
-        if target.dim() == 5:
-            if self.use_last_target:
-                target = target[:, -1, ...]  # 4D
-            else:
-                # use 1st target channel
-                target = target[:, 0, ...]  # 4D
-
-        input1 = input2 = input
-        multi_head = isinstance(input, tuple)
-        if multi_head:
-            input1, input2 = input
-
-        input1, input2, target = convert_to_numpy(input1, input2, target)
-
-        batch_aps = []
-        i_batch = 0
-        # iterate over the batch
-        for inp1, inp2, tar in zip(input1, input2, target):
-            if multi_head:
-                inp = (inp1, inp2)
-            else:
-                inp = inp1
-
-            segs = self.input_to_seg(inp, tar)  # expects 4D
-            assert segs.ndim == 4
-            # convert target to seg
-            tar = self.target_to_seg(tar)
-
-            # filter small instances if necessary
-            tar = self._filter_instances(tar)
-
-            # compute average precision per channel
-            segs_aps = [self.metric(self._filter_instances(seg), tar) for seg in segs]
-
-            logger.info(f'Batch: {i_batch}. Max Average Precision for channel: {np.argmax(segs_aps)}')
-            # save max AP
-            batch_aps.append(np.max(segs_aps))
-            i_batch += 1
-
-        return torch.tensor(batch_aps).mean()
-
-    def _filter_instances(self, input):
-        """
-        Filters instances smaller than 'min_instance_size' by overriding them with 0-index
-        :param input: input instance segmentation
-        """
-        if self.min_instance_size is not None:
-            labels, counts = np.unique(input, return_counts=True)
-            for label, count in zip(labels, counts):
-                if count < self.min_instance_size:
-                    input[input == label] = 0
-        return input
-
-    def input_to_seg(self, input, target=None):
-        raise NotImplementedError
-
-    def target_to_seg(self, target):
-        return target
-
-
-class BlobsAveragePrecision(GenericAveragePrecision):
-    """
-    Computes Average Precision given foreground prediction and ground truth instance segmentation.
-    """
-
-    def __init__(self, thresholds=None, metric='ap', min_instance_size=None, input_channel=0, **kwargs):
-        super().__init__(min_instance_size=min_instance_size, use_last_target=True, metric=metric)
-        if thresholds is None:
-            thresholds = [0.4, 0.5, 0.6, 0.7, 0.8]
-        assert isinstance(thresholds, list)
-        self.thresholds = thresholds
-        self.input_channel = input_channel
-
-    def input_to_seg(self, input, target=None):
-        input = input[self.input_channel]
-        segs = []
-        for th in self.thresholds:
-            # threshold and run connected components
-            mask = (input > th).astype(np.uint8)
-            seg = measure.label(mask, background=0, connectivity=1)
-            segs.append(seg)
-        return np.stack(segs)
-
-
-class BlobsBoundaryAveragePrecision(GenericAveragePrecision):
-    """
-    Computes Average Precision given foreground prediction, boundary prediction and ground truth instance segmentation.
-    Segmentation mask is computed as (P_mask - P_boundary) > th followed by a connected component
-    """
-
-    def __init__(self, thresholds=None, metric='ap', min_instance_size=None, **kwargs):
-        super().__init__(min_instance_size=min_instance_size, use_last_target=True, metric=metric)
-        if thresholds is None:
-            thresholds = [0.3, 0.4, 0.5, 0.6, 0.7]
-        assert isinstance(thresholds, list)
-        self.thresholds = thresholds
-
-    def input_to_seg(self, input, target=None):
-        # input = P_mask - P_boundary
-        input = input[0] - input[1]
-        segs = []
-        for th in self.thresholds:
-            # threshold and run connected components
-            mask = (input > th).astype(np.uint8)
-            seg = measure.label(mask, background=0, connectivity=1)
-            segs.append(seg)
-        return np.stack(segs)
-
-
-class BoundaryAveragePrecision(GenericAveragePrecision):
-    """
-    Computes Average Precision given boundary prediction and ground truth instance segmentation.
-    """
-
-    def __init__(self, thresholds=None, min_instance_size=None, input_channel=0, **kwargs):
-        super().__init__(min_instance_size=min_instance_size, use_last_target=True)
-        if thresholds is None:
-            thresholds = [0.3, 0.4, 0.5, 0.6]
-        assert isinstance(thresholds, list)
-        self.thresholds = thresholds
-        self.input_channel = input_channel
-
-    def input_to_seg(self, input, target=None):
-        input = input[self.input_channel]
-        segs = []
-        for th in self.thresholds:
-            seg = measure.label(np.logical_not(input > th).astype(np.uint8), background=0, connectivity=1)
-            segs.append(seg)
-        return np.stack(segs)
-
-
-class PSNR:
-    """
-    Computes Peak Signal to Noise Ratio. Use e.g. as an eval metric for denoising task
-    """
-
-    def __init__(self, **kwargs):
-        pass
-
-    def __call__(self, input, target):
-        input, target = convert_to_numpy(input, target)
-        return peak_signal_noise_ratio(target, input)
-
-
-class MSE:
-    """
-    Computes MSE between input and target
-    """
-
-    def __init__(self, **kwargs):
-        pass
-
-    def __call__(self, input, target):
-        input, target = convert_to_numpy(input, target)
-        return mean_squared_error(input, target)
-
-
-def get_evaluation_metric(config):
-    """
-    Returns the evaluation metric function based on provided configuration
-    :param config: (dict) a top level configuration object containing the 'eval_metric' key
-    :return: an instance of the evaluation metric
-    """
-
-    def _metric_class(class_name):
-        m = importlib.import_module('pytorch3dunet.unet3d.metrics')
-        clazz = getattr(m, class_name)
-        return clazz
-
-    assert 'eval_metric' in config, 'Could not find evaluation metric configuration'
-    metric_config = config['eval_metric']
-    metric_class = _metric_class(metric_config['name'])
-    return metric_class(**metric_config)
diff --git a/build/lib/pytorch3dunet/unet3d/model.py b/build/lib/pytorch3dunet/unet3d/model.py
deleted file mode 100644
index e4de49a7..00000000
--- a/build/lib/pytorch3dunet/unet3d/model.py
+++ /dev/null
@@ -1,249 +0,0 @@
-import torch.nn as nn
-
-from pytorch3dunet.unet3d.buildingblocks import DoubleConv, ResNetBlock, ResNetBlockSE, \
-    create_decoders, create_encoders
-from pytorch3dunet.unet3d.utils import get_class, number_of_features_per_level
-
-
-class AbstractUNet(nn.Module):
-    """
-    Base class for standard and residual UNet.
-
-    Args:
-        in_channels (int): number of input channels
-        out_channels (int): number of output segmentation masks;
-            Note that the of out_channels might correspond to either
-            different semantic classes or to different binary segmentation mask.
-            It's up to the user of the class to interpret the out_channels and
-            use the proper loss criterion during training (i.e. CrossEntropyLoss (multi-class)
-            or BCEWithLogitsLoss (two-class) respectively)
-        f_maps (int, tuple): number of feature maps at each level of the encoder; if it's an integer the number
-            of feature maps is given by the geometric progression: f_maps ^ k, k=1,2,3,4
-        final_sigmoid (bool): if True apply element-wise nn.Sigmoid after the final 1x1 convolution,
-            otherwise apply nn.Softmax. In effect only if `self.training == False`, i.e. during validation/testing
-        basic_module: basic model for the encoder/decoder (DoubleConv, ResNetBlock, ....)
-        layer_order (string): determines the order of layers in `SingleConv` module.
-            E.g. 'crg' stands for GroupNorm3d+Conv3d+ReLU. See `SingleConv` for more info
-        num_groups (int): number of groups for the GroupNorm
-        num_levels (int): number of levels in the encoder/decoder path (applied only if f_maps is an int)
-            default: 4
-        is_segmentation (bool): if True and the model is in eval mode, Sigmoid/Softmax normalization is applied
-            after the final convolution; if False (regression problem) the normalization layer is skipped
-        conv_kernel_size (int or tuple): size of the convolving kernel in the basic_module
-        pool_kernel_size (int or tuple): the size of the window
-        conv_padding (int or tuple): add zero-padding added to all three sides of the input
-        conv_upscale (int): number of the convolution to upscale in encoder if DoubleConv, default: 2
-        upsample (str): algorithm used for decoder upsampling:
-            InterpolateUpsampling:   'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area'
-            TransposeConvUpsampling: 'deconv'
-            No upsampling:           None
-            Default: 'default' (chooses automatically)
-        dropout_prob (float or tuple): dropout probability, default: 0.1
-        is3d (bool): if True the model is 3D, otherwise 2D, default: True
-    """
-
-    def __init__(self, in_channels, out_channels, final_sigmoid, basic_module, f_maps=64, layer_order='gcr',
-                 num_groups=8, num_levels=4, is_segmentation=True, conv_kernel_size=3, pool_kernel_size=2,
-                 conv_padding=1, conv_upscale=2, upsample='default', dropout_prob=0.1, is3d=True):
-        super(AbstractUNet, self).__init__()
-
-        if isinstance(f_maps, int):
-            f_maps = number_of_features_per_level(f_maps, num_levels=num_levels)
-
-        assert isinstance(f_maps, list) or isinstance(f_maps, tuple)
-        assert len(f_maps) > 1, "Required at least 2 levels in the U-Net"
-        if 'g' in layer_order:
-            assert num_groups is not None, "num_groups must be specified if GroupNorm is used"
-
-        # create encoder path
-        self.encoders = create_encoders(in_channels, f_maps, basic_module, conv_kernel_size,
-                                        conv_padding, conv_upscale, dropout_prob,
-                                        layer_order, num_groups, pool_kernel_size, is3d)
-
-        # create decoder path
-        self.decoders = create_decoders(f_maps, basic_module, conv_kernel_size, conv_padding,
-                                        layer_order, num_groups, upsample, dropout_prob,
-                                        is3d)
-
-        # in the last layer a 1×1 convolution reduces the number of output channels to the number of labels
-        if is3d:
-            self.final_conv = nn.Conv3d(f_maps[0], out_channels, 1)
-        else:
-            self.final_conv = nn.Conv2d(f_maps[0], out_channels, 1)
-
-        if is_segmentation:
-            # semantic segmentation problem
-            if final_sigmoid:
-                self.final_activation = nn.Sigmoid()
-            else:
-                self.final_activation = nn.Softmax(dim=1)
-        else:
-            # regression problem
-            self.final_activation = None
-
-    def forward(self, x):
-        # encoder part
-        encoders_features = []
-        for encoder in self.encoders:
-            x = encoder(x)
-            # reverse the encoder outputs to be aligned with the decoder
-            encoders_features.insert(0, x)
-
-        # remove the last encoder's output from the list
-        # !!remember: it's the 1st in the list
-        encoders_features = encoders_features[1:]
-
-        # decoder part
-        for decoder, encoder_features in zip(self.decoders, encoders_features):
-            # pass the output from the corresponding encoder and the output
-            # of the previous decoder
-            x = decoder(encoder_features, x)
-
-        x = self.final_conv(x)
-
-        # apply final_activation (i.e. Sigmoid or Softmax) only during prediction.
-        # During training the network outputs logits
-        if not self.training and self.final_activation is not None:
-            x = self.final_activation(x)
-
-        return x
-
-
-class UNet3D(AbstractUNet):
-    """
-    3DUnet model from
-    `"3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation"
-        <https://arxiv.org/pdf/1606.06650.pdf>`.
-
-    Uses `DoubleConv` as a basic_module and nearest neighbor upsampling in the decoder
-    """
-
-    def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr',
-                 num_groups=8, num_levels=4, is_segmentation=True, conv_padding=1,
-                 conv_upscale=2, upsample='default', dropout_prob=0.1, **kwargs):
-        super(UNet3D, self).__init__(in_channels=in_channels,
-                                     out_channels=out_channels,
-                                     final_sigmoid=final_sigmoid,
-                                     basic_module=DoubleConv,
-                                     f_maps=f_maps,
-                                     layer_order=layer_order,
-                                     num_groups=num_groups,
-                                     num_levels=num_levels,
-                                     is_segmentation=is_segmentation,
-                                     conv_padding=conv_padding,
-                                     conv_upscale=conv_upscale,
-                                     upsample=upsample,
-                                     dropout_prob=dropout_prob,
-                                     is3d=True)
-
-
-class ResidualUNet3D(AbstractUNet):
-    """
-    Residual 3DUnet model implementation based on https://arxiv.org/pdf/1706.00120.pdf.
-    Uses ResNetBlock as a basic building block, summation joining instead
-    of concatenation joining and transposed convolutions for upsampling (watch out for block artifacts).
-    Since the model effectively becomes a residual net, in theory it allows for deeper UNet.
-    """
-
-    def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr',
-                 num_groups=8, num_levels=5, is_segmentation=True, conv_padding=1,
-                 conv_upscale=2, upsample='default', dropout_prob=0.1, **kwargs):
-        super(ResidualUNet3D, self).__init__(in_channels=in_channels,
-                                             out_channels=out_channels,
-                                             final_sigmoid=final_sigmoid,
-                                             basic_module=ResNetBlock,
-                                             f_maps=f_maps,
-                                             layer_order=layer_order,
-                                             num_groups=num_groups,
-                                             num_levels=num_levels,
-                                             is_segmentation=is_segmentation,
-                                             conv_padding=conv_padding,
-                                             conv_upscale=conv_upscale,
-                                             upsample=upsample,
-                                             dropout_prob=dropout_prob,
-                                             is3d=True)
-
-
-class ResidualUNetSE3D(AbstractUNet):
-    """_summary_
-    Residual 3DUnet model implementation with squeeze and excitation based on 
-    https://arxiv.org/pdf/1706.00120.pdf.
-    Uses ResNetBlockSE as a basic building block, summation joining instead
-    of concatenation joining and transposed convolutions for upsampling (watch
-    out for block artifacts). Since the model effectively becomes a residual
-    net, in theory it allows for deeper UNet.
-    """
-
-    def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr',
-                 num_groups=8, num_levels=5, is_segmentation=True, conv_padding=1,
-                 conv_upscale=2, upsample='default', dropout_prob=0.1, **kwargs):
-        super(ResidualUNetSE3D, self).__init__(in_channels=in_channels,
-                                               out_channels=out_channels,
-                                               final_sigmoid=final_sigmoid,
-                                               basic_module=ResNetBlockSE,
-                                               f_maps=f_maps,
-                                               layer_order=layer_order,
-                                               num_groups=num_groups,
-                                               num_levels=num_levels,
-                                               is_segmentation=is_segmentation,
-                                               conv_padding=conv_padding,
-                                               conv_upscale=conv_upscale,
-                                               upsample=upsample,
-                                               dropout_prob=dropout_prob,
-                                               is3d=True)
-
-
-class UNet2D(AbstractUNet):
-    """
-    2DUnet model from
-    `"U-Net: Convolutional Networks for Biomedical Image Segmentation" <https://arxiv.org/abs/1505.04597>`
-    """
-
-    def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr',
-                 num_groups=8, num_levels=4, is_segmentation=True, conv_padding=1,
-                 conv_upscale=2, upsample='default', dropout_prob=0.1, **kwargs):
-        super(UNet2D, self).__init__(in_channels=in_channels,
-                                     out_channels=out_channels,
-                                     final_sigmoid=final_sigmoid,
-                                     basic_module=DoubleConv,
-                                     f_maps=f_maps,
-                                     layer_order=layer_order,
-                                     num_groups=num_groups,
-                                     num_levels=num_levels,
-                                     is_segmentation=is_segmentation,
-                                     conv_padding=conv_padding,
-                                     conv_upscale=conv_upscale,
-                                     upsample=upsample,
-                                     dropout_prob=dropout_prob,
-                                     is3d=False)
-
-
-class ResidualUNet2D(AbstractUNet):
-    """
-    Residual 2DUnet model implementation based on https://arxiv.org/pdf/1706.00120.pdf.
-    """
-
-    def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=64, layer_order='gcr',
-                 num_groups=8, num_levels=5, is_segmentation=True, conv_padding=1,
-                 conv_upscale=2, upsample='default', dropout_prob=0.1, **kwargs):
-        super(ResidualUNet2D, self).__init__(in_channels=in_channels,
-                                             out_channels=out_channels,
-                                             final_sigmoid=final_sigmoid,
-                                             basic_module=ResNetBlock,
-                                             f_maps=f_maps,
-                                             layer_order=layer_order,
-                                             num_groups=num_groups,
-                                             num_levels=num_levels,
-                                             is_segmentation=is_segmentation,
-                                             conv_padding=conv_padding,
-                                             conv_upscale=conv_upscale,
-                                             upsample=upsample,
-                                             dropout_prob=dropout_prob,
-                                             is3d=False)
-
-
-def get_model(model_config):
-    model_class = get_class(model_config['name'], modules=[
-        'pytorch3dunet.unet3d.model'
-    ])
-    return model_class(**model_config)
diff --git a/build/lib/pytorch3dunet/unet3d/predictor.py b/build/lib/pytorch3dunet/unet3d/predictor.py
deleted file mode 100644
index c9b4f6eb..00000000
--- a/build/lib/pytorch3dunet/unet3d/predictor.py
+++ /dev/null
@@ -1,281 +0,0 @@
-import os
-import time
-from concurrent import futures
-from pathlib import Path
-
-import h5py
-import numpy as np
-import torch
-from skimage import measure
-from torch import nn
-from tqdm import tqdm
-
-from pytorch3dunet.datasets.hdf5 import AbstractHDF5Dataset
-from pytorch3dunet.datasets.utils import SliceBuilder, remove_padding
-from pytorch3dunet.unet3d.model import UNet2D
-from pytorch3dunet.unet3d.utils import get_logger
-
-logger = get_logger('UNetPredictor')
-
-
-def _get_output_file(dataset, suffix='_predictions', output_dir=None):
-    input_dir, file_name = os.path.split(dataset.file_path)
-    if output_dir is None:
-        output_dir = input_dir
-    output_filename = os.path.splitext(file_name)[0] + suffix + '.h5'
-    return Path(output_dir) / output_filename
-
-
-def _is_2d_model(model):
-    if isinstance(model, nn.DataParallel):
-        model = model.module
-    return isinstance(model, UNet2D)
-
-
-class _AbstractPredictor:
-    def __init__(self,
-                 model: nn.Module,
-                 output_dir: str,
-                 out_channels: int,
-                 output_dataset: str = 'predictions',
-                 save_segmentation: bool = False,
-                 prediction_channel: int = None,
-                 **kwargs):
-        """
-        Base class for predictors.
-        Args:
-            model: segmentation model
-            output_dir: directory where the predictions will be saved
-            out_channels: number of output channels of the model
-            output_dataset: name of the dataset in the H5 file where the predictions will be saved
-            save_segmentation: if true the segmentation will be saved instead of the probability maps
-            prediction_channel: save only the specified channel from the network output
-        """
-        self.model = model
-        self.output_dir = output_dir
-        self.out_channels = out_channels
-        self.output_dataset = output_dataset
-        self.save_segmentation = save_segmentation
-        self.prediction_channel = prediction_channel
-
-    def __call__(self, test_loader):
-        raise NotImplementedError
-
-
-class StandardPredictor(_AbstractPredictor):
-    """
-    Applies the model on the given dataset and saves the result as H5 file.
-    Predictions from the network are kept in memory. If the results from the network don't fit in into RAM
-    use `LazyPredictor` instead.
-
-    The output dataset names inside the H5 is given by `output_dataset` config argument.
-    """
-
-    def __init__(self,
-                 model: nn.Module,
-                 output_dir: str,
-                 out_channels: int,
-                 output_dataset: str = 'predictions',
-                 save_segmentation: bool = False,
-                 prediction_channel: int = None,
-                 **kwargs):
-        super().__init__(model, output_dir, out_channels, output_dataset, save_segmentation, prediction_channel,
-                         **kwargs)
-
-    def __call__(self, test_loader):
-        assert isinstance(test_loader.dataset, AbstractHDF5Dataset)
-        logger.info(f"Processing '{test_loader.dataset.file_path}'...")
-        start = time.perf_counter()
-
-        logger.info(f'Running inference on {len(test_loader)} batches')
-        # dimensionality of the output predictions
-        volume_shape = test_loader.dataset.volume_shape()
-        if self.prediction_channel is not None:
-            # single channel prediction map
-            prediction_maps_shape = (1,) + volume_shape
-        else:
-            prediction_maps_shape = (self.out_channels,) + volume_shape
-
-        # create destination H5 file
-        output_file = _get_output_file(dataset=test_loader.dataset, output_dir=self.output_dir)
-        with h5py.File(output_file, 'w') as h5_output_file:
-            # allocate prediction and normalization arrays
-            logger.info('Allocating prediction and normalization arrays...')
-            prediction_map, normalization_mask = self._allocate_prediction_maps(prediction_maps_shape, h5_output_file)
-
-            # determine halo used for padding
-            patch_halo = test_loader.dataset.halo_shape
-
-            # Sets the module in evaluation mode explicitly
-            # It is necessary for batchnorm/dropout layers if present as well as final Sigmoid/Softmax to be applied
-            self.model.eval()
-            # Run predictions on the entire input dataset
-            with torch.no_grad():
-                for input, indices in tqdm(test_loader):
-                    # send batch to gpu
-                    if torch.cuda.is_available():
-                        input = input.pin_memory().cuda(non_blocking=True)
-
-                    if _is_2d_model(self.model):
-                        # remove the singleton z-dimension from the input
-                        input = torch.squeeze(input, dim=-3)
-                        # forward pass
-                        prediction = self.model(input)
-                        # add the singleton z-dimension to the output
-                        prediction = torch.unsqueeze(prediction, dim=-3)
-                    else:
-                        # forward pass
-                        prediction = self.model(input)
-
-                    # unpad the predicted patch
-                    prediction = remove_padding(prediction, patch_halo)
-                    # convert to numpy array
-                    prediction = prediction.cpu().numpy()
-                    # for each batch sample
-                    for pred, index in zip(prediction, indices):
-                        # save patch index: (C,D,H,W)
-                        if self.prediction_channel is None:
-                            channel_slice = slice(0, self.out_channels)
-                        else:
-                            # use only the specified channel
-                            channel_slice = slice(0, 1)
-                            pred = np.expand_dims(pred[self.prediction_channel], axis=0)
-
-                        # add channel dimension to the index
-                        index = (channel_slice,) + tuple(index)
-                        # accumulate probabilities into the output prediction array
-                        prediction_map[index] += pred
-                        # count voxel visits for normalization
-                        normalization_mask[index] += 1
-
-            logger.info(f'Finished inference in {time.perf_counter() - start:.2f} seconds')
-            # save results
-            output_type = 'segmentation' if self.save_segmentation else 'probability maps'
-            logger.info(f'Saving {output_type} to: {output_file}')
-            self._save_results(prediction_map, normalization_mask, h5_output_file, test_loader.dataset)
-
-    def _allocate_prediction_maps(self, output_shape, output_file):
-        # initialize the output prediction arrays
-        prediction_map = np.zeros(output_shape, dtype='float32')
-        # initialize normalization mask in order to average out probabilities of overlapping patches
-        normalization_mask = np.zeros(output_shape, dtype='uint8')
-        return prediction_map, normalization_mask
-
-    def _save_results(self, prediction_map, normalization_mask, output_file, dataset):
-        result = prediction_map / normalization_mask
-        if self.save_segmentation:
-            result = np.argmax(result, axis=0).astype('uint16')
-        output_file.create_dataset(self.output_dataset, data=result, compression="gzip")
-
-
-class LazyPredictor(StandardPredictor):
-    """
-        Applies the model on the given dataset and saves the result in the `output_file` in the H5 format.
-        Predicted patches are directly saved into the H5 and they won't be stored in memory. Since this predictor
-        is slower than the `StandardPredictor` it should only be used when the predicted volume does not fit into RAM.
-        """
-
-    def __init__(self,
-                 model: nn.Module,
-                 output_dir: str,
-                 out_channels: int,
-                 output_dataset: str = 'predictions',
-                 save_segmentation: bool = False,
-                 prediction_channel: int = None,
-                 **kwargs):
-        super().__init__(model, output_dir, out_channels, output_dataset, save_segmentation, prediction_channel,
-                         **kwargs)
-
-    def _allocate_prediction_maps(self, output_shape, output_file):
-        # allocate datasets for probability maps
-        prediction_map = output_file.create_dataset(self.output_dataset,
-                                                    shape=output_shape,
-                                                    dtype='float32',
-                                                    chunks=True,
-                                                    compression='gzip')
-        # allocate datasets for normalization masks
-        normalization_mask = output_file.create_dataset('normalization',
-                                                        shape=output_shape,
-                                                        dtype='uint8',
-                                                        chunks=True,
-                                                        compression='gzip')
-        return prediction_map, normalization_mask
-
-    def _save_results(self, prediction_map, normalization_mask, output_file, dataset):
-        z, y, x = prediction_map.shape[1:]
-        # take slices which are 1/27 of the original volume
-        patch_shape = (z // 3, y // 3, x // 3)
-        if self.save_segmentation:
-            output_file.create_dataset('segmentation', shape=(z, y, x), dtype='uint16', chunks=True, compression='gzip')
-
-        for index in SliceBuilder._build_slices(prediction_map, patch_shape=patch_shape, stride_shape=patch_shape):
-            logger.info(f'Normalizing slice: {index}')
-            prediction_map[index] /= normalization_mask[index]
-            # make sure to reset the slice that has been visited already in order to avoid 'double' normalization
-            # when the patches overlap with each other
-            normalization_mask[index] = 1
-            # save segmentation
-            if self.save_segmentation:
-                output_file['segmentation'][index[1:]] = np.argmax(prediction_map[index], axis=0).astype('uint16')
-
-        del output_file['normalization']
-        if self.save_segmentation:
-            del output_file[self.output_dataset]
-
-
-class DSB2018Predictor(_AbstractPredictor):
-    def __init__(self, model, output_dir, config, save_segmentation=True, pmaps_thershold=0.5, **kwargs):
-        super().__init__(model, output_dir, config, **kwargs)
-        self.pmaps_threshold = pmaps_thershold
-        self.save_segmentation = save_segmentation
-
-    def _slice_from_pad(self, pad):
-        if pad == 0:
-            return slice(None, None)
-        else:
-            return slice(pad, -pad)
-
-    def __call__(self, test_loader):
-        # Sets the module in evaluation mode explicitly
-        self.model.eval()
-        # initial process pool for saving results to disk
-        executor = futures.ProcessPoolExecutor(max_workers=32)
-        # Run predictions on the entire input dataset
-        with torch.no_grad():
-            for img, path in test_loader:
-                # send batch to gpu
-                if torch.cuda.is_available():
-                    img = img.cuda(non_blocking=True)
-                # forward pass
-                pred = self.model(img)
-
-                executor.submit(
-                    dsb_save_batch,
-                    self.output_dir,
-                    path
-                )
-
-        print('Waiting for all predictions to be saved to disk...')
-        executor.shutdown(wait=True)
-
-
-def dsb_save_batch(output_dir, path, pred, save_segmentation=True, pmaps_thershold=0.5):
-    def _pmaps_to_seg(pred):
-        mask = (pred > pmaps_thershold)
-        return measure.label(mask).astype('uint16')
-
-    # convert to numpy array
-    for single_pred, single_path in zip(pred, path):
-        logger.info(f'Processing {single_path}')
-        single_pred = single_pred.squeeze()
-
-        # save to h5 file
-        out_file = os.path.splitext(single_path)[0] + '_predictions.h5'
-        if output_dir is not None:
-            out_file = os.path.join(output_dir, os.path.split(out_file)[1])
-
-        with h5py.File(out_file, 'w') as f:
-            # logger.info(f'Saving output to {out_file}')
-            f.create_dataset('predictions', data=single_pred, compression='gzip')
-            if save_segmentation:
-                f.create_dataset('segmentation', data=_pmaps_to_seg(single_pred), compression='gzip')
diff --git a/build/lib/pytorch3dunet/unet3d/se.py b/build/lib/pytorch3dunet/unet3d/se.py
deleted file mode 100644
index 23fac3d7..00000000
--- a/build/lib/pytorch3dunet/unet3d/se.py
+++ /dev/null
@@ -1,113 +0,0 @@
-"""
-3D Squeeze and Excitation Modules
-*****************************
-3D Extensions of the following 2D squeeze and excitation blocks:
-    1. `Channel Squeeze and Excitation <https://arxiv.org/abs/1709.01507>`_
-    2. `Spatial Squeeze and Excitation <https://arxiv.org/abs/1803.02579>`_
-    3. `Channel and Spatial Squeeze and Excitation <https://arxiv.org/abs/1803.02579>`_
-New Project & Excite block, designed specifically for 3D inputs
-    'quote'
-    Coded by -- Anne-Marie Rickmann (https://github.com/arickm)
-"""
-
-import torch
-from torch import nn as nn
-from torch.nn import functional as F
-
-
-class ChannelSELayer3D(nn.Module):
-    """
-    3D extension of Squeeze-and-Excitation (SE) block described in:
-        *Hu et al., Squeeze-and-Excitation Networks, arXiv:1709.01507*
-        *Zhu et al., AnatomyNet, arXiv:arXiv:1808.05238*
-    """
-
-    def __init__(self, num_channels, reduction_ratio=2):
-        """
-        Args:
-            num_channels (int): No of input channels
-            reduction_ratio (int): By how much should the num_channels should be reduced
-        """
-        super(ChannelSELayer3D, self).__init__()
-        self.avg_pool = nn.AdaptiveAvgPool3d(1)
-        num_channels_reduced = num_channels // reduction_ratio
-        self.reduction_ratio = reduction_ratio
-        self.fc1 = nn.Linear(num_channels, num_channels_reduced, bias=True)
-        self.fc2 = nn.Linear(num_channels_reduced, num_channels, bias=True)
-        self.relu = nn.ReLU()
-        self.sigmoid = nn.Sigmoid()
-
-    def forward(self, x):
-        batch_size, num_channels, D, H, W = x.size()
-        # Average along each channel
-        squeeze_tensor = self.avg_pool(x)
-
-        # channel excitation
-        fc_out_1 = self.relu(self.fc1(squeeze_tensor.view(batch_size, num_channels)))
-        fc_out_2 = self.sigmoid(self.fc2(fc_out_1))
-
-        output_tensor = torch.mul(x, fc_out_2.view(batch_size, num_channels, 1, 1, 1))
-
-        return output_tensor
-
-
-class SpatialSELayer3D(nn.Module):
-    """
-    3D extension of SE block -- squeezing spatially and exciting channel-wise described in:
-        *Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks, MICCAI 2018*
-    """
-
-    def __init__(self, num_channels):
-        """
-        Args:
-            num_channels (int): No of input channels
-        """
-        super(SpatialSELayer3D, self).__init__()
-        self.conv = nn.Conv3d(num_channels, 1, 1)
-        self.sigmoid = nn.Sigmoid()
-
-    def forward(self, x, weights=None):
-        """
-        Args:
-            weights (torch.Tensor): weights for few shot learning
-            x: X, shape = (batch_size, num_channels, D, H, W)
-
-        Returns:
-            (torch.Tensor): output_tensor
-        """
-        # channel squeeze
-        batch_size, channel, D, H, W = x.size()
-
-        if weights:
-            weights = weights.view(1, channel, 1, 1)
-            out = F.conv2d(x, weights)
-        else:
-            out = self.conv(x)
-
-        squeeze_tensor = self.sigmoid(out)
-
-        # spatial excitation
-        output_tensor = torch.mul(x, squeeze_tensor.view(batch_size, 1, D, H, W))
-
-        return output_tensor
-
-
-class ChannelSpatialSELayer3D(nn.Module):
-    """
-       3D extension of concurrent spatial and channel squeeze & excitation:
-           *Roy et al., Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks, arXiv:1803.02579*
-       """
-
-    def __init__(self, num_channels, reduction_ratio=2):
-        """
-        Args:
-            num_channels (int): No of input channels
-            reduction_ratio (int): By how much should the num_channels should be reduced
-        """
-        super(ChannelSpatialSELayer3D, self).__init__()
-        self.cSE = ChannelSELayer3D(num_channels, reduction_ratio)
-        self.sSE = SpatialSELayer3D(num_channels)
-
-    def forward(self, input_tensor):
-        output_tensor = torch.max(self.cSE(input_tensor), self.sSE(input_tensor))
-        return output_tensor
diff --git a/build/lib/pytorch3dunet/unet3d/seg_metrics.py b/build/lib/pytorch3dunet/unet3d/seg_metrics.py
deleted file mode 100644
index e713ea23..00000000
--- a/build/lib/pytorch3dunet/unet3d/seg_metrics.py
+++ /dev/null
@@ -1,123 +0,0 @@
-import numpy as np
-from skimage.metrics import contingency_table
-
-
-def precision(tp, fp, fn):
-    return tp / (tp + fp) if tp > 0 else 0
-
-
-def recall(tp, fp, fn):
-    return tp / (tp + fn) if tp > 0 else 0
-
-
-def accuracy(tp, fp, fn):
-    return tp / (tp + fp + fn) if tp > 0 else 0
-
-
-def f1(tp, fp, fn):
-    return (2 * tp) / (2 * tp + fp + fn) if tp > 0 else 0
-
-
-def _relabel(input):
-    _, unique_labels = np.unique(input, return_inverse=True)
-    return unique_labels.reshape(input.shape)
-
-
-def _iou_matrix(gt, seg):
-    # relabel gt and seg for smaller memory footprint of contingency table
-    gt = _relabel(gt)
-    seg = _relabel(seg)
-
-    # get number of overlapping pixels between GT and SEG
-    n_inter = contingency_table(gt, seg).A
-
-    # number of pixels for GT instances
-    n_gt = n_inter.sum(axis=1, keepdims=True)
-    # number of pixels for SEG instances
-    n_seg = n_inter.sum(axis=0, keepdims=True)
-
-    # number of pixels in the union between GT and SEG instances
-    n_union = n_gt + n_seg - n_inter
-
-    iou_matrix = n_inter / n_union
-    # make sure that the values are within [0,1] range
-    assert 0 <= np.min(iou_matrix) <= np.max(iou_matrix) <= 1
-
-    return iou_matrix
-
-
-class SegmentationMetrics:
-    """
-    Computes precision, recall, accuracy, f1 score for a given ground truth and predicted segmentation.
-    Contingency table for a given ground truth and predicted segmentation is computed eagerly upon construction
-    of the instance of `SegmentationMetrics`.
-
-    Args:
-        gt (ndarray): ground truth segmentation
-        seg (ndarray): predicted segmentation
-    """
-
-    def __init__(self, gt, seg):
-        self.iou_matrix = _iou_matrix(gt, seg)
-
-    def metrics(self, iou_threshold):
-        """
-        Computes precision, recall, accuracy, f1 score at a given IoU threshold
-        """
-        # ignore background
-        iou_matrix = self.iou_matrix[1:, 1:]
-        detection_matrix = (iou_matrix > iou_threshold).astype(np.uint8)
-        n_gt, n_seg = detection_matrix.shape
-
-        # if the iou_matrix is empty or all values are 0
-        trivial = min(n_gt, n_seg) == 0 or np.all(detection_matrix == 0)
-        if trivial:
-            tp = fp = fn = 0
-        else:
-            # count non-zero rows to get the number of TP
-            tp = np.count_nonzero(detection_matrix.sum(axis=1))
-            # count zero rows to get the number of FN
-            fn = n_gt - tp
-            # count zero columns to get the number of FP
-            fp = n_seg - np.count_nonzero(detection_matrix.sum(axis=0))
-
-        return {
-            'precision': precision(tp, fp, fn),
-            'recall': recall(tp, fp, fn),
-            'accuracy': accuracy(tp, fp, fn),
-            'f1': f1(tp, fp, fn)
-        }
-
-
-class Accuracy:
-    """
-    Computes accuracy between ground truth and predicted segmentation a a given threshold value.
-    Defined as: AC = TP / (TP + FP + FN).
-    Kaggle DSB2018 calls it Precision, see:
-    https://www.kaggle.com/stkbailey/step-by-step-explanation-of-scoring-metric.
-    """
-
-    def __init__(self, iou_threshold):
-        self.iou_threshold = iou_threshold
-
-    def __call__(self, input_seg, gt_seg):
-        metrics = SegmentationMetrics(gt_seg, input_seg).metrics(self.iou_threshold)
-        return metrics['accuracy']
-
-
-class AveragePrecision:
-    """
-    Average precision taken for the IoU range (0.5, 0.95) with a step of 0.05 as defined in:
-    https://www.kaggle.com/stkbailey/step-by-step-explanation-of-scoring-metric
-    """
-
-    def __init__(self):
-        self.iou_range = np.linspace(0.50, 0.95, 10)
-
-    def __call__(self, input_seg, gt_seg):
-        # compute contingency_table
-        sm = SegmentationMetrics(gt_seg, input_seg)
-        # compute accuracy for each threshold
-        acc = [sm.metrics(iou)['accuracy'] for iou in self.iou_range]
-        # return the average
-        return np.mean(acc)
diff --git a/build/lib/pytorch3dunet/unet3d/trainer.py b/build/lib/pytorch3dunet/unet3d/trainer.py
deleted file mode 100644
index 4b59d568..00000000
--- a/build/lib/pytorch3dunet/unet3d/trainer.py
+++ /dev/null
@@ -1,404 +0,0 @@
-import os
-import torch
-import torch.nn as nn
-from torch.optim.lr_scheduler import ReduceLROnPlateau
-from torch.utils.tensorboard import SummaryWriter
-from datetime import datetime
-
-from pytorch3dunet.datasets.utils import get_train_loaders
-from pytorch3dunet.unet3d.losses import get_loss_criterion
-from pytorch3dunet.unet3d.metrics import get_evaluation_metric
-from pytorch3dunet.unet3d.model import get_model, UNet2D
-from pytorch3dunet.unet3d.utils import get_logger, get_tensorboard_formatter, create_optimizer, \
-    create_lr_scheduler, get_number_of_learnable_parameters
-from . import utils
-
-logger = get_logger('UNetTrainer')
-
-
-def create_trainer(config):
-    # Create the model
-    model = get_model(config['model'])
-
-    if torch.cuda.device_count() > 1 and not config['device'] == 'cpu':
-        model = nn.DataParallel(model)
-        logger.info(f'Using {torch.cuda.device_count()} GPUs for prediction')
-    if torch.cuda.is_available() and not config['device'] == 'cpu':
-        model = model.cuda()
-
-    # Log the number of learnable parameters
-    logger.info(f'Number of learnable params {get_number_of_learnable_parameters(model)}')
-
-    # Create loss criterion
-    loss_criterion = get_loss_criterion(config)
-    # Create evaluation metric
-    eval_criterion = get_evaluation_metric(config)
-
-    # Create data loaders
-    loaders = get_train_loaders(config)
-
-    # Create the optimizer
-    optimizer = create_optimizer(config['optimizer'], model)
-
-    # Create learning rate adjustment strategy
-    lr_scheduler = create_lr_scheduler(config.get('lr_scheduler', None), optimizer)
-
-    trainer_config = config['trainer']
-    # Create tensorboard formatter
-    tensorboard_formatter = get_tensorboard_formatter(trainer_config.pop('tensorboard_formatter', None))
-    # Create trainer
-    resume = trainer_config.pop('resume', None)
-    pre_trained = trainer_config.pop('pre_trained', None)
-
-    return UNetTrainer(model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, loss_criterion=loss_criterion,
-                       eval_criterion=eval_criterion, loaders=loaders, tensorboard_formatter=tensorboard_formatter,
-                       resume=resume, pre_trained=pre_trained, **trainer_config)
-
-
-class UNetTrainer:
-    """UNet trainer.
-
-    Args:
-        model (Unet3D): UNet 3D model to be trained
-        optimizer (nn.optim.Optimizer): optimizer used for training
-        lr_scheduler (torch.optim.lr_scheduler._LRScheduler): learning rate scheduler
-            WARN: bear in mind that lr_scheduler.step() is invoked after every validation step
-            (i.e. validate_after_iters) not after every epoch. So e.g. if one uses StepLR with step_size=30
-            the learning rate will be adjusted after every 30 * validate_after_iters iterations.
-        loss_criterion (callable): loss function
-        eval_criterion (callable): used to compute training/validation metric (such as Dice, IoU, AP or Rand score)
-            saving the best checkpoint is based on the result of this function on the validation set
-        loaders (dict): 'train' and 'val' loaders
-        checkpoint_dir (string): dir for saving checkpoints and tensorboard logs
-        max_num_epochs (int): maximum number of epochs
-        max_num_iterations (int): maximum number of iterations
-        validate_after_iters (int): validate after that many iterations
-        log_after_iters (int): number of iterations before logging to tensorboard
-        validate_iters (int): number of validation iterations, if None validate
-            on the whole validation set
-        eval_score_higher_is_better (bool): if True higher eval scores are considered better
-        best_eval_score (float): best validation score so far (higher better)
-        num_iterations (int): useful when loading the model from the checkpoint
-        num_epoch (int): useful when loading the model from the checkpoint
-        tensorboard_formatter (callable): converts a given batch of input/output/target image to a series of images
-            that can be displayed in tensorboard
-        skip_train_validation (bool): if True eval_criterion is not evaluated on the training set (used mostly when
-            evaluation is expensive)
-    """
-
-    def __init__(self, model, optimizer, lr_scheduler, loss_criterion, eval_criterion, loaders, checkpoint_dir,
-                 max_num_epochs, max_num_iterations, validate_after_iters=200, log_after_iters=100, validate_iters=None,
-                 num_iterations=1, num_epoch=0, eval_score_higher_is_better=True, tensorboard_formatter=None,
-                 skip_train_validation=False, resume=None, pre_trained=None, **kwargs):
-
-        self.model = model
-        self.optimizer = optimizer
-        self.scheduler = lr_scheduler
-        self.loss_criterion = loss_criterion
-        self.eval_criterion = eval_criterion
-        self.loaders = loaders
-        self.checkpoint_dir = checkpoint_dir
-        self.max_num_epochs = max_num_epochs
-        self.max_num_iterations = max_num_iterations
-        self.validate_after_iters = validate_after_iters
-        self.log_after_iters = log_after_iters
-        self.validate_iters = validate_iters
-        self.eval_score_higher_is_better = eval_score_higher_is_better
-
-        logger.info(model)
-        logger.info(f'eval_score_higher_is_better: {eval_score_higher_is_better}')
-
-        # initialize the best_eval_score
-        if eval_score_higher_is_better:
-            self.best_eval_score = float('-inf')
-        else:
-            self.best_eval_score = float('+inf')
-
-        self.writer = SummaryWriter(
-            log_dir=os.path.join(
-                checkpoint_dir, 'logs', 
-                datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
-                )
-            )
-
-        assert tensorboard_formatter is not None, 'TensorboardFormatter must be provided'
-        self.tensorboard_formatter = tensorboard_formatter
-
-        self.num_iterations = num_iterations
-        self.num_epochs = num_epoch
-        self.skip_train_validation = skip_train_validation
-
-        if resume is not None:
-            logger.info(f"Loading checkpoint '{resume}'...")
-            state = utils.load_checkpoint(resume, self.model, self.optimizer)
-            logger.info(
-                f"Checkpoint loaded from '{resume}'. Epoch: {state['num_epochs']}.  Iteration: {state['num_iterations']}. "
-                f"Best val score: {state['best_eval_score']}."
-            )
-            self.best_eval_score = state['best_eval_score']
-            self.num_iterations = state['num_iterations']
-            self.num_epochs = state['num_epochs']
-            self.checkpoint_dir = os.path.split(resume)[0]
-        elif pre_trained is not None:
-            logger.info(f"Logging pre-trained model from '{pre_trained}'...")
-            utils.load_checkpoint(pre_trained, self.model, None)
-            if 'checkpoint_dir' not in kwargs:
-                self.checkpoint_dir = os.path.split(pre_trained)[0]
-
-    def fit(self):
-        for _ in range(self.num_epochs, self.max_num_epochs):
-            # train for one epoch
-            should_terminate = self.train()
-
-            if should_terminate:
-                logger.info('Stopping criterion is satisfied. Finishing training')
-                return
-
-            self.num_epochs += 1
-        logger.info(f"Reached maximum number of epochs: {self.max_num_epochs}. Finishing training...")
-
-    def train(self):
-        """Trains the model for 1 epoch.
-
-        Returns:
-            True if the training should be terminated immediately, False otherwise
-        """
-        train_losses = utils.RunningAverage()
-        train_eval_scores = utils.RunningAverage()
-
-        # sets the model in training mode
-        self.model.train()
-
-        for t in self.loaders['train']:
-            logger.info(f'Training iteration [{self.num_iterations}/{self.max_num_iterations}]. '
-                        f'Epoch [{self.num_epochs}/{self.max_num_epochs - 1}]')
-
-            input, target, weight = self._split_training_batch(t)
-
-            output, loss = self._forward_pass(input, target, weight)
-
-            train_losses.update(loss.item(), self._batch_size(input))
-
-            # compute gradients and update parameters
-            self.optimizer.zero_grad()
-            loss.backward()
-            self.optimizer.step()
-
-            if self.num_iterations % self.validate_after_iters == 0:
-                # set the model in eval mode
-                self.model.eval()
-                # evaluate on validation set
-                eval_score = self.validate()
-                # set the model back to training mode
-                self.model.train()
-
-                # adjust learning rate if necessary
-                if isinstance(self.scheduler, ReduceLROnPlateau):
-                    self.scheduler.step(eval_score)
-                elif self.scheduler is not None:
-                    self.scheduler.step()
-
-                # log current learning rate in tensorboard
-                self._log_lr()
-                # remember best validation metric
-                is_best = self._is_best_eval_score(eval_score)
-
-                # save checkpoint
-                self._save_checkpoint(is_best)
-
-            if self.num_iterations % self.log_after_iters == 0:
-                # compute eval criterion
-                if not self.skip_train_validation:
-                    # apply final activation before calculating eval score
-                    if isinstance(self.model, nn.DataParallel):
-                        final_activation = self.model.module.final_activation
-                    else:
-                        final_activation = self.model.final_activation
-
-                    if final_activation is not None:
-                        act_output = final_activation(output)
-                    else:
-                        act_output = output
-                    eval_score = self.eval_criterion(act_output, target)
-                    train_eval_scores.update(eval_score.item(), self._batch_size(input))
-
-                # log stats, params and images
-                logger.info(
-                    f'Training stats. Loss: {train_losses.avg}. Evaluation score: {train_eval_scores.avg}')
-                self._log_stats('train', train_losses.avg, train_eval_scores.avg)
-                # self._log_params()
-                self._log_images(input, target, output, 'train_')
-
-            if self.should_stop():
-                return True
-
-            self.num_iterations += 1
-
-        return False
-
-    def should_stop(self):
-        """
-        Training will terminate if maximum number of iterations is exceeded or the learning rate drops below
-        some predefined threshold (1e-6 in our case)
-        """
-        if self.max_num_iterations < self.num_iterations:
-            logger.info(f'Maximum number of iterations {self.max_num_iterations} exceeded.')
-            return True
-
-        min_lr = 1e-6
-        lr = self.optimizer.param_groups[0]['lr']
-        if lr < min_lr:
-            logger.info(f'Learning rate below the minimum {min_lr}.')
-            return True
-
-        return False
-
-    def validate(self):
-        logger.info('Validating...')
-
-        val_losses = utils.RunningAverage()
-        val_scores = utils.RunningAverage()
-
-        with torch.no_grad():
-            for i, t in enumerate(self.loaders['val']):
-                logger.info(f'Validation iteration {i}')
-
-                input, target, weight = self._split_training_batch(t)
-
-                output, loss = self._forward_pass(input, target, weight)
-                val_losses.update(loss.item(), self._batch_size(input))
-
-                if i % 100 == 0:
-                    self._log_images(input, target, output, 'val_')
-
-                eval_score = self.eval_criterion(output, target)
-                val_scores.update(eval_score.item(), self._batch_size(input))
-
-                if self.validate_iters is not None and self.validate_iters <= i:
-                    # stop validation
-                    break
-
-            self._log_stats('val', val_losses.avg, val_scores.avg)
-            logger.info(f'Validation finished. Loss: {val_losses.avg}. Evaluation score: {val_scores.avg}')
-            return val_scores.avg
-
-    def _split_training_batch(self, t):
-        def _move_to_gpu(input):
-            if isinstance(input, tuple) or isinstance(input, list):
-                return tuple([_move_to_gpu(x) for x in input])
-            else:
-                if torch.cuda.is_available():
-                    input = input.cuda(non_blocking=True)
-                return input
-
-        t = _move_to_gpu(t)
-        weight = None
-        if len(t) == 2:
-            input, target = t
-        else:
-            input, target, weight = t
-        return input, target, weight
-
-    def _forward_pass(self, input, target, weight=None):
-        if isinstance(self.model, UNet2D):
-            # remove the singleton z-dimension from the input
-            input = torch.squeeze(input, dim=-3)
-            # forward pass
-            output = self.model(input)
-            # add the singleton z-dimension to the output
-            output = torch.unsqueeze(output, dim=-3)
-        else:
-            # forward pass
-            output = self.model(input)
-
-        # compute the loss
-        if weight is None:
-            loss = self.loss_criterion(output, target)
-        else:
-            loss = self.loss_criterion(output, target, weight)
-
-        return output, loss
-
-    def _is_best_eval_score(self, eval_score):
-        if self.eval_score_higher_is_better:
-            is_best = eval_score > self.best_eval_score
-        else:
-            is_best = eval_score < self.best_eval_score
-
-        if is_best:
-            logger.info(f'Saving new best evaluation metric: {eval_score}')
-            self.best_eval_score = eval_score
-
-        return is_best
-
-    def _save_checkpoint(self, is_best):
-        # remove `module` prefix from layer names when using `nn.DataParallel`
-        # see: https://discuss.pytorch.org/t/solved-keyerror-unexpected-key-module-encoder-embedding-weight-in-state-dict/1686/20
-        if isinstance(self.model, nn.DataParallel):
-            state_dict = self.model.module.state_dict()
-        else:
-            state_dict = self.model.state_dict()
-
-        last_file_path = os.path.join(self.checkpoint_dir, 'last_checkpoint.pytorch')
-        logger.info(f"Saving checkpoint to '{last_file_path}'")
-
-        utils.save_checkpoint({
-            'num_epochs': self.num_epochs + 1,
-            'num_iterations': self.num_iterations,
-            'model_state_dict': state_dict,
-            'best_eval_score': self.best_eval_score,
-            'optimizer_state_dict': self.optimizer.state_dict(),
-        }, is_best, checkpoint_dir=self.checkpoint_dir)
-
-    def _log_lr(self):
-        lr = self.optimizer.param_groups[0]['lr']
-        self.writer.add_scalar('learning_rate', lr, self.num_iterations)
-
-    def _log_stats(self, phase, loss_avg, eval_score_avg):
-        tag_value = {
-            f'{phase}_loss_avg': loss_avg,
-            f'{phase}_eval_score_avg': eval_score_avg
-        }
-
-        for tag, value in tag_value.items():
-            self.writer.add_scalar(tag, value, self.num_iterations)
-
-    def _log_params(self):
-        logger.info('Logging model parameters and gradients')
-        for name, value in self.model.named_parameters():
-            self.writer.add_histogram(name, value.data.cpu().numpy(), self.num_iterations)
-            self.writer.add_histogram(name + '/grad', value.grad.data.cpu().numpy(), self.num_iterations)
-
-    def _log_images(self, input, target, prediction, prefix=''):
-
-        if isinstance(self.model, nn.DataParallel):
-            net = self.model.module
-        else:
-            net = self.model
-
-        if net.final_activation is not None:
-            prediction = net.final_activation(prediction)
-
-        inputs_map = {
-            'inputs': input,
-            'targets': target,
-            'predictions': prediction
-        }
-        img_sources = {}
-        for name, batch in inputs_map.items():
-            if isinstance(batch, list) or isinstance(batch, tuple):
-                for i, b in enumerate(batch):
-                    img_sources[f'{name}{i}'] = b.data.cpu().numpy()
-            else:
-                img_sources[name] = batch.data.cpu().numpy()
-
-        for name, batch in img_sources.items():
-            for tag, image in self.tensorboard_formatter(name, batch):
-                self.writer.add_image(prefix + tag, image, self.num_iterations)
-
-    @staticmethod
-    def _batch_size(input):
-        if isinstance(input, list) or isinstance(input, tuple):
-            return input[0].size(0)
-        else:
-            return input.size(0)
diff --git a/build/lib/pytorch3dunet/unet3d/utils.py b/build/lib/pytorch3dunet/unet3d/utils.py
deleted file mode 100644
index 01d5559c..00000000
--- a/build/lib/pytorch3dunet/unet3d/utils.py
+++ /dev/null
@@ -1,366 +0,0 @@
-import importlib
-import logging
-import os
-import shutil
-import sys
-
-import h5py
-import numpy as np
-import torch
-from torch import optim
-
-
-def save_checkpoint(state, is_best, checkpoint_dir):
-    """Saves model and training parameters at '{checkpoint_dir}/last_checkpoint.pytorch'.
-    If is_best==True saves '{checkpoint_dir}/best_checkpoint.pytorch' as well.
-
-    Args:
-        state (dict): contains model's state_dict, optimizer's state_dict, epoch
-            and best evaluation metric value so far
-        is_best (bool): if True state contains the best model seen so far
-        checkpoint_dir (string): directory where the checkpoint are to be saved
-    """
-
-    if not os.path.exists(checkpoint_dir):
-        os.mkdir(checkpoint_dir)
-
-    last_file_path = os.path.join(checkpoint_dir, 'last_checkpoint.pytorch')
-    torch.save(state, last_file_path)
-    if is_best:
-        best_file_path = os.path.join(checkpoint_dir, 'best_checkpoint.pytorch')
-        shutil.copyfile(last_file_path, best_file_path)
-
-
-def load_checkpoint(checkpoint_path, model, optimizer=None,
-                    model_key='model_state_dict', optimizer_key='optimizer_state_dict'):
-    """Loads model and training parameters from a given checkpoint_path
-    If optimizer is provided, loads optimizer's state_dict of as well.
-
-    Args:
-        checkpoint_path (string): path to the checkpoint to be loaded
-        model (torch.nn.Module): model into which the parameters are to be copied
-        optimizer (torch.optim.Optimizer) optional: optimizer instance into
-            which the parameters are to be copied
-
-    Returns:
-        state
-    """
-    if not os.path.exists(checkpoint_path):
-        raise IOError(f"Checkpoint '{checkpoint_path}' does not exist")
-
-    state = torch.load(checkpoint_path, map_location='cpu')
-    model.load_state_dict(state[model_key])
-
-    if optimizer is not None:
-        optimizer.load_state_dict(state[optimizer_key])
-
-    return state
-
-
-def save_network_output(output_path, output, logger=None):
-    if logger is not None:
-        logger.info(f'Saving network output to: {output_path}...')
-    output = output.detach().cpu()[0]
-    with h5py.File(output_path, 'w') as f:
-        f.create_dataset('predictions', data=output, compression='gzip')
-
-
-loggers = {}
-
-
-def get_logger(name, level=logging.INFO):
-    global loggers
-    if loggers.get(name) is not None:
-        return loggers[name]
-    else:
-        logger = logging.getLogger(name)
-        logger.setLevel(level)
-        # Logging to console
-        stream_handler = logging.StreamHandler(sys.stdout)
-        formatter = logging.Formatter(
-            '%(asctime)s [%(threadName)s] %(levelname)s %(name)s - %(message)s')
-        stream_handler.setFormatter(formatter)
-        logger.addHandler(stream_handler)
-
-        loggers[name] = logger
-
-        return logger
-
-
-def get_number_of_learnable_parameters(model):
-    return sum(p.numel() for p in model.parameters() if p.requires_grad)
-
-
-class RunningAverage:
-    """Computes and stores the average
-    """
-
-    def __init__(self):
-        self.count = 0
-        self.sum = 0
-        self.avg = 0
-
-    def update(self, value, n=1):
-        self.count += n
-        self.sum += value * n
-        self.avg = self.sum / self.count
-
-
-def number_of_features_per_level(init_channel_number, num_levels):
-    return [init_channel_number * 2 ** k for k in range(num_levels)]
-
-
-class _TensorboardFormatter:
-    """
-    Tensorboard formatters converts a given batch of images (be it input/output to the network or the target segmentation
-    image) to a series of images that can be displayed in tensorboard. This is the parent class for all tensorboard
-    formatters which ensures that returned images are in the 'CHW' format.
-    """
-
-    def __init__(self, **kwargs):
-        pass
-
-    def __call__(self, name, batch):
-        """
-        Transform a batch to a series of tuples of the form (tag, img), where `tag` corresponds to the image tag
-        and `img` is the image itself.
-
-        Args:
-             name (str): one of 'inputs'/'targets'/'predictions'
-             batch (torch.tensor): 4D or 5D torch tensor
-        """
-
-        def _check_img(tag_img):
-            tag, img = tag_img
-
-            assert img.ndim == 2 or img.ndim == 3, 'Only 2D (HW) and 3D (CHW) images are accepted for display'
-
-            if img.ndim == 2:
-                img = np.expand_dims(img, axis=0)
-            else:
-                C = img.shape[0]
-                assert C == 1 or C == 3, 'Only (1, H, W) or (3, H, W) images are supported'
-
-            return tag, img
-
-        tagged_images = self.process_batch(name, batch)
-
-        return list(map(_check_img, tagged_images))
-
-    def process_batch(self, name, batch):
-        raise NotImplementedError
-
-
-class DefaultTensorboardFormatter(_TensorboardFormatter):
-    def __init__(self, skip_last_target=False, **kwargs):
-        super().__init__(**kwargs)
-        self.skip_last_target = skip_last_target
-
-    def process_batch(self, name, batch):
-        if name == 'targets' and self.skip_last_target:
-            batch = batch[:, :-1, ...]
-
-        tag_template = '{}/batch_{}/channel_{}/slice_{}'
-
-        tagged_images = []
-
-        if batch.ndim == 5:
-            # NCDHW
-            slice_idx = batch.shape[2] // 2  # get the middle slice
-            for batch_idx in range(batch.shape[0]):
-                for channel_idx in range(batch.shape[1]):
-                    tag = tag_template.format(name, batch_idx, channel_idx, slice_idx)
-                    img = batch[batch_idx, channel_idx, slice_idx, ...]
-                    tagged_images.append((tag, self._normalize_img(img)))
-        else:
-            # batch has no channel dim: NDHW
-            slice_idx = batch.shape[1] // 2  # get the middle slice
-            for batch_idx in range(batch.shape[0]):
-                tag = tag_template.format(name, batch_idx, 0, slice_idx)
-                img = batch[batch_idx, slice_idx, ...]
-                tagged_images.append((tag, self._normalize_img(img)))
-
-        return tagged_images
-
-    @staticmethod
-    def _normalize_img(img):
-        return np.nan_to_num((img - np.min(img)) / np.ptp(img))
-
-
-def _find_masks(batch, min_size=10):
-    """Center the z-slice in the 'middle' of a given instance, given a batch of instances
-
-    Args:
-        batch (ndarray): 5d numpy tensor (NCDHW)
-    """
-    result = []
-    for b in batch:
-        assert b.shape[0] == 1
-        patch = b[0]
-        z_sum = patch.sum(axis=(1, 2))
-        coords = np.where(z_sum > min_size)[0]
-        if len(coords) > 0:
-            ind = coords[len(coords) // 2]
-            result.append(b[:, ind:ind + 1, ...])
-        else:
-            ind = b.shape[1] // 2
-            result.append(b[:, ind:ind + 1, ...])
-
-    return np.stack(result, axis=0)
-
-
-def get_tensorboard_formatter(formatter_config):
-    if formatter_config is None:
-        return DefaultTensorboardFormatter()
-
-    class_name = formatter_config['name']
-    m = importlib.import_module('pytorch3dunet.unet3d.utils')
-    clazz = getattr(m, class_name)
-    return clazz(**formatter_config)
-
-
-def expand_as_one_hot(input, C, ignore_index=None):
-    """
-    Converts NxSPATIAL label image to NxCxSPATIAL, where each label gets converted to its corresponding one-hot vector.
-    It is assumed that the batch dimension is present.
-    Args:
-        input (torch.Tensor): 3D/4D input image
-        C (int): number of channels/labels
-        ignore_index (int): ignore index to be kept during the expansion
-    Returns:
-        4D/5D output torch.Tensor (NxCxSPATIAL)
-    """
-    assert input.dim() == 4
-
-    # expand the input tensor to Nx1xSPATIAL before scattering
-    input = input.unsqueeze(1)
-    # create output tensor shape (NxCxSPATIAL)
-    shape = list(input.size())
-    shape[1] = C
-
-    if ignore_index is not None:
-        # create ignore_index mask for the result
-        mask = input.expand(shape) == ignore_index
-        # clone the src tensor and zero out ignore_index in the input
-        input = input.clone()
-        input[input == ignore_index] = 0
-        # scatter to get the one-hot tensor
-        result = torch.zeros(shape).to(input.device).scatter_(1, input, 1)
-        # bring back the ignore_index in the result
-        result[mask] = ignore_index
-        return result
-    else:
-        # scatter to get the one-hot tensor
-        return torch.zeros(shape).to(input.device).scatter_(1, input, 1)
-
-
-def convert_to_numpy(*inputs):
-    """
-    Coverts input tensors to numpy ndarrays
-
-    Args:
-        inputs (iteable of torch.Tensor): torch tensor
-
-    Returns:
-        tuple of ndarrays
-    """
-
-    def _to_numpy(i):
-        assert isinstance(i, torch.Tensor), "Expected input to be torch.Tensor"
-        return i.detach().cpu().numpy()
-
-    return (_to_numpy(i) for i in inputs)
-
-
-def create_optimizer(optimizer_config, model):
-    optim_name = optimizer_config.get('name', 'Adam')
-    # common optimizer settings
-    learning_rate = optimizer_config.get('learning_rate', 1e-3)
-    weight_decay = optimizer_config.get('weight_decay', 0)
-
-    # grab optimizer specific settings and init
-    # optimizer
-    if optim_name == 'Adadelta':
-        rho = optimizer_config.get('rho', 0.9)
-        optimizer = optim.Adadelta(model.parameters(), lr=learning_rate, rho=rho,
-                                   weight_decay=weight_decay)
-    elif optim_name == 'Adagrad':
-        lr_decay = optimizer_config.get('lr_decay', 0)
-        optimizer = optim.Adagrad(model.parameters(), lr=learning_rate, lr_decay=lr_decay,
-                                  weight_decay=weight_decay)
-    elif optim_name == 'AdamW':
-        betas = tuple(optimizer_config.get('betas', (0.9, 0.999)))
-        optimizer = optim.AdamW(model.parameters(), lr=learning_rate, betas=betas,
-                                weight_decay=weight_decay)
-    elif optim_name == 'SparseAdam':
-        betas = tuple(optimizer_config.get('betas', (0.9, 0.999)))
-        optimizer = optim.SparseAdam(model.parameters(), lr=learning_rate, betas=betas)
-    elif optim_name == 'Adamax':
-        betas = tuple(optimizer_config.get('betas', (0.9, 0.999)))
-        optimizer = optim.Adamax(model.parameters(), lr=learning_rate, betas=betas,
-                                 weight_decay=weight_decay)
-    elif optim_name == 'ASGD':
-        lambd = optimizer_config.get('lambd', 0.0001)
-        alpha = optimizer_config.get('alpha', 0.75)
-        t0 = optimizer_config.get('t0', 1e6)
-        optimizer = optim.Adamax(model.parameters(), lr=learning_rate, lambd=lambd,
-                                 alpha=alpha, t0=t0, weight_decay=weight_decay)
-    elif optim_name == 'LBFGS':
-        max_iter = optimizer_config.get('max_iter', 20)
-        max_eval = optimizer_config.get('max_eval', None)
-        tolerance_grad = optimizer_config.get('tolerance_grad', 1e-7)
-        tolerance_change = optimizer_config.get('tolerance_change', 1e-9)
-        history_size = optimizer_config.get('history_size', 100)
-        optimizer = optim.LBFGS(model.parameters(), lr=learning_rate, max_iter=max_iter,
-                                max_eval=max_eval, tolerance_grad=tolerance_grad,
-                                tolerance_change=tolerance_change, history_size=history_size)
-    elif optim_name == 'NAdam':
-        betas = tuple(optimizer_config.get('betas', (0.9, 0.999)))
-        momentum_decay = optimizer_config.get('momentum_decay', 4e-3)
-        optimizer = optim.NAdam(model.parameters(), lr=learning_rate, betas=betas,
-                                momentum_decay=momentum_decay,
-                                weight_decay=weight_decay)
-    elif optim_name == 'RAdam':
-        betas = tuple(optimizer_config.get('betas', (0.9, 0.999)))
-        optimizer = optim.RAdam(model.parameters(), lr=learning_rate, betas=betas,
-                                weight_decay=weight_decay)
-    elif optim_name == 'RMSprop':
-        alpha = optimizer_config.get('alpha', 0.99)
-        optimizer = optim.RMSprop(model.parameters(), lr=learning_rate, alpha=alpha,
-                                  weight_decay=weight_decay)
-    elif optim_name == 'Rprop':
-        momentum = optimizer_config.get('momentum', 0)
-        optimizer = optim.RMSprop(model.parameters(), lr=learning_rate, weight_decay=weight_decay, momentum=momentum)
-    elif optim_name == 'SGD':
-        momentum = optimizer_config.get('momentum', 0)
-        dampening = optimizer_config.get('dampening', 0)
-        nesterov = optimizer_config.get('nesterov', False)
-        optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum,
-                              dampening=dampening, nesterov=nesterov,
-                              weight_decay=weight_decay)
-    else:  # Adam is default
-        betas = tuple(optimizer_config.get('betas', (0.9, 0.999)))
-        optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=betas,
-                               weight_decay=weight_decay)
-
-    return optimizer
-
-
-def create_lr_scheduler(lr_config, optimizer):
-    if lr_config is None:
-        return None
-    class_name = lr_config.pop('name')
-    m = importlib.import_module('torch.optim.lr_scheduler')
-    clazz = getattr(m, class_name)
-    # add optimizer to the config
-    lr_config['optimizer'] = optimizer
-    return clazz(**lr_config)
-
-
-def get_class(class_name, modules):
-    for module in modules:
-        m = importlib.import_module(module)
-        clazz = getattr(m, class_name, None)
-        if clazz is not None:
-            return clazz
-    raise RuntimeError(f'Unsupported dataset class: {class_name}')

From e2143ecd491bfd694ecb41d6aa03491ec8d882d9 Mon Sep 17 00:00:00 2001
From: Shota Mizusaki <nrxg129@gmail.com>
Date: Tue, 30 Jul 2024 20:32:38 +0900
Subject: [PATCH 4/4] Since csr_matrix does not have .A, change to toarray().

---
 pytorch3dunet/unet3d/seg_metrics.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/pytorch3dunet/unet3d/seg_metrics.py b/pytorch3dunet/unet3d/seg_metrics.py
index e713ea23..ddafe59d 100644
--- a/pytorch3dunet/unet3d/seg_metrics.py
+++ b/pytorch3dunet/unet3d/seg_metrics.py
@@ -29,7 +29,7 @@ def _iou_matrix(gt, seg):
     seg = _relabel(seg)
 
     # get number of overlapping pixels between GT and SEG
-    n_inter = contingency_table(gt, seg).A
+    n_inter = contingency_table(gt, seg).toarray()
 
     # number of pixels for GT instances
     n_gt = n_inter.sum(axis=1, keepdims=True)