Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch vs. TensorFlow model #22

Open
mdeleeuw1 opened this issue Mar 1, 2024 · 2 comments
Open

PyTorch vs. TensorFlow model #22

mdeleeuw1 opened this issue Mar 1, 2024 · 2 comments

Comments

@mdeleeuw1
Copy link
Contributor

Dear Mart,

I hope all is well.

I have been using the PyTorch model of HookNet in a similar fashion as the HookNetPracticalGuideTorch.ipynb prescribes. However, when comparing performance with a 'normal' U-Net, I am finding the PyTorch version is outperformed by the normal U-Net. To be able to best compare U-Net to the configuration of HookNet in your paper, i.e. the TensorFlow model, I was wondering whether there are any major differences between the two versions? The various operations used in the encoders and decoders of both models look very similar. If there are any differences, why did you choose to make these differences between the two models?

Also, I would like to implement the multi-loss approach as performed in your paper, instead of the single-loss as is performed in the PyTorch practical guide. Is this as simple as defining separate losses in the training loop for each of the output of the model's three branches and performing loss.backward() for the combined loss? I have also made a few minor changes to the HookNet class to make this possible:

        self.high_last_conv = nn.Conv2d(self.high_mag_branch.decoder._out_channels[0], n_classes, 1)
        self.mid_last_conv = nn.Conv2d(self.mid_mag_branch.decoder._out_channels[0], n_classes, 1)
        self.low_last_conv = nn.Conv2d(self.low_mag_branch.decoder._out_channels[0], n_classes, 1)

    def forward(self, high_input, mid_input, low_input):
        low_out, low_hook_out = self.low_mag_branch(low_input)
        mid_out, mid_hook_out = self.mid_mag_branch(mid_input, low_hook_out)
        high_out, high_hook_out = self.high_mag_branch(high_input, mid_hook_out)
        return {'high_out': self.high_last_conv(high_out),
                'mid_out': self.mid_last_conv(mid_out),
                'low_out': self.low_last_conv(low_out)}

I have also changed the forward function of the Decoder class, such that it outputs both the feature maps relevant for the hooking mechanism 'hook_out', and the feature maps at the end of the decoder 'out' (the final output of the two lower-resolution branches was cut short due to the second if-statement):

    def forward(self, x, residuals, hook_in=None):
        out = x
        hook_true = False
        for d in reversed(range(self._depth)):
            if hook_in is not None and d == self._hook_to_index:
                out = concatenator(out, hook_in)
            
            out = self._decode_path[f"upsample{d}"](out)
            out = concatenator(out, residuals[d])
            out = self._decode_path[f"convblock{d}"](out)
            
            if self._hook_from_index is not None and d == self._hook_from_index:
                hook_out = out
                hook_true = True

        if hook_true == False:
            hook_out = out
        return out, hook_out

By doing so, the model now has an output for all three branches. I hope my approach makes sense (and is correct), do you think these changes suffice to realize a multi-loss approach for the PyTorch model? Would love to hear your insights!

Thanks in advance for your help!

Cheers,

Mike

@martvanrijthoven
Copy link
Collaborator

Dear Mike,

Sorry for my late reply.

HookNet (both Tensorflow and Pytorch version) consists of a U-Net models with very basic decoders and encoders. It might be that for specific tasks, a U-Net with a more advanced backbone eg resnet50, might perform better, especially on tasks where multiresolution is less important.

The only difference between the Tensorflow and PyTorch versions are

  • Tensorflow: 2 branches, PyTorch: 3 branches
  • Tensorflow: multiloss, PyTorch: no multiloss

The difference are because I thought 3 branches was interesting to implement and i did not have time to implement the multiloss. I think a multiloss for the Pytorch version makes a lot of sense. Your modifications look good to me, But it is good to double check all the output shapes. But at first sight i don't see any issues.

It is really great to hear that you are working on improving the PyTorch version. Please feel free to open PRs if would like your code to be merged :D

Best wishes,
Mart

@mdeleeuw1
Copy link
Contributor Author

Dear Mart,

Thank you for elaborating on my questions. The differences between the two versions are indeed what I expected :).

I will soon start with trying to run the model on an entire WSI and evaluating performance. If you have any tips on how to best configure the patch iterator etc. to do inference on entire WSIs and stitch the masks, these would be more than welcome! Thanks.

I will open a PR for the changes to the PyTorch version.

Kind regards,

Mike

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants