diff --git a/ffi/examples/packaging/src/extension.cc b/ffi/examples/packaging/src/extension.cc index eb4be8508dc6..7a2eb1514851 100644 --- a/ffi/examples/packaging/src/extension.cc +++ b/ffi/examples/packaging/src/extension.cc @@ -24,6 +24,7 @@ * The library is written in C++ and can be compiled into a shared library. * The shared library can then be loaded into python and used to call the functions. */ +#include #include #include #include @@ -43,7 +44,7 @@ namespace ffi = tvm::ffi; */ void RaiseError(ffi::String msg) { TVM_FFI_THROW(RuntimeError) << msg; } -void AddOne(DLTensor* x, DLTensor* y) { +void AddOne(ffi::Tensor x, ffi::Tensor y) { // implementation of a library function TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor"; DLDataType f32_dtype{kDLFloat, 32, 1}; diff --git a/ffi/examples/quick_start/src/add_one_cpu.cc b/ffi/examples/quick_start/src/add_one_cpu.cc index 2499510c5394..76b9b3752c88 100644 --- a/ffi/examples/quick_start/src/add_one_cpu.cc +++ b/ffi/examples/quick_start/src/add_one_cpu.cc @@ -16,14 +16,14 @@ * specific language governing permissions and limitations * under the License. */ - +#include #include #include #include namespace tvm_ffi_example { -void AddOne(DLTensor* x, DLTensor* y) { +void AddOne(tvm::ffi::Tensor x, tvm::ffi::Tensor y) { // implementation of a library function TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor"; DLDataType f32_dtype{kDLFloat, 32, 1}; diff --git a/ffi/examples/quick_start/src/add_one_cuda.cu b/ffi/examples/quick_start/src/add_one_cuda.cu index 282395fe01d6..ead2ec89a95c 100644 --- a/ffi/examples/quick_start/src/add_one_cuda.cu +++ b/ffi/examples/quick_start/src/add_one_cuda.cu @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include #include @@ -30,7 +31,7 @@ __global__ void AddOneKernel(float* x, float* y, int n) { } } -void AddOneCUDA(DLTensor* x, DLTensor* y) { +void AddOneCUDA(tvm::ffi::Tensor x, tvm::ffi::Tensor y) { // implementation of a library function TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor"; DLDataType f32_dtype{kDLFloat, 32, 1}; diff --git a/ffi/include/tvm/ffi/container/tensor.h b/ffi/include/tvm/ffi/container/tensor.h index 93526e5c2a5d..8a8134d86020 100644 --- a/ffi/include/tvm/ffi/container/tensor.h +++ b/ffi/include/tvm/ffi/container/tensor.h @@ -151,7 +151,7 @@ class TensorObj : public Object, public DLTensor { protected: // backs up the shape/strides Optional shape_data_; - Optional stride_data_; + Optional strides_data_; static void DLManagedTensorDeleter(DLManagedTensor* tensor) { TensorObj* obj = static_cast(tensor->manager_ctx); @@ -189,7 +189,7 @@ class TensorObjFromNDAlloc : public TensorObj { this->strides = const_cast(strides.data()); this->byte_offset = 0; this->shape_data_ = std::move(shape); - this->stride_data_ = std::move(strides); + this->strides_data_ = std::move(strides); alloc_.AllocData(static_cast(this), std::forward(extra_args)...); } @@ -208,7 +208,7 @@ class TensorObjFromDLPack : public TensorObj { if (tensor_->dl_tensor.strides == nullptr) { Shape strides = Shape(details::MakeStridesFromShape(ndim, shape)); this->strides = const_cast(strides.data()); - this->stride_data_ = std::move(strides); + this->strides_data_ = std::move(strides); } } @@ -244,6 +244,18 @@ class Tensor : public ObjectRef { } return *(obj->shape_data_); } + /*! + * \brief Get the strides of the Tensor. + * \return The strides of the Tensor. + */ + tvm::ffi::Shape strides() const { + TensorObj* obj = get_mutable(); + TVM_FFI_ICHECK(obj->strides != nullptr); + if (!obj->strides_data_.has_value()) { + obj->strides_data_ = tvm::ffi::Shape(obj->strides, obj->strides + obj->ndim); + } + return *(obj->strides_data_); + } /*! * \brief Get the data type of the Tensor. * \return The data type of the Tensor. diff --git a/ffi/tests/cpp/test_tensor.cc b/ffi/tests/cpp/test_tensor.cc index 17a6427af35c..3ad182d844f0 100644 --- a/ffi/tests/cpp/test_tensor.cc +++ b/ffi/tests/cpp/test_tensor.cc @@ -35,10 +35,15 @@ inline Tensor Empty(Shape shape, DLDataType dtype, DLDevice device) { TEST(Tensor, Basic) { Tensor nd = Empty(Shape({1, 2, 3}), DLDataType({kDLFloat, 32, 1}), DLDevice({kDLCPU, 0})); Shape shape = nd.shape(); + Shape strides = nd.strides(); EXPECT_EQ(shape.size(), 3); EXPECT_EQ(shape[0], 1); EXPECT_EQ(shape[1], 2); EXPECT_EQ(shape[2], 3); + EXPECT_EQ(strides.size(), 3); + EXPECT_EQ(strides[0], 6); + EXPECT_EQ(strides[1], 3); + EXPECT_EQ(strides[2], 1); EXPECT_EQ(nd.dtype(), DLDataType({kDLFloat, 32, 1})); for (int64_t i = 0; i < shape.Product(); ++i) { reinterpret_cast(nd->data)[i] = static_cast(i); @@ -47,6 +52,7 @@ TEST(Tensor, Basic) { Any any0 = nd; Tensor nd2 = any0.as().value(); EXPECT_EQ(nd2.shape(), shape); + EXPECT_EQ(nd2.strides(), strides); EXPECT_EQ(nd2.dtype(), DLDataType({kDLFloat, 32, 1})); for (int64_t i = 0; i < shape.Product(); ++i) { EXPECT_EQ(reinterpret_cast(nd2->data)[i], i);