Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training adjustments #4

Merged
merged 14 commits into from
Apr 20, 2023
Merged

Training adjustments #4

merged 14 commits into from
Apr 20, 2023

Conversation

LorenzLamm
Copy link
Collaborator

I tried to reproduce the performance of our best model, which is based on nnUNet.
To this end, I added a lot of data augmentations, adjusted the model architecture, including deep supervision, and refined the training procedure (e.g. PolyLR scheduler, slight weight decay).
Now I'm happy with the performance. Please have a look and see if everything is fine :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These augmentations are key to the network's performance in my opinion.
I tried to reproduce nnUNet's data augmentations, which are mainly based on the batchgenerators package. I wanted to not rely on either nnunet or batchgenerators package, so I tried implementing all the augmentations with MONAI.
Many of the augmentations were already available, others I translated from batchgenerators to MONAI (see also dataloading.transforms.py) .

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a classical Pytorch dataset with getitem function that applies data augmentation transforms on each image separately. I also tried to do it batchwise in the collate_fn, but I think this doesn't really make sense, as it makes things more complicated and does not give a large time improvement.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pytorch Lightning data module that initializes the datasets and data loaders.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea, thanks!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's the (very preliminary) prediction script.
Most importantly, the parser takes as input the path to the tomogram to be segmented and the path to the model checkpoint that should be used.

Then the script performs the segmentation, also using 8-fold test time augmentation (using flipping along different axes). I experienced that this makes the segmentations more smooth and stable. I think it's worth doing it, even though it increases prediction speed by a lot.
(Was around 10-15 min / tomo in the end using an a100 GPU)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Training script: It initialized model checkpoint logging, as well as the loggers.

Note that this is still the WandB logger. For development, I like to keep this for now, but in the end, we should switch to Tensorboard logging (or even just CSV logging to avoid installing tensorflow?). But we'll just need to adjust one line for this.

Also some other stuff, like learning rate tracking is probably redundant here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Adjusted the DynUNet architecture to output feature maps from lower levels directly for deep supervision. Thus, we can compare the low-resolution outputs to downsampled versions of the GT map. (Instead of upsampling them first and then comparing to full-res GT). That's at least how nnunet does it -- not sure what's better here though.
  2. Define Deep Supervision loss that computes masked DICE & Cross entropy loss for each downsampling level.
  3. Define combination of Dice and CE loss that accepts ignore labels and does not evaluate the loss there.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super nice!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using the adjusted DynUNet (see optim_utils.py) with a lot of downsampling (5 time!) and same channel sizes as nnUNet.
Logging masked DICE score + accuracy (latter only for sanity -- it's not very informative) and loss for both validation & training (epoch-wise).

Copy link
Collaborator

@kevinyamauchi kevinyamauchi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, @LorenzLamm ! This is great. Thanks for this. I have made some minor comments below. I think we merge soon so that we can keep moving.

Things I think we should change before merging:

  • where possible, I think we should remove the ___main__ code from the various modules. I can see these are useful for testing, but I think we will want users to use our CLI (as we develop it) rather than calling the module files directly.
  • Can we keep the SemanticSegmentationUnet in the networks module? Users/devs may want to import it for inference as well as training, so I think it would be more intuitive there.

Additionally, I think there are a few things we can do in follow-up PRs

  • convert print statements to logging that way users and developers can easily suppress them.
  • add docstrings with parameters and return values explained

os.path.basename(orig_data_path)[:-4] + "_" + ckpt_token + "_segmented.mrc",
)
store_tomogram(out_file_thres, predictions_np_thres)
print("MemBrain has finished segmenting your tomogram.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be nice to wrap print statements like this under some sort of verbose flag. For example, make verbose an input parameters and then

if verbose:
    print....

Comment on lines 340 to 361
out_dir = "/scicore/home/engel0006/GROUP/pool-engel/Lorenz/MemBrain-seg/\
membrain-seg/sanity_imgs"
patch_path1 = "/scicore/home/engel0006/GROUP/pool-engel/Lorenz/nnUNet_training/\
training_dirs/nnUNet_raw_data_base/nnUNet_raw_data/Task527_ChlamyV1/\
imagesTr/tomo02_patch000_0000.nii.gz"
label_path1 = "/scicore/home/engel0006/GROUP/pool-engel/Lorenz/nnUNet_training/\
training_dirs/nnUNet_raw_data_base/nnUNet_raw_data/Task527_ChlamyV1/\
labelsTr/tomo02_patch000.nii.gz"
img = np.expand_dims(read_nifti(patch_path1), 0)
label = np.expand_dims(read_nifti(label_path1), 0)

store_test_images(
out_dir,
img,
label,
get_training_transforms(prob_to_one=True, return_as_list=True),
)
get_augmentation_timing(
imgs=[img, img],
labels=[label, label],
aug_sequence=get_training_transforms(prob_to_one=False, return_as_list=True),
)
Copy link
Collaborator

@kevinyamauchi kevinyamauchi Apr 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are users expected to directly run this file or was this from testing? If it's just for testing, I think we should remove this. If not, I think we should replace the path variables to some sort of input parser as those paths won't work on the users system.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes. Probably makes sense to remove this.
I used this to see the effects of the data augmentations, but we should not require users to do it as well. So I'll just remove it.

Comment on lines +15 to +18
self.train_img_dir = os.path.join(self.data_dir, "imagesTr")
self.train_lab_dir = os.path.join(self.data_dir, "labelsTr")
self.val_img_dir = os.path.join(self.data_dir, "imagesVal")
self.val_lab_dir = os.path.join(self.data_dir, "labelsVal")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are assuming a structure for the data directory, I think it would be nice if we added a note to the doctstring explaining the expected data directory structure.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. This is the data structure that I normally use for training, but maybe we should discuss at some point what data structure makes sense.
For now, I'll just add a docstring explaining the assumed data structure as it is right now.

Comment on lines +159 to +161
The image is downsampled using nearest neighbor interpolation using a
random scale factor. Afterwards, it is upsampled again with trilinear
interpolation, imitating the upscaling of low-resolution images.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

neat idea!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea, thanks!

def test_loss_fn_correctness():
from membrain_seg.training.optim_utils import test_loss_fn

test_loss_fn()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please move the test function here? While this works, I think it's better to separate the test code from the user-facing code. Also, it is helpful to test importing the functions here as they would be imported when the library is used.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, makes sense. Will do! :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Corrected an error here: The script was saving score maps instead of segmentation masks if --store_probabilities flag was set to True.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the testing functions for checking the effects of data augmentations

Copy link
Collaborator Author

@LorenzLamm LorenzLamm Apr 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the collate function (which was only used in a previous version with batch-wise augmentations).
Also removed the test functions for the dataset -- these stored some test images to check whether data loading and augmentations worked well.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed some redundant lines of code and renamed to /networks/unet.py again.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved the test functions to the tests folder instead of having them in the script itself.

@LorenzLamm
Copy link
Collaborator Author

Hey, @LorenzLamm ! This is great. Thanks for this. I have made some minor comments below. I think we merge soon so that we can keep moving.

Things I think we should change before merging:

  • where possible, I think we should remove the ___main__ code from the various modules. I can see these are useful for testing, but I think we will want users to use our CLI (as we develop it) rather than calling the module files directly.
  • Can we keep the SemanticSegmentationUnet in the networks module? Users/devs may want to import it for inference as well as training, so I think it would be more intuitive there.

Additionally, I think there are a few things we can do in follow-up PRs

  • convert print statements to logging that way users and developers can easily suppress them.
  • add docstrings with parameters and return values explained

Thanks a lot for your feedback, Kevin! I tried to work in your suggested changes and removed most of the main() functions + changed the unet folder back to "networks".
Also cleaned up some lines of code & removed some redundant functions.

As you suggest, probably it's best to replace the print statements and add more information to the docstring in later PRs. Need to invest some time for this I think :-D

Copy link
Collaborator

@kevinyamauchi kevinyamauchi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! Thank you for the quick turnaround, @LorenzLamm . I will merge this and make a couple of issues to track some of the things we are punting on.

Great work!

@kevinyamauchi kevinyamauchi merged commit 3e07955 into main Apr 20, 2023
This was referenced Apr 20, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants