diff --git a/packages/react-reconciler/src/ReactFiberBeginWork.js b/packages/react-reconciler/src/ReactFiberBeginWork.js index ef947a15a33e2..47b5957128934 100644 --- a/packages/react-reconciler/src/ReactFiberBeginWork.js +++ b/packages/react-reconciler/src/ReactFiberBeginWork.js @@ -1168,18 +1168,13 @@ export function replayFunctionComponent( workInProgress: Fiber, nextProps: any, Component: any, + secondArg: any, renderLanes: Lanes, ): Fiber | null { // This function is used to replay a component that previously suspended, // after its data resolves. It's a simplified version of // updateFunctionComponent that reuses the hooks from the previous attempt. - let context; - if (!disableLegacyContext) { - const unmaskedContext = getUnmaskedContext(workInProgress, Component, true); - context = getMaskedContext(workInProgress, unmaskedContext); - } - prepareToReadContext(workInProgress, renderLanes); if (enableSchedulingProfiler) { markComponentRenderStarted(workInProgress); @@ -1189,7 +1184,7 @@ export function replayFunctionComponent( workInProgress, Component, nextProps, - context, + secondArg, ); const hasId = checkDidRenderIdHook(); if (enableSchedulingProfiler) { diff --git a/packages/react-reconciler/src/ReactFiberWorkLoop.js b/packages/react-reconciler/src/ReactFiberWorkLoop.js index 292cad9d0735e..f9c7d1fa61678 100644 --- a/packages/react-reconciler/src/ReactFiberWorkLoop.js +++ b/packages/react-reconciler/src/ReactFiberWorkLoop.js @@ -39,6 +39,7 @@ import { enableTransitionTracing, useModernStrictMode, revertRemovalOfSiblingPrerendering, + disableLegacyContext, } from 'shared/ReactFeatureFlags'; import ReactSharedInternals from 'shared/ReactSharedInternals'; import is from 'shared/objectIs'; @@ -281,6 +282,7 @@ import { flushSyncWorkOnLegacyRootsOnly, getContinuationForRoot, } from './ReactFiberRootScheduler'; +import {getMaskedContext, getUnmaskedContext} from './ReactFiberContext'; const ceil = Math.ceil; @@ -2379,8 +2381,8 @@ function replaySuspendedUnitOfWork(unitOfWork: Fiber): void { // Fallthrough to the next branch. } // eslint-disable-next-line no-fallthrough - case FunctionComponent: - case ForwardRef: { + case SimpleMemoComponent: + case FunctionComponent: { // Resolve `defaultProps`. This logic is copied from `beginWork`. // TODO: Consider moving this switch statement into that module. Also, // could maybe use this as an opportunity to say `use` doesn't work with @@ -2391,23 +2393,39 @@ function replaySuspendedUnitOfWork(unitOfWork: Fiber): void { unitOfWork.elementType === Component ? unresolvedProps : resolveDefaultProps(Component, unresolvedProps); + let context: any; + if (!disableLegacyContext) { + const unmaskedContext = getUnmaskedContext(unitOfWork, Component, true); + context = getMaskedContext(unitOfWork, unmaskedContext); + } next = replayFunctionComponent( current, unitOfWork, resolvedProps, Component, + context, workInProgressRootRenderLanes, ); break; } - case SimpleMemoComponent: { - const Component = unitOfWork.type; - const nextProps = unitOfWork.pendingProps; + case ForwardRef: { + // Resolve `defaultProps`. This logic is copied from `beginWork`. + // TODO: Consider moving this switch statement into that module. Also, + // could maybe use this as an opportunity to say `use` doesn't work with + // `defaultProps` :) + const Component = unitOfWork.type.render; + const unresolvedProps = unitOfWork.pendingProps; + const resolvedProps = + unitOfWork.elementType === Component + ? unresolvedProps + : resolveDefaultProps(Component, unresolvedProps); + next = replayFunctionComponent( current, unitOfWork, - nextProps, + resolvedProps, Component, + unitOfWork.ref, workInProgressRootRenderLanes, ); break; diff --git a/packages/react-reconciler/src/__tests__/ReactUse-test.js b/packages/react-reconciler/src/__tests__/ReactUse-test.js index 86bb586ec5829..ec3b86acb1cda 100644 --- a/packages/react-reconciler/src/__tests__/ReactUse-test.js +++ b/packages/react-reconciler/src/__tests__/ReactUse-test.js @@ -1472,4 +1472,132 @@ describe('ReactUse', () => { assertLog(['Hi']); expect(root).toMatchRenderedOutput('Hi'); }); + + test('unwrap uncached promises inside forwardRef', async () => { + const asyncInstance = {}; + const Async = React.forwardRef((props, ref) => { + React.useImperativeHandle(ref, () => asyncInstance); + const text = use(Promise.resolve('Async')); + return ; + }); + + const ref = React.createRef(); + function App() { + return ( + }> + + + ); + } + + const root = ReactNoop.createRoot(); + await act(() => { + startTransition(() => { + root.render(); + }); + }); + assertLog(['Async']); + expect(root).toMatchRenderedOutput('Async'); + expect(ref.current).toBe(asyncInstance); + }); + + test('unwrap uncached promises inside memo', async () => { + const Async = React.memo( + props => { + const text = use(Promise.resolve(props.text)); + return ; + }, + (a, b) => a.text === b.text, + ); + + function App({text}) { + return ( + }> + + + ); + } + + const root = ReactNoop.createRoot(); + await act(() => { + startTransition(() => { + root.render(); + }); + }); + assertLog(['Async']); + expect(root).toMatchRenderedOutput('Async'); + + // Update to the same value + await act(() => { + startTransition(() => { + root.render(); + }); + }); + // Should not have re-rendered, because it's memoized + assertLog([]); + expect(root).toMatchRenderedOutput('Async'); + + // Update to a different value + await act(() => { + startTransition(() => { + root.render(); + }); + }); + assertLog(['Async!']); + expect(root).toMatchRenderedOutput('Async!'); + }); + + // @gate !disableLegacyContext + test('unwrap uncached promises in component that accesses legacy context', async () => { + class ContextProvider extends React.Component { + static childContextTypes = { + legacyContext() {}, + }; + getChildContext() { + return {legacyContext: 'Async'}; + } + render() { + return this.props.children; + } + } + + function Async({label}, context) { + const text = use(Promise.resolve(context.legacyContext + ` (${label})`)); + return ; + } + Async.contextTypes = { + legacyContext: () => {}, + }; + + const AsyncMemo = React.memo(Async, (a, b) => a.label === b.label); + + function App() { + return ( + + }> +
+ +
+
+ +
+
+
+ ); + } + + const root = ReactNoop.createRoot(); + await act(() => { + startTransition(() => { + root.render(); + }); + }); + assertLog(['Async (function component)', 'Async (memo component)']); + expect(root).toMatchRenderedOutput( + <> +
Async (function component)
+
Async (memo component)
+ , + ); + }); });