Skip to content

Commit

Permalink
[CODEGEN][CONTRIB] CoreML codegen (apache#5634)
Browse files Browse the repository at this point in the history
* [CODEGEN][CONTRIB] CoreML codegen

* import coremltools only when it is necessary

* fix pylint errors

* don't import contrib.coreml when using runtime lib

* skip coreml codegen test in CI

* don't register relay.ext.coremlcompiler in __init__.py

* move tvm/contrib/coreml.py to tvm/contrib/target/coreml.py

* use existing transformers for graph partitioning

* skip test only when coremltools is not available

* add check for annotation

* move _register_coreml_op to python/tvm/relay/op/contrib/coreml.py

* skip compile when xcode is unavailable

* relay.op.Op -> tvm.ir.Op

* set USE_COREML on

* refine test
  • Loading branch information
kazum authored and Trevor Morris committed Jun 18, 2020
1 parent e6b2a08 commit 076723d
Show file tree
Hide file tree
Showing 11 changed files with 840 additions and 98 deletions.
171 changes: 171 additions & 0 deletions apps/ios_rpc/tests/ios_rpc_mobilenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import tvm
from tvm import rpc, relay
from tvm.contrib.download import download_testdata
from tvm.relay.expr_functor import ExprMutator
from tvm.relay import transform
from tvm.relay.op.annotation import compiler_begin, compiler_end
from tvm.relay.quantize.quantize import prerequisite_optimize
from tvm.contrib import util, xcode, graph_runtime, coreml_runtime
from tvm.contrib.target import coreml as _coreml

import os
import re
import sys
import numpy as np
from mxnet import gluon
from PIL import Image
import coremltools

# Set to be address of tvm proxy.
proxy_host = os.environ["TVM_IOS_RPC_PROXY_HOST"]
# Set your desination via env variable.
# Should in format "platform=iOS,id=<the test device uuid>"
destination = os.environ["TVM_IOS_RPC_DESTINATION"]

if not re.match(r"^platform=.*,id=.*$", destination):
print("Bad format: {}".format(destination))
print("Example of expected string: platform=iOS,id=1234567890abcabcabcabc1234567890abcabcab")
sys.exit(1)

proxy_port = 9090
key = "iphone"

# Change target configuration, this is setting for iphone6s
#arch = "x86_64"
#sdk = "iphonesimulator"
arch = "arm64"
sdk = "iphoneos"
target_host = "llvm -target=%s-apple-darwin" % arch

# override metal compiler to compile to iphone
@tvm.register_func("tvm_callback_metal_compile")
def compile_metal(src):
return xcode.compile_metal(src, sdk=sdk)

def prepare_input():
img_url = 'https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true'
img_name = 'cat.png'
synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/',
'4d0b62f3d01426887599d4f7ede23ee5/raw/',
'596b27d23537e5a1b5751d2b0481ef172f58b539/',
'imagenet1000_clsid_to_human.txt'])
synset_name = 'imagenet1000_clsid_to_human.txt'
img_path = download_testdata(img_url, 'cat.png', module='data')
synset_path = download_testdata(synset_url, synset_name, module='data')
with open(synset_path) as f:
synset = eval(f.read())
image = Image.open(img_path).resize((224, 224))

image = np.array(image) - np.array([123., 117., 104.])
image /= np.array([58.395, 57.12, 57.375])
image = image.transpose((2, 0, 1))
image = image[np.newaxis, :]
return image.astype('float32'), synset


def get_model(model_name, data_shape):
gluon_model = gluon.model_zoo.vision.get_model(model_name, pretrained=True)
mod, params = relay.frontend.from_mxnet(gluon_model, {"data": data_shape})
# we want a probability so add a softmax operator
func = mod["main"]
func = relay.Function(func.params, relay.nn.softmax(func.body), None, func.type_params, func.attrs)

return func, params


def test_mobilenet():
temp = util.tempdir()
image, synset = prepare_input()
model, params = get_model('mobilenetv2_1.0', image.shape)

def run(mod, target):
with relay.build_config(opt_level=3):
graph, lib, _params = relay.build(mod, target=target,
target_host=target_host, params=params)
path_dso = temp.relpath("deploy.dylib")
lib.export_library(path_dso, xcode.create_dylib, arch=arch, sdk=sdk)
xcode.codesign(path_dso)

# Start RPC test server that contains the compiled library.
xcode.popen_test_rpc(proxy_host, proxy_port, key,
destination=destination, libs=[path_dso])

# connect to the proxy
remote = rpc.connect(proxy_host, proxy_port, key=key)

if target == "metal":
ctx = remote.metal(0)
else:
ctx = remote.cpu(0)
lib = remote.load_module("deploy.dylib")
m = graph_runtime.create(graph, lib, ctx)

m.set_input('data', tvm.nd.array(image, ctx))
m.set_input(**_params)
m.run()
tvm_output = m.get_output(0)
top1 = np.argmax(tvm_output.asnumpy()[0])
print('TVM prediction top-1:', top1, synset[top1])

# evaluate
ftimer = m.module.time_evaluator("run", ctx, number=3, repeat=10)
prof_res = np.array(ftimer().results) * 1000
print("%-19s (%s)" % ("%.2f ms" % np.mean(prof_res), "%.2f ms" % np.std(prof_res)))

def annotate(func, compiler):
"""
An annotator for Core ML.
"""
# Bind free variables to the constant values.
bind_dict = {}
for arg in func.params:
name = arg.name_hint
if name in params:
bind_dict[arg] = relay.const(params[name])

func = relay.bind(func, bind_dict)

# Annotate the entire graph for Core ML
mod = tvm.IRModule()
mod["main"] = func

seq = tvm.transform.Sequential([
transform.SimplifyInference(),
transform.FoldConstant(),
transform.FoldScaleAxis(),
transform.AnnotateTarget(compiler),
transform.MergeCompilerRegions(),
transform.PartitionGraph()
])

with relay.build_config(opt_level=3):
mod = seq(mod)

return mod

# CPU
run(model, target_host)
# Metal
run(model, "metal")
# CoreML
run(annotate(model, "coremlcompiler"), target_host)

if __name__ == "__main__":
test_mobilenet()
10 changes: 4 additions & 6 deletions python/tvm/contrib/coreml_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,12 @@
import tvm._ffi
from ..rpc import base as rpc_base

def create(compiled_model_path, output_names, ctx):
def create(model_dir, ctx):
"""Create a runtime executor module given a coreml model and context.
Parameters
----------
compiled_model_path : str
The path of the compiled model to be deployed.
output_names : list of str
The output names of the model.
model_dir : str
The directory where the compiled models are located.
ctx : TVMContext
The context to deploy the module. It can be local or remote when there
is only one TVMContext.
Expand All @@ -42,7 +40,7 @@ def create(compiled_model_path, output_names, ctx):
else:
fcreate = tvm._ffi.get_global_func(runtime_func)

return CoreMLModule(fcreate(compiled_model_path, ctx, *output_names))
return CoreMLModule(fcreate(model_dir))


class CoreMLModule(object):
Expand Down
18 changes: 18 additions & 0 deletions python/tvm/contrib/target/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Codegen and runtime APIs for targets.
"""
Loading

0 comments on commit 076723d

Please sign in to comment.