diff --git a/stdlib/Statistics/src/Statistics.jl b/stdlib/Statistics/src/Statistics.jl index e9d31486592174..d08d039e5e3a6a 100644 --- a/stdlib/Statistics/src/Statistics.jl +++ b/stdlib/Statistics/src/Statistics.jl @@ -74,6 +74,7 @@ function mean(f, itr) return total/count end mean(f, A::AbstractArray) = sum(f, A) / length(A) +mean(f, A; dims = 1) = sum(f, A, dims = dims) / size(A, dims) """ mean!(r, v) diff --git a/stdlib/Statistics/test/runtests.jl b/stdlib/Statistics/test/runtests.jl index 94ed6f7f7a9020..f56fbfe8ad1e7a 100644 --- a/stdlib/Statistics/test/runtests.jl +++ b/stdlib/Statistics/test/runtests.jl @@ -73,6 +73,8 @@ end @test mean([1,2,3]) == 2. @test mean([0 1 2; 4 5 6], dims=1) == [2. 3. 4.] @test mean([1 2 3; 4 5 6], dims=1) == [2.5 3.5 4.5] + @test mean(-, [1 2 3 ; 4 5 6], dims = 1) == [-2.5 -3.5 -4.5] + @test mean(-, [1 2 3 ; 4 5 6], dims = 2) == transpose([-2.0 -5.0]) @test mean(i->i+1, 0:2) === 2. @test mean(isodd, [3]) === 1. @test mean(x->3x, (1,1)) === 3.