-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[Bug] Batchnorm running_var behaves differently when using gpu vs. cpu #14357
Comments
Hey, this is the MXNet Label Bot. |
I ever submitted a PR |
@mxnet-label-bot update [Bug, Gluon] |
This bug is still present in 1.4.1, the batchnorm operator does not support training on CPU. That is quite critical in my opinion. |
Wow, this is pretty crazy, all sorts of frequently used models will fail on CPU |
To clarify it is only an issue during training, the running Var won't be computed correctly. But if you load a gpu trained model on CPU then inference is still correct. |
I have just encountered the same or similar problem. I am training on a GPU using the python API. However, my inference target machine does not have a GPU and uses the C++ API on a CPU. My inference results are all nan. When I grouped all symbols to the output I was able to see that the problem starts after the first batch normalization layer. When I run the same C++ code on a machine with GPU, everything is ok. |
Well, batch normalization makes sense also to train smaller MLPs, and some people do that (we have a project aiming to do that). And people will try to debug bigger models on CPU before pushing them out to expensive GPU instances. I really hope this is not something very hard to fix? |
@PatricZhao @TaoLv is there any success using BN for training on CPU? |
Sure. We have training benchmarks for MKL-DNN backend. @juliusshufan Can you share some recent training trends here? Better to have BN in the model. Thanks. |
@TaoLv @eric-haibin-lin Per the nightly tracking on convergence, the ResNet50 with CiFAR10 indating the training trends to converge. |
The problem is not the training performance, the problem is the learnt |
@ThomasDelteil Thanks for your prompt response. So far the training and pre-trained model based inference is covered, may I know if your proposal on a (minimum) reproduction case? Thanks. |
Inference on a trained from scratch on CPU network is likely to be erroneous thought not guaranteed, the inference of a pretrained-on-GPU model on CPU is fine because the running values have been computed correctly. I'll share with you tomorrow an example of a training script that works on GPU but not on CPU. |
for ctx in [mx.cpu(), mx.gpu()]:
layer = gluon.nn.BatchNorm()
layer.initialize(ctx=ctx)
for i in range(100):
data = mx.nd.random.normal(loc=10, scale=2, shape=(1,3,224,224), ctx=ctx)
with autograd.record():
out = layer(data)
print(ctx, layer.running_var.data().asnumpy(), layer.running_mean.data().asnumpy())
as you can see the variance and mean are erroneous on CPU, it is actually always ones for the variance, and 0 for the mean. edit: it seems that the running_mean and running_var are computed in the forward on GPU but not on CPU. for ctx in [mx.cpu(), mx.gpu()]:
layer = gluon.nn.BatchNorm()
layer.initialize(ctx=ctx)
for i in range(100):
data = mx.nd.random.normal(loc=10, scale=2, shape=(1,3,224,224), ctx=ctx)
with autograd.record():
out = layer(data)
out.backward()
print(ctx, layer.running_var.data().asnumpy(), layer.running_mean.data().asnumpy())
|
import mxnet as mx
from mxnet import nd, autograd, gluon
import numpy as np
def transform(data, label):
return nd.transpose(data.astype(np.float32), (2,0,1))/255, label.astype(np.float32)
trainset = gluon.data.vision.FashionMNIST(train=True, transform=transform)
train_data = gluon.data.DataLoader(dataset=trainset, batch_size=50, shuffle=True)
SCE = gluon.loss.SoftmaxCrossEntropyLoss()
for ctx in [mx.cpu(), mx.gpu()]:
net = gluon.model_zoo.vision.get_model('resnet18_v1', pretrained=False, classes=10)
# Parameter initialization
net.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx, force_reinit=True)
trainer = gluon.Trainer(params=net.collect_params(), optimizer='sgd', optimizer_params={'learning_rate': .01, 'wd': 0.0001, 'momentum': 0.9})
# Training
for i, (data, label) in enumerate(train_data):
data = data.as_in_context(ctx)
label = label.as_in_context(ctx)
with autograd.record():
output = net(data)
loss = SCE(output, label)
loss.backward()
trainer.step(data.shape[0])
if i == 20:
break
# Training accuracy under autograd
accuracy = mx.metric.Accuracy()
for i, (data, label) in enumerate(train_data):
with autograd.record():
output = net(data.as_in_context(ctx))
accuracy.update(label, output)
if i == 5:
break
print("Train accuracy so far, under autograd training scope evaluation:", accuracy.get(), ctx)
# Training accuracy outside autograd
accuracy = mx.metric.Accuracy()
for i, (data, label) in enumerate(train_data):
output = net(data.as_in_context(ctx))
accuracy.update(label, output)
if i == 5:
break
print("Train accuracy so far, outside autograd training scope evaluation:", accuracy.get(), ctx) Inside / Outside the autograd scope should only impact the batch norm operator, using the local batch normalization or the computed running mean and variance. We can see that the CPU version is a lot worse than the GPU one.
|
@ThomasDelteil thanks for the example. Will follow up the case @wuxun-zhang @juliusshufan |
Thanks @pengzhao-intel, I was planning to look into this issue. Please let me know if you have already in progress work. |
@sandeep-krishnamurthy Thanks for asking, a previous quick going through this issue, seems the difference caused by different computation method on CPU/GPU backend, it's iters by iters v.s. accumulation. I also tried the ResNet50 training on CIFAR, the training and validation accuracy are all okay on CPU. |
Has this bug been fixed in the latest versions? |
Sorry for the later response for a while. We will look into this issue now and get back soon :) |
Hi @adrianloy,
|
Hi @ThomasDelteil,
So, for comparison purpose I moved generating tensors to NumPy.
For the test above I receive almost the same values for both backends:
The difference is so small that I guess it could be neglected (as a difference in rounding in both backends) |
Hi @ThomasDelteil ,
It shows that statistically GPU and CPU give similar result:
Please see the log for full data: test_03_master_fixed_sync.txt |
@adrianloy, @PatricZhao, |
I do see that there is no difference between CPU and GPU anymore so I guess I can close this. What I dont really understand is why the accuracy outside autograd is so much lower than under autograd (for both contexts). Is this expected behaviour? Still sounds like a bug to me, but maybe I am missing something obvious. |
Hi @adrianloy,
you will receive the same accuracy under and outside autograd. |
I have the issue in my model that the running_var parameter of some batchnorm layers are NaN after my dummy forward pass to initialize the parameters. While debugging, I discovered that the value of the running_var depends on the context I use. I assume this is a bug, as a model should behave the same no matter which context is used. Here is a minimum reproducible example:
If I set num_gpus in this code to 0 its using the CPU and the output is, that all running_var values are 1.
If i set num_gpus to 1 or 2 its using GPUs and the value of all running_var values is 2.6561329e-05
Can anyone reproduce this? I am using mxnet 1.3.1 on ubuntu, built from source.
The text was updated successfully, but these errors were encountered: