Skip to content

Commit

Permalink
Subtype: Code clean for union state stack. (JuliaLang#48479)
Browse files Browse the repository at this point in the history
  • Loading branch information
N5N3 committed Mar 17, 2023
1 parent ab9e44e commit 1107361
Showing 1 changed file with 42 additions and 69 deletions.
111 changes: 42 additions & 69 deletions src/subtype.c
Original file line number Diff line number Diff line change
Expand Up @@ -519,23 +519,38 @@ static jl_unionall_t *rename_unionall(jl_unionall_t *u)

static int subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param);

static jl_value_t *pick_union_element(jl_value_t *u JL_PROPAGATES_ROOT, jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT
static int next_union_state(jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT
{
jl_unionstate_t *state = R ? &e->Runions : &e->Lunions;
if (state->more == 0)
return 0;
// reset `used` and let `pick_union_decision` clean the stack.
state->used = state->more;
statestack_set(state, state->used - 1, 1);
return 1;
}

static int pick_union_decision(jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT
{
jl_unionstate_t *state = R ? &e->Runions : &e->Lunions;
if (state->depth >= state->used) {
statestack_set(state, state->used, 0);
state->used++;
}
int ui = statestack_get(state, state->depth);
state->depth++;
if (ui == 0)
state->more = state->depth; // memorize that this was the deepest available choice
return ui;
}

static jl_value_t *pick_union_element(jl_value_t *u JL_PROPAGATES_ROOT, jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT
{
do {
if (state->depth >= state->used) {
statestack_set(state, state->used, 0);
state->used++;
}
int ui = statestack_get(state, state->depth);
state->depth++;
if (ui == 0) {
state->more = state->depth; // memorize that this was the deepest available choice
u = ((jl_uniontype_t*)u)->a;
}
else {
if (pick_union_decision(e, R))
u = ((jl_uniontype_t*)u)->b;
}
else
u = ((jl_uniontype_t*)u)->a;
} while (jl_is_uniontype(u));
return u;
}
Expand Down Expand Up @@ -1208,15 +1223,7 @@ static int subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param)
// of unions and vars: if matching `typevar <: union`, first try to match the whole
// union against the variable before trying to take it apart to see if there are any
// variables lurking inside.
jl_unionstate_t *state = &e->Runions;
if (state->depth >= state->used) {
statestack_set(state, state->used, 0);
state->used++;
}
ui = statestack_get(state, state->depth);
state->depth++;
if (ui == 0)
state->more = state->depth; // memorize that this was the deepest available choice
ui = pick_union_decision(e, 1);
}
if (ui == 1)
y = pick_union_element(y, e, 1);
Expand Down Expand Up @@ -1383,14 +1390,7 @@ static int forall_exists_equal(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
// here to try the usual algorithm if subtyping later fails.
jl_unionstate_t *state = &e->Runions;
jl_saved_unionstate_t oldRunions; push_unionstate(&oldRunions, state);
if (state->depth >= state->used) {
statestack_set(state, state->used, 0);
state->used++;
}
int ui = statestack_get(state, state->depth);
state->depth++;
if (ui == 0) {
state->more = state->depth; // memorize that this was the deepest available choice
if (pick_union_decision(e, 1) == 0) {
if (equal_unions((jl_uniontype_t*)x, (jl_uniontype_t*)y, e))
return 1;
pop_unionstate(state, &oldRunions);
Expand All @@ -1414,18 +1414,12 @@ static int forall_exists_equal(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
pop_unionstate(&e->Runions, &oldRunions);
}
else {
int lastset = 0;
while (1) {
e->Lunions.more = 0;
e->Lunions.depth = 0;
sub = subtype(x, y, e, 2);
int set = e->Lunions.more;
if (!sub || !set)
if (!sub || !next_union_state(e, 0))
break;
for (int i = set; i <= lastset; i++)
statestack_set(&e->Lunions, i, 0);
lastset = set - 1;
statestack_set(&e->Lunions, lastset, 1);
}
}

Expand All @@ -1436,16 +1430,14 @@ static int forall_exists_equal(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
static int exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, jl_value_t *saved, jl_savedenv_t *se, int param)
{
e->Runions.used = 0;
int lastset = 0;
while (1) {
e->Runions.depth = 0;
e->Runions.more = 0;
e->Lunions.depth = 0;
e->Lunions.more = 0;
if (subtype(x, y, e, param))
return 1;
int set = e->Runions.more;
if (set) {
if (next_union_state(e, 1)) {
// We preserve `envout` here as `subtype_unionall` needs previous assigned env values.
int oldidx = e->envidx;
e->envidx = e->envsz;
Expand All @@ -1456,10 +1448,6 @@ static int exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, jl_value_
restore_env(e, saved, se);
return 0;
}
for (int i = set; i <= lastset; i++)
statestack_set(&e->Runions, i, 0);
lastset = set - 1;
statestack_set(&e->Runions, lastset, 1);
}
}

Expand All @@ -1475,19 +1463,13 @@ static int forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, in
save_env(e, &saved, &se);

e->Lunions.used = 0;
int lastset = 0;
int sub;
while (1) {
sub = exists_subtype(x, y, e, saved, &se, param);
int set = e->Lunions.more;
if (!sub || !set)
if (!sub || !next_union_state(e, 0))
break;
free_env(&se);
save_env(e, &saved, &se);
for (int i = set; i <= lastset; i++)
statestack_set(&e->Lunions, i, 0);
lastset = set - 1;
statestack_set(&e->Lunions, lastset, 1);
}

free_env(&se);
Expand Down Expand Up @@ -3326,39 +3308,30 @@ static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
jl_value_t **merged = &is[3];
jl_savedenv_t se, me;
save_env(e, saved, &se);
int lastset = 0, niter = 0, total_iter = 0;
jl_value_t *ii = intersect(x, y, e, 0);
is[0] = ii; // root
int niter = 0, total_iter = 0;
is[0] = intersect(x, y, e, 0); // root
if (is[0] != jl_bottom_type)
niter = merge_env(e, merged, &me, niter);
restore_env(e, *saved, &se);
while (e->Runions.more) {
if (e->emptiness_only && ii != jl_bottom_type)
while (next_union_state(e, 1)) {
if (e->emptiness_only && is[0] != jl_bottom_type)
break;
e->Runions.depth = 0;
int set = e->Runions.more - 1;
e->Runions.more = 0;
statestack_set(&e->Runions, set, 1);
for (int i = set + 1; i <= lastset; i++)
statestack_set(&e->Runions, i, 0);
lastset = set;

is[0] = ii;
is[1] = intersect(x, y, e, 0);
if (is[1] != jl_bottom_type)
niter = merge_env(e, merged, &me, niter);
restore_env(e, *saved, &se);
if (is[0] == jl_bottom_type)
ii = is[1];
else if (is[1] == jl_bottom_type)
ii = is[0];
else {
is[0] = is[1];
else if (is[1] != jl_bottom_type) {
// TODO: the repeated subtype checks in here can get expensive
ii = jl_type_union(is, 2);
is[0] = jl_type_union(is, 2);
}
total_iter++;
if (niter > 4 || total_iter > 400000) {
ii = y;
is[0] = y;
break;
}
}
Expand All @@ -3368,7 +3341,7 @@ static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
}
free_env(&se);
JL_GC_POP();
return ii;
return is[0];
}

// type intersection entry points
Expand Down

0 comments on commit 1107361

Please sign in to comment.