diff --git a/immer/detail/hamts/champ.hpp b/immer/detail/hamts/champ.hpp index 64a8cb82..6ab8a563 100644 --- a/immer/detail/hamts/champ.hpp +++ b/immer/detail/hamts/champ.hpp @@ -12,6 +12,9 @@ #include #include +#include +#include +#include namespace immer { namespace detail { @@ -832,15 +835,306 @@ struct champ typename Combine, typename K, typename Fn> - champ update(const K& k, Fn&& fn) const + champ update(K&& k, Fn&& fn) const { auto hash = Hash{}(k); auto res = do_update( - root, k, std::forward(fn), hash, 0); + root, std::forward(k), std::forward(fn), hash, 0); auto new_size = size + (res.added ? 1 : 0); return {res.node, new_size}; } + using update_mut_result = add_mut_result; + + template + struct TryUpdater + { + static update_result + do_try_update(node_t* node, + byval_if_possible k, + byval_if_possible fn, + byval_if_possible valueEquals, + hash_t hash, + shift_t shift) + { + if (shift == max_shift) { + auto fst = node->collisions(); + auto lst = fst + node->collision_count(); + for (; fst != lst; ++fst) + if (Equal{}(*fst, k)) { + auto&& new_value = std::forward(fn)( + Project{}(detail::as_const(*fst))); + if (valueEquals(detail::as_const(new_value), + Project{}(detail::as_const(*fst)))) + return {nullptr, false}; + + return {node_t::copy_collision_replace( + node, + fst, + Combine{}(std::forward(k), + std::forward( + new_value))), + false}; + } + return {node_t::copy_collision_insert( + node, + Combine{}(std::forward(k), + std::forward(fn)(Default{}()))), + true}; + } else { + auto idx = (hash & (mask << shift)) >> shift; + auto bit = bitmap_t{1u} << idx; + if (node->nodemap() & bit) { + auto offset = node->children_count(bit); + auto result = do_try_update(node->children()[offset], + std::forward(k), + std::forward(fn), + valueEquals, + hash, + shift + B); + if (!result.node) + return result; + + IMMER_TRY { + result.node = node_t::copy_inner_replace( + node, offset, result.node); + return result; + } + IMMER_CATCH (...) { + node_t::delete_deep_shift(result.node, shift + B); + IMMER_RETHROW; + } + } else if (node->datamap() & bit) { + auto offset = node->data_count(bit); + auto val = node->values() + offset; + if (Equal{}(*val, k)) { + auto&& new_value = std::forward(fn)( + Project{}(detail::as_const(*val))); + if (detail::as_const(new_value) == + Project{}(detail::as_const(*val))) + return {nullptr, false}; + + return {node_t::copy_inner_replace_value( + node, + offset, + Combine{}(std::forward(k), + std::forward( + new_value))), + false}; + } else { + auto child = node_t::make_merged( + shift + B, + Combine{}(std::forward(k), + std::forward(fn)(Default{}())), + hash, + *val, + Hash{}(*val)); + IMMER_TRY { + return {node_t::copy_inner_replace_merged( + node, bit, offset, child), + true}; + } + IMMER_CATCH (...) { + node_t::delete_deep_shift(child, shift + B); + IMMER_RETHROW; + } + } + } else { + return {node_t::copy_inner_insert_value( + node, + bit, + Combine{}(std::forward(k), + std::forward(fn)(Default{}()))), + true}; + } + } + } + + static update_mut_result + do_try_update_mut(edit_t e, + node_t* node, + byval_if_possible k, + byval_if_possible fn, + byval_if_possible valueEquals, + hash_t hash, + shift_t shift) + { + if (shift == max_shift) { + auto fst = node->collisions(); + auto lst = fst + node->collision_count(); + for (; fst != lst; ++fst) + if (Equal{}(*fst, k)) { + auto&& new_value = std::forward(fn)( + Project{}(detail::as_const(*fst))); + if (valueEquals(detail::as_const(new_value), + Project{}(detail::as_const(*fst)))) { + return {nullptr, false, false}; + } + + if (node->can_mutate(e)) { + *fst = Combine{}( + std::forward(k), + std::forward(new_value)); + return {node, false, true}; + } else { + auto r = node_t::copy_collision_replace( + node, + fst, + Combine{}(std::forward(k), + std::forward( + new_value))); + return {node_t::owned(r, e), false, false}; + } + } + auto v = Combine{}(std::forward(k), + std::forward(fn)(Default{}())); + auto mutate = node->can_mutate(e); + auto r = + mutate ? node_t::move_collision_insert(node, std::move(v)) + : node_t::copy_collision_insert(node, std::move(v)); + return {node_t::owned(r, e), true, mutate}; + } else { + auto idx = (hash & (mask << shift)) >> shift; + auto bit = bitmap_t{1u} << idx; + if (node->nodemap() & bit) { + auto offset = node->children_count(bit); + auto child = node->children()[offset]; + if (node->can_mutate(e)) { + auto result = do_try_update_mut(e, + child, + std::forward(k), + std::forward(fn), + valueEquals, + hash, + shift + B); + if (!result.node) + return result; + + node->children()[offset] = result.node; + if (!result.mutated && child->dec()) + node_t::delete_deep_shift(child, shift + B); + return {node, result.added, true}; + } else { + auto result = do_try_update(child, + std::forward(k), + std::forward(fn), + valueEquals, + hash, + shift + B); + if (!result.node) + return {nullptr, false, false}; + + IMMER_TRY { + result.node = node_t::copy_inner_replace( + node, offset, result.node); + node_t::owned(result.node, e); + return {result.node, result.added, false}; + } + IMMER_CATCH (...) { + node_t::delete_deep_shift(result.node, shift + B); + IMMER_RETHROW; + } + } + } else if (node->datamap() & bit) { + auto offset = node->data_count(bit); + auto val = node->values() + offset; + if (Equal{}(*val, k)) { + if (node->can_mutate(e)) { + auto vals = node->ensure_mutable_values(e); + auto&& new_value = std::forward(fn)( + Project{}(detail::as_const(vals[offset]))); + if (valueEquals( + detail::as_const(new_value), + Project{}(detail::as_const(vals[offset])))) + return {nullptr, false, false}; + + vals[offset] = Combine{}( + std::forward(k), + std::forward(new_value)); + return {node, false, true}; + } else { + auto&& new_value = std::forward(fn)( + Project{}(detail::as_const(*val))); + if (valueEquals(detail::as_const(new_value), + Project{}(detail::as_const(*val)))) + return {nullptr, false, false}; + + auto r = node_t::copy_inner_replace_value( + node, + offset, + Combine{}(std::forward(k), + std::forward( + new_value))); + return {node_t::owned_values(r, e), false, false}; + } + } else { + auto mutate = node->can_mutate(e); + auto mutate_values = + mutate && node->can_mutate_values(e); + auto hash2 = Hash{}(*val); + auto child = node_t::make_merged_e( + e, + shift + B, + Combine{}(std::forward(k), + std::forward(fn)(Default{}())), + hash, + mutate_values ? std::move(*val) : *val, + hash2); + IMMER_TRY { + auto r = mutate ? node_t::move_inner_replace_merged( + e, node, bit, offset, child) + : node_t::copy_inner_replace_merged( + node, bit, offset, child); + return { + node_t::owned_values_safe(r, e), true, mutate}; + } + IMMER_CATCH (...) { + node_t::delete_deep_shift(child, shift + B); + IMMER_RETHROW; + } + } + } else { + auto mutate = node->can_mutate(e); + auto v = Combine{}(std::forward(k), + std::forward(fn)(Default{}())); + auto r = mutate ? node_t::move_inner_insert_value( + e, node, bit, std::move(v)) + : node_t::copy_inner_insert_value( + node, bit, std::move(v)); + return {node_t::owned_values(r, e), true, mutate}; + } + } + } + }; + + template > + champ try_update(K&& k, Fn&& fn, ValueEquals valueEquals = {}) const + { + auto hash = Hash{}(k); + auto res = TryUpdater:: + do_try_update(root, + std::forward(k), + std::forward(fn), + std::move(valueEquals), + hash, + 0); + if (!res.node) + return {root->inc(), size}; + + auto new_size = size + size_t(res.added); + return {res.node, new_size}; + } + template node_t* do_update_if_exists( node_t* node, K&& k, Fn&& fn, hash_t hash, shift_t shift) const @@ -909,8 +1203,6 @@ struct champ }; } - using update_mut_result = add_mut_result; - template - void update_mut(edit_t e, const K& k, Fn&& fn) + void update_mut(edit_t e, K&& k, Fn&& fn) { auto hash = Hash{}(k); auto res = do_update_mut( - e, root, k, std::forward(fn), hash, 0); + e, root, std::forward(k), std::forward(fn), hash, 0); + if (!res.mutated && root->dec()) + node_t::delete_deep(root, 0); + root = res.node; + size += res.added ? 1 : 0; + } + + template + void try_update_mut(edit_t e, K&& k, Fn&& fn, ValueEquals valueEquals) + { + auto hash = Hash{}(k); + auto res = TryUpdater:: + do_try_update_mut(e, + root, + std::forward(k), + std::forward(fn), + valueEquals, + hash, + 0); + if (!res.node) + return; + if (!res.mutated && root->dec()) node_t::delete_deep(root, 0); root = res.node; diff --git a/immer/detail/util.hpp b/immer/detail/util.hpp index d6ae246b..ec14bd61 100644 --- a/immer/detail/util.hpp +++ b/immer/detail/util.hpp @@ -310,5 +310,15 @@ distance(Iterator first, Sentinel last) return last - first; } +template +static constexpr bool can_efficiently_pass_by_value = + sizeof(T) <= 2 * sizeof(void*) && std::is_trivially_copyable::value; + +template +using byval_if_possible = + std::conditional_t>, + std::decay_t, + OrElse>; + } // namespace detail } // namespace immer diff --git a/immer/map.hpp b/immer/map.hpp index cab8e5d5..f3172d4e 100644 --- a/immer/map.hpp +++ b/immer/map.hpp @@ -424,6 +424,37 @@ class map return update_move(move_t{}, std::move(k), std::forward(fn)); } + /*! + * Returns a map replacing the association `(k, v)` by the + * association new association `(k, fn(v))`, where `v` is the + * currently associated value for `k` in the map or a default + * constructed value otherwise. It may allocate memory + * and its complexity is *effectively* @f$ O(1) @f$. + * + * If `fn(v) == v`, the map remains unchanged and no memory is allocated. + * You may customize the equality comparison for values by setting the + * ValueEquals callback + */ + template > + IMMER_NODISCARD map try_update(key_type k, + Fn&& fn, + ValueEquals valueEquals = {}) const& + { + return impl_ + .template try_update( + std::move(k), std::forward(fn), std::move(valueEquals)); + } + + template > + IMMER_NODISCARD decltype(auto) + try_update(key_type k, Fn&& fn, ValueEquals valueEquals = {}) && + { + return try_update_move(move_t{}, + std::move(k), + std::forward(fn), + std::move(valueEquals)); + } + /*! * Returns a map replacing the association `(k, v)` by the association new * association `(k, fn(v))`, where `v` is the currently associated value for @@ -516,6 +547,29 @@ class map std::move(k), std::forward(fn)); } + template + map&& try_update_move(std::true_type, + key_type k, + Fn&& fn, + ValueEquals valueEquals) + { + impl_.template try_update_mut( + {}, std::move(k), std::forward(fn), std::move(valueEquals)); + return std::move(*this); + } + template + map try_update_move(std::false_type, + key_type k, + Fn&& fn, + ValueEquals valueEquals) + { + return impl_ + .template try_update( + std::move(k), std::forward(fn), std::move(valueEquals)); + } + template map&& update_if_exists_move(std::true_type, key_type k, Fn&& fn) { diff --git a/test/algorithm.cpp b/test/algorithm.cpp index 6a785dbd..69cffb40 100644 --- a/test/algorithm.cpp +++ b/test/algorithm.cpp @@ -150,6 +150,23 @@ TEST_CASE("update maps") do_check(immer::table{}); } +TEST_CASE("try update maps") +{ + auto do_check = [](auto v) { + (void) v.try_update(0, [](auto&& x) { + using type_t = std::decay_t; + // for maps, we actually do not make a copy at all but pase the + // original instance directly, as const.. + static_assert(std::is_same::value, ""); + return x; + }); + }; + + do_check(immer::map{}); + // -- tables not supported yet with try_update + // do_check(immer::table{}); +} + TEST_CASE("update_if_exists maps") { auto do_check = [](auto v) { diff --git a/test/map/generic.ipp b/test/map/generic.ipp index e9b6cf58..b91bbed7 100644 --- a/test/map/generic.ipp +++ b/test/map/generic.ipp @@ -171,6 +171,9 @@ TEST_CASE("equals and setting") CHECK(v.set(1234, 42) == v.insert({1234, 42})); CHECK(v.update(1234, [](auto&& x) { return x + 1; }) == v.set(1234, 1)); CHECK(v.update(42, [](auto&& x) { return x + 1; }) == v.set(42, 43)); + CHECK(v.try_update(42, [](auto&& x) { return x + 1; }) == v.set(42, 43)); + CHECK(v.try_update(42, [](auto&& x) { return x; }).identity() == + v.identity()); CHECK(v.update_if_exists(1234, [](auto&& x) { return x + 1; }) == v); CHECK(v.update_if_exists(42, [](auto&& x) { return x + 1; }) == @@ -282,6 +285,22 @@ TEST_CASE("update a lot") } } + SECTION("try_update immutable") + { + for (decltype(v.size()) i = 0; i < v.size(); ++i) { + v = v.try_update(i, [](auto&& x) { return x + 1; }); + CHECK(v[i] == i + 1); + } + } + + SECTION("try_update move") + { + for (decltype(v.size()) i = 0; i < v.size(); ++i) { + v = std::move(v).try_update(i, [](auto&& x) { return x + 1; }); + CHECK(v[i] == i + 1); + } + } + SECTION("erase") { for (decltype(v.size()) i = 0; i < v.size(); ++i) { @@ -315,11 +334,11 @@ TEST_CASE("update_if_exists a lot") #if !IMMER_IS_LIBGC_TEST TEST_CASE("update boxed move string") { - constexpr auto N = 666u; - constexpr auto S = 7; + static constexpr auto N = 666u; + static constexpr auto S = 7; auto s = MAP_T>{}; - SECTION("preserve immutability") - { + + auto do_test = [&s](auto Updater) { auto s0 = s; auto i0 = 0u; // insert @@ -328,8 +347,9 @@ TEST_CASE("update boxed move string") s0 = s; i0 = i; } - s = std::move(s).update(std::to_string(i), - [&](auto&&) { return std::to_string(i); }); + s = Updater(std::move(s), std::to_string(i), [&](auto&&) { + return std::to_string(i); + }); { CHECK(s.size() == i + 1); for (auto j : test_irange(0u, i + 1)) { @@ -355,7 +375,7 @@ TEST_CASE("update boxed move string") s0 = s; i0 = i; } - s = std::move(s).update(std::to_string(i), [&](auto&&) { + s = Updater(std::move(s), std::to_string(i), [&](auto&&) { return std::to_string(i + 1); }); { @@ -373,6 +393,24 @@ TEST_CASE("update boxed move string") CHECK(*s0.find(std::to_string(j)) == std::to_string(j)); } } + }; + + SECTION("preserve immutability") + { + do_test([](auto&& map, auto&& key, auto&& cb) { + return std::forward(map).update( + std::forward(key), + std::forward(cb)); + }); + } + + SECTION("preserve immutability (try_update)") + { + do_test([](auto&& map, auto&& key, auto&& cb) { + return std::forward(map).try_update( + std::forward(key), + std::forward(cb)); + }); } } #endif @@ -449,6 +487,27 @@ TEST_CASE("exception safety") IMMER_TRACE_E(d.happenings); } + SECTION("try_update") + { + auto v = dadaist_map_t{}; + auto d = dadaism{}; + for (auto i = 0u; i < n; ++i) + v = std::move(v).set(i, i); + for (auto i = 0u; i < v.size();) { + try { + auto s = d.next(); + v = v.try_update(i, [](auto x) { return x + 1; }); + ++i; + } catch (dada_error) {} + for (auto i : test_irange(0u, i)) + CHECK(v.at(i) == i + 1); + for (auto i : test_irange(i, n)) + CHECK(v.at(i) == i); + } + CHECK(d.happenings > 0); + IMMER_TRACE_E(d.happenings); + } + SECTION("update_if_exists") { auto v = dadaist_map_t{}; @@ -492,6 +551,28 @@ TEST_CASE("exception safety") IMMER_TRACE_E(d.happenings); } + SECTION("try_update collisisions") + { + auto vals = make_values_with_collisions(n); + auto v = dadaist_conflictor_map_t{}; + auto d = dadaism{}; + for (auto i = 0u; i < n; ++i) + v = v.insert(vals[i]); + for (auto i = 0u; i < v.size();) { + try { + auto s = d.next(); + v = v.try_update(vals[i].first, [](auto x) { return x + 1; }); + ++i; + } catch (dada_error) {} + for (auto i : test_irange(0u, i)) + CHECK(v.at(vals[i].first) == vals[i].second + 1); + for (auto i : test_irange(i, n)) + CHECK(v.at(vals[i].first) == vals[i].second); + } + CHECK(d.happenings > 0); + IMMER_TRACE_E(d.happenings); + } + SECTION("update_if_exists collisisions") { auto vals = make_values_with_collisions(n); @@ -584,6 +665,29 @@ TEST_CASE("exception safety") IMMER_TRACE_E(d.happenings); } + SECTION("try_update collisisions move") + { + auto vals = make_values_with_collisions(n); + auto v = dadaist_conflictor_map_t{}; + auto d = dadaism{}; + for (auto i = 0u; i < n; ++i) + v = std::move(v).insert(vals[i]); + for (auto i = 0u; i < v.size();) { + try { + auto s = d.next(); + v = std::move(v).try_update(vals[i].first, + [](auto x) { return x + 1; }); + ++i; + } catch (dada_error) {} + for (auto i : test_irange(0u, i)) + CHECK(v.at(vals[i].first) == vals[i].second + 1); + for (auto i : test_irange(i, n)) + CHECK(v.at(vals[i].first) == vals[i].second); + } + CHECK(d.happenings > 0); + IMMER_TRACE_E(d.happenings); + } + SECTION("update_if_exists collisisions move") { auto vals = make_values_with_collisions(n);