Skip to content

Commit

Permalink
Subtype: avoid false alarm caused by eager forall_exists_subtype. (J…
Browse files Browse the repository at this point in the history
…uliaLang#48441)

* Avoid earsing `Runion` within nested `forall_exists_subtype`

If `Runion.more != 0` we‘d better not erase the local `Runion` as we need it if the subtyping fails after.

This commit replaces `forall_exists_subtype` with a local version.
It first tries `forall_exists_subtype` and estimates the "problem scale".
If subtyping fails and the scale looks small then it switches to the slow path.

TODO: At present, the "problem scale" only counts the number of checked `Lunion`s.
But perhaps we need a more accurate result (e.g. sum of `Runion.depth`)

* Change the reversed subtyping into a local check.

Make sure we don't forget the bound in `env`.
(And we can fuse `local_forall_exists_subtype`)

* Optimization for non-union invariant parameter.
  • Loading branch information
N5N3 committed Mar 17, 2023
1 parent 2117018 commit 2a2068d
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 41 deletions.
119 changes: 82 additions & 37 deletions src/subtype.c
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ static jl_value_t *pick_union_element(jl_value_t *u JL_PROPAGATES_ROOT, jl_stenv
return u;
}

static int forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param);
static int local_forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param, int limit_slow);

// subtype for variable bounds consistency check. needs its own forall/exists environment.
static int subtype_ccheck(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
Expand All @@ -571,17 +571,7 @@ static int subtype_ccheck(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
if (x == (jl_value_t*)jl_any_type && jl_is_datatype(y))
return 0;
jl_saved_unionstate_t oldLunions; push_unionstate(&oldLunions, &e->Lunions);
jl_saved_unionstate_t oldRunions; push_unionstate(&oldRunions, &e->Runions);
int sub;
e->Lunions.used = e->Runions.used = 0;
e->Runions.depth = 0;
e->Runions.more = 0;
e->Lunions.depth = 0;
e->Lunions.more = 0;

sub = forall_exists_subtype(x, y, e, 0);

pop_unionstate(&e->Runions, &oldRunions);
int sub = local_forall_exists_subtype(x, y, e, 0, 1);
pop_unionstate(&e->Lunions, &oldLunions);
return sub;
}
Expand Down Expand Up @@ -1362,6 +1352,72 @@ static int is_definite_length_tuple_type(jl_value_t *x)
return k == JL_VARARG_NONE || k == JL_VARARG_INT;
}

static int _forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param, int *count, int *noRmore);

static int may_contain_union_decision(jl_value_t *x, jl_stenv_t *e, jl_typeenv_t *log) JL_NOTSAFEPOINT
{
if (x == NULL || x == (jl_value_t*)jl_any_type || x == jl_bottom_type)
return 0;
if (jl_is_unionall(x))
return may_contain_union_decision(((jl_unionall_t *)x)->body, e, log);
if (jl_is_datatype(x)) {
jl_datatype_t *xd = (jl_datatype_t *)x;
for (int i = 0; i < jl_nparams(xd); i++) {
jl_value_t *param = jl_tparam(xd, i);
if (jl_is_vararg(param))
param = jl_unwrap_vararg(param);
if (may_contain_union_decision(param, e, log))
return 1;
}
return 0;
}
if (!jl_is_typevar(x))
return 1;
jl_typeenv_t *t = log;
while (t != NULL) {
if (x == (jl_value_t *)t->var)
return 1;
t = t->prev;
}
jl_typeenv_t newlog = { (jl_tvar_t*)x, NULL, log };
jl_varbinding_t *xb = lookup(e, (jl_tvar_t *)x);
return may_contain_union_decision(xb ? xb->lb : ((jl_tvar_t *)x)->lb, e, &newlog) ||
may_contain_union_decision(xb ? xb->ub : ((jl_tvar_t *)x)->ub, e, &newlog);
}

static int local_forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param, int limit_slow)
{
int16_t oldRmore = e->Runions.more;
int sub;
if (may_contain_union_decision(y, e, NULL) && pick_union_decision(e, 1) == 0) {
jl_saved_unionstate_t oldRunions; push_unionstate(&oldRunions, &e->Runions);
e->Lunions.used = e->Runions.used = 0;
e->Lunions.depth = e->Runions.depth = 0;
e->Lunions.more = e->Runions.more = 0;
int count = 0, noRmore = 0;
sub = _forall_exists_subtype(x, y, e, param, &count, &noRmore);
pop_unionstate(&e->Runions, &oldRunions);
// we should not try the slow path if `forall_exists_subtype` has tested all cases;
// Once limit_slow == 1, also skip it if
// 1) `forall_exists_subtype` return false
// 2) the left `Union` looks big
if (noRmore || (limit_slow && (count > 3 || !sub)))
e->Runions.more = oldRmore;
}
else {
// slow path
e->Lunions.used = 0;
while (1) {
e->Lunions.more = 0;
e->Lunions.depth = 0;
sub = subtype(x, y, e, param);
if (!sub || !next_union_state(e, 0))
break;
}
}
return sub;
}

static int forall_exists_equal(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
{
if (obviously_egal(x, y)) return 1;
Expand All @@ -1380,33 +1436,13 @@ static int forall_exists_equal(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
}

jl_saved_unionstate_t oldLunions; push_unionstate(&oldLunions, &e->Lunions);
e->Lunions.used = 0;
int sub;

if (!jl_has_free_typevars(x) || !jl_has_free_typevars(y)) {
jl_saved_unionstate_t oldRunions; push_unionstate(&oldRunions, &e->Runions);
e->Runions.used = 0;
e->Runions.depth = 0;
e->Runions.more = 0;
e->Lunions.depth = 0;
e->Lunions.more = 0;

sub = forall_exists_subtype(x, y, e, 2);

pop_unionstate(&e->Runions, &oldRunions);
}
else {
while (1) {
e->Lunions.more = 0;
e->Lunions.depth = 0;
sub = subtype(x, y, e, 2);
if (!sub || !next_union_state(e, 0))
break;
}
}
int limit_slow = !jl_has_free_typevars(x) || !jl_has_free_typevars(y);
int sub = local_forall_exists_subtype(x, y, e, 2, limit_slow) &&
local_forall_exists_subtype(y, x, e, 0, 0);

pop_unionstate(&e->Lunions, &oldLunions);
return sub && subtype(y, x, e, 0);
return sub;
}

static int exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, jl_value_t *saved, jl_savedenv_t *se, int param)
Expand All @@ -1433,7 +1469,7 @@ static int exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, jl_value_
}
}

static int forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param)
static int _forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param, int *count, int *noRmore)
{
// The depth recursion has the following shape, after simplification:
// ∀₁
Expand All @@ -1446,8 +1482,12 @@ static int forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, in

e->Lunions.used = 0;
int sub;
if (count) *count = 0;
if (noRmore) *noRmore = 1;
while (1) {
sub = exists_subtype(x, y, e, saved, &se, param);
if (count) *count = (*count < 4) ? *count + 1 : 4;
if (noRmore) *noRmore = *noRmore && e->Runions.more == 0;
if (!sub || !next_union_state(e, 0))
break;
free_env(&se);
Expand All @@ -1459,6 +1499,11 @@ static int forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, in
return sub;
}

static int forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param)
{
return _forall_exists_subtype(x, y, e, param, NULL, NULL);
}

static void init_stenv(jl_stenv_t *e, jl_value_t **env, int envsz)
{
e->vars = NULL;
Expand Down
13 changes: 9 additions & 4 deletions test/subtype.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1406,6 +1406,8 @@ f24521(::Type{T}, ::Type{T}) where {T} = T
@test !(Ref{Union{Int64, Val{Number}}} <: Ref{Union{Val{T}, T}} where T)
@test !(Ref{Union{Ref{Number}, Int64}} <: Ref{Union{Ref{T}, T}} where T)
@test !(Ref{Union{Val{Number}, Int64}} <: Ref{Union{Val{T}, T}} where T)
@test !(Val{Ref{Union{Int64, Ref{Number}}}} <: Val{S} where {S<:Ref{Union{Ref{T}, T}} where T})
@test !(Tuple{Ref{Union{Int64, Ref{Number}}}} <: Tuple{S} where {S<:Ref{Union{Ref{T}, T}} where T})

# issue #26180
@test !(Ref{Union{Ref{Int64}, Ref{Number}}} <: Ref{Ref{T}} where T)
Expand Down Expand Up @@ -2270,8 +2272,8 @@ abstract type P47654{A} end
@test_broken typeintersect(Tuple{Vector{VT}, Vector{VT}} where {N1, VT<:AbstractVector{N1}},
Tuple{Vector{VN} where {N, VN<:AbstractVector{N}}, Vector{Vector{Float64}}}) !== Union{}
#issue 40865
@test_broken Tuple{Set{Ref{Int}}, Set{Ref{Int}}} <: Tuple{Set{KV}, Set{K}} where {K,KV<:Union{K,Ref{K}}}
@test_broken Tuple{Set{Val{Int}}, Set{Val{Int}}} <: Tuple{Set{KV}, Set{K}} where {K,KV<:Union{K,Val{K}}}
@test Tuple{Set{Ref{Int}}, Set{Ref{Int}}} <: Tuple{Set{KV}, Set{K}} where {K,KV<:Union{K,Ref{K}}}
@test Tuple{Set{Val{Int}}, Set{Val{Int}}} <: Tuple{Set{KV}, Set{K}} where {K,KV<:Union{K,Val{K}}}

#issue 39099
A = Tuple{Tuple{Int, Int, Vararg{Int, N}}, Tuple{Int, Vararg{Int, N}}, Tuple{Vararg{Int, N}}} where N
Expand Down Expand Up @@ -2308,7 +2310,10 @@ end

# try to fool a greedy algorithm that picks X=Int, Y=String here
@test Tuple{Ref{Union{Int,String}}, Ref{Union{Int,String}}} <: Tuple{Ref{Union{X,Y}}, Ref{X}} where {X,Y}
# this slightly more complex case has been broken since 1.0 (worked in 0.6)
@test_broken Tuple{Ref{Union{Int,String,Missing}}, Ref{Union{Int,String}}} <: Tuple{Ref{Union{X,Y}}, Ref{X}} where {X,Y}
@test Tuple{Ref{Union{Int,String,Missing}}, Ref{Union{Int,String}}} <: Tuple{Ref{Union{X,Y}}, Ref{X}} where {X,Y}

@test !(Tuple{Any, Any, Any} <: Tuple{Any, Vararg{T}} where T)

let a = (isodd(i) ? Pair{Char, String} : Pair{String, String} for i in 1:2000)
@test Tuple{Type{Pair{Union{Char, String}, String}}, a...} <: Tuple{Type{Pair{K, V}}, Vararg{Pair{A, B} where B where A}} where V where K
end

0 comments on commit 2a2068d

Please sign in to comment.