Skip to content

Commit

Permalink
Support more dtypes for TVMDSOOp (#5694)
Browse files Browse the repository at this point in the history
  • Loading branch information
tobegit3hub authored May 29, 2020
1 parent 06bf8b0 commit a9ce2f7
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 5 deletions.
23 changes: 20 additions & 3 deletions src/contrib/tf_op/tvm_dso_op_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,29 @@ class TensorAsBuf {

tensorflow::Status GetDLPackDtype(const tensorflow::Tensor& tf_tensor, DLDataType* res) {
auto dtype = tf_tensor.dtype();
if (dtype == tensorflow::DT_FLOAT) {

if (dtype == tensorflow::DT_HALF) {
*res = {kDLFloat, 16, 1};
} else if (dtype == tensorflow::DT_FLOAT) {
*res = {kDLFloat, 32, 1};
} else if (dtype == tensorflow::DT_INT64) {
*res = {kDLInt, 64, 1};
} else if (dtype == tensorflow::DT_DOUBLE) {
*res = {kDLFloat, 64, 1};
} else if (dtype == tensorflow::DT_INT8) {
*res = {kDLInt, 8, 1};
} else if (dtype == tensorflow::DT_INT16) {
*res = {kDLInt, 16, 1};
} else if (dtype == tensorflow::DT_INT32) {
*res = {kDLInt, 32, 1};
} else if (dtype == tensorflow::DT_INT64) {
*res = {kDLInt, 64, 1};
} else if (dtype == tensorflow::DT_UINT8) {
*res = {kDLUInt, 8, 1};
} else if (dtype == tensorflow::DT_UINT16) {
*res = {kDLUInt, 16, 1};
} else if (dtype == tensorflow::DT_UINT32) {
*res = {kDLUInt, 32, 1};
} else if (dtype == tensorflow::DT_UINT64) {
*res = {kDLUInt, 64, 1};
} else {
return tensorflow::Status(tensorflow::error::INTERNAL, "Fail to get dlpack datatype");
}
Expand Down
8 changes: 6 additions & 2 deletions src/contrib/tf_op/tvm_dso_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,15 @@

REGISTER_OP("TvmDsoOp")
.Input("input_args: ListT")
.Attr("ListT: list({int8, int32, int64, float16, float32})")
.Attr(
"ListT: list({float16, float32, float64, int8, int16, int32, int64, uint8, uint16,"
"uint32, uint64})")
.Input("dynamic_output_shape: int64")
.Output("output: output_dtype")
.Attr("lib_path: string")
.Attr("func_name: string")
.Attr("output_dtype: {int8, int32, int64, float16, float32} = DT_FLOAT")
.Attr(
"output_dtype: {float16, float32, float64, int8, int16, int32, int64, uint8, uint16,"
"uint32, uint64} = DT_FLOAT")
.Attr("static_output_shape: list(int) >= 0 = []")
.Attr("has_static_output_shape: bool");

0 comments on commit a9ce2f7

Please sign in to comment.