diff --git a/src/gf.c b/src/gf.c index 520db1ab45a98a..c68e20a449e778 100644 --- a/src/gf.c +++ b/src/gf.c @@ -35,28 +35,6 @@ JL_DLLEXPORT size_t jl_get_tls_world_age(void) return jl_get_ptls_states()->world_age; } -JL_DLLEXPORT jl_value_t *jl_invoke(jl_method_instance_t *meth, jl_value_t **args, uint32_t nargs) -{ - jl_callptr_t fptr = meth->invoke; - if (fptr != jl_fptr_trampoline) { - return fptr(meth, args, nargs); - } - else { - // if this hasn't been inferred (compiled) yet, - // inferring it might not be able to handle the world range - // so we just do a generic apply here - // because that might actually be faster - // since it can go through the unrolled caches for this world - // and if inference is successful, this meth would get updated anyways, - // and we'll get the fast path here next time - - // TODO: if `meth` came from an `invoke` call, we should make sure - // meth->def is called instead of doing normal dispatch. - - return jl_apply(args, nargs); - } -} - /// ----- Handling for Julia callbacks ----- /// JL_DLLEXPORT int8_t jl_is_in_pure_context(void) @@ -2200,6 +2178,8 @@ JL_DLLEXPORT jl_value_t *jl_gf_invoke_lookup(jl_value_t *types, size_t world) return (jl_value_t*)entry; } +jl_value_t *jl_gf_invoke_by_method(jl_method_t *method, jl_value_t **args, size_t nargs); + // invoke() // this does method dispatch with a set of types to match other than the // types of the actual arguments. this means it sometimes does NOT call the @@ -2212,13 +2192,10 @@ JL_DLLEXPORT jl_value_t *jl_gf_invoke_lookup(jl_value_t *types, size_t world) jl_value_t *jl_gf_invoke(jl_value_t *types0, jl_value_t **args, size_t nargs) { size_t world = jl_get_ptls_states()->world_age; - jl_svec_t *tpenv = jl_emptysvec; - jl_tupletype_t *tt = NULL; jl_value_t *types = NULL; - JL_GC_PUSH3(&types, &tpenv, &tt); + JL_GC_PUSH1(&types); jl_value_t *gf = args[0]; types = jl_argtype_with_function(gf, types0); - jl_methtable_t *mt = jl_gf_mtable(gf); jl_typemap_entry_t *entry = (jl_typemap_entry_t*)jl_gf_invoke_lookup(types, world); if ((jl_value_t*)entry == jl_nothing) { @@ -2228,10 +2205,19 @@ jl_value_t *jl_gf_invoke(jl_value_t *types0, jl_value_t **args, size_t nargs) // now we have found the matching definition. // next look for or create a specialization of this definition. + JL_GC_POP(); + return jl_gf_invoke_by_method(entry->func.method, args, nargs); +} - jl_method_t *method = entry->func.method; +jl_value_t *jl_gf_invoke_by_method(jl_method_t *method, jl_value_t **args, size_t nargs) +{ + size_t world = jl_get_ptls_states()->world_age; jl_method_instance_t *mfunc = NULL; jl_typemap_entry_t *tm = NULL; + jl_methtable_t *mt = jl_gf_mtable(args[0]); + jl_svec_t *tpenv = jl_emptysvec; + jl_tupletype_t *tt = NULL; + JL_GC_PUSH2(&tpenv, &tt); if (method->invokes.unknown != NULL) tm = jl_typemap_assoc_exact(method->invokes, args, nargs, jl_cachearg_offset(mt), world); if (tm) { @@ -2248,7 +2234,7 @@ jl_value_t *jl_gf_invoke(jl_value_t *types0, jl_value_t **args, size_t nargs) if (method->invokes.unknown == NULL) method->invokes.unknown = jl_nothing; - mfunc = cache_method(mt, &method->invokes, entry->func.value, tt, method, world, tpenv, 1); + mfunc = cache_method(mt, &method->invokes, (jl_value_t*)method, tt, method, world, tpenv, 1); JL_UNLOCK(&method->writelock); } JL_GC_POP(); @@ -2303,6 +2289,33 @@ JL_DLLEXPORT jl_value_t *jl_get_invoke_lambda(jl_methtable_t *mt, return (jl_value_t*)mfunc; } +JL_DLLEXPORT jl_value_t *jl_invoke(jl_method_instance_t *meth, jl_value_t **args, uint32_t nargs) +{ + jl_callptr_t fptr = meth->invoke; + if (fptr != jl_fptr_trampoline) { + return fptr(meth, args, nargs); + } + else { + // if this hasn't been inferred (compiled) yet, + // inferring it might not be able to handle the world range + // so we just do a generic apply here + // because that might actually be faster + // since it can go through the unrolled caches for this world + // and if inference is successful, this meth would get updated anyways, + // and we'll get the fast path here next time + + jl_method_instance_t *mfunc = jl_lookup_generic_(args, nargs, + jl_int32hash_fast(jl_return_address()), + jl_get_ptls_states()->world_age); + // check whether `jl_apply_generic` would call the right method + if (mfunc->def.method == meth->def.method) + return mfunc->invoke(mfunc, args, nargs); + + // no; came from an `invoke` call + return jl_gf_invoke_by_method(meth->def.method, args, nargs); + } +} + // Return value is rooted globally jl_function_t *jl_new_generic_function_with_supertype(jl_sym_t *name, jl_module_t *module, jl_datatype_t *st, int iskw) { diff --git a/test/core.jl b/test/core.jl index ae9a90b5149cee..d1d4ff40bab0a1 100644 --- a/test/core.jl +++ b/test/core.jl @@ -2432,6 +2432,20 @@ const T24460 = Tuple{T,T} where T g24460() = invoke(f24460, T24460, 1, 2) @test @inferred(g24460()) === 2.0 +# issue #30679 +@noinline function f30679(::DataType) + b = IOBuffer() + write(b, 0x00) + 2 +end +@noinline function f30679(t::Type{Int}) + x = invoke(f30679, Tuple{DataType}, t) + b = IOBuffer() + write(b, 0x00) + return x + 40 +end +@test f30679(Int) == 42 + call_lambda1() = (()->x)(1) call_lambda2() = ((x)->x)() call_lambda3() = ((x)->x)(1,2)