diff --git a/include/tvm/runtime/container/shape_tuple.h b/include/tvm/runtime/container/shape_tuple.h new file mode 100644 index 000000000000..774077fc3d5e --- /dev/null +++ b/include/tvm/runtime/container/shape_tuple.h @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/container/shape_tuple.h + * \brief Runtime ShapeTuple container types. + */ +#ifndef TVM_RUNTIME_CONTAINER_SHAPE_TUPLE_H_ +#define TVM_RUNTIME_CONTAINER_SHAPE_TUPLE_H_ + +#include +#include + +#include "./base.h" + +namespace tvm { +namespace runtime { + +/*! \brief An object representing a shape tuple. */ +class ShapeTupleObj : public Object { + public: + /*! \brief The type of shape index element. */ + using index_type = int64_t; + /*! \brief The pointer to shape tuple data. */ + index_type* data; + /*! \brief The size of the shape tuple object. */ + uint64_t size; + + static constexpr const uint32_t _type_index = runtime::TypeIndex::kRuntimeShapeTuple; + static constexpr const char* _type_key = "runtime.ShapeTuple"; + TVM_DECLARE_FINAL_OBJECT_INFO(ShapeTupleObj, Object); + + private: + /*! \brief ShapeTuple object which is moved from std::vector container. */ + class FromStd; + + friend class ShapeTuple; +}; + +/*! \brief An object representing shape tuple moved from std::vector. */ +class ShapeTupleObj::FromStd : public ShapeTupleObj { + public: + /*! \brief The type of shape index element. */ + using index_type = ShapeTupleObj::index_type; + /*! + * \brief Construct a new FromStd object + * + * \param other The moved/copied std::vector object + * + * \note If user passes const reference, it will trigger copy. If it's rvalue, + * it will be moved into other. + */ + explicit FromStd(std::vector other) : data_container{other} {} + + private: + /*! \brief Container that holds the memory. */ + std::vector data_container; + + friend class ShapeTuple; +}; + +/*! + * \brief Reference to shape tuple objects. + */ +class ShapeTuple : public ObjectRef { + public: + /*! \brief The type of shape index element. */ + using index_type = ShapeTupleObj::index_type; + + /*! + * \brief Construct an empty shape tuple. + */ + ShapeTuple() : ShapeTuple(std::vector()) {} + + /*! + * \brief Constructor from iterator + * \param begin begin of iterator + * \param end end of iterator + * \tparam IterType The type of iterator + */ + template + ShapeTuple(IterType begin, IterType end) : ShapeTuple(std::vector(begin, end)) {} + + /*! + * \brief constructor from initializer list + * \param shape The initializer list + */ + ShapeTuple(std::initializer_list shape) : ShapeTuple(shape.begin(), shape.end()) {} + + /*! + * \brief Construct a new ShapeTuple object + * + * \param shape The moved/copied std::vector object + * + * \note If user passes const reference, it will trigger copy. If it's rvalue, + * it will be moved into other. + */ + ShapeTuple(std::vector shape); // NOLINT(*) + + /*! + * \brief Return the data pointer + * + * \return const index_type* data pointer + */ + const index_type* data() const { return get()->data; } + + /*! + * \brief Return the size of the shape tuple + * + * \return size_t shape tuple size + */ + size_t size() const { return get()->size; } + + /*! + * \brief Immutably read i-th element from the shape tuple. + * \param idx The index + * \return the i-th element. + */ + index_type operator[](size_t idx) const { + ICHECK(0 <= idx && idx < this->size()) + << "IndexError: indexing " << idx << " on an array of size " << this->size(); + return this->data()[idx]; + } + + /*! + * \brief Immutably read i-th element from the shape tuple. + * \param idx The index + * \return the i-th element. + */ + index_type at(size_t idx) const { return this->operator[](idx); } + + /*! \return Whether shape tuple is empty */ + bool empty() const { return size() == 0; } + + /*! \return The first element of the shape tuple */ + index_type front() const { return this->at(0); } + + /*! \return The last element of the shape tuple */ + index_type back() const { return this->at(this->size() - 1); } + + /*! \return begin iterator */ + const index_type* begin() const { return get()->data; } + + /*! \return end iterator */ + const index_type* end() const { return (get()->data + size()); } + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ShapeTuple, ObjectRef, ShapeTupleObj); +}; + +inline ShapeTuple::ShapeTuple(std::vector shape) { + auto ptr = make_object(std::move(shape)); + ptr->size = ptr->data_container.size(); + ptr->data = ptr->data_container.data(); + data_ = std::move(ptr); +} + +} // namespace runtime + +// expose the functions to the root namespace. +using runtime::ShapeTuple; +using runtime::ShapeTupleObj; +} // namespace tvm + +#endif // TVM_RUNTIME_CONTAINER_SHAPE_TUPLE_H_ diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index bfc681e24418..1127a9ae732c 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -25,8 +25,8 @@ #define TVM_RUNTIME_NDARRAY_H_ #include -#include #include +#include #include #include #include @@ -128,7 +128,7 @@ class NDArray : public ObjectRef { * \param dtype The data type of the new array. * \note The memory size of new array must be smaller than the current one. */ - TVM_DLL NDArray CreateView(std::vector shape, DLDataType dtype); + TVM_DLL NDArray CreateView(ShapeTuple shape, DLDataType dtype); /*! * \brief Create a reference view of NDArray that * represents as DLManagedTensor. @@ -143,7 +143,7 @@ class NDArray : public ObjectRef { * \param mem_scope The memory scope of the array. * \return The created Array */ - TVM_DLL static NDArray Empty(std::vector shape, DLDataType dtype, Device dev, + TVM_DLL static NDArray Empty(ShapeTuple shape, DLDataType dtype, Device dev, Optional mem_scope = NullOpt); /*! * \brief Create a NDArray backed by a dlpack tensor. @@ -166,7 +166,7 @@ class NDArray : public ObjectRef { TVM_DLL static void CopyFromTo(const DLTensor* from, DLTensor* to, TVMStreamHandle stream = nullptr); - TVM_DLL std::vector Shape() const; + TVM_DLL ShapeTuple Shape() const; TVM_DLL runtime::DataType DataType() const; // internal namespace struct Internal; @@ -241,7 +241,7 @@ class NDArray::ContainerBase { * \brief The shape container, * can be used used for shape data. */ - std::vector shape_; + ShapeTuple shape_; }; /*! @@ -261,13 +261,13 @@ class NDArray::Container : public Object, public NDArray::ContainerBase { dl_tensor.byte_offset = 0; } - Container(void* data, std::vector shape, DLDataType dtype, Device dev) { + Container(void* data, ShapeTuple shape, DLDataType dtype, Device dev) { // Initialize the type index. type_index_ = Container::RuntimeTypeIndex(); dl_tensor.data = data; shape_ = std::move(shape); dl_tensor.ndim = static_cast(shape_.size()); - dl_tensor.shape = dmlc::BeginPtr(shape_); + dl_tensor.shape = const_cast(shape_.data()); dl_tensor.dtype = dtype; dl_tensor.strides = nullptr; dl_tensor.byte_offset = 0; @@ -357,8 +357,7 @@ inline void NDArray::CopyTo(const NDArray& other) const { inline NDArray NDArray::CopyTo(const Device& dev) const { ICHECK(data_ != nullptr); const DLTensor* dptr = operator->(); - NDArray ret = - Empty(std::vector(dptr->shape, dptr->shape + dptr->ndim), dptr->dtype, dev); + NDArray ret = Empty(ShapeTuple(dptr->shape, dptr->shape + dptr->ndim), dptr->dtype, dev); this->CopyTo(ret); return ret; } @@ -460,7 +459,7 @@ inline bool NDArray::Load(dmlc::Stream* strm) { if (ndim != 0) { ICHECK(strm->ReadArray(&shape[0], ndim)) << "Invalid DLTensor file format"; } - NDArray ret = NDArray::Empty(shape, dtype, dev); + NDArray ret = NDArray::Empty(ShapeTuple(shape), dtype, dev); int64_t num_elems = 1; int elem_bytes = (ret->dtype.bits + 7) / 8; for (int i = 0; i < ret->ndim; ++i) { diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index f13bdee09f87..0ed61177e65a 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -68,6 +68,8 @@ struct TypeIndex { kRuntimeArray = 4, /*! \brief runtime::Map. */ kRuntimeMap = 5, + /*! \brief runtime::ShapeTuple. */ + kRuntimeShapeTuple = 6, // static assignments that may subject to change. kRuntimeClosure, kRuntimeADT, diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 3e8f23b755f9..9bfe379a3d77 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -26,6 +26,7 @@ #include #include +#include #include #include #include diff --git a/python/tvm/runtime/container.py b/python/tvm/runtime/container.py index 63383e7710f5..7f83693292ba 100644 --- a/python/tvm/runtime/container.py +++ b/python/tvm/runtime/container.py @@ -137,3 +137,27 @@ def __from_tvm_object__(cls, obj): val = str.__new__(cls, content) val.__tvm_object__ = obj return val + + +@tvm._ffi.register_object("runtime.ShapeTuple") +class ShapeTuple(Object): + """TVM runtime ShapeTuple object. + Parameters + ---------- + shape : list[int] + The shape list used to construct the object. + """ + + def __init__(self, shape): + assert isinstance(shape, (list, tuple)), "Expect list of tuple, but received : {0}".format( + type(shape) + ) + for x in shape: + assert isinstance(x, int), "Expect int type, but received : {0}".format(type(x)) + self.__init_handle_by_constructor__(_ffi_api.ShapeTuple, *shape) + + def __len__(self): + return _ffi_api.GetShapeTupleSize(self) + + def __getitem__(self, idx): + return getitem_helper(self, _ffi_api.GetShapeTupleElem, len(self), idx) diff --git a/src/node/container_printing.cc b/src/node/container_printing.cc index 7b972966bef8..1565630cc6ac 100644 --- a/src/node/container_printing.cc +++ b/src/node/container_printing.cc @@ -60,4 +60,16 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << '}'; }); +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << '['; + for (size_t i = 0; i < op->size; ++i) { + if (i != 0) { + p->stream << ", "; + } + p->stream << op->data[i]; + } + p->stream << ']'; + }); } // namespace tvm diff --git a/src/runtime/container.cc b/src/runtime/container.cc index 9d648dcb9a5f..159404be5351 100644 --- a/src/runtime/container.cc +++ b/src/runtime/container.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -108,7 +109,6 @@ TVM_REGISTER_GLOBAL("runtime.ADT").set_body([](TVMArgs args, TVMRetValue* rv) { }); // String - TVM_REGISTER_OBJECT_TYPE(StringObj); TVM_REGISTER_GLOBAL("runtime.String").set_body_typed([](std::string str) { @@ -120,7 +120,6 @@ TVM_REGISTER_GLOBAL("runtime.GetFFIString").set_body_typed([](String str) { }); // Map - TVM_REGISTER_OBJECT_TYPE(MapNode); TVM_REGISTER_GLOBAL("runtime.Map").set_body([](TVMArgs args, TVMRetValue* ret) { @@ -185,7 +184,27 @@ TVM_REGISTER_GLOBAL("runtime.MapItems").set_body([](TVMArgs args, TVMRetValue* r TVM_DLL constexpr uint64_t DenseMapNode::kNextProbeLocation[]; #endif +// Closure TVM_REGISTER_OBJECT_TYPE(ClosureObj); +// ShapeTuple +TVM_REGISTER_OBJECT_TYPE(ShapeTupleObj); + +TVM_REGISTER_GLOBAL("runtime.ShapeTuple").set_body([](TVMArgs args, TVMRetValue* rv) { + std::vector shape; + for (int i = 0; i < args.size(); i++) { + shape.push_back(args[i]); + } + *rv = ShapeTuple(shape); +}); + +TVM_REGISTER_GLOBAL("runtime.GetShapeTupleSize").set_body_typed([](ShapeTuple shape) { + return static_cast(shape.size()); +}); + +TVM_REGISTER_GLOBAL("runtime.GetShapeTupleElem").set_body_typed([](ShapeTuple shape, int idx) { + ICHECK_LT(idx, shape.size()); + return shape[idx]; +}); } // namespace runtime } // namespace tvm diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index 3d3466bed47c..968a4488bbcf 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -123,7 +123,7 @@ struct NDArray::Internal { } // Local create function which allocates tensor metadata // but does not allocate space for the data. - static NDArray Create(std::vector shape, DLDataType dtype, Device dev) { + static NDArray Create(ShapeTuple shape, DLDataType dtype, Device dev) { VerifyDataType(dtype); // critical zone: construct header @@ -134,7 +134,7 @@ struct NDArray::Internal { NDArray ret(GetObjectPtr(data)); // setup shape data->shape_ = std::move(shape); - data->dl_tensor.shape = dmlc::BeginPtr(data->shape_); + data->dl_tensor.shape = const_cast(data->shape_.data()); data->dl_tensor.ndim = static_cast(data->shape_.size()); // setup dtype data->dl_tensor.dtype = dtype; @@ -172,7 +172,7 @@ struct NDArray::Internal { } }; -NDArray NDArray::CreateView(std::vector shape, DLDataType dtype) { +NDArray NDArray::CreateView(ShapeTuple shape, DLDataType dtype) { ICHECK(data_ != nullptr); ICHECK(get_mutable()->dl_tensor.strides == nullptr) << "Can only create view for compact tensor"; NDArray ret = Internal::Create(shape, dtype, get_mutable()->dl_tensor.device); @@ -190,8 +190,7 @@ NDArray NDArray::CreateView(std::vector shape, DLDataType dtype) { DLManagedTensor* NDArray::ToDLPack() const { return Internal::ToDLPack(get_mutable()); } -NDArray NDArray::Empty(std::vector shape, DLDataType dtype, Device dev, - Optional mem_scope) { +NDArray NDArray::Empty(ShapeTuple shape, DLDataType dtype, Device dev, Optional mem_scope) { NDArray ret = Internal::Create(shape, dtype, dev); ret.get_mutable()->dl_tensor.data = DeviceAPI::Get(ret->device) @@ -207,9 +206,11 @@ NDArray NDArray::FromDLPack(DLManagedTensor* tensor) { data->manager_ctx = tensor; data->dl_tensor = tensor->dl_tensor; // update shape_ - data->shape_.resize(data->dl_tensor.ndim); - data->shape_.assign(data->dl_tensor.shape, data->dl_tensor.shape + data->dl_tensor.ndim); - data->dl_tensor.shape = data->shape_.data(); + std::vector shape; + shape.resize(data->dl_tensor.ndim); + shape.assign(data->dl_tensor.shape, data->dl_tensor.shape + data->dl_tensor.ndim); + data->shape_ = ShapeTuple(shape); + data->dl_tensor.shape = const_cast(data->shape_.data()); return NDArray(GetObjectPtr(data)); } @@ -242,7 +243,7 @@ void NDArray::CopyFromTo(const DLTensor* from, DLTensor* to, TVMStreamHandle str DeviceAPI::Get(dev)->CopyDataFromTo(const_cast(from), to, stream); } -std::vector NDArray::Shape() const { return get_mutable()->shape_; } +ShapeTuple NDArray::Shape() const { return get_mutable()->shape_; } runtime::DataType NDArray::DataType() const { return runtime::DataType(get_mutable()->dl_tensor.dtype); } @@ -274,7 +275,7 @@ int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, int dtype_ Device dev; dev.device_type = static_cast(device_type); dev.device_id = device_id; - auto ndarray = NDArray::Empty(std::vector(shape, shape + ndim), dtype, dev); + auto ndarray = NDArray::Empty(ShapeTuple(shape, shape + ndim), dtype, dev); *out = NDArray::Internal::MoveToFFIHandle(ndarray); API_END(); @@ -283,7 +284,7 @@ int TVMArrayAlloc(const tvm_index_t* shape, int ndim, int dtype_code, int dtype_ TVM_REGISTER_GLOBAL("runtime.TVMArrayAllocWithScope").set_body([](TVMArgs args, TVMRetValue* ret) { int64_t* shape_ptr = static_cast(static_cast(args[0])); int ndim = args[1]; - std::vector shape(shape_ptr, shape_ptr + ndim); + ShapeTuple shape(shape_ptr, shape_ptr + ndim); DataType dtype = args[2]; Device dev = args[3]; Optional mem_scope = args[4]; diff --git a/tests/python/unittest/test_runtime_container.py b/tests/python/unittest/test_runtime_container.py index 39fd575ff6d8..781fd7f93886 100644 --- a/tests/python/unittest/test_runtime_container.py +++ b/tests/python/unittest/test_runtime_container.py @@ -78,7 +78,16 @@ def test_string(): assert s == z +def test_shape_tuple(): + shape = [random.randint(-10, 10) for _ in range(5)] + stuple = _container.ShapeTuple(shape) + len(stuple) == len(shape) + for a, b in zip(stuple, shape): + assert a == b + + if __name__ == "__main__": test_string() test_adt_constructor() test_tuple_object() + test_shape_tuple()