diff --git a/src/index.test.tsx b/src/index.test.tsx
index 3844323..6686baa 100644
--- a/src/index.test.tsx
+++ b/src/index.test.tsx
@@ -21,3 +21,25 @@ test("mergeRefs", () => {
expect(refAsFunc).toHaveBeenCalledWith(null);
expect(refAsObj.current).toBe(null);
});
+
+test("mergeRefs with undefined and null refs", () => {
+ const Dummy = React.forwardRef(function Dummy(_, ref) {
+ React.useImperativeHandle(ref, () => "refValue");
+ return null;
+ });
+ const refAsFunc = jest.fn();
+ const refAsObj = { current: undefined };
+ const Example: React.FC<{ visible: boolean }> = ({ visible }) => {
+ return visible ? (
+
+ ) : null;
+ };
+ const { rerender } = render();
+ expect(refAsFunc).toHaveBeenCalledTimes(1);
+ expect(refAsFunc).toHaveBeenCalledWith("refValue");
+ expect(refAsObj.current).toBe("refValue");
+ rerender();
+ expect(refAsFunc).toHaveBeenCalledTimes(2);
+ expect(refAsFunc).toHaveBeenCalledWith(null);
+ expect(refAsObj.current).toBe(null);
+});
diff --git a/src/index.tsx b/src/index.tsx
index b2e688b..2ff12ff 100644
--- a/src/index.tsx
+++ b/src/index.tsx
@@ -1,7 +1,7 @@
import type * as React from "react";
export function mergeRefs(
- refs: Array | React.LegacyRef>
+ refs: Array | React.LegacyRef | undefined | null>
): React.RefCallback {
return (value) => {
refs.forEach((ref) => {