Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RUNTIME] Implement TVMDSOOp(TensorFlow custom op) for TVM runtime #4459

Merged
merged 44 commits into from
Apr 7, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
73726a8
Add implementation of TVMDSOOp
tobegit3hub Dec 3, 2019
9c3f732
feat: Update cmake script to work with c++11 and in-repo build
Dec 4, 2019
8361217
feat: Use libtvm as oplib dependency
Dec 4, 2019
ef387fd
fix: Add missing link dependency to libtvm
Dec 4, 2019
41d9b38
Merge branch 'master' into add_tvmdsoop
tobegit3hub Dec 5, 2019
056d4b1
feat: Update tf tvmdso op by review comments
Jan 2, 2020
9fb271c
Fix conflict of master
tobegit3hub Jan 6, 2020
a6d9c28
Merge branch 'fix/add_tvmdsoop_fix_review' into 'feat/add_tvmdsoop'
Jan 6, 2020
69df549
fix: Update with pr comments
Jan 16, 2020
23ce5b2
Merge branch 'feat/UpdatePR' into 'feat/add_tvmdsoop'
Jan 16, 2020
c2e72a6
fix: Fix lint
Jan 16, 2020
e390c49
Merge branch 'fix/FixLint' into 'feat/add_tvmdsoop'
Jan 16, 2020
729a5dc
feat: Add test script and fix gpu shape
Mar 5, 2020
502b03f
Merge branch 'fix/AddTestCaseAndFixGpuShape' into 'feat/add_tvmdsoop'
Mar 5, 2020
7f1b5b3
feat: Add test script and fix gpu shape
Mar 5, 2020
0debcd6
fix: Conditional build tftvm op for gpu
Mar 6, 2020
9a9cf41
Merge branch 'feat/AddTfOpGpuConditionalBuild' into 'feat/add_tvmdsoop'
Mar 6, 2020
f5be2f3
fix: Conditional build tftvm op for gpu
Mar 6, 2020
7eae33d
fix: Fix pylint of tf_op module.py
Mar 9, 2020
b064685
Merge branch 'fix/FixTfdsoOpLint' into 'feat/add_tvmdsoop'
Mar 9, 2020
cd8fd80
fix: Fix pylint of tf_op module.py
Mar 9, 2020
1fc54ec
feat: Conditional enable gpu test for tftvm op
Mar 10, 2020
cfcdc41
Merge branch 'feat/EnableConditionalGpuTfOpTest' into 'feat/add_tvmds…
Mar 10, 2020
b4b9f96
feat: Conditional enable gpu test for tftvm op
Mar 10, 2020
74c3d3d
Merge branch upstream master into feat/add_tvmdsoop
Mar 16, 2020
864c4a5
feat: Add tf_tvmdsoop test script as an app test
Mar 16, 2020
29003e9
Merge branch 'feat/AddTFOpToIntegrateTests' into 'feat/add_tvmdsoop'
Mar 16, 2020
03fbe4c
Merge branch '4pd_add_tvmdsoop' into add_tvmdsoop
Mar 16, 2020
ec05511
fix: Fix gpu/cpu enabled check on tvm in test script
Mar 17, 2020
51ed779
fix: Make tf tvmdso op test script runnable with pytest
Mar 17, 2020
685f7d0
remove unused test script test_tfop_module.py
Mar 17, 2020
ea6328b
fix: Remove pushd & popd in tfdsoop test script
Mar 18, 2020
0ae0942
fix: Upgrade tftvmop use python3 to find TensorFlow
Mar 18, 2020
9594700
fix: Upgrade tftvmop use python3 to find TensorFlow
Mar 18, 2020
8a5d2fb
fix: Change target_link_options to target_link_libraries
Mar 18, 2020
b8fbd2e
fix: Add tftvmop build script's c++ option
Mar 19, 2020
380e1d7
fix: Add tvm library path to tf op test library path
Mar 19, 2020
0b3884d
fix: Debug ci build for tftvm dso op
Mar 19, 2020
9dc89c6
Merge branch 'master' into add_tvmdsoop
tobegit3hub Mar 25, 2020
9fd18b8
fix: Fix cmake error and skip tfop test
Apr 1, 2020
9967748
Merge branch 'add_tvmdsoop' of github.com:4paradigm/incubator-tvm int…
Apr 1, 2020
8ac182f
fix: Fix typo and indentation issues
Apr 2, 2020
9343669
feat: Use TF list input op def
Apr 3, 2020
38af1ce
fix: Fix style and unexpected changes
Apr 7, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ tvm_option(USE_MSVC_MT "Build with MT" OFF)
tvm_option(USE_MICRO "Build with Micro" OFF)
tvm_option(INSTALL_DEV "Install compiler infrastructure" OFF)
tvm_option(HIDE_PRIVATE_SYMBOLS "Compile with -fvisibility=hidden." OFF)
tvm_option(USE_TFOP "Build with TensorFlow TVMDSOOp" OFF)
tobegit3hub marked this conversation as resolved.
Show resolved Hide resolved

# 3rdparty libraries
tvm_option(DLPACK_PATH "Path to DLPACK" "3rdparty/dlpack/include")
Expand Down Expand Up @@ -260,6 +261,7 @@ include(cmake/modules/contrib/Sort.cmake)
include(cmake/modules/contrib/NNPack.cmake)
include(cmake/modules/contrib/HybridDump.cmake)
include(cmake/modules/contrib/TFLite.cmake)
include(cmake/modules/contrib/TFOP.cmake)

if(NOT MSVC)
include(CheckCXXCompilerFlag)
Expand Down
3 changes: 3 additions & 0 deletions cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,6 @@ set(USE_VTA_FPGA OFF)

# Whether to build the example external runtime module
set(USE_EXAMPLE_EXT_RUNTIME OFF)

# Whether to build the TensorFlow TVMDSOOp module
set(USE_TFOP OFF)
46 changes: 46 additions & 0 deletions cmake/modules/contrib/TFOP.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@

if(NOT USE_TFOP STREQUAL "OFF")

# If want build this directly comment out below lines.
# if ("${TVM_HOME}" STREQUAL "")
# message(FATAL_ERROR "TVM_HOME is not defined")
# else()
# message("Use TVM_HOME=\"${TVM_HOME}\"")
#endif()
# include_directories(${TVM_HOME}/include)
# include_directories(${TVM_HOME}/3rdparty/dlpack/include)
# include_directories(${TVM_HOME}/3rdparty/dmlc-core/include)
# set(TFTVM_LINK_FLAGS -ltvm_runtime -L${TVM_HOME}/build)

execute_process(COMMAND python -c "import tensorflow as tf; print(' '.join(tf.sysconfig.get_compile_flags()))"
tobegit3hub marked this conversation as resolved.
Show resolved Hide resolved
OUTPUT_VARIABLE TF_COMPILE_FLAGS_STR
RESULT_VARIABLE TF_STATUS)
if (NOT ${TF_STATUS} EQUAL 0)
message(FATAL_ERROR "Fail to get TensorFlow compile flags")
endif()

execute_process(COMMAND python -c "import tensorflow as tf; print(' '.join(tf.sysconfig.get_link_flags()))"
OUTPUT_VARIABLE TF_LINK_FLAGS_STR
RESULT_VARIABLE TF_STATUS)
if (NOT ${TF_STATUS} EQUAL 0)
message(FATAL_ERROR "Fail to get TensorFlow link flags")
endif()

string(REGEX REPLACE "\n" " " TF_FLAGS "${TF_COMPILE_FLAGS} ${TF_LINK_FLAGS}")
separate_arguments(TF_COMPILE_FLAGS UNIX_COMMAND ${TF_COMPILE_FLAGS_STR})
separate_arguments(TF_LINK_FLAGS UNIX_COMMAND ${TF_LINK_FLAGS_STR})


set(OP_LIBRARY_NAME tvm_dso_op)
file(GLOB_RECURSE TFTVM_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/tf_op/*.cc)
add_library(${OP_LIBRARY_NAME} SHARED ${TFTVM_SRCS})
set_target_properties(${OP_LIBRARY_NAME} PROPERTIES PREFIX "")
set(TFTVM_LINK_FLAGS -ltvm -L${CMAKE_CURRENT_BINARY_DIR})
tobegit3hub marked this conversation as resolved.
Show resolved Hide resolved
add_dependencies(${OP_LIBRARY_NAME} tvm)

set(TFTVM_COMPILE_FLAGS -O2 -ldl -g)
tobegit3hub marked this conversation as resolved.
Show resolved Hide resolved
target_compile_options(${OP_LIBRARY_NAME} PUBLIC ${TFTVM_COMPILE_FLAGS} ${TF_COMPILE_FLAGS})
target_link_options(${OP_LIBRARY_NAME} PUBLIC ${TFTVM_LINK_FLAGS} ${TF_LINK_FLAGS})

endif()

20 changes: 20 additions & 0 deletions python/tvm/contrib/tf_op/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from . import module

Module = module.Module
104 changes: 104 additions & 0 deletions python/tvm/contrib/tf_op/module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tensorflow as tf
from tensorflow.python.framework import load_library


class Module():
tobegit3hub marked this conversation as resolved.
Show resolved Hide resolved
tobegit3hub marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, lib_path):
self.lib_path = lib_path

def func(self, name, output_dtype=None, output_shape=None):
tobegit3hub marked this conversation as resolved.
Show resolved Hide resolved
return Func(self.lib_path, name, output_dtype, output_shape)

def __getitem__(self, func_name):
return self.func(func_name)


class Func():
tobegit3hub marked this conversation as resolved.
Show resolved Hide resolved
tobegit3hub marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, lib_path, func_name, output_dtype, output_shape):
self.lib_path = lib_path
self.func_name = func_name
self.output_dtype = output_dtype

# const(0) indicate invalid dynamic shape
self.dynamic_output_shape = tf.constant(0, tf.int64)
self.static_output_shape = None
self.has_static_output_shape = False # extra flag is required

if self._is_static_shape(output_shape):
self.static_output_shape = output_shape
self.has_static_output_shape = True
elif output_shape is not None:
self.dynamic_output_shape = self._pack_shape_tensor(output_shape)

# TODO: support non-xpu device
#self.device = device
tobegit3hub marked this conversation as resolved.
Show resolved Hide resolved
# delay initialization to called first time, where num input arguments is known
tobegit3hub marked this conversation as resolved.
Show resolved Hide resolved
self.tvm_dso_op = None
self.module = load_library.load_op_library('tvm_dso_op.so')

def apply(self, *params):
if self.tvm_dso_op is None:
num_inputs = len(params)
self.tvm_dso_op = getattr(self.module, "tvm_dso_op%s" % num_inputs)

return self.tvm_dso_op(*params,
dynamic_output_shape=self.dynamic_output_shape,
static_output_shape=self.static_output_shape,
has_static_output_shape=self.has_static_output_shape,
lib_path=self.lib_path,
func_name=self.func_name,
output_dtype=self.output_dtype)

def __call__(self, *params):
return self.apply(*params)

def _is_static_shape(self, shape):
if shape is None or not isinstance(shape, list):
return False
for d in shape:
if not isinstance(d, int):
return False
if d < 0:
tobegit3hub marked this conversation as resolved.
Show resolved Hide resolved
raise Exception("Negative dimension is illegal: %d" % d)
return True

def _pack_shape_tensor(self, shape):
if isinstance(shape, tf.Tensor):
if shape.dtype == tf.int32:
shape = tf.cast(shape, tf.int64)
return shape
elif isinstance(shape, list):
shape_dims = []
for d in shape:
if isinstance(d, int):
shape_dims.append(tf.constant(d, tf.int64))
elif isinstance(d, tf.Tensor) and len(d.shape) == 0:
if d.dtype == tf.int32:
d = tf.cast(d, tf.int64)
shape_dims.append(d)
else:
raise TypeError("Input shape dimension is neither scala tensor nor int")
tobegit3hub marked this conversation as resolved.
Show resolved Hide resolved
return tf.stack(shape_dims)
else:
raise TypeError("Input shape is neither tensor nor list")



44 changes: 44 additions & 0 deletions src/contrib/tf_op/index_seq.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/**
tobegit3hub marked this conversation as resolved.
Show resolved Hide resolved
* Refer to std::index_sequence (since c++14)
* Utilities to invoke variadic function with template <size_t N>
*/
#ifndef __TFTVM_INDEX_SEQ__
#define __TFTVM_INDEX_SEQ__

template <std::size_t ...>
tobegit3hub marked this conversation as resolved.
Show resolved Hide resolved
struct IndexSeq {};

template <std::size_t N, std::size_t ... Tail>
struct IndexSeqHelper : public IndexSeqHelper<N-1U, N-1U, Tail...> {};

template <std::size_t ... Tail>
struct IndexSeqHelper<0U, Tail ...> {
using type = IndexSeq<Tail ...>;
};

template <std::size_t N>
using make_index_sequence = typename IndexSeqHelper<N>::type;


template <typename F, typename T, std::size_t N, std::size_t... Idx>
void apply_variadic_impl(F f, T(&t)[N], IndexSeq<Idx...>) {
f(t[Idx]...);
}

template <typename F, typename T, std::size_t N>
void apply_variadic(F f, T(&t)[N]) {
apply_variadic_impl(f, t, make_index_sequence<N>{});
}

template <typename F, typename T, std::size_t N, std::size_t... Idx>
void apply_variadic_by_ptrs_impl(F f, T(&t)[N], IndexSeq<Idx...>) {
f(&t[Idx]...);
}

template <typename F, typename T, std::size_t N>
void apply_variadic_by_ptrs(F f, T(&t)[N]) {
apply_variadic_by_ptrs_impl(f, t, make_index_sequence<N>{});
}

#endif

Loading