diff --git a/src/vanilla/store.ts b/src/vanilla/store.ts index d5010a8ccf1..42c0c2a4392 100644 --- a/src/vanilla/store.ts +++ b/src/vanilla/store.ts @@ -104,6 +104,8 @@ type AtomState<Value = AnyValue> = { v?: Value /** Atom error */ e?: AnyError + /** Indicates whether the atom value is has been changed */ + x?: boolean } const isAtomStateInitialized = <Value>(atomState: AtomState<Value>) => @@ -167,7 +169,14 @@ type Pending = readonly [ functions: Set<() => void>, ] -const createPending = (): Pending => [new Map(), new Map(), new Set()] +const createPending = (): Pending => [ + /** dependents */ + new Map(), + /** atomStates */ + new Map(), + /** functions */ + new Set(), +] const addPendingAtom = ( pending: Pending, @@ -198,33 +207,6 @@ const addPendingFunction = (pending: Pending, fn: () => void) => { pending[2].add(fn) } -const flushPending = (pending: Pending) => { - let error: AnyError - let hasError = false - const call = (fn: () => void) => { - try { - fn() - } catch (e) { - if (!hasError) { - error = e - hasError = true - } - } - } - while (pending[1].size || pending[2].size) { - pending[0].clear() - const atomStates = new Set(pending[1].values()) - pending[1].clear() - const functions = new Set(pending[2]) - pending[2].clear() - atomStates.forEach((atomState) => atomState.m?.l.forEach(call)) - functions.forEach(call) - } - if (hasError) { - throw error - } -} - // internal & unstable type type StoreArgs = readonly [ getAtomState: <Value>(atom: Atom<Value>) => AtomState<Value>, @@ -276,6 +258,33 @@ const buildStore = ( debugMountedAtoms = new Set() } + const flushPending = (pending: Pending) => { + let error: AnyError + let hasError = false + const call = (fn: () => void) => { + try { + fn() + } catch (e) { + if (!hasError) { + error = e + hasError = true + } + } + } + while (pending[0].size || pending[1].size || pending[2].size) { + recomputeDependents(pending, new Set(pending[0].keys())) + const atomStates = new Set(pending[1].values()) + pending[1].clear() + const functions = new Set(pending[2]) + pending[2].clear() + atomStates.forEach((atomState) => atomState.m?.l.forEach(call)) + functions.forEach(call) + } + if (hasError) { + throw error + } + } + const setAtomStateValueOrPromise = ( atom: AnyAtom, atomState: AtomState, @@ -306,7 +315,6 @@ const buildStore = ( const readAtomState = <Value>( pending: Pending | undefined, atom: Atom<Value>, - dirtyAtoms?: Set<AnyAtom>, ): AtomState<Value> => { const atomState = getAtomState(atom) // See if we can skip recomputing this atom. @@ -314,7 +322,7 @@ const buildStore = ( // If the atom is mounted, we can use cached atom state. // because it should have been updated by dependencies. // We can't use the cache if the atom is dirty. - if (atomState.m && !dirtyAtoms?.has(atom)) { + if (atomState.m && !atomState.x) { return atomState } // Otherwise, check if the dependencies have changed. @@ -324,7 +332,7 @@ const buildStore = ( ([a, n]) => // Recursively, read the atom state of the dependency, and // check if the atom epoch number is unchanged - readAtomState(pending, a, dirtyAtoms).n === n, + readAtomState(pending, a).n === n, ) ) { return atomState @@ -347,7 +355,7 @@ const buildStore = ( return returnAtomValue(aState) } // a !== atom - const aState = readAtomState(pending, a, dirtyAtoms) + const aState = readAtomState(pending, a) try { return returnAtomValue(aState) } finally { @@ -418,57 +426,103 @@ const buildStore = ( const readAtom = <Value>(atom: Atom<Value>): Value => returnAtomValue(readAtomState(undefined, atom)) - const getMountedOrPendingDependents = <Value>( + const markRecomputePending = ( pending: Pending, - atom: Atom<Value>, - atomState: AtomState<Value>, - ): Map<AnyAtom, AtomState> => { - const dependents = new Map<AnyAtom, AtomState>() - for (const a of atomState.m?.t || []) { + atom: AnyAtom, + atomState: AtomState, + ) => { + addPendingAtom(pending, atom, atomState) + if (isPendingRecompute(atom)) { + return + } + const dependents = getAllDependents(pending, [atom]) + for (const [dependent] of dependents) { + getAtomState(dependent).x = true + } + } + + const markRecomputeComplete = ( + pending: Pending, + atom: AnyAtom, + atomState: AtomState, + ) => { + atomState.x = false + pending[0].delete(atom) + } + + const isPendingRecompute = (atom: AnyAtom) => getAtomState(atom).x + + const getMountedDependents = ( + pending: Pending, + a: AnyAtom, + aState: AtomState, + ) => { + return new Set<AnyAtom>( + [ + ...(aState.m?.t || []), + ...aState.p, + ...(getPendingDependents(pending, a) || []), + ].filter((a) => getAtomState(a).m), + ) + } + + /** @returns map of all dependents or dependencies (deep) of the root atoms */ + const getDeep = ( + /** function to get immediate dependents or dependencies of the atom */ + getDeps: (a: AnyAtom, aState: AtomState) => Iterable<AnyAtom>, + rootAtoms: Iterable<AnyAtom>, + ) => { + const visited = new Map<AnyAtom, Set<AnyAtom>>() + const stack: AnyAtom[] = Array.from(rootAtoms) + while (stack.length > 0) { + const a = stack.pop()! const aState = getAtomState(a) - if (aState.m) { - dependents.set(a, aState) + if (visited.has(a)) { + continue + } + const deps = new Set(getDeps(a, aState)) + visited.set(a, deps) + for (const d of deps) { + if (!visited.has(d)) { + stack.push(d) + } } } - for (const atomWithPendingPromise of atomState.p) { - dependents.set( - atomWithPendingPromise, - getAtomState(atomWithPendingPromise), - ) - } - getPendingDependents(pending, atom)?.forEach((dependent) => { - dependents.set(dependent, getAtomState(dependent)) - }) - return dependents + return visited } + const getAllDependents = (pending: Pending, atoms: Iterable<AnyAtom>) => + getDeep((a, aState) => getMountedDependents(pending, a, aState), atoms) + // This is a topological sort via depth-first search, slightly modified from // what's described here for simplicity and performance reasons: // https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search - function getSortedDependents( + const getSortedDependents = ( pending: Pending, - rootAtom: AnyAtom, - rootAtomState: AtomState, - ): [[AnyAtom, AtomState, number][], Set<AnyAtom>] { - const sorted: [atom: AnyAtom, atomState: AtomState, epochNumber: number][] = - [] + rootAtoms: Iterable<AnyAtom>, + ) => { + const atomMap = getAllDependents(pending, rootAtoms) + const sorted: AnyAtom[] = [] const visiting = new Set<AnyAtom>() const visited = new Set<AnyAtom>() - // Visit the root atom. This is the only atom in the dependency graph + // Visit the root atoms. These are the only atoms in the dependency graph // without incoming edges, which is one reason we can simplify the algorithm - const stack: [a: AnyAtom, aState: AtomState][] = [[rootAtom, rootAtomState]] + const stack: [a: AnyAtom, dependents: Set<AnyAtom>][] = [] + for (const a of rootAtoms) { + if (atomMap.has(a)) { + stack.push([a, atomMap.get(a)!]) + } + } while (stack.length > 0) { - const [a, aState] = stack[stack.length - 1]! + const [a, dependents] = stack[stack.length - 1]! if (visited.has(a)) { - // All dependents have been processed, now process this atom stack.pop() continue } if (visiting.has(a)) { - // The algorithm calls for pushing onto the front of the list. For - // performance, we will simply push onto the end, and then will iterate in - // reverse order later. - sorted.push([a, aState, aState.n]) + // The algorithm calls for pushing onto the front of the list. + // For performance we push on the end, and will reverse the order later. + sorted.push(a) // Atom has been visited but not yet processed visited.add(a) stack.pop() @@ -476,50 +530,46 @@ const buildStore = ( } visiting.add(a) // Push unvisited dependents onto the stack - for (const [d, s] of getMountedOrPendingDependents(pending, a, aState)) { - if (a !== d && !visiting.has(d)) { - stack.push([d, s]) + for (const d of dependents) { + if (a !== d && !visiting.has(d) && atomMap.has(d)) { + stack.push([d, atomMap.get(d)!]) } } } - return [sorted, visited] + return sorted.reverse() } - const recomputeDependents = <Value>( - pending: Pending, - atom: Atom<Value>, - atomState: AtomState<Value>, - ) => { - // Step 1: traverse the dependency graph to build the topsorted atom list - // We don't bother to check for cycles, which simplifies the algorithm. - const [topsortedAtoms, markedAtoms] = getSortedDependents( - pending, - atom, - atomState, - ) - - // Step 2: use the topsorted atom list to recompute all affected atoms - // Track what's changed, so that we can short circuit when possible - const changedAtoms = new Set<AnyAtom>([atom]) - for (let i = topsortedAtoms.length - 1; i >= 0; --i) { - const [a, aState, prevEpochNumber] = topsortedAtoms[i]! - let hasChangedDeps = false - for (const dep of aState.d.keys()) { - if (dep !== a && changedAtoms.has(dep)) { - hasChangedDeps = true - break - } - } - if (hasChangedDeps) { - readAtomState(pending, a, markedAtoms) + const recomputeDependents = (pending: Pending, rootAtoms: Set<AnyAtom>) => { + if (rootAtoms.size === 0) { + return + } + const hasChangedDeps = (aState: AtomState) => + Array.from(aState.d.keys()).some((d) => rootAtoms.has(d)) + // traverse the dependency graph to build the topsorted atom list + for (const a of getSortedDependents(pending, rootAtoms)) { + // use the topsorted atom list to recompute all affected atoms + // Track what's changed, so that we can short circuit when possible + const aState = getAtomState(a) + const prevEpochNumber = aState.n + if (isPendingRecompute(a) || hasChangedDeps(aState)) { + readAtomState(pending, a) mountDependencies(pending, a, aState) if (prevEpochNumber !== aState.n) { - addPendingAtom(pending, a, aState) - changedAtoms.add(a) + markRecomputePending(pending, a, aState) } } - markedAtoms.delete(a) + markRecomputeComplete(pending, a, aState) + } + } + + const recomputeDependencies = (pending: Pending, a: AnyAtom) => { + if (!isPendingRecompute(a)) { + return } + const getDependencies = (_: unknown, aState: AtomState) => aState.d.keys() + const dependencies = Array.from(getDeep(getDependencies, [a]).keys()) + const dirtyDependencies = new Set(dependencies.filter(isPendingRecompute)) + recomputeDependents(pending, dirtyDependencies) } const writeAtomState = <Value, Args extends unknown[], Result>( @@ -528,8 +578,10 @@ const buildStore = ( ...args: Args ): Result => { let isSync = true - const getter: Getter = <V>(a: Atom<V>) => - returnAtomValue(readAtomState(pending, a)) + const getter: Getter = <V>(a: Atom<V>) => { + recomputeDependencies(pending, atom) + return returnAtomValue(readAtomState(pending, a)) + } const setter: Setter = <V, As extends unknown[], R>( a: WritableAtom<V, As, R>, ...args: As @@ -546,8 +598,7 @@ const buildStore = ( setAtomStateValueOrPromise(a, aState, v) mountDependencies(pending, a, aState) if (prevEpochNumber !== aState.n) { - addPendingAtom(pending, a, aState) - recomputeDependents(pending, a, aState) + markRecomputePending(pending, a, aState) } return undefined as R } else { @@ -732,8 +783,7 @@ const buildStore = ( setAtomStateValueOrPromise(atom, atomState, value) mountDependencies(pending, atom, atomState) if (prevEpochNumber !== atomState.n) { - addPendingAtom(pending, atom, atomState) - recomputeDependents(pending, atom, atomState) + markRecomputePending(pending, atom, atomState) } } } diff --git a/tests/vanilla/dependency.test.tsx b/tests/vanilla/dependency.test.tsx index 3970cfa7eba..90f0c114e7d 100644 --- a/tests/vanilla/dependency.test.tsx +++ b/tests/vanilla/dependency.test.tsx @@ -1,9 +1,5 @@ import { expect, it, vi } from 'vitest' import { atom, createStore } from 'jotai/vanilla' -import type { - INTERNAL_DevStoreRev4, - INTERNAL_PrdStore, -} from 'jotai/vanilla/store' it('can propagate updates with async atom chains', async () => { const store = createStore() @@ -412,37 +408,20 @@ it('can cache reading an atom in write function (with mounting)', () => { it('batches sync writes', () => { const a = atom(0) - a.debugLabel = 'a' - const b = atom((get) => get(a) + 1) - b.debugLabel = 'b' + const b = atom((get) => get(a)) const fetch = vi.fn() const c = atom((get) => fetch(get(a))) - c.debugLabel = 'c' const w = atom(null, (get, set) => { - const b1 = get(b) // 1 - set(a, b1) - expect(fetch).toHaveBeenCalledTimes(0) - const b2 = get(b) // 2 - set(a, b2) + set(a, 1) + expect(get(b)).toBe(1) expect(fetch).toHaveBeenCalledTimes(0) }) - w.debugLabel = 'w' - const store = createStore() as INTERNAL_DevStoreRev4 & INTERNAL_PrdStore + const store = createStore() store.sub(b, () => {}) store.sub(c, () => {}) - const getAtomState = store.dev4_get_internal_weak_map().get - const aState = getAtomState(a) as any - aState.label = 'a' - const bState = getAtomState(b) as any - bState.label = 'b' - const cState = getAtomState(c) as any - cState.label = 'c' fetch.mockClear() store.set(w) - // we expect b to be recomputed when a's value is changed by `set` - // we expect c to be recomputed in flushPending after the graph has updated - // this distinction is possible by tracking what atoms are accessed with w.write's `get` - expect(store.get(a)).toBe(2) expect(fetch).toHaveBeenCalledOnce() - expect(fetch).toBeCalledWith(2) + expect(fetch).toBeCalledWith(1) + expect(store.get(a)).toBe(1) })