From ef23805084d3bf13675b1d13c9bb68c0f7d2edf4 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Fri, 18 Mar 2022 16:14:59 +0900 Subject: [PATCH] inference: override `InterConditional` result with `Const` carefully MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I found that a tricky thing can happen when constant inference derives `Const`-result while non-constant inference has derived (non-constant) `InterConditional` result beforehand. In such a case, currently we discard the result with constant infernece (since `!(Const ⊑ InterConditional)`), but we can achieve more accuracy by not discarding that `Const`-information, e.g.: ```julia julia> iszero_simple(x) = x === 0 iszero_simple (generic function with 1 method) julia> @test Base.return_types() do iszero_simple(0) ? nothing : missing end |> only === Nothing Test Passed ``` --- base/compiler/abstractinterpretation.jl | 29 ++++++++++++++----------- test/compiler/inference.jl | 7 ++++++ 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 5d8c82bd5ee9a7..bf57ae51e41752 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -118,10 +118,9 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), splitsigs = switchtupleunion(sig) for sig_n in splitsigs result = abstract_call_method(interp, method, sig_n, svec(), multiple_matches, sv) - rt, edge = result.rt, result.edge - if edge !== nothing - push!(edges, edge) - end + rt = result.rt + edge = result.edge + edge !== nothing && push!(edges, edge) this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i] this_arginfo = ArgInfo(fargs, this_argtypes) const_call_result = abstract_call_method_with_const_args(interp, result, @@ -129,8 +128,10 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), effects = result.edge_effects const_result = nothing if const_call_result !== nothing - if const_call_result.rt ⊑ rt - (; rt, effects, const_result) = const_call_result + const_rt = const_call_result.rt + if const_rt ⊑ rt + rt = const_rt + (; effects, const_result) = const_call_result end end tristate_merge!(sv, effects) @@ -143,6 +144,8 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), break end end + this_conditional = ignorelimited(this_rt) + this_rt = widenwrappedconditional(this_rt) else if infer_compilation_signature(interp) # Also infer the compilation signature for this method, so it's available @@ -159,10 +162,10 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), end result = abstract_call_method(interp, method, sig, match.sparams, multiple_matches, sv) - this_rt, edge = result.rt, result.edge - if edge !== nothing - push!(edges, edge) - end + this_conditional = ignorelimited(result.rt) + this_rt = widenwrappedconditional(result.rt) + edge = result.edge + edge !== nothing && push!(edges, edge) # try constant propagation with argtypes for this match # this is in preparation for inlining, or improving the return result this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i] @@ -172,10 +175,12 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), effects = result.edge_effects const_result = nothing if const_call_result !== nothing - this_const_rt = const_call_result.rt + this_const_conditional = ignorelimited(const_call_result.rt) + this_const_rt = widenwrappedconditional(const_call_result.rt) # return type of const-prop' inference can be wider than that of non const-prop' inference # e.g. in cases when there are cycles but cached result is still accurate if this_const_rt ⊑ this_rt + this_conditional = this_const_conditional this_rt = this_const_rt (; effects, const_result) = const_call_result end @@ -186,8 +191,6 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), any_const_result = true end end - this_conditional = ignorelimited(this_rt) - this_rt = widenwrappedconditional(this_rt) @assert !(this_conditional isa Conditional) "invalid lattice element returned from inter-procedural context" seen += 1 rettype = tmerge(rettype, this_rt) diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index 48e913ba8b1672..64462520559fa3 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -2032,6 +2032,13 @@ end end @test ts == Any[Any] end + + # a tricky case: if constant inference derives `Const` while non-constant infernece has + # derived `InterConditional`, we should not discard that constant information + iszero_simple(x) = x === 0 + @test Base.return_types() do + iszero_simple(0) ? nothing : missing + end |> only === Nothing end @testset "branching on conditional object" begin