Skip to content

Commit

Permalink
[RELAY][FRONTEND][CAFFE2] add Mul and ConvTranspose operator (apache#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Huacong Yang authored and Trevor Morris committed Apr 16, 2020
1 parent bb0856c commit 6321fa7
Showing 1 changed file with 35 additions and 0 deletions.
35 changes: 35 additions & 0 deletions python/tvm/relay/frontend/caffe2.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,12 @@ class Add(Elemwise):
name = 'add'


class Mul(Elemwise):
""" Operator converter for Mul.
"""
name = 'multiply'


class Pool(Caffe2OpConverter):
""" A helper class for pool op converters.
"""
Expand Down Expand Up @@ -233,6 +239,33 @@ def _impl(cls, inputs, args, params):
return out


class ConvTranspose(Caffe2OpConverter):
""" Operator converter for ConvTranspose.
"""

@classmethod
def _impl(cls, inputs, args, params):
# get number of channels
channels = infer_channels(inputs[1], True)
args['channels'] = channels
_clean_up_pool_args(args)
out = AttrCvt(
op_name=dimension_picker('conv', '_transpose'),
transforms={
'kernel_shape': 'kernel_size',
'pads': ('padding', (0, 0), revert_caffe2_pad),
'dilations': ('dilation', (1, 1)),
'order': ('data_layout', ("NCHW"), lambda x: x if isinstance(x, str) else x.decode('UTF-8')),
},
excludes=[],
ignores=_caffe2_internal_args,
custom_check=dimension_constraint())(inputs[:2], args, params)
use_bias = len(inputs) == 3
if use_bias:
out = _op.nn.bias_add(out, inputs[2])
return out


class Concat(Caffe2OpConverter):
""" Operator converter for Concat.
"""
Expand Down Expand Up @@ -353,12 +386,14 @@ def _get_convert_map():
# caffe2 common operators
'Add': Add.get_converter(),
'Sum': Sum.get_converter(),
'Mul': Mul.get_converter(),
'Softmax': Softmax.get_converter(),

# nn
'AveragePool': AveragePool.get_converter(),
'MaxPool': MaxPool.get_converter(),
'Conv': Conv.get_converter(),
'ConvTranspose': ConvTranspose.get_converter(),
'Concat': Concat.get_converter(),
'FC': FC.get_converter(),
'SpatialBN': SpatialBN.get_converter(),
Expand Down

0 comments on commit 6321fa7

Please sign in to comment.