diff --git a/packages/expect/src/jest-extend.ts b/packages/expect/src/jest-extend.ts index 0c2e2608c733..a7d9795668ff 100644 --- a/packages/expect/src/jest-extend.ts +++ b/packages/expect/src/jest-extend.ts @@ -17,9 +17,6 @@ import { subsetEquality, } from './jest-utils' -const isAsyncFunction = (fn: unknown) => - typeof fn === 'function' && (fn as any)[Symbol.toStringTag] === 'AsyncFunction' - const getMatcherState = (assertion: Chai.AssertionStatic & Chai.Assertion, expect: Vi.ExpectStatic) => { const obj = assertion._obj const isNot = util.flag(assertion, 'negate') as boolean @@ -56,30 +53,27 @@ class JestExtendError extends Error { function JestExtendPlugin(expect: Vi.ExpectStatic, matchers: MatchersObject): ChaiPlugin { return (c, utils) => { Object.entries(matchers).forEach(([expectAssertionName, expectAssertion]) => { - function expectSyncWrapper(this: Chai.AssertionStatic & Chai.Assertion, ...args: any[]) { + function expectWrapper(this: Chai.AssertionStatic & Chai.Assertion, ...args: any[]) { const { state, isNot, obj } = getMatcherState(this, expect) // @ts-expect-error args wanting tuple - const { pass, message, actual, expected } = expectAssertion.call(state, obj, ...args) as SyncExpectationResult - - if ((pass && isNot) || (!pass && !isNot)) - throw new JestExtendError(message(), actual, expected) - } + const result = expectAssertion.call(state, obj, ...args) - async function expectAsyncWrapper(this: Chai.AssertionStatic & Chai.Assertion, ...args: any[]) { - const { state, isNot, obj } = getMatcherState(this, expect) + if (result && typeof result === 'object' && result instanceof Promise) { + return result.then(({ pass, message, actual, expected }) => { + if ((pass && isNot) || (!pass && !isNot)) + throw new JestExtendError(message(), actual, expected) + }) + } - // @ts-expect-error args wanting tuple - const { pass, message, actual, expected } = await expectAssertion.call(state, obj, ...args) as SyncExpectationResult + const { pass, message, actual, expected } = result if ((pass && isNot) || (!pass && !isNot)) throw new JestExtendError(message(), actual, expected) } - const expectAssertionWrapper = isAsyncFunction(expectAssertion) ? expectAsyncWrapper : expectSyncWrapper - - utils.addMethod((globalThis as any)[JEST_MATCHERS_OBJECT].matchers, expectAssertionName, expectAssertionWrapper) - utils.addMethod(c.Assertion.prototype, expectAssertionName, expectAssertionWrapper) + utils.addMethod((globalThis as any)[JEST_MATCHERS_OBJECT].matchers, expectAssertionName, expectWrapper) + utils.addMethod(c.Assertion.prototype, expectAssertionName, expectWrapper) class CustomMatcher extends AsymmetricMatcher<[unknown, ...unknown[]]> { constructor(inverse = false, ...sample: [unknown, ...unknown[]]) { diff --git a/test/core/test/jest-expect.test.ts b/test/core/test/jest-expect.test.ts index f6d763cd903b..ae44403dc0e2 100644 --- a/test/core/test/jest-expect.test.ts +++ b/test/core/test/jest-expect.test.ts @@ -8,6 +8,9 @@ class TestError extends Error {} // For expect.extend interface CustomMatchers { toBeDividedBy(divisor: number): R + toBeTestedAsync(): Promise + toBeTestedSync(): R + toBeTestedPromise(): R } declare global { namespace Vi { @@ -142,7 +145,7 @@ describe('jest-expect', () => { expect(['Bob', 'Eve']).toEqual(expect.not.arrayContaining(['Steve'])) }) - it('expect.extend', () => { + it('expect.extend', async () => { expect.extend({ toBeDividedBy(received, divisor) { const pass = received % divisor === 0 @@ -161,6 +164,24 @@ describe('jest-expect', () => { } } }, + async toBeTestedAsync() { + return { + pass: false, + message: () => 'toBeTestedAsync', + } + }, + toBeTestedSync() { + return { + pass: false, + message: () => 'toBeTestedSync', + } + }, + toBeTestedPromise() { + return Promise.resolve({ + pass: false, + message: () => 'toBeTestedPromise', + }) + }, }) expect(5).toBeDividedBy(5) @@ -169,6 +190,11 @@ describe('jest-expect', () => { one: expect.toBeDividedBy(1), two: expect.not.toBeDividedBy(5), }) + expect(() => expect(2).toBeDividedBy(5)).toThrowError() + + expect(() => expect(null).toBeTestedSync()).toThrowError('toBeTestedSync') + await expect(async () => await expect(null).toBeTestedAsync()).rejects.toThrowError('toBeTestedAsync') + await expect(async () => await expect(null).toBeTestedPromise()).rejects.toThrowError('toBeTestedPromise') }) it('object', () => {