Skip to content

Commit

Permalink
move _register_coreml_op to python/tvm/relay/op/contrib/coreml.py
Browse files Browse the repository at this point in the history
  • Loading branch information
kazum committed Jun 1, 2020
1 parent 2d77623 commit 958168b
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 26 deletions.
26 changes: 0 additions & 26 deletions python/tvm/contrib/target/coreml.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@

import tvm._ffi
from ...relay.expr_functor import ExprVisitor
from ...relay.expr import Constant
from ...relay import op as _op
from .. import xcode, coreml_runtime

def _convert_add(builder, name, inputs, outputs, args, attrs):
Expand Down Expand Up @@ -226,27 +224,3 @@ def coreml_compiler(ref):

ctx = tvm.cpu(0)
return coreml_runtime.create(model_dir, ctx).module


def _register_coreml_op(op_name):
"""Register a function to check the given operator is supported by Core ML.
Paramters
---------
op_name : Str
The name of operator that will be registered.
"""
def _check_supported(attrs, args):
if op_name == 'nn.conv2d':
if not isinstance(args[1], Constant):
return False
if attrs['kernel_layout'] not in ['HWIO', 'OIHW']:
return False
return True

_op.register(op_name, "target.coremlcompiler", _check_supported)


for op in _convert_map:
_register_coreml_op(op)
1 change: 1 addition & 0 deletions python/tvm/relay/op/contrib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@
from .register import get_pattern_table, register_pattern_table

from .dnnl import *
from .coreml import *
45 changes: 45 additions & 0 deletions python/tvm/relay/op/contrib/coreml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument
"""CoreML codegen supported operators."""
from ... import op as _op
from ...expr import Constant
from ....contrib.target.coreml import _convert_map


def _register_coreml_op(op_name):
"""Register a function to check the given operator is supported by Core ML.
Paramters
---------
op_name : Str
The name of operator that will be registered.
"""
def _check_supported(attrs, args):
if op_name == 'nn.conv2d':
if not isinstance(args[1], Constant):
return False
if attrs['kernel_layout'] not in ['HWIO', 'OIHW']:
return False
return True

_op.register(op_name, "target.coremlcompiler", _check_supported)


for op in _convert_map:
_register_coreml_op(op)

0 comments on commit 958168b

Please sign in to comment.