Skip to content

Avoid Promise<Awaited<T>> in return type inference #45925

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 40 additions & 29 deletions src/compiler/checker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25458,8 +25458,8 @@ namespace ts {

if (functionFlags & FunctionFlags.Async) { // Async function or AsyncGenerator function
// Get the awaited type without the `Awaited<T>` alias
const contextualAwaitedType = mapType(contextualReturnType, getAwaitedType);
return contextualAwaitedType && getUnionType([unwrapAwaitedType(contextualAwaitedType), createPromiseLikeType(contextualAwaitedType)]);
const contextualAwaitedType = mapType(contextualReturnType, getAwaitedTypeNoAlias);
return contextualAwaitedType && getUnionType([contextualAwaitedType, createPromiseLikeType(contextualAwaitedType)]);
}

return contextualReturnType; // Regular function or Generator function
Expand All @@ -25471,8 +25471,8 @@ namespace ts {
function getContextualTypeForAwaitOperand(node: AwaitExpression, contextFlags?: ContextFlags): Type | undefined {
const contextualType = getContextualType(node, contextFlags);
if (contextualType) {
const contextualAwaitedType = getAwaitedType(contextualType);
return contextualAwaitedType && getUnionType([unwrapAwaitedType(contextualAwaitedType), createPromiseLikeType(contextualAwaitedType)]);
const contextualAwaitedType = getAwaitedTypeNoAlias(contextualType);
return contextualAwaitedType && getUnionType([contextualAwaitedType, createPromiseLikeType(contextualAwaitedType)]);
}
return undefined;
}
Expand Down Expand Up @@ -31136,7 +31136,8 @@ namespace ts {
const globalPromiseType = getGlobalPromiseType(/*reportErrors*/ true);
if (globalPromiseType !== emptyGenericType) {
// if the promised type is itself a promise, get the underlying type; otherwise, fallback to the promised type
promisedType = getAwaitedType(promisedType) || unknownType;
// Unwrap an `Awaited<T>` to `T` to improve inference.
promisedType = getAwaitedTypeNoAlias(unwrapAwaitedType(promisedType)) || unknownType;
return createTypeReference(globalPromiseType, [promisedType]);
}

Expand All @@ -31148,7 +31149,8 @@ namespace ts {
const globalPromiseLikeType = getGlobalPromiseLikeType(/*reportErrors*/ true);
if (globalPromiseLikeType !== emptyGenericType) {
// if the promised type is itself a promise, get the underlying type; otherwise, fallback to the promised type
promisedType = getAwaitedType(promisedType) || unknownType;
// Unwrap an `Awaited<T>` to `T` to improve inference.
promisedType = getAwaitedTypeNoAlias(unwrapAwaitedType(promisedType)) || unknownType;
return createTypeReference(globalPromiseLikeType, [promisedType]);
}

Expand Down Expand Up @@ -31205,7 +31207,7 @@ namespace ts {
// Promise/A+ compatible implementation will always assimilate any foreign promise, so the
// return type of the body should be unwrapped to its awaited type, which we will wrap in
// the native Promise<T> type later in this function.
returnType = checkAwaitedType(returnType, /*errorNode*/ func, Diagnostics.The_return_type_of_an_async_function_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member);
returnType = unwrapAwaitedType(checkAwaitedType(returnType, /*withAlias*/ false, /*errorNode*/ func, Diagnostics.The_return_type_of_an_async_function_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member));
}
}
else if (isGenerator) { // Generator or AsyncGenerator function
Expand Down Expand Up @@ -31438,7 +31440,7 @@ namespace ts {
// Promise/A+ compatible implementation will always assimilate any foreign promise, so the
// return type of the body should be unwrapped to its awaited type, which should be wrapped in
// the native Promise<T> type by the caller.
type = checkAwaitedType(type, func, Diagnostics.The_return_type_of_an_async_function_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member);
type = unwrapAwaitedType(checkAwaitedType(type, /*withAlias*/ false, func, Diagnostics.The_return_type_of_an_async_function_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member));
}
if (type.flags & TypeFlags.Never) {
hasReturnOfTypeNever = true;
Expand Down Expand Up @@ -31640,7 +31642,7 @@ namespace ts {
const returnOrPromisedType = returnType && unwrapReturnType(returnType, functionFlags);
if (returnOrPromisedType) {
if ((functionFlags & FunctionFlags.AsyncGenerator) === FunctionFlags.Async) { // Async function
const awaitedType = checkAwaitedType(exprType, node.body, Diagnostics.The_return_type_of_an_async_function_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member);
const awaitedType = checkAwaitedType(exprType, /*withAlias*/ false, node.body, Diagnostics.The_return_type_of_an_async_function_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member);
checkTypeAssignableToAndOptionallyElaborate(awaitedType, returnOrPromisedType, node.body, node.body);
}
else { // Normal function
Expand Down Expand Up @@ -31857,7 +31859,7 @@ namespace ts {
}

const operandType = checkExpression(node.expression);
const awaitedType = checkAwaitedType(operandType, node, Diagnostics.Type_of_await_operand_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member);
const awaitedType = checkAwaitedType(operandType, /*withAlias*/ true, node, Diagnostics.Type_of_await_operand_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member);
if (awaitedType === operandType && awaitedType !== errorType && !(operandType.flags & TypeFlags.AnyOrUnknown)) {
addErrorOrSuggestion(/*isError*/ false, createDiagnosticForNode(node, Diagnostics.await_has_no_effect_on_the_type_of_this_expression));
}
Expand Down Expand Up @@ -32809,8 +32811,8 @@ namespace ts {
let wouldWorkWithAwait = false;
const errNode = errorNode || operatorToken;
if (isRelated) {
const awaitedLeftType = unwrapAwaitedType(getAwaitedType(leftType));
const awaitedRightType = unwrapAwaitedType(getAwaitedType(rightType));
const awaitedLeftType = getAwaitedTypeNoAlias(leftType);
const awaitedRightType = getAwaitedTypeNoAlias(rightType);
wouldWorkWithAwait = !(awaitedLeftType === leftType && awaitedRightType === rightType)
&& !!(awaitedLeftType && awaitedRightType)
&& isRelated(awaitedLeftType, awaitedRightType);
Expand Down Expand Up @@ -34892,12 +34894,15 @@ namespace ts {
/**
* Gets the "awaited type" of a type.
* @param type The type to await.
* @param withAlias When `true`, wraps the "awaited type" in `Awaited<T>` if needed.
* @remarks The "awaited type" of an expression is its "promised type" if the expression is a
* Promise-like type; otherwise, it is the type of the expression. This is used to reflect
* The runtime behavior of the `await` keyword.
*/
function checkAwaitedType(type: Type, errorNode: Node, diagnosticMessage: DiagnosticMessage, arg0?: string | number): Type {
const awaitedType = getAwaitedType(type, errorNode, diagnosticMessage, arg0);
function checkAwaitedType(type: Type, withAlias: boolean, errorNode: Node, diagnosticMessage: DiagnosticMessage, arg0?: string | number): Type {
const awaitedType = withAlias ?
getAwaitedType(type, errorNode, diagnosticMessage, arg0) :
getAwaitedTypeNoAlias(type, errorNode, diagnosticMessage, arg0);
return awaitedType || errorType;
}

Expand Down Expand Up @@ -34931,10 +34936,7 @@ namespace ts {
/**
* For a generic `Awaited<T>`, gets `T`.
*/
function unwrapAwaitedType(type: Type): Type;
function unwrapAwaitedType(type: Type | undefined): Type | undefined;
function unwrapAwaitedType(type: Type | undefined) {
if (!type) return undefined;
function unwrapAwaitedType(type: Type) {
return type.flags & TypeFlags.Union ? mapType(type, unwrapAwaitedType) :
isAwaitedTypeInstantiation(type) ? type.aliasTypeArguments[0] :
type;
Expand Down Expand Up @@ -34989,6 +34991,16 @@ namespace ts {
* This is used to reflect the runtime behavior of the `await` keyword.
*/
function getAwaitedType(type: Type, errorNode?: Node, diagnosticMessage?: DiagnosticMessage, arg0?: string | number): Type | undefined {
const awaitedType = getAwaitedTypeNoAlias(type, errorNode, diagnosticMessage, arg0);
return awaitedType && createAwaitedTypeIfNeeded(awaitedType);
}

/**
* Gets the "awaited type" of a type without introducing an `Awaited<T>` wrapper.
*
* @see {@link getAwaitedType}
*/
function getAwaitedTypeNoAlias(type: Type, errorNode?: Node, diagnosticMessage?: DiagnosticMessage, arg0?: string | number): Type | undefined {
if (isTypeAny(type)) {
return type;
}
Expand All @@ -35001,14 +35013,13 @@ namespace ts {
// If we've already cached an awaited type, return a possible `Awaited<T>` for it.
const typeAsAwaitable = type as PromiseOrAwaitableType;
if (typeAsAwaitable.awaitedTypeOfType) {
return createAwaitedTypeIfNeeded(typeAsAwaitable.awaitedTypeOfType);
return typeAsAwaitable.awaitedTypeOfType;
}

// For a union, get a union of the awaited types of each constituent.
if (type.flags & TypeFlags.Union) {
const mapper = errorNode ? (constituentType: Type) => getAwaitedType(constituentType, errorNode, diagnosticMessage, arg0) : getAwaitedType;
typeAsAwaitable.awaitedTypeOfType = mapType(type, mapper);
return typeAsAwaitable.awaitedTypeOfType && createAwaitedTypeIfNeeded(typeAsAwaitable.awaitedTypeOfType);
const mapper = errorNode ? (constituentType: Type) => getAwaitedTypeNoAlias(constituentType, errorNode, diagnosticMessage, arg0) : getAwaitedTypeNoAlias;
return typeAsAwaitable.awaitedTypeOfType = mapType(type, mapper);
}

const promisedType = getPromisedTypeOfPromise(type);
Expand Down Expand Up @@ -35056,14 +35067,14 @@ namespace ts {
// Keep track of the type we're about to unwrap to avoid bad recursive promise types.
// See the comments above for more information.
awaitedTypeStack.push(type.id);
const awaitedType = getAwaitedType(promisedType, errorNode, diagnosticMessage, arg0);
const awaitedType = getAwaitedTypeNoAlias(promisedType, errorNode, diagnosticMessage, arg0);
awaitedTypeStack.pop();

if (!awaitedType) {
return undefined;
}

return createAwaitedTypeIfNeeded(typeAsAwaitable.awaitedTypeOfType = awaitedType);
return typeAsAwaitable.awaitedTypeOfType = awaitedType;
}

// The type was not a promise, so it could not be unwrapped any further.
Expand All @@ -35089,7 +35100,7 @@ namespace ts {
return undefined;
}

return createAwaitedTypeIfNeeded(typeAsAwaitable.awaitedTypeOfType = type);
return typeAsAwaitable.awaitedTypeOfType = type;
}

/**
Expand Down Expand Up @@ -35139,7 +35150,7 @@ namespace ts {
if (globalPromiseType !== emptyGenericType && !isReferenceToType(returnType, globalPromiseType)) {
// The promise type was not a valid type reference to the global promise type, so we
// report an error and return the unknown type.
error(returnTypeNode, Diagnostics.The_return_type_of_an_async_function_or_method_must_be_the_global_Promise_T_type_Did_you_mean_to_write_Promise_0, typeToString(unwrapAwaitedType(getAwaitedType(returnType)) || voidType));
error(returnTypeNode, Diagnostics.The_return_type_of_an_async_function_or_method_must_be_the_global_Promise_T_type_Did_you_mean_to_write_Promise_0, typeToString(getAwaitedTypeNoAlias(returnType) || voidType));
return;
}
}
Expand Down Expand Up @@ -35192,7 +35203,7 @@ namespace ts {
return;
}
}
checkAwaitedType(returnType, node, Diagnostics.The_return_type_of_an_async_function_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member);
checkAwaitedType(returnType, /*withAlias*/ false, node, Diagnostics.The_return_type_of_an_async_function_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member);
}

/** Check a decorator */
Expand Down Expand Up @@ -37473,7 +37484,7 @@ namespace ts {
const isGenerator = !!(functionFlags & FunctionFlags.Generator);
const isAsync = !!(functionFlags & FunctionFlags.Async);
return isGenerator ? getIterationTypeOfGeneratorFunctionReturnType(IterationTypeKind.Return, returnType, isAsync) || errorType :
isAsync ? unwrapAwaitedType(getAwaitedType(returnType)) || errorType :
isAsync ? getAwaitedTypeNoAlias(returnType) || errorType :
returnType;
}

Expand Down Expand Up @@ -37517,7 +37528,7 @@ namespace ts {
else if (getReturnTypeFromAnnotation(container)) {
const unwrappedReturnType = unwrapReturnType(returnType, functionFlags) ?? returnType;
const unwrappedExprType = functionFlags & FunctionFlags.Async
? checkAwaitedType(exprType, node, Diagnostics.The_return_type_of_an_async_function_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member)
? checkAwaitedType(exprType, /*withAlias*/ false, node, Diagnostics.The_return_type_of_an_async_function_must_either_be_a_valid_promise_or_must_not_contain_a_callable_then_member)
: exprType;
if (unwrappedReturnType) {
// If the function has a return type, but promisedType is
Expand Down
16 changes: 16 additions & 0 deletions tests/baselines/reference/awaitedTypeStrictNull.errors.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,22 @@ tests/cases/compiler/awaitedTypeStrictNull.ts(22,12): error TS2589: Type instant
])
}

// https://github.com/microsoft/TypeScript/issues/45924
class Api<D = {}> {
// Should result in `Promise<T>` instead of `Promise<Awaited<T>>`.
async post<T = D>() { return this.request<T>(); }
async request<D>(): Promise<D> { throw new Error(); }
}

declare const api: Api;
interface Obj { x: number }

async function fn<T>(): Promise<T extends object ? { [K in keyof T]: Obj } : Obj> {
// Per #45924, this was failing due to incorrect inference both above and here.
// Should not error.
return api.post();
}

// helps with tests where '.types' just prints out the type alias name
type _Expect<TActual extends TExpected, TExpected> = TActual;

27 changes: 27 additions & 0 deletions tests/baselines/reference/awaitedTypeStrictNull.js
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,22 @@ async function main() {
])
}

// https://github.com/microsoft/TypeScript/issues/45924
class Api<D = {}> {
// Should result in `Promise<T>` instead of `Promise<Awaited<T>>`.
async post<T = D>() { return this.request<T>(); }
async request<D>(): Promise<D> { throw new Error(); }
}

declare const api: Api;
interface Obj { x: number }

async function fn<T>(): Promise<T extends object ? { [K in keyof T]: Obj } : Obj> {
// Per #45924, this was failing due to incorrect inference both above and here.
// Should not error.
return api.post();
}

// helps with tests where '.types' just prints out the type alias name
type _Expect<TActual extends TExpected, TExpected> = TActual;

Expand All @@ -56,3 +72,14 @@ async function main() {
MaybePromise(true),
]);
}
// https://github.com/microsoft/TypeScript/issues/45924
class Api {
// Should result in `Promise<T>` instead of `Promise<Awaited<T>>`.
async post() { return this.request(); }
async request() { throw new Error(); }
}
async function fn() {
// Per #45924, this was failing due to incorrect inference both above and here.
// Should not error.
return api.post();
}
Loading