Skip to content

Commit

Permalink
static_filter_map::filter uses std::countr_zero
Browse files Browse the repository at this point in the history
  • Loading branch information
fhamonic committed Sep 5, 2024
1 parent dabd15c commit c17b182
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 109 deletions.
157 changes: 63 additions & 94 deletions include/melon/container/static_filter_map.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include <ranges>
#include <vector>

#include "melon/detail/intrusive_view.hpp"

namespace fhamonic {
namespace melon {

Expand Down Expand Up @@ -131,17 +133,15 @@ class static_filter_map {
}

public:
friend constexpr bool operator==(const iterator_base & x,
const iterator_base & y) noexcept {
friend constexpr bool operator==(const I & x, const I & y) noexcept {
return x._p == y._p && x._local_index == y._local_index;
}
friend constexpr std::strong_ordering operator<=>(
const iterator_base & x, const iterator_base & y) noexcept {
const I & x, const I & y) noexcept {
if(const auto cmp = x._p <=> y._p; cmp != 0) return cmp;
return x._local_index <=> y._local_index;
}
constexpr difference_type operator-(
const iterator_base & other) const noexcept {
constexpr difference_type operator-(const I & other) const noexcept {
return (difference_type(N) * (_p - other._p) +
static_cast<difference_type>(_local_index) -
static_cast<difference_type>(other._local_index));
Expand All @@ -153,27 +153,27 @@ class static_filter_map {
return *static_cast<I *>(this);
}
constexpr I operator++(int) noexcept {
iterator tmp = *static_cast<iterator *>(this);
iterator tmp = *static_cast<I *>(this);
_bump_up();
return tmp;
}
constexpr I & operator--() noexcept {
_bump_down();
return *static_cast<iterator *>(this);
return *static_cast<I *>(this);
}
constexpr I operator--(int) noexcept {
iterator tmp = *static_cast<iterator *>(this);
iterator tmp = *static_cast<I *>(this);
_bump_down();
return tmp;
}
constexpr I & operator+=(difference_type i) noexcept {
_incr(i);
return *static_cast<iterator *>(this);
return *static_cast<I *>(this);
}

constexpr I & operator-=(difference_type i) noexcept {
_incr(-i);
return *static_cast<iterator *>(this);
return *static_cast<I *>(this);
}

friend constexpr I operator+(const I & x, difference_type n) {
Expand Down Expand Up @@ -212,6 +212,8 @@ class static_filter_map {
};

class const_iterator : public iterator_base<const_iterator> {
friend static_filter_map;

public:
using iterator_category = std::random_access_iterator_tag;
using difference_type = iterator_base<const_iterator>::difference_type;
Expand All @@ -223,8 +225,8 @@ class static_filter_map {
using iterator_base<const_iterator>::iterator_base;

constexpr const_reference operator*() const noexcept {
return (*iterator_base<iterator>::_p >>
iterator_base<iterator>::_local_index) &
return (*iterator_base<const_iterator>::_p >>
iterator_base<const_iterator>::_local_index) &
1;
}
constexpr const_reference operator[](difference_type i) const {
Expand Down Expand Up @@ -300,88 +302,55 @@ class static_filter_map {
b ? ~span_type(0) : span_type(0));
}

// template <std::viewable_range R>
// auto filter(R && r) noexcept const {
// if constexpr(std::same_as<R, std::ranges::iota_view<K, K>>) {
// K begin_index = std::max(static_cast<K>(0), *std::ranges::begin(r));
// K end_index = std::min(static_cast<K>(_size), *std::ranges::end(r));


// } /*else {
// return std::ranges::filter(r, [](const auto & k) { return operator[](k);});
// }*/
// }

// auto true_keys() const {
// // span_type * p = _data.get();
// // size_type index = 0;
// // const span_type * p_end = _data.get() + _size / N;
// // const size_type index_end = _size & span_index_mask;

// // for(;;) {
// // index += static_cast<size_type>(std::countr_zero((*p) >>
// index));
// // if(p == p_end && index >= index_end) co_return;
// // if(index >= N) {
// // ++p;
// // index = 0;
// // continue;
// // }
// // co_yield static_cast<size_type>(
// // difference_type(N) * (p - _data.get()) +
// // static_cast<difference_type>(index));
// // ++index;
// // }

// const span_type * data = _data.get();
// const size_type last_out_index = _size / N;
// const size_type last_in_index = _size & span_index_mask;

// struct {
// size_type out_index;
// size_type in_index;
// } cursor(0, 0);
// for(;;) {
// cursor.in_index += static_cast<size_type>(
// std::countr_zero((data[cursor.out_index]) >>
// cursor.in_index));
// if(cursor.out_index == last_out_index &&
// cursor.in_index >= last_in_index)
// break;
// if(cursor.in_index >= N) {
// ++cursor.out_index;
// cursor.in_index = 0;
// continue;
// }
// break;
// }

// return intrusive_view(
// cursor,
// [](const auto & cur) -> size_type {
// return cur.out_index * size_type(N) + cur.in_index;
// },
// [ data, last_out_index, last_in_index ](auto cur) -> auto{
// ++cur.in_index;
// for(;;) {
// cur.in_index += static_cast<size_type>(std::countr_zero(
// (data[cur.out_index]) >> cur.in_index));
// if(cur.out_index == last_out_index &&
// cur.in_index >= last_in_index)
// return cur;
// if(cur.in_index >= N) {
// ++cur.out_index;
// cur.in_index = 0;
// continue;
// }
// return cur;
// }
// },
// [last_out_index, last_in_index](const arc a) -> bool {
// return cur.out_index != last_out_index ||
// cur.in_index < last_in_index;
// });
// }
template <std::ranges::viewable_range R>
auto filter(R && r) const noexcept {
if constexpr(std::same_as<R, std::ranges::iota_view<K, K>>) {
K begin_index = std::max(static_cast<K>(0), *std::ranges::begin(r));
K end_index = std::min(static_cast<K>(_size), *std::ranges::end(r));

const_iterator begin_it(_data.get() + begin_index / N,
begin_index & span_index_mask);
const const_iterator end_it(_data.get() + end_index / N,
end_index & span_index_mask);

auto next_it = [end_it](const_iterator cursor) {
++cursor;
span_type shifted = (*cursor._p) >> cursor._local_index;
if(shifted == span_type{0}) {
do {
++cursor._p;
} while(cursor < end_it && *cursor._p == span_type{0});
cursor._local_index = 0;
shifted = *cursor._p;
}
cursor._local_index +=
static_cast<size_type>(std::countr_zero(shifted));
return cursor;
};

if(!*begin_it) begin_it = next_it(begin_it);

return intrusive_view(
begin_it,
[data = _data.get()](const const_iterator & cursor) -> K {
return static_cast<K>(
static_cast<size_type>(cursor._p - data) *
size_type(N) +
cursor._local_index);
},
std::move(next_it),
[end_it](const const_iterator & cursor) -> bool {
return cursor < end_it;
});
} else {
return std::views::filter(
std::views::transform(
r, [](auto && i) { return static_cast<K>(i); }),
[this](const auto & k) {
return operator[](static_cast<K>(k));
});
}
}
};

} // namespace melon
Expand Down
2 changes: 1 addition & 1 deletion test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ add_executable(
dumb_digraph_test.cpp
mutable_digraph_test.cpp
static_map_test.cpp
static_map_bool_test.cpp
static_filter_map_test.cpp
static_digraph_builder_test.cpp
breadth_first_search_test.cpp
depth_first_search_test.cpp
Expand Down
31 changes: 29 additions & 2 deletions test/static_map_bool_test.cpp → test/static_filter_map_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

#include <random>

#include "melon/container/static_map.hpp"
#include "melon/container/static_digraph.hpp"
#include "melon/container/static_map.hpp"
#include "melon/mapping.hpp"

#include "ranges_test_helper.hpp"
Expand All @@ -13,7 +13,8 @@ using namespace fhamonic::melon;

static_assert(std::copyable<static_map<std::size_t, bool>>);
static_assert(std::ranges::random_access_range<static_map<std::size_t, bool>>);
static_assert(output_mapping_of<static_map<std::size_t, bool>, std::size_t, bool>);
static_assert(
output_mapping_of<static_map<std::size_t, bool>, std::size_t, bool>);

GTEST_TEST(static_map_bool, empty_constructor) {
static_map<std::size_t, bool> map;
Expand Down Expand Up @@ -134,4 +135,30 @@ GTEST_TEST(static_map_bool, iterator_extensive_read) {
std::random_access_iterator<static_map<std::size_t, bool>::iterator>);

// ASSERT_TRUE(std::ranges::equal(std::views::values(map), datas));
}

GTEST_TEST(static_map_bool, filter) {
const std::size_t nb_bools = 153;
static_filter_map<std::size_t> map(nb_bools, false);
std::vector<std::size_t> indices;

auto gen = std::bind(std::uniform_int_distribution<>(0, 1),
std::default_random_engine());

for(std::size_t i = 0; i < nb_bools; ++i) {
bool b = gen();
if(!b) continue;
indices.emplace_back(i);
map[i] = b;
}

static_assert(std::random_access_iterator<std::vector<bool>::iterator>);
static_assert(
std::random_access_iterator<static_map<std::size_t, bool>::iterator>);

ASSERT_TRUE(EQ_MULTISETS(
map.filter(std::views::iota(std::size_t{0}, nb_bools)), indices));

ASSERT_TRUE(EQ_MULTISETS(
map.filter(std::views::iota(int{0}, int{nb_bools})), indices));
}
12 changes: 0 additions & 12 deletions test/static_map_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,6 @@ GTEST_TEST(static_map, for_each_read) {
std::vector<int> datas = {0, 7, 3, 5, 6, 11};
const static_map<std::size_t, int> map(datas.begin(), datas.end());
std::size_t cpt = 0;
// for(auto && [k, v] : map) {
// ASSERT_EQ(k, cpt);
// ASSERT_EQ(v, datas[cpt]);
// ++cpt;
// }
for(auto && v : map) {
ASSERT_EQ(v, datas[cpt]);
++cpt;
Expand All @@ -132,13 +127,6 @@ GTEST_TEST(static_map, for_each_write) {
std::vector<int> datas = {0, 7, 3, 5, 6, 11};
static_map<std::size_t, int> map(datas.begin(), datas.end());
std::size_t cpt = 0;
// for(auto && [k, v] : map) {
// ASSERT_EQ(k, cpt);
// ASSERT_EQ(v, datas[cpt]);
// v = 3 * static_cast<int>(cpt) + 1;
// ASSERT_EQ(map[cpt], 3 * static_cast<int>(cpt) + 1);
// ++cpt;
// }
for(auto & v : map) {
ASSERT_EQ(v, datas[cpt]);
v = 3 * static_cast<int>(cpt) + 1;
Expand Down

0 comments on commit c17b182

Please sign in to comment.