Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mindspore.nn.Conv2d lacks checks on convolutional kernel shapes #291

Open
PhyllisJi opened this issue Jun 5, 2024 · 2 comments
Open

Mindspore.nn.Conv2d lacks checks on convolutional kernel shapes #291

PhyllisJi opened this issue Jun 5, 2024 · 2 comments

Comments

@PhyllisJi
Copy link

PhyllisJi commented Jun 5, 2024

Software Environment:

  • MindSpore version (source or binary): binary
  • Python version (e.g., Python 3.7.5): 3.9
  • OS platform and distribution (e.g., Linux Ubuntu 16.04): Ubuntu 20.04
  • GCC/Compiler version (if compiled from source):

Describe the current behavior

This layer works fine even with incorrect parameter settings in model-implemented code, and crashes if run on its own.

Describe the expected behavior

The convolution kernel size needs to be checked and inappropriate parameter settings need to be thrown as an exception in advance.

Steps to reproduce the issue

import mindspore
import mindspore.nn as nn
import mindspore.ops as ops
import numpy as np

# Define the Conv2dTranspose layer with specified parameters
conv_transpose = nn.Conv2dTranspose(
    in_channels=8, 
    out_channels=8, 
    kernel_size=(1, 1), 
    stride=(1, 1), 
    pad_mode="pad", 
    padding=(7, 7, 7, 7), 
    output_padding=(0, 0), 
    dilation=(1, 1), 
    group=1, 
    has_bias=True
)

# Create an input tensor with shape (1, 8, 14, 14)
input_tensor = mindspore.Tensor(np.random.randn(1, 8, 14, 14), mindspore.float32)

# Apply the Conv2dTranspose layer to the input tensor
output_tensor = conv_transpose(input_tensor)

# Print the shape of the output tensor
print("Output shape:", output_tensor.shape)
import mindspore
import numpy as np
import os


mindspore.context.set_context(device_target='CPU')


class Model_A5cKhrbOS9qjJeV7sRO4l2ukdpmxk6kx(mindspore.nn.Cell):
    def __init__(self):
        super().__init__()
        self.conv1 = mindspore.nn.Conv2d(in_channels=3, out_channels=3, kernel_size=(3, 3), stride=(2, 2), pad_mode="pad", padding=(1, 1, 1, 1), dilation=(1, 1), group=1, has_bias=True, data_format="NCHW")
        self.conv2 = mindspore.nn.Conv2d(in_channels=3, out_channels=3, kernel_size=(3, 3), stride=(1, 1), pad_mode="pad", padding=(1, 1, 1, 1), dilation=(1, 1), group=3, has_bias=True, data_format="NCHW")
        self.relu1 = mindspore.nn.ReLU()
        self.conv3 = mindspore.nn.Conv2d(in_channels=3, out_channels=4, kernel_size=(1, 1), stride=(1, 1), pad_mode="pad", padding=(0, 0, 0, 0), dilation=(1, 1), group=1, has_bias=True, data_format="NCHW")
        self.relu2 = mindspore.nn.ReLU()
        self.conv4 = mindspore.nn.Conv2d(in_channels=4, out_channels=4, kernel_size=(3, 3), stride=(2, 2), pad_mode="pad", padding=(1, 1, 1, 1), dilation=(1, 1), group=4, has_bias=True, data_format="NCHW")
        self.relu3 = mindspore.nn.ReLU()
        self.conv5 = mindspore.nn.Conv2d(in_channels=4, out_channels=5, kernel_size=(1, 1), stride=(1, 1), pad_mode="pad", padding=(0, 0, 0, 0), dilation=(1, 1), group=1, has_bias=True, data_format="NCHW")
        self.relu4 = mindspore.nn.ReLU()
        self.conv6 = mindspore.nn.Conv2d(in_channels=5, out_channels=5, kernel_size=(3, 3), stride=(1, 1), pad_mode="pad", padding=(1, 1, 1, 1), dilation=(1, 1), group=5, has_bias=True, data_format="NCHW")
        self.relu5 = mindspore.nn.ReLU()
        self.conv7 = mindspore.nn.Conv2d(in_channels=5, out_channels=6, kernel_size=(1, 1), stride=(1, 1), pad_mode="pad", padding=(0, 0, 0, 0), dilation=(1, 1), group=1, has_bias=True, data_format="NCHW")
        self.relu6 = mindspore.nn.ReLU()
        self.conv8 = mindspore.nn.Conv2d(in_channels=6, out_channels=6, kernel_size=(3, 3), stride=(2, 2), pad_mode="pad", padding=(1, 1, 1, 1), dilation=(1, 1), group=6, has_bias=True, data_format="NCHW")
        self.relu7 = mindspore.nn.ReLU()
        self.conv9 = mindspore.nn.Conv2d(in_channels=6, out_channels=7, kernel_size=(1, 1), stride=(1, 1), pad_mode="pad", padding=(0, 0, 0, 0), dilation=(1, 1), group=1, has_bias=True, data_format="NCHW")
        self.relu8 = mindspore.nn.ReLU()
        self.conv10 = mindspore.nn.Conv2d(in_channels=7, out_channels=7, kernel_size=(3, 3), stride=(1, 1), pad_mode="pad", padding=(1, 1, 1, 1), dilation=(1, 1), group=7, has_bias=True, data_format="NCHW")
        self.relu9 = mindspore.nn.ReLU()
        self.conv11 = mindspore.nn.Conv2d(in_channels=7, out_channels=7, kernel_size=(1, 1), stride=(1, 1), pad_mode="pad", padding=(0, 0, 0, 0), dilation=(1, 1), group=1, has_bias=True, data_format="NCHW")
        self.relu10 = mindspore.nn.ReLU()
        self.conv12 = mindspore.nn.Conv2d(in_channels=7, out_channels=7, kernel_size=(3, 3), stride=(2, 2), pad_mode="pad", padding=(1, 1, 1, 1), dilation=(1, 1), group=7, has_bias=True, data_format="NCHW")
        self.relu11 = mindspore.nn.ReLU()
        self.conv13 = mindspore.nn.Conv2d(in_channels=7, out_channels=8, kernel_size=(1, 1), stride=(1, 1), pad_mode="pad", padding=(0, 0, 0, 0), dilation=(1, 1), group=1, has_bias=True, data_format="NCHW")
        self.relu12 = mindspore.nn.ReLU()
        self.conv14 = mindspore.nn.Conv2d(in_channels=8, out_channels=8, kernel_size=(3, 3), stride=(1, 1), pad_mode="pad", padding=(1, 1, 1, 1), dilation=(1, 1), group=8, has_bias=True, data_format="NCHW")
        self.relu13 = mindspore.nn.ReLU()
        self.conv15 = mindspore.nn.Conv2d(in_channels=8, out_channels=8, kernel_size=(1, 1), stride=(1, 1), pad_mode="pad", padding=(0, 0, 0, 0), dilation=(1, 1), group=1, has_bias=True, data_format="NCHW")
        self.relu14 = mindspore.nn.ReLU()
        self.conv16 = mindspore.nn.Conv2d(in_channels=8, out_channels=8, kernel_size=(3, 3), stride=(1, 1), pad_mode="pad", padding=(1, 1, 1, 1), dilation=(1, 1), group=8, has_bias=True, data_format="NCHW")
        self.relu15 = mindspore.nn.ReLU()
        self.conv17 = mindspore.nn.Conv2d(in_channels=8, out_channels=8, kernel_size=(1, 1), stride=(1, 1), pad_mode="pad", padding=(0, 0, 0, 0), dilation=(1, 1), group=1, has_bias=True, data_format="NCHW")
        self.relu16 = mindspore.nn.ReLU()
        self.conv18 = mindspore.nn.Conv2d(in_channels=8, out_channels=8, kernel_size=(3, 3), stride=(1, 1), pad_mode="pad", padding=(1, 1, 1, 1), dilation=(1, 1), group=8, has_bias=True, data_format="NCHW")
        self.relu17 = mindspore.nn.ReLU()
        self.conv19 = mindspore.nn.Conv2d(in_channels=8, out_channels=8, kernel_size=(1, 1), stride=(1, 1), pad_mode="pad", padding=(0, 0, 0, 0), dilation=(1, 1), group=1, has_bias=True, data_format="NCHW")
        self.relu18 = mindspore.nn.ReLU()
        self.conv20 = mindspore.nn.Conv2d(in_channels=8, out_channels=8, kernel_size=(3, 3), stride=(1, 1), pad_mode="pad", padding=(1, 1, 1, 1), dilation=(1, 1), group=8, has_bias=True, data_format="NCHW")
        self.relu19 = mindspore.nn.ReLU()
        self.conv21_mutated = mindspore.nn.Conv2dTranspose(in_channels=8, out_channels=8, kernel_size=(1, 1), stride=(1, 1), pad_mode="pad", padding=(7, 7, 7, 7), output_padding=(0, 0), dilation=(1, 1), group=1, has_bias=True)

    def construct(self, input):
        conv1_output = self.conv1(input)
        conv2_output = self.conv2(conv1_output)
        relu1_output = self.relu1(conv2_output)
        conv3_output = self.conv3(relu1_output)
        relu2_output = self.relu2(conv3_output)
        conv4_output = self.conv4(relu2_output)
        relu3_output = self.relu3(conv4_output)
        conv5_output = self.conv5(relu3_output)
        relu4_output = self.relu4(conv5_output)
        conv6_output = self.conv6(relu4_output)
        relu5_output = self.relu5(conv6_output)
        conv7_output = self.conv7(relu5_output)
        relu6_output = self.relu6(conv7_output)
        conv8_output = self.conv8(relu6_output)
        relu7_output = self.relu7(conv8_output)
        conv9_output = self.conv9(relu7_output)
        relu8_output = self.relu8(conv9_output)
        conv10_output = self.conv10(relu8_output)
        relu9_output = self.relu9(conv10_output)
        conv11_output = self.conv11(relu9_output)
        relu10_output = self.relu10(conv11_output)
        conv12_output = self.conv12(relu10_output)
        relu11_output = self.relu11(conv12_output)
        conv13_output = self.conv13(relu11_output)
        relu12_output = self.relu12(conv13_output)
        conv14_output = self.conv14(relu12_output)
        relu13_output = self.relu13(conv14_output)
        conv15_output = self.conv15(relu13_output)
        relu14_output = self.relu14(conv15_output)
        conv16_output = self.conv16(relu14_output)
        relu15_output = self.relu15(conv16_output)
        conv17_output = self.conv17(relu15_output)
        relu16_output = self.relu16(conv17_output)
        conv18_output = self.conv18(relu16_output)
        relu17_output = self.relu17(conv18_output)
        conv19_output = self.conv19(relu17_output)
        relu18_output = self.relu18(conv19_output)
        conv20_output = self.conv20(relu18_output)
        relu19_output = self.relu19(conv20_output)
        print(relu19_output.shape)
        conv21_output = self.conv21_mutated(relu19_output)
        print(conv21_output.shape)
        fc_output = conv21_output
        return fc_output


def go():
    model = Model_A5cKhrbOS9qjJeV7sRO4l2ukdpmxk6kx()
    x = mindspore.Tensor(np.random.randn(1, 3, 224, 224).astype(np.float32))
    y = model(x)
    flag = True


def train(inp, label):
    ms_model = Model_A5cKhrbOS9qjJeV7sRO4l2ukdpmxk6kx()
    initialize(ms_model)
    ms_input = mindspore.Tensor(inp.astype(np.float32))
    def forward_fn(label):
        ms_output = ms_model(ms_input)
        label = label.astype(np.int32)
        ms_targets = mindspore.Tensor(label)
        loss = mindspore.nn.CrossEntropyLoss(reduction='mean')(ms_output, ms_targets)
        return loss, ms_output

    (ms_loss, ms_output), ms_gradients = mindspore.value_and_grad(forward_fn, None, ms_model.trainable_params(), has_aux=True)(label)
    ms_gradients_dic = {}
    for var, gradient in zip(ms_model.trainable_params(), ms_gradients):
        ms_gradients_dic.setdefault(var.name, gradient.numpy())
    return ms_gradients_dic, ms_loss.numpy().item(), ms_output.numpy()

def initialize(model):
    module_dir = os.path.dirname(__file__)
    for name, param in model.parameters_and_names():
        layer_name, matrix_name = name.rsplit('.', 1)
        matrix_path = module_dir + '/../initializer/' + layer_name + '/' + matrix_name + '.npz'
        data = np.load(matrix_path)
        data = data['matrix']
        weight_tensor = mindspore.Tensor(data).float()
        param.set_data(weight_tensor)

go()

Related log / screenshot

Traceback (most recent call last):
  File "/mnt/AA_MoCoDiff/BR_MoCoDiff/result/extract/test.py", line 143, in <module>
    output_tensor = conv_transpose(input_tensor)
  File "/root/miniconda3/envs/mocodiff/lib/python3.9/site-packages/mindspore/nn/cell.py", line 705, in __call__
    raise err
  File "/root/miniconda3/envs/mocodiff/lib/python3.9/site-packages/mindspore/nn/cell.py", line 702, in __call__
    _pynative_executor.end_graph(self, output, *args, **kwargs)
  File "/root/miniconda3/envs/mocodiff/lib/python3.9/site-packages/mindspore/common/api.py", line 1215, in end_graph
    self._executor.end_graph(obj, output, *args, *(kwargs.values()))
RuntimeError: 
----------------------------------------------------
- cuDNN Error:
----------------------------------------------------
cudnnSetTensorNdDescriptor failed | Error Number: 3 CUDNN_STATUS_BAD_PARAM

----------------------------------------------------
- C++ Call Stack: (For framework developers)
----------------------------------------------------
mindspore/ccsrc/plugin/device/gpu/kernel/nn/convolution/conv2d_cudnn_gpu_kernel.h:574 Set4DDesc

Special notes for this issue

@VectorSL
Copy link
Contributor

VectorSL commented Sep 9, 2024

same as issue228

@PhyllisJi
Copy link
Author

same as issue228

MindSpore 2.2.14

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants