diff --git a/src/compiler/factory.ts b/src/compiler/factory.ts index f88950f531a37..1f0294423a5d9 100644 --- a/src/compiler/factory.ts +++ b/src/compiler/factory.ts @@ -4701,7 +4701,7 @@ namespace ts { } } - function getLeftmostExpression(node: Expression, stopAtCallExpressions: boolean) { + export function getLeftmostExpression(node: Expression, stopAtCallExpressions: boolean) { while (true) { switch (node.kind) { case SyntaxKind.PostfixUnaryExpression: diff --git a/src/services/codefixes/removeUnnecessaryAwait.ts b/src/services/codefixes/removeUnnecessaryAwait.ts index c235d028efbce..07028e2793762 100644 --- a/src/services/codefixes/removeUnnecessaryAwait.ts +++ b/src/services/codefixes/removeUnnecessaryAwait.ts @@ -26,8 +26,18 @@ namespace ts.codefix { return; } - const parenthesizedExpression = tryCast(awaitExpression.parent, isParenthesizedExpression); - const removeParens = parenthesizedExpression && (isIdentifier(awaitExpression.expression) || isCallExpression(awaitExpression.expression)); - changeTracker.replaceNode(sourceFile, removeParens ? parenthesizedExpression || awaitExpression : awaitExpression, awaitExpression.expression); + let expressionToReplace: Node = awaitExpression; + const hasSurroundingParens = isParenthesizedExpression(awaitExpression.parent); + if (hasSurroundingParens) { + const leftMostExpression = getLeftmostExpression(awaitExpression.expression, /*stopAtCallExpressions*/ false); + if (isIdentifier(leftMostExpression)) { + const precedingToken = findPrecedingToken(awaitExpression.parent.pos, sourceFile); + if (precedingToken && precedingToken.kind !== SyntaxKind.NewKeyword) { + expressionToReplace = awaitExpression.parent; + } + } + } + + changeTracker.replaceNode(sourceFile, expressionToReplace, awaitExpression.expression); } } diff --git a/tests/cases/fourslash/codeFixRemoveUnnecessaryAwait.ts b/tests/cases/fourslash/codeFixRemoveUnnecessaryAwait.ts index b8e1de4605e27..3e209cd59f208 100644 --- a/tests/cases/fourslash/codeFixRemoveUnnecessaryAwait.ts +++ b/tests/cases/fourslash/codeFixRemoveUnnecessaryAwait.ts @@ -1,5 +1,6 @@ /// ////declare class C { foo(): void } +////declare function getC(): { Class: C }; ////declare function foo(): string; ////async function f() { //// await ""; @@ -7,6 +8,8 @@ //// (await foo()).toLowerCase(); //// (await 0).toFixed(); //// (await new C).foo(); +//// (await function() { }()); +//// new (await getC()).Class(); ////} verify.codeFix({ @@ -14,6 +17,7 @@ verify.codeFix({ index: 0, newFileContent: `declare class C { foo(): void } +declare function getC(): { Class: C }; declare function foo(): string; async function f() { ""; @@ -21,6 +25,8 @@ async function f() { (await foo()).toLowerCase(); (await 0).toFixed(); (await new C).foo(); + (await function() { }()); + new (await getC()).Class(); }` }); @@ -29,6 +35,7 @@ verify.codeFixAll({ fixId: "removeUnnecessaryAwait", newFileContent: `declare class C { foo(): void } +declare function getC(): { Class: C }; declare function foo(): string; async function f() { ""; @@ -36,5 +43,7 @@ async function f() { foo().toLowerCase(); (0).toFixed(); (new C).foo(); + (function() { } ()); + new (getC()).Class(); }` });