Skip to content

Commit

Permalink
[RUNTIME][CONTRIB] CoreML Runtime (apache#5283)
Browse files Browse the repository at this point in the history
* [RUNTIME][CONTRIB] CoreML Runtime

* fix lint

* fix CI

* use xcrun to compile coreml model
  • Loading branch information
kazum authored and Trevor Morris committed Jun 18, 2020
1 parent f837857 commit b15194f
Show file tree
Hide file tree
Showing 9 changed files with 507 additions and 1 deletion.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ tvm_option(USE_ANTLR "Build with ANTLR for Relay parsing" OFF)
tvm_option(USE_CPP_RPC "Build CPP RPC" OFF)
tvm_option(USE_TFLITE "Build with tflite support" OFF)
tvm_option(USE_TENSORFLOW_PATH "TensorFlow root path when use TFLite" none)
tvm_option(USE_COREML "Build with coreml support" OFF)

if(USE_CPP_RPC AND UNIX)
message(FATAL_ERROR "USE_CPP_RPC is only supported with WIN32. Use the Makefile for non-Windows.")
Expand Down Expand Up @@ -316,6 +317,7 @@ include(cmake/modules/contrib/TensorRT.cmake)
include(cmake/modules/contrib/HybridDump.cmake)
include(cmake/modules/contrib/TFLite.cmake)
include(cmake/modules/contrib/TF_TVMDSOOP.cmake)
include(cmake/modules/contrib/CoreML.cmake)

if(NOT MSVC)
include(CheckCXXCompilerFlag)
Expand Down
2 changes: 1 addition & 1 deletion apps/ios_rpc/tvmrpc.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@
);
runOnlyForDeploymentPostprocessing = 0;
shellPath = /bin/sh;
shellScript = "libpath=${CONFIGURATION_BUILD_DIR}/${CONTENTS_FOLDER_PATH}/Frameworks/tvm\nmkdir -p ${libpath}\nrm -rf ${libpath}/*\n \nif [ -f ${SRCROOT}/rpc_config.txt ]; then\n head -n 1 ${SRCROOT}/rpc_config.txt > ${libpath}/rpc_config.txt\n tail -n +2 ${SRCROOT}/rpc_config.txt | xargs -J % cp % ${libpath}\nfi\n\n";
shellScript = "libpath=${CONFIGURATION_BUILD_DIR}/${CONTENTS_FOLDER_PATH}/Frameworks/tvm\nmkdir -p ${libpath}\nrm -rf ${libpath}/*\n \nif [ -f ${SRCROOT}/rpc_config.txt ]; then\n head -n 1 ${SRCROOT}/rpc_config.txt > ${libpath}/rpc_config.txt\n tail -n +2 ${SRCROOT}/rpc_config.txt | xargs -J % cp -r % ${libpath}\nfi\n\n";
};
/* End PBXShellScriptBuildPhase section */

Expand Down
2 changes: 2 additions & 0 deletions apps/ios_rpc/tvmrpc/TVMRuntime.mm
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
// Metal
#include "../../../src/runtime/metal/metal_module.mm"
#include "../../../src/runtime/metal/metal_device_api.mm"
// CoreML
#include "../../../src/runtime/contrib/coreml/coreml_runtime.mm"

namespace dmlc {
// Override logging mechanism
Expand Down
25 changes: 25 additions & 0 deletions cmake/modules/contrib/CoreML.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

if(USE_COREML)
message(STATUS "Build with contrib.coreml")
find_library(FOUNDATION_LIB Foundation)
find_library(COREML_LIB Coreml)
file(GLOB COREML_CONTRIB_SRC src/runtime/contrib/coreml/*.mm)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${FOUNDATION_LIB} ${COREML_LIB})
list(APPEND RUNTIME_SRCS ${COREML_CONTRIB_SRC})
endif(USE_COREML)
71 changes: 71 additions & 0 deletions python/tvm/contrib/coreml_runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""CoreML runtime that load and run coreml models."""
import tvm._ffi
from ..rpc import base as rpc_base

def create(compiled_model_path, output_names, ctx):
"""Create a runtime executor module given a coreml model and context.
Parameters
----------
compiled_model_path : str
The path of the compiled model to be deployed.
output_names : list of str
The output names of the model.
ctx : TVMContext
The context to deploy the module. It can be local or remote when there
is only one TVMContext.
Returns
-------
coreml_runtime : CoreMLModule
Runtime coreml module that can be used to execute the coreml model.
"""
device_type = ctx.device_type
runtime_func = "tvm.coreml_runtime.create"

if device_type >= rpc_base.RPC_SESS_MASK:
fcreate = ctx._rpc_sess.get_function(runtime_func)
else:
fcreate = tvm._ffi.get_global_func(runtime_func)

return CoreMLModule(fcreate(compiled_model_path, ctx, *output_names))


class CoreMLModule(object):
"""Wrapper runtime module.
This is a thin wrapper of the underlying TVM module.
you can also directly call set_input, run, and get_output
of underlying module functions
Parameters
----------
module : Module
The internal tvm module that holds the actual coreml functions.
Attributes
----------
module : Module
The internal tvm module that holds the actual coreml functions.
"""

def __init__(self, module):
self.module = module
self.invoke = module["invoke"]
self.set_input = module["set_input"]
self.get_output = module["get_output"]
self.get_num_outputs = module["get_num_outputs"]
11 changes: 11 additions & 0 deletions python/tvm/contrib/xcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,17 @@ def compile_metal(code, path_target=None, sdk="macosx"):
return libbin


def compile_coreml(model, out_dir="."):
"""Compile coreml model and return the compiled model path.
"""
mlmodel_path = os.path.join(out_dir, "tmp.mlmodel")
model.save(mlmodel_path)

xcrun(["coremlcompiler", "compile", mlmodel_path, out_dir])

return os.path.join(out_dir, "tmp.mlmodelc")


class XCodeRPCServer(object):
"""Wrapper for RPC server
Expand Down
116 changes: 116 additions & 0 deletions src/runtime/contrib/coreml/coreml_runtime.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \brief CoreML runtime that can run coreml model
* containing only tvm PackedFunc.
* \file coreml_runtime.h
*/
#ifndef TVM_RUNTIME_CONTRIB_COREML_COREML_RUNTIME_H_
#define TVM_RUNTIME_CONTRIB_COREML_COREML_RUNTIME_H_

#import <Foundation/Foundation.h>
#import <CoreML/CoreML.h>

#include <dlpack/dlpack.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>

#include <vector>
#include <string>
#include <memory>

namespace tvm {
namespace runtime {

/*!
* \brief CoreML runtime.
*
* This runtime can be accessed in various language via
* TVM runtime PackedFunc API.
*/
class CoreMLRuntime : public ModuleNode {
public:
/*!
* \brief Get member function to front-end.
* \param name The name of the function.
* \param sptr_to_self The pointer to the module node.
* \return The corresponding member function.
*/
virtual PackedFunc GetFunction(const std::string& name,
const ObjectPtr<Object>& sptr_to_self);

/*!
* \return The type key of the executor.
*/
const char* type_key() const {
return "CoreMLRuntime";
}

/*!
* \brief Invoke the coreml prediction.
*/
void Invoke();

/*!
* \brief Initialize the coreml runtime with coreml model and context.
* \param model_path The compiled model path.
* \param ctx The context where the coreml model will be executed on.
* \param output_names The output names of the model.
*/
void Init(const std::string& model_path,
TVMContext ctx,
const std::vector<NSString *>& output_names);

/*!
* \brief set input to the model.
* \param key The input name.
* \param data_in The input data.
*/
void SetInput(const std::string& key, DLTensor* data_in);
/*!
* \brief Return NDArray for given output index.
* \param index The output index.
*
* \return NDArray corresponding to given output node index.
*/
NDArray GetOutput(int index) const;
/*!
* \brief Return the number of outputs
*
* \return The number of outputs
*/
int GetNumOutputs() const;

// CoreML model
MLModel *model_;
// CoreML model input dictionary
NSMutableDictionary<NSString *, id> *input_dict_;
// CoreML model output
id<MLFeatureProvider> output_;
// List of output names
std::vector<NSString *> output_names_;
// TVM context
TVMContext ctx_;
};

} // namespace runtime
} // namespace tvm

#endif // TVM_RUNTIME_CONTRIB_COREML_COREML_RUNTIME_H_
Loading

0 comments on commit b15194f

Please sign in to comment.