Skip to content

Commit

Permalink
correctly limit depth and length
Browse files Browse the repository at this point in the history
remove code to handle exponential blowup,
since there isn't any
  • Loading branch information
vtjnash committed Oct 9, 2017
1 parent 813525c commit b89e88e
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 88 deletions.
110 changes: 50 additions & 60 deletions base/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -909,8 +909,8 @@ end
function limit_type_size(@nospecialize(t), @nospecialize(compare), @nospecialize(source), allowed_tuplelen::Int)
source = svec(unwrap_unionall(compare), unwrap_unionall(source))
source[1] === source[2] && (source = svec(source[1]))
type_more_complex(t, compare, source, TUPLE_COMPLEXITY_LIMIT_DEPTH, allowed_tuplelen) || return t
r = _limit_type_size(t, compare, source, allowed_tuplelen)
type_more_complex(t, compare, source, 1, TUPLE_COMPLEXITY_LIMIT_DEPTH, allowed_tuplelen) || return t
r = _limit_type_size(t, compare, source, 1, allowed_tuplelen)
@assert t <: r
#@assert r === _limit_type_size(r, t, source) # this monotonicity constraint is slightly stronger than actually required,
# since we only actually need to demonstrate that repeated application would reaches a fixed point,
Expand All @@ -920,7 +920,7 @@ end

sym_isless(a::Symbol, b::Symbol) = ccall(:strcmp, Int32, (Ptr{UInt8}, Ptr{UInt8}), a, b) < 0

function type_more_complex(@nospecialize(t), @nospecialize(c), sources::SimpleVector, tupledepth::Int, allowed_tuplelen::Int)
function type_more_complex(@nospecialize(t), @nospecialize(c), sources::SimpleVector, depth::Int, tupledepth::Int, allowed_tuplelen::Int)
# detect cases where the comparison is trivial
if t === c
return false
Expand All @@ -930,7 +930,7 @@ function type_more_complex(@nospecialize(t), @nospecialize(c), sources::SimpleVe
return false # fastpath: unparameterized types are always finite
elseif tupledepth > 0 && isa(unwrap_unionall(t), DataType) && isa(c, Type) && c !== Union{} && c <: t
return false # t is already wider than the comparison in the type lattice
elseif tupledepth > 0 && is_derived_type_from_any(unwrap_unionall(t), sources)
elseif tupledepth > 0 && is_derived_type_from_any(unwrap_unionall(t), sources, depth)
return false # t isn't something new
end
# peel off wrappers
Expand All @@ -944,19 +944,20 @@ function type_more_complex(@nospecialize(t), @nospecialize(c), sources::SimpleVe
end
# rules for various comparison types
if isa(c, TypeVar)
tupledepth = 1 # allow replacing a TypeVar with a concrete value (since we know the UnionAll must be in covariant position)
if isa(t, TypeVar)
return !(t.lb === Union{} || t.lb === c.lb) || # simplify lb towards Union{}
type_more_complex(t.ub, c.ub, sources, tupledepth, 0)
type_more_complex(t.ub, c.ub, sources, depth + 1, tupledepth, 0)
end
c.lb === Union{} || return true
return type_more_complex(t, c.ub, sources, max(tupledepth, 1), 0) # allow replacing a TypeVar with a concrete value
return type_more_complex(t, c.ub, sources, depth, tupledepth, 0)
elseif isa(c, Union)
if isa(t, Union)
return type_more_complex(t.a, c.a, sources, tupledepth, allowed_tuplelen) ||
type_more_complex(t.b, c.b, sources, tupledepth, allowed_tuplelen)
return type_more_complex(t.a, c.a, sources, depth, tupledepth, allowed_tuplelen) ||
type_more_complex(t.b, c.b, sources, depth, tupledepth, allowed_tuplelen)
end
return type_more_complex(t, c.a, sources, tupledepth, allowed_tuplelen) &&
type_more_complex(t, c.b, sources, tupledepth, allowed_tuplelen)
return type_more_complex(t, c.a, sources, depth, tupledepth, allowed_tuplelen) &&
type_more_complex(t, c.b, sources, depth, tupledepth, allowed_tuplelen)
elseif isa(t, Int) && isa(c, Int)
return t !== 1 # alternatively, could use !(0 <= t < c)
end
Expand Down Expand Up @@ -989,34 +990,41 @@ function type_more_complex(@nospecialize(t), @nospecialize(c), sources::SimpleVe
end
end
end
type_more_complex(tPi, cPi, sources, tupledepth, 0) && return true
type_more_complex(tPi, cPi, sources, depth + 1, tupledepth, 0) && return true
end
return false
elseif isvarargtype(c)
return type_more_complex(t, unwrapva(c), sources, tupledepth, 0)
return type_more_complex(t, unwrapva(c), sources, depth, tupledepth, 0)
end
if isType(t) # allow taking typeof any source type anywhere as Type{...}, as long as it isn't nesting Type{Type{...}}
tt = unwrap_unionall(t.parameters[1])
if isa(tt, DataType) && !isType(tt)
is_derived_type_from_any(tt, sources) || return true
is_derived_type_from_any(tt, sources, depth) || return true
return false
end
end
end
return true
end

function is_derived_type(@nospecialize(t), @nospecialize(c)) # try to find `type` somewhere in `comparison` type
t === c && return true
# try to find `type` somewhere in `comparison` type
# at a minimum nesting depth of `mindepth`
function is_derived_type(@nospecialize(t), @nospecialize(c), mindepth::Int)
if mindepth > 0
mindepth -= 1
end
if t === c
return mindepth == 0
end
if isa(c, TypeVar)
# see if it is replacing a TypeVar upper bound with something simpler
return is_derived_type(t, c.ub)
return is_derived_type(t, c.ub, mindepth)
elseif isa(c, Union)
# see if it is one of the elements of the union
return is_derived_type(t, c.a) || is_derived_type(t, c.b)
return is_derived_type(t, c.a, mindepth + 1) || is_derived_type(t, c.b, mindepth + 1)
elseif isa(c, UnionAll)
# see if it is derived from the body
return is_derived_type(t, c.body)
return is_derived_type(t, c.body, mindepth)
elseif isa(c, DataType)
if isa(t, DataType)
# see if it is one of the supertypes of a parameter
Expand All @@ -1029,7 +1037,7 @@ function is_derived_type(@nospecialize(t), @nospecialize(c)) # try to find `type
# see if it was extracted from a type parameter
cP = c.parameters
for p in cP
is_derived_type(t, p) && return true
is_derived_type(t, p, mindepth) && return true
end
if isleaftype(c) && isbits(c)
# see if it was extracted from a fieldtype
Expand All @@ -1040,21 +1048,22 @@ function is_derived_type(@nospecialize(t), @nospecialize(c)) # try to find `type
# it cannot have a reference cycle in the type graph
cF = c.types
for f in cF
is_derived_type(t, f) && return true
is_derived_type(t, f, mindepth) && return true
end
end
end
return false
end

function is_derived_type_from_any(@nospecialize(t), sources::SimpleVector)
function is_derived_type_from_any(@nospecialize(t), sources::SimpleVector, mindepth::Int)
for s in sources
is_derived_type(t, s) && return true
is_derived_type(t, s, mindepth) && return true
end
return false
end

function _limit_type_size(@nospecialize(t), @nospecialize(c), sources::SimpleVector, allowed_tuplelen::Int) # type vs. comparison which was derived from source
# type vs. comparison or which was derived from source
function _limit_type_size(@nospecialize(t), @nospecialize(c), sources::SimpleVector, depth::Int, allowed_tuplelen::Int)
if t === c
return t # quick egal test
elseif t === Union{}
Expand All @@ -1063,7 +1072,7 @@ function _limit_type_size(@nospecialize(t), @nospecialize(c), sources::SimpleVec
return t # fast path: unparameterized are always simple
elseif isa(unwrap_unionall(t), DataType) && isa(c, Type) && c !== Union{} && c <: t
return t # t is already wider than the comparison in the type lattice
elseif is_derived_type_from_any(unwrap_unionall(t), sources)
elseif is_derived_type_from_any(unwrap_unionall(t), sources, depth)
return t # t isn't something new
end
if isa(t, TypeVar)
Expand All @@ -1074,8 +1083,8 @@ function _limit_type_size(@nospecialize(t), @nospecialize(c), sources::SimpleVec
end
elseif isa(t, Union)
if isa(c, Union)
a = _limit_type_size(t.a, c.a, sources, allowed_tuplelen)
b = _limit_type_size(t.b, c.b, sources, allowed_tuplelen)
a = _limit_type_size(t.a, c.a, sources, depth, allowed_tuplelen)
b = _limit_type_size(t.b, c.b, sources, depth, allowed_tuplelen)
return Union{a, b}
end
elseif isa(t, UnionAll)
Expand All @@ -1084,11 +1093,11 @@ function _limit_type_size(@nospecialize(t), @nospecialize(c), sources::SimpleVec
cv = c.var
if tv.ub === cv.ub
if tv.lb === cv.lb
return UnionAll(tv, _limit_type_size(t.body, c.body, sources, allowed_tuplelen))
return UnionAll(tv, _limit_type_size(t.body, c.body, sources, depth + 1, allowed_tuplelen))
end
ub = tv.ub
else
ub = _limit_type_size(tv.ub, cv.ub, sources, 0)
ub = _limit_type_size(tv.ub, cv.ub, sources, depth + 1, 0)
end
if tv.lb === cv.lb
lb = tv.lb
Expand All @@ -1097,21 +1106,21 @@ function _limit_type_size(@nospecialize(t), @nospecialize(c), sources::SimpleVec
lb = Bottom
end
v2 = TypeVar(tv.name, lb, ub)
return UnionAll(v2, _limit_type_size(t{v2}, c{v2}, sources, allowed_tuplelen))
return UnionAll(v2, _limit_type_size(t{v2}, c{v2}, sources, depth + 1, allowed_tuplelen))
end
tbody = _limit_type_size(t.body, c, sources, allowed_tuplelen)
tbody = _limit_type_size(t.body, c, sources, depth + 1, allowed_tuplelen)
tbody === t.body && return t
return UnionAll(t.var, tbody)
elseif isa(c, UnionAll)
# peel off non-matching wrapper of comparison
return _limit_type_size(t, c.body, sources, allowed_tuplelen)
return _limit_type_size(t, c.body, sources, depth, allowed_tuplelen)
elseif isa(t, DataType)
if isa(c, DataType)
tP = t.parameters
cP = c.parameters
if t.name === c.name && !isempty(cP)
if isvarargtype(t)
VaT = _limit_type_size(tP[1], cP[1], sources, 0)
VaT = _limit_type_size(tP[1], cP[1], sources, depth + 1, 0)
N = tP[2]
if isa(N, TypeVar) || N === cP[2]
return Vararg{VaT, N}
Expand All @@ -1138,19 +1147,19 @@ function _limit_type_size(@nospecialize(t), @nospecialize(c), sources::SimpleVec
else
cPi = Any
end
Q[i] = _limit_type_size(Q[i], cPi, sources, 0)
Q[i] = _limit_type_size(Q[i], cPi, sources, depth + 1, 0)
end
return Tuple{Q...}
end
elseif isvarargtype(c)
# Tuple{Vararg{T}} --> Tuple{T} is OK
return _limit_type_size(t, cP[1], sources, 0)
return _limit_type_size(t, cP[1], sources, depth, 0)
end
end
if isType(t) # allow taking typeof as Type{...}, but ensure it doesn't start nesting
tt = unwrap_unionall(t.parameters[1])
if isa(tt, DataType) && !isType(tt)
is_derived_type_from_any(tt, sources) && return t
is_derived_type_from_any(tt, sources, depth) && return t
end
end
if isvarargtype(t)
Expand Down Expand Up @@ -1866,43 +1875,23 @@ function abstract_call_method(method::Method, @nospecialize(f), @nospecialize(si
end

if limited
newsig = sig
sigtuple = unwrap_unionall(sig)::DataType
msig = unwrap_unionall(method.sig)::DataType
max_spec_len = length(msig.parameters) + 1
spec_len = length(msig.parameters) + 1
ls = length(sigtuple.parameters)
if method === sv.linfo.def
# direct self-recursion permits much greater use of reducers
# without using non-local state (just the total edge)
# here we assume that complexity(specTypes) :>= complexity(sig)
comparison = sv.linfo.specTypes
l_comparison = length(unwrap_unionall(comparison).parameters)
max_spec_len = max(max_spec_len, l_comparison)
spec_len = max(spec_len, l_comparison)
else
comparison = method.sig
end
if method.isva && ls > max_spec_len
# limit length based on size of definition signature.
# for example, given function f(T, Any...), limit to 3 arguments
# instead of the default (MAX_TUPLETYPE_LEN)
fst = sigtuple.parameters[max_spec_len]
allsame = true
# allow specializing on longer arglists if all the trailing
# arguments are the same, since there is no exponential
# blowup in this case.
for i = (max_spec_len + 1):ls
if sigtuple.parameters[i] != fst
allsame = false
break
end
end
if !allsame
sigtuple = limit_tuple_type_n(sigtuple, max_spec_len)
newsig = rewrap_unionall(sigtuple, newsig)
end
end
# see if the type is still too big, and limit it further if still required
newsig = limit_type_size(newsig, comparison, sv.linfo.specTypes, max_spec_len)
# see if the type is too big, and limit it if required
newsig = limit_type_size(sig, comparison, sv.linfo.specTypes, spec_len)

if newsig !== sig
if !sv.limited
# continue inference, but limit parameter complexity to ensure (quick) convergence
Expand Down Expand Up @@ -1939,6 +1928,7 @@ function abstract_call_method(method::Method, @nospecialize(f), @nospecialize(si
end
sparams = recomputed[2]::SimpleVector
end

rt, edge = typeinf_edge(method, sig, sparams, sv)
edge !== nothing && add_backedge!(edge::MethodInstance, sv)
return rt
Expand Down
43 changes: 20 additions & 23 deletions base/sparse/higherorderfns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -926,37 +926,34 @@ end
# vectors/matrices in mixedargs in their orginal order, and such that the result of
# broadcast(parevalf, passedargstup...) is broadcast(f, mixedargs...)
@inline function capturescalars(f, mixedargs)
let makeargs = _capturescalars(mixedargs...),
parevalf = (passed...) -> f(makeargs(passed...)...),
passedsrcargstup = _capturenonscalars(mixedargs...)
let (passedsrcargstup, makeargs) = _capturescalars(mixedargs...)
parevalf = (passed...) -> f(makeargs(passed...)...)
return (parevalf, passedsrcargstup)
end
end

@inline _capturenonscalars(nonscalararg::SparseVecOrMat, mixedargs...) =
(nonscalararg, _capturenonscalars(mixedargs...)...)
@inline _capturenonscalars(scalararg, mixedargs...) =
_capturenonscalars(mixedargs...)
@inline _capturenonscalars() = ()
nonscalararg(::SparseVecOrMat) = true
nonscalararg(::Any) = false

@inline _capturescalars(nonscalararg::SparseVecOrMat, mixedargs...) =
let f = _capturescalars(mixedargs...)
(head, tail...) -> (head, f(tail...)...) # pass-through
@inline function _capturescalars()
return (), () -> ()
end
@inline function _capturescalars(arg, mixedargs...)
let (rest, f) = _capturescalars(mixedargs...)
if nonscalararg(arg)
return (arg, rest...), (head, tail...) -> (head, f(tail...)...) # pass-through to broadcast
else
return rest, (tail...) -> (arg, f(tail...)...) # add back scalararg after (in makeargs)
end
end
@inline _capturescalars(scalararg, mixedargs...) =
let f = _capturescalars(mixedargs...)
(tail...) -> (scalararg, f(tail...)...) # add scalararg
end
@inline function _capturescalars(arg) # this definition is just an optimization (to bottom out the recursion slightly sooner)
if nonscalararg(arg)
return (arg,), (head,) -> (head,) # pass-through
else
return (), () -> (arg,) # add scalararg
end
# TODO: use the implicit version once inference can handle it
# handle too-many-arguments explicitly
@inline function _capturescalars()
too_many_arguments() = ()
too_many_arguments(tail...) = throw(ArgumentError("too many"))
end
#@inline _capturescalars(nonscalararg::SparseVecOrMat) =
# (head,) -> (head,) # pass-through
#@inline _capturescalars(scalararg) =
# () -> (scalararg,) # add scalararg

# NOTE: The following two method definitions work around #19096.
broadcast(f::Tf, ::Type{T}, A::SparseMatrixCSC) where {Tf,T} = broadcast(y -> f(T, y), A)
Expand Down
4 changes: 0 additions & 4 deletions test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3339,10 +3339,6 @@ end
@test EmptyIIOtherField13175(EmptyImmutable13175(), 1.0) == EmptyIIOtherField13175(EmptyImmutable13175(), 1.0)
@test EmptyIIOtherField13175(EmptyImmutable13175(), 1.0) != EmptyIIOtherField13175(EmptyImmutable13175(), 2.0)

# issue #13183
gg13183(x::X...) where {X} = 1==0 ? gg13183(x, x) : 0
@test gg13183(5) == 0

# issue 8932 (llvm return type legalizer error)
struct Vec3_8932
x::Float32
Expand Down
20 changes: 19 additions & 1 deletion test/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,20 @@

# tests for Core.Inference correctness and precision
import Core.Inference: Const, Conditional,
const isleaftype = Core.Inference._isleaftype

# demonstrate some of the type-size limits
@test Core.Inference.limit_type_size(Ref{Complex{T} where T}, Ref, Ref, 0) == Ref
@test Core.Inference.limit_type_size(Ref{Complex{T} where T}, Ref{Complex{T} where T}, Ref, 0) == Ref{Complex{T} where T}
let comparison = Tuple{X, X} where X<:Tuple
sig = Tuple{X, X} where X<:comparison
ref = Tuple{X, X} where X
@test Core.Inference.limit_type_size(sig, comparison, comparison, 10) == comparison
@test Core.Inference.limit_type_size(sig, ref, comparison, 10) == comparison
@test Core.Inference.limit_type_size(Tuple{sig}, Tuple{ref}, comparison, 10) == Tuple{comparison}
@test Core.Inference.limit_type_size(sig, ref, Tuple{comparison}, 10) == sig
end


# issue 9770
@noinline x9770() = false
Expand Down Expand Up @@ -186,7 +200,6 @@ function find_tvar10930(arg)
end
@test find_tvar10930(Vararg{Int}) === 1

const isleaftype = Base._isleaftype

# issue #12474
@generated function f12474(::Any)
Expand Down Expand Up @@ -1225,3 +1238,8 @@ end
let t = Tuple{Type{T23786{D, N} where N where D<:Tuple{Vararg{Array{T, 1} where T, N} where N}}}
@test Core.Inference.limit_type_depth(t, 4) >: t
end

# issue #13183
_false13183 = false
gg13183(x::X...) where {X} = (_false13183 ? gg13183(x, x) : 0)
@test gg13183(5) == 0

0 comments on commit b89e88e

Please sign in to comment.