diff --git a/CMakeLists.txt b/CMakeLists.txt index 74288029d020..6993f6727871 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -41,6 +41,7 @@ tvm_option(USE_MSVC_MT "Build with MT" OFF) tvm_option(USE_MICRO "Build with Micro" OFF) tvm_option(INSTALL_DEV "Install compiler infrastructure" OFF) tvm_option(HIDE_PRIVATE_SYMBOLS "Compile with -fvisibility=hidden." OFF) +tvm_option(USE_TF_TVMDSOOP "Build with TensorFlow TVMDSOOp" OFF) # 3rdparty libraries tvm_option(DLPACK_PATH "Path to DLPACK" "3rdparty/dlpack/include") @@ -259,6 +260,7 @@ include(cmake/modules/contrib/Sort.cmake) include(cmake/modules/contrib/NNPack.cmake) include(cmake/modules/contrib/HybridDump.cmake) include(cmake/modules/contrib/TFLite.cmake) +include(cmake/modules/contrib/TF_TVMDSOOP.cmake) if(NOT MSVC) include(CheckCXXCompilerFlag) diff --git a/apps/tf_tvmdsoop/CMakeLists.txt b/apps/tf_tvmdsoop/CMakeLists.txt new file mode 100644 index 000000000000..cb601ef6d30d --- /dev/null +++ b/apps/tf_tvmdsoop/CMakeLists.txt @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +cmake_minimum_required(VERSION 3.2) +project(tf_tvmdsoop C CXX) + +set(TFTVM_COMPILE_FLAGS -std=c++11) +set(BUILD_TVMDSOOP_ONLY ON) +set(CMAKE_CURRENT_SOURCE_DIR ${TVM_ROOT}) +set(CMAKE_CURRENT_BINARY_DIR ${TVM_ROOT}/build) + +include_directories(${TVM_ROOT}/3rdparty/dlpack/include/) +include_directories(${TVM_ROOT}/3rdparty/dmlc-core/include/) +include_directories(${TVM_ROOT}/include) + +link_directories(${TVM_ROOT}/build) + +include(${TVM_ROOT}/cmake/util/FindCUDA.cmake) +include(${TVM_ROOT}/cmake/modules/CUDA.cmake) + +include(${TVM_ROOT}/cmake/modules/contrib/TF_TVMDSOOP.cmake) diff --git a/apps/tf_tvmdsoop/prepare_and_test_tfop_module.sh b/apps/tf_tvmdsoop/prepare_and_test_tfop_module.sh new file mode 100644 index 000000000000..2bde4f87c84e --- /dev/null +++ b/apps/tf_tvmdsoop/prepare_and_test_tfop_module.sh @@ -0,0 +1,35 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +TVM_ROOT=$(cd $(dirname $0)/../..; pwd) +echo "TVM_ROOT=${TVM_ROOT}" + +export PYTHONPATH=${TVM_ROOT}/python + +python3 -c "import tvm; print(tvm.runtime.enabled('gpu'))" | grep -e 1 +if [ "$?" -eq 0 ]; then + echo "Build TF_TVMDSOOP with gpu support and execute tests" + CMAKE_OPTIONS="-DUSE_CUDA=ON -DPython3_EXECUTABLE=python3 -DTVM_ROOT=${TVM_ROOT}" + + mkdir -p build + cd build; cmake .. ${CMAKE_OPTIONS} && make + cd .. + + LD_LIBRARY_PATH=${TVM_ROOT}/build:./build:$LD_LIBRARY_PATH python3 -m pytest -v ./tests +fi + diff --git a/apps/tf_tvmdsoop/tests/test_tfop_module.py b/apps/tf_tvmdsoop/tests/test_tfop_module.py new file mode 100644 index 000000000000..1672b58fd60a --- /dev/null +++ b/apps/tf_tvmdsoop/tests/test_tfop_module.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test script for tf op module""" +import tempfile +import os +import logging +import tensorflow as tf +import numpy as np +import tvm +from tvm import te +from tvm.contrib import tf_op + + +def test_use_tvmdso_op(): + """main test function""" + + def export_cpu_add_lib(): + """create cpu add op lib""" + n = te.var("n") + ph_a = te.placeholder((n,), name='ph_a') + ph_b = te.placeholder((n,), name='ph_b') + ph_c = te.compute(ph_a.shape, lambda i: ph_a[i] + ph_b[i], name='ph_c') + sched = te.create_schedule(ph_c.op) + fadd_dylib = tvm.build(sched, [ph_a, ph_b, ph_c], "c", name="vector_add") + lib_path = tempfile.mktemp("tvm_add_dll.so") + fadd_dylib.export_library(lib_path) + return lib_path + + + def export_gpu_add_lib(): + """create gpu add op lib""" + n = te.var("n") + ph_a = te.placeholder((n,), name='ph_a') + ph_b = te.placeholder((n,), name='ph_b') + ph_c = te.compute(ph_a.shape, lambda i: ph_a[i] + ph_b[i], name='ph_c') + sched = te.create_schedule(ph_c.op) + b_axis, t_axis = sched[ph_c].split(ph_c.op.axis[0], factor=64) + sched[ph_c].bind(b_axis, te.thread_axis("blockIdx.x")) + sched[ph_c].bind(t_axis, te.thread_axis("threadIdx.x")) + fadd_dylib = tvm.build(sched, [ph_a, ph_b, ph_c], "cuda", name="vector_add") + lib_path = tempfile.mktemp("tvm_add_cuda_dll.so") + fadd_dylib.export_library(lib_path) + return lib_path + + + def test_add(session, lib_path, tf_device): + """test add lib with TensorFlow wrapper""" + module = tf_op.OpModule(lib_path) + + left = tf.placeholder("float32", shape=[4]) + right = tf.placeholder("float32", shape=[4]) + + feed_dict = {left: [1.0, 2.0, 3.0, 4.0], right: [5.0, 6.0, 7.0, 8.0]} + expect = np.asarray([6.0, 8.0, 10.0, 12.0]) + + add1 = module.func("vector_add", output_shape=[4], output_dtype="float") + add2 = module.func("vector_add", output_shape=tf.shape(left), output_dtype="float") + add3 = module.func("vector_add", output_shape=[tf.shape(left)[0]], output_dtype="float") + + with tf.device(tf_device): + output1 = session.run(add1(left, right), feed_dict) + np.testing.assert_equal(output1, expect) + + output2 = session.run(add2(left, right), feed_dict) + np.testing.assert_equal(output2, expect) + + output3 = session.run(add3(left, right), feed_dict) + np.testing.assert_equal(output3, expect) + + + def cpu_test(session): + """test function for cpu""" + cpu_lib = None + try: + cpu_lib = export_cpu_add_lib() + test_add(session, cpu_lib, "/cpu:0") + finally: + if cpu_lib is not None: + os.remove(cpu_lib) + + + def gpu_test(session): + """test function for gpu""" + gpu_lib = None + try: + gpu_lib = export_gpu_add_lib() + test_add(session, gpu_lib, "/gpu:0") + finally: + if gpu_lib is not None: + os.remove(gpu_lib) + + with tf.Session() as session: + if tvm.runtime.enabled("cpu"): + logging.info("Test TensorFlow op on cpu kernel") + cpu_test(session) + if tvm.runtime.enabled("gpu"): + logging.info("Test TensorFlow op on gpu kernel") + gpu_test(session) + + +if __name__ == "__main__": + test_use_tvmdso_op() diff --git a/cmake/config.cmake b/cmake/config.cmake index 6ab362c899a3..04b6c3b66ea4 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -204,3 +204,7 @@ set(USE_EXAMPLE_EXT_RUNTIME OFF) # Whether use Thrust set(USE_THRUST OFF) + +# Whether to build the TensorFlow TVMDSOOp module +set(USE_TF_TVMDSOOP OFF) + diff --git a/cmake/modules/contrib/TF_TVMDSOOP.cmake b/cmake/modules/contrib/TF_TVMDSOOP.cmake new file mode 100644 index 000000000000..e92822a397ae --- /dev/null +++ b/cmake/modules/contrib/TF_TVMDSOOP.cmake @@ -0,0 +1,58 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +if(NOT USE_TF_TVMDSOOP STREQUAL "OFF") + find_package(Python3 COMPONENTS Interpreter) + + execute_process(COMMAND ${Python3_EXECUTABLE} -c "import tensorflow as tf; print(' '.join(tf.sysconfig.get_compile_flags()))" + OUTPUT_VARIABLE TF_COMPILE_FLAGS_STR + RESULT_VARIABLE TF_STATUS) + if (NOT ${TF_STATUS} EQUAL 0) + message(FATAL_ERROR "Fail to get TensorFlow compile flags") + endif() + + if(NOT USE_CUDA STREQUAL "OFF") + add_definitions(-DTF_TVMDSOOP_ENABLE_GPU) + endif() + + execute_process(COMMAND ${Python3_EXECUTABLE} -c "import tensorflow as tf; print(' '.join(tf.sysconfig.get_link_flags()))" + OUTPUT_VARIABLE TF_LINK_FLAGS_STR + RESULT_VARIABLE TF_STATUS) + if (NOT ${TF_STATUS} EQUAL 0) + message(FATAL_ERROR "Fail to get TensorFlow link flags") + endif() + + string(REGEX REPLACE "\n" " " TF_FLAGS "${TF_COMPILE_FLAGS} ${TF_LINK_FLAGS}") + separate_arguments(TF_COMPILE_FLAGS UNIX_COMMAND ${TF_COMPILE_FLAGS_STR}) + separate_arguments(TF_LINK_FLAGS UNIX_COMMAND ${TF_LINK_FLAGS_STR}) + + + set(OP_LIBRARY_NAME tvm_dso_op) + file(GLOB_RECURSE TFTVM_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/tf_op/*.cc) + add_library(${OP_LIBRARY_NAME} SHARED ${TFTVM_SRCS}) + set_target_properties(${OP_LIBRARY_NAME} PROPERTIES PREFIX "") + set(TFTVM_LINK_FLAGS -ltvm -L${CMAKE_CURRENT_BINARY_DIR}) + + if (NOT BUILD_TVMDSOOP_ONLY STREQUAL "ON") + add_dependencies(${OP_LIBRARY_NAME} tvm) + endif() + + target_compile_options(${OP_LIBRARY_NAME} PUBLIC ${TFTVM_COMPILE_FLAGS} ${TF_COMPILE_FLAGS}) + target_link_libraries(${OP_LIBRARY_NAME} PUBLIC ${TFTVM_LINK_FLAGS} ${TF_LINK_FLAGS}) + +endif() + diff --git a/python/tvm/contrib/tf_op/__init__.py b/python/tvm/contrib/tf_op/__init__.py new file mode 100644 index 000000000000..05d0ecc1ddc1 --- /dev/null +++ b/python/tvm/contrib/tf_op/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Module container of TensorFlow TVMDSO op""" +from . import module + +OpModule = module.OpModule diff --git a/python/tvm/contrib/tf_op/module.py b/python/tvm/contrib/tf_op/module.py new file mode 100644 index 000000000000..f13670e39895 --- /dev/null +++ b/python/tvm/contrib/tf_op/module.py @@ -0,0 +1,113 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Module container of TensorFlow TVMDSO op""" +import tensorflow as tf +from tensorflow.python.framework import load_library + + +class OpModule: + """Module container of TensorFlow TVMDSO op which wraps exported + TVM op implementation library to be called on TensorFlow side""" + + def __init__(self, lib_path): + self.lib_path = lib_path + + def func(self, name, output_dtype=None, output_shape=None): + """Get tvm op function wrapped as TensorFlow tensor to tensor function + + Parameters + ---------- + name: str + function name + output_dtype: str or TensorFlow datatype + Output datatype, default is float32 + output_shape: List of integer/tf scalar tensor or tf shape tensor + Output shape, default the same with first input's shape + + Returns + ---------- + Func object that acts as TensorFlow tensor to tensor function. + """ + return TensorFunc(self.lib_path, name, output_dtype, output_shape) + + def __getitem__(self, func_name): + return self.func(func_name) + + +class TensorFunc: + """Function object that acts as TensorFlow tensor to tensor function.""" + + def __init__(self, lib_path, func_name, output_dtype, output_shape): + self.lib_path = lib_path + self.func_name = func_name + self.output_dtype = output_dtype + + # const(0) indicate invalid dynamic shape + self.dynamic_output_shape = tf.constant(0, tf.int64) + self.static_output_shape = None + self.has_static_output_shape = False # extra flag is required + + if self._is_static_shape(output_shape): + self.static_output_shape = output_shape + self.has_static_output_shape = True + elif output_shape is not None: + self.dynamic_output_shape = self._pack_shape_tensor(output_shape) + + self.module = load_library.load_op_library('tvm_dso_op.so') + self.tvm_dso_op = self.module.tvm_dso_op + + def apply(self, *params): + return self.tvm_dso_op(params, + dynamic_output_shape=self.dynamic_output_shape, + static_output_shape=self.static_output_shape, + has_static_output_shape=self.has_static_output_shape, + lib_path=self.lib_path, + func_name=self.func_name, + output_dtype=self.output_dtype) + + def __call__(self, *params): + return self.apply(*params) + + def _is_static_shape(self, shape): + if shape is None or not isinstance(shape, list): + return False + for dim_value in shape: + if not isinstance(dim_value, int): + return False + if dim_value < 0: + raise Exception("Negative dimension is illegal: %d" % dim_value) + return True + + def _pack_shape_tensor(self, shape): + if isinstance(shape, tf.Tensor): + if shape.dtype == tf.int32: + shape = tf.cast(shape, tf.int64) + elif isinstance(shape, list): + shape_dims = [] + for dim_value in shape: + if isinstance(dim_value, int): + shape_dims.append(tf.constant(dim_value, tf.int64)) + elif isinstance(dim_value, tf.Tensor) and dim_value.shape.rank == 0: + if dim_value.dtype == tf.int32: + dim_value = tf.cast(dim_value, tf.int64) + shape_dims.append(dim_value) + else: + raise TypeError("Input shape dimension is neither scalar tensor nor int") + shape = tf.stack(shape_dims) + else: + raise TypeError("Input shape is neither tensor nor list") + return shape diff --git a/src/contrib/tf_op/tvm_dso_op_kernels.cc b/src/contrib/tf_op/tvm_dso_op_kernels.cc new file mode 100644 index 000000000000..d74d8fb917e5 --- /dev/null +++ b/src/contrib/tf_op/tvm_dso_op_kernels.cc @@ -0,0 +1,310 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifdef TF_TVMDSOOP_ENABLE_GPU +#include +#endif +#include +#include +#include +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; +typedef tensorflow::gtl::InlinedVector ShapeContainer; + +using tensorflow::OpKernel; +using tensorflow::OpKernelConstruction; +using tensorflow::OpKernelContext; + +using tvm::runtime::TVMArgs; +using tvm::runtime::TVMArgsSetter; +using tvm::runtime::TVMRetValue; + +// Op utility trait for diffrent device type template +template +class TVMDSOOpTrait; + +// Buffer information used for actual computation. +// Each buffer is associated with one TensorFlow tensor +// whose underlying buffer is record into "origin_buf". +// For input tensor, we copy data from origin_buf to buf +// and for output tensor, copy data from buf to origin_buf +class TensorAsBuf { + public: + tensorflow::Tensor inline_tensor; + tensorflow::Tensor* tensor; + + size_t size; + size_t offset; + + int device_type; + + char* origin_buf; + char* buf; + + void CopyToOrigin() { + if (buf == origin_buf) { + return; + } + if (device_type == kDLCPU) { + memcpy(origin_buf, buf + offset, size); +#ifdef TF_TVMDSOOP_ENABLE_GPU + } else if (device_type == kDLGPU) { + cudaMemcpy(origin_buf, buf + offset, size, cudaMemcpyDeviceToDevice); +#endif + } else { + LOG(FATAL) << "Only support CPU and CUDA now. Device " << device_type + << " is not implemented currently"; + } + } + + void CopyFromOrigin() { + if (buf == origin_buf) { + return; + } + if (device_type == kDLCPU) { + memcpy(buf + offset, origin_buf, size); +#ifdef TF_TVMDSOOP_ENABLE_GPU + } else if (device_type == kDLGPU) { + cudaMemcpy(buf + offset, origin_buf, size, cudaMemcpyDeviceToDevice); +#endif + } else { + LOG(FATAL) << "Only support CPU and CUDA now. Device " << device_type + << " is not implemented currently"; + } + } +}; + +tensorflow::Status GetDLPackDtype(const tensorflow::Tensor& tf_tensor, DLDataType* res) { + auto dtype = tf_tensor.dtype(); + if (dtype == tensorflow::DT_FLOAT) { + *res = {kDLFloat, 32, 1}; + } else if (dtype == tensorflow::DT_INT64) { + *res = {kDLInt, 64, 1}; + } else if (dtype == tensorflow::DT_INT32) { + *res = {kDLInt, 32, 1}; + } else { + return tensorflow::Status(tensorflow::error::INTERNAL, "Fail to get dlpack datatype"); + } + return tensorflow::Status::OK(); +} + +// Ensure buffer used for actual computation take 64byte alignment +void EnsureAlignment(OpKernelContext* ctx, const tensorflow::Tensor& tensor, TensorAsBuf* out) { + char* buf = const_cast(tensor.tensor_data().data()); + out->origin_buf = buf; + out->size = tensor.TotalBytes(); + + int alignment = 64; + char* aligned = reinterpret_cast(((uint64_t)buf + alignment - 1) & (~(alignment - 1))); + if (buf == aligned) { + out->tensor = const_cast(&tensor); + out->buf = buf; + out->offset = 0; + } else { + tensorflow::TensorShape buf_shape; + tensorflow::int64 dims[1] = {(tensorflow::int64)(tensor.TotalBytes() + alignment)}; + tensorflow::TensorShapeUtils::MakeShape(dims, 1, &buf_shape); + + out->tensor = &out->inline_tensor; + ctx->allocate_temp(tensor.dtype(), buf_shape, out->tensor); + + buf = const_cast(out->tensor->tensor_data().data()); + char* buf_aligned = reinterpret_cast(((uint64_t)buf + alignment) & (~(alignment - 1))); + out->buf = buf; + out->offset = buf_aligned - buf; + } +} + +// Create DLPack tensor from TensorFlow tensor +tensorflow::Status MakeDLTensor(const TensorAsBuf& src, const DLContext& ctx, int64_t* tf_shape, + DLTensor* out) { + DLDataType dlpack_type; + const tensorflow::Tensor& tensor = *src.tensor; + + auto status = GetDLPackDtype(tensor, &dlpack_type); + if (!status.ok()) { + return status; + } + out->ctx = ctx; + out->ndim = tensor.shape().dims(); + out->shape = tf_shape; + out->strides = nullptr; + out->byte_offset = 0; + out->dtype = dlpack_type; + out->data = src.buf + src.offset; + return tensorflow::Status::OK(); +} + +template <> +class TVMDSOOpTrait { + public: + static const int device_type = kDLCPU; + + static int device_id(OpKernelContext* context) { return 0; } + + static void make_shape_from_tensor(const tensorflow::Tensor& shape_tensor, + tensorflow::TensorShape* output_shape) { + tensorflow::int64 num_dims = shape_tensor.NumElements(); + const tensorflow::int64* dims = shape_tensor.flat().data(); + tensorflow::TensorShapeUtils::MakeShape(dims, num_dims, output_shape); + } +}; + +#ifdef TF_TVMDSOOP_ENABLE_GPU +template <> +class TVMDSOOpTrait { + public: + static const int device_type = kDLGPU; + + static int device_id(OpKernelContext* context) { + auto device_base = context->device(); + auto gpu_device_info = device_base->tensorflow_gpu_device_info(); + return gpu_device_info->gpu_id; + } + + static void make_shape_from_tensor(const tensorflow::Tensor& shape_tensor, + tensorflow::TensorShape* output_shape) { + tensorflow::int64 num_dims = shape_tensor.NumElements(); + const tensorflow::int64* flat = shape_tensor.flat().data(); + tensorflow::int64* dims = new tensorflow::int64[num_dims]; + cudaMemcpy(dims, flat, sizeof(tensorflow::int64) * num_dims, cudaMemcpyDeviceToHost); + tensorflow::TensorShapeUtils::MakeShape(dims, num_dims, output_shape); + delete dims; + } +}; +#endif + +template +class TVMDSOOp : public OpKernel { + private: + tvm::runtime::PackedFunc tvm_func; + std::string lib_path; + std::string func_name; + + tensorflow::DataType output_dtype; + + bool has_static_output_shape; + std::vector static_output_shape; + + void initAttributes(OpKernelConstruction* context) { + context->GetAttr("lib_path", &lib_path); + context->GetAttr("func_name", &func_name); + context->GetAttr("output_dtype", &output_dtype); + + context->GetAttr("has_static_output_shape", &has_static_output_shape); + context->GetAttr("static_output_shape", &static_output_shape); + } + + public: + explicit TVMDSOOp(OpKernelConstruction* context) : OpKernel(context) { + // Get attr + initAttributes(context); + + // Load TVM function from dynamic library + tvm::runtime::Module mod_dylib = tvm::runtime::Module::LoadFromFile(lib_path); + tvm_func = mod_dylib.GetFunction(func_name); + CHECK(tvm_func != nullptr); + } + + void Compute(tensorflow::OpKernelContext* context) override { + // the last input is output shape spec + const int num_inputs = context->num_inputs() - 1; + const int num_total_args = num_inputs + 1; + std::vector args(num_total_args); + std::vector buf_info(num_inputs); + std::vector shapes(num_inputs); + + tensorflow::Status status; + int device_id = TVMDSOOpTrait::device_id(context); + int device_type = TVMDSOOpTrait::device_type; + + DLContext dl_ctx = {DLDeviceType(device_type), device_id}; + + // Get output shape + tensorflow::TensorShape output_shape; + auto& output_shape_tensor = context->input(num_inputs); + if (has_static_output_shape) { + // use static output shape + const tensorflow::int64* dims = static_output_shape.data(); + tensorflow::TensorShapeUtils::MakeShape(dims, static_output_shape.size(), &output_shape); + } else if (output_shape_tensor.dims() == 1) { + // use shape tensor values as output shape + TVMDSOOpTrait::make_shape_from_tensor(output_shape_tensor, &output_shape); + } else { + // use input tensor shape by default + output_shape = context->input(0).shape(); + } + + for (int i = 0; i < num_inputs; ++i) { + // Grab the input tensor + auto& input_tensor = context->input(i); + + // Create shape container, should keep ref during execution + shapes[i] = input_tensor.shape().dim_sizes(); + auto shape_ptr = reinterpret_cast(shapes[i].data()); + + TensorAsBuf& input = buf_info[i]; + input.device_type = device_type; + + EnsureAlignment(context, input_tensor, &input); + input.CopyFromOrigin(); + + status = MakeDLTensor(input, dl_ctx, shape_ptr, &args[i]); + OP_REQUIRES_OK(context, status); + } + + // Allocate output tensor + tensorflow::Tensor* output_tensor; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output_tensor)); + // shape dimension buf should keel alive on stack + auto output_shape_dim_buf = output_tensor->shape().dim_sizes(); + auto output_shape_ptr = reinterpret_cast(output_shape_dim_buf.data()); + + TensorAsBuf output; + output.device_type = device_type; + EnsureAlignment(context, *output_tensor, &output); + + status = MakeDLTensor(output, dl_ctx, output_shape_ptr, &args[num_inputs]); + OP_REQUIRES_OK(context, status); + + // Prepare PackedFunc arguments + std::vector tvm_values(num_total_args); + std::vector tvm_type_codes(num_total_args); + TVMArgsSetter setter(tvm_values.data(), tvm_type_codes.data()); + for (int k = 0; k < num_total_args; ++k) { + setter(k, &args[k]); + } + TVMRetValue rv; + tvm_func.CallPacked(TVMArgs(tvm_values.data(), tvm_type_codes.data(), num_total_args), &rv); + + output.CopyToOrigin(); + } +}; + +#ifdef TF_TVMDSOOP_ENABLE_GPU +REGISTER_KERNEL_BUILDER(Name("TvmDsoOp").Device(tensorflow::DEVICE_CPU), TVMDSOOp); +REGISTER_KERNEL_BUILDER(Name("TvmDsoOp").Device(tensorflow::DEVICE_GPU), TVMDSOOp); +#else +REGISTER_KERNEL_BUILDER(Name("TvmDsoOp").Device(tensorflow::DEVICE_CPU), TVMDSOOp); +#endif diff --git a/src/contrib/tf_op/tvm_dso_ops.cc b/src/contrib/tf_op/tvm_dso_ops.cc new file mode 100644 index 000000000000..1183b2ef34b5 --- /dev/null +++ b/src/contrib/tf_op/tvm_dso_ops.cc @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "tensorflow/core/framework/op.h" + +REGISTER_OP("TvmDsoOp") + .Input("input_args: ListT") + .Attr("ListT: list({int8, int32, int64, float16, float32})") + .Input("dynamic_output_shape: int64") + .Output("output: output_dtype") + .Attr("lib_path: string") + .Attr("func_name: string") + .Attr("output_dtype: {int8, int32, int64, float16, float32} = DT_FLOAT") + .Attr("static_output_shape: list(int) >= 0 = []") + .Attr("has_static_output_shape: bool"); diff --git a/tests/scripts/task_python_integration.sh b/tests/scripts/task_python_integration.sh index 5c00fd9c8896..dcd8139abd81 100755 --- a/tests/scripts/task_python_integration.sh +++ b/tests/scripts/task_python_integration.sh @@ -53,6 +53,9 @@ cd ../.. TVM_FFI=cython python3 -m pytest -v apps/dso_plugin_module TVM_FFI=ctypes python3 -m pytest -v apps/dso_plugin_module +# Do not enable TensorFlow op +# TVM_FFI=cython sh prepare_and_test_tfop_module.sh +# TVM_FFI=ctypes sh prepare_and_test_tfop_module.sh TVM_FFI=ctypes python3 -m pytest -v tests/python/integration TVM_FFI=ctypes python3 -m pytest -v tests/python/contrib