diff --git a/src/jltypes.c b/src/jltypes.c index bf15611de4587..d8cbb9acb63c2 100644 --- a/src/jltypes.c +++ b/src/jltypes.c @@ -719,20 +719,6 @@ jl_value_t *simple_union(jl_value_t *a, jl_value_t *b) int obviously_disjoint(jl_value_t *a, jl_value_t *b, int specificity); -static int simple_disjoint(jl_value_t *a, jl_value_t *b, int hasfree) -{ - if (jl_is_uniontype(b)) { - jl_value_t *b1 = ((jl_uniontype_t *)b)->a, *b2 = ((jl_uniontype_t *)b)->b; - JL_GC_PUSH2(&b1, &b2); - int res = simple_disjoint(a, b1, hasfree) && simple_disjoint(a, b2, hasfree); - JL_GC_POP(); - return res; - } - if (!hasfree && !jl_has_free_typevars(b)) - return jl_has_empty_intersection(a, b); - return obviously_disjoint(a, b, 0); -} - jl_value_t *simple_intersect(jl_value_t *a, jl_value_t *b, int overesi) { // Unlike `Union`, we don't unwrap `UnionAll` here to avoid possible widening. @@ -746,19 +732,31 @@ jl_value_t *simple_intersect(jl_value_t *a, jl_value_t *b, int overesi) flatten_type_union(&b, 1, temp, &count, 0); assert(count == nt); size_t i, j; + int8_t *stemp = (int8_t *)alloca(count); // first remove disjoint elements. + memset(stemp, 0, count); + for (i = 0; i < nta; i++) { + int hasfree = jl_has_free_typevars(temp[i]); + for (j = nta; j < nt; j++) { + if (!stemp[i] || !stemp[j]) { + int intersect = !hasfree && !jl_has_free_typevars(temp[j]); + if (!(intersect ? jl_has_empty_intersection(temp[i], temp[j]) : obviously_disjoint(temp[i], temp[j], 0))) + stemp[i] = stemp[j] = 1; + } + } + } for (i = 0; i < nt; i++) { - if (simple_disjoint(temp[i], (i < nta ? b : a), jl_has_free_typevars(temp[i]))) - temp[i] = NULL; + temp[i] = stemp[i] ? temp[i] : NULL; } // then check subtyping. // stemp[k] == -1 : ∃i temp[k] >:ₛ temp[i] // stemp[k] == 1 : ∃i temp[k] == temp[i] // stemp[k] == 2 : ∃i temp[k] <:ₛ temp[i] - int8_t *stemp = (int8_t *)alloca(count); memset(stemp, 0, count); + int all_disjoint = 1, subs[2] = {1, 1}, rs[2] = {1, 1}; for (i = 0; i < nta; i++) { if (temp[i] == NULL) continue; + all_disjoint = 0; int hasfree = jl_has_free_typevars(temp[i]); for (j = nta; j < nt; j++) { if (temp[j] == NULL) continue; @@ -778,22 +776,23 @@ jl_value_t *simple_intersect(jl_value_t *a, jl_value_t *b, int overesi) } } } - int subs[2] = {1, 1}, rs[2] = {1, 1}; - for (i = 0; i < nt; i++) { - subs[i >= nta] &= (temp[i] == NULL || stemp[i] > 0); - rs[i >= nta] &= (temp[i] != NULL && stemp[i] > 0); - } - // return a(b) if a(b) <: b(a) - if (rs[0]) { - JL_GC_POP(); - return a; - } - if (rs[1]) { - JL_GC_POP(); - return b; + if (!all_disjoint) { + for (i = 0; i < nt; i++) { + subs[i >= nta] &= (temp[i] == NULL || stemp[i] > 0); + rs[i >= nta] &= (temp[i] != NULL && stemp[i] > 0); + } + // return a(b) if a(b) <: b(a) + if (rs[0]) { + JL_GC_POP(); + return a; + } + if (rs[1]) { + JL_GC_POP(); + return b; + } } // return `Union{}` for `merge_env` if we can't prove `<:` or `>:` - if (!overesi && !subs[0] && !subs[1]) { + if (all_disjoint || (!overesi && !subs[0] && !subs[1])) { JL_GC_POP(); return jl_bottom_type; }