diff --git a/src/coreclr/jit/async.cpp b/src/coreclr/jit/async.cpp index e685e632aef289..ec92f9b23ca6a7 100644 --- a/src/coreclr/jit/async.cpp +++ b/src/coreclr/jit/async.cpp @@ -169,6 +169,8 @@ PhaseStatus Compiler::SaveAsyncContexts() // Await is inside a try, need to insert try-finally around it. restoreBB = InsertTryFinallyForContextRestore(curBB, stmt, restoreAfterStmt); restoreAfterStmt = nullptr; + // we have split the block that could have another await. + nextBB = restoreBB->Next(); #endif } @@ -1506,8 +1508,9 @@ void AsyncTransformation::FillInGCPointersOnSuspension(GenTreeCall* if (layout.ContinuationContextGCDataIndex != UINT_MAX) { const AsyncCallInfo& callInfo = call->GetAsyncInfo(); - assert(callInfo.SaveAndRestoreSynchronizationContextField && - (callInfo.SynchronizationContextLclNum != BAD_VAR_NUM)); + assert(callInfo.SaveAndRestoreSynchronizationContextField); + assert(callInfo.ExecutionContextHandling == ExecutionContextHandling::SaveAndRestore); + assert(callInfo.SynchronizationContextLclNum != BAD_VAR_NUM); // Insert call // AsyncHelpers.CaptureContinuationContext( diff --git a/src/coreclr/jit/codegenxarch.cpp b/src/coreclr/jit/codegenxarch.cpp index bf81042a88d841..bde2493ab0d137 100644 --- a/src/coreclr/jit/codegenxarch.cpp +++ b/src/coreclr/jit/codegenxarch.cpp @@ -106,14 +106,21 @@ void CodeGen::genEmitGSCookieCheck(bool pushReg) // we are generating GS cookie check after a GT_RETURN block. // Note: On Amd64 System V RDX is an arg register - REG_ARG_2 - as well // as return register for two-register-returned structs. +#ifdef TARGET_X86 + // Note: ARG_0 can be REG_ASYNC_CONTINUATION_RET + // we will check for that later if we end up saving/restoring this. + regGSCheck = REG_ARG_0; + regNumber regGSCheckAlternative = REG_ARG_1; +#else + // these cannot be a part of any kind of return + regGSCheck = REG_R8; + regNumber regGSCheckAlternative = REG_R9; +#endif + if (compiler->lvaKeepAliveAndReportThis() && compiler->lvaGetDesc(compiler->info.compThisArg)->lvIsInReg() && - (compiler->lvaGetDesc(compiler->info.compThisArg)->GetRegNum() == REG_ARG_0)) + (compiler->lvaGetDesc(compiler->info.compThisArg)->GetRegNum() == regGSCheck)) { - regGSCheck = REG_ARG_1; - } - else - { - regGSCheck = REG_ARG_0; + regGSCheck = regGSCheckAlternative; } } else @@ -158,6 +165,14 @@ void CodeGen::genEmitGSCookieCheck(bool pushReg) { // AOT case - GS cookie value needs to be accessed through an indirection. + // if we use the continuation reg, the pop/push requires no-GC + // this can happen only when AOT supports async on x86 + if (compiler->compIsAsync() && (regGSCheck == REG_ASYNC_CONTINUATION_RET)) + { + regMaskGSCheck = RBM_ASYNC_CONTINUATION_RET; + GetEmitter()->emitDisableGC(); + } + pushedRegs = genPushRegs(regMaskGSCheck, &byrefPushedRegs, &norefPushedRegs); instGen_Set_Reg_To_Imm(EA_HANDLE_CNS_RELOC, regGSCheck, (ssize_t)compiler->gsGlobalSecurityCookieAddr); diff --git a/src/coreclr/vm/asyncthunks.cpp b/src/coreclr/vm/asyncthunks.cpp index 941c0e08b44942..bf2a385cd6d8ec 100644 --- a/src/coreclr/vm/asyncthunks.cpp +++ b/src/coreclr/vm/asyncthunks.cpp @@ -338,6 +338,46 @@ SigPointer MethodDesc::GetAsyncThunkResultTypeSig() return SigPointer(returnTypeSig, (DWORD)(returnTypeSigEnd - returnTypeSig)); } +bool MethodDesc::IsValueTaskAsyncThunk() +{ + _ASSERTE(IsAsyncThunkMethod()); + PCCOR_SIGNATURE pSigRaw; + DWORD cSig; + if (FAILED(GetMDImport()->GetSigOfMethodDef(GetMemberDef(), &cSig, &pSigRaw))) + { + _ASSERTE(!"Loaded MethodDesc should not fail to get signature"); + pSigRaw = NULL; + cSig = 0; + } + + SigPointer pSig(pSigRaw, cSig); + uint32_t callConvInfo; + IfFailThrow(pSig.GetCallingConvInfo(&callConvInfo)); + + if ((callConvInfo & IMAGE_CEE_CS_CALLCONV_GENERIC) != 0) + { + // GenParamCount + IfFailThrow(pSig.GetData(NULL)); + } + + // ParamCount + IfFailThrow(pSig.GetData(NULL)); + + // ReturnType comes now. Skip the modifiers. + IfFailThrow(pSig.SkipCustomModifiers()); + + // here we should have something Task, ValueTask, Task or ValueTask + BYTE bElementType; + IfFailThrow(pSig.GetByte(&bElementType)); + + // skip ELEMENT_TYPE_GENERICINST + if (bElementType == ELEMENT_TYPE_GENERICINST) + IfFailThrow(pSig.GetByte(&bElementType)); + + _ASSERTE(bElementType == ELEMENT_TYPE_VALUETYPE || bElementType == ELEMENT_TYPE_CLASS); + return bElementType == ELEMENT_TYPE_VALUETYPE; +} + // Given a method Foo, return a MethodSpec token for Foo instantiated // with the result type from the current async method's return type. For // example, if "this" represents Task> Foo(), and "md" is @@ -435,30 +475,32 @@ void MethodDesc::EmitAsyncMethodThunk(MethodDesc* pAsyncOtherVariant, MetaSig& m MethodDesc* mdIsCompleted; MethodDesc* mdGetResult; + bool isValueTask = IsValueTaskAsyncThunk(); + if (msig.IsReturnTypeVoid()) { - pMTTask = CoreLibBinder::GetClass(CLASS__TASK); - thTaskAwaiter = CoreLibBinder::GetClass(CLASS__TASK_AWAITER); - mdGetAwaiter = CoreLibBinder::GetMethod(METHOD__TASK__GET_AWAITER); - mdIsCompleted = CoreLibBinder::GetMethod(METHOD__TASK_AWAITER__GET_ISCOMPLETED); - mdGetResult = CoreLibBinder::GetMethod(METHOD__TASK_AWAITER__GET_RESULT); + pMTTask = CoreLibBinder::GetClass(isValueTask ? CLASS__VALUETASK : CLASS__TASK); + thTaskAwaiter = CoreLibBinder::GetClass(isValueTask ? CLASS__VALUETASK_AWAITER : CLASS__TASK_AWAITER); + mdGetAwaiter = CoreLibBinder::GetMethod(isValueTask ? METHOD__VALUETASK__GET_AWAITER : METHOD__TASK__GET_AWAITER); + mdIsCompleted = CoreLibBinder::GetMethod(isValueTask ? METHOD__VALUETASK_AWAITER__GET_ISCOMPLETED : METHOD__TASK_AWAITER__GET_ISCOMPLETED); + mdGetResult = CoreLibBinder::GetMethod(isValueTask ? METHOD__VALUETASK_AWAITER__GET_RESULT : METHOD__TASK_AWAITER__GET_RESULT); } else { TypeHandle thLogicalRetType = msig.GetRetTypeHandleThrowing(); - MethodTable* pMTTaskOpen = CoreLibBinder::GetClass(CLASS__TASK_1); + MethodTable* pMTTaskOpen = CoreLibBinder::GetClass(isValueTask ? CLASS__VALUETASK_1 : CLASS__TASK_1); pMTTask = ClassLoader::LoadGenericInstantiationThrowing(pMTTaskOpen->GetModule(), pMTTaskOpen->GetCl(), Instantiation(&thLogicalRetType, 1)).GetMethodTable(); - MethodTable* pMTTaskAwaiterOpen = CoreLibBinder::GetClass(CLASS__TASK_AWAITER_1); + MethodTable* pMTTaskAwaiterOpen = CoreLibBinder::GetClass(isValueTask ? CLASS__VALUETASK_AWAITER_1 : CLASS__TASK_AWAITER_1); thTaskAwaiter = ClassLoader::LoadGenericInstantiationThrowing(pMTTaskAwaiterOpen->GetModule(), pMTTaskAwaiterOpen->GetCl(), Instantiation(&thLogicalRetType, 1)); - mdGetAwaiter = CoreLibBinder::GetMethod(METHOD__TASK_1__GET_AWAITER); + mdGetAwaiter = CoreLibBinder::GetMethod(isValueTask ? METHOD__VALUETASK_1__GET_AWAITER : METHOD__TASK_1__GET_AWAITER); mdGetAwaiter = MethodDesc::FindOrCreateAssociatedMethodDesc(mdGetAwaiter, pMTTask, FALSE, Instantiation(), FALSE); - mdIsCompleted = CoreLibBinder::GetMethod(METHOD__TASK_AWAITER_1__GET_ISCOMPLETED); + mdIsCompleted = CoreLibBinder::GetMethod(isValueTask ? METHOD__VALUETASK_AWAITER_1__GET_ISCOMPLETED : METHOD__TASK_AWAITER_1__GET_ISCOMPLETED); mdIsCompleted = MethodDesc::FindOrCreateAssociatedMethodDesc(mdIsCompleted, thTaskAwaiter.GetMethodTable(), FALSE, Instantiation(), FALSE); - mdGetResult = CoreLibBinder::GetMethod(METHOD__TASK_AWAITER_1__GET_RESULT); + mdGetResult = CoreLibBinder::GetMethod(isValueTask ? METHOD__VALUETASK_AWAITER_1__GET_RESULT : METHOD__TASK_AWAITER_1__GET_RESULT); mdGetResult = MethodDesc::FindOrCreateAssociatedMethodDesc(mdGetResult, thTaskAwaiter.GetMethodTable(), FALSE, Instantiation(), FALSE); } @@ -550,7 +592,19 @@ void MethodDesc::EmitAsyncMethodThunk(MethodDesc* pAsyncOtherVariant, MetaSig& m getResultToken = pCode->GetToken(mdGetResult); } - pCode->EmitCALLVIRT(getAwaiterToken, 1, 1); + if (isValueTask) + { + LocalDesc valuetaskLocalDesc(pMTTask); + DWORD valuetaskLocal = pCode->NewLocal(valuetaskLocalDesc); + pCode->EmitSTLOC(valuetaskLocal); + pCode->EmitLDLOCA(valuetaskLocal); + pCode->EmitCALL(getAwaiterToken, 1, 1); + } + else + { + pCode->EmitCALLVIRT(getAwaiterToken, 1, 1); + } + pCode->EmitSTLOC(awaiterLocal); pCode->EmitLDLOCA(awaiterLocal); pCode->EmitCALL(getIsCompletedToken, 1, 1); diff --git a/src/coreclr/vm/corelib.h b/src/coreclr/vm/corelib.h index 0b979353253f72..e188c5e1dff185 100644 --- a/src/coreclr/vm/corelib.h +++ b/src/coreclr/vm/corelib.h @@ -344,12 +344,14 @@ DEFINE_CLASS(THREAD_START_EXCEPTION,Threading, ThreadStartException DEFINE_METHOD(THREAD_START_EXCEPTION,EX_CTOR, .ctor, IM_Exception_RetVoid) DEFINE_CLASS(VALUETASK_1, Tasks, ValueTask`1) +DEFINE_METHOD(VALUETASK_1, GET_AWAITER, GetAwaiter, NoSig) DEFINE_CLASS(VALUETASK, Tasks, ValueTask) DEFINE_METHOD(VALUETASK, FROM_EXCEPTION, FromException, SM_Exception_RetValueTask) DEFINE_METHOD(VALUETASK, FROM_EXCEPTION_1, FromException, GM_Exception_RetValueTaskOfT) DEFINE_METHOD(VALUETASK, FROM_RESULT_T, FromResult, GM_T_RetValueTaskOfT) DEFINE_METHOD(VALUETASK, GET_COMPLETED_TASK, get_CompletedTask, SM_RetValueTask) +DEFINE_METHOD(VALUETASK, GET_AWAITER, GetAwaiter, NoSig) DEFINE_CLASS(TASK_1, Tasks, Task`1) DEFINE_METHOD(TASK_1, GET_AWAITER, GetAwaiter, NoSig) @@ -369,6 +371,14 @@ DEFINE_CLASS(TASK_AWAITER, CompilerServices, TaskAwaiter) DEFINE_METHOD(TASK_AWAITER, GET_ISCOMPLETED, get_IsCompleted, NoSig) DEFINE_METHOD(TASK_AWAITER, GET_RESULT, GetResult, NoSig) +DEFINE_CLASS(VALUETASK_AWAITER_1, CompilerServices, ValueTaskAwaiter`1) +DEFINE_METHOD(VALUETASK_AWAITER_1, GET_ISCOMPLETED, get_IsCompleted, NoSig) +DEFINE_METHOD(VALUETASK_AWAITER_1, GET_RESULT, GetResult, NoSig) + +DEFINE_CLASS(VALUETASK_AWAITER, CompilerServices, ValueTaskAwaiter) +DEFINE_METHOD(VALUETASK_AWAITER, GET_ISCOMPLETED, get_IsCompleted, NoSig) +DEFINE_METHOD(VALUETASK_AWAITER, GET_RESULT, GetResult, NoSig) + DEFINE_CLASS(TYPE_HANDLE, System, RuntimeTypeHandle) DEFINE_CLASS(RT_TYPE_HANDLE, System, RuntimeTypeHandle) DEFINE_METHOD(RT_TYPE_HANDLE, PVOID_CTOR, .ctor, IM_RuntimeType_RetVoid) diff --git a/src/coreclr/vm/method.hpp b/src/coreclr/vm/method.hpp index 07683f725f4a8d..9e272f64457e54 100644 --- a/src/coreclr/vm/method.hpp +++ b/src/coreclr/vm/method.hpp @@ -2099,6 +2099,7 @@ class MethodDesc void EmitTaskReturningThunk(MethodDesc* pAsyncOtherVariant, MetaSig& thunkMsig, ILStubLinker* pSL); void EmitAsyncMethodThunk(MethodDesc* pAsyncOtherVariant, MetaSig& msig, ILStubLinker* pSL); SigPointer GetAsyncThunkResultTypeSig(); + bool IsValueTaskAsyncThunk(); int GetTokenForGenericMethodCallWithAsyncReturnType(ILCodeStream* pCode, MethodDesc* md); int GetTokenForGenericTypeMethodCallWithAsyncReturnType(ILCodeStream* pCode, MethodDesc* md); int GetTokenForAwaitAwaiterInstantiatedOverTaskAwaiterType(ILCodeStream* pCode, TypeHandle taskAwaiterType);