Skip to content

Commit

Permalink
fix #30679, call correct method for invoke calls in jl_invoke fal…
Browse files Browse the repository at this point in the history
…lback (#31845)

(cherry picked from commit 4e20cc2)
  • Loading branch information
JeffBezanson authored and KristofferC committed Apr 27, 2019
1 parent ef22206 commit e5de459
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 28 deletions.
69 changes: 41 additions & 28 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -35,28 +35,6 @@ JL_DLLEXPORT size_t jl_get_tls_world_age(void)
return jl_get_ptls_states()->world_age;
}

JL_DLLEXPORT jl_value_t *jl_invoke(jl_method_instance_t *meth, jl_value_t **args, uint32_t nargs)
{
jl_callptr_t fptr = meth->invoke;
if (fptr != jl_fptr_trampoline) {
return fptr(meth, args, nargs);
}
else {
// if this hasn't been inferred (compiled) yet,
// inferring it might not be able to handle the world range
// so we just do a generic apply here
// because that might actually be faster
// since it can go through the unrolled caches for this world
// and if inference is successful, this meth would get updated anyways,
// and we'll get the fast path here next time

// TODO: if `meth` came from an `invoke` call, we should make sure
// meth->def is called instead of doing normal dispatch.

return jl_apply(args, nargs);
}
}

/// ----- Handling for Julia callbacks ----- ///

JL_DLLEXPORT int8_t jl_is_in_pure_context(void)
Expand Down Expand Up @@ -2200,6 +2178,8 @@ JL_DLLEXPORT jl_value_t *jl_gf_invoke_lookup(jl_value_t *types, size_t world)
return (jl_value_t*)entry;
}

jl_value_t *jl_gf_invoke_by_method(jl_method_t *method, jl_value_t **args, size_t nargs);

// invoke()
// this does method dispatch with a set of types to match other than the
// types of the actual arguments. this means it sometimes does NOT call the
Expand All @@ -2212,13 +2192,10 @@ JL_DLLEXPORT jl_value_t *jl_gf_invoke_lookup(jl_value_t *types, size_t world)
jl_value_t *jl_gf_invoke(jl_value_t *types0, jl_value_t **args, size_t nargs)
{
size_t world = jl_get_ptls_states()->world_age;
jl_svec_t *tpenv = jl_emptysvec;
jl_tupletype_t *tt = NULL;
jl_value_t *types = NULL;
JL_GC_PUSH3(&types, &tpenv, &tt);
JL_GC_PUSH1(&types);
jl_value_t *gf = args[0];
types = jl_argtype_with_function(gf, types0);
jl_methtable_t *mt = jl_gf_mtable(gf);
jl_typemap_entry_t *entry = (jl_typemap_entry_t*)jl_gf_invoke_lookup(types, world);

if ((jl_value_t*)entry == jl_nothing) {
Expand All @@ -2228,10 +2205,19 @@ jl_value_t *jl_gf_invoke(jl_value_t *types0, jl_value_t **args, size_t nargs)

// now we have found the matching definition.
// next look for or create a specialization of this definition.
JL_GC_POP();
return jl_gf_invoke_by_method(entry->func.method, args, nargs);
}

jl_method_t *method = entry->func.method;
jl_value_t *jl_gf_invoke_by_method(jl_method_t *method, jl_value_t **args, size_t nargs)
{
size_t world = jl_get_ptls_states()->world_age;
jl_method_instance_t *mfunc = NULL;
jl_typemap_entry_t *tm = NULL;
jl_methtable_t *mt = jl_gf_mtable(args[0]);
jl_svec_t *tpenv = jl_emptysvec;
jl_tupletype_t *tt = NULL;
JL_GC_PUSH2(&tpenv, &tt);
if (method->invokes.unknown != NULL)
tm = jl_typemap_assoc_exact(method->invokes, args, nargs, jl_cachearg_offset(mt), world);
if (tm) {
Expand All @@ -2248,7 +2234,7 @@ jl_value_t *jl_gf_invoke(jl_value_t *types0, jl_value_t **args, size_t nargs)
if (method->invokes.unknown == NULL)
method->invokes.unknown = jl_nothing;

mfunc = cache_method(mt, &method->invokes, entry->func.value, tt, method, world, tpenv, 1);
mfunc = cache_method(mt, &method->invokes, (jl_value_t*)method, tt, method, world, tpenv, 1);
JL_UNLOCK(&method->writelock);
}
JL_GC_POP();
Expand Down Expand Up @@ -2303,6 +2289,33 @@ JL_DLLEXPORT jl_value_t *jl_get_invoke_lambda(jl_methtable_t *mt,
return (jl_value_t*)mfunc;
}

JL_DLLEXPORT jl_value_t *jl_invoke(jl_method_instance_t *meth, jl_value_t **args, uint32_t nargs)
{
jl_callptr_t fptr = meth->invoke;
if (fptr != jl_fptr_trampoline) {
return fptr(meth, args, nargs);
}
else {
// if this hasn't been inferred (compiled) yet,
// inferring it might not be able to handle the world range
// so we just do a generic apply here
// because that might actually be faster
// since it can go through the unrolled caches for this world
// and if inference is successful, this meth would get updated anyways,
// and we'll get the fast path here next time

jl_method_instance_t *mfunc = jl_lookup_generic_(args, nargs,
jl_int32hash_fast(jl_return_address()),
jl_get_ptls_states()->world_age);
// check whether `jl_apply_generic` would call the right method
if (mfunc->def.method == meth->def.method)
return mfunc->invoke(mfunc, args, nargs);

// no; came from an `invoke` call
return jl_gf_invoke_by_method(meth->def.method, args, nargs);
}
}

// Return value is rooted globally
jl_function_t *jl_new_generic_function_with_supertype(jl_sym_t *name, jl_module_t *module, jl_datatype_t *st, int iskw)
{
Expand Down
14 changes: 14 additions & 0 deletions test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2432,6 +2432,20 @@ const T24460 = Tuple{T,T} where T
g24460() = invoke(f24460, T24460, 1, 2)
@test @inferred(g24460()) === 2.0

# issue #30679
@noinline function f30679(::DataType)
b = IOBuffer()
write(b, 0x00)
2
end
@noinline function f30679(t::Type{Int})
x = invoke(f30679, Tuple{DataType}, t)
b = IOBuffer()
write(b, 0x00)
return x + 40
end
@test f30679(Int) == 42

call_lambda1() = (()->x)(1)
call_lambda2() = ((x)->x)()
call_lambda3() = ((x)->x)(1,2)
Expand Down

0 comments on commit e5de459

Please sign in to comment.