Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[compute/cker] Introduce the ShapeIterator #14311

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions compute/cker/include/cker/ShapeIterator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __NNFW_CKER_SHAPE_ITERATOR_H__
#define __NNFW_CKER_SHAPE_ITERATOR_H__

#include <utility>
#include "cker/Shape.h"

namespace nnfw
{
namespace cker
{
struct ShapeIterator
{
/// Definition of this iterator's traits that can be accessed by std::iterator_traits<It>
using value_type = decltype(std::declval<Shape>().Dims(0));
using difference_type = std::ptrdiff_t;
using pointer = value_type *;
using reference = value_type &;
using iterator_category = std::bidirectional_iterator_tag;

ShapeIterator(const Shape &s) : _shape{s}, _current{0}, _last{s.DimensionsCount()} {}
static ShapeIterator end_iterator(const Shape &s) { return ShapeIterator(s, EndIteratorTag{}); }

ShapeIterator &operator++()
{
++_current;
return *this;
}

// postincrement
ShapeIterator operator++(int)
{
auto copy = *this;
++_current;
return copy;
}

ShapeIterator &operator--()
{
--_current;
return *this;
}

ShapeIterator operator--(int)
{
auto copy = *this;
--_current;
return copy;
}

bool operator!=(const ShapeIterator &other) const { return _current != other._current; }
bool operator==(const ShapeIterator &other) const { return _current == other._current; }

/// Because the underlying method returns by-value, this operator does the same
/// instead of returning by-reference like most iterators do.
value_type operator*() const { return _shape.Dims(_current); }

private:
struct EndIteratorTag
{
};
// Creates an iterator instance pointing to the past-the-end element
// This iterator doesn't point to a valid element and thus its dereference is undefined behavior
ShapeIterator(const Shape &s, EndIteratorTag)
: _shape{s}, _current{s.DimensionsCount()}, _last{s.DimensionsCount()}
{
}

const Shape &_shape;
int32_t _current = 0, _last = 0;
};

inline ShapeIterator begin(const Shape &s) { return ShapeIterator(s); }
inline ShapeIterator end(const Shape &s) { return ShapeIterator::end_iterator(s); }

} // namespace cker
} // namespace nnfw

#endif //
27 changes: 27 additions & 0 deletions compute/cker/include/cker/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@
#define __NNFW_CKER_UTILS_H__

#include "Shape.h"
#include "ShapeIterator.h"

#include "neon/neon_check.h"

#include <algorithm>
#include <cstdint>
#include <numeric>
#include <string>
#include <fixedpoint/fixedpoint.h>

namespace nnfw
Expand Down Expand Up @@ -480,6 +483,30 @@ template <typename T> class SequentialTensorWriter
T *output_ptr_;
};

inline std::ostream &operator<<(std::ostream &os, const Shape &shape)
{
using std::begin;
using std::end;

std::string formatted =
std::accumulate(begin(shape), end(shape), std::string{"["},
[](std::string joined, ShapeIterator::value_type dim) {
return std::move(joined).append(std::to_string(dim)).append(",");
});

if (formatted.back() == '[')
{
formatted.push_back(']');
}
else
{
formatted.back() = ']';
}

os << formatted;
return os;
}

} // namespace cker
} // namespace nnfw

Expand Down
108 changes: 108 additions & 0 deletions compute/cker/src/ShapeIterator.test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cker/ShapeIterator.h>
#include <cker/Utils.h>
#include <gtest/gtest.h>
#include <numeric>

using namespace nnfw::cker;

TEST(CKer_Utils, ShapeIterator_basic)
{
const Shape test_shape{1, 3, 1024, 768};
{
// test the front and back iterability with basic operators
ShapeIterator it{test_shape};
EXPECT_EQ(*it, 1);
++it;
EXPECT_EQ(*it, 3);
it++;
EXPECT_EQ(*it, 1024);
--it;
EXPECT_EQ(*it, 3);
it--;
EXPECT_EQ(*it, 1);
}
{
// test the iterator's compatibility with STL iterator functions
ShapeIterator it{test_shape};
auto it2 = std::next(it);
EXPECT_EQ(*it2, 3);
EXPECT_EQ(*it, 1); // make sure the original iterator is untouched

std::advance(it2, 2);
EXPECT_EQ(*it2, 768);

std::advance(it2, -1);
EXPECT_EQ(*it2, 1024);
}
{
// postincrement operator test
ShapeIterator it{test_shape};
const auto it2 = it++;
EXPECT_EQ(*it, 3);
EXPECT_EQ(*it2, 1);
}
{
// test the ability to iterate over a Shape with range-based loops
int expected_dims[] = {1, 3, 1024, 768};
int i = 0;
for (auto &&dim : test_shape)
{
EXPECT_EQ(dim, expected_dims[i++]);
}
}
{
// test the ability to retrieve iterators using begin & end
const auto first = begin(test_shape);
const auto last = end(test_shape);
EXPECT_GT(std::distance(first, last), 0);
EXPECT_EQ(std::distance(first, last), test_shape.DimensionsCount());
}

{
// test and demostrate the usage of iterators with STL algos
const auto first = begin(test_shape);
const auto last = end(test_shape);
const auto shape_elems =
std::accumulate(first, last, 1, std::multiplies<ShapeIterator::value_type>{});
EXPECT_EQ(shape_elems, test_shape.FlatSize());
}

{
// Shape and ofstream interoperability test
std::stringstream ss;
ss << test_shape;
EXPECT_EQ(ss.str(), "[1,3,1024,768]");
}
}

TEST(CKer_Utils, neg_ShapeIterator_empty_shape)
{
const Shape test_shape{};
{
const auto first = begin(test_shape);
const auto last = end(test_shape);
EXPECT_EQ(first, last);
}

{
std::stringstream ss;
ss << test_shape;
EXPECT_EQ(ss.str(), "[]");
}
}