Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

What is the result of BiSeNetV2 without pretrained model? #173

Closed
MengzhangLI opened this issue Aug 12, 2021 · 38 comments
Closed

What is the result of BiSeNetV2 without pretrained model? #173

MengzhangLI opened this issue Aug 12, 2021 · 38 comments

Comments

@MengzhangLI
Copy link

MengzhangLI commented Aug 12, 2021

Hi, thanks for your great work!

I just implemented this excellent codebase and my result is 74.41 mIoU on cityscapes validation set of single scale.

Just as the author, mIoU: 73.36 and paddleseg, mIoU: 73.19, both of them do not have a pretrained model, could you tell me what the result of no-pretrained model of your codebase, or teach me how to discard pretrained model?

I really want to re-implement several experiments with no pretrained model condition of your code, and I will tell you results ASAP.

Thank you vary much!

Best,

@CoinCheung
Copy link
Owner

It should be between 71-74, if I remeber well.
If we do not have pretrained weights, the results would have a larger variance. You can run 5 rounds and see the results.
If you need to discard pretrained weight, you can comment this line:

self.load_pretrain()

@NguyenCongPhucBK
Copy link

Hi @CoinCheung , when I finetune from trained model (your checkpoint in Cityscapes data has 19 class, I want to finetune with my model has 2 class. I has error:
Can I help me, please.
Thanks you!

image

@MengzhangLI
Copy link
Author

Hi, this is my results without pretrained model, I'm still implementing more experiments.

截屏2021-08-14 上午1 14 09

Could you teach me how to run pycharm configuration if you know how to avoid using -m torch.distributed.launch --nproc_per_node=$NGPUS?

I really want to debug training process because I want to figure it out the learning rate schedule and loss calculation, both of them are coded by yourself (nice work indeed) and different from those public codebase.

Thanks again!

Best,

@CoinCheung
Copy link
Owner

@MengzhangLI
Hi, Good to know that you have done your own experiments. Said that I have no experience with pycharm, but you can refer to pytorch example here. They spawn the processes manually in the code. Or if you do not want to waste your time on that, you can simple print out the values as log messages, which would be simpler.

@CoinCheung
Copy link
Owner

@MengzhangLI Hi, as a temporary solution, you can change the following line:

child.load_state_dict(state[name], strict=True)

and set strict=True to strict=False, this would skip the loading of wrong size parameters.

@MengzhangLI
Copy link
Author

@MengzhangLI
Hi, Good to know that you have done your own experiments. Said that I have no experience with pycharm, but you can refer to pytorch example here. They spawn the processes manually in the code. Or if you do not want to waste your time on that, you can simple print out the values as log messages, which would be simpler.

Hi, thanks for your nice reply.

I am wondering what your pretrained model is? Because most pretrained models are backbones of certain big tasks on large dataset such as ImageNet. But your provided pretrained model has both detail branch (which are usually backbones in other tasks such as Res50, Res101 and so on) and specified modules (segment branch and BGA).

When using your pretrained model, is it somewhat like finetune process based on your previous training model?

Best,

@CoinCheung
Copy link
Owner

CoinCheung commented Aug 16, 2021

Hi,

The model is pretrained on imagenet. The specific training method is associated with some other people's task (I believe he would not want me to talk about it), so I cannot say too much. Sorry for not providing useful information.

@CoinCheung
Copy link
Owner

@MengzhangLI
Hi, according to the result you posted above ( you have trained without pretrained weights), the result is lower than that in paddleseg (I did not notice this difference last time), do you have any option on the reason that causes this gap?

@MengzhangLI
Copy link
Author

MengzhangLI commented Aug 16, 2021

@MengzhangLI
Hi, according to the result you posted above ( you have trained without pretrained weights), the result is lower than that in paddleseg (I did not notice this difference last time), do you have any option on the reason that causes this gap?

I guess it is just because the framework of paddleseg enables training procedure more robust?
Here are next several numerical results without pretrained model:
image

I think there are also two differences between paddleseg and your codebase:

(1) learning rate strategy: It is 0.05 with PolynomialDecay, which is different from your unique LR schedule. BTW, could you tell me how to implement your LR schedule? I'm a little bit confused when reading your code WarmupPolyLrScheduler:

WXWorkCapture_16290917308763

(2) Loss function and OHEM. Default OHEM threshold 0.7 means keeping 30% hardest samples while your implementing it means loss function -torch.log(torch.tensor(thresh, requires_grad=False, dtype=torch.float)).cuda() larger than -log(0.7) because it is cross entropy loss after LogSoftmax and NLLLoss.

@MengzhangLI
Copy link
Author

Added:

The model itself is also different:

(1) Paddleseg usually use conv.bias (i.e., in conv, bias=True).

(2) default auxiliary: only one fully-connected layer, while your setting is two layer whose channels are also different. For example, in aux4, it is 64 -> 128 -> 256 -> 19.

@CoinCheung
Copy link
Owner

@MengzhangLI
Hi,

Answers to your question:

  1. My lr_scheduler is just common lr_scheduler with warmup at the head. You can refer to the code here:

    class WarmupLrScheduler(torch.optim.lr_scheduler._LRScheduler):

    You might observe that the lr value printed in the log message is different from the value you set. It is because I computed the mean value of the lr of different optimization group, code is here:
    lr = sum(lr) / len(lr)

  2. Here 0.7 means softmax score value over 0.7 should be discarded, since they are easier pixels than the others. I have not tried "discarding 70% of the whole pixels", but I guess it should make limited difference(no experiments as proof of this).

@CoinCheung
Copy link
Owner

Also, I did not mention that, I used a 0.1x lr value when I loaded pretrained weights. If you do not load pretrained weights, you might need to use larger lr value(same value as in paper).
This would make some difference between the result, but the problem of large variance still exists. You will still see the results varies between each round of train/val cycle.

@MengzhangLI
Copy link
Author

Ok, thank you. I've transferred your code into MMSegmentation framework.

When setting lr=0.05 with polydecay strategy just as Paddleseg, its mIoU is about 72.43, still weaker than paddleseg. Right now I'm fixing it.

BTW, the author use 3x3 and 1x1 conv module in auxiliary head, which is both different from your setting and paddleseg. Did you try to implement his setting?

===========================================================================================
image

===========================================================================================

Best,

@CoinCheung
Copy link
Owner

CoinCheung commented Aug 16, 2021

Hi,
Thanks for providing so much useful information. I used aux head with more conv layers because I found it would make fp16 training on coco-stuff dataset more stable. If I used the structure same as paper, one aux loss would become nan during fp16 training.

Let me share more information that I can remember, hope this would help you. When I tried this model quite long time ago(about one year ago), I found the multi-gpu configuration can also make a difference. Using 8 gpus(large batch, small total iterations) brings different results as 2gpus (8 image per gpu). You have to adjust training hyper-parameters(mainly lr) if you change gpu parallel configuration. I did not do enough experiments to form up a conclusion on this, but I have an impression that the learning rate seems not to go linearly with batch size as expected.

Have you measured the variance of paddleseg, is the result of each round of training align with each other ?

@MengzhangLI
Copy link
Author

Not yet. I'm focusing on how to reconstruct BiSeNetV2 module on MMSegmentation, because right now the best result is about 72.4, a little bit lower than author and paddleseg.

I will report related results later. Right now I only found learning rate and its strategy is very important.

Best,

@MengzhangLI
Copy link
Author

Hi,
Thanks for providing so much useful information. I used aux head with more conv layers because I found it would make fp16 training on coco-stuff dataset more stable. If I used the structure same as paper, one aux loss would become nan during fp16 training.

Let me share more information that I can remember, hope this would help you. When I tried this model quite long time ago(about one year ago), I found the multi-gpu configuration can also make a difference. Using 8 gpus(large batch, small total iterations) brings different results as 2gpus (8 image per gpu). You have to adjust training hyper-parameters(mainly lr) if you change gpu parallel configuration. I did not do enough experiments to form up a conclusion on this, but I have an impression that the learning rate seems not to go linearly with batch size as expected.

Have you measured the variance of paddleseg, is the result of each round of training align with each other ?

Do you mean

(1) 2 GPU x 8 batch sizes (total 16 batch sizes), 160K iterations.

(2) 8 GPU x 8 or 4 batch sizes (total 64 or 32 batch sizes) , 40K or 80K iterations.

Do you mean although the total batch sizes and training iterations are the same, the results are different?

Based on many people and community talked about many times, learning rate has something to do with total batch sizes (maybe linear relation). So maybe learning rate should also be 2 or 4 times than before. But this relation is only empirical, grid search is necessary I guess.

Best,

@CoinCheung
Copy link
Owner

Yes, similar configurations. I tried to enlarge learning rate linearly with batch size, but still there can be difference. Or so if I did not remember wrongly.

I trained 150k as the paper does, rather than 160k.

@MengzhangLI
Copy link
Author

Yes, similar configurations. I tried to enlarge learning rate linearly with batch size, but still there can be difference. Or so if I did not remember wrongly.

I trained 150k as the paper does, rather than 160k.

OK, got it.

I think after transferring to MMSegmentation, I should also properly enlarge iteration times (for example, 200K or 240K) to see whether in 150/160K iterations is converge or not.

I will report my results couple of days later. Thank you.

Best,

@MengzhangLI
Copy link
Author

@MengzhangLI
Hi, according to the result you posted above ( you have trained without pretrained weights), the result is lower than that in paddleseg (I did not notice this difference last time), do you have any option on the reason that causes this gap?

I guess it is just because the framework of paddleseg enables training procedure more robust?
Here are next several numerical results without pretrained model:
image

I think there are also two differences between paddleseg and your codebase:

(1) learning rate strategy: It is 0.05 with PolynomialDecay, which is different from your unique LR schedule. BTW, could you tell me how to implement your LR schedule? I'm a little bit confused when reading your code WarmupPolyLrScheduler:

WXWorkCapture_16290917308763

(2) Loss function and OHEM. Default OHEM threshold 0.7 means keeping 30% hardest samples while your implementing it means loss function -torch.log(torch.tensor(thresh, requires_grad=False, dtype=torch.float)).cuda() larger than -log(0.7) because it is cross entropy loss after LogSoftmax and NLLLoss.

Sorry man I made mistakes, 74+ results is using pretrained model. My bad. (I think it is caused by I remote scp server last weekend by pycharm, when connecting, local folder will upload automatically thus those commented line are renewed.)

So right now no pretrained model is about 0.70-0.71, maybe it is caused by learning rate as you mentioned before. I will change learning rate to 0.05.

@MengzhangLI
Copy link
Author

MengzhangLI commented Aug 17, 2021

Update:
Paddleseg totally follow original author settings, i.e.,

(1) No ReLU at the end of BGALayer (link).

(2) No BN and ReLU at the end of CEBlock (link).

(3) No ReLU at the end of GELayerS1 and GELayerS2 (link)

(4) Just 3x3 and 1x1 conv for each Segmentation Head (link).

The result of paddleseg on cityscapes validation set is 73.19 without OHEM, which is totally the same as paper.

image

@CoinCheung
Copy link
Owner

@MengzhangLI
Hi,
is this your experiment result? Does paddleseg's implementation has variance if we do not use fixed random seed?
In my opinion, the relu in bca and ceb would make few difference, but using fp16 might bring some difference(not verified by experiment).

@MengzhangLI
Copy link
Author

MengzhangLI commented Aug 17, 2021

@MengzhangLI
Hi,
is this your experiment result? Does paddleseg's implementation has variance if we do not use fixed random seed?
In my opinion, the relu in bca and ceb would make few difference, but using fp16 might bring some difference(not verified by experiment).

Nah, the figure is author's result. I will keep pace with their settings and do some experiments these days based on MMSegmentation. It's my bad that I don't know how to change your unique learning rate function (if I just set lr=0.05 in bisenetv2_city config, it will become ~0.1 at first several iterations).

Plus, right now I do not implement training on Paddleseg.

Best,

@CoinCheung
Copy link
Owner

CoinCheung commented Aug 17, 2021

It is just an average of optimization groups:

lr = lr_schdr.get_lr()

You can change this into:

    lr = lr_schdr.get_lr()[0]

This should be the value in the cfg.
It is not the problem of lr_scheduler, it is how we compute log message.

@MengzhangLI
Copy link
Author

Seems like it can not work.

截屏2021-08-17 下午9 39 01

截屏2021-08-17 下午9 56 21

@CoinCheung
Copy link
Owner

You should also comment out the following line:

lr = sum(lr) / len(lr)

@MengzhangLI
Copy link
Author

It works, thank you.

I will report results quickly.

Best,

@MengzhangLI
Copy link
Author

Hi, I re-implemented in the previous several days. The results (which not using pretrained model) were worse than we expected.

Here is my results: most of them are less than 72.0.

截屏2021-08-21 下午3 34 42

And those configuration I changed is following as below:
截屏2021-08-21 下午3 35 04

截屏2021-08-21 下午3 35 29

Best,

@CoinCheung
Copy link
Owner

CoinCheung commented Aug 21, 2021

That is still not align with my local experiments, maybe I have made other changes after adding the pretrained backbone.

One thing that I could think of is the following line:

if hasattr(model, 'get_params'):

The method of get_params is added after using the pretrained weights. This makes sure that the pretrained weights are training with assigned learning rate, while the non-pretrained weights are training with 10x learning rate. Maybe we should also comment out the get_params from the model definition:
def get_params(self):

There can also be other details that bring differences. If you do not want to resume the previous non-pretrained experiments step by step, you can simply roll back the code:

git clone ...
cd BiSeNet
git reset --hard 90de8f9bdb6281f700087a

This comment should be the latest one before pretrained weights are added.

@MengzhangLI
Copy link
Author

Hi, sorry for late reply.

Based on your excellent work, I integrated BiSeNet V2 into MMSegmentation, here is related PR. Results are very exciting.

mIoU of val cityscapes dataset:

Paper Paper (With OHEM) Paddleseg Ours Ours (FP16)
73.19 73.36 73.19 73.37 75.49

Thank you very much. Next time I would add V1 and more dataset and backbones.

Best,

@MengzhangLI
Copy link
Author

P.S. All of them are trained from scratched, 75.49 is a little bit wired (you can find my training log in my PR), maybe because itself is unstable. But all results are better than paddleseg and author reports.

@CoinCheung
Copy link
Owner

Hi, Good to know you have done experiments and draw your own conclusions !!!

I think the difference of 73.37 and 75.49 is because of model variance. When you train with same configuration, you see result is different each time. Adding a pretrained backbone would partially reduce this variance and increase the mean result in some way.

You have done a good job !!!

@MengzhangLI
Copy link
Author

MengzhangLI commented Sep 7, 2021

One more question, why author's BiSeNet V1 R18 could reach out 76.28 on Cityscapes mIoU?

I use pretrained model usually get 75.

https://github.com/ycszen/TorchSeg

Thank you again.

@MengzhangLI
Copy link
Author

Three reasons I suppose.
(1) Using some tricks he does not highlight in paper.
(2) Elongate iterations / enlarge batch sizes to make training process longer (it is possible because I did not find training log)
(3) Using special pretrained model.

@CoinCheung
Copy link
Owner

Hi,

I have no experience of training the author's code, though I have read it when I tried to figure out the details. Their code is not updated for sometime, new-version pytorch can also bring some differences. Or it is simply a phenomenon of model variance. By the way, how did you obtain the result of 75?

@MengzhangLI
Copy link
Author

Hi,

I have no experience of training the author's code, though I have read it when I tried to figure out the details. Their code is not updated for sometime, new-version pytorch can also bring some differences. Or it is simply a phenomenon of model variance. By the way, how did you obtain the result of 75?

By the help of ResNet18 V1c pretrained model, please see my pr here. And it seems there is a gap between its FPS and yours, please see here.

@CoinCheung
Copy link
Owner

for the problem of miou, maybe there are some details in the paper brings about the gap.
for the problem of speed, did you use tensorrt for inference?

@CoinCheung
Copy link
Owner

I used tensorrt 7.2.3 for inference, and I used input size of 1024x2048, my gpu is tesla T4.

I originally used tensorrt 7.0.0, and the speed is a little bit slower than that of 7.2.3.(about 1fps).

@MengzhangLI
Copy link
Author

for the problem of miou, maybe there are some details in the paper brings about the gap.
for the problem of speed, did you use tensorrt for inference?

Not yet. ./tools/benchmark.py in MMSegmentation is not very accurate for testing FPS and TensorRT for MMSegmentation is experimental trail right now. But we would testify its FPS especially using TensorRT in the near future.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants