diff --git a/apps/tf_tvmdsoop/tests/test_tfop_module.py b/apps/tf_tvmdsoop/tests/test_tfop_module.py index f2dee98ee01c..1672b58fd60a 100644 --- a/apps/tf_tvmdsoop/tests/test_tfop_module.py +++ b/apps/tf_tvmdsoop/tests/test_tfop_module.py @@ -61,7 +61,7 @@ def export_gpu_add_lib(): def test_add(session, lib_path, tf_device): """test add lib with TensorFlow wrapper""" - module = tf_op.Module(lib_path) + module = tf_op.OpModule(lib_path) left = tf.placeholder("float32", shape=[4]) right = tf.placeholder("float32", shape=[4]) diff --git a/python/tvm/contrib/tf_op/module.py b/python/tvm/contrib/tf_op/module.py index 446800c82a03..f67f715a1be6 100644 --- a/python/tvm/contrib/tf_op/module.py +++ b/python/tvm/contrib/tf_op/module.py @@ -67,16 +67,11 @@ def __init__(self, lib_path, func_name, output_dtype, output_shape): elif output_shape is not None: self.dynamic_output_shape = self._pack_shape_tensor(output_shape) - # delay op initialization to where Func.apply() get called first time - self.tvm_dso_op = None self.module = load_library.load_op_library('tvm_dso_op.so') + self.tvm_dso_op = self.module.tvm_dso_op def apply(self, *params): - if self.tvm_dso_op is None: - num_inputs = len(params) - self.tvm_dso_op = getattr(self.module, "tvm_dso_op%s" % num_inputs) - - return self.tvm_dso_op(*params, + return self.tvm_dso_op(params, dynamic_output_shape=self.dynamic_output_shape, static_output_shape=self.static_output_shape, has_static_output_shape=self.has_static_output_shape, diff --git a/src/contrib/tf_op/index_seq.h b/src/contrib/tf_op/index_seq.h deleted file mode 100644 index 5448c1f5f42d..000000000000 --- a/src/contrib/tf_op/index_seq.h +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/* - * Refer to std::index_sequence (since c++14) - * Utilities to invoke variadic function with template - */ -#ifndef TVM_CONTRIB_TF_OP_INDEX_SEQ_H_ -#define TVM_CONTRIB_TF_OP_INDEX_SEQ_H_ - -template -struct IndexSeq {}; - -template -struct IndexSeqHelper : public IndexSeqHelper {}; - -template -struct IndexSeqHelper<0U, Tail...> { - using type = IndexSeq; -}; - -template -using make_index_sequence = typename IndexSeqHelper::type; - -template -void apply_variadic_impl(F f, T (&t)[N], IndexSeq) { - f(t[Idx]...); -} - -template -void apply_variadic(F f, T (&t)[N]) { - apply_variadic_impl(f, t, make_index_sequence{}); -} - -template -void apply_variadic_by_ptrs_impl(F f, T (&t)[N], IndexSeq) { - f(&t[Idx]...); -} - -template -void apply_variadic_by_ptrs(F f, T (&t)[N]) { - apply_variadic_by_ptrs_impl(f, t, make_index_sequence{}); -} - -#endif // TVM_CONTRIB_TF_OP_INDEX_SEQ_H_ diff --git a/src/contrib/tf_op/tvm_dso_op_kernels.cc b/src/contrib/tf_op/tvm_dso_op_kernels.cc index 03024da9a478..d74d8fb917e5 100644 --- a/src/contrib/tf_op/tvm_dso_op_kernels.cc +++ b/src/contrib/tf_op/tvm_dso_op_kernels.cc @@ -26,7 +26,6 @@ #include #include -#include "index_seq.h" #include "tensorflow/core/framework/op_kernel.h" typedef Eigen::ThreadPoolDevice CPUDevice; @@ -37,6 +36,10 @@ using tensorflow::OpKernel; using tensorflow::OpKernelConstruction; using tensorflow::OpKernelContext; +using tvm::runtime::TVMArgs; +using tvm::runtime::TVMArgsSetter; +using tvm::runtime::TVMRetValue; + // Op utility trait for diffrent device type template template class TVMDSOOpTrait; @@ -192,7 +195,7 @@ class TVMDSOOpTrait { }; #endif -template +template class TVMDSOOp : public OpKernel { private: tvm::runtime::PackedFunc tvm_func; @@ -225,9 +228,12 @@ class TVMDSOOp : public OpKernel { } void Compute(tensorflow::OpKernelContext* context) override { - DLTensor args[NUM_INPUTS + 1]; - TensorAsBuf buf_info[NUM_INPUTS]; - ShapeContainer shapes[NUM_INPUTS]; + // the last input is output shape spec + const int num_inputs = context->num_inputs() - 1; + const int num_total_args = num_inputs + 1; + std::vector args(num_total_args); + std::vector buf_info(num_inputs); + std::vector shapes(num_inputs); tensorflow::Status status; int device_id = TVMDSOOpTrait::device_id(context); @@ -237,7 +243,7 @@ class TVMDSOOp : public OpKernel { // Get output shape tensorflow::TensorShape output_shape; - auto& output_shape_tensor = context->input(NUM_INPUTS); + auto& output_shape_tensor = context->input(num_inputs); if (has_static_output_shape) { // use static output shape const tensorflow::int64* dims = static_output_shape.data(); @@ -250,7 +256,7 @@ class TVMDSOOp : public OpKernel { output_shape = context->input(0).shape(); } - for (int i = 0; i < NUM_INPUTS; ++i) { + for (int i = 0; i < num_inputs; ++i) { // Grab the input tensor auto& input_tensor = context->input(i); @@ -279,32 +285,26 @@ class TVMDSOOp : public OpKernel { output.device_type = device_type; EnsureAlignment(context, *output_tensor, &output); - status = MakeDLTensor(output, dl_ctx, output_shape_ptr, &args[NUM_INPUTS]); + status = MakeDLTensor(output, dl_ctx, output_shape_ptr, &args[num_inputs]); OP_REQUIRES_OK(context, status); - apply_variadic_by_ptrs(tvm_func, args); + // Prepare PackedFunc arguments + std::vector tvm_values(num_total_args); + std::vector tvm_type_codes(num_total_args); + TVMArgsSetter setter(tvm_values.data(), tvm_type_codes.data()); + for (int k = 0; k < num_total_args; ++k) { + setter(k, &args[k]); + } + TVMRetValue rv; + tvm_func.CallPacked(TVMArgs(tvm_values.data(), tvm_type_codes.data(), num_total_args), &rv); output.CopyToOrigin(); } }; #ifdef TF_TVMDSOOP_ENABLE_GPU -#define REGISTER_TFTVM_KERNEL(n) \ - REGISTER_KERNEL_BUILDER(Name("TvmDsoOp" #n).Device(tensorflow::DEVICE_CPU), \ - TVMDSOOp); \ - REGISTER_KERNEL_BUILDER(Name("TvmDsoOp" #n).Device(tensorflow::DEVICE_GPU), \ - TVMDSOOp); +REGISTER_KERNEL_BUILDER(Name("TvmDsoOp").Device(tensorflow::DEVICE_CPU), TVMDSOOp); +REGISTER_KERNEL_BUILDER(Name("TvmDsoOp").Device(tensorflow::DEVICE_GPU), TVMDSOOp); #else -#define REGISTER_TFTVM_KERNEL(n) \ - REGISTER_KERNEL_BUILDER(Name("TvmDsoOp" #n).Device(tensorflow::DEVICE_CPU), \ - TVMDSOOp); +REGISTER_KERNEL_BUILDER(Name("TvmDsoOp").Device(tensorflow::DEVICE_CPU), TVMDSOOp); #endif - -REGISTER_TFTVM_KERNEL(1) -REGISTER_TFTVM_KERNEL(2) -REGISTER_TFTVM_KERNEL(3) -REGISTER_TFTVM_KERNEL(4) -REGISTER_TFTVM_KERNEL(5) -REGISTER_TFTVM_KERNEL(6) -REGISTER_TFTVM_KERNEL(7) -REGISTER_TFTVM_KERNEL(8) diff --git a/src/contrib/tf_op/tvm_dso_ops.cc b/src/contrib/tf_op/tvm_dso_ops.cc index f228313949cb..1183b2ef34b5 100644 --- a/src/contrib/tf_op/tvm_dso_ops.cc +++ b/src/contrib/tf_op/tvm_dso_ops.cc @@ -19,104 +19,13 @@ #include "tensorflow/core/framework/op.h" -#define REGISTER_TFTVM_OP(n) \ - REGISTER_OP("TvmDsoOp" #n) \ - .Output("output: output_dtype") \ - .Attr("lib_path: string") \ - .Attr("func_name: string") \ - .Attr("output_dtype: {int32, int64, float} = DT_FLOAT") \ - .Attr("static_output_shape: list(int) >= 0 = []") \ - .Attr("has_static_output_shape: bool") - -REGISTER_TFTVM_OP(1).Input("input: T").Attr("T: type").Input("dynamic_output_shape: int64"); - -REGISTER_TFTVM_OP(2) - .Input("input1: T1") - .Attr("T1: type") - .Input("input2: T2") - .Attr("T2: type") - .Input("dynamic_output_shape: int64"); - -REGISTER_TFTVM_OP(3) - .Input("input1: T1") - .Attr("T1: type") - .Input("input2: T2") - .Attr("T2: type") - .Input("input3: T3") - .Attr("T3: type") - .Input("dynamic_output_shape: int64"); - -REGISTER_TFTVM_OP(4) - .Input("input1: T1") - .Attr("T1: type") - .Input("input2: T2") - .Attr("T2: type") - .Input("input3: T3") - .Attr("T3: type") - .Input("input4: T4") - .Attr("T4: type") - .Input("dynamic_output_shape: int64"); - -REGISTER_TFTVM_OP(5) - .Input("input1: T1") - .Attr("T1: type") - .Input("input2: T2") - .Attr("T2: type") - .Input("input3: T3") - .Attr("T3: type") - .Input("input4: T4") - .Attr("T4: type") - .Input("input5: T5") - .Attr("T5: type") - .Input("dynamic_output_shape: int64"); - -REGISTER_TFTVM_OP(6) - .Input("input1: T1") - .Attr("T1: type") - .Input("input2: T2") - .Attr("T2: type") - .Input("input3: T3") - .Attr("T3: type") - .Input("input4: T4") - .Attr("T4: type") - .Input("input5: T5") - .Attr("T5: type") - .Input("input6: T6") - .Attr("T6: type") - .Input("dynamic_output_shape: int64"); - -REGISTER_TFTVM_OP(7) - .Input("input1: T1") - .Attr("T1: type") - .Input("input2: T2") - .Attr("T2: type") - .Input("input3: T3") - .Attr("T3: type") - .Input("input4: T4") - .Attr("T4: type") - .Input("input5: T5") - .Attr("T5: type") - .Input("input6: T6") - .Attr("T6: type") - .Input("input7: T7") - .Attr("T7: type") - .Input("dynamic_output_shape: int64"); - -REGISTER_TFTVM_OP(8) - .Input("input1: T1") - .Attr("T1: type") - .Input("input2: T2") - .Attr("T2: type") - .Input("input3: T3") - .Attr("T3: type") - .Input("input4: T4") - .Attr("T4: type") - .Input("input5: T5") - .Attr("T5: type") - .Input("input6: T6") - .Attr("T6: type") - .Input("input7: T7") - .Attr("T7: type") - .Input("input8: T8") - .Attr("T8: type") - .Input("dynamic_output_shape: int64"); +REGISTER_OP("TvmDsoOp") + .Input("input_args: ListT") + .Attr("ListT: list({int8, int32, int64, float16, float32})") + .Input("dynamic_output_shape: int64") + .Output("output: output_dtype") + .Attr("lib_path: string") + .Attr("func_name: string") + .Attr("output_dtype: {int8, int32, int64, float16, float32} = DT_FLOAT") + .Attr("static_output_shape: list(int) >= 0 = []") + .Attr("has_static_output_shape: bool");