From f995dcdaab7fe14557768e40b03c13bf5b037b22 Mon Sep 17 00:00:00 2001 From: Jameson Nash Date: Thu, 7 Sep 2023 16:08:55 +0000 Subject: [PATCH 1/2] cleanup Vararg instantiation code Various simplifications and improvements from investigating #51228. Improves the logic for showing of NTuple to handle constant lengths. Improves the logic for showing NTuple of bound length (e.g. NTuple itself). Also makes a choice to avoid showing non-types as NTuple, but instead try to write them out, to make it more visually obvious when the parameters have been swapped. --- base/show.jl | 65 +++++++++++++++++++++++++++++-------- src/codegen.cpp | 4 +-- src/gf.c | 8 ++--- src/jltypes.c | 74 +++++++++++++++++++++++++----------------- src/julia.h | 2 +- src/method.c | 2 +- src/precompile_utils.c | 2 +- src/subtype.c | 4 +-- test/show.jl | 10 ++++-- 9 files changed, 115 insertions(+), 56 deletions(-) diff --git a/base/show.jl b/base/show.jl index 7da7aa925fa1c..079bf5a423cc0 100644 --- a/base/show.jl +++ b/base/show.jl @@ -1084,29 +1084,68 @@ function show_datatype(io::IO, x::DataType, wheres::Vector{TypeVar}=TypeVar[]) # Print tuple types with homogeneous tails longer than max_n compactly using `NTuple` or `Vararg` if istuple + if n == 0 + print(io, "Tuple{}") + return + end + + # find the length of the homogeneous tail max_n = 3 taillen = 1 - for i in (n-1):-1:1 - if parameters[i] === parameters[n] - taillen += 1 + pn = parameters[n] + fulln = n + vakind = :none + vaN = 0 + if pn isa Core.TypeofVararg + if isdefined(pn, :N) + vaN = pn.N + if vaN isa Int + taillen = vaN + fulln += taillen - 1 + vakind = :fixed + else + vakind = :bound + end else - break + vakind = :unbound + end + pn = unwrapva(pn) + end + if !(pn isa TypeVar || pn isa Type) + # prefer Tuple over NTuple if it contains something other than types + # (e.g. if the user has switched the N and T accidentally) + taillen = 0 + elseif vakind === :none || vakind === :fixed + for i in (n-1):-1:1 + if parameters[i] === pn + taillen += 1 + else + break + end end end - if n == taillen > max_n - print(io, "NTuple{", n, ", ") - show(io, parameters[1]) + + # prefer NTuple over Tuple if it is a Vararg without a fixed length + # and prefer Tuple for short lists of elements + if (vakind == :bound && n == 1 == taillen) || (vakind === :fixed && taillen == fulln > max_n) || + (vakind === :none && taillen == fulln > max_n) + print(io, "NTuple{") + vakind === :bound ? show(io, vaN) : print(io, fulln) + print(io, ", ") + show(io, pn) print(io, "}") else print(io, "Tuple{") - for i = 1:(taillen > max_n ? n-taillen : n) + headlen = (taillen > max_n ? fulln - taillen : fulln) + for i = 1:headlen i > 1 && print(io, ", ") - show(io, parameters[i]) + show(io, vakind === :fixed && i >= n ? pn : parameters[i]) end - if taillen > max_n - print(io, ", Vararg{") - show(io, parameters[n]) - print(io, ", ", taillen, "}") + if headlen < fulln + headlen > 0 && print(io, ", ") + print(io, "Vararg{") + show(io, pn) + print(io, ", ", fulln - headlen, "}") end print(io, "}") end diff --git a/src/codegen.cpp b/src/codegen.cpp index 324157814b19b..edc3b614b2ccc 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -6738,7 +6738,7 @@ static jl_cgval_t emit_cfunction(jl_codectx_t &ctx, jl_value_t *output_type, con sigt = NULL; } else { - sigt = jl_apply_tuple_type((jl_svec_t*)sigt); + sigt = jl_apply_tuple_type((jl_svec_t*)sigt, 1); } if (sigt && !(unionall_env && jl_has_typevar_from_unionall(rt, unionall_env))) { unionall_env = NULL; @@ -7242,7 +7242,7 @@ static jl_datatype_t *compute_va_type(jl_method_instance_t *lam, size_t nreq) } jl_svecset(tupargs, i-nreq, argType); } - jl_value_t *typ = jl_apply_tuple_type(tupargs); + jl_value_t *typ = jl_apply_tuple_type(tupargs, 1); JL_GC_POP(); return (jl_datatype_t*)typ; } diff --git a/src/gf.c b/src/gf.c index e8cdb493026b0..b403a6e1af087 100644 --- a/src/gf.c +++ b/src/gf.c @@ -1266,7 +1266,7 @@ static jl_method_instance_t *cache_method( intptr_t max_varargs = get_max_varargs(definition, kwmt, mt, NULL); jl_compilation_sig(tt, sparams, definition, max_varargs, &newparams); if (newparams) { - temp2 = jl_apply_tuple_type(newparams); + temp2 = jl_apply_tuple_type(newparams, 1); // Now there may be a problem: the widened signature is more general // than just the given arguments, so it might conflict with another // definition that does not have cache instances yet. To fix this, we @@ -1389,7 +1389,7 @@ static jl_method_instance_t *cache_method( } } if (newparams) { - simplett = (jl_datatype_t*)jl_apply_tuple_type(newparams); + simplett = (jl_datatype_t*)jl_apply_tuple_type(newparams, 1); temp2 = (jl_value_t*)simplett; } @@ -2579,7 +2579,7 @@ JL_DLLEXPORT jl_value_t *jl_normalize_to_compilable_sig(jl_methtable_t *mt, jl_t jl_compilation_sig(ti, env, m, max_varargs, &newparams); int is_compileable = ((jl_datatype_t*)ti)->isdispatchtuple; if (newparams) { - tt = (jl_datatype_t*)jl_apply_tuple_type(newparams); + tt = (jl_datatype_t*)jl_apply_tuple_type(newparams, 1); if (!is_compileable) { // compute new env, if used below jl_value_t *ti = jl_type_intersection_env((jl_value_t*)tt, (jl_value_t*)m->sig, &newparams); @@ -2834,7 +2834,7 @@ jl_value_t *jl_argtype_with_function_type(jl_value_t *ft JL_MAYBE_UNROOTED, jl_v jl_svecset(tt, 0, ft); for (size_t i = 0; i < l; i++) jl_svecset(tt, i+1, jl_tparam(types,i)); - tt = (jl_value_t*)jl_apply_tuple_type((jl_svec_t*)tt); + tt = (jl_value_t*)jl_apply_tuple_type((jl_svec_t*)tt, 1); tt = jl_rewrap_unionall_(tt, types0); JL_GC_POP(); return tt; diff --git a/src/jltypes.c b/src/jltypes.c index f38197e49353d..57c7ae9e9616e 100644 --- a/src/jltypes.c +++ b/src/jltypes.c @@ -333,7 +333,7 @@ JL_DLLEXPORT int jl_get_size(jl_value_t *val, size_t *pnt) if (jl_is_long(val)) { ssize_t slen = jl_unbox_long(val); if (slen < 0) - jl_errorf("size or dimension is negative: %d", slen); + jl_errorf("size or dimension is negative: %zd", slen); *pnt = slen; return 1; } @@ -1435,17 +1435,6 @@ jl_datatype_t *jl_apply_cmpswap_type(jl_value_t *ty) return rettyp; } -// used to expand an NTuple to a flat representation -static jl_value_t *jl_tupletype_fill(size_t n, jl_value_t *v) -{ - jl_value_t *p = NULL; - JL_GC_PUSH1(&p); - p = (jl_value_t*)jl_svec_fill(n, v); - p = jl_apply_tuple_type((jl_svec_t*)p); - JL_GC_POP(); - return p; -} - JL_EXTENSION struct _jl_typestack_t { jl_datatype_t *tt; struct _jl_typestack_t *prev; @@ -1796,13 +1785,13 @@ int _may_substitute_ub(jl_value_t *v, jl_tvar_t *var, int inside_inv, int *cov_c // * `var` does not appear in invariant position // * `var` appears at most once (in covariant position) and not in a `Vararg` // unless the upper bound is concrete (diagonal rule) -int may_substitute_ub(jl_value_t *v, jl_tvar_t *var) JL_NOTSAFEPOINT +static int may_substitute_ub(jl_value_t *v, jl_tvar_t *var) JL_NOTSAFEPOINT { int cov_count = 0; return _may_substitute_ub(v, var, 0, &cov_count); } -jl_value_t *normalize_unionalls(jl_value_t *t) +static jl_value_t *normalize_unionalls(jl_value_t *t) { if (jl_is_uniontype(t)) { jl_uniontype_t *u = (jl_uniontype_t*)t; @@ -1840,6 +1829,29 @@ jl_value_t *normalize_unionalls(jl_value_t *t) return t; } +// used to expand an NTuple to a flat representation +static jl_value_t *jl_tupletype_fill(size_t n, jl_value_t *t, int check) +{ + if (check) { + // Since we are skipping making the Vararg and skipping checks later, + // we inline the checks from jl_wrap_vararg here now + if (!jl_valid_type_param(t)) + jl_type_error_rt("Vararg", "type", (jl_value_t*)jl_type_type, t); + // jl_wrap_vararg sometimes simplifies the type, so we only do this 1 time, instead of for each n later + t = normalize_unionalls(t); + jl_value_t *tw = extract_wrapper(t); + if (tw && t != tw && jl_types_equal(t, tw)) + t = tw; + check = 0; // remember that checks are already done now + } + jl_value_t *p = NULL; + JL_GC_PUSH1(&p); + p = (jl_value_t*)jl_svec_fill(n, t); + p = jl_apply_tuple_type((jl_svec_t*)p, check); + JL_GC_POP(); + return p; +} + static jl_value_t *_jl_instantiate_type_in_env(jl_value_t *ty, jl_unionall_t *env, jl_value_t **vals, jl_typeenv_t *prev, jl_typestack_t *stack); static jl_value_t *inst_datatype_inner(jl_datatype_t *dt, jl_svec_t *p, jl_value_t **iparams, size_t ntp, @@ -1962,7 +1974,7 @@ static jl_value_t *inst_datatype_inner(jl_datatype_t *dt, jl_svec_t *p, jl_value if (nt == 0 || !jl_has_free_typevars(va0)) { if (ntp == 1) { JL_GC_POP(); - return jl_tupletype_fill(nt, va0); + return jl_tupletype_fill(nt, va0, 0); } size_t i, l; p = jl_alloc_svec(ntp - 1 + nt); @@ -1971,7 +1983,7 @@ static jl_value_t *inst_datatype_inner(jl_datatype_t *dt, jl_svec_t *p, jl_value l = ntp - 1 + nt; for (; i < l; i++) jl_svecset(p, i, va0); - jl_value_t *ndt = jl_apply_tuple_type(p); + jl_value_t *ndt = jl_apply_tuple_type(p, check); JL_GC_POP(); return ndt; } @@ -2136,19 +2148,19 @@ static jl_value_t *inst_datatype_inner(jl_datatype_t *dt, jl_svec_t *p, jl_value return (jl_value_t*)ndt; } -static jl_value_t *jl_apply_tuple_type_v_(jl_value_t **p, size_t np, jl_svec_t *params) +static jl_value_t *jl_apply_tuple_type_v_(jl_value_t **p, size_t np, jl_svec_t *params, int check) { - return inst_datatype_inner(jl_anytuple_type, params, p, np, NULL, NULL, 1); + return inst_datatype_inner(jl_anytuple_type, params, p, np, NULL, NULL, check); } -JL_DLLEXPORT jl_value_t *jl_apply_tuple_type(jl_svec_t *params) +JL_DLLEXPORT jl_value_t *jl_apply_tuple_type(jl_svec_t *params, int check) { - return jl_apply_tuple_type_v_(jl_svec_data(params), jl_svec_len(params), params); + return jl_apply_tuple_type_v_(jl_svec_data(params), jl_svec_len(params), params, check); } JL_DLLEXPORT jl_value_t *jl_apply_tuple_type_v(jl_value_t **p, size_t np) { - return jl_apply_tuple_type_v_(p, np, NULL); + return jl_apply_tuple_type_v_(p, np, NULL, 1); } jl_tupletype_t *jl_lookup_arg_tuple_type(jl_value_t *arg1, jl_value_t **args, size_t nargs, int leaf) @@ -2211,13 +2223,15 @@ static jl_value_t *inst_tuple_w_(jl_value_t *t, jl_typeenv_t *env, jl_typestack_ jl_datatype_t *tt = (jl_datatype_t*)t; jl_svec_t *tp = tt->parameters; size_t ntp = jl_svec_len(tp); - // Instantiate NTuple{3,Int} + // Instantiate Tuple{Vararg{T,N}} where T is fixed and N is known, such as Dims{3} + // And avoiding allocating the intermediate steps // Note this does not instantiate Tuple{Vararg{Int,3}}; that's done in inst_datatype_inner + // Note this does not instantiate NTuple{N,T}, since it is unnecessary and inefficient to expand that now if (jl_is_va_tuple(tt) && ntp == 1) { - // If this is a Tuple{Vararg{T,N}} with known N, expand it to + // If this is a Tuple{Vararg{T,N}} with known N and T, expand it to // a fixed-length tuple jl_value_t *T=NULL, *N=NULL; - jl_value_t *va = jl_unwrap_unionall(jl_tparam0(tt)); + jl_value_t *va = jl_tparam0(tt); jl_value_t *ttT = jl_unwrap_vararg(va); jl_value_t *ttN = jl_unwrap_vararg_num(va); jl_typeenv_t *e = env; @@ -2228,11 +2242,12 @@ static jl_value_t *inst_tuple_w_(jl_value_t *t, jl_typeenv_t *env, jl_typestack_ N = e->val; e = e->prev; } - if (T != NULL && N != NULL && jl_is_long(N)) { + if (T != NULL && N != NULL && jl_is_long(N)) { // TODO: && !jl_has_free_typevars(T) to match inst_datatype_inner, or even && jl_is_concrete_type(T) + // Since this is skipping jl_wrap_vararg, we inline the checks from it here ssize_t nt = jl_unbox_long(N); if (nt < 0) - jl_errorf("size or dimension is negative: %zd", nt); - return jl_tupletype_fill(nt, T); + jl_errorf("Vararg length is negative: %zd", nt); + return jl_tupletype_fill(nt, T, check); } } jl_value_t **iparams; @@ -2428,9 +2443,8 @@ jl_vararg_t *jl_wrap_vararg(jl_value_t *t, jl_value_t *n, int check) } } if (t) { - if (!jl_valid_type_param(t)) { + if (!jl_valid_type_param(t)) jl_type_error_rt("Vararg", "type", (jl_value_t*)jl_type_type, t); - } t = normalize_unionalls(t); jl_value_t *tw = extract_wrapper(t); if (tw && t != tw && jl_types_equal(t, tw)) @@ -2735,7 +2749,7 @@ void jl_init_types(void) JL_GC_DISABLED jl_anytuple_type->layout = NULL; jl_typeofbottom_type->super = jl_wrap_Type(jl_bottom_type); - jl_emptytuple_type = (jl_datatype_t*)jl_apply_tuple_type(jl_emptysvec); + jl_emptytuple_type = (jl_datatype_t*)jl_apply_tuple_type(jl_emptysvec, 0); jl_emptytuple = jl_gc_permobj(0, jl_emptytuple_type); jl_emptytuple_type->instance = jl_emptytuple; diff --git a/src/julia.h b/src/julia.h index 50c4f8994de15..84ff230f71843 100644 --- a/src/julia.h +++ b/src/julia.h @@ -1564,7 +1564,7 @@ JL_DLLEXPORT jl_value_t *jl_apply_type1(jl_value_t *tc, jl_value_t *p1); JL_DLLEXPORT jl_value_t *jl_apply_type2(jl_value_t *tc, jl_value_t *p1, jl_value_t *p2); JL_DLLEXPORT jl_datatype_t *jl_apply_modify_type(jl_value_t *dt); JL_DLLEXPORT jl_datatype_t *jl_apply_cmpswap_type(jl_value_t *dt); -JL_DLLEXPORT jl_value_t *jl_apply_tuple_type(jl_svec_t *params); +JL_DLLEXPORT jl_value_t *jl_apply_tuple_type(jl_svec_t *params, int check); // if uncertain, set check=1 JL_DLLEXPORT jl_value_t *jl_apply_tuple_type_v(jl_value_t **p, size_t np); JL_DLLEXPORT jl_datatype_t *jl_new_datatype(jl_sym_t *name, jl_module_t *module, diff --git a/src/method.c b/src/method.c index 68110d6995bbf..7d8d0e9ec4a78 100644 --- a/src/method.c +++ b/src/method.c @@ -998,7 +998,7 @@ JL_DLLEXPORT jl_method_t* jl_method_def(jl_svec_t *argdata, JL_GC_PUSH3(&f, &m, &argtype); size_t i, na = jl_svec_len(atypes); - argtype = jl_apply_tuple_type(atypes); + argtype = jl_apply_tuple_type(atypes, 1); if (!jl_is_datatype(argtype)) jl_error("invalid type in method definition (Union{})"); diff --git a/src/precompile_utils.c b/src/precompile_utils.c index 055ec4b3330f1..9a577b900a1b7 100644 --- a/src/precompile_utils.c +++ b/src/precompile_utils.c @@ -120,7 +120,7 @@ static void _compile_all_union(jl_value_t *sig) jl_svecset(p, i, ty); } } - methsig = jl_apply_tuple_type(p); + methsig = jl_apply_tuple_type(p, 1); methsig = jl_rewrap_unionall(methsig, sig); _compile_all_tvar_union(methsig); } diff --git a/src/subtype.c b/src/subtype.c index c67beecae9dbd..d8177f0fd21ff 100644 --- a/src/subtype.c +++ b/src/subtype.c @@ -3393,7 +3393,7 @@ static jl_value_t *intersect_tuple(jl_datatype_t *xd, jl_datatype_t *yd, jl_sten else if (isy) res = (jl_value_t*)yd; else if (p) - res = jl_apply_tuple_type(p); + res = jl_apply_tuple_type(p, 1); else res = jl_apply_tuple_type_v(params, np); } @@ -4130,7 +4130,7 @@ static jl_value_t *switch_union_tuple(jl_value_t *a, jl_value_t *b) ts[1] = jl_tparam(b, i); jl_svecset(vec, i, jl_type_union(ts, 2)); } - jl_value_t *ans = jl_apply_tuple_type(vec); + jl_value_t *ans = jl_apply_tuple_type(vec, 1); JL_GC_POP(); return ans; } diff --git a/test/show.jl b/test/show.jl index 0aa4d805491b1..f95f943c3c1a4 100644 --- a/test/show.jl +++ b/test/show.jl @@ -1368,6 +1368,9 @@ test_repr("(:).a") @test repr(Tuple{Float32, Float32, Float32}) == "Tuple{Float32, Float32, Float32}" @test repr(Tuple{String, Int64, Int64, Int64}) == "Tuple{String, Int64, Int64, Int64}" @test repr(Tuple{String, Int64, Int64, Int64, Int64}) == "Tuple{String, Vararg{Int64, 4}}" +@test repr(NTuple) == "NTuple{N, T} where {N, T}" +@test repr(Tuple{NTuple{N}, Vararg{NTuple{N}, 4}} where N) == "NTuple{5, NTuple{N, T} where T} where N" +@test repr(Tuple{Float64, NTuple{N}, Vararg{NTuple{N}, 4}} where N) == "Tuple{Float64, Vararg{NTuple{N, T} where T, 5}} where N" # Test printing of NamedTuples using the macro syntax @test repr(@NamedTuple{kw::Int64}) == "@NamedTuple{kw::Int64}" @@ -1380,17 +1383,20 @@ test_repr("(:).a") @test repr(@Kwargs{init::Int}) == "Base.Pairs{Symbol, $Int, Tuple{Symbol}, @NamedTuple{init::$Int}}" @testset "issue #42931" begin - @test repr(NTuple{4, :A}) == "NTuple{4, :A}" + @test repr(NTuple{4, :A}) == "Tuple{:A, :A, :A, :A}" @test repr(NTuple{3, :A}) == "Tuple{:A, :A, :A}" @test repr(NTuple{2, :A}) == "Tuple{:A, :A}" @test repr(NTuple{1, :A}) == "Tuple{:A}" @test repr(NTuple{0, :A}) == "Tuple{}" @test repr(Tuple{:A, :A, :A, :B}) == "Tuple{:A, :A, :A, :B}" - @test repr(Tuple{:A, :A, :A, :A}) == "NTuple{4, :A}" + @test repr(Tuple{:A, :A, :A, :A}) == "Tuple{:A, :A, :A, :A}" @test repr(Tuple{:A, :A, :A}) == "Tuple{:A, :A, :A}" @test repr(Tuple{:A}) == "Tuple{:A}" @test repr(Tuple{}) == "Tuple{}" + + @test repr(Tuple{Vararg{N, 10}} where N) == "NTuple{10, N} where N" + @test repr(Tuple{Vararg{10, N}} where N) == "Tuple{Vararg{10, N}} where N" end # Test that REPL/mime display of invalid UTF-8 data doesn't throw an exception: From 90932cacced44f52b91d3eb6c87158fdd353638a Mon Sep 17 00:00:00 2001 From: Jameson Nash Date: Sun, 24 Sep 2023 13:54:40 +0000 Subject: [PATCH 2/2] fixup! cleanup Vararg instantiation code --- src/jltypes.c | 8 +++++--- test/docs.jl | 4 ++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/jltypes.c b/src/jltypes.c index 57c7ae9e9616e..998f3fe47f157 100644 --- a/src/jltypes.c +++ b/src/jltypes.c @@ -1713,7 +1713,7 @@ static void check_datatype_parameters(jl_typename_t *tn, jl_value_t **params, si JL_GC_POP(); } -jl_value_t *extract_wrapper(jl_value_t *t JL_PROPAGATES_ROOT) JL_GLOBALLY_ROOTED +jl_value_t *extract_wrapper(jl_value_t *t JL_PROPAGATES_ROOT) JL_NOTSAFEPOINT JL_GLOBALLY_ROOTED { t = jl_unwrap_unionall(t); if (jl_is_datatype(t)) @@ -1832,6 +1832,8 @@ static jl_value_t *normalize_unionalls(jl_value_t *t) // used to expand an NTuple to a flat representation static jl_value_t *jl_tupletype_fill(size_t n, jl_value_t *t, int check) { + jl_value_t *p = NULL; + JL_GC_PUSH1(&p); if (check) { // Since we are skipping making the Vararg and skipping checks later, // we inline the checks from jl_wrap_vararg here now @@ -1839,13 +1841,13 @@ static jl_value_t *jl_tupletype_fill(size_t n, jl_value_t *t, int check) jl_type_error_rt("Vararg", "type", (jl_value_t*)jl_type_type, t); // jl_wrap_vararg sometimes simplifies the type, so we only do this 1 time, instead of for each n later t = normalize_unionalls(t); + p = t; jl_value_t *tw = extract_wrapper(t); if (tw && t != tw && jl_types_equal(t, tw)) t = tw; + p = t; check = 0; // remember that checks are already done now } - jl_value_t *p = NULL; - JL_GC_PUSH1(&p); p = (jl_value_t*)jl_svec_fill(n, t); p = jl_apply_tuple_type((jl_svec_t*)p, check); JL_GC_POP(); diff --git a/test/docs.jl b/test/docs.jl index a2f556b7ee848..760aa76671c0c 100644 --- a/test/docs.jl +++ b/test/docs.jl @@ -1028,7 +1028,7 @@ struct $(curmod_prefix)Undocumented.st3{T<:Integer, N} # Fields ``` -a :: Tuple{Vararg{T<:Integer, N}} +a :: NTuple{N, T<:Integer} b :: Array{Int64, N} c :: Int64 ``` @@ -1052,7 +1052,7 @@ struct $(curmod_prefix)Undocumented.st4{T, N} # Fields ``` a :: T -b :: Tuple{Vararg{T, N}} +b :: NTuple{N, T} ``` # Supertype Hierarchy