Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

typeintersect: followup cleanup for the nothrow path of type instantiation #54514

Merged
merged 7 commits into from
May 19, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 60 additions & 42 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -1499,11 +1499,11 @@ jl_unionall_t *jl_rename_unionall(jl_unionall_t *u)
return (jl_unionall_t*)t;
}

jl_value_t *jl_substitute_var_nothrow(jl_value_t *t, jl_tvar_t *var, jl_value_t *val)
jl_value_t *jl_substitute_var_nothrow(jl_value_t *t, jl_tvar_t *var, jl_value_t *val, int nothrow)
{
if (val == (jl_value_t*)var)
return t;
int nothrow = jl_is_typevar(val) ? 0 : 1;
nothrow = jl_is_typevar(val) ? 0 : nothrow;
jl_typeenv_t env = { var, val, NULL };
return inst_type_w_(t, &env, NULL, 1, nothrow);
}
Expand Down Expand Up @@ -1725,7 +1725,7 @@ void jl_precompute_memoized_dt(jl_datatype_t *dt, int cacheable)
dt->hash = typekey_hash(dt->name, jl_svec_data(dt->parameters), l, cacheable);
}

static void check_datatype_parameters(jl_typename_t *tn, jl_value_t **params, size_t np)
static int check_datatype_parameters(jl_typename_t *tn, jl_value_t **params, size_t np, int nothrow)
{
jl_value_t *wrapper = tn->wrapper;
jl_value_t **bounds;
Expand All @@ -1743,6 +1743,10 @@ static void check_datatype_parameters(jl_typename_t *tn, jl_value_t **params, si
assert(jl_is_unionall(wrapper));
jl_tvar_t *tv = ((jl_unionall_t*)wrapper)->var;
if (!within_typevar(params[i], bounds[2*i], bounds[2*i+1])) {
if (nothrow) {
JL_GC_POP();
return 1;
}
if (tv->lb != bounds[2*i] || tv->ub != bounds[2*i+1])
// pass a new version of `tv` containing the instantiated bounds
tv = jl_new_typevar(tv->name, bounds[2*i], bounds[2*i+1]);
Expand All @@ -1752,12 +1756,23 @@ static void check_datatype_parameters(jl_typename_t *tn, jl_value_t **params, si
int j;
for (j = 2*i + 2; j < 2*np; j++) {
jl_value_t *bj = bounds[j];
if (bj != (jl_value_t*)jl_any_type && bj != jl_bottom_type)
bounds[j] = jl_substitute_var(bj, tv, params[i]);
if (bj != (jl_value_t*)jl_any_type && bj != jl_bottom_type) {
int isub = j & 1;
// use different nothrow level for lb and ub substitution.
// TODO: should normal path of type instantiation follows this rule?
jl_value_t *nb = jl_substitute_var_nothrow(bj, tv, params[i], nothrow ? (isub ? 2 : 1) : 0 );
if (nb == NULL) {
assert(nothrow);
JL_GC_POP();
return 1;
}
bounds[j] = nb;
}
}
wrapper = ((jl_unionall_t*)wrapper)->body;
}
JL_GC_POP();
return 0;
}

static jl_value_t *extract_wrapper(jl_value_t *t JL_PROPAGATES_ROOT) JL_NOTSAFEPOINT JL_GLOBALLY_ROOTED
Expand Down Expand Up @@ -2004,13 +2019,8 @@ static jl_value_t *inst_datatype_inner(jl_datatype_t *dt, jl_svec_t *p, jl_value
// for whether this is even valid
if (check && !istuple) {
assert(ntp > 0);
JL_TRY {
check_datatype_parameters(tn, iparams, ntp);
}
JL_CATCH {
if (!nothrow) jl_rethrow();
if (check_datatype_parameters(tn, iparams, ntp, nothrow))
return NULL;
}
}
else if (ntp == 0 && jl_emptytuple_type != NULL) {
// empty tuple type case
Expand Down Expand Up @@ -2401,7 +2411,8 @@ static jl_value_t *inst_tuple_w_(jl_value_t *t, jl_typeenv_t *env, jl_typestack_
jl_value_t *elt = jl_svecref(tp, i);
jl_value_t *pi = inst_type_w_(elt, env, stack, check, nothrow);
if (pi == NULL) {
if (i == ntp-1 && jl_is_vararg(elt)) {
assert(nothrow);
if (nothrow == 1 || (i == ntp-1 && jl_is_vararg(elt))) {
t = NULL;
break;
}
Expand Down Expand Up @@ -2440,11 +2451,10 @@ static jl_value_t *inst_type_w_(jl_value_t *t, jl_typeenv_t *env, jl_typestack_t
jl_value_t *var = NULL;
jl_value_t *newbody = NULL;
JL_GC_PUSH3(&lb, &var, &newbody);
JL_TRY {
lb = inst_type_w_(ua->var->lb, env, stack, check, 0);
}
JL_CATCH {
if (!nothrow) jl_rethrow();
// set nothrow <= 1 to ensure lb's accuracy.
lb = inst_type_w_(ua->var->lb, env, stack, check, nothrow ? 1 : 0);
if (lb == NULL) {
assert(nothrow);
t = NULL;
}
if (t != NULL) {
Expand All @@ -2468,11 +2478,9 @@ static jl_value_t *inst_type_w_(jl_value_t *t, jl_typeenv_t *env, jl_typestack_t
if (newbody == NULL) {
t = NULL;
}
else if (newbody == (jl_value_t*)jl_emptytuple_type) {
// NTuple{0} => Tuple{} can make a typevar disappear
t = (jl_value_t*)jl_emptytuple_type;
}
else if (nothrow && !jl_has_typevar(newbody, (jl_tvar_t *)var)) {
else if (!jl_has_typevar(newbody, (jl_tvar_t *)var)) {
// inner instantiation might make a typevar disappear, e.g.
// NTuple{0,T} => Tuple{}
t = newbody;
}
else if (newbody != ua->body || var != (jl_value_t*)ua->var) {
Expand All @@ -2489,16 +2497,21 @@ static jl_value_t *inst_type_w_(jl_value_t *t, jl_typeenv_t *env, jl_typestack_t
jl_value_t *b = NULL;
JL_GC_PUSH2(&a, &b);
b = inst_type_w_(u->b, env, stack, check, nothrow);
if (nothrow) {
// ensure jl_type_union nothrow.
if (a && !(jl_is_typevar(a) || jl_is_type(a)))
a = NULL;
if (b && !(jl_is_typevar(b) || jl_is_type(b)))
b = NULL;
}
if (a != u->a || b != u->b) {
if (!check) {
// fast path for `jl_rename_unionall`.
t = jl_new_struct(jl_uniontype_type, a, b);
}
else if (nothrow && a == NULL) {
t = b;
}
else if (nothrow && b == NULL) {
t = a;
else if (a == NULL || b == NULL) {
assert(nothrow);
t = nothrow == 1 ? NULL : a == NULL ? b : a;
}
else {
assert(a != NULL && b != NULL);
Expand All @@ -2516,15 +2529,21 @@ static jl_value_t *inst_type_w_(jl_value_t *t, jl_typeenv_t *env, jl_typestack_t
JL_GC_PUSH2(&T, &N);
if (v->T) {
T = inst_type_w_(v->T, env, stack, check, nothrow);
if (T == NULL)
T = jl_bottom_type;
if (v->N) // This branch should never throw.
N = inst_type_w_(v->N, env, stack, check, 0);
if (T == NULL) {
if (nothrow == 2)
T = jl_bottom_type;
else
t = NULL;
}
if (t && v->N) {
// set nothrow <= 1 to ensure invariant parameter's accuracy.
N = inst_type_w_(v->N, env, stack, check, nothrow ? 1 : 0);
if (N == NULL)
t = NULL;
}
}
if (T != v->T || N != v->N) {
// `Vararg` is special, we'd better handle inner error at Tuple level.
if (t && (T != v->T || N != v->N))
t = (jl_value_t*)jl_wrap_vararg(T, N, check, nothrow);
}
JL_GC_POP();
return t;
}
Expand All @@ -2543,16 +2562,15 @@ static jl_value_t *inst_type_w_(jl_value_t *t, jl_typeenv_t *env, jl_typestack_t
int bound = 0;
for (i = 0; i < ntp; i++) {
jl_value_t *elt = jl_svecref(tp, i);
JL_TRY {
jl_value_t *pi = inst_type_w_(elt, env, stack, check, 0);
iparams[i] = pi;
bound |= (pi != elt);
}
JL_CATCH {
if (!nothrow) jl_rethrow();
// set nothrow <= 1 to ensure invariant parameter's accuracy.
jl_value_t *pi = inst_type_w_(elt, env, stack, check, nothrow ? 1 : 0);
if (pi == NULL) {
assert(nothrow);
t = NULL;
break;
}
if (t == NULL) break;
iparams[i] = pi;
bound |= (pi != elt);
}
// if t's parameters are not bound in the environment, return it uncopied (#9378)
if (t != NULL && bound)
Expand Down
2 changes: 1 addition & 1 deletion src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,7 @@ JL_DLLEXPORT int jl_type_morespecific_no_subtype(jl_value_t *a, jl_value_t *b);
jl_value_t *jl_instantiate_type_with(jl_value_t *t, jl_value_t **env, size_t n);
JL_DLLEXPORT jl_value_t *jl_instantiate_type_in_env(jl_value_t *ty, jl_unionall_t *env, jl_value_t **vals);
jl_value_t *jl_substitute_var(jl_value_t *t, jl_tvar_t *var, jl_value_t *val);
jl_value_t *jl_substitute_var_nothrow(jl_value_t *t, jl_tvar_t *var, jl_value_t *val);
jl_value_t *jl_substitute_var_nothrow(jl_value_t *t, jl_tvar_t *var, jl_value_t *val, int nothrow);
jl_unionall_t *jl_rename_unionall(jl_unionall_t *u);
JL_DLLEXPORT jl_value_t *jl_unwrap_unionall(jl_value_t *v JL_PROPAGATES_ROOT) JL_NOTSAFEPOINT;
JL_DLLEXPORT jl_value_t *jl_rewrap_unionall(jl_value_t *t, jl_value_t *u);
Expand Down
12 changes: 7 additions & 5 deletions src/subtype.c
Original file line number Diff line number Diff line change
Expand Up @@ -2784,7 +2784,7 @@ static jl_value_t *omit_bad_union(jl_value_t *u, jl_tvar_t *t)
res = jl_bottom_type;
}
else if (obviously_egal(var->lb, ub)) {
res = jl_substitute_var_nothrow(body, var, ub);
res = jl_substitute_var_nothrow(body, var, ub, 2);
if (res == NULL)
res = jl_bottom_type;
}
Expand Down Expand Up @@ -2958,9 +2958,11 @@ static jl_value_t *finish_unionall(jl_value_t *res JL_MAYBE_UNROOTED, jl_varbind
}
}
if (varval) {
*btemp->ub = jl_substitute_var_nothrow(iub, vb->var, varval);
if (*btemp->ub == NULL)
iub = jl_substitute_var_nothrow(iub, vb->var, varval, 2);
if (iub == NULL)
res = jl_bottom_type;
else
*btemp->ub = iub;
}
else if (iub == (jl_value_t*)vb->var) {
// TODO: this loses some constraints, such as in this test, where we replace T4<:S3 (e.g. T4==S3 since T4 only appears covariantly once) with T4<:Any
Expand Down Expand Up @@ -3091,12 +3093,12 @@ static jl_value_t *finish_unionall(jl_value_t *res JL_MAYBE_UNROOTED, jl_varbind
if (varval) {
// you can construct `T{x} where x` even if T's parameter is actually
// limited. in that case we might get an invalid instantiation here.
res = jl_substitute_var_nothrow(res, vb->var, varval);
res = jl_substitute_var_nothrow(res, vb->var, varval, 2);
// simplify chains of UnionAlls where bounds become equal
while (res != NULL && jl_is_unionall(res) && obviously_egal(((jl_unionall_t*)res)->var->lb,
((jl_unionall_t*)res)->var->ub)) {
jl_unionall_t * ures = (jl_unionall_t *)res;
res = jl_substitute_var_nothrow(ures->body, ures->var, ures->var->lb);
res = jl_substitute_var_nothrow(ures->body, ures->var, ures->var->lb, 2);
}
if (res == NULL)
res = jl_bottom_type;
Expand Down
21 changes: 19 additions & 2 deletions test/subtype.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2617,13 +2617,30 @@ end
#issue 54356
abstract type A54356{T<:Real} end
struct B54356{T} <: A54356{T} end
let S = Tuple{Val, Val{T}} where {T}, R = Tuple{Val{Val{T}}, Val{T}} where {T}
# general parameters check
struct C54356{S,T<:Union{S,Complex{S}}} end
struct D54356{S<:Real,T} end
let S = Tuple{Val, Val{T}} where {T}, R = Tuple{Val{Val{T}}, Val{T}} where {T},
SS = Tuple{Val, Val{T}, Val{T}} where {T}, RR = Tuple{Val{Val{T}}, Val{T}, Val{T}} where {T}
# parameters check for self
@testintersect(Tuple{Val{A}, A} where {B, A<:Union{Val{B}, Complex{B}}}, S{1}, R{1})
# parameters check for supertype (B54356 -> A54356)
@testintersect(Tuple{Val{A}, A} where {B, A<:Union{Val{B}, B54356{B}}}, S{1}, R{1})
# enure unused TypeVar skips the `UnionAll` wrapping
@testintersect(Tuple{Val{A}, A} where {B, A<:(Union{Val{B}, D54356{B,C}} where {C})}, S{1}, R{1})
# invariant parameter should not get narrowed
@testintersect(Tuple{Val{A}, A} where {B, A<:Union{Val{B}, Val{Union{Int,Complex{B}}}}}, S{1}, R{1})
# bit value could not be `Union` element
@testintersect(Tuple{Val{A}, A, Val{B}} where {B, A<:Union{B, Val{B}}}, SS{1}, RR{1})
@testintersect(Tuple{Val{A}, A, Val{B}} where {B, A<:Union{B, Complex{B}}}, SS{1}, Union{})
# `check_datatype_parameters` should ignore bad `Union` elements in constraint's ub
T = Tuple{Val{Union{Val{Nothing}, Val{C54356{V,V}}}}, Val{Nothing}} where {Nothing<:V<:Nothing}
@test T <: S{Nothing}
@test T <: Tuple{Val{A}, A} where {B, C, A<:Union{Val{B}, Val{C54356{B,C}}}}
@test T <: typeintersect(Tuple{Val{A}, A} where {B, C, A<:Union{Val{B}, Val{C54356{B,C}}}}, S{Nothing})
# extra check for Vararg
@testintersect(Tuple{Val{A}, A} where {B, A<:Union{Val{B}, NTuple{B,Any}}}, S{-1}, R{-1})
@testintersect(Tuple{Val{A}, A} where {B, A<:Union{Val{B}, Tuple{Any,Vararg{Any,B}}}}, S{-1}, R{-1})
@testintersect(Tuple{Val{A}, A} where {B, A<:Union{Val{B}, Tuple{Vararg{Int,Union{Int,Complex{B}}}}}}, S{1}, R{1})
# extra check for NamedTuple
@testintersect(Tuple{Val{A}, A} where {B, A<:Union{Val{B}, NamedTuple{B,Tuple{Int}}}}, S{1}, R{1})
@testintersect(Tuple{Val{A}, A} where {B, A<:Union{Val{B}, NamedTuple{B,Tuple{Int}}}}, S{(1,)}, R{(1,)})
Expand Down