Skip to content

Commit

Permalink
Subtype: minor optimization for simple_intersect (#49477)
Browse files Browse the repository at this point in the history
1. remove duplicated disjoint check.
2. add a fast path for all disjoint case.
  • Loading branch information
N5N3 authored Apr 25, 2023
1 parent 86b819c commit b291522
Showing 1 changed file with 30 additions and 31 deletions.
61 changes: 30 additions & 31 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -719,20 +719,6 @@ jl_value_t *simple_union(jl_value_t *a, jl_value_t *b)

int obviously_disjoint(jl_value_t *a, jl_value_t *b, int specificity);

static int simple_disjoint(jl_value_t *a, jl_value_t *b, int hasfree)
{
if (jl_is_uniontype(b)) {
jl_value_t *b1 = ((jl_uniontype_t *)b)->a, *b2 = ((jl_uniontype_t *)b)->b;
JL_GC_PUSH2(&b1, &b2);
int res = simple_disjoint(a, b1, hasfree) && simple_disjoint(a, b2, hasfree);
JL_GC_POP();
return res;
}
if (!hasfree && !jl_has_free_typevars(b))
return jl_has_empty_intersection(a, b);
return obviously_disjoint(a, b, 0);
}

jl_value_t *simple_intersect(jl_value_t *a, jl_value_t *b, int overesi)
{
// Unlike `Union`, we don't unwrap `UnionAll` here to avoid possible widening.
Expand All @@ -746,19 +732,31 @@ jl_value_t *simple_intersect(jl_value_t *a, jl_value_t *b, int overesi)
flatten_type_union(&b, 1, temp, &count, 0);
assert(count == nt);
size_t i, j;
int8_t *stemp = (int8_t *)alloca(count);
// first remove disjoint elements.
memset(stemp, 0, count);
for (i = 0; i < nta; i++) {
int hasfree = jl_has_free_typevars(temp[i]);
for (j = nta; j < nt; j++) {
if (!stemp[i] || !stemp[j]) {
int intersect = !hasfree && !jl_has_free_typevars(temp[j]);
if (!(intersect ? jl_has_empty_intersection(temp[i], temp[j]) : obviously_disjoint(temp[i], temp[j], 0)))
stemp[i] = stemp[j] = 1;
}
}
}
for (i = 0; i < nt; i++) {
if (simple_disjoint(temp[i], (i < nta ? b : a), jl_has_free_typevars(temp[i])))
temp[i] = NULL;
temp[i] = stemp[i] ? temp[i] : NULL;
}
// then check subtyping.
// stemp[k] == -1 : ∃i temp[k] >:ₛ temp[i]
// stemp[k] == 1 : ∃i temp[k] == temp[i]
// stemp[k] == 2 : ∃i temp[k] <:ₛ temp[i]
int8_t *stemp = (int8_t *)alloca(count);
memset(stemp, 0, count);
int all_disjoint = 1, subs[2] = {1, 1}, rs[2] = {1, 1};
for (i = 0; i < nta; i++) {
if (temp[i] == NULL) continue;
all_disjoint = 0;
int hasfree = jl_has_free_typevars(temp[i]);
for (j = nta; j < nt; j++) {
if (temp[j] == NULL) continue;
Expand All @@ -778,22 +776,23 @@ jl_value_t *simple_intersect(jl_value_t *a, jl_value_t *b, int overesi)
}
}
}
int subs[2] = {1, 1}, rs[2] = {1, 1};
for (i = 0; i < nt; i++) {
subs[i >= nta] &= (temp[i] == NULL || stemp[i] > 0);
rs[i >= nta] &= (temp[i] != NULL && stemp[i] > 0);
}
// return a(b) if a(b) <: b(a)
if (rs[0]) {
JL_GC_POP();
return a;
}
if (rs[1]) {
JL_GC_POP();
return b;
if (!all_disjoint) {
for (i = 0; i < nt; i++) {
subs[i >= nta] &= (temp[i] == NULL || stemp[i] > 0);
rs[i >= nta] &= (temp[i] != NULL && stemp[i] > 0);
}
// return a(b) if a(b) <: b(a)
if (rs[0]) {
JL_GC_POP();
return a;
}
if (rs[1]) {
JL_GC_POP();
return b;
}
}
// return `Union{}` for `merge_env` if we can't prove `<:` or `>:`
if (!overesi && !subs[0] && !subs[1]) {
if (all_disjoint || (!overesi && !subs[0] && !subs[1])) {
JL_GC_POP();
return jl_bottom_type;
}
Expand Down

0 comments on commit b291522

Please sign in to comment.