Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mindspore.ops.reciprocal会导致完全一致的代码在CPU和GPU上的运行结果存在巨大差异 #282

Open
PhyllisJi opened this issue May 17, 2024 · 1 comment

Comments

@PhyllisJi
Copy link

PhyllisJi commented May 17, 2024

Environment

Hardware Environment(Ascend/GPU/CPU):

Uncomment only one /device <> line, hit enter to put that in a new line, and remove leading whitespaces from that line:

/device gpu

/device cpu

Software Environment:

  • MindSpore version (source or binary): 2.2.14 binary
  • Python version (e.g., Python 3.7.5): python 3.9
  • OS platform and distribution (e.g., Linux Ubuntu 16.04): Ubuntu 20.04
  • GCC/Compiler version (if compiled from source):

Describe the current behavior

There is a huge difference between the training process using CPU and GPU for the same code that implements the model.

"loss diff": 0.23193359375,
"output diff": 0.12652587890625,
"grad diff": 
 tail_fc.bias 5.960464477539063e-08
 tail_fc.weight 2.6376953125
 conv1_mutated.weight 11981.5
 conv1_mutated.bias 2888.5
 conv2_mutated.bias 4349.0
 conv2_mutated.weight 3055.0

Describe the expected behavior

There should be no significant difference

Steps to reproduce the issue

class Model_cxStrEsnTdlpdKLtuzJoAANEvSNEqVch(mindspore.nn.Cell):
    def __init__(self):
        super(Model_cxStrEsnTdlpdKLtuzJoAANEvSNEqVch, self).__init__()
        self.conv1_mutated = mindspore.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=(7, 7), stride=(2, 2), pad_mode="pad", padding=(3, 3, 3, 3), dilation=(8, 2), group=1, has_bias=True, data_format="NCHW")
        
        self.pool1_mutated = mindspore.nn.MaxPool2d(kernel_size=(8, 8), stride=(2, 2), pad_mode="pad", padding=(0, 0), dilation=(1, 1), return_indices=False, ceil_mode=True, data_format="NCHW")
        self.conv2_mutated = mindspore.nn.Conv2dTranspose(in_channels=4, out_channels=4, kernel_size=(1, 1), stride=(1, 1), pad_mode="pad", padding=(0, 0, 0, 0), output_padding=(0, 0), dilation=(1, 1), group=1, has_bias=True)
        self.relu2_mutated = mindspore.ops.reciprocal
        self.tail_flatten = mindspore.nn.Flatten(start_dim=1, end_dim=-1)
        self.tail_fc = mindspore.nn.Dense(in_channels=8944, out_channels=1000)

    def construct(self, input):
        conv1_output = self.conv1_mutated(input)
        relu1_output = mindspore.ops.where(conv1_output >= 0, conv1_output, 0.1 * (mindspore.ops.exp(conv1_output) - 1))
        pool1_output = self.pool1_mutated(relu1_output)
        conv2_output = self.conv2_mutated(pool1_output)
        relu2_output = self.relu2_mutated(conv2_output)
        tail_flatten_output = self.tail_flatten(relu2_output)
        tail_fc_output = self.tail_fc(tail_flatten_output)

        tail_fc_output = tail_fc_output
        return tail_fc_output


def go():
    try:
        ms_model = Model_cxStrEsnTdlpdKLtuzJoAANEvSNEqVch()
        ms_input = mindspore.Tensor(np.random.randn(1, 3, 224, 224).astype(np.float32))
        ms_output = ms_model(ms_input)
        flag = True
    except Exception:
        flag = False
    return flag


def train(inp, label):
    ms_model = Model_cxStrEsnTdlpdKLtuzJoAANEvSNEqVch()
    initialize(ms_model)
    ms_input = mindspore.Tensor(inp.astype(np.float32))
    def forward_fn(label):
        ms_output = ms_model(ms_input)
        label = label.astype(np.int32)
        ms_targets = mindspore.Tensor(label)
        loss = mindspore.nn.CrossEntropyLoss(reduction='mean')(ms_output, ms_targets)
        return loss, ms_output

    (ms_loss, ms_output), ms_gradients = mindspore.value_and_grad(forward_fn, None, ms_model.trainable_params(), has_aux=True)(label)
    ms_gradients_dic = {}
    for var, gradient in zip(ms_model.trainable_params(), ms_gradients):
        ms_gradients_dic.setdefault(var.name, gradient.numpy())
    return ms_gradients_dic, ms_loss.numpy().item(), ms_output.numpy()

def initialize(model):
    module_dir = os.path.dirname(__file__)
    for name, param in model.parameters_and_names():
        layer_name, matrix_name = name.rsplit('.', 1)
        matrix_path = module_dir + '/../initializer/' + layer_name + '/' + matrix_name + '.npz'
        data = np.load(matrix_path)
        data = data['matrix']
        weight_tensor = mindspore.Tensor(data).float()
        param.set_data(weight_tensor)

only change mindspore.context.set_context(device_target='CPU') to mindspore.context.set_context(device_target='GPU')

Related log / screenshot

"loss diff": 0.23193359375,
"output diff": 0.12652587890625,
"grad diff": 
 tail_fc.bias 5.960464477539063e-08
 tail_fc.weight 2.6376953125
 conv1_mutated.weight 11981.5
 conv1_mutated.bias 2888.5
 conv2_mutated.bias 4349.0
 conv2_mutated.weight 3055.0

Special notes for this issue

@singularity6033
Copy link

Use official MindSpore 2.2.14 to reproduce the result, here is the diff between CPU and GPU:
image
Our benchmark is PyTorch, here is the diff between PyTorch CPU and GPU:
image
PyTorch also has this kind of gap between CPU and GPU.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants