diff --git a/3rdparty/dmlc-core b/3rdparty/dmlc-core index ff3db4367a30..6c401e242c59 160000 --- a/3rdparty/dmlc-core +++ b/3rdparty/dmlc-core @@ -1 +1 @@ -Subproject commit ff3db4367a30f542aafb83b4af45e685b80102d0 +Subproject commit 6c401e242c59a1f4c913918246591bb13fd714e7 diff --git a/include/tvm/node/container.h b/include/tvm/node/container.h index 2b6645fa165b..0b0887c484d9 100644 --- a/include/tvm/node/container.h +++ b/include/tvm/node/container.h @@ -37,6 +37,9 @@ namespace tvm { +using runtime::Array; +using runtime::ArrayNode; +using runtime::IterAdapter; using runtime::make_object; using runtime::Object; using runtime::ObjectEqual; @@ -46,16 +49,6 @@ using runtime::ObjectRef; using runtime::String; using runtime::StringObj; -/*! \brief array node content in array */ -class ArrayNode : public Object { - public: - /*! \brief the data content */ - std::vector data; - - static constexpr const char* _type_key = "Array"; - TVM_DECLARE_FINAL_OBJECT_INFO(ArrayNode, Object); -}; - /*! \brief map node content */ class MapNode : public Object { public: @@ -82,273 +75,6 @@ class StrMapNode : public Object { TVM_DECLARE_FINAL_OBJECT_INFO(StrMapNode, Object); }; -/*! - * \brief iterator adapter that adapts TIter to return another type. - * \tparam Converter a struct that contains converting function - * \tparam TIter the content iterator type. - */ -template -class IterAdapter { - public: - using difference_type = typename std::iterator_traits::difference_type; - using value_type = typename Converter::ResultType; - using pointer = typename Converter::ResultType*; - using reference = typename Converter::ResultType&; // NOLINT(*) - using iterator_category = typename std::iterator_traits::iterator_category; - - explicit IterAdapter(TIter iter) : iter_(iter) {} - inline IterAdapter& operator++() { - ++iter_; - return *this; - } - inline IterAdapter operator+(difference_type offset) const { return IterAdapter(iter_ + offset); } - - template - typename std::enable_if::value, - typename T::difference_type>::type inline - operator-(const IterAdapter& rhs) const { - return iter_ - rhs.iter_; - } - - inline bool operator==(IterAdapter other) const { return iter_ == other.iter_; } - inline bool operator!=(IterAdapter other) const { return !(*this == other); } - inline const value_type operator*() const { return Converter::convert(*iter_); } - - private: - TIter iter_; -}; - -/*! - * \brief Array container of NodeRef in DSL graph. - * Array implements copy on write semantics, which means array is mutable - * but copy will happen when array is referenced in more than two places. - * - * operator[] only provide const acces, use Set to mutate the content. - * \tparam T The content NodeRef type. - */ -template ::value>::type> -class Array : public ObjectRef { - public: - /*! - * \brief default constructor - */ - Array() { data_ = make_object(); } - /*! - * \brief move constructor - * \param other source - */ - Array(Array&& other) : ObjectRef() { // NOLINT(*) - data_ = std::move(other.data_); - } - /*! - * \brief copy constructor - * \param other source - */ - Array(const Array& other) : ObjectRef() { // NOLINT(*) - data_ = std::move(other.data_); - } - /*! - * \brief constructor from pointer - * \param n the container pointer - */ - explicit Array(ObjectPtr n) : ObjectRef(n) {} - /*! - * \brief constructor from iterator - * \param begin begin of iterator - * \param end end of iterator - * \tparam IterType The type of iterator - */ - template - Array(IterType begin, IterType end) { - assign(begin, end); - } - /*! - * \brief constructor from initializer list - * \param init The initalizer list - */ - Array(std::initializer_list init) { // NOLINT(*) - assign(init.begin(), init.end()); - } - /*! - * \brief constructor from vector - * \param init The vector - */ - Array(const std::vector& init) { // NOLINT(*) - assign(init.begin(), init.end()); - } - /*! - * \brief Constructs a container with n elements. Each element is a copy of val - * \param n The size of the container - * \param val The init value - */ - explicit Array(size_t n, const T& val) { - auto tmp_node = make_object(); - for (size_t i = 0; i < n; ++i) { - tmp_node->data.push_back(val); - } - data_ = std::move(tmp_node); - } - /*! - * \brief move assign operator - * \param other The source of assignment - * \return reference to self. - */ - Array& operator=(Array&& other) { - data_ = std::move(other.data_); - return *this; - } - /*! - * \brief copy assign operator - * \param other The source of assignment - * \return reference to self. - */ - Array& operator=(const Array& other) { - data_ = other.data_; - return *this; - } - /*! - * \brief reset the array to content from iterator. - * \param begin begin of iterator - * \param end end of iterator - * \tparam IterType The type of iterator - */ - template - void assign(IterType begin, IterType end) { - auto n = make_object(); - for (IterType it = begin; it != end; ++it) { - n->data.push_back(T(*it)); - } - data_ = std::move(n); - } - /*! - * \brief Read i-th element from array. - * \param i The index - * \return the i-th element. - */ - inline const T operator[](size_t i) const { - return DowncastNoCheck(static_cast(data_.get())->data[i]); - } - /*! \return The size of the array */ - inline size_t size() const { - if (data_.get() == nullptr) return 0; - return static_cast(data_.get())->data.size(); - } - /*! - * \brief copy on write semantics - * Do nothing if current handle is the unique copy of the array. - * Otherwise make a new copy of the array to ensure the current handle - * hold a unique copy. - * - * \return Handle to the internal node container(which ganrantees to be unique) - */ - inline ArrayNode* CopyOnWrite() { - if (data_.get() == nullptr || !data_.unique()) { - ObjectPtr n = make_object(); - n->data = static_cast(data_.get())->data; - ObjectPtr(std::move(n)).swap(data_); - } - return static_cast(data_.get()); - } - /*! - * \brief push a new item to the back of the list - * \param item The item to be pushed. - */ - inline void push_back(const T& item) { - ArrayNode* n = this->CopyOnWrite(); - n->data.push_back(item); - } - /*! - * \brief Resize the array. - * \param size The new size. - */ - inline void resize(size_t size) { - ArrayNode* n = this->CopyOnWrite(); - n->data.resize(size); - } - /*! - * \brief set i-th element of the array. - * \param i The index - * \param value The value to be setted. - */ - inline void Set(size_t i, const T& value) { - ArrayNode* n = this->CopyOnWrite(); - n->data[i] = value; - } - /*! \return whether array is empty */ - inline bool empty() const { return size() == 0; } - /*! - * \brief Helper function to apply fmutate to mutate an array. - * \param fmutate The transformation function T -> T. - * \tparam F the type of the mutation function. - * \note This function performs copy on write optimization. - */ - template - inline void MutateByApply(F fmutate) { - ArrayNode* ptr = static_cast(data_.get()); - if (ptr == nullptr) return; - if (data_.unique()) { - // Copy on write optimization. - // Perform inplace update because this is an unique copy. - for (size_t i = 0; i < ptr->data.size(); ++i) { - // It is important to use move here - // to make prevent the element's ref count from increasing - // so fmutate itself can perform copy-on-write optimization - T old_elem = DowncastNoCheck(std::move(ptr->data[i])); - T new_elem = fmutate(std::move(old_elem)); - ptr->data[i] = std::move(new_elem); - } - } else { - // lazily trigger copy if there is element change. - ObjectPtr copy; - for (size_t i = 0; i < ptr->data.size(); ++i) { - T old_elem = DowncastNoCheck(ptr->data[i]); - T new_elem = fmutate(old_elem); - if (!new_elem.same_as(ptr->data[i])) { - // copy the old array - if (copy == nullptr) { - copy = runtime::make_object(*ptr); - } - copy->data[i] = std::move(new_elem); - } - } - // replace the data with the new copy. - if (copy != nullptr) { - data_ = std::move(copy); - } - } - } - - /*! \brief specify container node */ - using ContainerType = ArrayNode; - - struct ValueConverter { - using ResultType = T; - static inline T convert(const ObjectRef& n) { return DowncastNoCheck(n); } - }; - using iterator = IterAdapter::const_iterator>; - - using reverse_iterator = - IterAdapter::const_reverse_iterator>; - - /*! \return begin iterator */ - inline iterator begin() const { - return iterator(static_cast(data_.get())->data.begin()); - } - /*! \return end iterator */ - inline iterator end() const { - return iterator(static_cast(data_.get())->data.end()); - } - /*! \return rbegin iterator */ - inline reverse_iterator rbegin() const { - return reverse_iterator(static_cast(data_.get())->data.rbegin()); - } - /*! \return rend iterator */ - inline reverse_iterator rend() const { - return reverse_iterator(static_cast(data_.get())->data.rend()); - } -}; - /*! * \brief Map container of NodeRef->NodeRef in DSL graph. * Map implements copy on write semantics, which means map is mutable @@ -404,8 +130,8 @@ class Map : public ObjectRef { assign(init.begin(), init.end()); } /*! - * \brief constructor from vector - * \param init The vector + * \brief constructor from unordered_map + * \param init The unordered_map */ template Map(const std::unordered_map& init) { // NOLINT(*) @@ -625,7 +351,7 @@ struct ObjectTypeChecker > { if (ptr == nullptr) return true; if (!ptr->IsInstance()) return false; const ArrayNode* n = static_cast(ptr); - for (const auto& p : n->data) { + for (const ObjectRef& p : *n) { if (!ObjectTypeChecker::Check(p.get())) { return false; } diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index c0e227291ada..7fb7f3add58c 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -188,7 +188,7 @@ struct SqueezeAttrs : public tvm::AttrsNode { "If `axis = None`, all axis of dimension 1 get squeezed;" "Else, the dimension in axes get squeezed." "It is an error if an axis does not has dimension 1.") - .set_default(NullValue >()); + .set_default(NullValue>()); } }; // struct SqueezeAttrs diff --git a/include/tvm/runtime/container.h b/include/tvm/runtime/container.h index e2f2453933a5..a52e99791132 100644 --- a/include/tvm/runtime/container.h +++ b/include/tvm/runtime/container.h @@ -29,8 +29,10 @@ #include #include +#include #include #include +#include #include // We use c++14 std::experimental::string_view for optimizing hash computation // only right now, its usage is limited in this file. Any broader usage of @@ -160,7 +162,6 @@ class InplaceArrayBase { new (field_ptr) ElemType(std::forward(args)...); } - private: /*! * \brief Return the self object for the array. * @@ -189,6 +190,777 @@ class InplaceArrayBase { } }; +/*! + * \brief iterator adapter that adapts TIter to return another type. + * \tparam Converter a struct that contains converting function + * \tparam TIter the content iterator type. + */ +template +class IterAdapter { + public: + using difference_type = typename std::iterator_traits::difference_type; + using value_type = typename Converter::ResultType; + using pointer = typename Converter::ResultType*; + using reference = typename Converter::ResultType&; // NOLINT(*) + using iterator_category = typename std::iterator_traits::iterator_category; + + explicit IterAdapter(TIter iter) : iter_(iter) {} + IterAdapter& operator++() { + ++iter_; + return *this; + } + IterAdapter& operator--() { + --iter_; + return *this; + } + IterAdapter& operator++(int) { + IterAdapter copy = *this; + ++iter_; + return copy; + } + IterAdapter& operator--(int) { + IterAdapter copy = *this; + --iter_; + return copy; + } + + IterAdapter operator+(difference_type offset) const { return IterAdapter(iter_ + offset); } + + template + typename std::enable_if::value, + typename T::difference_type>::type inline + operator-(const IterAdapter& rhs) const { + return iter_ - rhs.iter_; + } + + bool operator==(IterAdapter other) const { return iter_ == other.iter_; } + bool operator!=(IterAdapter other) const { return !(*this == other); } + const value_type operator*() const { return Converter::convert(*iter_); } + + private: + TIter iter_; +}; + +/*! + * \brief iterator adapter that adapts TIter to return another type. + * \tparam Converter a struct that contains converting function + * \tparam TIter the content iterator type. + */ +template +class ReverseIterAdapter { + public: + using difference_type = typename std::iterator_traits::difference_type; + using value_type = typename Converter::ResultType; + using pointer = typename Converter::ResultType*; + using reference = typename Converter::ResultType&; // NOLINT(*) + using iterator_category = typename std::iterator_traits::iterator_category; + + explicit ReverseIterAdapter(TIter iter) : iter_(iter) {} + ReverseIterAdapter& operator++() { + --iter_; + return *this; + } + ReverseIterAdapter& operator--() { + ++iter_; + return *this; + } + ReverseIterAdapter& operator++(int) { + ReverseIterAdapter copy = *this; + --iter_; + return copy; + } + ReverseIterAdapter& operator--(int) { + ReverseIterAdapter copy = *this; + ++iter_; + return copy; + } + ReverseIterAdapter operator+(difference_type offset) const { + return ReverseIterAdapter(iter_ - offset); + } + + template + typename std::enable_if::value, + typename T::difference_type>::type inline + operator-(const ReverseIterAdapter& rhs) const { + return rhs.iter_ - iter_; + } + + bool operator==(ReverseIterAdapter other) const { return iter_ == other.iter_; } + bool operator!=(ReverseIterAdapter other) const { return !(*this == other); } + const value_type operator*() const { return Converter::convert(*iter_); } + + private: + TIter iter_; +}; + +/*! \brief array node content in array */ +class ArrayNode : public Object, public InplaceArrayBase { + public: + /*! \return The size of the array */ + size_t size() const { return this->size_; } + + /*! + * \brief Read i-th element from array. + * \param i The index + * \return the i-th element. + */ + const ObjectRef at(int64_t i) const { return this->operator[](i); } + + /*! \return begin constant iterator */ + const ObjectRef* begin() const { return static_cast(InplaceArrayBase::AddressOf(0)); } + + /*! \return end constant iterator */ + const ObjectRef* end() const { return begin() + size_; } + + /*! \brief Release reference to all the elements */ + void clear() { ShrinkBy(size_); } + + /*! + * \brief Set i-th element of the array in-place + * \param i The index + * \param item The value to be set + */ + void SetItem(int64_t i, ObjectRef item) { this->operator[](i) = std::move(item); } + + /*! + * \brief Constructs a container and copy from another + * \param cap The capacity of the container + * \param from Source of the copy + * \return Ref-counted ArrayNode requested + */ + static ObjectPtr CopyFrom(int64_t cap, ArrayNode* from) { + int64_t size = from->size_; + CHECK_GE(cap, size) << "ValueError: not enough capacity"; + ObjectPtr p = ArrayNode::Empty(cap); + ObjectRef* write = p->MutableBegin(); + ObjectRef* read = from->MutableBegin(); + // To ensure exception safety, size is only incremented after the initialization succeeds + for (int64_t& i = p->size_ = 0; i < size; ++i) { + new (write++) ObjectRef(*read++); + } + return p; + } + + /*! + * \brief Constructs a container and move from another + * \param cap The capacity of the container + * \param from Source of the move + * \return Ref-counted ArrayNode requested + */ + static ObjectPtr MoveFrom(int64_t cap, ArrayNode* from) { + int64_t size = from->size_; + CHECK_GE(cap, size) << "ValueError: not enough capacity"; + ObjectPtr p = ArrayNode::Empty(cap); + ObjectRef* write = p->MutableBegin(); + ObjectRef* read = from->MutableBegin(); + // To ensure exception safety, size is only incremented after the initialization succeeds + for (int64_t& i = p->size_ = 0; i < size; ++i) { + new (write++) ObjectRef(std::move(*read++)); + } + from->size_ = 0; + return p; + } + + /*! + * \brief Constructs a container with n elements. Each element is a copy of val + * \param n The size of the container + * \param val The init value + * \return Ref-counted ArrayNode requested + */ + static ObjectPtr CreateRepeated(int64_t n, const ObjectRef& val) { + ObjectPtr p = ArrayNode::Empty(n); + ObjectRef* itr = p->MutableBegin(); + for (int64_t& i = p->size_ = 0; i < n; ++i) { + new (itr++) ObjectRef(val); + } + return p; + } + + static constexpr const uint32_t _type_index = TypeIndex::kRuntimeArray; + static constexpr const char* _type_key = "Array"; + TVM_DECLARE_FINAL_OBJECT_INFO(ArrayNode, Object); + + private: + /*! \return Size of initialized memory, used by InplaceArrayBase. */ + size_t GetSize() const { return this->size_; } + + /*! \return begin mutable iterator */ + ObjectRef* MutableBegin() const { + return static_cast(InplaceArrayBase::AddressOf(0)); + } + + /*! \return end mutable iterator */ + ObjectRef* MutableEnd() const { return MutableBegin() + size_; } + + /*! + * \brief Create an ArrayNode with the given capacity. + * \param n Required capacity + * \return Ref-counted ArrayNode requested + */ + static ObjectPtr Empty(int64_t n = kInitSize) { + CHECK_GE(n, 0); + ObjectPtr p = make_inplace_array_object(n); + p->capacity_ = n; + p->size_ = 0; + return p; + } + + /*! + * \brief Inplace-initialize the elements starting idx from [first, last) + * \param idx The starting point + * \param first Begin of iterator + * \param last End of iterator + * \tparam IterType The type of iterator + * \return Self + */ + template + ArrayNode* InitRange(int64_t idx, IterType first, IterType last) { + ObjectRef* itr = MutableBegin() + idx; + for (; first != last; ++first) { + ObjectRef ref = *first; + new (itr++) ObjectRef(std::move(ref)); + } + return this; + } + + /*! + * \brief Move elements from right to left, requires src_begin > dst + * \param dst Destination + * \param src_begin The start point of copy (inclusive) + * \param src_end The end point of copy (exclusive) + * \return Self + */ + ArrayNode* MoveElementsLeft(int64_t dst, int64_t src_begin, int64_t src_end) { + ObjectRef* from = MutableBegin() + src_begin; + ObjectRef* to = MutableBegin() + dst; + while (src_begin++ != src_end) { + *to++ = std::move(*from++); + } + return this; + } + + /*! + * \brief Move elements from left to right, requires src_begin < dst + * \param dst Destination + * \param src_begin The start point of move (inclusive) + * \param src_end The end point of move (exclusive) + * \return Self + */ + ArrayNode* MoveElementsRight(int64_t dst, int64_t src_begin, int64_t src_end) { + ObjectRef* from = MutableBegin() + src_end; + ObjectRef* to = MutableBegin() + (src_end - src_begin + dst); + while (src_begin++ != src_end) { + *--to = std::move(*--from); + } + return this; + } + + /*! + * \brief Enlarges the size of the array + * \param delta Size enlarged, should be positive + * \param val Default value + * \return Self + */ + ArrayNode* EnlargeBy(int64_t delta, const ObjectRef& val = ObjectRef(nullptr)) { + ObjectRef* itr = MutableEnd(); + while (delta-- > 0) { + new (itr++) ObjectRef(val); + ++size_; + } + return this; + } + + /*! + * \brief Shrinks the size of the array + * \param delta Size shrinked, should be positive + * \return Self + */ + ArrayNode* ShrinkBy(int64_t delta) { + ObjectRef* itr = MutableEnd(); + while (delta-- > 0) { + (--itr)->ObjectRef::~ObjectRef(); + --size_; + } + return this; + } + + /*! \brief Number of elements used */ + int64_t size_; + + /*! \brief Number of elements allocated */ + int64_t capacity_; + + /*! \brief Initial size of ArrayNode */ + static const constexpr int64_t kInitSize = 4; + + /*! \brief Expansion factor of the Array */ + static const constexpr int64_t kIncFactor = 2; + + // CRTP parent class + friend InplaceArrayBase; + + // Reference class + template + friend class Array; + + // To specialize make_object + friend ObjectPtr make_object<>(); +}; + +/*! + * \brief Array container of ObjectRef in DSL graph. + * Array implements copy-on-write semantics, which means array is mutable + * but copy will happen when array is referenced in more than two places. + * + * operator[] only provide const access, use Set to mutate the content. + * \tparam T The content ObjectRef type. + */ +template ::value>::type> +class Array : public ObjectRef { + public: + // constructors + /*! + * \brief default constructor + */ + Array() { data_ = ArrayNode::Empty(); } + + /*! + * \brief move constructor + * \param other source + */ + Array(Array&& other) : ObjectRef() { // NOLINT(*) + data_ = std::move(other.data_); + } + + /*! + * \brief copy constructor + * \param other source + */ + Array(const Array& other) : ObjectRef() { // NOLINT(*) + data_ = other.data_; + } + + /*! + * \brief constructor from pointer + * \param n the container pointer + */ + explicit Array(ObjectPtr n) : ObjectRef(n) {} + + /*! + * \brief Constructor from iterator + * \param first begin of iterator + * \param last end of iterator + * \tparam IterType The type of iterator + */ + template + Array(IterType first, IterType last) { + Assign(first, last); + } + + /*! + * \brief constructor from initializer list + * \param init The initializer list + */ + Array(std::initializer_list init) { // NOLINT(*) + Assign(init.begin(), init.end()); + } + + /*! + * \brief constructor from vector + * \param init The vector + */ + Array(const std::vector& init) { // NOLINT(*) + Assign(init.begin(), init.end()); + } + + /*! + * \brief Constructs a container with n elements. Each element is a copy of val + * \param n The size of the container + * \param val The init value + */ + explicit Array(const size_t n, const T& val) { data_ = ArrayNode::CreateRepeated(n, val); } + + /*! + * \brief move assign operator + * \param other The source of assignment + * \return reference to self. + */ + Array& operator=(Array&& other) { + data_ = std::move(other.data_); + return *this; + } + + /*! + * \brief copy assign operator + * \param other The source of assignment + * \return reference to self. + */ + Array& operator=(const Array& other) { + data_ = other.data_; + return *this; + } + + public: + // iterators + struct ValueConverter { + using ResultType = T; + static T convert(const ObjectRef& n) { return DowncastNoCheck(n); } + }; + + using iterator = IterAdapter; + using reverse_iterator = ReverseIterAdapter; + + /*! \return begin iterator */ + iterator begin() const { return iterator(GetArrayNode()->begin()); } + + /*! \return end iterator */ + iterator end() const { return iterator(GetArrayNode()->end()); } + + /*! \return rbegin iterator */ + reverse_iterator rbegin() const { + // ArrayNode::end() is never nullptr + return reverse_iterator(GetArrayNode()->end() - 1); + } + + /*! \return rend iterator */ + reverse_iterator rend() const { + // ArrayNode::begin() is never nullptr + return reverse_iterator(GetArrayNode()->begin() - 1); + } + + public: + // const methods in std::vector + /*! + * \brief Immutably read i-th element from array. + * \param i The index + * \return the i-th element. + */ + const T operator[](int64_t i) const { + ArrayNode* p = GetArrayNode(); + CHECK(p != nullptr) << "ValueError: cannot index a null array"; + CHECK(0 <= i && i < p->size_) << "IndexError: indexing " << i << " on an array of size " + << p->size_; + return DowncastNoCheck(*(p->begin() + i)); + } + + /*! \return The size of the array */ + size_t size() const { + ArrayNode* p = GetArrayNode(); + return p == nullptr ? 0 : GetArrayNode()->size_; + } + + /*! \return The capacity of the array */ + size_t capacity() const { + ArrayNode* p = GetArrayNode(); + return p == nullptr ? 0 : GetArrayNode()->capacity_; + } + + /*! \return Whether array is empty */ + bool empty() const { return size() == 0; } + + /*! \return The first element of the array */ + const T front() const { + ArrayNode* p = GetArrayNode(); + CHECK(p != nullptr) << "ValueError: cannot index a null array"; + CHECK_GT(p->size_, 0) << "IndexError: cannot index an empty array"; + return DowncastNoCheck(*(p->begin())); + } + + /*! \return The last element of the array */ + const T back() const { + ArrayNode* p = GetArrayNode(); + CHECK(p != nullptr) << "ValueError: cannot index a null array"; + CHECK_GT(p->size_, 0) << "IndexError: cannot index an empty array"; + return DowncastNoCheck(*(p->end() - 1)); + } + + public: + // mutation in std::vector, implements copy-on-write + + /*! + * \brief push a new item to the back of the list + * \param item The item to be pushed. + */ + void push_back(const T& item) { + ArrayNode* p = CopyOnWrite(1); + p->EmplaceInit(p->size_++, item); + } + + /*! + * \brief Insert an element into the given position + * \param position An iterator pointing to the insertion point + * \param val The element to insert + */ + void insert(iterator position, const T& val) { + CHECK(data_ != nullptr) << "ValueError: cannot insert a null array"; + int64_t idx = std::distance(begin(), position); + int64_t size = GetArrayNode()->size_; + auto addr = CopyOnWrite(1) // + ->EnlargeBy(1) // + ->MoveElementsRight(idx + 1, idx, size) // + ->MutableBegin(); + new (addr + idx) ObjectRef(val); + } + + /*! + * \brief Insert a range of elements into the given position + * \param position An iterator pointing to the insertion point + * \param first The begin iterator of the range + * \param last The end iterator of the range + */ + template + void insert(iterator position, IterType first, IterType last) { + if (first == last) { + return; + } + CHECK(data_ != nullptr) << "ValueError: cannot insert a null array"; + int64_t idx = std::distance(begin(), position); + int64_t size = GetArrayNode()->size_; + int64_t numel = std::distance(first, last); + CopyOnWrite(numel) + ->EnlargeBy(numel) + ->MoveElementsRight(idx + numel, idx, size) + ->InitRange(idx, first, last); + } + + /*! \brief Remove the last item of the list */ + void pop_back() { + CHECK(data_ != nullptr) << "ValueError: cannot pop_back because array is null"; + int64_t size = GetArrayNode()->size_; + CHECK_GT(size, 0) << "ValueError: cannot pop_back because array is empty"; + CopyOnWrite()->ShrinkBy(1); + } + + /*! + * \brief Erase an element on the given position + * \param position An iterator pointing to the element to be erased + */ + void erase(iterator position) { + CHECK(data_ != nullptr) << "ValueError: cannot erase a null array"; + int64_t st = std::distance(begin(), position); + int64_t size = GetArrayNode()->size_; + CHECK(0 <= st && st < size) << "ValueError: cannot erase at index " << st + << ", because Array size is " << size; + CopyOnWrite() // + ->MoveElementsLeft(st, st + 1, size) // + ->ShrinkBy(1); + } + + /*! + * \brief Erase a given range of elements + * \param first The begin iterator of the range + * \param last The end iterator of the range + */ + void erase(iterator first, iterator last) { + if (first == last) { + return; + } + CHECK(data_ != nullptr) << "ValueError: cannot erase a null array"; + int64_t size = GetArrayNode()->size_; + int64_t st = std::distance(begin(), first); + int64_t ed = std::distance(begin(), last); + CHECK_LT(st, ed) << "ValueError: cannot erase array in range [" << st << ", " << ed << ")"; + CHECK(0 <= st && st <= size && 0 <= ed && ed <= size) + << "ValueError: cannot erase array in range [" << st << ", " << ed << ")" + << ", because array size is " << size; + CopyOnWrite() // + ->MoveElementsLeft(st, ed, size) // + ->ShrinkBy(ed - st); + } + + /*! + * \brief Resize the array. + * \param n The new size. + */ + void resize(int64_t n) { + CHECK_GE(n, 0) << "ValueError: cannot resize an Array to negative size"; + if (data_ == nullptr) { + SwitchContainer(n); + return; + } + int64_t size = GetArrayNode()->size_; + if (size < n) { + CopyOnWrite(n - size)->EnlargeBy(n - size); + } else if (size > n) { + CopyOnWrite()->ShrinkBy(size - n); + } + } + + /*! + * \brief Make sure the list has the capacity of at least n + * \param n lower bound of the capacity + */ + void reserve(int64_t n) { + if (data_ == nullptr || n > GetArrayNode()->capacity_) { + SwitchContainer(n); + } + } + + /*! \brief Release reference to all the elements */ + void clear() { + if (data_ != nullptr) { + ArrayNode* p = CopyOnWrite(); + p->clear(); + } + } + + public: + // Array's own methods + + /*! + * \brief set i-th element of the array. + * \param i The index + * \param value The value to be setted. + */ + void Set(int64_t i, T value) { + ArrayNode* p = this->CopyOnWrite(); + CHECK(0 <= i && i < p->size_) << "IndexError: indexing " << i << " on an array of size " + << p->size_; + *(p->MutableBegin() + i) = std::move(value); + } + + /*! \return The underlying ArrayNode */ + ArrayNode* GetArrayNode() const { return static_cast(data_.get()); } + + /*! + * \brief Helper function to apply fmutate to mutate an array. + * \param fmutate The transformation function T -> T. + * \tparam F the type of the mutation function. + * \note This function performs copy on write optimization. + */ + template + void MutateByApply(F fmutate) { + if (data_ == nullptr) { + return; + } + struct StackFrame { + ArrayNode* p; + ObjectRef* itr; + int64_t i; + int64_t size; + }; + std::unique_ptr s = std::make_unique(); + s->p = GetArrayNode(); + s->itr = s->p->MutableBegin(); + s->i = 0; + s->size = s->p->size_; + if (!data_.unique()) { + // Loop invariant: keeps iterating when + // 1) data is not unique + // 2) no elements are actually mutated yet + for (; s->i < s->size; ++s->i, ++s->itr) { + T new_elem = fmutate(DowncastNoCheck(*s->itr)); + // do nothing when there is no mutation + if (new_elem.same_as(*s->itr)) { + continue; + } + // loop invariant breaks when the first real mutation happens + // we copy the elements into a new unique array + ObjectPtr copy = ArrayNode::CopyFrom(s->p->capacity_, s->p); + s->itr = copy->MutableBegin() + (s->i++); + *s->itr++ = std::move(new_elem); + data_ = std::move(copy); + // make sure `data_` is unique and break + break; + } + } + // when execution comes to this line, it is guaranteed that either + // 1) i == size + // or 2) data_.unique() is true + for (; s->i < s->size; ++s->i, ++s->itr) { + *s->itr = std::move(fmutate(std::move(DowncastNoCheck(std::move(*s->itr))))); + } + } + + /*! + * \brief reset the array to content from iterator. + * \param first begin of iterator + * \param last end of iterator + * \tparam IterType The type of iterator + */ + template + void Assign(IterType first, IterType last) { + int64_t cap = std::distance(first, last); + CHECK_GE(cap, 0) << "ValueError: cannot construct an Array of negative size"; + ArrayNode* p = GetArrayNode(); + if (p != nullptr && data_.unique() && p->capacity_ >= cap) { + // do not have to make new space + p->clear(); + } else { + // create new space + data_ = ArrayNode::Empty(cap); + p = GetArrayNode(); + } + // To ensure exception safety, size is only incremented after the initialization succeeds + ObjectRef* itr = p->MutableBegin(); + for (int64_t& i = p->size_ = 0; i < cap; ++i, ++first, ++itr) { + new (itr) ObjectRef(*first); + } + } + + /*! + * \brief Copy on write semantics + * Do nothing if current handle is the unique copy of the array. + * Otherwise make a new copy of the array to ensure the current handle + * hold a unique copy. + * + * \return Handle to the internal node container(which ganrantees to be unique) + */ + ArrayNode* CopyOnWrite() { + if (data_ == nullptr) { + return SwitchContainer(ArrayNode::kInitSize); + } + if (!data_.unique()) { + return SwitchContainer(capacity()); + } + return static_cast(data_.get()); + } + + /*! \brief specify container node */ + using ContainerType = ArrayNode; + + private: + /*! + * \brief Implement copy-on-write semantics, and ensures capacity is enough for extra elements. + * \param reserve_extra Number of extra slots needed + * \return ArrayNode pointer to the unique copy + */ + ArrayNode* CopyOnWrite(int64_t reserve_extra) { + ArrayNode* p = GetArrayNode(); + if (p == nullptr) { + return SwitchContainer(std::max(ArrayNode::kInitSize, reserve_extra)); + } + if (p->capacity_ >= p->size_ + reserve_extra) { + return CopyOnWrite(); + } + int64_t cap = p->capacity_ * ArrayNode::kIncFactor; + cap = std::max(cap, p->size_ + reserve_extra); + return SwitchContainer(cap); + } + + /*! + * \brief Move or copy the ArrayNode to new address with the given capacity + * \param capacity The capacity requirement of the new address + */ + ArrayNode* SwitchContainer(int64_t capacity) { + if (data_ == nullptr) { + data_ = ArrayNode::Empty(capacity); + } else if (data_.unique()) { + data_ = ArrayNode::MoveFrom(capacity, GetArrayNode()); + } else { + data_ = ArrayNode::CopyFrom(capacity, GetArrayNode()); + } + return static_cast(data_.get()); + } +}; + +// Specialize make_object to make sure it is correct. +template <> +inline ObjectPtr make_object() { + return ArrayNode::Empty(); +} + /*! \brief An object representing a structure or enumeration. */ class ADTObj : public Object, public InplaceArrayBase { public: diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 51b1372e9ff2..cb07b3dfc62b 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -64,6 +64,8 @@ struct TypeIndex { kRuntimeNDArray = 2, /*! \brief runtime::String. */ kRuntimeString = 3, + /*! \brief runtime::Array. */ + kRuntimeArray = 4, // static assignments that may subject to change. kRuntimeClosure, kRuntimeADT, diff --git a/python/tvm/ir/container.py b/python/tvm/ir/container.py index 11ef107f5514..c415454ae35f 100644 --- a/python/tvm/ir/container.py +++ b/python/tvm/ir/container.py @@ -22,7 +22,7 @@ from tvm.runtime import _ffi_node_api -@tvm._ffi.register_object +@tvm._ffi.register_object("Array") class Array(Object): """Array container of TVM. diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 24b108e628e8..1a9283de0cfc 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -158,11 +158,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << '['; - for (size_t i = 0; i < op->data.size(); ++i) { + for (size_t i = 0; i < op->size(); ++i) { if (i != 0) { p->stream << ", "; } - p->Print(op->data[i]); + p->Print(op->at(i)); } p->stream << ']'; }); diff --git a/src/node/container.cc b/src/node/container.cc index a5e7669fc66d..6f737435714d 100644 --- a/src/node/container.cc +++ b/src/node/container.cc @@ -137,16 +137,16 @@ struct ArrayNodeTrait { static constexpr const std::nullptr_t VisitAttrs = nullptr; static void SHashReduce(const ArrayNode* key, SHashReducer hash_reduce) { - hash_reduce(static_cast(key->data.size())); - for (size_t i = 0; i < key->data.size(); ++i) { - hash_reduce(key->data[i]); + hash_reduce(static_cast(key->size())); + for (size_t i = 0; i < key->size(); ++i) { + hash_reduce(key->at(i)); } } static bool SEqualReduce(const ArrayNode* lhs, const ArrayNode* rhs, SEqualReducer equal) { - if (lhs->data.size() != rhs->data.size()) return false; - for (size_t i = 0; i < lhs->data.size(); ++i) { - if (!equal(lhs->data[i], rhs->data[i])) return false; + if (lhs->size() != rhs->size()) return false; + for (size_t i = 0; i < lhs->size(); ++i) { + if (!equal(lhs->at(i), rhs->at(i))) return false; } return true; } @@ -167,9 +167,7 @@ TVM_REGISTER_GLOBAL("node.Array").set_body([](TVMArgs args, TVMRetValue* ret) { data.push_back(ObjectRef(nullptr)); } } - auto node = make_object(); - node->data = std::move(data); - *ret = Array(node); + *ret = Array(data); }); TVM_REGISTER_GLOBAL("node.ArrayGetItem").set_body([](TVMArgs args, TVMRetValue* ret) { @@ -178,15 +176,15 @@ TVM_REGISTER_GLOBAL("node.ArrayGetItem").set_body([](TVMArgs args, TVMRetValue* Object* ptr = static_cast(args[0].value().v_handle); CHECK(ptr->IsInstance()); auto* n = static_cast(ptr); - CHECK_LT(static_cast(i), n->data.size()) << "out of bound of array"; - *ret = n->data[static_cast(i)]; + CHECK_LT(static_cast(i), n->size()) << "out of bound of array"; + *ret = n->at(i); }); TVM_REGISTER_GLOBAL("node.ArraySize").set_body([](TVMArgs args, TVMRetValue* ret) { CHECK_EQ(args[0].type_code(), kTVMObjectHandle); Object* ptr = static_cast(args[0].value().v_handle); CHECK(ptr->IsInstance()); - *ret = static_cast(static_cast(ptr)->data.size()); + *ret = static_cast(static_cast(ptr)->size()); }); struct MapNodeTrait { @@ -368,20 +366,20 @@ TVM_REGISTER_GLOBAL("node.MapItems").set_body([](TVMArgs args, TVMRetValue* ret) if (ptr->IsInstance()) { auto* n = static_cast(ptr); - auto rkvs = make_object(); + Array rkvs; for (const auto& kv : n->data) { - rkvs->data.push_back(kv.first); - rkvs->data.push_back(kv.second); + rkvs.push_back(kv.first); + rkvs.push_back(kv.second); } - *ret = Array(rkvs); + *ret = std::move(rkvs); } else { auto* n = static_cast(ptr); - auto rkvs = make_object(); + Array rkvs; for (const auto& kv : n->data) { - rkvs->data.push_back(tir::StringImmNode::make(kv.first)); - rkvs->data.push_back(kv.second); + rkvs.push_back(tir::StringImmNode::make(kv.first)); + rkvs.push_back(kv.second); } - *ret = Array(rkvs); + *ret = std::move(rkvs); } }); } // namespace tvm diff --git a/src/node/serialization.cc b/src/node/serialization.cc index 4675c5339f8d..9bd94f0741d0 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -105,7 +105,7 @@ class NodeIndexer : public AttrVisitor { if (node->IsInstance()) { ArrayNode* n = static_cast(node); - for (const auto& sp : n->data) { + for (const auto& sp : *n) { MakeIndex(const_cast(sp.get())); } } else if (node->IsInstance()) { @@ -244,8 +244,8 @@ class JSONAttrGetter : public AttrVisitor { if (node->IsInstance()) { ArrayNode* n = static_cast(node); - for (size_t i = 0; i < n->data.size(); ++i) { - node_->data.push_back(node_index_->at(const_cast(n->data[i].get()))); + for (size_t i = 0; i < n->size(); ++i) { + node_->data.push_back(node_index_->at(const_cast(n->at(i).get()))); } } else if (node->IsInstance()) { MapNode* n = static_cast(node); @@ -270,7 +270,7 @@ class JSONAttrGetter : public AttrVisitor { // from given json node. class JSONAttrSetter : public AttrVisitor { public: - const std::vector >* node_list_; + const std::vector>* node_list_; const std::vector* tensor_list_; JSONNode* node_; @@ -322,9 +322,10 @@ class JSONAttrSetter : public AttrVisitor { if (node->IsInstance()) { ArrayNode* n = static_cast(node); - n->data.clear(); + CHECK_EQ(n->size(), node_->data.size()); + int64_t i = 0; for (size_t index : node_->data) { - n->data.push_back(ObjectRef(node_list_->at(index))); + n->SetItem(i++, ObjectRef(node_list_->at(index))); } } else if (node->IsInstance()) { MapNode* n = static_cast(node); @@ -414,21 +415,23 @@ std::string SaveJSON(const ObjectRef& n) { } ObjectRef LoadJSON(std::string json_str) { - std::istringstream is(json_str); - dmlc::JSONReader reader(&is); JSONGraph jgraph; - // load in json graph. - jgraph.Load(&reader); - std::vector > nodes; + std::vector> nodes; std::vector tensors; - // load in tensors - for (const std::string& blob : jgraph.b64ndarrays) { - dmlc::MemoryStringStream mstrm(const_cast(&blob)); - support::Base64InStream b64strm(&mstrm); - b64strm.InitPosition(); - runtime::NDArray temp; - CHECK(temp.Load(&b64strm)); - tensors.emplace_back(temp); + { + // load in json graph. + std::istringstream is(json_str); + dmlc::JSONReader reader(&is); + jgraph.Load(&reader); + // load in tensors + for (const std::string& blob : jgraph.b64ndarrays) { + dmlc::MemoryStringStream mstrm(const_cast(&blob)); + support::Base64InStream b64strm(&mstrm); + b64strm.InitPosition(); + runtime::NDArray temp; + CHECK(temp.Load(&b64strm)); + tensors.emplace_back(temp); + } } ReflectionVTable* reflection = ReflectionVTable::Global(); @@ -436,9 +439,12 @@ ObjectRef LoadJSON(std::string json_str) { nodes.reserve(jgraph.nodes.size()); for (const JSONNode& jnode : jgraph.nodes) { - if (jnode.type_key.length() != 0) { + if (jnode.type_key == ArrayNode::_type_key) { + CHECK(jnode.repr_bytes.empty()); + nodes.emplace_back(ArrayNode::CreateRepeated(jnode.data.size(), ObjectRef(nullptr))); + } else if (jnode.type_key.length() != 0) { ObjectPtr node = reflection->CreateInitObject(jnode.type_key, jnode.repr_bytes); - nodes.emplace_back(node); + nodes.emplace_back(std::move(node)); } else { nodes.emplace_back(ObjectPtr()); } diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index 5166a489e22f..076339d774b4 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -732,7 +732,7 @@ Doc RelayTextPrinter::VisitAttr_(const ArrayNode* op) { Doc doc; doc << "["; std::vector arr_vals; - for (auto val : op->data) { + for (auto val : *op) { arr_vals.push_back(PrintAttr(val)); } doc << Doc::Concat(arr_vals); diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index 511a24377738..0bcc1488cba8 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -164,11 +164,11 @@ Doc TIRTextPrinter::PrintIRModule(const IRModule& module) { Doc TIRTextPrinter::PrintArray(const ArrayNode* op) { Doc doc; doc << '['; - for (size_t i = 0; i < op->data.size(); ++i) { + for (size_t i = 0; i < op->size(); ++i) { if (i != 0) { doc << ", "; } - doc << Print(op->data[i]); + doc << Print(op->at(i)); } doc << ']'; return doc; diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 8ec094e9bdf1..6ccf5853e023 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -509,8 +509,8 @@ bool ReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, if (param->newshape) { auto temp = param->newshape.value(); if (param->reverse) { - data_shape.assign(data->shape.rbegin(), data->shape.rend()); - newshape.assign(temp.rbegin(), temp.rend()); + data_shape.Assign(data->shape.rbegin(), data->shape.rend()); + newshape.Assign(temp.rbegin(), temp.rend()); } else { data_shape = data->shape; newshape = temp; @@ -1938,7 +1938,7 @@ bool SplitRel(const Array& types, int num_inputs, const Attrs& attrs, } reporter->Assign(types[1], TupleType(Array(fields))); } else { - auto indices = param->indices_or_sections.as()->data; + auto indices = Downcast>(param->indices_or_sections); auto begin = IndexExpr(tir::make_zero(DataType::Int(32))); std::vector fields; for (unsigned int i = 0; i < indices.size(); ++i) { diff --git a/src/relay/transforms/infer_layout_util.h b/src/relay/transforms/infer_layout_util.h index 7ced51db781e..9868ee5d03db 100644 --- a/src/relay/transforms/infer_layout_util.h +++ b/src/relay/transforms/infer_layout_util.h @@ -135,9 +135,9 @@ inline Array> BinaryBroadcastLayout(const Attrs& attrs, } if (new_in_layouts.defined()) { - layouts.assign(new_in_layouts.begin(), new_in_layouts.end()); + layouts.Assign(new_in_layouts.begin(), new_in_layouts.end()); } else { - layouts.assign(old_in_layouts.begin(), old_in_layouts.end()); + layouts.Assign(old_in_layouts.begin(), old_in_layouts.end()); } if (!layouts[0].defined() && !layouts[1].defined()) { diff --git a/src/te/schedule/schedule_dataflow_rewrite.cc b/src/te/schedule/schedule_dataflow_rewrite.cc index ed2880653d63..95612635a3d9 100644 --- a/src/te/schedule/schedule_dataflow_rewrite.cc +++ b/src/te/schedule/schedule_dataflow_rewrite.cc @@ -37,10 +37,10 @@ namespace te { template size_t FindNodeRef(ArrayNode* array_node, const T& v) { const Object* n = v.get(); - for (size_t i = 0; i < array_node->data.size(); ++i) { - if (array_node->data[i].get() == n) return i; + for (size_t i = 0; i < array_node->size(); ++i) { + if (array_node->at(i).get() == n) return i; } - return array_node->data.size(); + return array_node->size(); } // The replacer of cache. @@ -158,13 +158,13 @@ Tensor Schedule::cache_read(const Tensor& tensor, const std::string& scope, s->op = repl_op; } ReplaceDataFlow((*this)->stages, &vmap, &rvmap); - ArrayNode* stages = (*this)->stages.CopyOnWrite(); + Array& stages = (*this)->stages; Stage op_stage = operator[](tensor->op); - size_t pos = FindNodeRef(stages, op_stage); + size_t pos = FindNodeRef(stages.GetArrayNode(), op_stage); Stage cache_stage = Stage(cache->op); cache_stage.set_scope(scope); - CHECK_LT(pos, stages->data.size()); - stages->data.insert(stages->data.begin() + pos + 1, cache_stage); + CHECK_LT(pos, stages.size()); + stages.insert(stages.begin() + pos + 1, cache_stage); (*this)->stage_map.Set(cache->op, cache_stage); // Update group cache_stage->group = op_stage->group; @@ -251,12 +251,12 @@ Array ReplaceOriginalOp(Schedule sch, Stage orig_stage, const std::strin orig_stage->leaf_iter_vars = orig_stage->all_iter_vars; orig_stage->relations = Array(); // create schedule for new cached stage. - ArrayNode* stages = sch->stages.CopyOnWrite(); - size_t pos = FindNodeRef(stages, orig_stage); + Array& stages = sch->stages; + size_t pos = FindNodeRef(stages.GetArrayNode(), orig_stage); Stage cache_stage = Stage(cache_op); cache_stage.set_scope(scope); - CHECK_LT(pos, stages->data.size()); - stages->data.insert(stages->data.begin() + pos, cache_stage); + CHECK_LT(pos, stages.size()); + stages.insert(stages.begin() + pos, cache_stage); sch->stage_map.Set(cache_op, cache_stage); // Update group cache_stage->group = orig_stage->group; @@ -465,14 +465,14 @@ void RebaseNonZeroMinLoop(const Schedule& sch) { if (it != s->iter_var_attrs.end() && (*it).second->bind_thread.defined()) { continue; } - if (idx < leaf_vars->data.size()) { + if (idx < leaf_vars->size()) { // insert rebase IterVar rebased = IterVarNode::make(Range(), iv->var.copy_with_suffix(""), iv->iter_type); s->relations.push_back(RebaseNode::make(iv, rebased)); if (s->iter_var_attrs.count(iv)) { s->iter_var_attrs.Set(rebased, s->iter_var_attrs.at(iv)); } - leaf_vars->data[idx] = rebased; + leaf_vars->SetItem(idx, rebased); rebase_map[iv] = rebased; } } @@ -635,8 +635,7 @@ Array Schedule::rfactor(const Tensor& tensor, const IterVar& axis, int f ArrayNode* leaf_vars = reduce_stage->leaf_iter_vars.CopyOnWrite(); { size_t axis_pos = FindNodeRef(leaf_vars, axis); - CHECK_NE(axis_pos, leaf_vars->data.size()) - << "Cannot find IterVar " << axis << " in leaf iter vars"; + CHECK_NE(axis_pos, leaf_vars->size()) << "Cannot find IterVar " << axis << " in leaf iter vars"; } // Find touched reduction axis. std::unordered_map touch_map; @@ -762,12 +761,12 @@ Array Schedule::rfactor(const Tensor& tensor, const IterVar& axis, int f } // initialize the factored stage. Operation factor_op(n); - ArrayNode* stages = (*this)->stages.CopyOnWrite(); - size_t stage_pos = FindNodeRef(stages, reduce_stage); + Array& stages = (*this)->stages; + size_t stage_pos = FindNodeRef(stages.GetArrayNode(), reduce_stage); Stage factor_stage = Stage(factor_op); factor_stage->relations = rels; - CHECK_LT(stage_pos, stages->data.size()); - stages->data.insert(stages->data.begin() + stage_pos, factor_stage); + CHECK_LT(stage_pos, stages.size()); + stages.insert(stages.begin() + stage_pos, factor_stage); (*this)->stage_map.Set(factor_op, factor_stage); factor_stage->group = reduce_stage->group; if (factor_stage->group.defined()) { diff --git a/src/te/schedule/schedule_lang.cc b/src/te/schedule/schedule_lang.cc index e73c3c7eca5e..7de5257a18a4 100644 --- a/src/te/schedule/schedule_lang.cc +++ b/src/te/schedule/schedule_lang.cc @@ -37,17 +37,17 @@ namespace te { template size_t FindNodeRef(ArrayNode* array_node, const T& v) { const Object* n = v.get(); - for (size_t i = 0; i < array_node->data.size(); ++i) { - if (array_node->data[i].get() == n) return i; + for (size_t i = 0; i < array_node->size(); ++i) { + if (array_node->at(i).get() == n) return i; } - return array_node->data.size(); + return array_node->size(); } size_t FindLeafVar(ArrayNode* all_vars, ArrayNode* leaf_vars, const IterVar& v) { size_t pos = FindNodeRef(leaf_vars, v); - if (pos < leaf_vars->data.size()) return pos; + if (pos < leaf_vars->size()) return pos; - if (FindNodeRef(all_vars, v) < all_vars->data.size()) { + if (FindNodeRef(all_vars, v) < all_vars->size()) { LOG(FATAL) << "Operate on iter var " << v << "that has already been split"; } else { LOG(FATAL) << "Operate on iter var " << v << "that is not part of the schedule"; @@ -68,17 +68,17 @@ void Split(StageNode* self, IterVar parent, PrimExpr factor, PrimExpr nparts, It *p_outer = outer; *p_inner = inner; // The splits - ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); - ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); - size_t pos = FindLeafVar(all_vars, leaf_vars, parent); + Array& all_vars = self->all_iter_vars; + Array& leaf_vars = self->leaf_iter_vars; + size_t pos = FindLeafVar(all_vars.GetArrayNode(), leaf_vars.GetArrayNode(), parent); self->relations.push_back(SplitNode::make(parent, outer, inner, factor, nparts)); // add vars to all vars - all_vars->data.push_back(outer); - all_vars->data.push_back(inner); + all_vars.push_back(outer); + all_vars.push_back(inner); // replace the position. - leaf_vars->data.erase(leaf_vars->data.begin() + pos); - leaf_vars->data.insert(leaf_vars->data.begin() + pos, inner); - leaf_vars->data.insert(leaf_vars->data.begin() + pos, outer); + leaf_vars.erase(leaf_vars.begin() + pos); + leaf_vars.insert(leaf_vars.begin() + pos, inner); + leaf_vars.insert(leaf_vars.begin() + pos, outer); } Stage::Stage(Operation op) { @@ -188,14 +188,14 @@ Stage& Stage::env_threads(Array threads) { CHECK(self->op.defined() && self->op.as()) << "env_threads is only valid for composite ops such as ScanOp"; CHECK_EQ(self->env_threads.size(), 0U) << "Already set env_threads"; - ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); - ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); + Array& leaf_vars = self->leaf_iter_vars; + Array& all_vars = self->all_iter_vars; std::vector temp; for (IterVar iv : threads) { temp.push_back(iv); } - leaf_vars->data.insert(leaf_vars->data.begin(), temp.begin(), temp.end()); - all_vars->data.insert(all_vars->data.end(), temp.begin(), temp.end()); + leaf_vars.insert(leaf_vars.begin(), temp.begin(), temp.end()); + all_vars.insert(all_vars.end(), temp.begin(), temp.end()); self->env_threads = threads; return *this; } @@ -233,11 +233,11 @@ Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar* p_target) { // NOLINT IterVar fused = IterVarNode::make(Range(), Var(fused_name, outer->var.dtype()), iter_type); - ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); - ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); + Array& all_vars = self->all_iter_vars; + Array& leaf_vars = self->leaf_iter_vars; - size_t pos_inner = FindLeafVar(all_vars, leaf_vars, inner); - size_t pos_outer = FindLeafVar(all_vars, leaf_vars, outer); + size_t pos_inner = FindLeafVar(all_vars.GetArrayNode(), leaf_vars.GetArrayNode(), inner); + size_t pos_outer = FindLeafVar(all_vars.GetArrayNode(), leaf_vars.GetArrayNode(), outer); if (pos_inner + 1 == pos_outer) { std::swap(outer, inner); std::swap(pos_inner, pos_outer); @@ -245,10 +245,9 @@ Stage& Stage::fuse(IterVar outer, IterVar inner, IterVar* p_target) { // NOLINT CHECK_EQ(pos_inner, pos_outer + 1) << "Can only fuse iterations that are consecutive between each other"; self->relations.push_back(FuseNode::make(outer, inner, fused)); - all_vars->data.push_back(fused); - leaf_vars->data.erase(leaf_vars->data.begin() + pos_outer, - leaf_vars->data.begin() + pos_inner + 1); - leaf_vars->data.insert(leaf_vars->data.begin() + pos_outer, fused); + all_vars.push_back(fused); + leaf_vars.erase(leaf_vars.begin() + pos_outer, leaf_vars.begin() + pos_inner + 1); + leaf_vars.insert(leaf_vars.begin() + pos_outer, fused); *p_target = fused; return *this; } @@ -267,10 +266,10 @@ Stage& Stage::fuse(const Array& axes, IterVar* p_target) { // NOLINT(* IterVar singleton = IterVarNode::make(Range::make_by_min_extent(0, 1), Var("singleton", DataType::Int(32)), kDataPar); self->relations.push_back(SingletonNode::make(singleton)); - ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); - ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); - all_vars->data.push_back(singleton); - leaf_vars->data.insert(leaf_vars->data.begin(), singleton); + Array& all_vars = self->all_iter_vars; + Array& leaf_vars = self->leaf_iter_vars; + all_vars.push_back(singleton); + leaf_vars.insert(leaf_vars.begin(), singleton); *p_target = singleton; } return *this; @@ -296,11 +295,11 @@ Stage& Stage::reorder(const Array& order) { // NOLINT(*) } std::vector temp; for (size_t i = 0; i < pos.size(); ++i) { - temp.emplace_back(leaf_vars->data[pos[i]]); + temp.emplace_back(leaf_vars->at(pos[i])); } std::sort(pos.begin(), pos.end()); for (size_t i = 0; i < pos.size(); ++i) { - leaf_vars->data[pos[i]] = temp[i]; + leaf_vars->SetItem(pos[i], temp[i]); } return *this; } diff --git a/src/tir/transforms/storage_access.cc b/src/tir/transforms/storage_access.cc index 35888bd7f9e1..cd749b9ced81 100644 --- a/src/tir/transforms/storage_access.cc +++ b/src/tir/transforms/storage_access.cc @@ -116,7 +116,7 @@ void StorageAccessVisitor::VisitStmt_(const AttrStmtNode* op) { IterVar iv = Downcast(op->node); env_threads_.push_back(iv); StmtExprVisitor::VisitStmt_(op); - env_threads_.CopyOnWrite()->data.pop_back(); + env_threads_.pop_back(); } else if (op->attr_key == attr::thread_extent) { IterVar iv = Downcast(op->node); env_threads_.push_back(iv); @@ -131,7 +131,7 @@ void StorageAccessVisitor::VisitStmt_(const AttrStmtNode* op) { } else { StmtExprVisitor::VisitStmt_(op); } - env_threads_.CopyOnWrite()->data.pop_back(); + env_threads_.pop_back(); } else { StmtExprVisitor::VisitStmt_(op); } diff --git a/tests/cpp/container_test.cc b/tests/cpp/container_test.cc index 5d1f4720b965..f73355397887 100644 --- a/tests/cpp/container_test.cc +++ b/tests/cpp/container_test.cc @@ -171,6 +171,106 @@ TEST(Array, Iterator) { CHECK(vector[1].as()->value == 2); } +TEST(Array, PushPop) { + using namespace tvm; + Array a; + std::vector b; + for (int i = 0; i < 10; ++i) { + a.push_back(i); + b.push_back(i); + ASSERT_EQ(a.front(), b.front()); + ASSERT_EQ(a.back(), b.back()); + ASSERT_EQ(a.size(), b.size()); + int n = a.size(); + for (int j = 0; j < n; ++j) { + ASSERT_EQ(a[j], b[j]); + } + } + for (int i = 9; i >= 0; --i) { + ASSERT_EQ(a.front(), b.front()); + ASSERT_EQ(a.back(), b.back()); + ASSERT_EQ(a.size(), b.size()); + a.pop_back(); + b.pop_back(); + int n = a.size(); + for (int j = 0; j < n; ++j) { + ASSERT_EQ(a[j], b[j]); + } + } + ASSERT_EQ(a.empty(), true); +} + +TEST(Array, ResizeReserveClear) { + using namespace tvm; + for (size_t n = 0; n < 10; ++n) { + Array a; + Array b; + a.resize(n); + b.reserve(n); + ASSERT_EQ(a.size(), n); + ASSERT_GE(a.capacity(), n); + a.clear(); + b.clear(); + ASSERT_EQ(a.size(), 0); + ASSERT_EQ(b.size(), 0); + } +} + +TEST(Array, InsertErase) { + using namespace tvm; + Array a; + std::vector b; + for (int n = 1; n <= 10; ++n) { + a.insert(a.end(), n); + b.insert(b.end(), n); + for (int pos = 0; pos <= n; ++pos) { + a.insert(a.begin() + pos, pos); + b.insert(b.begin() + pos, pos); + ASSERT_EQ(a.front(), b.front()); + ASSERT_EQ(a.back(), b.back()); + ASSERT_EQ(a.size(), n + 1); + ASSERT_EQ(b.size(), n + 1); + for (int k = 0; k <= n; ++k) { + ASSERT_EQ(a[k], b[k]); + } + a.erase(a.begin() + pos); + b.erase(b.begin() + pos); + } + ASSERT_EQ(a.front(), b.front()); + ASSERT_EQ(a.back(), b.back()); + ASSERT_EQ(a.size(), n); + } +} + +TEST(Array, InsertEraseRange) { + using namespace tvm; + Array range_a{-1, -2, -3, -4}; + std::vector range_b{-1, -2, -3, -4}; + Array a; + std::vector b; + for (size_t n = 1; n <= 10; ++n) { + a.insert(a.end(), n); + b.insert(b.end(), n); + for (size_t pos = 0; pos <= n; ++pos) { + a.insert(a.begin() + pos, range_a.begin(), range_a.end()); + b.insert(b.begin() + pos, range_b.begin(), range_b.end()); + ASSERT_EQ(a.front(), b.front()); + ASSERT_EQ(a.back(), b.back()); + ASSERT_EQ(a.size(), n + range_a.size()); + ASSERT_EQ(b.size(), n + range_b.size()); + size_t m = n + range_a.size(); + for (size_t k = 0; k < m; ++k) { + ASSERT_EQ(a[k], b[k]); + } + a.erase(a.begin() + pos, a.begin() + pos + range_a.size()); + b.erase(b.begin() + pos, b.begin() + pos + range_b.size()); + } + ASSERT_EQ(a.front(), b.front()); + ASSERT_EQ(a.back(), b.back()); + ASSERT_EQ(a.size(), n); + } +} + TEST(Map, Expr) { using namespace tvm; Var x("x"); diff --git a/tests/python/relay/test_backend_graph_runtime.py b/tests/python/relay/test_backend_graph_runtime.py index b0399a53a732..226d5ba218e8 100644 --- a/tests/python/relay/test_backend_graph_runtime.py +++ b/tests/python/relay/test_backend_graph_runtime.py @@ -105,7 +105,7 @@ def test_with_params(): mod.run() res = mod.get_output(0).asnumpy() ref_res = np.exp(y_data + x_data) - tvm.testing.assert_allclose(res, ref_res) + tvm.testing.assert_allclose(res, ref_res, atol=1e-5, rtol=1e-5) def test_plan_memory():