diff --git a/mmedit/models/backbones/encoder_decoders/decoders/plain_decoder.py b/mmedit/models/backbones/encoder_decoders/decoders/plain_decoder.py index a07a829dd1..4265734830 100644 --- a/mmedit/models/backbones/encoder_decoders/decoders/plain_decoder.py +++ b/mmedit/models/backbones/encoder_decoders/decoders/plain_decoder.py @@ -1,5 +1,3 @@ -import warnings - import torch import torch.nn as nn import torch.nn.functional as F @@ -36,9 +34,62 @@ def forward(ctx, input, indices, kernel_size, stride, padding, @staticmethod def symbolic(g, input, indices, kernel_size, stride, padding, output_size): - warnings.warn( - 'The definitions of indices are different between Pytorch and ONNX' - ', so the outputs between Pytorch and ONNX maybe different') + # get shape + input_shape = g.op('Shape', input) + const_0 = g.op('Constant', value_t=torch.tensor(0)) + const_1 = g.op('Constant', value_t=torch.tensor(1)) + batch_size = g.op('Gather', input_shape, const_0, axis_i=0) + channel = g.op('Gather', input_shape, const_1, axis_i=0) + + # height = (height - 1) * stride + kernel_size + height = g.op( + 'Gather', + input_shape, + g.op('Constant', value_t=torch.tensor(2)), + axis_i=0) + height = g.op('Sub', height, const_1) + height = g.op('Mul', height, + g.op('Constant', value_t=torch.tensor(stride[1]))) + height = g.op('Add', height, + g.op('Constant', value_t=torch.tensor(kernel_size[1]))) + + # width = (width - 1) * stride + kernel_size + width = g.op( + 'Gather', + input_shape, + g.op('Constant', value_t=torch.tensor(3)), + axis_i=0) + width = g.op('Sub', width, const_1) + width = g.op('Mul', width, + g.op('Constant', value_t=torch.tensor(stride[0]))) + width = g.op('Add', width, + g.op('Constant', value_t=torch.tensor(kernel_size[0]))) + + # step of channel + channel_step = g.op('Mul', height, width) + # step of batch + batch_step = g.op('Mul', channel_step, channel) + + # channel offset + range_channel = g.op('Range', const_0, channel, const_1) + range_channel = g.op( + 'Reshape', range_channel, + g.op('Constant', value_t=torch.tensor([1, -1, 1, 1]))) + range_channel = g.op('Mul', range_channel, channel_step) + range_channel = g.op('Cast', range_channel, to_i=7) # 7 is int64 + + # batch offset + range_batch = g.op('Range', const_0, batch_size, const_1) + range_batch = g.op( + 'Reshape', range_batch, + g.op('Constant', value_t=torch.tensor([-1, 1, 1, 1]))) + range_batch = g.op('Mul', range_batch, batch_step) + range_batch = g.op('Cast', range_batch, to_i=7) # 7 is int64 + + # update indices + indices = g.op('Add', indices, range_channel) + indices = g.op('Add', indices, range_batch) + return g.op( 'MaxUnpool', input,