Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Widen diagonal var during Type unwrapping in instanceof_tfunc #52228

Merged
merged 5 commits into from
Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,25 +95,31 @@ add_tfunc(throw, 1, 1, @nospecs((𝕃::AbstractLattice, x)->Bottom), 0)
# if isexact is false, the actual runtime type may (will) be a subtype of t
# if isconcrete is true, the actual runtime type is definitely concrete (unreachable if not valid as a typeof)
# if istype is true, the actual runtime value will definitely be a type (e.g. this is false for Union{Type{Int}, Int})
function instanceof_tfunc(@nospecialize(t), astag::Bool=false)
function instanceof_tfunc(@nospecialize(t), astag::Bool=false, @nospecialize(troot) = t)
if isa(t, Const)
if isa(t.val, Type) && valid_as_lattice(t.val, astag)
return t.val, true, isconcretetype(t.val), true
end
return Bottom, true, false, false # runtime throws on non-Type
end
t = widenconst(t)
troot = widenconst(troot)
if t === Bottom
return Bottom, true, true, false # runtime unreachable
elseif t === typeof(Bottom) || !hasintersect(t, Type)
return Bottom, true, false, false # literal Bottom or non-Type
elseif isType(t)
tp = t.parameters[1]
valid_as_lattice(tp, astag) || return Bottom, true, false, false # runtime unreachable / throws on non-Type
if troot isa UnionAll
# Free `TypeVar`s inside `Type` has violated the "diagonal" rule.
# Widen them before `UnionAll` rewraping to relax concrete constraint.
tp = widen_diagonal(tp, troot)
end
return tp, !has_free_typevars(tp), isconcretetype(tp), true
elseif isa(t, UnionAll)
t′ = unwrap_unionall(t)
t′′, isexact, isconcrete, istype = instanceof_tfunc(t′, astag)
t′′, isexact, isconcrete, istype = instanceof_tfunc(t′, astag, rewrap_unionall(t, troot))
tr = rewrap_unionall(t′′, t)
if t′′ isa DataType && t′′.name !== Tuple.name && !has_free_typevars(tr)
# a real instance must be within the declared bounds of the type,
Expand All @@ -128,8 +134,8 @@ function instanceof_tfunc(@nospecialize(t), astag::Bool=false)
end
return tr, isexact, isconcrete, istype
elseif isa(t, Union)
ta, isexact_a, isconcrete_a, istype_a = instanceof_tfunc(t.a, astag)
tb, isexact_b, isconcrete_b, istype_b = instanceof_tfunc(t.b, astag)
ta, isexact_a, isconcrete_a, istype_a = instanceof_tfunc(t.a, astag, troot)
tb, isexact_b, isconcrete_b, istype_b = instanceof_tfunc(t.b, astag, troot)
isconcrete = isconcrete_a && isconcrete_b
istype = istype_a && istype_b
# most users already handle the Union case, so here we assume that
Expand Down
5 changes: 5 additions & 0 deletions base/essentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,11 @@ function rename_unionall(@nospecialize(u))
return UnionAll(nv, body{nv})
end

# remove concrete constraint on diagonal TypeVar if it comes from troot
function widen_diagonal(@nospecialize(t), troot::UnionAll)
body = ccall(:jl_widen_diagonal, Any, (Any, Any), t, troot)
end

function isvarargtype(@nospecialize(t))
return isa(t, Core.TypeofVararg)
end
Expand Down
207 changes: 207 additions & 0 deletions src/subtype.c
Original file line number Diff line number Diff line change
Expand Up @@ -4304,6 +4304,213 @@ int jl_subtype_matching(jl_value_t *a, jl_value_t *b, jl_svec_t **penv)
return sub;
}

// type utils
static void check_diagonal(jl_value_t *t, jl_varbinding_t *troot, int param)
{
if (jl_is_uniontype(t)) {
int i, len = 0;
jl_varbinding_t *v;
for (v = troot; v != NULL; v = v->prev)
len++;
int8_t *occurs = (int8_t *)alloca(len);
for (v = troot, i = 0; v != NULL; v = v->prev, i++)
occurs[i] = v->occurs_inv | (v->occurs_cov << 2);
check_diagonal(((jl_uniontype_t *)t)->a, troot, param);
for (v = troot, i = 0; v != NULL; v = v->prev, i++) {
int8_t occurs_inv = occurs[i] & 3;
int8_t occurs_cov = occurs[i] >> 2;
occurs[i] = v->occurs_inv | (v->occurs_cov << 2);
v->occurs_inv = occurs_inv;
v->occurs_cov = occurs_cov;
}
check_diagonal(((jl_uniontype_t *)t)->b, troot, param);
for (v = troot, i = 0; v != NULL; v = v->prev, i++) {
if (v->occurs_inv < (occurs[i] & 3))
v->occurs_inv = occurs[i] & 3;
if (v->occurs_cov < (occurs[i] >> 2))
v->occurs_cov = occurs[i] >> 2;
}
}
else if (jl_is_unionall(t)) {
assert(troot != NULL);
jl_varbinding_t *v1 = troot, *v2 = troot->prev;
while (v2 != NULL) {
if (v2->var == ((jl_unionall_t *)t)->var) {
v1->prev = v2->prev;
break;
}
v1 = v2;
v2 = v2->prev;
}
check_diagonal(((jl_unionall_t *)t)->body, troot, param);
v1->prev = v2;
}
else if (jl_is_datatype(t)) {
int nparam = jl_is_tuple_type(t) ? 1 : 2;
if (nparam < param) nparam = param;
for (size_t i = 0; i < jl_nparams(t); i++) {
check_diagonal(jl_tparam(t, i), troot, nparam);
}
}
else if (jl_is_vararg(t)) {
jl_value_t *T = jl_unwrap_vararg(t);
jl_value_t *N = jl_unwrap_vararg_num(t);
int n = (N && jl_is_long(N)) ? jl_unbox_long(N) : 2;
if (T && n > 0) check_diagonal(T, troot, param);
if (T && n > 1) check_diagonal(T, troot, param);
if (N) check_diagonal(N, troot, 2);
}
else if (jl_is_typevar(t)) {
jl_varbinding_t *v = troot;
for (; v != NULL; v = v->prev) {
if (v->var == (jl_tvar_t *)t) {
if (param == 1 && v->occurs_cov < 2) v->occurs_cov++;
if (param == 2 && v->occurs_inv < 2) v->occurs_inv++;
break;
}
}
if (v == NULL)
check_diagonal(((jl_tvar_t *)t)->ub, troot, 0);
}
}

static jl_value_t *insert_nondiagonal(jl_value_t *type, jl_varbinding_t *troot, int widen2ub)
{
if (jl_is_typevar(type)) {
int concretekind = widen2ub > 1 ? 0 : 1;
jl_varbinding_t *v = troot;
for (; v != NULL; v = v->prev) {
if (v->occurs_inv == 0 &&
v->occurs_cov > concretekind &&
v->var == (jl_tvar_t *)type)
break;
}
if (v != NULL) {
if (widen2ub) {
type = insert_nondiagonal(((jl_tvar_t *)type)->ub, troot, 2);
}
else {
// we must replace each covariant occurrence of newvar with a different newvar2<:newvar (diagonal rule)
if (v->innervars == NULL)
v->innervars = jl_alloc_array_1d(jl_array_any_type, 0);
jl_value_t *newvar = NULL, *lb = v->var->lb, *ub = (jl_value_t *)v->var;
jl_array_t *innervars = v->innervars;
JL_GC_PUSH4(&newvar, &lb, &ub, &innervars);
newvar = (jl_value_t *)jl_new_typevar(v->var->name, lb, ub);
jl_array_ptr_1d_push(innervars, newvar);
JL_GC_POP();
type = newvar;
}
}
}
else if (jl_is_unionall(type)) {
jl_value_t *body = ((jl_unionall_t*)type)->body;
jl_tvar_t *var = ((jl_unionall_t*)type)->var;
jl_varbinding_t *v = troot;
for (; v != NULL; v = v->prev) {
if (v->var == var)
break;
}
if (v) v->var = NULL; // Temporarily remove `type->var` from binding list.
jl_value_t *newbody = insert_nondiagonal(body, troot, widen2ub);
if (v) v->var = var; // And restore it after inner insertation.
jl_value_t *newvar = NULL;
JL_GC_PUSH2(&newbody, &newvar);
if (body == newbody || jl_has_typevar(newbody, var)) {
if (body != newbody)
newbody = jl_new_struct(jl_unionall_type, var, newbody);
// n.b. we do not widen lb, since that would be the wrong direction
newvar = insert_nondiagonal(var->ub, troot, widen2ub);
if (newvar != var->ub) {
newvar = (jl_value_t*)jl_new_typevar(var->name, var->lb, newvar);
newbody = jl_apply_type1(newbody, newvar);
newbody = jl_type_unionall((jl_tvar_t*)newvar, newbody);
}
}
type = newbody;
JL_GC_POP();
}
else if (jl_is_uniontype(type)) {
jl_value_t *a = ((jl_uniontype_t*)type)->a;
jl_value_t *b = ((jl_uniontype_t*)type)->b;
jl_value_t *newa = NULL;
jl_value_t *newb = NULL;
JL_GC_PUSH2(&newa, &newb);
newa = insert_nondiagonal(a, troot, widen2ub);
newb = insert_nondiagonal(b, troot, widen2ub);
if (newa != a || newb != b)
type = simple_union(newa, newb);
JL_GC_POP();
}
else if (jl_is_vararg(type)) {
// As for Vararg we'd better widen it's var to ub as otherwise they are still diagonal
jl_value_t *t = jl_unwrap_vararg(type);
jl_value_t *n = jl_unwrap_vararg_num(type);
if (widen2ub == 0)
widen2ub = !(n && jl_is_long(n)) || jl_unbox_long(n) > 1;
jl_value_t *newt;
JL_GC_PUSH2(&newt, &n);
newt = insert_nondiagonal(t, troot, widen2ub);
if (t != newt)
type = (jl_value_t *)jl_wrap_vararg(newt, n, 0);
JL_GC_POP();
}
else if (jl_is_datatype(type)) {
if (jl_is_tuple_type(type)) {
jl_svec_t *newparams = NULL;
jl_value_t *newelt = NULL;
JL_GC_PUSH2(&newparams, &newelt);
for (size_t i = 0; i < jl_nparams(type); i++) {
jl_value_t *elt = jl_tparam(type, i);
newelt = insert_nondiagonal(elt, troot, widen2ub);
if (elt != newelt) {
if (!newparams)
newparams = jl_svec_copy(((jl_datatype_t*)type)->parameters);
jl_svecset(newparams, i, newelt);
}
}
if (newparams)
type = (jl_value_t*)jl_apply_tuple_type(newparams, 1);
JL_GC_POP();
}
}
return type;
}

static jl_value_t *_widen_diagonal(jl_value_t *t, jl_varbinding_t *troot) {
check_diagonal(t, troot, 0);
int any_concrete = 0;
for (jl_varbinding_t *v = troot; v != NULL; v = v->prev)
any_concrete |= v->occurs_cov > 1 && v->occurs_inv == 0;
if (!any_concrete)
return t; // no diagonal
return insert_nondiagonal(t, troot, 0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you already checked occurs_inv, it may be already fine to do wident2ub always?

Suggested change
return insert_nondiagonal(t, troot, 0);
return insert_nondiagonal(t, troot, 1);

In effect though, that may be the same as implementing it as a map of rewrap_unionall over the fields (but recursively) roughly doing map(rewrap_unionall, t.parameters, Repeated(u))?

I think the tricky cases I was thinking of were possibly something like:
Tuple{Val{S}, T, T} where {S, T<:S} or perhaps
Tuple{S, T, T} where {S, T<:S} or perhaps
Tuple{S, Vararg{T}} where {S, T<:S}

These could be safely widened to Tuple{Val, Any, Any}, Tuple{Any, Any, Any}, and Tuple{Any, Vararg{Any}} respectively, but I was hoping to instead convert them to the more precise results of Tuple{Val{S}, T1, T2} where {S, T1<:S, T2<:S}, Tuple{S, T1, T2} where {S, T1<:S, T2<:S}, and Tuple{S, Vararg{Union{T1,T2} where T1<:S where T2<:S}} where S

(though normally allocation normalization will have incorrectly broken that last type anyways, since it quickly converts that to Tuple{S, Vararg{T where T<:S}} where S which it thinks is equal to Tuple{S, Vararg{S}} where S even though it is not necessarily correct to replace T.ub with T in the expression T where T there–as that also implies a new constraint on the Vararg: that it is a diagonal Vararg when it was not intended to be)

Copy link
Member Author

@N5N3 N5N3 Nov 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems ok that

julia> CC.instanceof_tfunc(Type{Tuple{S, Vararg{T}}} where {S, T<:S})[1]
Tuple{S, Vararg{S}} where S

As S should be always concrete at runtime and Tuple{S, Vararg{S}} where S should cover all possible types.
Edit: Well we are doing inference, so that might be bad.

The current code fails for this case though

julia> CC.instanceof_tfunc(Type{Tuple{S, S, Vararg{T}}} where {S, T<:S})[1]
Tuple{S2, S1, Vararg{S}} where {S, S1<:S, S2<:S}

It shows that the widen2ub branch should be recursive, and return Tuple{S2, S1, Vararg{Any}} where {S, S1<:S, S2<:S} instead.

As for the first two examples, the current code with widen2ub = 0 should be precise, though with worse normalization.

}

static jl_value_t *widen_diagonal(jl_value_t *t, jl_unionall_t *u, jl_varbinding_t *troot)
{
jl_varbinding_t vb = { u->var, NULL, NULL, 1, 0, 0, 0, 0, 0, 0, 0, 0, NULL, troot };
jl_value_t *nt;
JL_GC_PUSH2(&vb.innervars, &nt);
if (jl_is_unionall(u->body))
nt = widen_diagonal(t, (jl_unionall_t *)u->body, &vb);
else
nt = _widen_diagonal(t, &vb);
if (vb.innervars != NULL) {
for (size_t i = 0; i < jl_array_nrows(vb.innervars); i++) {
jl_tvar_t *var = (jl_tvar_t*)jl_array_ptr_ref(vb.innervars, i);
if (jl_has_typevar(nt, var))
nt = jl_type_unionall(var, nt);
N5N3 marked this conversation as resolved.
Show resolved Hide resolved
N5N3 marked this conversation as resolved.
Show resolved Hide resolved
}
}
JL_GC_POP();
return nt;
}

JL_DLLEXPORT jl_value_t *jl_widen_diagonal(jl_value_t *t, jl_unionall_t *ua)
{
return widen_diagonal(t, ua, NULL);
}

// specificity comparison

Expand Down
13 changes: 13 additions & 0 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5593,3 +5593,16 @@ end |> only === Float64
@test Base.infer_exception_type(c::Bool -> c ? 1 : 2) == Union{}
@test Base.infer_exception_type(c::Missing -> c ? 1 : 2) == TypeError
@test Base.infer_exception_type(c::Any -> c ? 1 : 2) == TypeError

# Issue #52168
f52168(x, t::Type) = x::NTuple{2, Base.inferencebarrier(t)::Type}
@test f52168((1, 2.), Any) === (1, 2.)

# Issue #27031
let x = 1, _Any = Any
@noinline bar27031(tt::Tuple{T,T}, ::Type{Val{T}}) where {T} = notsame27031(tt)
@noinline notsame27031(tt::Tuple{T, T}) where {T} = error()
@noinline notsame27031(tt::Tuple{T, S}) where {T, S} = "OK"
foo27031() = bar27031((x, 1.0), Val{_Any})
@test foo27031() == "OK"
end
11 changes: 11 additions & 0 deletions test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8059,3 +8059,14 @@ check_globalref_lowering() = @insert_global
let src = code_lowered(check_globalref_lowering)[1]
@test length(src.code) == 2
end

# Test correctness of widen_diagonal
let widen_diagonal(x::UnionAll) = Base.rewrap_unionall(Base.widen_diagonal(Base.unwrap_unionall(x), x), x),
check_widen_diagonal(x, y) = !<:(x, y) && x <: widen_diagonal(y)
@test Tuple{Int,Float64} <: widen_diagonal(NTuple)
@test Tuple{Int,Float64} <: widen_diagonal(Tuple{T,T} where {T})
@test Tuple{Real,Int,Float64} <: widen_diagonal(Tuple{S,Vararg{T}} where {S, T<:S})
@test Tuple{Int,Int,Float64,Float64} <: widen_diagonal(Tuple{S,S,Vararg{T}} where {S, T<:S})
@test Union{Tuple{T}, Tuple{T,Int}} where {T} === widen_diagonal(Union{Tuple{T}, Tuple{T,Int}} where {T})
@test Tuple === widen_diagonal(Union{Tuple{Vararg{S}}, Tuple{Vararg{T}}} where {S, T})
end
Loading