Skip to content

Latest commit

 

History

History
592 lines (526 loc) · 13.2 KB

3d_segmentation.md

File metadata and controls

592 lines (526 loc) · 13.2 KB

Example - Train a 3d U-Net for Segmentation of Pulmonary Infiltrates Associated with COVID-19

end-to-end example on how to use faimed3d to train a fully 3D U-Net with a pretrained encoder

from faimed3d.all import *
from torchvision.models.video import r3d_18
from fastai.callback.all import SaveModelCallback
from torch import nn
REPO_DIR = Path(os.getcwd()).parent
DATA_DIR = Path('/media/ScaleOut/vahldiek/CT/1_Covid-19-CT-Seg')
ct_images = (DATA_DIR/'coronacases').ls()
ct_images.sort()
infect_mask = (DATA_DIR/'infection_mask').ls()
infect_mask.sort()
d = pd.DataFrame({'images': ct_images, 
                 'masks': infect_mask})
TensorDicom3D.create(d.images[0]).show()

png

TensorDicom3D.create(d.images[11]).show()

png

The radiopeadia images have already been windowed, which will make it difficult for the model to generalize, so they are excluded.
Training on only six images, makes the epoch very short. As at the beginning of each epoch PyTorch and fastai need some time to prepare the dataloaders and model (e.g. switch to/from evaluation mode), artificially increasing the data through oversampling can speed up the overall training process.

d = d[:10] # exclude all radiopaedia cases
d['is_valid'] = [1,1,1,1,0,0,0,0,0,0]
d_oversampled = pd.concat((d, )*5) # oversample

Six images will be used for training and five for validation, so the data augmentations are increased.

dls = SegmentationDataLoaders3D.from_df(d_oversampled, path = '/',
                                item_tfms = ResizeCrop3D((0, 0, 0), (20, 224, 224)), 
                                batch_tfms = [RandomPerspective3D(224, 0.5), 
                                              *aug_transforms_3d(p_all=0.15, noise=False)],
                                bs = 1, 
                                val_bs = 1,
                                splitter = ColSplitter('is_valid'))

Combining dice_loss and CrossEntropyLoss can be a helpful technique to get faster convergence.

def dice(input, target):
    iflat = input.contiguous().view(-1)
    tflat = target.contiguous().view(-1)
    intersection = (iflat * tflat).sum()
    return ((2. * intersection) /
           (iflat.sum() + tflat.sum()))

def dice_score(input, target):
    return dice(input.argmax(1), target)

def dice_loss(input, target): 
    return 1 - dice(input.softmax(1)[:, 1], target)

def loss(input, target):
    return dice_loss(input, target) + nn.CrossEntropyLoss()(input, target[:, 0])
learn = unet_learner_3d(dls, r3d_18, n_out=2, 
                        loss_func = loss,
                        metrics = dice_score,
                        model_dir = REPO_DIR/'models',
                        cbs = [SaveModelCallback(monitor='dice_score')]
                       )
learn = learn.to_fp16()
learn.lr_find()
SuggestedLRs(lr_min=0.05248074531555176, lr_steep=0.43651583790779114)

png

The suggested learning rates by the lr_finder are usually to high, leading to exploding gradients. It is better to divide the suggested LR by 10 or 100. Still, training at the beginning is often unstable, but will become more stable with more epochs. nan for validation loss in early epochs, is not necessarily a reason to stop the training.

learn.fit_one_cycle(3, 0.01, wd = 1e-4)
epoch train_loss valid_loss dice_score time
0 1.617970 7.553207 0.022217 01:23
1 1.447020 1.348669 0.030857 01:14
2 1.348788 1.290999 0.219593 01:14
Better model found at epoch 0 with dice_score value: 0.022217031568288803.
Better model found at epoch 1 with dice_score value: 0.03085685335099697.
Better model found at epoch 2 with dice_score value: 0.21959252655506134.
learn.unfreeze()
learn.fit_one_cycle(50, 1e-3, wd = 1e-4)
epoch train_loss valid_loss dice_score time
0 1.277849 1.300274 0.169884 01:16
1 1.271986 1.285651 0.259191 01:18
2 1.267972 1.266772 0.394450 01:19
3 1.262012 1.279280 0.284623 01:18
4 1.256839 1.258349 0.492355 01:18
5 1.254410 1.248541 0.528375 01:19
6 1.243158 1.260393 0.261894 01:19
7 1.233036 nan 0.048495 01:18
8 1.219665 1.205178 0.570096 01:18
9 1.206917 1.301556 0.091743 01:19
10 1.186561 1.181180 0.487740 01:18
11 1.169439 1.216021 0.182845 01:18
12 1.157408 1.165419 0.395111 01:18
13 1.139597 1.136980 0.524025 01:18
14 1.124535 1.131342 0.527730 01:18
15 1.104634 1.211899 0.111883 01:19
16 1.097506 1.103193 0.568950 01:18
17 1.084650 1.088717 0.610375 01:19
18 1.067548 1.070056 0.628555 01:18
19 1.048581 1.083311 0.375368 01:18
20 1.038405 1.056531 0.538761 01:19
21 1.032926 1.091693 0.297455 01:18
22 1.027478 1.053580 0.492599 01:19
23 1.012300 1.017731 0.667421 01:18
24 0.999201 1.014029 0.657778 01:19
25 0.984335 1.012218 0.536109 01:18
26 0.978572 1.037715 0.311021 01:19
27 0.972869 1.004563 0.572911 01:19
28 0.962912 0.985065 0.598763 01:19
29 0.957247 0.983570 0.642466 01:19
30 0.946731 0.974621 0.605570 01:18
31 0.944053 0.999155 0.416248 01:19
32 0.936624 0.969453 0.601356 01:19
33 0.931428 0.954822 0.649087 01:19
34 0.916852 0.948883 0.629919 01:19
35 0.907974 0.943074 0.655565 01:19
36 0.904585 0.958528 0.564986 01:19
37 0.903870 0.953206 0.572820 01:18
38 0.896562 0.934182 0.703282 01:19
39 0.894336 0.929551 0.690418 01:19
40 0.885841 0.932464 0.677648 01:19
41 0.893995 0.955278 0.528160 01:19
42 0.887511 0.928274 0.674235 01:19
43 0.882461 0.924388 0.697535 01:19
44 0.882864 0.928446 0.689763 01:19
45 0.880625 0.922291 0.720408 01:19
46 0.885734 0.925096 0.712881 01:19
47 0.878734 0.919903 0.715904 01:18
48 0.879444 0.932021 0.633199 01:19
49 0.873842 0.916253 0.736103 01:18
Better model found at epoch 0 with dice_score value: 0.16988438367843628.
Better model found at epoch 1 with dice_score value: 0.25919052958488464.
Better model found at epoch 2 with dice_score value: 0.39445048570632935.
Better model found at epoch 4 with dice_score value: 0.49235478043556213.
Better model found at epoch 5 with dice_score value: 0.5283752679824829.
Better model found at epoch 8 with dice_score value: 0.5700957775115967.
Better model found at epoch 17 with dice_score value: 0.6103752851486206.
Better model found at epoch 18 with dice_score value: 0.6285550594329834.
Better model found at epoch 23 with dice_score value: 0.6674210429191589.
Better model found at epoch 38 with dice_score value: 0.7032817602157593.
Better model found at epoch 45 with dice_score value: 0.7204076647758484.
Better model found at epoch 49 with dice_score value: 0.736102819442749.

Evaluation

Evalutation is done only on the original data.

dls = SegmentationDataLoaders3D.from_df(d, path = '/',
                                item_tfms = ResizeCrop3D((0, 0, 0), (20, 224, 224)), 
                                batch_tfms = [RandomPerspective3D(224, 0.5), 
                                              *aug_transforms_3d(p_all=0.5)],
                                bs = 1, 
                                val_bs = 1,
                                splitter = ColSplitter('is_valid'))
learn = learn.load('model')
inp, pred, target = learn.get_preds(with_input = True)
inp.show(nrow=10, figsize = (30, 30))
pred.argmax(1).show(nrow=10, add_to_existing = True, alpha = 0.25, cmap = 'jet')

png

inp.show(nrow=10, figsize = (30, 30))
target.show(nrow=10, add_to_existing = True, alpha = 0.25, cmap = 'jet')

png

dice_score(pred, target)
TensorMask3D(0.7255)

A micro averaged dice score of 0.726 is similar to the macro averaged dice score of 0.673 as reported by Ma et al. although, micro averaged scores tend to be slightly higher and Ma et al. only used four CT examinations instead of six to train the segmentation model. Still this is close to the published state of the art.