-
Notifications
You must be signed in to change notification settings - Fork 106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Torch version affects the network's training performance #8
Comments
Hi, do you know if this is still an issue in Pytorch 1.8? Thank you! |
Based on my experiment, yes, Pytorch 1.8 is still an issue. If you can only use Pytorch 1.8 due to hardware restriction (i.e. CUDA version etc.), you can replace all BatchNorm with InstanceNorm, which should avoid this. |
will it work with Pythorch 1.6 |
Hi @EhrazImam It looks like the answer is no. Please find the implementation of BN from 1.5.1 (which is the one I was using) here and BN from 1.6.0 here. You will see the change of funciton signature I mentioned above. |
Hi, Thank you brought up this issue out front.
So I am in delimma: use torch 1.5.1 OOM vs. use A6000 with enough memory but cannot run torch 1.5.1 For your information, the new generation of GPU like RTX 3090, A6000, etc. will run on torch 1.10.0 with Cuda 11.2 or later (which support sm_86) I understand that it is almost impossible to support all version of pytorch, but how about to selectively support a least one version compatible with the "future" generation of GPU such as pytorch 1.10 with Cuda 11.2 or later? What do you think? Thanks a lot for your help in advance! |
hi,@mli0603 @ynjiun ,I found a way to resolve this problem,according to pytorch/pytorch#37823 (comment) & https://discuss.pytorch.org/t/performance-highly-degraded-when-eval-is-activated-in-the-test-phase/3323/66 ,I modified the code in _disable_batchnorm_tracking , setting the mean and var variables in the batch norm to None ,which resolve the problem. def _disable_batchnorm_tracking(self):
"""
disable Batchnorm tracking stats to reduce dependency on dataset (this acts as InstanceNorm with affine when batch size is 1)
"""
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.track_running_stats = False
m.running_mean = None
m.running_var = None |
Oh nice! Thank you very much for this patch. Let me test it on my end too. |
fix BN issue due to torch version #8
I am opening this issue because apparently depending on which version of pytorch you are using, the training result will be different. Here are the 3px error evaluation curves of on a minimal example of overfitting the network on a single image for 300 epochs:
The purple line is trained with Pytorch 1.7.0 and the orange line is trained with Pytorch 1.5.1. As you can see, with version 1.7.0 the error rate is flat 100%, while version 1.5.1 the error rate is dropping. Reason for this is that the BatchNorm function has changed between version 1.5.1 and Pytorch 1.7.0. In version 1.5.1, if I disable
track_running_stats
here, both evaluation and training will use batch stats. However in Pytorch 1.7.0, it is forced to userunning_mean
andrunning_var
in evaluation mode, while in training the batch stats is used. Withtrack_running_stats
disabled, therunning_mean
is 0 andrunning_var
is 1, which is clearly different from the batch stats.Therefore, instead of trying to do something against torch's implementation, I will recommend to use Pytorch 1.5.1 if you want to retrain from scratch. Otherwise, if you want to use other Pytorch version, you can replace all BatchNorm with InstanceNorm and port the learnt values from BatchNorm (i.e. weight and bias). This is a
wontfix
problem because it is quite hard to accomodate all torch versions.The text was updated successfully, but these errors were encountered: