From d7345da7cffbe53b4855ff4b6bce7e571bf5b2fb Mon Sep 17 00:00:00 2001 From: Christian van der Loo Date: Mon, 26 Aug 2024 21:10:47 -0400 Subject: [PATCH] Fix Immer type inference for `setState` (#2696) * fix(immer): tweak type inference to base `setState` type off of store `setState` instead of `getState` * fix(immer): instead, infer type directly from StoreApi["setState"] * fix(immer): instead of using `StoreApi`, extract from A2 the non-functional component of state * docs: add comment describing why it is not derived from `A1` * test: add example middleware that modifies getState w/o setState * fix: add assertion for inner `set` and `get` types --------- Co-authored-by: Daishi Kato --- src/middleware/immer.ts | 15 +++++++-- tests/middlewareTypes.test.tsx | 58 ++++++++++++++++++++++++++++++++-- 2 files changed, 68 insertions(+), 5 deletions(-) diff --git a/src/middleware/immer.ts b/src/middleware/immer.ts index 7d692d760d..9b79fa1660 100644 --- a/src/middleware/immer.ts +++ b/src/middleware/immer.ts @@ -32,10 +32,11 @@ type SkipTwo = T extends { length: 0 } ? A : never +type SetStateType = Exclude any> + type WithImmer = Write> type StoreImmer = S extends { - getState: () => infer T setState: infer SetState } ? SetState extends { @@ -43,13 +44,21 @@ type StoreImmer = S extends { (...a: infer A2): infer Sr2 } ? { + // Ideally, we would want to infer the `nextStateOrUpdater` `T` type from the + // `A1` type, but this is infeasible since it is an intersection with + // a partial type. setState( - nextStateOrUpdater: T | Partial | ((state: Draft) => void), + nextStateOrUpdater: + | SetStateType + | Partial> + | ((state: Draft>) => void), shouldReplace?: false, ...a: SkipTwo ): Sr1 setState( - nextStateOrUpdater: T | ((state: Draft) => void), + nextStateOrUpdater: + | SetStateType + | ((state: Draft>) => void), shouldReplace: true, ...a: SkipTwo ): Sr2 diff --git a/tests/middlewareTypes.test.tsx b/tests/middlewareTypes.test.tsx index 0da6b12316..d38fa40e6e 100644 --- a/tests/middlewareTypes.test.tsx +++ b/tests/middlewareTypes.test.tsx @@ -1,9 +1,9 @@ /* eslint @typescript-eslint/no-unused-expressions: off */ // FIXME /* eslint react-compiler/react-compiler: off */ -import { describe, expect, it } from 'vitest' +import { describe, expect, expectTypeOf, it } from 'vitest' import { create } from 'zustand' -import type { StoreApi } from 'zustand' +import type { StateCreator, StoreApi, StoreMutatorIdentifier } from 'zustand' import { combine, devtools, @@ -19,6 +19,27 @@ type CounterState = { inc: () => void } +type ExampleStateCreator = < + Mps extends [StoreMutatorIdentifier, unknown][] = [], + Mcs extends [StoreMutatorIdentifier, unknown][] = [], + U = T, +>( + f: StateCreator, +) => StateCreator + +type Write = Omit & U +type StoreModifyAllButSetState = S extends { + getState: () => infer T +} + ? Omit, 'setState'> + : never + +declare module 'zustand/vanilla' { + interface StoreMutators { + 'org/example': Write> + } +} + describe('counter state spec (no middleware)', () => { it('no middleware', () => { const useBoundStore = create((set, get) => ({ @@ -64,6 +85,39 @@ describe('counter state spec (single middleware)', () => { immer(() => ({ count: 0 })), ) expect(testSubtyping).toBeDefined() + + const exampleMiddleware = ((initializer) => + initializer) as ExampleStateCreator + + const testDerivedSetStateType = create()( + exampleMiddleware( + immer((set, get) => ({ + count: 0, + inc: () => + set((state) => { + state.count = get().count + 1 + type OmitFn = Exclude any> + expectTypeOf< + OmitFn[0]> + >().not.toMatchTypeOf<{ additional: number }>() + expectTypeOf>().toMatchTypeOf<{ + additional: number + }>() + }), + })), + ), + ) + expect(testDerivedSetStateType).toBeDefined() + // the type of the `getState` should include our new property + expectTypeOf(testDerivedSetStateType.getState()).toMatchTypeOf<{ + additional: number + }>() + // the type of the `setState` should not include our new property + expectTypeOf< + Parameters[0] + >().not.toMatchTypeOf<{ + additional: number + }>() }) it('redux', () => {