Skip to content

Commit

Permalink
Add export for Splat(f), replacing Base.splat (#42717)
Browse files Browse the repository at this point in the history
* Deprecate `Base.splat(x)` in favor of `Splat(x)` (now exported)
* Add pretty printing of `Splat(f)`
  • Loading branch information
Seelengrab authored May 9, 2022
1 parent 3023693 commit ef10e52
Show file tree
Hide file tree
Showing 14 changed files with 40 additions and 17 deletions.
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ New library functions
---------------------

* `Iterators.flatmap` was added ([#44792]).
* New helper `Splat(f)` which acts like `x -> f(x...)`, with pretty printing for
inspecting which function `f` was originally wrapped. ([#42717])

Library changes
---------------
Expand Down Expand Up @@ -120,6 +122,7 @@ Standard library changes
Deprecated or removed
---------------------

* Unexported `splat` is deprecated in favor of exported `Splat`, which has pretty printing of the wrapped function. ([#42717])

External dependencies
---------------------
Expand Down
6 changes: 6 additions & 0 deletions base/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -294,3 +294,9 @@ const var"@_noinline_meta" = var"@noinline"
@deprecate getindex(t::Tuple, i::Real) t[convert(Int, i)]

# END 1.8 deprecations

# BEGIN 1.9 deprecations

@deprecate splat(x) Splat(x) false

# END 1.9 deprecations
1 change: 1 addition & 0 deletions base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,7 @@ export
atreplinit,
exit,
ntuple,
Splat,

# I/O and events
close,
Expand Down
2 changes: 1 addition & 1 deletion base/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ the `zip` iterator is a tuple of values of its subiterators.
`zip` orders the calls to its subiterators in such a way that stateful iterators will
not advance when another iterator finishes in the current iteration.
See also: [`enumerate`](@ref), [`splat`](@ref Base.splat).
See also: [`enumerate`](@ref), [`Splat`](@ref Base.Splat).
# Examples
```jldoctest
Expand Down
24 changes: 18 additions & 6 deletions base/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1185,27 +1185,39 @@ used to implement specialized methods.
<(x) = Fix2(<, x)

"""
splat(f)
Splat(f)
Defined as
Equivalent to
```julia
splat(f) = args->f(args...)
my_splat(f) = args->f(args...)
```
i.e. given a function returns a new function that takes one argument and splats
its argument into the original function. This is useful as an adaptor to pass
a multi-argument function in a context that expects a single argument, but
passes a tuple as that single argument.
passes a tuple as that single argument. Additionally has pretty printing.
# Example usage:
```jldoctest
julia> map(Base.splat(+), zip(1:3,4:6))
julia> map(Base.Splat(+), zip(1:3,4:6))
3-element Vector{Int64}:
5
7
9
julia> my_add = Base.Splat(+)
Splat(+)
julia> my_add((1,2,3))
6
```
"""
splat(f) = args->f(args...)
struct Splat{F} <: Function
f::F
Splat(f) = new{Core.Typeof(f)}(f)
end
(s::Splat)(args) = s.f(args...)
print(io::IO, s::Splat) = print(io, "Splat(", s.f, ')')
show(io::IO, s::Splat) = print(io, s)

## in and related operators

Expand Down
1 change: 1 addition & 0 deletions base/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ end

show(io::IO, ::MIME"text/plain", c::ComposedFunction) = show(io, c)
show(io::IO, ::MIME"text/plain", c::Returns) = show(io, c)
show(io::IO, ::MIME"text/plain", s::Splat) = show(io, s)

function show(io::IO, ::MIME"text/plain", iter::Union{KeySet,ValueIterator})
isempty(iter) && get(io, :compact, false) && return show(io, iter)
Expand Down
4 changes: 2 additions & 2 deletions base/strings/search.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ function _searchindex(s::Union{AbstractString,ByteArray},
if i === nothing return 0 end
ii = nextind(s, i)::Int
a = Iterators.Stateful(trest)
matched = all(splat(==), zip(SubString(s, ii), a))
matched = all(Splat(==), zip(SubString(s, ii), a))
(isempty(a) && matched) && return i
i = ii
end
Expand Down Expand Up @@ -435,7 +435,7 @@ function _rsearchindex(s::AbstractString,
a = Iterators.Stateful(trest)
b = Iterators.Stateful(Iterators.reverse(
pairs(SubString(s, 1, ii))))
matched = all(splat(==), zip(a, (x[2] for x in b)))
matched = all(Splat(==), zip(a, (x[2] for x in b)))
if matched && isempty(a)
isempty(b) && return firstindex(s)
return nextind(s, popfirst!(b)[1])::Int
Expand Down
2 changes: 1 addition & 1 deletion doc/src/base/base.md
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ new
Base.:(|>)
Base.:(∘)
Base.ComposedFunction
Base.splat
Base.Splat
Base.Fix1
Base.Fix2
```
Expand Down
2 changes: 1 addition & 1 deletion doc/src/devdocs/ast.md
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ These symbols appear in the `head` field of [`Expr`](@ref)s in lowered form.
* `splatnew`

Similar to `new`, except field values are passed as a single tuple. Works similarly to
`Base.splat(new)` if `new` were a first-class function, hence the name.
`Base.Splat(new)` if `new` were a first-class function, hence the name.

* `isdefined`

Expand Down
2 changes: 1 addition & 1 deletion stdlib/LinearAlgebra/src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import Base: USE_BLAS64, abs, acos, acosh, acot, acoth, acsc, acsch, adjoint, as
vec, zero
using Base: IndexLinear, promote_eltype, promote_op, promote_typeof,
@propagate_inbounds, reduce, typed_hvcat, typed_vcat, require_one_based_indexing,
splat
Splat
using Base.Broadcast: Broadcasted, broadcasted
using OpenBLAS_jll
using libblastrampoline_jll
Expand Down
2 changes: 1 addition & 1 deletion stdlib/LinearAlgebra/src/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ function Base.hash(F::QRCompactWY, h::UInt)
return hash(F.factors, foldr(hash, _triuppers_qr(F.T); init=hash(QRCompactWY, h)))
end
function Base.:(==)(A::QRCompactWY, B::QRCompactWY)
return A.factors == B.factors && all(splat(==), zip(_triuppers_qr.((A.T, B.T))...))
return A.factors == B.factors && all(Splat(==), zip(_triuppers_qr.((A.T, B.T))...))
end
function Base.isequal(A::QRCompactWY, B::QRCompactWY)
return isequal(A.factors, B.factors) && all(zip(_triuppers_qr.((A.T, B.T))...)) do (a, b)
Expand Down
4 changes: 2 additions & 2 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -870,13 +870,13 @@ end
ys = 1:2:20
bc = Broadcast.instantiate(Broadcast.broadcasted(*, xs, ys))
@test IndexStyle(bc) == IndexLinear()
@test sum(bc) == mapreduce(Base.splat(*), +, zip(xs, ys))
@test sum(bc) == mapreduce(Base.Splat(*), +, zip(xs, ys))

xs2 = reshape(xs, 1, :)
ys2 = reshape(ys, 1, :)
bc = Broadcast.instantiate(Broadcast.broadcasted(*, xs2, ys2))
@test IndexStyle(bc) == IndexCartesian()
@test sum(bc) == mapreduce(Base.splat(*), +, zip(xs, ys))
@test sum(bc) == mapreduce(Base.Splat(*), +, zip(xs, ys))

xs = 1:5:3*5
ys = 1:4:3*4
Expand Down
2 changes: 1 addition & 1 deletion test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2842,7 +2842,7 @@ j30385(T, y) = k30385(f30385(T, y))
@test @inferred(j30385(:dummy, 1)) == "dummy"

@test Base.return_types(Tuple, (NamedTuple{<:Any,Tuple{Any,Int}},)) == Any[Tuple{Any,Int}]
@test Base.return_types(Base.splat(tuple), (typeof((a=1,)),)) == Any[Tuple{Int}]
@test Base.return_types(Base.Splat(tuple), (typeof((a=1,)),)) == Any[Tuple{Int}]

# test that return_type_tfunc isn't affected by max_methods differently than return_type
_rttf_test(::Int8) = 0
Expand Down
2 changes: 1 addition & 1 deletion test/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ end
@test length(I) == iterate_length(I) == simd_iterate_length(I) == simd_trip_count(I)
@test collect(I) == iterate_elements(I) == simd_iterate_elements(I) == index_elements(I)
end
@test all(Base.splat(==), zip(Iterators.flatten(map(collect, P)), iter))
@test all(Base.Splat(==), zip(Iterators.flatten(map(collect, P)), iter))
end
end
@testset "empty/invalid partitions" begin
Expand Down

0 comments on commit ef10e52

Please sign in to comment.