-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Training adjustments #4
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These augmentations are key to the network's performance in my opinion.
I tried to reproduce nnUNet's data augmentations, which are mainly based on the batchgenerators package. I wanted to not rely on either nnunet or batchgenerators package, so I tried implementing all the augmentations with MONAI.
Many of the augmentations were already available, others I translated from batchgenerators to MONAI (see also dataloading.transforms.py) .
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a classical Pytorch dataset with getitem function that applies data augmentation transforms on each image separately. I also tried to do it batchwise in the collate_fn, but I think this doesn't really make sense, as it makes things more complicated and does not give a large time improvement.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pytorch Lightning data module that initializes the datasets and data loaders.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good idea, thanks!
src/membrain_seg/segment.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's the (very preliminary) prediction script.
Most importantly, the parser takes as input the path to the tomogram to be segmented and the path to the model checkpoint that should be used.
Then the script performs the segmentation, also using 8-fold test time augmentation (using flipping along different axes). I experienced that this makes the segmentations more smooth and stable. I think it's worth doing it, even though it increases prediction speed by a lot.
(Was around 10-15 min / tomo in the end using an a100 GPU)
src/membrain_seg/train.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Training script: It initialized model checkpoint logging, as well as the loggers.
Note that this is still the WandB logger. For development, I like to keep this for now, but in the end, we should switch to Tensorboard logging (or even just CSV logging to avoid installing tensorflow?). But we'll just need to adjust one line for this.
Also some other stuff, like learning rate tracking is probably redundant here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Adjusted the DynUNet architecture to output feature maps from lower levels directly for deep supervision. Thus, we can compare the low-resolution outputs to downsampled versions of the GT map. (Instead of upsampling them first and then comparing to full-res GT). That's at least how nnunet does it -- not sure what's better here though.
- Define Deep Supervision loss that computes masked DICE & Cross entropy loss for each downsampling level.
- Define combination of Dice and CE loss that accepts ignore labels and does not evaluate the loss there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super nice!
src/membrain_seg/training/unet.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using the adjusted DynUNet (see optim_utils.py) with a lot of downsampling (5 time!) and same channel sizes as nnUNet.
Logging masked DICE score + accuracy (latter only for sanity -- it's not very informative) and loss for both validation & training (epoch-wise).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, @LorenzLamm ! This is great. Thanks for this. I have made some minor comments below. I think we merge soon so that we can keep moving.
Things I think we should change before merging:
- where possible, I think we should remove the
___main__
code from the various modules. I can see these are useful for testing, but I think we will want users to use our CLI (as we develop it) rather than calling the module files directly. - Can we keep the
SemanticSegmentationUnet
in thenetworks
module? Users/devs may want to import it for inference as well as training, so I think it would be more intuitive there.
Additionally, I think there are a few things we can do in follow-up PRs
- convert print statements to logging that way users and developers can easily suppress them.
- add docstrings with parameters and return values explained
os.path.basename(orig_data_path)[:-4] + "_" + ckpt_token + "_segmented.mrc", | ||
) | ||
store_tomogram(out_file_thres, predictions_np_thres) | ||
print("MemBrain has finished segmenting your tomogram.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be nice to wrap print statements like this under some sort of verbose
flag. For example, make verbose
an input parameters and then
if verbose:
print....
out_dir = "/scicore/home/engel0006/GROUP/pool-engel/Lorenz/MemBrain-seg/\ | ||
membrain-seg/sanity_imgs" | ||
patch_path1 = "/scicore/home/engel0006/GROUP/pool-engel/Lorenz/nnUNet_training/\ | ||
training_dirs/nnUNet_raw_data_base/nnUNet_raw_data/Task527_ChlamyV1/\ | ||
imagesTr/tomo02_patch000_0000.nii.gz" | ||
label_path1 = "/scicore/home/engel0006/GROUP/pool-engel/Lorenz/nnUNet_training/\ | ||
training_dirs/nnUNet_raw_data_base/nnUNet_raw_data/Task527_ChlamyV1/\ | ||
labelsTr/tomo02_patch000.nii.gz" | ||
img = np.expand_dims(read_nifti(patch_path1), 0) | ||
label = np.expand_dims(read_nifti(label_path1), 0) | ||
|
||
store_test_images( | ||
out_dir, | ||
img, | ||
label, | ||
get_training_transforms(prob_to_one=True, return_as_list=True), | ||
) | ||
get_augmentation_timing( | ||
imgs=[img, img], | ||
labels=[label, label], | ||
aug_sequence=get_training_transforms(prob_to_one=False, return_as_list=True), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are users expected to directly run this file or was this from testing? If it's just for testing, I think we should remove this. If not, I think we should replace the path variables to some sort of input parser as those paths won't work on the users system.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yes. Probably makes sense to remove this.
I used this to see the effects of the data augmentations, but we should not require users to do it as well. So I'll just remove it.
self.train_img_dir = os.path.join(self.data_dir, "imagesTr") | ||
self.train_lab_dir = os.path.join(self.data_dir, "labelsTr") | ||
self.val_img_dir = os.path.join(self.data_dir, "imagesVal") | ||
self.val_lab_dir = os.path.join(self.data_dir, "labelsVal") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we are assuming a structure for the data directory, I think it would be nice if we added a note to the doctstring explaining the expected data directory structure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. This is the data structure that I normally use for training, but maybe we should discuss at some point what data structure makes sense.
For now, I'll just add a docstring explaining the assumed data structure as it is right now.
The image is downsampled using nearest neighbor interpolation using a | ||
random scale factor. Afterwards, it is upsampled again with trilinear | ||
interpolation, imitating the upscaling of low-resolution images. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
neat idea!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good idea, thanks!
def test_loss_fn_correctness(): | ||
from membrain_seg.training.optim_utils import test_loss_fn | ||
|
||
test_loss_fn() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please move the test function here? While this works, I think it's better to separate the test code from the user-facing code. Also, it is helpful to test importing the functions here as they would be imported when the library is used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, makes sense. Will do! :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Corrected an error here: The script was saving score maps instead of segmentation masks if --store_probabilities flag was set to True.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed the testing functions for checking the effects of data augmentations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed the collate function (which was only used in a previous version with batch-wise augmentations).
Also removed the test functions for the dataset -- these stored some test images to check whether data loading and augmentations worked well.
src/membrain_seg/networks/unet.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed some redundant lines of code and renamed to /networks/unet.py again.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved the test functions to the tests folder instead of having them in the script itself.
Thanks a lot for your feedback, Kevin! I tried to work in your suggested changes and removed most of the main() functions + changed the unet folder back to "networks". As you suggest, probably it's best to replace the print statements and add more information to the docstring in later PRs. Need to invest some time for this I think :-D |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me! Thank you for the quick turnaround, @LorenzLamm . I will merge this and make a couple of issues to track some of the things we are punting on.
Great work!
I tried to reproduce the performance of our best model, which is based on nnUNet.
To this end, I added a lot of data augmentations, adjusted the model architecture, including deep supervision, and refined the training procedure (e.g. PolyLR scheduler, slight weight decay).
Now I'm happy with the performance. Please have a look and see if everything is fine :)