Skip to content

Commit

Permalink
add more test and remove cat_nested
Browse files Browse the repository at this point in the history
`cat_nested` failed to infer in some cases.
It has been inserted into `make_makeargs`, So I remove it.
  • Loading branch information
N5N3 committed Dec 8, 2021
1 parent fd6721b commit 7429b48
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 15 deletions.
15 changes: 4 additions & 11 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -314,9 +314,8 @@ some cases.
"""
function flatten(bc::Broadcasted{Style}) where {Style}
isflat(bc) && return bc
# concatenate the nested arguments into {a, b, c, d}
# args = cat_nested(bc)
# build a function `makeargs` that takes a "flat" argument list and
# 1. concatenate the nested arguments into {a, b, c, d}
# 2. build a function `makeargs` that takes a "flat" argument list and
# and creates the appropriate input arguments for `f`, e.g.,
# makeargs = (w, x, y, z) -> (w, g(x, y), z)
#
Expand All @@ -329,8 +328,7 @@ function flatten(bc::Broadcasted{Style}) where {Style}
@inline function (args::Vararg{Any,N})
f(makeargs(args...)...)
end
newf = _make(args)
return Broadcasted{Style}(newf, args, bc.axes)
return Broadcasted{Style}(_make(args), args, bc.axes)
end
end

Expand All @@ -340,18 +338,13 @@ _isflat(args::NestedTuple) = false
_isflat(args::Tuple) = _isflat(tail(args))
_isflat(args::Tuple{}) = true

cat_nested(t::Broadcasted, rest...) = (cat_nested(t.args...)..., cat_nested(rest...)...)
cat_nested(t::Any, rest...) = (t, cat_nested(rest...)...)
cat_nested() = ()

"""
make_makeargs(args::Tuple, t::Tuple) -> Function, Tuple
Each element of `t` is one (consecutive) node in a broadcast tree.
`args` contains the rest arguments on the "right" side of `t`.
The jobs of `make_makeargs` are:
1. append the flattened arguments in `t` at the beginning of `args`, i.e.
`(cat_nested(t)..., args...)`
1. append the flattened arguments in `t` at the beginning of `args`.
2. return a function that takes in flattened argument list and returns a
tuple (each entry corresponding to an entry in `t`, having evaluated
the corresponding element in the broadcast tree).
Expand Down
11 changes: 7 additions & 4 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -772,16 +772,19 @@ end

# issue #27988: inference of Broadcast.flatten
using .Broadcast: Broadcasted
let
let _cat_nested(bc) = Broadcast.flatten(bc).args
bc = Broadcasted(+, (Broadcasted(*, (1, 2)), Broadcasted(*, (Broadcasted(*, (3, 4)), 5))))
@test @inferred(Broadcast.cat_nested(bc)) == (1,2,3,4,5)
@test @inferred(_cat_nested(bc)) == (1,2,3,4,5)
@test @inferred(Broadcast.materialize(Broadcast.flatten(bc))) == @inferred(Broadcast.materialize(bc)) == 62
bc = Broadcasted(+, (Broadcasted(*, (1, Broadcasted(/, (2.0, 2.5)))), Broadcasted(*, (Broadcasted(*, (3, 4)), 5))))
@test @inferred(Broadcast.cat_nested(bc)) == (1,2.0,2.5,3,4,5)
@test @inferred(_cat_nested(bc)) == (1,2.0,2.5,3,4,5)
@test @inferred(Broadcast.materialize(Broadcast.flatten(bc))) == @inferred(Broadcast.materialize(bc)) == 60.8
# 1 .* 1 .- 1 .* 1 .^2 .+ 1 .* 1 .+ 1 .^ 3
bc = Base.Broadcast.Broadcasted(+, (Base.Broadcast.Broadcasted(+, (Base.Broadcast.Broadcasted(-, (Base.Broadcast.Broadcasted(*, (1, 1)), Base.Broadcast.Broadcasted(*, (1, Base.Broadcast.Broadcasted(Base.literal_pow, (Base.RefValue{typeof(^)}(^), 1, Base.RefValue{Val{2}}(Val{2}()))))))), Base.Broadcast.Broadcasted(*, (1, 1)))), Base.Broadcast.Broadcasted(Base.literal_pow, (Base.RefValue{typeof(^)}(^), 1, Base.RefValue{Val{3}}(Val{3}())))))
bc = Broadcasted(+, (Broadcasted(+, (Broadcasted(-, (Broadcasted(*, (1, 1)), Broadcasted(*, (1, Broadcasted(Base.literal_pow, (Ref(^), 1, Ref(Val(2)))))))), Broadcasted(*, (1, 1)))), Broadcasted(Base.literal_pow, (Base.RefValue{typeof(^)}(^), 1, Base.RefValue{Val{3}}(Val{3}())))))
@test @inferred(Broadcast.materialize(Broadcast.flatten(bc))) == @inferred(Broadcast.materialize(bc)) == 2
# @. 1 + 1 * (1 + 1 + 1 + 1)
bc = Broadcasted(+, (1, Broadcasted(*, (1, Broadcasted(+, (1, 1, 1, 1))))))
@test @inferred(_cat_nested(bc)) == (1,1,1,1,1,1) # `cat_nested` failed to infer this
end

let
Expand Down

0 comments on commit 7429b48

Please sign in to comment.