Skip to content

Commit

Permalink
feat: Use TF list input op def
Browse files Browse the repository at this point in the history
  • Loading branch information
baoxinqi authored and wrongtest committed Apr 3, 2020
1 parent 8ac182f commit 9343669
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 196 deletions.
2 changes: 1 addition & 1 deletion apps/tf_tvmdsoop/tests/test_tfop_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def export_gpu_add_lib():

def test_add(session, lib_path, tf_device):
"""test add lib with TensorFlow wrapper"""
module = tf_op.Module(lib_path)
module = tf_op.OpModule(lib_path)

left = tf.placeholder("float32", shape=[4])
right = tf.placeholder("float32", shape=[4])
Expand Down
9 changes: 2 additions & 7 deletions python/tvm/contrib/tf_op/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,11 @@ def __init__(self, lib_path, func_name, output_dtype, output_shape):
elif output_shape is not None:
self.dynamic_output_shape = self._pack_shape_tensor(output_shape)

# delay op initialization to where Func.apply() get called first time
self.tvm_dso_op = None
self.module = load_library.load_op_library('tvm_dso_op.so')
self.tvm_dso_op = self.module.tvm_dso_op

def apply(self, *params):
if self.tvm_dso_op is None:
num_inputs = len(params)
self.tvm_dso_op = getattr(self.module, "tvm_dso_op%s" % num_inputs)

return self.tvm_dso_op(*params,
return self.tvm_dso_op(params,
dynamic_output_shape=self.dynamic_output_shape,
static_output_shape=self.static_output_shape,
has_static_output_shape=self.has_static_output_shape,
Expand Down
61 changes: 0 additions & 61 deletions src/contrib/tf_op/index_seq.h

This file was deleted.

52 changes: 26 additions & 26 deletions src/contrib/tf_op/tvm_dso_op_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>

#include "index_seq.h"
#include "tensorflow/core/framework/op_kernel.h"

typedef Eigen::ThreadPoolDevice CPUDevice;
Expand All @@ -37,6 +36,10 @@ using tensorflow::OpKernel;
using tensorflow::OpKernelConstruction;
using tensorflow::OpKernelContext;

using tvm::runtime::TVMArgs;
using tvm::runtime::TVMArgsSetter;
using tvm::runtime::TVMRetValue;

// Op utility trait for diffrent device type template
template <typename DEVICE_TYPE>
class TVMDSOOpTrait;
Expand Down Expand Up @@ -192,7 +195,7 @@ class TVMDSOOpTrait<GPUDevice> {
};
#endif

template <typename DEVICE_TYPE, int NUM_INPUTS>
template <typename DEVICE_TYPE>
class TVMDSOOp : public OpKernel {
private:
tvm::runtime::PackedFunc tvm_func;
Expand Down Expand Up @@ -225,9 +228,12 @@ class TVMDSOOp : public OpKernel {
}

void Compute(tensorflow::OpKernelContext* context) override {
DLTensor args[NUM_INPUTS + 1];
TensorAsBuf buf_info[NUM_INPUTS];
ShapeContainer shapes[NUM_INPUTS];
// the last input is output shape spec
const int num_inputs = context->num_inputs() - 1;
const int num_total_args = num_inputs + 1;
std::vector<DLTensor> args(num_total_args);
std::vector<TensorAsBuf> buf_info(num_inputs);
std::vector<ShapeContainer> shapes(num_inputs);

tensorflow::Status status;
int device_id = TVMDSOOpTrait<DEVICE_TYPE>::device_id(context);
Expand All @@ -237,7 +243,7 @@ class TVMDSOOp : public OpKernel {

// Get output shape
tensorflow::TensorShape output_shape;
auto& output_shape_tensor = context->input(NUM_INPUTS);
auto& output_shape_tensor = context->input(num_inputs);
if (has_static_output_shape) {
// use static output shape
const tensorflow::int64* dims = static_output_shape.data();
Expand All @@ -250,7 +256,7 @@ class TVMDSOOp : public OpKernel {
output_shape = context->input(0).shape();
}

for (int i = 0; i < NUM_INPUTS; ++i) {
for (int i = 0; i < num_inputs; ++i) {
// Grab the input tensor
auto& input_tensor = context->input(i);

Expand Down Expand Up @@ -279,32 +285,26 @@ class TVMDSOOp : public OpKernel {
output.device_type = device_type;
EnsureAlignment(context, *output_tensor, &output);

status = MakeDLTensor(output, dl_ctx, output_shape_ptr, &args[NUM_INPUTS]);
status = MakeDLTensor(output, dl_ctx, output_shape_ptr, &args[num_inputs]);
OP_REQUIRES_OK(context, status);

apply_variadic_by_ptrs(tvm_func, args);
// Prepare PackedFunc arguments
std::vector<TVMValue> tvm_values(num_total_args);
std::vector<int> tvm_type_codes(num_total_args);
TVMArgsSetter setter(tvm_values.data(), tvm_type_codes.data());
for (int k = 0; k < num_total_args; ++k) {
setter(k, &args[k]);
}
TVMRetValue rv;
tvm_func.CallPacked(TVMArgs(tvm_values.data(), tvm_type_codes.data(), num_total_args), &rv);

output.CopyToOrigin();
}
};

#ifdef TF_TVMDSOOP_ENABLE_GPU
#define REGISTER_TFTVM_KERNEL(n) \
REGISTER_KERNEL_BUILDER(Name("TvmDsoOp" #n).Device(tensorflow::DEVICE_CPU), \
TVMDSOOp<CPUDevice, n>); \
REGISTER_KERNEL_BUILDER(Name("TvmDsoOp" #n).Device(tensorflow::DEVICE_GPU), \
TVMDSOOp<GPUDevice, n>);
REGISTER_KERNEL_BUILDER(Name("TvmDsoOp").Device(tensorflow::DEVICE_CPU), TVMDSOOp<CPUDevice>);
REGISTER_KERNEL_BUILDER(Name("TvmDsoOp").Device(tensorflow::DEVICE_GPU), TVMDSOOp<GPUDevice>);
#else
#define REGISTER_TFTVM_KERNEL(n) \
REGISTER_KERNEL_BUILDER(Name("TvmDsoOp" #n).Device(tensorflow::DEVICE_CPU), \
TVMDSOOp<CPUDevice, n>);
REGISTER_KERNEL_BUILDER(Name("TvmDsoOp").Device(tensorflow::DEVICE_CPU), TVMDSOOp<CPUDevice>);
#endif

REGISTER_TFTVM_KERNEL(1)
REGISTER_TFTVM_KERNEL(2)
REGISTER_TFTVM_KERNEL(3)
REGISTER_TFTVM_KERNEL(4)
REGISTER_TFTVM_KERNEL(5)
REGISTER_TFTVM_KERNEL(6)
REGISTER_TFTVM_KERNEL(7)
REGISTER_TFTVM_KERNEL(8)
111 changes: 10 additions & 101 deletions src/contrib/tf_op/tvm_dso_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,104 +19,13 @@

#include "tensorflow/core/framework/op.h"

#define REGISTER_TFTVM_OP(n) \
REGISTER_OP("TvmDsoOp" #n) \
.Output("output: output_dtype") \
.Attr("lib_path: string") \
.Attr("func_name: string") \
.Attr("output_dtype: {int32, int64, float} = DT_FLOAT") \
.Attr("static_output_shape: list(int) >= 0 = []") \
.Attr("has_static_output_shape: bool")

REGISTER_TFTVM_OP(1).Input("input: T").Attr("T: type").Input("dynamic_output_shape: int64");

REGISTER_TFTVM_OP(2)
.Input("input1: T1")
.Attr("T1: type")
.Input("input2: T2")
.Attr("T2: type")
.Input("dynamic_output_shape: int64");

REGISTER_TFTVM_OP(3)
.Input("input1: T1")
.Attr("T1: type")
.Input("input2: T2")
.Attr("T2: type")
.Input("input3: T3")
.Attr("T3: type")
.Input("dynamic_output_shape: int64");

REGISTER_TFTVM_OP(4)
.Input("input1: T1")
.Attr("T1: type")
.Input("input2: T2")
.Attr("T2: type")
.Input("input3: T3")
.Attr("T3: type")
.Input("input4: T4")
.Attr("T4: type")
.Input("dynamic_output_shape: int64");

REGISTER_TFTVM_OP(5)
.Input("input1: T1")
.Attr("T1: type")
.Input("input2: T2")
.Attr("T2: type")
.Input("input3: T3")
.Attr("T3: type")
.Input("input4: T4")
.Attr("T4: type")
.Input("input5: T5")
.Attr("T5: type")
.Input("dynamic_output_shape: int64");

REGISTER_TFTVM_OP(6)
.Input("input1: T1")
.Attr("T1: type")
.Input("input2: T2")
.Attr("T2: type")
.Input("input3: T3")
.Attr("T3: type")
.Input("input4: T4")
.Attr("T4: type")
.Input("input5: T5")
.Attr("T5: type")
.Input("input6: T6")
.Attr("T6: type")
.Input("dynamic_output_shape: int64");

REGISTER_TFTVM_OP(7)
.Input("input1: T1")
.Attr("T1: type")
.Input("input2: T2")
.Attr("T2: type")
.Input("input3: T3")
.Attr("T3: type")
.Input("input4: T4")
.Attr("T4: type")
.Input("input5: T5")
.Attr("T5: type")
.Input("input6: T6")
.Attr("T6: type")
.Input("input7: T7")
.Attr("T7: type")
.Input("dynamic_output_shape: int64");

REGISTER_TFTVM_OP(8)
.Input("input1: T1")
.Attr("T1: type")
.Input("input2: T2")
.Attr("T2: type")
.Input("input3: T3")
.Attr("T3: type")
.Input("input4: T4")
.Attr("T4: type")
.Input("input5: T5")
.Attr("T5: type")
.Input("input6: T6")
.Attr("T6: type")
.Input("input7: T7")
.Attr("T7: type")
.Input("input8: T8")
.Attr("T8: type")
.Input("dynamic_output_shape: int64");
REGISTER_OP("TvmDsoOp")
.Input("input_args: ListT")
.Attr("ListT: list({int8, int32, int64, float16, float32})")
.Input("dynamic_output_shape: int64")
.Output("output: output_dtype")
.Attr("lib_path: string")
.Attr("func_name: string")
.Attr("output_dtype: {int8, int32, int64, float16, float32} = DT_FLOAT")
.Attr("static_output_shape: list(int) >= 0 = []")
.Attr("has_static_output_shape: bool");

0 comments on commit 9343669

Please sign in to comment.