Skip to content

Commit

Permalink
inference: (slightly) improve type stability of capturing closures
Browse files Browse the repository at this point in the history
As an idea to improve type stability for capturing closures, such as in
#31909, I tried this idea of propagating the closure
object as a `PartialStruct` whose `fields` include captured variables
of which types are (partially) known. By performing const-prop on this
`closure::PartialStruct`, we can achieve a certain level of type
stability.
Specifically, I made some modifications to `abstract_eval_new` to allow
creating `PartialStruct` even for `:new` objects that are
`!isconcretedispatch` (since `PartialStruct` can now represent abstract
elements). I also adjusted `const_prop_argument_heuristic` to perform
aggressive constant propagation using such `closure::PartialStruct`.

As a result, the following code now achieves type stability:
```julia
julia> Base.infer_return_type((Bool,Int,)) do b, y
           x = b ? 1 : missing
           inner = y -> x + y
           return inner(y)
       end
Any                   # master
Union{Missing, Int64} # this commit
```

However, this alone was not enough to fully resolve #31909.
The call graph of `map` is extremely complex, and simply applying
constant propagation everywhere does not achieve the type safety
requested in the issue.

Nevertheless this commit alone would still improve type stability for
some cases, so I will go ahead and submit it as a PR.
  • Loading branch information
aviatesk committed Nov 22, 2024
1 parent 1bf2ef9 commit fb76bfe
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 33 deletions.
96 changes: 64 additions & 32 deletions Compiler/src/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1053,9 +1053,12 @@ function maybe_get_const_prop_profitable(interp::AbstractInterpreter,
# N.B. remarks are emitted within `const_prop_rettype_heuristic`
return nothing
end
if !const_prop_argument_heuristic(interp, arginfo, sv)
arg_result = const_prop_argument_heuristic(interp, arginfo, sv)
if arg_result === nothing
add_remark!(interp, sv, "[constprop] Disabled by argument heuristics")
return nothing
else
force |= arg_result
end
all_overridden = is_all_overridden(interp, arginfo, sv)
if !force && !const_prop_function_heuristic(interp, f, arginfo, all_overridden, sv)
Expand Down Expand Up @@ -1122,16 +1125,19 @@ end
function const_prop_argument_heuristic(interp::AbstractInterpreter, arginfo::ArgInfo, sv::AbsIntState)
𝕃ᡒ = typeinf_lattice(interp)
argtypes = arginfo.argtypes
for i in 1:length(argtypes)
for i = 1:length(argtypes)
a = argtypes[i]
if has_conditional(𝕃ᡒ, sv) && isa(a, Conditional) && arginfo.fargs !== nothing
is_const_prop_profitable_conditional(a, arginfo.fargs, sv) && return true
is_const_prop_profitable_conditional(a, arginfo.fargs, sv) && return false
else
a = widenslotwrapper(a)
has_nontrivial_extended_info(𝕃ᡒ, a) && is_const_prop_profitable_arg(𝕃ᡒ, a) && return true
if has_nontrivial_extended_info(𝕃ᡒ, a) && is_const_prop_profitable_arg(𝕃ᡒ, a)
# force const-prop' if the function object itself has some profitable information
return i == 1 || widenconst(a) <: Function
end
end
end
return false
return nothing
end

function is_const_prop_profitable_conditional(cnd::Conditional, fargs::Vector{Any}, sv::InferenceState)
Expand Down Expand Up @@ -2992,8 +2998,8 @@ function abstract_eval_new(interp::AbstractInterpreter, e::Expr, sstate::Stateme
ismutable = ismutabletype(ut)
fcount = datatype_fieldcount(ut)
nargs = length(e.args) - 1
has_any_uninitialized = (fcount === nothing || (fcount > nargs && (let t = rt
any(i::Int -> !is_undefref_fieldtype(fieldtype(t, i)), (nargs+1):fcount)
has_any_uninitialized = (fcount === nothing || (fcount > nargs && (let boxed = Core.Box(rt)
any(i::Int -> !is_undefref_fieldtype(fieldtype(boxed.contents, i)), (nargs+1):fcount)
end)))
if has_any_uninitialized
# allocation with undefined field is inconsistent always
Expand All @@ -3005,43 +3011,69 @@ function abstract_eval_new(interp::AbstractInterpreter, e::Expr, sstate::Stateme
else
consistent = ALWAYS_TRUE # immutable allocation is consistent
end
if isconcretedispatch(rt)
nothrow = true
@assert fcount !== nothing && fcount β‰₯ nargs "malformed :new expression" # syntactically enforced by the front-end
ats = Vector{Any}(undef, nargs)
local anyrefine = false
local allconst = true
@inline function compute_fields_info(@nospecialize(rt))
local anyrefine, allconst, nothrow = false, true, true
βŠ‘, β‹€, βŠ“ = partialorder(𝕃ᡒ), strictneqpartialorder(𝕃ᡒ), meet(𝕃ᡒ)
for i = 1:nargs
at = widenslotwrapper(abstract_eval_value(interp, e.args[i+1], sstate, sv))
ft = fieldtype(rt, i)
nothrow && (nothrow = βŠ‘(𝕃ᡒ, at, ft))
at = tmeet(𝕃ᡒ, at, ft)
at === Bottom && return RTEffects(Bottom, TypeError, EFFECTS_THROWS)
nothrow && (nothrow = at βŠ‘ ft)
at = at βŠ“ ft
at === Bottom && return nothing
if ismutable && !isconst(rt, i)
ats[i] = ft # can't constrain this field (as it may be modified later)
# can't constrain this field (as it may be modified later)
allconst = false
continue
end
allconst &= isa(at, Const)
if !anyrefine
anyrefine = has_nontrivial_extended_info(𝕃ᡒ, at) || # extended lattice information
β‹€(𝕃ᡒ, at, ft) # just a type-level information, but more precise than the declared type
anyrefine || (anyrefine =
has_nontrivial_extended_info(𝕃ᡒ, at) || # extended lattice information
at β‹€ ft) # just a type-level information, but more precise than the declared type
end
return anyrefine, allconst, nothrow
end
@noinline function compute_fields(@nospecialize(rt), unwrap_const::Bool=false)
local fields = Vector{Any}(undef, nargs)
βŠ“ = meet(𝕃ᡒ)
for i = 1:nargs
at = widenslotwrapper(abstract_eval_value(interp, e.args[i+1], vtypes, sv))
ft = fieldtype(rt, i)
if ismutable && !isconst(rt, i)
@assert !unwrap_const
fields[i] = ft # can't constrain this field (as it may be modified later)
else
at = at βŠ“ ft
if unwrap_const
fields[i] = (at::Const).val
else
fields[i] = at
end
end
ats[i] = at
end
return fields
end
if isconcretedispatch(rt)
@assert fcount !== nothing && fcount β‰₯ nargs "malformed :new expression" # syntactically enforced by the front-end
ret = compute_fields_info(rt)
ret === nothing && return RTEffects(Bottom, TypeError, EFFECTS_THROWS)
anyrefine, allconst, nothrow = ret
if fcount == nargs && consistent === ALWAYS_TRUE && allconst
argvals = Vector{Any}(undef, nargs)
for j in 1:nargs
argvals[j] = (ats[j]::Const).val
end
argvals = compute_fields(rt, #=unwrap_const=#true)
rt = Const(ccall(:jl_new_structv, Any, (Any, Ptr{Cvoid}, UInt32), rt, argvals, nargs))
elseif anyrefine || nargs > datatype_min_ninitialized(rt)
# propagate partially initialized struct as `PartialStruct` when:
# - any refinement information is available (`anyrefine`), or when
# - `nargs` is greater than `n_initialized` derived from the struct type
# information alone
rt = PartialStruct(𝕃ᡒ, rt, ats)
rt = PartialStruct(𝕃ᡒ, rt, compute_fields(rt))
end
else
ret = compute_fields_info(rt)
ret === nothing && return RTEffects(Bottom, TypeError, EFFECTS_THROWS)
anyrefine = ret[1]
if anyrefine || nargs > datatype_min_ninitialized(ut)
rt = PartialStruct(𝕃ᡒ, rt, compute_fields(rt))
end
rt = refine_partial_type(rt)
nothrow = false
end
Expand All @@ -3063,18 +3095,18 @@ function abstract_eval_splatnew(interp::AbstractInterpreter, e::Expr, sstate::St
at = abstract_eval_value(interp, e.args[2], sstate, sv)
n = fieldcount(rt)
if (isa(at, Const) && isa(at.val, Tuple) && n == length(at.val::Tuple) &&
(let t = rt, at = at
all(i::Int -> getfield(at.val::Tuple, i) isa fieldtype(t, i), 1:n)
(let boxed = Core.Box(rt)
all(i::Int -> getfield(at.val::Tuple, i) isa fieldtype(boxed.contents, i), 1:n)
end))
nothrow = isexact
rt = Const(ccall(:jl_new_structt, Any, (Any, Any), rt, at.val))
elseif (isa(at, PartialStruct) && βŠ‘(𝕃ᡒ, at, Tuple) && n > 0 &&
n == length(at.fields::Vector{Any}) && !isvarargtype(at.fields[end]) &&
(let t = rt, at = at
all(i::Int -> βŠ‘(𝕃ᡒ, (at.fields::Vector{Any})[i], fieldtype(t, i)), 1:n)
n == length(at.fields) && !isvarargtype(at.fields[end]) &&
(let boxed = Core.Box(rt)
all(i::Int -> βŠ‘(𝕃ᡒ, (at.fields)[i], fieldtype(boxed.contents, i)), 1:n)
end))
nothrow = isexact
rt = PartialStruct(𝕃ᡒ, rt, at.fields::Vector{Any})
rt = PartialStruct(𝕃ᡒ, rt, at.fields)
end
else
rt = refine_partial_type(rt)
Expand Down
2 changes: 1 addition & 1 deletion base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ end

const REFLECTION_COMPILER = RefValue{Union{Nothing, Module}}(nothing)

function invoke_in_typeinf_world(args...)
function invoke_in_typeinf_world(@nospecialize args...)
vargs = Any[args...]
return ccall(:jl_call_in_typeinf_world, Any, (Ptr{Any}, Cint), vargs, length(vargs))
end
Expand Down

0 comments on commit fb76bfe

Please sign in to comment.