Skip to content

Commit

Permalink
typeintersect: followup cleanup for the nothrow path of type instanti…
Browse files Browse the repository at this point in the history
…ation (#54514)

Adopt suggestions from
#54465 (review)
and fix various added regession & residual MWE.

(cherry picked from commit af545b9)
  • Loading branch information
N5N3 authored and KristofferC committed May 22, 2024
1 parent 92fccc8 commit d67cc7d
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 50 deletions.
109 changes: 67 additions & 42 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -1499,11 +1499,11 @@ jl_unionall_t *jl_rename_unionall(jl_unionall_t *u)
return (jl_unionall_t*)t;
}

jl_value_t *jl_substitute_var_nothrow(jl_value_t *t, jl_tvar_t *var, jl_value_t *val)
jl_value_t *jl_substitute_var_nothrow(jl_value_t *t, jl_tvar_t *var, jl_value_t *val, int nothrow)
{
if (val == (jl_value_t*)var)
return t;
int nothrow = jl_is_typevar(val) ? 0 : 1;
nothrow = jl_is_typevar(val) ? 0 : nothrow;
jl_typeenv_t env = { var, val, NULL };
return inst_type_w_(t, &env, NULL, 1, nothrow);
}
Expand Down Expand Up @@ -1725,7 +1725,7 @@ void jl_precompute_memoized_dt(jl_datatype_t *dt, int cacheable)
dt->hash = typekey_hash(dt->name, jl_svec_data(dt->parameters), l, cacheable);
}

static void check_datatype_parameters(jl_typename_t *tn, jl_value_t **params, size_t np)
static int check_datatype_parameters(jl_typename_t *tn, jl_value_t **params, size_t np, int nothrow)
{
jl_value_t *wrapper = tn->wrapper;
jl_value_t **bounds;
Expand All @@ -1743,6 +1743,10 @@ static void check_datatype_parameters(jl_typename_t *tn, jl_value_t **params, si
assert(jl_is_unionall(wrapper));
jl_tvar_t *tv = ((jl_unionall_t*)wrapper)->var;
if (!within_typevar(params[i], bounds[2*i], bounds[2*i+1])) {
if (nothrow) {
JL_GC_POP();
return 1;
}
if (tv->lb != bounds[2*i] || tv->ub != bounds[2*i+1])
// pass a new version of `tv` containing the instantiated bounds
tv = jl_new_typevar(tv->name, bounds[2*i], bounds[2*i+1]);
Expand All @@ -1752,12 +1756,26 @@ static void check_datatype_parameters(jl_typename_t *tn, jl_value_t **params, si
int j;
for (j = 2*i + 2; j < 2*np; j++) {
jl_value_t *bj = bounds[j];
if (bj != (jl_value_t*)jl_any_type && bj != jl_bottom_type)
bounds[j] = jl_substitute_var(bj, tv, params[i]);
if (bj != (jl_value_t*)jl_any_type && bj != jl_bottom_type) {
int isub = j & 1;
// use different nothrow level for lb and ub substitution.
// TODO: This assuming the top instantiation could only start with
// `nothrow == 2` or `nothrow == 0`. If `nothrow` is initially set to 1
// then we might miss some inner error, perhaps the normal path should
// also follow this rule?
jl_value_t *nb = jl_substitute_var_nothrow(bj, tv, params[i], nothrow ? (isub ? 2 : 1) : 0 );
if (nb == NULL) {
assert(nothrow);
JL_GC_POP();
return 1;
}
bounds[j] = nb;
}
}
wrapper = ((jl_unionall_t*)wrapper)->body;
}
JL_GC_POP();
return 0;
}

static jl_value_t *extract_wrapper(jl_value_t *t JL_PROPAGATES_ROOT) JL_NOTSAFEPOINT JL_GLOBALLY_ROOTED
Expand Down Expand Up @@ -2004,13 +2022,8 @@ static jl_value_t *inst_datatype_inner(jl_datatype_t *dt, jl_svec_t *p, jl_value
// for whether this is even valid
if (check && !istuple) {
assert(ntp > 0);
JL_TRY {
check_datatype_parameters(tn, iparams, ntp);
}
JL_CATCH {
if (!nothrow) jl_rethrow();
if (check_datatype_parameters(tn, iparams, ntp, nothrow))
return NULL;
}
}
else if (ntp == 0 && jl_emptytuple_type != NULL) {
// empty tuple type case
Expand Down Expand Up @@ -2401,7 +2414,8 @@ static jl_value_t *inst_tuple_w_(jl_value_t *t, jl_typeenv_t *env, jl_typestack_
jl_value_t *elt = jl_svecref(tp, i);
jl_value_t *pi = inst_type_w_(elt, env, stack, check, nothrow);
if (pi == NULL) {
if (i == ntp-1 && jl_is_vararg(elt)) {
assert(nothrow);
if (nothrow == 1 || (i == ntp-1 && jl_is_vararg(elt))) {
t = NULL;
break;
}
Expand All @@ -2420,6 +2434,10 @@ static jl_value_t *inst_tuple_w_(jl_value_t *t, jl_typeenv_t *env, jl_typestack_
return t;
}

// `nothrow` means that when type checking fails, the type instantiation should
// return `NULL` instead of immediately throwing an error. If `nothrow` == 2 then
// we further assume that the imprecise instantiation for non invariant parameters
// is acceptable, and inner error (`NULL`) would be ignored.
static jl_value_t *inst_type_w_(jl_value_t *t, jl_typeenv_t *env, jl_typestack_t *stack, int check, int nothrow)
{
size_t i;
Expand All @@ -2440,11 +2458,10 @@ static jl_value_t *inst_type_w_(jl_value_t *t, jl_typeenv_t *env, jl_typestack_t
jl_value_t *var = NULL;
jl_value_t *newbody = NULL;
JL_GC_PUSH3(&lb, &var, &newbody);
JL_TRY {
lb = inst_type_w_(ua->var->lb, env, stack, check, 0);
}
JL_CATCH {
if (!nothrow) jl_rethrow();
// set nothrow <= 1 to ensure lb's accuracy.
lb = inst_type_w_(ua->var->lb, env, stack, check, nothrow ? 1 : 0);
if (lb == NULL) {
assert(nothrow);
t = NULL;
}
if (t != NULL) {
Expand All @@ -2468,11 +2485,9 @@ static jl_value_t *inst_type_w_(jl_value_t *t, jl_typeenv_t *env, jl_typestack_t
if (newbody == NULL) {
t = NULL;
}
else if (newbody == (jl_value_t*)jl_emptytuple_type) {
// NTuple{0} => Tuple{} can make a typevar disappear
t = (jl_value_t*)jl_emptytuple_type;
}
else if (nothrow && !jl_has_typevar(newbody, (jl_tvar_t *)var)) {
else if (!jl_has_typevar(newbody, (jl_tvar_t *)var)) {
// inner instantiation might make a typevar disappear, e.g.
// NTuple{0,T} => Tuple{}
t = newbody;
}
else if (newbody != ua->body || var != (jl_value_t*)ua->var) {
Expand All @@ -2489,16 +2504,21 @@ static jl_value_t *inst_type_w_(jl_value_t *t, jl_typeenv_t *env, jl_typestack_t
jl_value_t *b = NULL;
JL_GC_PUSH2(&a, &b);
b = inst_type_w_(u->b, env, stack, check, nothrow);
if (nothrow) {
// ensure jl_type_union nothrow.
if (a && !(jl_is_typevar(a) || jl_is_type(a)))
a = NULL;
if (b && !(jl_is_typevar(b) || jl_is_type(b)))
b = NULL;
}
if (a != u->a || b != u->b) {
if (!check) {
// fast path for `jl_rename_unionall`.
t = jl_new_struct(jl_uniontype_type, a, b);
}
else if (nothrow && a == NULL) {
t = b;
}
else if (nothrow && b == NULL) {
t = a;
else if (a == NULL || b == NULL) {
assert(nothrow);
t = nothrow == 1 ? NULL : a == NULL ? b : a;
}
else {
assert(a != NULL && b != NULL);
Expand All @@ -2516,15 +2536,21 @@ static jl_value_t *inst_type_w_(jl_value_t *t, jl_typeenv_t *env, jl_typestack_t
JL_GC_PUSH2(&T, &N);
if (v->T) {
T = inst_type_w_(v->T, env, stack, check, nothrow);
if (T == NULL)
T = jl_bottom_type;
if (v->N) // This branch should never throw.
N = inst_type_w_(v->N, env, stack, check, 0);
if (T == NULL) {
if (nothrow == 2)
T = jl_bottom_type;
else
t = NULL;
}
if (t && v->N) {
// set nothrow <= 1 to ensure invariant parameter's accuracy.
N = inst_type_w_(v->N, env, stack, check, nothrow ? 1 : 0);
if (N == NULL)
t = NULL;
}
}
if (T != v->T || N != v->N) {
// `Vararg` is special, we'd better handle inner error at Tuple level.
if (t && (T != v->T || N != v->N))
t = (jl_value_t*)jl_wrap_vararg(T, N, check, nothrow);
}
JL_GC_POP();
return t;
}
Expand All @@ -2543,16 +2569,15 @@ static jl_value_t *inst_type_w_(jl_value_t *t, jl_typeenv_t *env, jl_typestack_t
int bound = 0;
for (i = 0; i < ntp; i++) {
jl_value_t *elt = jl_svecref(tp, i);
JL_TRY {
jl_value_t *pi = inst_type_w_(elt, env, stack, check, 0);
iparams[i] = pi;
bound |= (pi != elt);
}
JL_CATCH {
if (!nothrow) jl_rethrow();
// set nothrow <= 1 to ensure invariant parameter's accuracy.
jl_value_t *pi = inst_type_w_(elt, env, stack, check, nothrow ? 1 : 0);
if (pi == NULL) {
assert(nothrow);
t = NULL;
break;
}
if (t == NULL) break;
iparams[i] = pi;
bound |= (pi != elt);
}
// if t's parameters are not bound in the environment, return it uncopied (#9378)
if (t != NULL && bound)
Expand Down
2 changes: 1 addition & 1 deletion src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,7 @@ JL_DLLEXPORT int jl_type_morespecific_no_subtype(jl_value_t *a, jl_value_t *b);
jl_value_t *jl_instantiate_type_with(jl_value_t *t, jl_value_t **env, size_t n);
JL_DLLEXPORT jl_value_t *jl_instantiate_type_in_env(jl_value_t *ty, jl_unionall_t *env, jl_value_t **vals);
jl_value_t *jl_substitute_var(jl_value_t *t, jl_tvar_t *var, jl_value_t *val);
jl_value_t *jl_substitute_var_nothrow(jl_value_t *t, jl_tvar_t *var, jl_value_t *val);
jl_value_t *jl_substitute_var_nothrow(jl_value_t *t, jl_tvar_t *var, jl_value_t *val, int nothrow);
jl_unionall_t *jl_rename_unionall(jl_unionall_t *u);
JL_DLLEXPORT jl_value_t *jl_unwrap_unionall(jl_value_t *v JL_PROPAGATES_ROOT) JL_NOTSAFEPOINT;
JL_DLLEXPORT jl_value_t *jl_rewrap_unionall(jl_value_t *t, jl_value_t *u);
Expand Down
12 changes: 7 additions & 5 deletions src/subtype.c
Original file line number Diff line number Diff line change
Expand Up @@ -2784,7 +2784,7 @@ static jl_value_t *omit_bad_union(jl_value_t *u, jl_tvar_t *t)
res = jl_bottom_type;
}
else if (obviously_egal(var->lb, ub)) {
res = jl_substitute_var_nothrow(body, var, ub);
res = jl_substitute_var_nothrow(body, var, ub, 2);
if (res == NULL)
res = jl_bottom_type;
}
Expand Down Expand Up @@ -2958,9 +2958,11 @@ static jl_value_t *finish_unionall(jl_value_t *res JL_MAYBE_UNROOTED, jl_varbind
}
}
if (varval) {
*btemp->ub = jl_substitute_var_nothrow(iub, vb->var, varval);
if (*btemp->ub == NULL)
iub = jl_substitute_var_nothrow(iub, vb->var, varval, 2);
if (iub == NULL)
res = jl_bottom_type;
else
*btemp->ub = iub;
}
else if (iub == (jl_value_t*)vb->var) {
// TODO: this loses some constraints, such as in this test, where we replace T4<:S3 (e.g. T4==S3 since T4 only appears covariantly once) with T4<:Any
Expand Down Expand Up @@ -3091,12 +3093,12 @@ static jl_value_t *finish_unionall(jl_value_t *res JL_MAYBE_UNROOTED, jl_varbind
if (varval) {
// you can construct `T{x} where x` even if T's parameter is actually
// limited. in that case we might get an invalid instantiation here.
res = jl_substitute_var_nothrow(res, vb->var, varval);
res = jl_substitute_var_nothrow(res, vb->var, varval, 2);
// simplify chains of UnionAlls where bounds become equal
while (res != NULL && jl_is_unionall(res) && obviously_egal(((jl_unionall_t*)res)->var->lb,
((jl_unionall_t*)res)->var->ub)) {
jl_unionall_t * ures = (jl_unionall_t *)res;
res = jl_substitute_var_nothrow(ures->body, ures->var, ures->var->lb);
res = jl_substitute_var_nothrow(ures->body, ures->var, ures->var->lb, 2);
}
if (res == NULL)
res = jl_bottom_type;
Expand Down
21 changes: 19 additions & 2 deletions test/subtype.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2617,13 +2617,30 @@ end
#issue 54356
abstract type A54356{T<:Real} end
struct B54356{T} <: A54356{T} end
let S = Tuple{Val, Val{T}} where {T}, R = Tuple{Val{Val{T}}, Val{T}} where {T}
# general parameters check
struct C54356{S,T<:Union{S,Complex{S}}} end
struct D54356{S<:Real,T} end
let S = Tuple{Val, Val{T}} where {T}, R = Tuple{Val{Val{T}}, Val{T}} where {T},
SS = Tuple{Val, Val{T}, Val{T}} where {T}, RR = Tuple{Val{Val{T}}, Val{T}, Val{T}} where {T}
# parameters check for self
@testintersect(Tuple{Val{A}, A} where {B, A<:Union{Val{B}, Complex{B}}}, S{1}, R{1})
# parameters check for supertype (B54356 -> A54356)
@testintersect(Tuple{Val{A}, A} where {B, A<:Union{Val{B}, B54356{B}}}, S{1}, R{1})
# enure unused TypeVar skips the `UnionAll` wrapping
@testintersect(Tuple{Val{A}, A} where {B, A<:(Union{Val{B}, D54356{B,C}} where {C})}, S{1}, R{1})
# invariant parameter should not get narrowed
@testintersect(Tuple{Val{A}, A} where {B, A<:Union{Val{B}, Val{Union{Int,Complex{B}}}}}, S{1}, R{1})
# bit value could not be `Union` element
@testintersect(Tuple{Val{A}, A, Val{B}} where {B, A<:Union{B, Val{B}}}, SS{1}, RR{1})
@testintersect(Tuple{Val{A}, A, Val{B}} where {B, A<:Union{B, Complex{B}}}, SS{1}, Union{})
# `check_datatype_parameters` should ignore bad `Union` elements in constraint's ub
T = Tuple{Val{Union{Val{Nothing}, Val{C54356{V,V}}}}, Val{Nothing}} where {Nothing<:V<:Nothing}
@test T <: S{Nothing}
@test T <: Tuple{Val{A}, A} where {B, C, A<:Union{Val{B}, Val{C54356{B,C}}}}
@test T <: typeintersect(Tuple{Val{A}, A} where {B, C, A<:Union{Val{B}, Val{C54356{B,C}}}}, S{Nothing})
# extra check for Vararg
@testintersect(Tuple{Val{A}, A} where {B, A<:Union{Val{B}, NTuple{B,Any}}}, S{-1}, R{-1})
@testintersect(Tuple{Val{A}, A} where {B, A<:Union{Val{B}, Tuple{Any,Vararg{Any,B}}}}, S{-1}, R{-1})
@testintersect(Tuple{Val{A}, A} where {B, A<:Union{Val{B}, Tuple{Vararg{Int,Union{Int,Complex{B}}}}}}, S{1}, R{1})
# extra check for NamedTuple
@testintersect(Tuple{Val{A}, A} where {B, A<:Union{Val{B}, NamedTuple{B,Tuple{Int}}}}, S{1}, R{1})
@testintersect(Tuple{Val{A}, A} where {B, A<:Union{Val{B}, NamedTuple{B,Tuple{Int}}}}, S{(1,)}, R{(1,)})
Expand Down

0 comments on commit d67cc7d

Please sign in to comment.