Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ensure jl_compilation_sig does not narrow Vararg #48152

Merged
merged 2 commits into from
Jan 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 54 additions & 24 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -584,15 +584,12 @@ jl_value_t *jl_nth_slot_type(jl_value_t *sig, size_t i) JL_NOTSAFEPOINT
{
sig = jl_unwrap_unionall(sig);
size_t len = jl_nparams(sig);
if (len == 0)
return NULL;
if (i < len-1)
return jl_tparam(sig, i);
if (jl_is_vararg(jl_tparam(sig, len-1)))
return jl_unwrap_vararg(jl_tparam(sig, len-1));
if (i == len-1)
return jl_tparam(sig, i);
return NULL;
jl_value_t *p = jl_tparam(sig, len-1);
if (jl_is_vararg(p))
p = jl_unwrap_vararg(p);
return p;
}

// if concrete_match returns false, the sig may specify `Type{T::DataType}`, while the `tt` contained DataType
Expand Down Expand Up @@ -660,31 +657,62 @@ static jl_value_t *ml_matches(jl_methtable_t *mt,

// get the compilation signature specialization for this method
static void jl_compilation_sig(
jl_tupletype_t *const tt, // the original tupletype of the call : this is expected to be a relative simple type (no Varags, Union, UnionAll, etc.)
jl_tupletype_t *const tt, // the original tupletype of the call (or DataType from precompile)
jl_svec_t *sparams,
jl_method_t *definition,
intptr_t nspec,
// output:
jl_svec_t **const newparams JL_REQUIRE_ROOTED_SLOT)
{
assert(jl_is_tuple_type(tt));
jl_value_t *decl = definition->sig;
size_t nargs = definition->nargs; // == jl_nparams(jl_unwrap_unionall(decl));

if (definition->generator) {
// staged functions aren't optimized
// so assume the caller was intelligent about calling us
return;
}
if (definition->sig == (jl_value_t*)jl_anytuple_type && jl_atomic_load_relaxed(&definition->unspecialized)) {

if (decl == (jl_value_t*)jl_anytuple_type && jl_atomic_load_relaxed(&definition->unspecialized)) {
*newparams = jl_anytuple_type->parameters; // handle builtin methods
return;
}

jl_value_t *decl = definition->sig;
assert(jl_is_tuple_type(tt));
// some early sanity checks
size_t i, np = jl_nparams(tt);
size_t nargs = definition->nargs; // == jl_nparams(jl_unwrap_unionall(decl));
switch (jl_va_tuple_kind((jl_datatype_t*)decl)) {
case JL_VARARG_NONE:
if (jl_is_va_tuple(tt))
// odd
return;
if (np != nargs)
// there are not enough input parameters to make this into a compilation sig
return;
break;
case JL_VARARG_INT:
case JL_VARARG_BOUND:
if (jl_is_va_tuple(tt))
// the length needed is not known, but required for compilation
return;
if (np < nargs - 1)
// there are not enough input parameters to make this into a compilation sig
return;
break;
case JL_VARARG_UNBOUND:
if (np < nspec && jl_is_va_tuple(tt))
// there are insufficient given parameters for jl_isa_compileable_sig now to like this type
// (there were probably fewer methods defined when we first selected this signature)
return;
break;
}

jl_value_t *type_i = NULL;
JL_GC_PUSH1(&type_i);
for (i = 0; i < np; i++) {
jl_value_t *elt = jl_tparam(tt, i);
if (jl_is_vararg(elt))
elt = jl_unwrap_vararg(elt);
jl_value_t *decl_i = jl_nth_slot_type(decl, i);
type_i = jl_rewrap_unionall(decl_i, decl);
size_t i_arg = (i < nargs - 1 ? i : nargs - 1);
Expand Down Expand Up @@ -732,16 +760,14 @@ static void jl_compilation_sig(
if (!jl_has_free_typevars(decl_i) && !jl_is_kind(decl_i)) {
if (decl_i != elt) {
if (!*newparams) *newparams = jl_svec_copy(tt->parameters);
// n.b. it is possible here that !(elt <: decl_i), if elt was something unusual from intersection
// so this might narrow the result slightly, though still being compatible with the declared signature
jl_svecset(*newparams, i, (jl_value_t*)decl_i);
}
continue;
}
}

if (jl_is_vararg(elt)) {
continue;
}

if (jl_types_equal(elt, (jl_value_t*)jl_type_type)) { // elt == Type{T} where T
// not triggered for isdispatchtuple(tt), this attempts to handle
// some cases of adapting a random signature into a compilation signature
Expand Down Expand Up @@ -827,7 +853,7 @@ static void jl_compilation_sig(
// in general, here we want to find the biggest type that's not a
// supertype of any other method signatures. so far we are conservative
// and the types we find should be bigger.
if (jl_nparams(tt) >= nspec && jl_va_tuple_kind((jl_datatype_t*)decl) == JL_VARARG_UNBOUND) {
if (np >= nspec && jl_va_tuple_kind((jl_datatype_t*)decl) == JL_VARARG_UNBOUND) {
if (!*newparams) *newparams = tt->parameters;
type_i = jl_svecref(*newparams, nspec - 2);
// if all subsequent arguments are subtypes of type_i, specialize
Expand Down Expand Up @@ -2076,7 +2102,9 @@ JL_DLLEXPORT jl_value_t *jl_matching_methods(jl_tupletype_t *types, jl_value_t *
if (ambig != NULL)
*ambig = 0;
jl_value_t *unw = jl_unwrap_unionall((jl_value_t*)types);
if (jl_is_tuple_type(unw) && (unw == (jl_value_t*)jl_emptytuple_type || jl_tparam0(unw) == jl_bottom_type))
if (!jl_is_tuple_type(unw))
return (jl_value_t*)jl_an_empty_vec_any;
if (unw == (jl_value_t*)jl_emptytuple_type || jl_tparam0(unw) == jl_bottom_type)
return (jl_value_t*)jl_an_empty_vec_any;
if (mt == jl_nothing)
mt = (jl_value_t*)jl_method_table_for(unw);
Expand Down Expand Up @@ -2173,8 +2201,8 @@ jl_code_instance_t *jl_compile_method_internal(jl_method_instance_t *mi, size_t
if (codeinst)
return codeinst;

// if mi has a better (wider) signature for compilation use that instead
// and just copy it here for caching
// if mi has a better (wider) signature preferred for compilation use that
// instead and just copy it here for caching
jl_method_instance_t *mi2 = jl_normalize_to_compilable_mi(mi);
if (mi2 != mi) {
jl_code_instance_t *codeinst2 = jl_compile_method_internal(mi2, world);
Expand Down Expand Up @@ -2363,7 +2391,7 @@ JL_DLLEXPORT jl_value_t *jl_normalize_to_compilable_sig(jl_methtable_t *mt, jl_t
jl_method_instance_t *jl_normalize_to_compilable_mi(jl_method_instance_t *mi JL_PROPAGATES_ROOT)
{
jl_method_t *def = mi->def.method;
if (!jl_is_method(def))
if (!jl_is_method(def) || !jl_is_datatype(mi->specTypes))
return mi;
jl_methtable_t *mt = jl_method_get_table(def);
if ((jl_value_t*)mt == jl_nothing)
Expand Down Expand Up @@ -2445,7 +2473,7 @@ jl_method_instance_t *jl_get_specialization1(jl_tupletype_t *types JL_PROPAGATES

// Get a MethodInstance for a precompile() call. This uses a special kind of lookup that
// tries to find a method for which the requested signature is compileable.
jl_method_instance_t *jl_get_compile_hint_specialization(jl_tupletype_t *types JL_PROPAGATES_ROOT, size_t world, size_t *min_valid, size_t *max_valid, int mt_cache)
static jl_method_instance_t *jl_get_compile_hint_specialization(jl_tupletype_t *types JL_PROPAGATES_ROOT, size_t world, size_t *min_valid, size_t *max_valid, int mt_cache)
{
if (jl_has_free_typevars((jl_value_t*)types))
return NULL; // don't poison the cache due to a malformed query
Expand All @@ -2468,7 +2496,7 @@ jl_method_instance_t *jl_get_compile_hint_specialization(jl_tupletype_t *types J
if (n == 1) {
match = (jl_method_match_t*)jl_array_ptr_ref(matches, 0);
}
else {
else if (jl_is_datatype(types)) {
// first, select methods for which `types` is compileable
size_t count = 0;
for (i = 0; i < n; i++) {
Expand Down Expand Up @@ -2839,7 +2867,9 @@ JL_DLLEXPORT jl_value_t *jl_apply_generic(jl_value_t *F, jl_value_t **args, uint
static jl_method_match_t *_gf_invoke_lookup(jl_value_t *types JL_PROPAGATES_ROOT, jl_value_t *mt, size_t world, size_t *min_valid, size_t *max_valid)
{
jl_value_t *unw = jl_unwrap_unionall((jl_value_t*)types);
if (jl_is_tuple_type(unw) && jl_tparam0(unw) == jl_bottom_type)
if (!jl_is_tuple_type(unw))
return NULL;
if (jl_tparam0(unw) == jl_bottom_type)
return NULL;
if (mt == jl_nothing)
mt = (jl_value_t*)jl_method_table_for(unw);
Expand Down
4 changes: 4 additions & 0 deletions test/compiler/codegen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -790,3 +790,7 @@ f47247(a::Ref{Int}, b::Nothing) = setfield!(a, :x, b)
f(x) = Core.bitcast(UInt64, x)
@test occursin("llvm.trap", get_llvm(f, Tuple{Union{}}))
end

f48085(@nospecialize x...) = length(x)
@test Core.Compiler.get_compileable_sig(which(f48085, (Vararg{Any},)), Tuple{typeof(f48085), Vararg{Int}}, Core.svec()) === nothing
@test Core.Compiler.get_compileable_sig(which(f48085, (Vararg{Any},)), Tuple{typeof(f48085), Int, Vararg{Int}}, Core.svec()) === Tuple{typeof(f48085), Any, Vararg{Any}}