Skip to content

Commit

Permalink
fix onnx conversion of maxunpool2d (#243)
Browse files Browse the repository at this point in the history
  • Loading branch information
grimoire authored Apr 10, 2021
1 parent 5b0be70 commit 3c26d34
Showing 1 changed file with 56 additions and 5 deletions.
61 changes: 56 additions & 5 deletions mmedit/models/backbones/encoder_decoders/decoders/plain_decoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import warnings

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -36,9 +34,62 @@ def forward(ctx, input, indices, kernel_size, stride, padding,

@staticmethod
def symbolic(g, input, indices, kernel_size, stride, padding, output_size):
warnings.warn(
'The definitions of indices are different between Pytorch and ONNX'
', so the outputs between Pytorch and ONNX maybe different')
# get shape
input_shape = g.op('Shape', input)
const_0 = g.op('Constant', value_t=torch.tensor(0))
const_1 = g.op('Constant', value_t=torch.tensor(1))
batch_size = g.op('Gather', input_shape, const_0, axis_i=0)
channel = g.op('Gather', input_shape, const_1, axis_i=0)

# height = (height - 1) * stride + kernel_size
height = g.op(
'Gather',
input_shape,
g.op('Constant', value_t=torch.tensor(2)),
axis_i=0)
height = g.op('Sub', height, const_1)
height = g.op('Mul', height,
g.op('Constant', value_t=torch.tensor(stride[1])))
height = g.op('Add', height,
g.op('Constant', value_t=torch.tensor(kernel_size[1])))

# width = (width - 1) * stride + kernel_size
width = g.op(
'Gather',
input_shape,
g.op('Constant', value_t=torch.tensor(3)),
axis_i=0)
width = g.op('Sub', width, const_1)
width = g.op('Mul', width,
g.op('Constant', value_t=torch.tensor(stride[0])))
width = g.op('Add', width,
g.op('Constant', value_t=torch.tensor(kernel_size[0])))

# step of channel
channel_step = g.op('Mul', height, width)
# step of batch
batch_step = g.op('Mul', channel_step, channel)

# channel offset
range_channel = g.op('Range', const_0, channel, const_1)
range_channel = g.op(
'Reshape', range_channel,
g.op('Constant', value_t=torch.tensor([1, -1, 1, 1])))
range_channel = g.op('Mul', range_channel, channel_step)
range_channel = g.op('Cast', range_channel, to_i=7) # 7 is int64

# batch offset
range_batch = g.op('Range', const_0, batch_size, const_1)
range_batch = g.op(
'Reshape', range_batch,
g.op('Constant', value_t=torch.tensor([-1, 1, 1, 1])))
range_batch = g.op('Mul', range_batch, batch_step)
range_batch = g.op('Cast', range_batch, to_i=7) # 7 is int64

# update indices
indices = g.op('Add', indices, range_channel)
indices = g.op('Add', indices, range_batch)

return g.op(
'MaxUnpool',
input,
Expand Down

0 comments on commit 3c26d34

Please sign in to comment.