Skip to content

Commit

Permalink
add dilations field to onnx importer
Browse files Browse the repository at this point in the history
blacking files

black file
  • Loading branch information
AndrewZhaoLuo committed Apr 29, 2021
1 parent 7f85111 commit 634253a
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 17 deletions.
30 changes: 21 additions & 9 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"""ONNX: Open Neural Network Exchange frontend for Relay."""
import copy
import warnings

import numpy as np
import tvm
from tvm.ir import IRModule
Expand All @@ -28,16 +29,23 @@
from .. import analysis
from .. import expr as _expr
from .. import function as _function
from .. import loops as _loops
from .. import op as _op
from .. import qnn as _qnn
from .. import vision as _vision
from .. import loops as _loops
from .. import ty as _ty

from .common import AttrCvt, Renamer
from .common import get_relay_op, new_var, infer_shape, infer_channels, infer_value, fold_constant
from .common import infer_type, get_name

from .. import vision as _vision
from .common import (
AttrCvt,
Renamer,
fold_constant,
get_name,
get_relay_op,
infer_channels,
infer_shape,
infer_type,
infer_value,
new_var,
)

__all__ = ["from_onnx"]

Expand Down Expand Up @@ -312,8 +320,12 @@ def _impl_v1(cls, inputs, attr, params):

return AttrCvt(
op_name=dimension_picker(cls.name),
transforms={"kernel_shape": "pool_size", "pads": ("padding", 0)},
ignores=["dilations", "storage_order"],
transforms={
"kernel_shape": "pool_size",
"pads": ("padding", 0),
"dilations": ("dilation", 1),
},
ignores=["storage_order"],
custom_check=dimension_constraint(),
)([data], attr, params)

Expand Down
16 changes: 8 additions & 8 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,18 @@
# specific language governing permissions and limitations
# under the License.
import numpy as np
import onnx
from onnx import helper, TensorProto, mapping, numpy_helper
import pytest
import scipy
import torch
import torchvision
import pytest
import tvm.topi.testing
import tvm
import tvm.testing
import tvm.topi.testing
from tvm import relay
from tvm.contrib import graph_executor
import scipy
import tvm.testing

import onnx
from onnx import TensorProto, helper, mapping, numpy_helper


def get_input_data_shape_dict(graph_def, input_data):
Expand Down Expand Up @@ -2696,7 +2697,7 @@ def repeat(N, D):

@tvm.testing.uses_gpu
def test_unsqueeze_constant():
from torch.nn import Linear, Sequential, Module
from torch.nn import Linear, Module, Sequential

class Flatten(Module):
def forward(self, input):
Expand Down Expand Up @@ -4212,7 +4213,6 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"):
"test_isinf_negative/",
"test_isinf_positive/",
"test_matmulinteger/",
"test_maxpool_2d_dilations/",
"test_maxpool_2d_same_lower/",
"test_maxpool_2d_same_upper/",
"test_maxpool_with_argmax_2d_precomputed_pads/",
Expand Down

0 comments on commit 634253a

Please sign in to comment.