From 301db971daaeeb627ba768375538e6e7ff36d215 Mon Sep 17 00:00:00 2001 From: Stephan Hilb Date: Sun, 3 May 2020 00:36:59 +0200 Subject: [PATCH] implement `count` and `count!` using `mapreduce` (#34048) This creates the same calling interface for `count` as for other mapreduce-type functions like e.g. `sum`, namely allowing the `dims` keyword. The implementation itself is shorter than before without sacrificing performance. More detailed documentation for `count` was added too. --- NEWS.md | 2 ++ base/exports.jl | 1 + base/reduce.jl | 18 ++++---------- base/reducedim.jl | 61 +++++++++++++++++++++++++++++++++++++++++++++++ test/reducedim.jl | 21 ++++++++++++++++ 5 files changed, 89 insertions(+), 14 deletions(-) diff --git a/NEWS.md b/NEWS.md index 748dda8738e25..94c9d166c8e94 100644 --- a/NEWS.md +++ b/NEWS.md @@ -146,6 +146,8 @@ New library features will acquire locks for safe multi-threaded access. Setting it to `false` provides better performance when only one thread will access the file. * The introspection macros (`@which`, `@code_typed`, etc.) now work with `do`-block syntax ([#35283]) and with dot syntax ([#35522]). +* `count` now accepts the `dims` keyword. +* new in-place `count!` function similar to `sum!`. Standard library changes ------------------------ diff --git a/base/exports.jl b/base/exports.jl index 400b2025a946f..12d11601aa034 100644 --- a/base/exports.jl +++ b/base/exports.jl @@ -486,6 +486,7 @@ export any, firstindex, collect, + count!, count, delete!, deleteat!, diff --git a/base/reduce.jl b/base/reduce.jl index f95fb5ea74976..414bac5099f78 100644 --- a/base/reduce.jl +++ b/base/reduce.jl @@ -836,6 +836,8 @@ end ## count +_bool(f::Function) = x->f(x)::Bool + """ count(p, itr) -> Integer count(itr) -> Integer @@ -853,22 +855,10 @@ julia> count([true, false, true, true]) 3 ``` """ -function count(pred, itr) - n = 0 - for x in itr - n += pred(x)::Bool - end - return n -end -function count(pred, a::AbstractArrayOrBroadcasted) - n = 0 - for i in eachindex(a) - @inbounds n += pred(a[i])::Bool - end - return n -end count(itr) = count(identity, itr) +count(f, itr) = mapreduce(_bool(f), add_sum, itr, init=0) + function count(::typeof(identity), x::Array{Bool}) n = 0 chunks = length(x) ÷ sizeof(UInt) diff --git a/base/reducedim.jl b/base/reducedim.jl index 331ea9a2eb099..d1e5001492fc4 100644 --- a/base/reducedim.jl +++ b/base/reducedim.jl @@ -359,6 +359,67 @@ julia> reduce(max, a, dims=1) reduce(op, A::AbstractArray; kw...) = mapreduce(identity, op, A; kw...) ##### Specific reduction functions ##### + +""" + count([f=identity,] A::AbstractArray; dims=:) + +Count the number of elements in `A` for which `f` returns `true` over the given +dimensions. + +!!! compat "Julia 1.5" + `dims` keyword was added in Julia 1.5. + +# Examples +```jldoctest +julia> A = [1 2; 3 4] +2×2 Array{Int64,2}: + 1 2 + 3 4 + +julia> count(<=(2), A, dims=1) +1×2 Array{Int64,2}: + 1 1 + +julia> count(<=(2), A, dims=2) +2×1 Array{Int64,2}: + 2 + 0 +``` +""" +count(A::AbstractArrayOrBroadcasted; dims=:) = count(identity, A, dims=dims) +count(f, A::AbstractArrayOrBroadcasted; dims=:) = mapreduce(_bool(f), add_sum, A, dims=dims, init=0) + +""" + count!([f=identity,] r, A; init=true) + +Count the number of elements in `A` for which `f` returns `true` over the +singleton dimensions of `r`, writing the result into `r` in-place. +If `init` is `true`, values in `r` are initialized to zero. + +!!! compat "Julia 1.5" + inplace `count!` was added in Julia 1.5. + +# Examples +```jldoctest +julia> A = [1 2; 3 4] +2×2 Array{Int64,2}: + 1 2 + 3 4 + +julia> count!(<=(2), [1 1], A) +1×2 Array{Int64,2}: + 1 1 + +julia> count!(<=(2), [1; 1], A) +2-element Array{Int64,1}: + 2 + 0 +``` +""" +count!(r::AbstractArray, A::AbstractArrayOrBroadcasted; init::Bool=true) = count!(identity, r, A; init=init) +count!(f, r::AbstractArray, A::AbstractArrayOrBroadcasted; init::Bool=true) = + mapreducedim!(_bool(f), add_sum, initarray!(r, add_sum, init, A), A) + """ sum(A::AbstractArray; dims) diff --git a/test/reducedim.jl b/test/reducedim.jl index 1194027356794..3f59ae6e2570a 100644 --- a/test/reducedim.jl +++ b/test/reducedim.jl @@ -12,6 +12,7 @@ safe_sum(A::Array{T}, region) where {T} = safe_mapslices(sum, A, region) safe_prod(A::Array{T}, region) where {T} = safe_mapslices(prod, A, region) safe_maximum(A::Array{T}, region) where {T} = safe_mapslices(maximum, A, region) safe_minimum(A::Array{T}, region) where {T} = safe_mapslices(minimum, A, region) +safe_count(A::AbstractArray{T}, region) where {T} = safe_mapslices(count, A, region) safe_sumabs(A::Array{T}, region) where {T} = safe_mapslices(sum, abs.(A), region) safe_sumabs2(A::Array{T}, region) where {T} = safe_mapslices(sum, abs2.(A), region) safe_maxabs(A::Array{T}, region) where {T} = safe_mapslices(maximum, abs.(A), region) @@ -21,15 +22,21 @@ safe_minabs(A::Array{T}, region) where {T} = safe_mapslices(minimum, abs.(A), re 1, 2, 3, 4, 5, (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4), (1, 2, 3), (1, 3, 4), (2, 3, 4), (1, 2, 3, 4)] Areduc = rand(3, 4, 5, 6) + Breduc = rand(Bool, 3, 4, 5, 6) + @assert axes(Areduc) == axes(Breduc) + r = fill(NaN, map(length, Base.reduced_indices(axes(Areduc), region))) @test sum!(r, Areduc) ≈ safe_sum(Areduc, region) @test prod!(r, Areduc) ≈ safe_prod(Areduc, region) @test maximum!(r, Areduc) ≈ safe_maximum(Areduc, region) @test minimum!(r, Areduc) ≈ safe_minimum(Areduc, region) + @test count!(r, Breduc) ≈ safe_count(Breduc, region) + @test sum!(abs, r, Areduc) ≈ safe_sumabs(Areduc, region) @test sum!(abs2, r, Areduc) ≈ safe_sumabs2(Areduc, region) @test maximum!(abs, r, Areduc) ≈ safe_maxabs(Areduc, region) @test minimum!(abs, r, Areduc) ≈ safe_minabs(Areduc, region) + @test count!(!, r, Breduc) ≈ safe_count(.!Breduc, region) # With init=false r2 = similar(r) @@ -41,6 +48,9 @@ safe_minabs(A::Array{T}, region) where {T} = safe_mapslices(minimum, abs.(A), re @test maximum!(r, Areduc, init=false) ≈ fill!(r2, 1.8) fill!(r, -0.2) @test minimum!(r, Areduc, init=false) ≈ fill!(r2, -0.2) + fill!(r, 1) + @test count!(r, Breduc, init=false) ≈ safe_count(Breduc, region) .+ 1 + fill!(r, 8.1) @test sum!(abs, r, Areduc, init=false) ≈ safe_sumabs(Areduc, region) .+ 8.1 fill!(r, 8.1) @@ -49,15 +59,20 @@ safe_minabs(A::Array{T}, region) where {T} = safe_mapslices(minimum, abs.(A), re @test maximum!(abs, r, Areduc, init=false) ≈ fill!(r2, 1.5) fill!(r, -1.5) @test minimum!(abs, r, Areduc, init=false) ≈ fill!(r2, -1.5) + fill!(r, 1) + @test count!(!, r, Breduc, init=false) ≈ safe_count(.!Breduc, region) .+ 1 @test @inferred(sum(Areduc, dims=region)) ≈ safe_sum(Areduc, region) @test @inferred(prod(Areduc, dims=region)) ≈ safe_prod(Areduc, region) @test @inferred(maximum(Areduc, dims=region)) ≈ safe_maximum(Areduc, region) @test @inferred(minimum(Areduc, dims=region)) ≈ safe_minimum(Areduc, region) + @test @inferred(count(Breduc, dims=region)) ≈ safe_count(Breduc, region) + @test @inferred(sum(abs, Areduc, dims=region)) ≈ safe_sumabs(Areduc, region) @test @inferred(sum(abs2, Areduc, dims=region)) ≈ safe_sumabs2(Areduc, region) @test @inferred(maximum(abs, Areduc, dims=region)) ≈ safe_maxabs(Areduc, region) @test @inferred(minimum(abs, Areduc, dims=region)) ≈ safe_minabs(Areduc, region) + @test @inferred(count(!, Breduc, dims=region)) ≈ safe_count(.!Breduc, region) end # Test reduction along first dimension; this is special-cased for @@ -416,3 +431,9 @@ end @test sum([Variable(:x), Variable(:y)], dims=1) == [AffExpr([Variable(:x), Variable(:y)])] end + +# count +@testset "count: throw on non-bool types" begin + @test_throws TypeError count([1], dims=1) + @test_throws TypeError count!([1], [1]) +end