diff --git a/compiler/packages/babel-plugin-react-compiler/src/Validation/ValidateNoRefAccesInRender.ts b/compiler/packages/babel-plugin-react-compiler/src/Validation/ValidateNoRefAccesInRender.ts index df6241a73f448..8a65b4709c174 100644 --- a/compiler/packages/babel-plugin-react-compiler/src/Validation/ValidateNoRefAccesInRender.ts +++ b/compiler/packages/babel-plugin-react-compiler/src/Validation/ValidateNoRefAccesInRender.ts @@ -11,12 +11,12 @@ import { IdentifierId, Place, SourceLocation, - isRefOrRefValue, isRefValueType, isUseRefType, } from '../HIR'; import { eachInstructionValueOperand, + eachPatternOperand, eachTerminalOperand, } from '../HIR/visitors'; import {Err, Ok, Result} from '../Utils/Result'; @@ -42,58 +42,165 @@ import {isEffectHook} from './ValidateMemoizedEffectDependencies'; * In the future we may reject more cases, based on either object names (`fooRef.current` is likely a ref) * or based on property name alone (`foo.current` might be a ref). */ +type State = { + refs: Set; + refValues: Map; + refAccessingFunctions: Set; +}; + export function validateNoRefAccessInRender(fn: HIRFunction): void { - const refAccessingFunctions: Set = new Set(); - validateNoRefAccessInRenderImpl(fn, refAccessingFunctions).unwrap(); + const state = { + refs: new Set(), + refValues: new Map(), + refAccessingFunctions: new Set(), + }; + validateNoRefAccessInRenderImpl(fn, state).unwrap(); } function validateNoRefAccessInRenderImpl( fn: HIRFunction, - refAccessingFunctions: Set, + state: State, ): Result { + let place; + for (const param of fn.params) { + if (param.kind === 'Identifier') { + place = param; + } else { + place = param.place; + } + + if (isRefValueType(place.identifier)) { + state.refValues.set(place.identifier.id, null); + } + if (isUseRefType(place.identifier)) { + state.refs.add(place.identifier.id); + } + } const errors = new CompilerError(); - const lookupLocations: Map = new Map(); for (const [, block] of fn.body.blocks) { + for (const phi of block.phis) { + phi.operands.forEach(operand => { + if (state.refs.has(operand.id) || isUseRefType(phi.id)) { + state.refs.add(phi.id.id); + } + const refValue = state.refValues.get(operand.id); + if (refValue !== undefined || isRefValueType(operand)) { + state.refValues.set( + phi.id.id, + refValue ?? state.refValues.get(phi.id.id) ?? null, + ); + } + if (state.refAccessingFunctions.has(operand.id)) { + state.refAccessingFunctions.add(phi.id.id); + } + }); + } + for (const instr of block.instructions) { + for (const operand of eachInstructionValueOperand(instr.value)) { + if (isRefValueType(operand.identifier)) { + CompilerError.invariant(state.refValues.has(operand.identifier.id), { + reason: 'Expected ref value to be in state', + loc: operand.loc, + }); + } + if (isUseRefType(operand.identifier)) { + CompilerError.invariant(state.refs.has(operand.identifier.id), { + reason: 'Expected ref to be in state', + loc: operand.loc, + }); + } + } + switch (instr.value.kind) { case 'JsxExpression': case 'JsxFragment': { for (const operand of eachInstructionValueOperand(instr.value)) { - validateNoDirectRefValueAccess(errors, operand, lookupLocations); + validateNoDirectRefValueAccess(errors, operand, state); } break; } + case 'ComputedLoad': case 'PropertyLoad': { + if (typeof instr.value.property !== 'string') { + validateNoRefValueAccess(errors, state, instr.value.property); + } if ( - isRefValueType(instr.lvalue.identifier) && - instr.value.property === 'current' + state.refAccessingFunctions.has(instr.value.object.identifier.id) ) { - lookupLocations.set(instr.lvalue.identifier.id, instr.loc); + state.refAccessingFunctions.add(instr.lvalue.identifier.id); + } + if (state.refs.has(instr.value.object.identifier.id)) { + /* + * Once an object contains a ref at any level, we treat it as a ref. + * If we look something up from it, that value may either be a ref + * or the ref value (or neither), so we conservatively assume it's both. + */ + state.refs.add(instr.lvalue.identifier.id); + state.refValues.set(instr.lvalue.identifier.id, instr.loc); } break; } + case 'LoadContext': case 'LoadLocal': { - if (refAccessingFunctions.has(instr.value.place.identifier.id)) { - refAccessingFunctions.add(instr.lvalue.identifier.id); + if ( + state.refAccessingFunctions.has(instr.value.place.identifier.id) + ) { + state.refAccessingFunctions.add(instr.lvalue.identifier.id); } - if (isRefValueType(instr.lvalue.identifier)) { - const loc = lookupLocations.get(instr.value.place.identifier.id); - if (loc !== undefined) { - lookupLocations.set(instr.lvalue.identifier.id, loc); - } + const refValue = state.refValues.get(instr.value.place.identifier.id); + if (refValue !== undefined) { + state.refValues.set(instr.lvalue.identifier.id, refValue); + } + if (state.refs.has(instr.value.place.identifier.id)) { + state.refs.add(instr.lvalue.identifier.id); } break; } + case 'StoreContext': case 'StoreLocal': { - if (refAccessingFunctions.has(instr.value.value.identifier.id)) { - refAccessingFunctions.add(instr.value.lvalue.place.identifier.id); - refAccessingFunctions.add(instr.lvalue.identifier.id); + if ( + state.refAccessingFunctions.has(instr.value.value.identifier.id) + ) { + state.refAccessingFunctions.add( + instr.value.lvalue.place.identifier.id, + ); + state.refAccessingFunctions.add(instr.lvalue.identifier.id); + } + const refValue = state.refValues.get(instr.value.value.identifier.id); + if ( + refValue !== undefined || + isRefValueType(instr.value.lvalue.place.identifier) + ) { + state.refValues.set( + instr.value.lvalue.place.identifier.id, + refValue ?? null, + ); + state.refValues.set(instr.lvalue.identifier.id, refValue ?? null); + } + if (state.refs.has(instr.value.value.identifier.id)) { + state.refs.add(instr.value.lvalue.place.identifier.id); + state.refs.add(instr.lvalue.identifier.id); } - if (isRefValueType(instr.value.lvalue.place.identifier)) { - const loc = lookupLocations.get(instr.value.value.identifier.id); - if (loc !== undefined) { - lookupLocations.set(instr.value.lvalue.place.identifier.id, loc); - lookupLocations.set(instr.lvalue.identifier.id, loc); + break; + } + case 'Destructure': { + const destructuredFunction = state.refAccessingFunctions.has( + instr.value.value.identifier.id, + ); + const destructuredRef = state.refs.has( + instr.value.value.identifier.id, + ); + for (const lval of eachPatternOperand(instr.value.lvalue.pattern)) { + if (isUseRefType(lval.identifier)) { + state.refs.add(lval.identifier.id); + } + if (destructuredRef || isRefValueType(lval.identifier)) { + state.refs.add(lval.identifier.id); + state.refValues.set(lval.identifier.id, null); + } + if (destructuredFunction) { + state.refAccessingFunctions.add(lval.identifier.id); } } break; @@ -107,32 +214,27 @@ function validateNoRefAccessInRenderImpl( */ [...eachInstructionValueOperand(instr.value)].some( operand => - isRefValueType(operand.identifier) || - refAccessingFunctions.has(operand.identifier.id), + state.refValues.has(operand.identifier.id) || + state.refAccessingFunctions.has(operand.identifier.id), ) || // check for cases where .current is accessed through an aliased ref ([...eachInstructionValueOperand(instr.value)].some(operand => - isUseRefType(operand.identifier), + state.refs.has(operand.identifier.id), ) && validateNoRefAccessInRenderImpl( instr.value.loweredFunc.func, - refAccessingFunctions, + state, ).isErr()) ) { // This function expression unconditionally accesses a ref - refAccessingFunctions.add(instr.lvalue.identifier.id); + state.refAccessingFunctions.add(instr.lvalue.identifier.id); } break; } case 'MethodCall': { if (!isEffectHook(instr.value.property.identifier)) { for (const operand of eachInstructionValueOperand(instr.value)) { - validateNoRefAccess( - errors, - refAccessingFunctions, - operand, - operand.loc, - ); + validateNoRefAccess(errors, state, operand, operand.loc); } } break; @@ -142,7 +244,7 @@ function validateNoRefAccessInRenderImpl( const isUseEffect = isEffectHook(callee.identifier); if (!isUseEffect) { // Report a more precise error when calling a local function that accesses a ref - if (refAccessingFunctions.has(callee.identifier.id)) { + if (state.refAccessingFunctions.has(callee.identifier.id)) { errors.push({ severity: ErrorSeverity.InvalidReact, reason: @@ -159,9 +261,9 @@ function validateNoRefAccessInRenderImpl( for (const operand of eachInstructionValueOperand(instr.value)) { validateNoRefAccess( errors, - refAccessingFunctions, + state, operand, - lookupLocations.get(operand.identifier.id) ?? operand.loc, + state.refValues.get(operand.identifier.id) ?? operand.loc, ); } } @@ -170,12 +272,17 @@ function validateNoRefAccessInRenderImpl( case 'ObjectExpression': case 'ArrayExpression': { for (const operand of eachInstructionValueOperand(instr.value)) { - validateNoRefAccess( - errors, - refAccessingFunctions, - operand, - lookupLocations.get(operand.identifier.id) ?? operand.loc, - ); + validateNoDirectRefValueAccess(errors, operand, state); + if (state.refAccessingFunctions.has(operand.identifier.id)) { + state.refAccessingFunctions.add(instr.lvalue.identifier.id); + } + if (state.refs.has(operand.identifier.id)) { + state.refs.add(instr.lvalue.identifier.id); + } + const refValue = state.refValues.get(operand.identifier.id); + if (refValue !== undefined) { + state.refValues.set(instr.lvalue.identifier.id, refValue); + } } break; } @@ -185,20 +292,15 @@ function validateNoRefAccessInRenderImpl( case 'ComputedStore': { validateNoRefAccess( errors, - refAccessingFunctions, + state, instr.value.object, - lookupLocations.get(instr.value.object.identifier.id) ?? instr.loc, + state.refValues.get(instr.value.object.identifier.id) ?? instr.loc, ); for (const operand of eachInstructionValueOperand(instr.value)) { if (operand === instr.value.object) { continue; } - validateNoRefValueAccess( - errors, - refAccessingFunctions, - lookupLocations, - operand, - ); + validateNoRefValueAccess(errors, state, operand); } break; } @@ -207,28 +309,27 @@ function validateNoRefAccessInRenderImpl( break; default: { for (const operand of eachInstructionValueOperand(instr.value)) { - validateNoRefValueAccess( - errors, - refAccessingFunctions, - lookupLocations, - operand, - ); + validateNoRefValueAccess(errors, state, operand); } break; } } + if (isUseRefType(instr.lvalue.identifier)) { + state.refs.add(instr.lvalue.identifier.id); + } + if ( + isRefValueType(instr.lvalue.identifier) && + !state.refValues.has(instr.lvalue.identifier.id) + ) { + state.refValues.set(instr.lvalue.identifier.id, instr.loc); + } } for (const operand of eachTerminalOperand(block.terminal)) { if (block.terminal.kind !== 'return') { - validateNoRefValueAccess( - errors, - refAccessingFunctions, - lookupLocations, - operand, - ); + validateNoRefValueAccess(errors, state, operand); } else { // Allow functions containing refs to be returned, but not direct ref values - validateNoDirectRefValueAccess(errors, operand, lookupLocations); + validateNoDirectRefValueAccess(errors, operand, state); } } } @@ -242,19 +343,18 @@ function validateNoRefAccessInRenderImpl( function validateNoRefValueAccess( errors: CompilerError, - refAccessingFunctions: Set, - lookupLocations: Map, + state: State, operand: Place, ): void { if ( - isRefValueType(operand.identifier) || - refAccessingFunctions.has(operand.identifier.id) + state.refValues.has(operand.identifier.id) || + state.refAccessingFunctions.has(operand.identifier.id) ) { errors.push({ severity: ErrorSeverity.InvalidReact, reason: 'Ref values (the `current` property) may not be accessed during render. (https://react.dev/reference/react/useRef)', - loc: lookupLocations.get(operand.identifier.id) ?? operand.loc, + loc: state.refValues.get(operand.identifier.id) ?? operand.loc, description: operand.identifier.name !== null && operand.identifier.name.kind === 'named' @@ -267,13 +367,14 @@ function validateNoRefValueAccess( function validateNoRefAccess( errors: CompilerError, - refAccessingFunctions: Set, + state: State, operand: Place, loc: SourceLocation, ): void { if ( - isRefOrRefValue(operand.identifier) || - refAccessingFunctions.has(operand.identifier.id) + state.refs.has(operand.identifier.id) || + state.refValues.has(operand.identifier.id) || + state.refAccessingFunctions.has(operand.identifier.id) ) { errors.push({ severity: ErrorSeverity.InvalidReact, @@ -293,14 +394,14 @@ function validateNoRefAccess( function validateNoDirectRefValueAccess( errors: CompilerError, operand: Place, - lookupLocations: Map, + state: State, ): void { - if (isRefValueType(operand.identifier)) { + if (state.refValues.has(operand.identifier.id)) { errors.push({ severity: ErrorSeverity.InvalidReact, reason: 'Ref values (the `current` property) may not be accessed during render. (https://react.dev/reference/react/useRef)', - loc: lookupLocations.get(operand.identifier.id) ?? operand.loc, + loc: state.refValues.get(operand.identifier.id) ?? operand.loc, description: operand.identifier.name !== null && operand.identifier.name.kind === 'named' diff --git a/compiler/packages/babel-plugin-react-compiler/src/__tests__/fixtures/compiler/error.invalid-use-ref-added-to-dep-without-type-info.expect.md b/compiler/packages/babel-plugin-react-compiler/src/__tests__/fixtures/compiler/error.invalid-use-ref-added-to-dep-without-type-info.expect.md index a28a74730bfb8..f576bac764613 100644 --- a/compiler/packages/babel-plugin-react-compiler/src/__tests__/fixtures/compiler/error.invalid-use-ref-added-to-dep-without-type-info.expect.md +++ b/compiler/packages/babel-plugin-react-compiler/src/__tests__/fixtures/compiler/error.invalid-use-ref-added-to-dep-without-type-info.expect.md @@ -22,13 +22,15 @@ function Foo({a}) { ## Error ``` - 3 | const ref = useRef(); - 4 | // type information is lost here as we don't track types of fields -> 5 | const val = {ref}; - | ^^^ InvalidReact: Ref values (the `current` property) may not be accessed during render. (https://react.dev/reference/react/useRef) (5:5) - 6 | // without type info, we don't know that val.ref.current is a ref value so we - 7 | // *would* end up depending on val.ref.current - 8 | // however, this is an instance of accessing a ref during render and is disallowed + 8 | // however, this is an instance of accessing a ref during render and is disallowed + 9 | // under React's rules, so we reject this input +> 10 | const x = {a, val: val.ref.current}; + | ^^^^^^^^^^^^^^^ InvalidReact: Ref values (the `current` property) may not be accessed during render. (https://react.dev/reference/react/useRef) (10:10) + +InvalidReact: Ref values (the `current` property) may not be accessed during render. (https://react.dev/reference/react/useRef) (10:10) + 11 | + 12 | return ; + 13 | } ``` \ No newline at end of file diff --git a/compiler/packages/babel-plugin-react-compiler/src/__tests__/fixtures/compiler/error.return-ref-callback-structure.expect.md b/compiler/packages/babel-plugin-react-compiler/src/__tests__/fixtures/compiler/error.return-ref-callback-structure.expect.md deleted file mode 100644 index 866d2e2fea657..0000000000000 --- a/compiler/packages/babel-plugin-react-compiler/src/__tests__/fixtures/compiler/error.return-ref-callback-structure.expect.md +++ /dev/null @@ -1,45 +0,0 @@ - -## Input - -```javascript -// @flow @validateRefAccessDuringRender @validatePreserveExistingMemoizationGuarantees - -import {useRef} from 'react'; - -component Foo(cond: boolean, cond2: boolean) { - const ref = useRef(); - - const s = () => { - return ref.current; - }; - - if (cond) return [s]; - else if (cond2) return {s}; - else return {s: [s]}; -} - -export const FIXTURE_ENTRYPOINT = { - fn: Foo, - params: [{cond: false, cond2: false}], -}; - -``` - - -## Error - -``` - 10 | }; - 11 | -> 12 | if (cond) return [s]; - | ^ InvalidReact: Ref values (the `current` property) may not be accessed during render. (https://react.dev/reference/react/useRef) (12:12) - -InvalidReact: Ref values (the `current` property) may not be accessed during render. (https://react.dev/reference/react/useRef) (13:13) - -InvalidReact: Ref values (the `current` property) may not be accessed during render. (https://react.dev/reference/react/useRef) (14:14) - 13 | else if (cond2) return {s}; - 14 | else return {s: [s]}; - 15 | } -``` - - \ No newline at end of file diff --git a/compiler/packages/babel-plugin-react-compiler/src/__tests__/fixtures/compiler/return-ref-callback-structure.expect.md b/compiler/packages/babel-plugin-react-compiler/src/__tests__/fixtures/compiler/return-ref-callback-structure.expect.md new file mode 100644 index 0000000000000..95976383cbff8 --- /dev/null +++ b/compiler/packages/babel-plugin-react-compiler/src/__tests__/fixtures/compiler/return-ref-callback-structure.expect.md @@ -0,0 +1,87 @@ + +## Input + +```javascript +// @flow @validateRefAccessDuringRender @validatePreserveExistingMemoizationGuarantees + +import {useRef} from 'react'; + +component Foo(cond: boolean, cond2: boolean) { + const ref = useRef(); + + const s = () => { + return ref.current; + }; + + if (cond) return [s]; + else if (cond2) return {s}; + else return {s: [s]}; +} + +export const FIXTURE_ENTRYPOINT = { + fn: Foo, + params: [{cond: false, cond2: false}], +}; + +``` + +## Code + +```javascript +import { c as _c } from "react/compiler-runtime"; + +import { useRef } from "react"; + +function Foo(t0) { + const $ = _c(4); + const { cond, cond2 } = t0; + const ref = useRef(); + let t1; + if ($[0] === Symbol.for("react.memo_cache_sentinel")) { + t1 = () => ref.current; + $[0] = t1; + } else { + t1 = $[0]; + } + const s = t1; + if (cond) { + let t2; + if ($[1] === Symbol.for("react.memo_cache_sentinel")) { + t2 = [s]; + $[1] = t2; + } else { + t2 = $[1]; + } + return t2; + } else { + if (cond2) { + let t2; + if ($[2] === Symbol.for("react.memo_cache_sentinel")) { + t2 = { s }; + $[2] = t2; + } else { + t2 = $[2]; + } + return t2; + } else { + let t2; + if ($[3] === Symbol.for("react.memo_cache_sentinel")) { + t2 = { s: [s] }; + $[3] = t2; + } else { + t2 = $[3]; + } + return t2; + } + } +} + +export const FIXTURE_ENTRYPOINT = { + fn: Foo, + params: [{ cond: false, cond2: false }], +}; + +``` + +### Eval output +(kind: ok) {"s":["[[ function params=0 ]]"]} \ No newline at end of file diff --git a/compiler/packages/babel-plugin-react-compiler/src/__tests__/fixtures/compiler/error.return-ref-callback-structure.js b/compiler/packages/babel-plugin-react-compiler/src/__tests__/fixtures/compiler/return-ref-callback-structure.js similarity index 100% rename from compiler/packages/babel-plugin-react-compiler/src/__tests__/fixtures/compiler/error.return-ref-callback-structure.js rename to compiler/packages/babel-plugin-react-compiler/src/__tests__/fixtures/compiler/return-ref-callback-structure.js