diff --git a/src/relay/backend/contrib/codegen_c/codegen.cc b/src/relay/backend/contrib/codegen_c/codegen.cc index 0b3510c85779..ed36fda90a21 100644 --- a/src/relay/backend/contrib/codegen_c/codegen.cc +++ b/src/relay/backend/contrib/codegen_c/codegen.cc @@ -68,7 +68,6 @@ class CodegenC : public MemoizedExprTranslator>, public Code runtime::NDArray array = cn->data; const auto& shape = array.Shape(); - const DLTensor& dl_tensor = array.ToDLPack()->dl_tensor; // Get the number of elements. int64_t num_elems = 1; @@ -83,11 +82,11 @@ class CodegenC : public MemoizedExprTranslator>, public Code // to avoid possible stack overflow. buf_stream << dtype << " " << output.name << "[" << num_elems << "] = {"; if (dtype == "float") { - float* p_flt = static_cast(dl_tensor.data); + float* p_flt = static_cast(array->data); for (int64_t i = 0; i < num_elems - 1; i++) buf_stream << p_flt[i] << ", "; if (num_elems) buf_stream << p_flt[num_elems - 1]; } else if (dtype == "int") { - int* p_flt = static_cast(dl_tensor.data); + int* p_flt = static_cast(array->data); for (int64_t i = 0; i < num_elems - 1; i++) buf_stream << p_flt[i] << ", "; if (num_elems) buf_stream << p_flt[num_elems - 1]; } else { diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index 26bc8786902c..5e45e94223f8 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -169,7 +169,7 @@ class CodegenDNNL : public MemoizedExprTranslator>, public C CHECK_EQ(GetDtypeString(type_node), "float") << "Only float is supported for now."; std::ostringstream buf_stream; - const float* ptr = static_cast(array.ToDLPack()->dl_tensor.data); + const float* ptr = static_cast(array->data); // Allocate large arrays on the static section to avoid stakc overflow. // Note that this would probably increase compilation time as the source diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index cb61a61470c7..9cdd36515297 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -193,35 +193,26 @@ TreeObjectPtr BuildDecisionTreeFromClauses(MatchValuePtr data, tvm::Array ToAllocTensorShape64(NDArray shape) { +std::vector ToAllocTensorShape(NDArray shape) { std::vector raw_shape; - DLTensor tensor = shape.ToDLPack()->dl_tensor; - CHECK_EQ(tensor.ndim, 1u); - CHECK_EQ(tensor.dtype.code, 0U) << "found " << tensor.dtype.code; - - // TODO(@jroesch): we really need to standaridize the bit width of - // all of the shape manipulating code. - CHECK_EQ(tensor.dtype.bits, 64) << "found " << tensor.dtype.bits; - int64_t* int_ptr = reinterpret_cast(tensor.data); - for (auto i = 0; i < tensor.shape[0]; i++) { - raw_shape.push_back(int_ptr[i]); - } - return raw_shape; -} - - -std::vector ToAllocTensorShape32(NDArray shape) { - std::vector raw_shape; - DLTensor tensor = shape.ToDLPack()->dl_tensor; - CHECK_EQ(tensor.ndim, 1u); - CHECK_EQ(tensor.dtype.code, 0U) << "found " << tensor.dtype.code; - - // TODO(@jroesch): we really need to standaridize the bit width of - // all of the shape manipulating code. - CHECK_LE(tensor.dtype.bits, 32) << "found " << tensor.dtype.bits; - int32_t* int_ptr = reinterpret_cast(tensor.data); - for (auto i = 0; i < tensor.shape[0]; i++) { - raw_shape.push_back(static_cast(int_ptr[i])); + CHECK_EQ(shape->ndim, 1u); + CHECK_EQ(shape->dtype.code, 0U) + << "The dtype of constant shape must be int32 or int64, but got " + << DLDataType2String(shape->dtype); + CHECK(shape->dtype.bits == 64 || shape->dtype.bits == 32) + << "The dtype of constant shape must be int32 or int64, but got" + << DLDataType2String(shape->dtype); + + if (shape->dtype.bits == 64) { + int64_t* int_ptr = reinterpret_cast(shape->data); + for (auto i = 0; i < shape->shape[0]; i++) { + raw_shape.push_back(int_ptr[i]); + } + } else { // int32 + int32_t* int_ptr = reinterpret_cast(shape->data); + for (auto i = 0; i < shape->shape[0]; i++) { + raw_shape.push_back(static_cast(int_ptr[i])); + } } return raw_shape; } @@ -546,17 +537,8 @@ class VMFunctionCompiler : ExprFunctor { if (const_shape) { NDArray shape = const_shape->data; - std::vector raw_shape; - DLTensor tensor = shape.ToDLPack()->dl_tensor; - // TODO(@jroesch): we need to get an RFC done to standarize this - if (tensor.dtype.bits == 64) { - raw_shape = ToAllocTensorShape64(shape); - } else if (tensor.dtype.bits == 32) { - raw_shape = ToAllocTensorShape32(shape); - } else { - LOG(FATAL) << "unsupported bitwidth: " << tensor.dtype.bits; - } - + // TODO(@jroesch): we need to get an RFC done to standarize shape dtype + std::vector raw_shape = ToAllocTensorShape(shape); // Add context field. Emit(Instruction::AllocTensor(storage_register, raw_shape, dtype, NewRegister())); } else { diff --git a/src/relay/op/memory/memory.cc b/src/relay/op/memory/memory.cc index 0a7142df572f..c7ffc95c05d5 100644 --- a/src/relay/op/memory/memory.cc +++ b/src/relay/op/memory/memory.cc @@ -23,6 +23,7 @@ */ #include +#include #include #include #include @@ -107,21 +108,22 @@ TVM_REGISTER_GLOBAL("relay.op.memory._make.alloc_tensor") std::vector FromConstShape(Constant konst) { runtime::NDArray shape = konst->data; std::vector raw_shape; - DLTensor tensor = shape.ToDLPack()->dl_tensor; - CHECK_EQ(tensor.ndim, 1u); - CHECK_EQ(tensor.dtype.code, 0U) << "found " << tensor.dtype.code; - - CHECK(tensor.dtype.bits == 64 || tensor.dtype.bits == 32) - << "found " << static_cast(tensor.dtype.bits); - - if (tensor.dtype.bits == 32) { - const int32_t* int_ptr = reinterpret_cast(tensor.data); - for (auto i = 0; i < tensor.shape[0]; i++) { + CHECK_EQ(shape->ndim, 1u); + CHECK_EQ(shape->dtype.code, 0U) + << "The dtype of constant shape must be int32 or int64, but got " + << runtime::DLDataType2String(shape->dtype); + CHECK(shape->dtype.bits == 64 || shape->dtype.bits == 32) + << "The dtype of constant shape must be int32 or int64, but got" + << runtime::DLDataType2String(shape->dtype); + + if (shape->dtype.bits == 32) { + const int32_t* int_ptr = reinterpret_cast(shape->data); + for (auto i = 0; i < shape->shape[0]; i++) { raw_shape.push_back(int_ptr[i]); } - } else if (tensor.dtype.bits == 64) { - const int64_t* int_ptr = reinterpret_cast(tensor.data); - for (auto i = 0; i < tensor.shape[0]; i++) { + } else if (shape->dtype.bits == 64) { + const int64_t* int_ptr = reinterpret_cast(shape->data); + for (auto i = 0; i < shape->shape[0]; i++) { raw_shape.push_back(int_ptr[i]); } }