Skip to content

Commit

Permalink
[Runtime] Introduce runtime::Array (apache#5585)
Browse files Browse the repository at this point in the history
* Introduce runtime::Array

* Sync with dmlc-core

* Tests added: size, capacity, empty, front, back, push_back, pop_back, insert * 2, erase * 2, resize, reserve, clear
  • Loading branch information
junrushao authored and Trevor Morris committed Jun 9, 2020
1 parent caacf21 commit 8c7f09a
Show file tree
Hide file tree
Showing 18 changed files with 990 additions and 388 deletions.
286 changes: 6 additions & 280 deletions include/tvm/node/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@

namespace tvm {

using runtime::Array;
using runtime::ArrayNode;
using runtime::IterAdapter;
using runtime::make_object;
using runtime::Object;
using runtime::ObjectEqual;
Expand All @@ -46,16 +49,6 @@ using runtime::ObjectRef;
using runtime::String;
using runtime::StringObj;

/*! \brief array node content in array */
class ArrayNode : public Object {
public:
/*! \brief the data content */
std::vector<ObjectRef> data;

static constexpr const char* _type_key = "Array";
TVM_DECLARE_FINAL_OBJECT_INFO(ArrayNode, Object);
};

/*! \brief map node content */
class MapNode : public Object {
public:
Expand All @@ -82,273 +75,6 @@ class StrMapNode : public Object {
TVM_DECLARE_FINAL_OBJECT_INFO(StrMapNode, Object);
};

/*!
* \brief iterator adapter that adapts TIter to return another type.
* \tparam Converter a struct that contains converting function
* \tparam TIter the content iterator type.
*/
template <typename Converter, typename TIter>
class IterAdapter {
public:
using difference_type = typename std::iterator_traits<TIter>::difference_type;
using value_type = typename Converter::ResultType;
using pointer = typename Converter::ResultType*;
using reference = typename Converter::ResultType&; // NOLINT(*)
using iterator_category = typename std::iterator_traits<TIter>::iterator_category;

explicit IterAdapter(TIter iter) : iter_(iter) {}
inline IterAdapter& operator++() {
++iter_;
return *this;
}
inline IterAdapter operator+(difference_type offset) const { return IterAdapter(iter_ + offset); }

template <typename T = IterAdapter>
typename std::enable_if<std::is_same<iterator_category, std::random_access_iterator_tag>::value,
typename T::difference_type>::type inline
operator-(const IterAdapter& rhs) const {
return iter_ - rhs.iter_;
}

inline bool operator==(IterAdapter other) const { return iter_ == other.iter_; }
inline bool operator!=(IterAdapter other) const { return !(*this == other); }
inline const value_type operator*() const { return Converter::convert(*iter_); }

private:
TIter iter_;
};

/*!
* \brief Array container of NodeRef in DSL graph.
* Array implements copy on write semantics, which means array is mutable
* but copy will happen when array is referenced in more than two places.
*
* operator[] only provide const acces, use Set to mutate the content.
* \tparam T The content NodeRef type.
*/
template <typename T,
typename = typename std::enable_if<std::is_base_of<ObjectRef, T>::value>::type>
class Array : public ObjectRef {
public:
/*!
* \brief default constructor
*/
Array() { data_ = make_object<ArrayNode>(); }
/*!
* \brief move constructor
* \param other source
*/
Array(Array<T>&& other) : ObjectRef() { // NOLINT(*)
data_ = std::move(other.data_);
}
/*!
* \brief copy constructor
* \param other source
*/
Array(const Array<T>& other) : ObjectRef() { // NOLINT(*)
data_ = std::move(other.data_);
}
/*!
* \brief constructor from pointer
* \param n the container pointer
*/
explicit Array(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief constructor from iterator
* \param begin begin of iterator
* \param end end of iterator
* \tparam IterType The type of iterator
*/
template <typename IterType>
Array(IterType begin, IterType end) {
assign(begin, end);
}
/*!
* \brief constructor from initializer list
* \param init The initalizer list
*/
Array(std::initializer_list<T> init) { // NOLINT(*)
assign(init.begin(), init.end());
}
/*!
* \brief constructor from vector
* \param init The vector
*/
Array(const std::vector<T>& init) { // NOLINT(*)
assign(init.begin(), init.end());
}
/*!
* \brief Constructs a container with n elements. Each element is a copy of val
* \param n The size of the container
* \param val The init value
*/
explicit Array(size_t n, const T& val) {
auto tmp_node = make_object<ArrayNode>();
for (size_t i = 0; i < n; ++i) {
tmp_node->data.push_back(val);
}
data_ = std::move(tmp_node);
}
/*!
* \brief move assign operator
* \param other The source of assignment
* \return reference to self.
*/
Array<T>& operator=(Array<T>&& other) {
data_ = std::move(other.data_);
return *this;
}
/*!
* \brief copy assign operator
* \param other The source of assignment
* \return reference to self.
*/
Array<T>& operator=(const Array<T>& other) {
data_ = other.data_;
return *this;
}
/*!
* \brief reset the array to content from iterator.
* \param begin begin of iterator
* \param end end of iterator
* \tparam IterType The type of iterator
*/
template <typename IterType>
void assign(IterType begin, IterType end) {
auto n = make_object<ArrayNode>();
for (IterType it = begin; it != end; ++it) {
n->data.push_back(T(*it));
}
data_ = std::move(n);
}
/*!
* \brief Read i-th element from array.
* \param i The index
* \return the i-th element.
*/
inline const T operator[](size_t i) const {
return DowncastNoCheck<T>(static_cast<const ArrayNode*>(data_.get())->data[i]);
}
/*! \return The size of the array */
inline size_t size() const {
if (data_.get() == nullptr) return 0;
return static_cast<const ArrayNode*>(data_.get())->data.size();
}
/*!
* \brief copy on write semantics
* Do nothing if current handle is the unique copy of the array.
* Otherwise make a new copy of the array to ensure the current handle
* hold a unique copy.
*
* \return Handle to the internal node container(which ganrantees to be unique)
*/
inline ArrayNode* CopyOnWrite() {
if (data_.get() == nullptr || !data_.unique()) {
ObjectPtr<ArrayNode> n = make_object<ArrayNode>();
n->data = static_cast<ArrayNode*>(data_.get())->data;
ObjectPtr<Object>(std::move(n)).swap(data_);
}
return static_cast<ArrayNode*>(data_.get());
}
/*!
* \brief push a new item to the back of the list
* \param item The item to be pushed.
*/
inline void push_back(const T& item) {
ArrayNode* n = this->CopyOnWrite();
n->data.push_back(item);
}
/*!
* \brief Resize the array.
* \param size The new size.
*/
inline void resize(size_t size) {
ArrayNode* n = this->CopyOnWrite();
n->data.resize(size);
}
/*!
* \brief set i-th element of the array.
* \param i The index
* \param value The value to be setted.
*/
inline void Set(size_t i, const T& value) {
ArrayNode* n = this->CopyOnWrite();
n->data[i] = value;
}
/*! \return whether array is empty */
inline bool empty() const { return size() == 0; }
/*!
* \brief Helper function to apply fmutate to mutate an array.
* \param fmutate The transformation function T -> T.
* \tparam F the type of the mutation function.
* \note This function performs copy on write optimization.
*/
template <typename F>
inline void MutateByApply(F fmutate) {
ArrayNode* ptr = static_cast<ArrayNode*>(data_.get());
if (ptr == nullptr) return;
if (data_.unique()) {
// Copy on write optimization.
// Perform inplace update because this is an unique copy.
for (size_t i = 0; i < ptr->data.size(); ++i) {
// It is important to use move here
// to make prevent the element's ref count from increasing
// so fmutate itself can perform copy-on-write optimization
T old_elem = DowncastNoCheck<T>(std::move(ptr->data[i]));
T new_elem = fmutate(std::move(old_elem));
ptr->data[i] = std::move(new_elem);
}
} else {
// lazily trigger copy if there is element change.
ObjectPtr<ArrayNode> copy;
for (size_t i = 0; i < ptr->data.size(); ++i) {
T old_elem = DowncastNoCheck<T>(ptr->data[i]);
T new_elem = fmutate(old_elem);
if (!new_elem.same_as(ptr->data[i])) {
// copy the old array
if (copy == nullptr) {
copy = runtime::make_object<ArrayNode>(*ptr);
}
copy->data[i] = std::move(new_elem);
}
}
// replace the data with the new copy.
if (copy != nullptr) {
data_ = std::move(copy);
}
}
}

/*! \brief specify container node */
using ContainerType = ArrayNode;

struct ValueConverter {
using ResultType = T;
static inline T convert(const ObjectRef& n) { return DowncastNoCheck<T>(n); }
};
using iterator = IterAdapter<ValueConverter, std::vector<ObjectRef>::const_iterator>;

using reverse_iterator =
IterAdapter<ValueConverter, std::vector<ObjectRef>::const_reverse_iterator>;

/*! \return begin iterator */
inline iterator begin() const {
return iterator(static_cast<const ArrayNode*>(data_.get())->data.begin());
}
/*! \return end iterator */
inline iterator end() const {
return iterator(static_cast<const ArrayNode*>(data_.get())->data.end());
}
/*! \return rbegin iterator */
inline reverse_iterator rbegin() const {
return reverse_iterator(static_cast<const ArrayNode*>(data_.get())->data.rbegin());
}
/*! \return rend iterator */
inline reverse_iterator rend() const {
return reverse_iterator(static_cast<const ArrayNode*>(data_.get())->data.rend());
}
};

/*!
* \brief Map container of NodeRef->NodeRef in DSL graph.
* Map implements copy on write semantics, which means map is mutable
Expand Down Expand Up @@ -404,8 +130,8 @@ class Map : public ObjectRef {
assign(init.begin(), init.end());
}
/*!
* \brief constructor from vector
* \param init The vector
* \brief constructor from unordered_map
* \param init The unordered_map
*/
template <typename Hash, typename Equal>
Map(const std::unordered_map<K, V, Hash, Equal>& init) { // NOLINT(*)
Expand Down Expand Up @@ -625,7 +351,7 @@ struct ObjectTypeChecker<Array<T> > {
if (ptr == nullptr) return true;
if (!ptr->IsInstance<ArrayNode>()) return false;
const ArrayNode* n = static_cast<const ArrayNode*>(ptr);
for (const auto& p : n->data) {
for (const ObjectRef& p : *n) {
if (!ObjectTypeChecker<T>::Check(p.get())) {
return false;
}
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
"If `axis = None`, all axis of dimension 1 get squeezed;"
"Else, the dimension in axes get squeezed."
"It is an error if an axis does not has dimension 1.")
.set_default(NullValue<Array<Integer> >());
.set_default(NullValue<Array<Integer>>());
}
}; // struct SqueezeAttrs

Expand Down
Loading

0 comments on commit 8c7f09a

Please sign in to comment.