Skip to content

Commit

Permalink
Implement accumulate and friends for Tuple (#34654)
Browse files Browse the repository at this point in the history
  • Loading branch information
tkf authored and KristofferC committed Apr 11, 2020
1 parent 865a01c commit a6b237c
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 4 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ New library features

* `isapprox` (or ``) now has a one-argument "curried" method `isapprox(x)` which returns a function, like `isequal` (or `==`)` ([#32305]).
* `Ref{NTuple{N,T}}` can be passed to `Ptr{T}`/`Ref{T}` `ccall` signatures ([#34199])
* `accumulate`, `cumsum`, and `cumprod` now support `Tuple` ([#34654]).


Standard library changes
Expand Down
31 changes: 27 additions & 4 deletions base/accumulate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,15 @@ function cumsum(A::AbstractArray{T}; dims::Integer) where T
end

"""
cumsum(x::AbstractVector)
cumsum(itr::Union{AbstractVector,Tuple})
Cumulative sum a vector. See also [`cumsum!`](@ref)
Cumulative sum an iterator. See also [`cumsum!`](@ref)
to use a preallocated output array, both for performance and to control the precision of the
output (e.g. to avoid overflow).
!!! compat "Julia 1.5"
`cumsum` on a tuple requires at least Julia 1.5.
# Examples
```jldoctest
julia> cumsum([1, 1, 1])
Expand All @@ -111,9 +114,13 @@ julia> cumsum([fill(1, 2) for i in 1:3])
[1, 1]
[2, 2]
[3, 3]
julia> cumsum((1, 1, 1))
(1, 2, 3)
```
"""
cumsum(x::AbstractVector) = cumsum(x, dims=1)
cumsum(itr) = accumulate(add_sum, itr)


"""
Expand Down Expand Up @@ -163,12 +170,15 @@ function cumprod(A::AbstractArray; dims::Integer)
end

"""
cumprod(x::AbstractVector)
cumprod(itr::Union{AbstractVector,Tuple})
Cumulative product of a vector. See also
Cumulative product of an iterator. See also
[`cumprod!`](@ref) to use a preallocated output array, both for performance and
to control the precision of the output (e.g. to avoid overflow).
!!! compat "Julia 1.5"
`cumprod` on a tuple requires at least Julia 1.5.
# Examples
```jldoctest
julia> cumprod(fill(1//2, 3))
Expand All @@ -182,9 +192,13 @@ julia> cumprod([fill(1//3, 2, 2) for i in 1:3])
[1//3 1//3; 1//3 1//3]
[2//9 2//9; 2//9 2//9]
[4//27 4//27; 4//27 4//27]
julia> cumprod((1, 2, 1))
(1, 2, 2)
```
"""
cumprod(x::AbstractVector) = cumprod(x, dims=1)
cumprod(itr) = accumulate(mul_prod, itr)


"""
Expand Down Expand Up @@ -247,6 +261,15 @@ function accumulate(op, A; dims::Union{Nothing,Integer}=nothing, kw...)
accumulate!(op, out, A; dims=dims, kw...)
end

function accumulate(op, xs::Tuple; init = _InitialValue())
rf = BottomRF(op)
ys, = afoldl(((), init), xs...) do (ys, acc), x
acc = rf(acc, x)
(ys..., acc), acc
end
return ys
end

"""
accumulate!(op, B, A; [dims], [init])
Expand Down
11 changes: 11 additions & 0 deletions test/tuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,17 @@ end
end
end

@testset "accumulate" begin
@test @inferred(cumsum(())) == ()
@test @inferred(cumsum((1, 2, 3))) == (1, 3, 6)
@test @inferred(cumprod((1, 2, 3))) == (1, 2, 6)
@test @inferred(accumulate(+, (1, 2, 3); init=10)) == (11, 13, 16)
op(::Nothing, ::Any) = missing
op(::Missing, ::Any) = nothing
@test @inferred(accumulate(op, (1, 2, 3, 4); init = nothing)) ===
(missing, nothing, missing, nothing)
end

@testset "ntuple" begin
nttest1(x::NTuple{n, Int}) where {n} = n
@test nttest1(()) == 0
Expand Down

0 comments on commit a6b237c

Please sign in to comment.