diff --git a/src/index.test.tsx b/src/index.test.tsx index 3844323..6686baa 100644 --- a/src/index.test.tsx +++ b/src/index.test.tsx @@ -21,3 +21,25 @@ test("mergeRefs", () => { expect(refAsFunc).toHaveBeenCalledWith(null); expect(refAsObj.current).toBe(null); }); + +test("mergeRefs with undefined and null refs", () => { + const Dummy = React.forwardRef(function Dummy(_, ref) { + React.useImperativeHandle(ref, () => "refValue"); + return null; + }); + const refAsFunc = jest.fn(); + const refAsObj = { current: undefined }; + const Example: React.FC<{ visible: boolean }> = ({ visible }) => { + return visible ? ( + + ) : null; + }; + const { rerender } = render(); + expect(refAsFunc).toHaveBeenCalledTimes(1); + expect(refAsFunc).toHaveBeenCalledWith("refValue"); + expect(refAsObj.current).toBe("refValue"); + rerender(); + expect(refAsFunc).toHaveBeenCalledTimes(2); + expect(refAsFunc).toHaveBeenCalledWith(null); + expect(refAsObj.current).toBe(null); +}); diff --git a/src/index.tsx b/src/index.tsx index b2e688b..2ff12ff 100644 --- a/src/index.tsx +++ b/src/index.tsx @@ -1,7 +1,7 @@ import type * as React from "react"; export function mergeRefs( - refs: Array | React.LegacyRef> + refs: Array | React.LegacyRef | undefined | null> ): React.RefCallback { return (value) => { refs.forEach((ref) => {