Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add focal loss #1024

Merged
merged 20 commits into from
Dec 3, 2021
Merged

[Feature] Add focal loss #1024

merged 20 commits into from
Dec 3, 2021

Conversation

RockeyCoss
Copy link
Contributor

@RockeyCoss RockeyCoss commented Nov 8, 2021

Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.

Motivation

Add focal loss

Modification

Add focal loss

Experiment Results

We can see through the experiments that:

  1. When applying UNet on the medical dataset(including chase, stare, hrf and drive), using focal loss with gamma between 0 to 1 yields relatively good results.
  2. Even with the best choice of gamma, focal loss still can't beat cross entropy loss when comparing their best dice scores during training. When comparing their last dice scores, focal loss outperforms cross entropy loss on stare dataset.
  3. Dice loss performs quite good that it outperforms cross entropy loss and focal loss on three datasets.
  4. With proper weight, training with multiple losses can yield clear and consistent improvement. For example, focal loss + cross entropy with weight 1:3 or 3:1 and focal loss + dice loss with weight 1:4 yield good results on four datasets.
exp_num method vessel dice best best index vessel dice last last index
1 pspnet_unet_s5-d16_128x128_40k_chase_db1 0.804 10 0.804 10
2 pspnet_unet_s5-d16_128x128_40k_chase_db1_dice 0.797 10 0.797 10
3 pspnet_unet_s5-d16_128x128_40k_chase_db1_gamma0 0.8009 10 0.8009 10
4 pspnet_unet_s5-d16_128x128_40k_chase_db1_gamma0.5 0.8001 10 0.8001 10
5 pspnet_unet_s5-d16_128x128_40k_chase_db1_gamma0.5_ce1to1 0.8039 10 0.8039 10
6 pspnet_unet_s5-d16_128x128_40k_chase_db1_gamma0.5_ce1to2 0.8054 10 0.8054 10
7 pspnet_unet_s5-d16_128x128_40k_chase_db1_gamma0.5_ce1to3 0.8061 10 0.8061 10
8 pspnet_unet_s5-d16_128x128_40k_chase_db1_gamma0.5_ce1to4 0.8037 11 0.8037 11
9 pspnet_unet_s5-d16_128x128_40k_chase_db1_gamma0.5_ce2to1 0.8048 10 0.8048 10
10 pspnet_unet_s5-d16_128x128_40k_chase_db1_gamma0.5_ce3to1 0.8052 11 0.8052 11
11 pspnet_unet_s5-d16_128x128_40k_chase_db1_gamma0.5_ce4to1 0.8048 9 0.8037 10
12 pspnet_unet_s5-d16_128x128_40k_chase_db1_gamma0.5_dice1to1 0.8026 10 0.8026 10
13 pspnet_unet_s5-d16_128x128_40k_chase_db1_gamma0.5_dice1to2 0.8005 10 0.8005 10
14 pspnet_unet_s5-d16_128x128_40k_chase_db1_gamma0.5_dice1to3 0.7979 11 0.7979 11
15 pspnet_unet_s5-d16_128x128_40k_chase_db1_gamma0.5_dice1to4 0.7986 12 0.7986 12
16 pspnet_unet_s5-d16_128x128_40k_chase_db1_gamma0.5_dice2to1 0.8012 10 0.801 11
17 pspnet_unet_s5-d16_128x128_40k_chase_db1_gamma0.5_dice3to1 0.8052 10 0.8031 11
18 pspnet_unet_s5-d16_128x128_40k_chase_db1_gamma0.5_dice4to1 0.8076 10 0.8076 10
19 pspnet_unet_s5-d16_128x128_40k_chase_db1_gamma1 0.7976 10 0.7976 10
20 pspnet_unet_s5-d16_128x128_40k_chase_db1_gamma1.5 0.793 10 0.793 10
21 pspnet_unet_s5-d16_128x128_40k_chase_db1_gamma2 0.7877 10 0.7877 10
22 pspnet_unet_s5-d16_128x128_40k_chase_db1_gamma4 0.7604 10 0.7604 10
23 pspnet_unet_s5-d16_128x128_40k_chase_db1_gamma6 0.6452 10 0.6452 10
24 pspnet_unet_s5-d16_128x128_40k_stare 0.8186 7 0.8097 10
25 pspnet_unet_s5-d16_128x128_40k_stare_dice 0.8258 10 0.8258 10
26 pspnet_unet_s5-d16_128x128_40k_stare_gamma0 0.8149 7 0.8123 10
27 pspnet_unet_s5-d16_128x128_40k_stare_gamma0.5 0.8138 7 0.8129 10
28 pspnet_unet_s5-d16_128x128_40k_stare_gamma1 0.8151 10 0.8151 10
29 pspnet_unet_s5-d16_128x128_40k_stare_gamma1.5 0.8129 9 0.8095 10
30 pspnet_unet_s5-d16_128x128_40k_stare_gamma1_ce1to1 0.8199 4 0.8146 10
31 pspnet_unet_s5-d16_128x128_40k_stare_gamma1_ce1to2 0.8195 4 0.8149 10
32 pspnet_unet_s5-d16_128x128_40k_stare_gamma1_ce1to3 0.8233 4 0.816 10
33 pspnet_unet_s5-d16_128x128_40k_stare_gamma1_ce1to4 0.8205 7 0.8149 10
34 pspnet_unet_s5-d16_128x128_40k_stare_gamma1_ce2to1 0.8159 8 0.8115 10
35 pspnet_unet_s5-d16_128x128_40k_stare_gamma1_ce3to1 0.8218 7 0.8134 11
36 pspnet_unet_s5-d16_128x128_40k_stare_gamma1_ce4to1 0.8221 5 0.8141 10
37 pspnet_unet_s5-d16_128x128_40k_stare_gamma1_dice1to1 0.8263 10 0.8263 10
38 pspnet_unet_s5-d16_128x128_40k_stare_gamma1_dice1to2 0.8274 10 0.8274 10
39 pspnet_unet_s5-d16_128x128_40k_stare_gamma1_dice1to3 0.8298 10 0.8239 11
40 pspnet_unet_s5-d16_128x128_40k_stare_gamma1_dice1to4 0.8262 11 0.8185 12
41 pspnet_unet_s5-d16_128x128_40k_stare_gamma1_dice2to1 0.825 10 0.825 10
42 pspnet_unet_s5-d16_128x128_40k_stare_gamma1_dice3to1 0.8212 10 0.8204 11
43 pspnet_unet_s5-d16_128x128_40k_stare_gamma1_dice4to1 0.8248 8 0.8221 11
44 pspnet_unet_s5-d16_128x128_40k_stare_gamma2 0.8123 9 0.8092 10
45 pspnet_unet_s5-d16_128x128_40k_stare_gamma4 0.7885 10 0.7885 10
46 pspnet_unet_s5-d16_128x128_40k_stare_gamma6 0.6574 10 0.6574 10
47 pspnet_unet_s5-d16_256x256_40k_hrf 0.7992 8 0.7946 12
48 pspnet_unet_s5-d16_256x256_40k_hrf_dice 0.8082 10 0.8082 10
49 pspnet_unet_s5-d16_256x256_40k_hrf_gamma0 0.7947 8 0.7875 10
50 pspnet_unet_s5-d16_256x256_40k_hrf_gamma0.5 0.7959 8 0.79 10
51 pspnet_unet_s5-d16_256x256_40k_hrf_gamma1 0.7962 8 0.7889 10
52 pspnet_unet_s5-d16_256x256_40k_hrf_gamma1.5 0.7942 8 0.789 10
53 pspnet_unet_s5-d16_256x256_40k_hrf_gamma1_ce1to1 0.8059 8 0.7964 10
54 pspnet_unet_s5-d16_256x256_40k_hrf_gamma1_ce1to2 0.804 8 0.7938 10
55 pspnet_unet_s5-d16_256x256_40k_hrf_gamma1_ce1to3 0.8078 8 0.7962 10
56 pspnet_unet_s5-d16_256x256_40k_hrf_gamma1_ce1to4 0.8023 10 0.7936 11
57 pspnet_unet_s5-d16_256x256_40k_hrf_gamma1_ce2to1 0.8022 8 0.7897 10
58 pspnet_unet_s5-d16_256x256_40k_hrf_gamma1_ce3to1 0.8061 8 0.7917 10
59 pspnet_unet_s5-d16_256x256_40k_hrf_gamma1_ce4to1 0.8006 8 0.7937 10
60 pspnet_unet_s5-d16_256x256_40k_hrf_gamma1_dice1to1 0.8083 8 0.8076 10
61 pspnet_unet_s5-d16_256x256_40k_hrf_gamma1_dice1to2 0.8091 10 0.8091 10
62 pspnet_unet_s5-d16_256x256_40k_hrf_gamma1_dice1to3 0.8063 8 0.8061 11
63 pspnet_unet_s5-d16_256x256_40k_hrf_gamma1_dice1to4 0.8055 12 0.8055 12
64 pspnet_unet_s5-d16_256x256_40k_hrf_gamma1_dice2to1 0.8082 11 0.8063 12
65 pspnet_unet_s5-d16_256x256_40k_hrf_gamma1_dice3to1 0.8105 8 0.8103 10
66 pspnet_unet_s5-d16_256x256_40k_hrf_gamma1_dice4to1 0.8097 8 0.8082 10
67 pspnet_unet_s5-d16_256x256_40k_hrf_gamma2 0.7881 9 0.7837 10
68 pspnet_unet_s5-d16_256x256_40k_hrf_gamma4 0.7745 9 0.7628 10
69 pspnet_unet_s5-d16_256x256_40k_hrf_gamma6 0.6749 9 0.6145 10
70 pspnet_unet_s5-d16_64x64_40k_drive 0.7926 4 0.7829 10
71 pspnet_unet_s5-d16_64x64_40k_drive_dice 0.7951 9 0.795 10
72 pspnet_unet_s5-d16_64x64_40k_drive_gamma0 0.7912 4 0.7804 10
73 pspnet_unet_s5-d16_64x64_40k_drive_gamma0.5 0.7844 4 0.7751 10
74 pspnet_unet_s5-d16_64x64_40k_drive_gamma0.5_cd1to1 0.7865 4 0.7833 10
75 pspnet_unet_s5-d16_64x64_40k_drive_gamma0.5_cd1to2 0.7906 4 0.7807 10
76 pspnet_unet_s5-d16_64x64_40k_drive_gamma0.5_ce1to3 0.792 4 0.7823 10
77 pspnet_unet_s5-d16_64x64_40k_drive_gamma0.5_ce1to4 0.7926 4 0.7839 10
78 pspnet_unet_s5-d16_64x64_40k_drive_gamma0.5_ce2to1 0.796 7 0.785 10
79 pspnet_unet_s5-d16_64x64_40k_drive_gamma0.5_ce3to1 0.7891 6 0.7834 10
80 pspnet_unet_s5-d16_64x64_40k_drive_gamma0.5_ce4to1 0.7942 4 0.7845 10
81 pspnet_unet_s5-d16_64x64_40k_drive_gamma0.5_dice1to1 0.7935 7 0.7924 10
82 pspnet_unet_s5-d16_64x64_40k_drive_gamma0.5_dice1to2 0.7968 10 0.7968 10
83 pspnet_unet_s5-d16_64x64_40k_drive_gamma0.5_dice1to3 0.796 11 0.796 11
84 pspnet_unet_s5-d16_64x64_40k_drive_gamma0.5_dice1to4 0.7964 10 0.7964 10
85 pspnet_unet_s5-d16_64x64_40k_drive_gamma0.5_dice2to1 0.7946 4 0.794 10
86 pspnet_unet_s5-d16_64x64_40k_drive_gamma0.5_dice3to1 0.7968 4 0.7899 12
87 pspnet_unet_s5-d16_64x64_40k_drive_gamma0.5_dice4to1 0.8014 9 0.7894 11
88 pspnet_unet_s5-d16_64x64_40k_drive_gamma1 0.7811 5 0.7799 10
89 pspnet_unet_s5-d16_64x64_40k_drive_gamma1.5 0.7794 5 0.7747 10
90 pspnet_unet_s5-d16_64x64_40k_drive_gamma2 0.775 10 0.775 10
91 pspnet_unet_s5-d16_64x64_40k_drive_gamma4 0.7579 8 0.7555 10
92 pspnet_unet_s5-d16_64x64_40k_drive_gamma6 0.7156 8 0.6892 10

@codecov
Copy link

codecov bot commented Nov 8, 2021

Codecov Report

Attention: Patch coverage is 75.72816% with 25 lines in your changes missing coverage. Please review.

Project coverage is 89.53%. Comparing base (f0e6201) to head (5e91872).
Report is 282 commits behind head on master.

Files Patch % Lines
mmseg/models/losses/focal_loss.py 75.49% 23 Missing and 2 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1024      +/-   ##
==========================================
- Coverage   89.74%   89.53%   -0.21%     
==========================================
  Files         121      122       +1     
  Lines        6828     6931     +103     
  Branches     1139     1156      +17     
==========================================
+ Hits         6128     6206      +78     
- Misses        496      519      +23     
- Partials      204      206       +2     
Flag Coverage Δ
unittests 89.53% <75.72%> (-0.21%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@Junjun2016
Copy link
Collaborator

Hi @RockeyCoss
Please fix the lint error.

@Junjun2016 Junjun2016 changed the title [Feature] add focal loss [Feature] Add focal loss Nov 8, 2021
@Junjun2016
Copy link
Collaborator

Besides, some lines are not covered by unittests.

@RockeyCoss
Copy link
Contributor Author

Besides, some lines are not covered by unittests.
Can't fix that because the uncovered lines can only be executed when cuda is available.

mmseg/models/losses/focal_loss.py Outdated Show resolved Hide resolved
mmseg/models/losses/focal_loss.py Outdated Show resolved Hide resolved
mmseg/models/losses/focal_loss.py Outdated Show resolved Hide resolved
mmseg/models/losses/focal_loss.py Outdated Show resolved Hide resolved
@RockeyCoss
Copy link
Contributor Author

RockeyCoss commented Nov 15, 2021

speed comparision
图片
图片

mmseg/models/losses/focal_loss.py Outdated Show resolved Hide resolved
mmseg/models/losses/focal_loss.py Outdated Show resolved Hide resolved
mmseg/models/losses/focal_loss.py Outdated Show resolved Hide resolved
@MengzhangLI
Copy link
Contributor

@RockeyCoss It seems like some lines are not covered by unit test. Do we need to add unit test?

@MengzhangLI MengzhangLI mentioned this pull request Nov 24, 2021
@RockeyCoss
Copy link
Contributor Author

@RockeyCoss It seems like some lines are not covered by unit test. Do we need to add unit test?

Those codes need cuda to run, so it's not possible to test them. However, I have tested them with cuda available enviroment and the test_same() in unitests tests if the code will produce the same results when running in cuda available enviroment and cuda unavailable environment.

@Junjun2016 Junjun2016 requested a review from xvjiarui November 30, 2021 15:10
Co-authored-by: Jerry Jiarui XU <xvjiarui0826@gmail.com>
@xvjiarui xvjiarui merged commit 1b41989 into open-mmlab:master Dec 3, 2021
bowenroom pushed a commit to bowenroom/mmsegmentation that referenced this pull request Feb 25, 2022
* [Feature] add focal loss

* fix the bug of 'non' reduction type

* refine the implementation

* add class_weight and ignore_index; support different alpha values for different classes

* fixed some bugs

* fix bugs

* add comments

* modify test

* Update mmseg/models/losses/focal_loss.py

Co-authored-by: Junjun2016 <hejunjun@sjtu.edu.cn>

* update test_focal_loss.py

* modified the implementation

* Update mmseg/models/losses/focal_loss.py

Co-authored-by: Jerry Jiarui XU <xvjiarui0826@gmail.com>

* update focal_loss.py

Co-authored-by: Junjun2016 <hejunjun@sjtu.edu.cn>
Co-authored-by: Jerry Jiarui XU <xvjiarui0826@gmail.com>
@Zihang-Wei
Copy link

Hello, how should I use the focal loss inside the config file? It pops up this strange message for me...
image
In the config file:

model = dict(
    test_cfg=dict(crop_size=(128, 128), stride=(85, 85)), 
    auxiliary_head=dict(
        num_classes=3,
        loss_decode=[
            # dict(type='CrossEntropyLoss', loss_name='loss_ce'),
            # dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0),
            dict(type='FocalLoss', loss_name='loss_focal', loss_weight=3.0, gamma=2.0, alpha=[0.2, 0.1, 0.7])
            ]),
    decode_head=dict(
        num_classes=3, 
        loss_decode=[
            # dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0),
            # dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0),
            dict(type='FocalLoss', loss_name='loss_focal', loss_weight=3.0, gamma=2.0, alpha=[0.2, 0.1, 0.7]),
            # dict(type='LovaszLoss', loss_name='loss_lovasz', loss_weight=3.0)
            ]))

@RockeyCoss @Junjun2016 @MengzhangLI @xvjiarui @kahkeng

@RockeyCoss
Copy link
Contributor Author

Hello, how should I use the focal loss inside the config file? It pops up this strange message for me... image In the config file:

model = dict(
    test_cfg=dict(crop_size=(128, 128), stride=(85, 85)), 
    auxiliary_head=dict(
        num_classes=3,
        loss_decode=[
            # dict(type='CrossEntropyLoss', loss_name='loss_ce'),
            # dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0),
            dict(type='FocalLoss', loss_name='loss_focal', loss_weight=3.0, gamma=2.0, alpha=[0.2, 0.1, 0.7])
            ]),
    decode_head=dict(
        num_classes=3, 
        loss_decode=[
            # dict(type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0),
            # dict(type='DiceLoss', loss_name='loss_dice', loss_weight=3.0),
            dict(type='FocalLoss', loss_name='loss_focal', loss_weight=3.0, gamma=2.0, alpha=[0.2, 0.1, 0.7]),
            # dict(type='LovaszLoss', loss_name='loss_lovasz', loss_weight=3.0)
            ]))

@RockeyCoss @Junjun2016 @MengzhangLI @xvjiarui @kahkeng

Hello @Zihang-Wei ! Did you install mmcv-full? By the way, we would appreciate it if you could report this problem in the issue. Because by doing this, the problem can be viewed by more people.

@Zihang-Wei
Copy link

Yes, inside my dockerfile, there is:
image

@Zihang-Wei
Copy link

However since inside VScode, the mmcv package is inside /opt/venv/packages/... therefore the callbacks won't step into the function imported from mmcv.ops

aravind-h-v pushed a commit to aravind-h-v/mmsegmentation that referenced this pull request Mar 27, 2023
…pen-mmlab#1024)

* document cpu offloading method

* address review comments

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants