Skip to content

Commit

Permalink
Subtype: Fix some diagonal rule related false alarm (#53034)
Browse files Browse the repository at this point in the history
close #33137
close #53021

---------

Co-authored-by: Jameson Nash <vtjnash@gmail.com>
  • Loading branch information
N5N3 and vtjnash authored Jan 26, 2024
1 parent 1e45aba commit 5cf1021
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 45 deletions.
80 changes: 49 additions & 31 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,43 @@ static void isort_union(jl_value_t **a, size_t len) JL_NOTSAFEPOINT
}
}

static int simple_subtype(jl_value_t *a, jl_value_t *b, int hasfree, int isUnion)
{
if (a == jl_bottom_type || b == (jl_value_t*)jl_any_type)
return 1;
if (jl_egal(a, b))
return 1;
if (hasfree == 0) {
int mergeable = isUnion;
if (!mergeable) // issue #24521: don't merge Type{T} where typeof(T) varies
mergeable = !(jl_is_type_type(a) && jl_is_type_type(b) &&
jl_typeof(jl_tparam0(a)) != jl_typeof(jl_tparam0(b)));
return mergeable && jl_subtype(a, b);
}
if (jl_is_typevar(a)) {
jl_value_t *na = ((jl_tvar_t*)a)->ub;
hasfree &= jl_has_free_typevars(na);
return simple_subtype(na, b, hasfree, isUnion);
}
if (jl_is_typevar(b)) {
jl_value_t *nb = ((jl_tvar_t*)b)->lb;
// This branch is not valid if `b` obeys diagonal rule,
// as it might normalize `Union` into a single `TypeVar`, e.g.
// Tuple{Union{Int,T},T} where {T>:Int} != Tuple{T,T} where {T>:Int}
if (is_leaf_bound(nb))
return 0;
hasfree &= jl_has_free_typevars(nb) << 1;
return simple_subtype(a, nb, hasfree, isUnion);
}
if (b==(jl_value_t*)jl_datatype_type || b==(jl_value_t*)jl_typeofbottom_type) {
// This branch is not valid for `Union`/`UnionAll`, e.g.
// (Type{Union{Int,T2} where {T2<:T1}} where {T1}){Int} == Type{Int64}
// (Type{Union{Int,T1}} where {T1}){Int} == Type{Int64}
return jl_is_type_type(a) && jl_typeof(jl_tparam0(a)) == b;
}
return 0;
}

JL_DLLEXPORT jl_value_t *jl_type_union(jl_value_t **ts, size_t n)
{
if (n == 0)
Expand All @@ -580,13 +617,9 @@ JL_DLLEXPORT jl_value_t *jl_type_union(jl_value_t **ts, size_t n)
int has_free = temp[i] != NULL && jl_has_free_typevars(temp[i]);
for (j = 0; j < nt; j++) {
if (j != i && temp[i] && temp[j]) {
if (temp[i] == jl_bottom_type ||
temp[j] == (jl_value_t*)jl_any_type ||
jl_egal(temp[i], temp[j]) ||
(!has_free && !jl_has_free_typevars(temp[j]) &&
jl_subtype(temp[i], temp[j]))) {
int has_free2 = has_free | (jl_has_free_typevars(temp[j]) << 1);
if (simple_subtype(temp[i], temp[j], has_free2, 1))
temp[i] = NULL;
}
}
}
}
Expand All @@ -608,17 +641,7 @@ JL_DLLEXPORT jl_value_t *jl_type_union(jl_value_t **ts, size_t n)
return tu;
}

// note: this is turned off as `Union` doesn't do such normalization.
// static int simple_subtype(jl_value_t *a, jl_value_t *b)
// {
// if (jl_is_kind(b) && jl_is_type_type(a) && jl_typeof(jl_tparam0(a)) == b)
// return 1;
// if (jl_is_typevar(b) && obviously_egal(a, ((jl_tvar_t*)b)->lb))
// return 1;
// return 0;
// }

static int simple_subtype2(jl_value_t *a, jl_value_t *b, int hasfree)
static int simple_subtype2(jl_value_t *a, jl_value_t *b, int hasfree, int isUnion)
{
int subab = 0, subba = 0;
if (jl_egal(a, b)) {
Expand All @@ -630,9 +653,9 @@ static int simple_subtype2(jl_value_t *a, jl_value_t *b, int hasfree)
else if (b == jl_bottom_type || a == (jl_value_t*)jl_any_type) {
subba = 1;
}
else if (hasfree) {
// subab = simple_subtype(a, b);
// subba = simple_subtype(b, a);
else if (hasfree != 0) {
subab = simple_subtype(a, b, hasfree, isUnion);
subba = simple_subtype(b, a, hasfree, isUnion);
}
else if (jl_is_type_type(a) && jl_is_type_type(b) &&
jl_typeof(jl_tparam0(a)) != jl_typeof(jl_tparam0(b))) {
Expand Down Expand Up @@ -664,10 +687,11 @@ jl_value_t *simple_union(jl_value_t *a, jl_value_t *b)
// first remove cross-redundancy and check if `a >: b` or `a <: b`.
for (i = 0; i < nta; i++) {
if (temp[i] == NULL) continue;
int hasfree = jl_has_free_typevars(temp[i]);
int has_free = jl_has_free_typevars(temp[i]);
for (j = nta; j < nt; j++) {
if (temp[j] == NULL) continue;
int subs = simple_subtype2(temp[i], temp[j], hasfree || jl_has_free_typevars(temp[j]));
int has_free2 = has_free | (jl_has_free_typevars(temp[j]) << 1);
int subs = simple_subtype2(temp[i], temp[j], has_free2, 0);
int subab = subs & 1, subba = subs >> 1;
if (subab) {
temp[i] = NULL;
Expand Down Expand Up @@ -697,15 +721,9 @@ jl_value_t *simple_union(jl_value_t *a, jl_value_t *b)
size_t jmax = i < nta ? nta : nt;
for (j = jmin; j < jmax; j++) {
if (j != i && temp[i] && temp[j]) {
if (temp[i] == jl_bottom_type ||
temp[j] == (jl_value_t*)jl_any_type ||
jl_egal(temp[i], temp[j]) ||
(!has_free && !jl_has_free_typevars(temp[j]) &&
// issue #24521: don't merge Type{T} where typeof(T) varies
!(jl_is_type_type(temp[i]) && jl_is_type_type(temp[j]) && jl_typeof(jl_tparam0(temp[i])) != jl_typeof(jl_tparam0(temp[j]))) &&
jl_subtype(temp[i], temp[j]))) {
int has_free2 = has_free | (jl_has_free_typevars(temp[j]) << 1);
if (simple_subtype(temp[i], temp[j], has_free2, 0))
temp[i] = NULL;
}
}
}
}
Expand Down Expand Up @@ -769,7 +787,7 @@ jl_value_t *simple_intersect(jl_value_t *a, jl_value_t *b, int overesi)
int hasfree = jl_has_free_typevars(temp[i]);
for (j = nta; j < nt; j++) {
if (temp[j] == NULL) continue;
int subs = simple_subtype2(temp[i], temp[j], hasfree || jl_has_free_typevars(temp[j]));
int subs = simple_subtype2(temp[i], temp[j], hasfree || jl_has_free_typevars(temp[j]), 0);
int subab = subs & 1, subba = subs >> 1;
if (subba && !subab) {
stemp[i] = -1;
Expand Down
2 changes: 2 additions & 0 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -1496,6 +1496,8 @@ static inline int jl_field_isconst(jl_datatype_t *st, int i) JL_NOTSAFEPOINT

JL_DLLEXPORT int jl_subtype(jl_value_t *a, jl_value_t *b);

int is_leaf_bound(jl_value_t *v) JL_NOTSAFEPOINT;

STATIC_INLINE int jl_is_kind(jl_value_t *v) JL_NOTSAFEPOINT
{
return (v==(jl_value_t*)jl_uniontype_type || v==(jl_value_t*)jl_datatype_type ||
Expand Down
6 changes: 3 additions & 3 deletions src/subtype.c
Original file line number Diff line number Diff line change
Expand Up @@ -805,7 +805,7 @@ static int subtype_var(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int R, int pa
// check that a type is concrete or quasi-concrete (Type{T}).
// this is used to check concrete typevars:
// issubtype is false if the lower bound of a concrete type var is not concrete.
static int is_leaf_bound(jl_value_t *v) JL_NOTSAFEPOINT
int is_leaf_bound(jl_value_t *v) JL_NOTSAFEPOINT
{
if (v == jl_bottom_type)
return 1;
Expand Down Expand Up @@ -1997,7 +1997,7 @@ static int obvious_subtype(jl_value_t *x, jl_value_t *y, jl_value_t *y0, int *su
if (var_occurs_invariant(body, (jl_tvar_t*)b))
return 0;
}
if (nparams_expanded_x > npy && jl_is_typevar(b) && concrete_min(a1) > 1) {
if (nparams_expanded_x > npy && jl_is_typevar(b) && is_leaf_typevar((jl_tvar_t *)b) && concrete_min(a1) > 1) {
// diagonal rule for 2 or more elements: they must all be concrete on the LHS
*subtype = 0;
return 1;
Expand All @@ -2008,7 +2008,7 @@ static int obvious_subtype(jl_value_t *x, jl_value_t *y, jl_value_t *y0, int *su
}
for (; i < nparams_expanded_x; i++) {
jl_value_t *a = (vx != JL_VARARG_NONE && i >= npx - 1) ? vxt : jl_tparam(x, i);
if (i > npy && jl_is_typevar(b)) { // i == npy implies a == a1
if (i > npy && jl_is_typevar(b) && is_leaf_typevar((jl_tvar_t *)b)) { // i == npy implies a == a1
// diagonal rule: all the later parameters are also constrained to be type-equal to the first
jl_value_t *a2 = a;
jl_value_t *au = jl_unwrap_unionall(a);
Expand Down
15 changes: 13 additions & 2 deletions test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,8 @@ k11840(::Type{Union{Tuple{Int32}, Tuple{Int64}}}) = '2'
# issue #20511
f20511(x::DataType) = 0
f20511(x) = 1
Type{Integer} # cache this
@test f20511(Union{Integer,T} where T <: Unsigned) == 1
Type{AbstractSet} # cache this
@test f20511(Union{AbstractSet,Set{T}} where T) == 1

# join
@test typejoin(Int8,Int16) === Signed
Expand Down Expand Up @@ -8101,3 +8101,14 @@ end

# #52433
@test_throws ErrorException Core.Intrinsics.pointerref(Ptr{Vector{Int64}}(C_NULL), 1, 0)

# #53034 (Union normalization for typevar elimination)
@test Tuple{Int,Any} <: Tuple{Union{Int,T},T} where {T>:Int}
@test Tuple{Int,Any} <: Tuple{Union{Int,T},T} where {T>:Integer}
# #53034 (Union normalization for Type elimination)
@test Int isa Type{Union{Int,T2} where {T2<:T1}} where {T1}
@test Int isa Type{Union{Int,T1}} where {T1}
@test Int isa Union{UnionAll, Type{Union{Int,T2} where {T2<:T1}}} where {T1}
@test Int isa Union{Union, Type{Union{Int,T1}}} where {T1}
@test_broken Int isa Union{UnionAll, Type{Union{Int,T2} where {T2<:T1}} where {T1}}
@test_broken Int isa Union{Union, Type{Union{Int,T1}} where {T1}}
19 changes: 10 additions & 9 deletions test/subtype.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,14 @@ function test_diagonal()
@test isequal_type(Ref{Tuple{T, T} where Int<:T<:Int},
Ref{Tuple{S, S}} where Int<:S<:Int)

# issue #53021
@test Tuple{X, X} where {X<:Union{}} <: Tuple{X, X, Vararg{Any}} where {Int<:X<:Int}
@test Tuple{Integer, X, Vararg{X}} where {X<:Int} <: Tuple{Any, Vararg{X}} where {X>:Int}
@test Tuple{Any, X, Vararg{X}} where {X<:Int} <: Tuple{Vararg{X}} where X>:Integer
@test Tuple{Integer, Integer, Any, Vararg{Any}} <: Tuple{Vararg{X}} where X>:Integer
# issue #53019
@test Tuple{T,T} where {T<:Int} <: Tuple{T,T} where {T>:Int}

let A = Tuple{Int,Int8,Vector{Integer}},
B = Tuple{T,T,Vector{T}} where T>:Integer,
C = Tuple{T,T,Vector{Union{Integer,T}}} where T
Expand Down Expand Up @@ -1260,14 +1268,7 @@ let a = Tuple{Tuple{T2,4},T6} where T2 where T6,
end
let a = Tuple{T3,Int64,Tuple{T3}} where T3,
b = Tuple{S3,S3,S4} where S4 where S3
I1 = typeintersect(a, b)
I2 = typeintersect(b, a)
@test I1 <: I2
@test I2 <: I1
@test_broken I1 <: a
@test I2 <: a
@test I1 <: b
@test I2 <: b
@testintersect(a, b, Tuple{Int64, Int64, Tuple{Int64}})
end
let a = Tuple{T1,Val{T2},T2} where T2 where T1,
b = Tuple{Float64,S1,S2} where S2 where S1
Expand Down Expand Up @@ -2445,7 +2446,7 @@ abstract type P47654{A} end
@test_broken typeintersect(Type{Tuple{Array{T,1} where T}}, UnionAll) != Union{}

#issue 33137
@test_broken (Tuple{Q,Int} where Q<:Int) <: Tuple{T,T} where T
@test (Tuple{Q,Int} where Q<:Int) <: Tuple{T,T} where T

# issue 24333
@test (Type{Union{Ref,Cvoid}} <: Type{Union{T,Cvoid}} where T)
Expand Down

0 comments on commit 5cf1021

Please sign in to comment.