Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ffi/examples/packaging/src/extension.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
* The library is written in C++ and can be compiled into a shared library.
* The shared library can then be loaded into python and used to call the functions.
*/
#include <tvm/ffi/container/tensor.h>
#include <tvm/ffi/dtype.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/function.h>
Expand All @@ -43,7 +44,7 @@ namespace ffi = tvm::ffi;
*/
void RaiseError(ffi::String msg) { TVM_FFI_THROW(RuntimeError) << msg; }

void AddOne(DLTensor* x, DLTensor* y) {
void AddOne(ffi::Tensor x, ffi::Tensor y) {
// implementation of a library function
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
DLDataType f32_dtype{kDLFloat, 32, 1};
Expand Down
4 changes: 2 additions & 2 deletions ffi/examples/quick_start/src/add_one_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
* specific language governing permissions and limitations
* under the License.
*/

#include <tvm/ffi/container/tensor.h>
#include <tvm/ffi/dtype.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/function.h>

namespace tvm_ffi_example {

void AddOne(DLTensor* x, DLTensor* y) {
void AddOne(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
// implementation of a library function
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
DLDataType f32_dtype{kDLFloat, 32, 1};
Expand Down
3 changes: 2 additions & 1 deletion ffi/examples/quick_start/src/add_one_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/ffi/container/tensor.h>
#include <tvm/ffi/dtype.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/extra/c_env_api.h>
Expand All @@ -30,7 +31,7 @@ __global__ void AddOneKernel(float* x, float* y, int n) {
}
}

void AddOneCUDA(DLTensor* x, DLTensor* y) {
void AddOneCUDA(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
// implementation of a library function
TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
DLDataType f32_dtype{kDLFloat, 32, 1};
Expand Down
18 changes: 15 additions & 3 deletions ffi/include/tvm/ffi/container/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ class TensorObj : public Object, public DLTensor {
protected:
// backs up the shape/strides
Optional<Shape> shape_data_;
Optional<Shape> stride_data_;
Optional<Shape> strides_data_;

static void DLManagedTensorDeleter(DLManagedTensor* tensor) {
TensorObj* obj = static_cast<TensorObj*>(tensor->manager_ctx);
Expand Down Expand Up @@ -189,7 +189,7 @@ class TensorObjFromNDAlloc : public TensorObj {
this->strides = const_cast<int64_t*>(strides.data());
this->byte_offset = 0;
this->shape_data_ = std::move(shape);
this->stride_data_ = std::move(strides);
this->strides_data_ = std::move(strides);
alloc_.AllocData(static_cast<DLTensor*>(this), std::forward<ExtraArgs>(extra_args)...);
}

Expand All @@ -208,7 +208,7 @@ class TensorObjFromDLPack : public TensorObj {
if (tensor_->dl_tensor.strides == nullptr) {
Shape strides = Shape(details::MakeStridesFromShape(ndim, shape));
this->strides = const_cast<int64_t*>(strides.data());
this->stride_data_ = std::move(strides);
this->strides_data_ = std::move(strides);
}
}

Expand Down Expand Up @@ -244,6 +244,18 @@ class Tensor : public ObjectRef {
}
return *(obj->shape_data_);
}
/*!
* \brief Get the strides of the Tensor.
* \return The strides of the Tensor.
*/
tvm::ffi::Shape strides() const {
TensorObj* obj = get_mutable();
TVM_FFI_ICHECK(obj->strides != nullptr);
if (!obj->strides_data_.has_value()) {
obj->strides_data_ = tvm::ffi::Shape(obj->strides, obj->strides + obj->ndim);
}
return *(obj->strides_data_);
}
/*!
* \brief Get the data type of the Tensor.
* \return The data type of the Tensor.
Expand Down
6 changes: 6 additions & 0 deletions ffi/tests/cpp/test_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,15 @@ inline Tensor Empty(Shape shape, DLDataType dtype, DLDevice device) {
TEST(Tensor, Basic) {
Tensor nd = Empty(Shape({1, 2, 3}), DLDataType({kDLFloat, 32, 1}), DLDevice({kDLCPU, 0}));
Shape shape = nd.shape();
Shape strides = nd.strides();
EXPECT_EQ(shape.size(), 3);
EXPECT_EQ(shape[0], 1);
EXPECT_EQ(shape[1], 2);
EXPECT_EQ(shape[2], 3);
EXPECT_EQ(strides.size(), 3);
EXPECT_EQ(strides[0], 6);
EXPECT_EQ(strides[1], 3);
EXPECT_EQ(strides[2], 1);
EXPECT_EQ(nd.dtype(), DLDataType({kDLFloat, 32, 1}));
for (int64_t i = 0; i < shape.Product(); ++i) {
reinterpret_cast<float*>(nd->data)[i] = static_cast<float>(i);
Expand All @@ -47,6 +52,7 @@ TEST(Tensor, Basic) {
Any any0 = nd;
Tensor nd2 = any0.as<Tensor>().value();
EXPECT_EQ(nd2.shape(), shape);
EXPECT_EQ(nd2.strides(), strides);
EXPECT_EQ(nd2.dtype(), DLDataType({kDLFloat, 32, 1}));
for (int64_t i = 0; i < shape.Product(); ++i) {
EXPECT_EQ(reinterpret_cast<float*>(nd2->data)[i], i);
Expand Down
Loading