Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pretty printing for splat(f) #42717

Merged
merged 11 commits into from
May 9, 2022
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ New library functions
---------------------

* `Iterators.flatmap` was added ([#44792]).
* New helper `Splat(f)` which acts like `x -> f(x...)`, with pretty printing for
inspecting which function `f` was originally wrapped. ([#42717])

Library changes
---------------
Expand Down Expand Up @@ -120,6 +122,7 @@ Standard library changes
Deprecated or removed
---------------------

* Unexported `splat` is deprecated in favor of exported `Splat`, which has pretty printing of the wrapped function. ([#42717])

External dependencies
---------------------
Expand Down
6 changes: 6 additions & 0 deletions base/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -294,3 +294,9 @@ const var"@_noinline_meta" = var"@noinline"
@deprecate getindex(t::Tuple, i::Real) t[convert(Int, i)]

# END 1.8 deprecations

# BEGIN 1.9 deprecations

@deprecate splat(x) Splat(x) false

# END 1.9 deprecations
1 change: 1 addition & 0 deletions base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,7 @@ export
atreplinit,
exit,
ntuple,
Splat,

# I/O and events
close,
Expand Down
2 changes: 1 addition & 1 deletion base/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ the `zip` iterator is a tuple of values of its subiterators.
`zip` orders the calls to its subiterators in such a way that stateful iterators will
not advance when another iterator finishes in the current iteration.

See also: [`enumerate`](@ref), [`splat`](@ref Base.splat).
See also: [`enumerate`](@ref), [`Splat`](@ref Base.Splat).

# Examples
```jldoctest
Expand Down
24 changes: 18 additions & 6 deletions base/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1185,27 +1185,39 @@ used to implement specialized methods.
<(x) = Fix2(<, x)

"""
splat(f)
Splat(f)

Defined as
Equivalent to
```julia
splat(f) = args->f(args...)
my_splat(f) = args->f(args...)
```
i.e. given a function returns a new function that takes one argument and splats
its argument into the original function. This is useful as an adaptor to pass
a multi-argument function in a context that expects a single argument, but
passes a tuple as that single argument.
passes a tuple as that single argument. Additionally has pretty printing.

# Example usage:
```jldoctest
julia> map(Base.splat(+), zip(1:3,4:6))
julia> map(Base.Splat(+), zip(1:3,4:6))
3-element Vector{Int64}:
5
7
9

julia> my_add = Base.Splat(+)
Splat(+)

julia> my_add((1,2,3))
6
```
"""
splat(f) = args->f(args...)
struct Splat{F} <: Function
f::F
Splat(f) = new{Core.Typeof(f)}(f)
end
Seelengrab marked this conversation as resolved.
Show resolved Hide resolved
(s::Splat)(args) = s.f(args...)
print(io::IO, s::Splat) = print(io, "Splat(", s.f, ')')
show(io::IO, s::Splat) = print(io, s)

## in and related operators

Expand Down
1 change: 1 addition & 0 deletions base/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ end

show(io::IO, ::MIME"text/plain", c::ComposedFunction) = show(io, c)
show(io::IO, ::MIME"text/plain", c::Returns) = show(io, c)
show(io::IO, ::MIME"text/plain", s::Splat) = show(io, s)

function show(io::IO, ::MIME"text/plain", iter::Union{KeySet,ValueIterator})
isempty(iter) && get(io, :compact, false) && return show(io, iter)
Expand Down
4 changes: 2 additions & 2 deletions base/strings/search.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ function _searchindex(s::Union{AbstractString,ByteArray},
if i === nothing return 0 end
ii = nextind(s, i)::Int
a = Iterators.Stateful(trest)
matched = all(splat(==), zip(SubString(s, ii), a))
matched = all(Splat(==), zip(SubString(s, ii), a))
(isempty(a) && matched) && return i
i = ii
end
Expand Down Expand Up @@ -435,7 +435,7 @@ function _rsearchindex(s::AbstractString,
a = Iterators.Stateful(trest)
b = Iterators.Stateful(Iterators.reverse(
pairs(SubString(s, 1, ii))))
matched = all(splat(==), zip(a, (x[2] for x in b)))
matched = all(Splat(==), zip(a, (x[2] for x in b)))
if matched && isempty(a)
isempty(b) && return firstindex(s)
return nextind(s, popfirst!(b)[1])::Int
Expand Down
2 changes: 1 addition & 1 deletion doc/src/base/base.md
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ new
Base.:(|>)
Base.:(∘)
Base.ComposedFunction
Base.splat
Base.Splat
Base.Fix1
Base.Fix2
```
Expand Down
2 changes: 1 addition & 1 deletion doc/src/devdocs/ast.md
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ These symbols appear in the `head` field of [`Expr`](@ref)s in lowered form.
* `splatnew`

Similar to `new`, except field values are passed as a single tuple. Works similarly to
`Base.splat(new)` if `new` were a first-class function, hence the name.
`Base.Splat(new)` if `new` were a first-class function, hence the name.

* `isdefined`

Expand Down
2 changes: 1 addition & 1 deletion stdlib/LinearAlgebra/src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import Base: USE_BLAS64, abs, acos, acosh, acot, acoth, acsc, acsch, adjoint, as
vec, zero
using Base: IndexLinear, promote_eltype, promote_op, promote_typeof,
@propagate_inbounds, reduce, typed_hvcat, typed_vcat, require_one_based_indexing,
splat
Splat
using Base.Broadcast: Broadcasted, broadcasted
using OpenBLAS_jll
using libblastrampoline_jll
Expand Down
2 changes: 1 addition & 1 deletion stdlib/LinearAlgebra/src/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ function Base.hash(F::QRCompactWY, h::UInt)
return hash(F.factors, foldr(hash, _triuppers_qr(F.T); init=hash(QRCompactWY, h)))
end
function Base.:(==)(A::QRCompactWY, B::QRCompactWY)
return A.factors == B.factors && all(splat(==), zip(_triuppers_qr.((A.T, B.T))...))
return A.factors == B.factors && all(Splat(==), zip(_triuppers_qr.((A.T, B.T))...))
end
function Base.isequal(A::QRCompactWY, B::QRCompactWY)
return isequal(A.factors, B.factors) && all(zip(_triuppers_qr.((A.T, B.T))...)) do (a, b)
Expand Down
4 changes: 2 additions & 2 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -870,13 +870,13 @@ end
ys = 1:2:20
bc = Broadcast.instantiate(Broadcast.broadcasted(*, xs, ys))
@test IndexStyle(bc) == IndexLinear()
@test sum(bc) == mapreduce(Base.splat(*), +, zip(xs, ys))
@test sum(bc) == mapreduce(Base.Splat(*), +, zip(xs, ys))

xs2 = reshape(xs, 1, :)
ys2 = reshape(ys, 1, :)
bc = Broadcast.instantiate(Broadcast.broadcasted(*, xs2, ys2))
@test IndexStyle(bc) == IndexCartesian()
@test sum(bc) == mapreduce(Base.splat(*), +, zip(xs, ys))
@test sum(bc) == mapreduce(Base.Splat(*), +, zip(xs, ys))

xs = 1:5:3*5
ys = 1:4:3*4
Expand Down
2 changes: 1 addition & 1 deletion test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2842,7 +2842,7 @@ j30385(T, y) = k30385(f30385(T, y))
@test @inferred(j30385(:dummy, 1)) == "dummy"

@test Base.return_types(Tuple, (NamedTuple{<:Any,Tuple{Any,Int}},)) == Any[Tuple{Any,Int}]
@test Base.return_types(Base.splat(tuple), (typeof((a=1,)),)) == Any[Tuple{Int}]
@test Base.return_types(Base.Splat(tuple), (typeof((a=1,)),)) == Any[Tuple{Int}]

# test that return_type_tfunc isn't affected by max_methods differently than return_type
_rttf_test(::Int8) = 0
Expand Down
2 changes: 1 addition & 1 deletion test/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ end
@test length(I) == iterate_length(I) == simd_iterate_length(I) == simd_trip_count(I)
@test collect(I) == iterate_elements(I) == simd_iterate_elements(I) == index_elements(I)
end
@test all(Base.splat(==), zip(Iterators.flatten(map(collect, P)), iter))
@test all(Base.Splat(==), zip(Iterators.flatten(map(collect, P)), iter))
end
end
@testset "empty/invalid partitions" begin
Expand Down