diff --git a/base/abstractdict.jl b/base/abstractdict.jl index 6bcae02d539dc2..c1b8f79d1b3bc3 100644 --- a/base/abstractdict.jl +++ b/base/abstractdict.jl @@ -465,6 +465,63 @@ function hash(a::AbstractDict, h::UInt) hash(hv, h) end +""" + modify!(f, d::AbstractDict{K, V}, key) + +Lookup and then update, insert or delete in one go without re-computing the hash. + +`f` is a callable object that must accept `Union{Some{V}, Nothing}` and return +`Union{T, Some{T}, Nothing}` where `T` is a type [`convert`](@ref)-able to the value type +`V`. The value `Some(d[key])` is passed to `f` if `haskey(d, key)`; otherwise `nothing` +is passed. If `f` returns `nothing`, corresponding entry in the dictionary `d` is removed. +If `f` returns non-`nothing` value `x`, `something(x)` is inserted to `d`. + +`modify!` returns whatever `f` returns as-is. + +# Examples +```jldoctest +julia> dict = Dict("a" => 1); + +julia> modify!(dict, "a") do val + Some(val === nothing ? 1 : something(val) + 1) + end +Some(2) + +julia> dict +Dict{String,Int64} with 1 entry: + "a" => 2 + +julia> dict = Dict(); + +julia> modify!(dict, "a") do val + Some(val === nothing ? 1 : something(val) + 1) + end +Some(1) + +julia> dict +Dict{Any,Any} with 1 entry: + "a" => 1 + +julia> modify!(_ -> nothing, dict, "a") + +julia> dict +Dict{Any,Any} with 0 entries +``` +""" +function modify!(f, dict::AbstractDict, key) + if haskey(dict, key) + val = f(Some(dict[key])) + else + val = f(nothing) + end + if val === nothing + delete!(dict, key) + else + dict[key] = something(val) + end + return val +end + function getindex(t::AbstractDict, key) v = get(t, key, secret_table_token) if v === secret_table_token diff --git a/base/dict.jl b/base/dict.jl index 8c1d762527bb8d..bd1cb29123b8ba 100644 --- a/base/dict.jl +++ b/base/dict.jl @@ -391,6 +391,45 @@ function setindex!(h::Dict{K,V}, v0, key::K) where V where K return h end +function modify!(f, h::Dict{K}, key0) where K + key = convert(K, key0) + if !isequal(key, key0) + throw(ArgumentError("$(limitrepr(key0)) is not a valid key for type $K")) + end + + # Ideally, to improve performance for the case that requires + # resizing, we should use something like `ht_keyindex` while + # keeping computed hash value and then do something like + # `ht_keyindex2!` if `f` returns non-`nothing`. + idx = ht_keyindex2!(h, key) + + age0 = h.age + if idx > 0 + @inbounds vold = h.vals[idx] + vnew = f(Some(vold)) + else + vnew = f(nothing) + end + if h.age != age0 + idx = ht_keyindex2!(h, key) + end + + if vnew === nothing + if idx > 0 + _delete!(h, idx) + end + else + if idx > 0 + h.age += 1 + @inbounds h.keys[idx] = key + @inbounds h.vals[idx] = something(vnew) + else + @inbounds _setindex!(h, something(vnew), key, -idx) + end + end + return vnew +end + """ get!(collection, key, default) diff --git a/test/dict.jl b/test/dict.jl index 1224d41bb220a6..e8a0c006d73fbf 100644 --- a/test/dict.jl +++ b/test/dict.jl @@ -1057,8 +1057,11 @@ end new{keytype(d), valtype(d)}(d) end end + Base.Dict(td::TestDict) = td.dict + Base.haskey(td::TestDict, key) = haskey(td.dict, key) Base.setindex!(td::TestDict, args...) = setindex!(td.dict, args...) Base.getindex(td::TestDict, args...) = getindex(td.dict, args...) + Base.delete!(td::TestDict, key) = delete!(td.dict, key) Base.pairs(D::TestDict) = pairs(D.dict) testdict = TestDict(:a=>1, :b=>2) map!(v->v-1, values(testdict)) @@ -1072,3 +1075,51 @@ end @test testdict[:b] == 1 end end + +@testset "modify!(f, ::$constructor, key)" for constructor in [ + Dict, + TestDict, +] + @testset "update" begin + dict = constructor(Dict("a" => 1)) + + @test modify!(dict, "a") do val + Some(val === nothing ? 1 : something(val) + 1) + end == Some(2) + + @test Dict(dict) == Dict("a" => 2) + end + + @testset "insert" begin + dict = constructor(Dict()) + + @test modify!(dict, "a") do val + Some(val === nothing ? 1 : something(val) + 1) + end == Some(1) + + @test Dict(dict) == Dict("a" => 1) + end + + @testset "delete" begin + dict = constructor(Dict("a" => 1)) + @test modify!(_ -> nothing, dict, "a") === nothing + @test Dict(dict) == Dict() + end + + @testset "no-op" begin + dict = constructor(Dict("a" => 1)) + @test modify!(_ -> nothing, dict, "b") === nothing + @test Dict(dict) == Dict("a" => 1) + end + + @testset "mutation inside `f`" begin + dict = constructor(Dict()) + + @test modify!(dict, "a") do val + dict["a"] = 0 + Some(val === nothing ? 1 : something(val) + 1) + end == Some(1) + + @test Dict(dict) == Dict("a" => 1) + end +end