end-to-end example on how to use
faimed3d
to train a fully 3D U-Net with a pretrained encoder
from faimed3d.all import *
from torchvision.models.video import r3d_18
from fastai.callback.all import SaveModelCallback
from torch import nn
REPO_DIR = Path(os.getcwd()).parent
DATA_DIR = Path('/media/ScaleOut/vahldiek/CT/1_Covid-19-CT-Seg')
ct_images = (DATA_DIR/'coronacases').ls()
ct_images.sort()
infect_mask = (DATA_DIR/'infection_mask').ls()
infect_mask.sort()
d = pd.DataFrame({'images': ct_images,
'masks': infect_mask})
TensorDicom3D.create(d.images[0]).show()
TensorDicom3D.create(d.images[11]).show()
The radiopeadia images have already been windowed, which will make it difficult for the model to generalize, so they are excluded.
Training on only six images, makes the epoch very short. As at the beginning of each epoch PyTorch and fastai need some time to prepare the dataloaders and model (e.g. switch to/from evaluation mode), artificially increasing the data through oversampling can speed up the overall training process.
d = d[:10] # exclude all radiopaedia cases
d['is_valid'] = [1,1,1,1,0,0,0,0,0,0]
d_oversampled = pd.concat((d, )*5) # oversample
Six images will be used for training and five for validation, so the data augmentations are increased.
dls = SegmentationDataLoaders3D.from_df(d_oversampled, path = '/',
item_tfms = ResizeCrop3D((0, 0, 0), (20, 224, 224)),
batch_tfms = [RandomPerspective3D(224, 0.5),
*aug_transforms_3d(p_all=0.15, noise=False)],
bs = 1,
val_bs = 1,
splitter = ColSplitter('is_valid'))
Combining dice_loss
and CrossEntropyLoss
can be a helpful technique to get faster convergence.
def dice(input, target):
iflat = input.contiguous().view(-1)
tflat = target.contiguous().view(-1)
intersection = (iflat * tflat).sum()
return ((2. * intersection) /
(iflat.sum() + tflat.sum()))
def dice_score(input, target):
return dice(input.argmax(1), target)
def dice_loss(input, target):
return 1 - dice(input.softmax(1)[:, 1], target)
def loss(input, target):
return dice_loss(input, target) + nn.CrossEntropyLoss()(input, target[:, 0])
learn = unet_learner_3d(dls, r3d_18, n_out=2,
loss_func = loss,
metrics = dice_score,
model_dir = REPO_DIR/'models',
cbs = [SaveModelCallback(monitor='dice_score')]
)
learn = learn.to_fp16()
learn.lr_find()
SuggestedLRs(lr_min=0.05248074531555176, lr_steep=0.43651583790779114)
The suggested learning rates by the lr_finder
are usually to high, leading to exploding gradients. It is better to divide the suggested LR by 10 or 100. Still, training at the beginning is often unstable, but will become more stable with more epochs. nan
for validation loss in early epochs, is not necessarily a reason to stop the training.
learn.fit_one_cycle(3, 0.01, wd = 1e-4)
epoch | train_loss | valid_loss | dice_score | time |
---|---|---|---|---|
0 | 1.617970 | 7.553207 | 0.022217 | 01:23 |
1 | 1.447020 | 1.348669 | 0.030857 | 01:14 |
2 | 1.348788 | 1.290999 | 0.219593 | 01:14 |
Better model found at epoch 0 with dice_score value: 0.022217031568288803.
Better model found at epoch 1 with dice_score value: 0.03085685335099697.
Better model found at epoch 2 with dice_score value: 0.21959252655506134.
learn.unfreeze()
learn.fit_one_cycle(50, 1e-3, wd = 1e-4)
epoch | train_loss | valid_loss | dice_score | time |
---|---|---|---|---|
0 | 1.277849 | 1.300274 | 0.169884 | 01:16 |
1 | 1.271986 | 1.285651 | 0.259191 | 01:18 |
2 | 1.267972 | 1.266772 | 0.394450 | 01:19 |
3 | 1.262012 | 1.279280 | 0.284623 | 01:18 |
4 | 1.256839 | 1.258349 | 0.492355 | 01:18 |
5 | 1.254410 | 1.248541 | 0.528375 | 01:19 |
6 | 1.243158 | 1.260393 | 0.261894 | 01:19 |
7 | 1.233036 | nan | 0.048495 | 01:18 |
8 | 1.219665 | 1.205178 | 0.570096 | 01:18 |
9 | 1.206917 | 1.301556 | 0.091743 | 01:19 |
10 | 1.186561 | 1.181180 | 0.487740 | 01:18 |
11 | 1.169439 | 1.216021 | 0.182845 | 01:18 |
12 | 1.157408 | 1.165419 | 0.395111 | 01:18 |
13 | 1.139597 | 1.136980 | 0.524025 | 01:18 |
14 | 1.124535 | 1.131342 | 0.527730 | 01:18 |
15 | 1.104634 | 1.211899 | 0.111883 | 01:19 |
16 | 1.097506 | 1.103193 | 0.568950 | 01:18 |
17 | 1.084650 | 1.088717 | 0.610375 | 01:19 |
18 | 1.067548 | 1.070056 | 0.628555 | 01:18 |
19 | 1.048581 | 1.083311 | 0.375368 | 01:18 |
20 | 1.038405 | 1.056531 | 0.538761 | 01:19 |
21 | 1.032926 | 1.091693 | 0.297455 | 01:18 |
22 | 1.027478 | 1.053580 | 0.492599 | 01:19 |
23 | 1.012300 | 1.017731 | 0.667421 | 01:18 |
24 | 0.999201 | 1.014029 | 0.657778 | 01:19 |
25 | 0.984335 | 1.012218 | 0.536109 | 01:18 |
26 | 0.978572 | 1.037715 | 0.311021 | 01:19 |
27 | 0.972869 | 1.004563 | 0.572911 | 01:19 |
28 | 0.962912 | 0.985065 | 0.598763 | 01:19 |
29 | 0.957247 | 0.983570 | 0.642466 | 01:19 |
30 | 0.946731 | 0.974621 | 0.605570 | 01:18 |
31 | 0.944053 | 0.999155 | 0.416248 | 01:19 |
32 | 0.936624 | 0.969453 | 0.601356 | 01:19 |
33 | 0.931428 | 0.954822 | 0.649087 | 01:19 |
34 | 0.916852 | 0.948883 | 0.629919 | 01:19 |
35 | 0.907974 | 0.943074 | 0.655565 | 01:19 |
36 | 0.904585 | 0.958528 | 0.564986 | 01:19 |
37 | 0.903870 | 0.953206 | 0.572820 | 01:18 |
38 | 0.896562 | 0.934182 | 0.703282 | 01:19 |
39 | 0.894336 | 0.929551 | 0.690418 | 01:19 |
40 | 0.885841 | 0.932464 | 0.677648 | 01:19 |
41 | 0.893995 | 0.955278 | 0.528160 | 01:19 |
42 | 0.887511 | 0.928274 | 0.674235 | 01:19 |
43 | 0.882461 | 0.924388 | 0.697535 | 01:19 |
44 | 0.882864 | 0.928446 | 0.689763 | 01:19 |
45 | 0.880625 | 0.922291 | 0.720408 | 01:19 |
46 | 0.885734 | 0.925096 | 0.712881 | 01:19 |
47 | 0.878734 | 0.919903 | 0.715904 | 01:18 |
48 | 0.879444 | 0.932021 | 0.633199 | 01:19 |
49 | 0.873842 | 0.916253 | 0.736103 | 01:18 |
Better model found at epoch 0 with dice_score value: 0.16988438367843628.
Better model found at epoch 1 with dice_score value: 0.25919052958488464.
Better model found at epoch 2 with dice_score value: 0.39445048570632935.
Better model found at epoch 4 with dice_score value: 0.49235478043556213.
Better model found at epoch 5 with dice_score value: 0.5283752679824829.
Better model found at epoch 8 with dice_score value: 0.5700957775115967.
Better model found at epoch 17 with dice_score value: 0.6103752851486206.
Better model found at epoch 18 with dice_score value: 0.6285550594329834.
Better model found at epoch 23 with dice_score value: 0.6674210429191589.
Better model found at epoch 38 with dice_score value: 0.7032817602157593.
Better model found at epoch 45 with dice_score value: 0.7204076647758484.
Better model found at epoch 49 with dice_score value: 0.736102819442749.
Evalutation is done only on the original data.
dls = SegmentationDataLoaders3D.from_df(d, path = '/',
item_tfms = ResizeCrop3D((0, 0, 0), (20, 224, 224)),
batch_tfms = [RandomPerspective3D(224, 0.5),
*aug_transforms_3d(p_all=0.5)],
bs = 1,
val_bs = 1,
splitter = ColSplitter('is_valid'))
learn = learn.load('model')
inp, pred, target = learn.get_preds(with_input = True)
inp.show(nrow=10, figsize = (30, 30))
pred.argmax(1).show(nrow=10, add_to_existing = True, alpha = 0.25, cmap = 'jet')
inp.show(nrow=10, figsize = (30, 30))
target.show(nrow=10, add_to_existing = True, alpha = 0.25, cmap = 'jet')
dice_score(pred, target)
TensorMask3D(0.7255)
A micro averaged dice score of 0.726 is similar to the macro averaged dice score of 0.673 as reported by Ma et al. although, micro averaged scores tend to be slightly higher and Ma et al. only used four CT examinations instead of six to train the segmentation model. Still this is close to the published state of the art.