Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Broadcast.flatten(bc).f more complier frendly. (better inferred and inlined) #43322

Merged
merged 1 commit into from
Jul 15, 2023

Conversation

N5N3
Copy link
Member

@N5N3 N5N3 commented Dec 3, 2021

A follow up attemp to fix #27988. (close #47493 close #50554)
Examples:

julia> using LazyArrays
julia> bc = @~ @. 1*(1 + 1) + 1*1;
julia> bc2 = @~ 1 .* 1 .- 1 .* 1 .^2 .+ 1 .* 1 .+ 1 .^ 3;

On master:

click for details

julia> @code_typed Broadcast.flatten(bc).f(1,1,1,1,1)
CodeInfo(
1%1  = Core.getfield(args, 1)::Int64%2  = Core.getfield(args, 2)::Int64%3  = Core.getfield(args, 3)::Int64%4  = Core.getfield(args, 4)::Int64%5  = Core.getfield(args, 5)::Int64
│   %6  = invoke Base.Broadcast.var"#13#14"{Base.Broadcast.var"#16#18"{Base.Broadcast.var"#15#17", Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}}, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}, Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}, typeof(+)}}(Base.Broadcast.var"#16#18"{Base.Broadcast.var"#15#17", Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}}, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}, Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}, typeof(+)}(Base.Broadcast.var"#15#17"(), Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}}(Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}(Base.Broadcast.var"#15#17"())), Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}(Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}(Base.Broadcast.var"#25#26"())), Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}(Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}(Base.Broadcast.var"#21#22"())), +))(%1::Int64, %2::Int64, %3::Vararg{Int64}, %4, %5)::Tuple{Int64, Int64, Vararg{Int64}}
│   %7  = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}(Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}(Base.Broadcast.var"#21#22"())), %6)::Tuple{Int64, Int64}%8  = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}(Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}(Base.Broadcast.var"#25#26"())), %6)::Tuple{Vararg{Int64}}%9  = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#16#18"{Base.Broadcast.var"#9#11", Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}}, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}, Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}, typeof(*)}(Base.Broadcast.var"#9#11"(), Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}}(Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}(Base.Broadcast.var"#15#17"())), Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}(Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}(Base.Broadcast.var"#25#26"())), Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}(Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}(Base.Broadcast.var"#21#22"())), *), %8)::Tuple{Int64}%10 = Core.getfield(%7, 1)::Int64%11 = Core.getfield(%7, 2)::Int64%12 = Base.mul_int(%10, %11)::Int64%13 = Core.getfield(%9, 1)::Int64%14 = Base.add_int(%12, %13)::Int64
└──       return %14
) => Int64

julia> @code_typed Broadcast.flatten(bc2).f(1,1,1,^,1,Val(2),1,1,^,1,Val(3))
CodeInfo(
1%1  = Core.getfield(args, 1)::Int64%2  = Core.getfield(args, 2)::Int64%3  = Core.getfield(args, 3)::Int64%4  = Core.getfield(args, 5)::Int64%5  = Core.getfield(args, 7)::Int64%6  = Core.getfield(args, 8)::Int64%7  = Core.getfield(args, 10)::Int64
│   %8  = invoke Base.Broadcast.var"#13#14"{Base.Broadcast.var"#16#18"{Base.Broadcast.var"#15#17", Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}}}, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}}, Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}}, typeof(Base.literal_pow)}}(Base.Broadcast.var"#16#18"{Base.Broadcast.var"#15#17", Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}}}, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}}, Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}}, typeof(Base.literal_pow)}(Base.Broadcast.var"#15#17"(), Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}}}(Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}}(Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}(Base.Broadcast.var"#15#17"()))), Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}}(Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}(Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}(Base.Broadcast.var"#25#26"()))), Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}}(Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}(Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}(Base.Broadcast.var"#21#22"()))), Base.literal_pow))(%3::Int64, ^::Function, %4::Vararg{Any}, $(QuoteNode(Val{2}())), %5, %6, ^, %7, $(QuoteNode(Val{3}())))::Tuple{Int64, Any, Vararg{Any}}
│   %9  = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}(Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}(Base.Broadcast.var"#21#22"())), %8)::Tuple{Int64, Any}%10 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}(Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}(Base.Broadcast.var"#25#26"())), %8)::Tuple%11 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#15#17"(), %10)::Tuple%12 = Core.getfield(%9, 1)::Int64%13 = Core.getfield(%9, 2)::Any%14 = (*)(%12, %13)::Any%15 = Core.tuple(%14)::Tuple{Any}%16 = Core._apply_iterate(Base.iterate, Core.tuple, %15, %11)::Tuple{Any, Vararg{Any}}%17 = Base.mul_int(%1, %2)::Int64%18 = Core.tuple(%17)::Tuple{Int64}%19 = Core._apply_iterate(Base.iterate, Core.tuple, %18, %16)::Tuple{Int64, Any, Vararg{Any}}%20 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}(Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}(Base.Broadcast.var"#21#22"())), %19)::Tuple{Int64, Any}%21 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}(Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}(Base.Broadcast.var"#25#26"())), %19)::Tuple%22 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#16#18"{Base.Broadcast.var"#15#17", Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}}, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}, Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}, typeof(*)}(Base.Broadcast.var"#15#17"(), Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}}(Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}(Base.Broadcast.var"#15#17"())), Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}(Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}(Base.Broadcast.var"#25#26"())), Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}(Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}(Base.Broadcast.var"#21#22"())), *), %21)::Tuple{Any, Vararg{Any}}%23 = Core.getfield(%20, 1)::Int64%24 = Core.getfield(%20, 2)::Any%25 = (-)(%23, %24)::Any%26 = Core.tuple(%25)::Tuple{Any}%27 = Core._apply_iterate(Base.iterate, Core.tuple, %26, %22)::Tuple{Any, Any, Vararg{Any}}%28 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}(Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}(Base.Broadcast.var"#21#22"())), %27)::Tuple{Any, Any}%29 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}(Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}(Base.Broadcast.var"#25#26"())), %27)::Tuple
│   %30 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#16#18"{Base.Broadcast.var"#9#11", Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}}}, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}}, Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}}, typeof(Base.literal_pow)}(Base.Broadcast.var"#9#11"(), Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}}}(Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}}(Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}(Base.Broadcast.var"#15#17"()))), Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}}(Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}(Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}(Base.Broadcast.var"#25#26"()))), Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}}(Base.Broadcast.var"#19#20"{Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}}(Base.Broadcast.var"#19#20"{Base.Broadcast.var"#21#22"}(Base.Broadcast.var"#21#22"()))), Base.literal_pow), %29)::Tuple{Any}
│   %31 = Core.getfield(%28, 1)::Any%32 = Core.getfield(%28, 2)::Any%33 = (+)(%31, %32)::Any%34 = Core.getfield(%30, 1)::Any%35 = (+)(%33, %34)::Any
└──       return %35
) => Any

On this PR

julia> @code_typed Broadcast.flatten(bc).f(1,1,1,1,1)   
CodeInfo(
1%1 = Core.getfield(args, 1)::Int64%2 = Core.getfield(args, 2)::Int64%3 = Core.getfield(args, 3)::Int64%4 = Core.getfield(args, 4)::Int64%5 = Core.getfield(args, 5)::Int64%6 = Base.add_int(%2, %3)::Int64%7 = Base.mul_int(%1, %6)::Int64%8 = Base.mul_int(%4, %5)::Int64%9 = Base.add_int(%7, %8)::Int64
└──      return %9
) => Int64

julia> @code_typed Broadcast.flatten(bc2).f(1,1,1,^,1,Val(2),1,1,^,1,Val(3))
CodeInfo(
1%1  = Core.getfield(args, 1)::Int64%2  = Core.getfield(args, 2)::Int64%3  = Core.getfield(args, 3)::Int64%4  = Core.getfield(args, 5)::Int64%5  = Core.getfield(args, 7)::Int64%6  = Core.getfield(args, 8)::Int64%7  = Core.getfield(args, 10)::Int64%8  = Base.mul_int(%1, %2)::Int64%9  = Base.mul_int(%4, %4)::Int64%10 = Base.mul_int(%3, %9)::Int64%11 = Base.sub_int(%8, %10)::Int64%12 = Base.mul_int(%5, %6)::Int64%13 = Base.add_int(%11, %12)::Int64%14 = Base.mul_int(%7, %7)::Int64%15 = Base.mul_int(%14, %7)::Int64%16 = Base.add_int(%13, %15)::Int64
└──       return %16
) => Int64

@dkarrasch dkarrasch added broadcast Applying a function over a collection compiler:inference Type inference labels Dec 3, 2021
@N5N3 N5N3 changed the title Add more type annotation to Broadcast.flatten for better inference. Make Broadcast.flatten(bc).f more complier frendly. Dec 4, 2021
@N5N3
Copy link
Member Author

N5N3 commented Dec 4, 2021

I only add the second example to test. As the first one's instability doesnt influence the return type.
So the test might not be that robust. (Do we have tools to test internal instability?)

@N5N3

This comment was marked as outdated.

@N5N3 N5N3 changed the title Make Broadcast.flatten(bc).f more complier frendly. Make Broadcast.flatten(bc).f more complier frendly. (better inferred and inlined) Feb 11, 2022
@N5N3
Copy link
Member Author

N5N3 commented Jun 19, 2022

It turns out this problem could be overcomed much more easily, as nested Broadcasted's getindex is well inferred and inlined by our compiler.
In [a8283f2], the process of Broadcast.flatten follows Broadcast.preprocess's style: make_makeargs will generate a tuple of functions. These functions take in the whole "flatttened" argument list, and return the appropriate input arguments of bc.f. i.e.

flattened(args...) = bc.f(map(f -> f(args), makeargs)...)

@ChrisRackauckas
Copy link
Member

Maybe @vchuravy can take a look at this? In general packages shouldn't use as much Broadcast.flatten as they do, but I think this will have a lot of knock-on effects because they do use it a lot, and the current lowering is pretty bad.

@N5N3 N5N3 added the forget me not PRs that one wants to make sure aren't forgotten label Aug 17, 2022
Copy link
Member

@vtjnash vtjnash left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This reminds me of #45789–was that the inspiration?

ah, this is a fun new way to make use of our compiler's constant prop. I feel a bit nervous that the compiler won't be up to the challenge though of figuring out the Pick{N} allocation, or that it should be a Val{N} instead (with the caller doing the indexing explicitly with the result rather than calling the result as a function). How well is that working though in your experience with this?

@N5N3
Copy link
Member Author

N5N3 commented Sep 2, 2022

In fact this was inspired by the current broadcast implement itself.
I just noticed that our nested Broadcasted's getindex/preprocess could be well inferred and inlined even with indirect recursion.

As for Pick{N}, I haven't encounter any realistic problem with it.
I tested it with more than 32 args, it turns out we touch the inline limitation of tail (in Broadcast._getindex) first, while flatten and "flattened function" still work well.
Thus my local version even tries to store singleton instance/type/bit scalar inputs into Returns to reduce the number of broadcasted arguments.

Since the flatten process is stable. I think our complier should be able to handle it well.
BTW is there a difference between Pick{N} and Val{N}?

@vtjnash
Copy link
Member

vtjnash commented Jul 15, 2023

No difference content or concept wise. Val just means you have to unwrap it yourself later and call getindex, rather than call it directly with the Tuple and have that defined to return the indexed value

base/broadcast.jl Outdated Show resolved Hide resolved
1. make `cat_nested` better inferred by switching to direct self-recursion.
2. `make_makeargs` now create a tuple of functions which take in the whole argument list and return the corresponding input for the broadcasted function.
@N5N3 N5N3 removed the forget me not PRs that one wants to make sure aren't forgotten label Jul 15, 2023
@N5N3 N5N3 added this pull request to the merge queue Jul 15, 2023
Merged via the queue into JuliaLang:master with commit f15eb4e Jul 15, 2023
4 checks passed
@N5N3 N5N3 deleted the N5N3flatten branch July 15, 2023 12:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
broadcast Applying a function over a collection compiler:inference Type inference
Projects
None yet
4 participants