diff --git a/src/contrib/tf_op/tvm_dso_op_kernels.cc b/src/contrib/tf_op/tvm_dso_op_kernels.cc index d74d8fb917e5..705a3347b68c 100644 --- a/src/contrib/tf_op/tvm_dso_op_kernels.cc +++ b/src/contrib/tf_op/tvm_dso_op_kernels.cc @@ -97,12 +97,29 @@ class TensorAsBuf { tensorflow::Status GetDLPackDtype(const tensorflow::Tensor& tf_tensor, DLDataType* res) { auto dtype = tf_tensor.dtype(); - if (dtype == tensorflow::DT_FLOAT) { + + if (dtype == tensorflow::DT_HALF) { + *res = {kDLFloat, 16, 1}; + } else if (dtype == tensorflow::DT_FLOAT) { *res = {kDLFloat, 32, 1}; - } else if (dtype == tensorflow::DT_INT64) { - *res = {kDLInt, 64, 1}; + } else if (dtype == tensorflow::DT_DOUBLE) { + *res = {kDLFloat, 64, 1}; + } else if (dtype == tensorflow::DT_INT8) { + *res = {kDLInt, 8, 1}; + } else if (dtype == tensorflow::DT_INT16) { + *res = {kDLInt, 16, 1}; } else if (dtype == tensorflow::DT_INT32) { *res = {kDLInt, 32, 1}; + } else if (dtype == tensorflow::DT_INT64) { + *res = {kDLInt, 64, 1}; + } else if (dtype == tensorflow::DT_UINT8) { + *res = {kDLUInt, 8, 1}; + } else if (dtype == tensorflow::DT_UINT16) { + *res = {kDLUInt, 16, 1}; + } else if (dtype == tensorflow::DT_UINT32) { + *res = {kDLUInt, 32, 1}; + } else if (dtype == tensorflow::DT_UINT64) { + *res = {kDLUInt, 64, 1}; } else { return tensorflow::Status(tensorflow::error::INTERNAL, "Fail to get dlpack datatype"); } diff --git a/src/contrib/tf_op/tvm_dso_ops.cc b/src/contrib/tf_op/tvm_dso_ops.cc index 1183b2ef34b5..794494298d71 100644 --- a/src/contrib/tf_op/tvm_dso_ops.cc +++ b/src/contrib/tf_op/tvm_dso_ops.cc @@ -21,11 +21,15 @@ REGISTER_OP("TvmDsoOp") .Input("input_args: ListT") - .Attr("ListT: list({int8, int32, int64, float16, float32})") + .Attr( + "ListT: list({float16, float32, float64, int8, int16, int32, int64, uint8, uint16," + "uint32, uint64})") .Input("dynamic_output_shape: int64") .Output("output: output_dtype") .Attr("lib_path: string") .Attr("func_name: string") - .Attr("output_dtype: {int8, int32, int64, float16, float32} = DT_FLOAT") + .Attr( + "output_dtype: {float16, float32, float64, int8, int16, int32, int64, uint8, uint16," + "uint32, uint64} = DT_FLOAT") .Attr("static_output_shape: list(int) >= 0 = []") .Attr("has_static_output_shape: bool");