diff --git a/NEWS.md b/NEWS.md index 0e61bc480cac2..786eb324232c8 100644 --- a/NEWS.md +++ b/NEWS.md @@ -24,7 +24,6 @@ New library functions Standard library changes ------------------------ - #### LinearAlgebra @@ -36,6 +35,7 @@ Standard library changes #### Statistics +* `mean` now accepts both a function argument and a `dims` keyword ([#31576]). #### Miscellaneous diff --git a/stdlib/Statistics/src/Statistics.jl b/stdlib/Statistics/src/Statistics.jl index 22d78bd3de65f..19e73e6ab13bf 100644 --- a/stdlib/Statistics/src/Statistics.jl +++ b/stdlib/Statistics/src/Statistics.jl @@ -73,7 +73,32 @@ function mean(f, itr) end return total/count end -mean(f, A::AbstractArray) = sum(f, A) / length(A) + +""" + mean(f::Function, A::AbstractArray; dims) + +Apply the function `f` to each element of array `A` and take the mean over dimensions `dims`. + +!!! compat "Julia 1.3" + This method requires at least Julia 1.3. + +```jldoctest +julia> mean(√, [1, 2, 3]) +1.3820881233139908 + +julia> mean([√1, √2, √3]) +1.3820881233139908 + +julia> mean(√, [1 2 3; 4 5 6], dims=2) +2×1 Array{Float64,2}: + 1.3820881233139908 + 2.2285192400943226 +``` +""" +mean(f, A::AbstractArray; dims=:) = _mean(f, A, dims) + +_mean(f, A::AbstractArray, ::Colon) = sum(f, A) / length(A) +_mean(f, A::AbstractArray, dims) = sum(f, A, dims=dims) / mapreduce(i -> size(A, i), *, unique(dims); init=1) """ mean!(r, v) diff --git a/stdlib/Statistics/test/runtests.jl b/stdlib/Statistics/test/runtests.jl index 1e080de32b895..e4849e04d4263 100644 --- a/stdlib/Statistics/test/runtests.jl +++ b/stdlib/Statistics/test/runtests.jl @@ -73,6 +73,11 @@ end @test mean([1,2,3]) == 2. @test mean([0 1 2; 4 5 6], dims=1) == [2. 3. 4.] @test mean([1 2 3; 4 5 6], dims=1) == [2.5 3.5 4.5] + @test mean(-, [1 2 3 ; 4 5 6], dims=1) == [-2.5 -3.5 -4.5] + @test mean(-, [1 2 3 ; 4 5 6], dims=2) == transpose([-2.0 -5.0]) + @test mean(-, [1 2 3 ; 4 5 6], dims=(1, 2)) == -3.5 .* ones(1, 1) + @test mean(-, [1 2 3 ; 4 5 6], dims=(1, 1)) == [-2.5 -3.5 -4.5] + @test mean(-, [1 2 3 ; 4 5 6], dims=()) == Float64[-1 -2 -3 ; -4 -5 -6] @test mean(i->i+1, 0:2) === 2. @test mean(isodd, [3]) === 1. @test mean(x->3x, (1,1)) === 3.