diff --git a/src/vanilla/utils/atomFamily.ts b/src/vanilla/utils/atomFamily.ts
index 81b0098613..4667354890 100644
--- a/src/vanilla/utils/atomFamily.ts
+++ b/src/vanilla/utils/atomFamily.ts
@@ -1,11 +1,27 @@
-import type { Atom } from '../../vanilla.ts'
+import { type Atom } from '../../vanilla.ts'
-type ShouldRemove = (createdAt: number, param: Param) => boolean
+/**
+ * in milliseconds
+ */
+type CreatedAt = number
+type ShouldRemove = (createdAt: CreatedAt, param: Param) => boolean
+type Cleanup = () => void
+type Callback = (event: {
+ type: 'CREATE' | 'REMOVE'
+ param: Param
+ atom: AtomType
+}) => void
export interface AtomFamily {
(param: Param): AtomType
+ getParams(): Iterable
remove(param: Param): void
setShouldRemove(shouldRemove: ShouldRemove | null): void
+ /**
+ * fires when a atom is created or removed
+ * This API is for advanced use cases, and can change without notice.
+ */
+ unstable_listen(callback: Callback): Cleanup
}
export function atomFamily>(
@@ -17,9 +33,9 @@ export function atomFamily>(
initializeAtom: (param: Param) => AtomType,
areEqual?: (a: Param, b: Param) => boolean,
) {
- type CreatedAt = number // in milliseconds
let shouldRemove: ShouldRemove | null = null
const atoms: Map = new Map()
+ const listeners = new Set>()
const createAtom = (param: Param) => {
let item: [AtomType, CreatedAt] | undefined
if (areEqual === undefined) {
@@ -44,16 +60,40 @@ export function atomFamily>(
const newAtom = initializeAtom(param)
atoms.set(param, [newAtom, Date.now()])
+ notifyListeners('CREATE', param, newAtom)
return newAtom
}
+ function notifyListeners(
+ type: 'CREATE' | 'REMOVE',
+ param: Param,
+ atom: AtomType,
+ ) {
+ for (const listener of listeners) {
+ listener({ type, param, atom })
+ }
+ }
+
+ createAtom.unstable_listen = (callback: Callback) => {
+ listeners.add(callback)
+ return () => {
+ listeners.delete(callback)
+ }
+ }
+
+ createAtom.getParams = () => atoms.keys()
+
createAtom.remove = (param: Param) => {
if (areEqual === undefined) {
+ if (!atoms.has(param)) return
+ const [atom] = atoms.get(param)!
atoms.delete(param)
+ notifyListeners('REMOVE', param, atom)
} else {
- for (const [key] of atoms) {
+ for (const [key, [atom]] of atoms) {
if (areEqual(key, param)) {
atoms.delete(key)
+ notifyListeners('REMOVE', key, atom)
break
}
}
@@ -63,9 +103,10 @@ export function atomFamily>(
createAtom.setShouldRemove = (fn: ShouldRemove | null) => {
shouldRemove = fn
if (!shouldRemove) return
- for (const [key, value] of atoms) {
- if (shouldRemove(value[1], key)) {
+ for (const [key, [atom, createdAt]] of atoms) {
+ if (shouldRemove(createdAt, key)) {
atoms.delete(key)
+ notifyListeners('REMOVE', key, atom)
}
}
}
diff --git a/tests/vanilla/utils/atomFamily.test.ts b/tests/vanilla/utils/atomFamily.test.ts
new file mode 100644
index 0000000000..a66f4d1a7b
--- /dev/null
+++ b/tests/vanilla/utils/atomFamily.test.ts
@@ -0,0 +1,95 @@
+import { expect, it, vi } from 'vitest'
+import { atom, createStore } from 'jotai/vanilla'
+import type { Atom } from 'jotai/vanilla'
+import { atomFamily } from 'jotai/vanilla/utils'
+
+it('should create atoms with different params', () => {
+ const store = createStore()
+ const aFamily = atomFamily((param: number) => atom(param))
+
+ expect(store.get(aFamily(1))).toEqual(1)
+ expect(store.get(aFamily(2))).toEqual(2)
+})
+
+it('should remove atoms', () => {
+ const store = createStore()
+ const initializeAtom = vi.fn((param: number) => atom(param))
+ const aFamily = atomFamily(initializeAtom)
+
+ expect(store.get(aFamily(1))).toEqual(1)
+ expect(store.get(aFamily(2))).toEqual(2)
+ aFamily.remove(2)
+ initializeAtom.mockClear()
+ expect(store.get(aFamily(1))).toEqual(1)
+ expect(initializeAtom).toHaveBeenCalledTimes(0)
+ expect(store.get(aFamily(2))).toEqual(2)
+ expect(initializeAtom).toHaveBeenCalledTimes(1)
+})
+
+it('should remove atoms with custom comparator', () => {
+ const store = createStore()
+ const initializeAtom = vi.fn((param: number) => atom(param))
+ const aFamily = atomFamily(initializeAtom, (a, b) => a === b)
+
+ expect(store.get(aFamily(1))).toEqual(1)
+ expect(store.get(aFamily(2))).toEqual(2)
+ expect(store.get(aFamily(3))).toEqual(3)
+ aFamily.remove(2)
+ initializeAtom.mockClear()
+ expect(store.get(aFamily(1))).toEqual(1)
+ expect(initializeAtom).toHaveBeenCalledTimes(0)
+ expect(store.get(aFamily(2))).toEqual(2)
+ expect(initializeAtom).toHaveBeenCalledTimes(1)
+})
+
+it('should remove atoms with custom shouldRemove', () => {
+ const store = createStore()
+ const initializeAtom = vi.fn((param: number) => atom(param))
+ const aFamily = atomFamily>(initializeAtom)
+ expect(store.get(aFamily(1))).toEqual(1)
+ expect(store.get(aFamily(2))).toEqual(2)
+ expect(store.get(aFamily(3))).toEqual(3)
+ aFamily.setShouldRemove((_createdAt, param) => param % 2 === 0)
+ initializeAtom.mockClear()
+ expect(store.get(aFamily(1))).toEqual(1)
+ expect(initializeAtom).toHaveBeenCalledTimes(0)
+ expect(store.get(aFamily(2))).toEqual(2)
+ expect(initializeAtom).toHaveBeenCalledTimes(1)
+ expect(store.get(aFamily(3))).toEqual(3)
+ expect(initializeAtom).toHaveBeenCalledTimes(1)
+})
+
+it('should notify listeners', () => {
+ const aFamily = atomFamily((param: number) => atom(param))
+ const listener = vi.fn(() => {})
+ type Event = { type: 'CREATE' | 'REMOVE'; param: number; atom: Atom }
+ const unsubscribe = aFamily.unstable_listen(listener)
+ const atom1 = aFamily(1)
+ expect(listener).toHaveBeenCalledTimes(1)
+ const eventCreate = listener.mock.calls[0]?.at(0) as unknown as Event
+ if (!eventCreate) throw new Error('eventCreate is undefined')
+ expect(eventCreate.type).toEqual('CREATE')
+ expect(eventCreate.param).toEqual(1)
+ expect(eventCreate.atom).toEqual(atom1)
+ listener.mockClear()
+ aFamily.remove(1)
+ expect(listener).toHaveBeenCalledTimes(1)
+ const eventRemove = listener.mock.calls[0]?.at(0) as unknown as Event
+ expect(eventRemove.type).toEqual('REMOVE')
+ expect(eventRemove.param).toEqual(1)
+ expect(eventRemove.atom).toEqual(atom1)
+ unsubscribe()
+ listener.mockClear()
+ aFamily(2)
+ expect(listener).toHaveBeenCalledTimes(0)
+})
+
+it('should return all params', () => {
+ const store = createStore()
+ const aFamily = atomFamily((param: number) => atom(param))
+
+ expect(store.get(aFamily(1))).toEqual(1)
+ expect(store.get(aFamily(2))).toEqual(2)
+ expect(store.get(aFamily(3))).toEqual(3)
+ expect(Array.from(aFamily.getParams())).toEqual([1, 2, 3])
+})