Skip to content

Commit d2b2a1b

Browse files
committed
[compiler] Propagate CreateFunction effects for functions that return functions
If you have a local helper function that itself returns a function (`() => () => { ... }`), we currently infer the return effect of the outer function as `Create mutable`. We correctly track the aliasing, but we lose some precision because we don't understand that a function specifically is being returned. Here, we do some extra analysis of which values are returned in InferMutationAliasingRanges, and if the sole return value is a function we infer a `CreateFunction` effect. We also infer an `Assign` (instead of a Create) if the sole return value was one of the context variables or parameters.
1 parent 1671142 commit d2b2a1b

File tree

3 files changed

+168
-35
lines changed

3 files changed

+168
-35
lines changed

compiler/packages/babel-plugin-react-compiler/src/Inference/InferMutationAliasingEffects.ts

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2498,10 +2498,47 @@ function computeEffectsForSignature(
24982498
break;
24992499
}
25002500
case 'CreateFunction': {
2501-
CompilerError.throwTodo({
2502-
reason: `Support CreateFrom effects in signatures`,
2503-
loc: receiver.loc,
2501+
const applyInto = substitutions.get(effect.into.identifier.id);
2502+
if (applyInto == null || applyInto.length !== 1) {
2503+
return null;
2504+
}
2505+
const captures: Array<Place> = [];
2506+
for (let i = 0; i < effect.captures.length; i++) {
2507+
const substitution = substitutions.get(
2508+
effect.captures[i].identifier.id,
2509+
);
2510+
if (substitution == null || substitution.length !== 1) {
2511+
return null;
2512+
}
2513+
captures.push(substitution[0]);
2514+
}
2515+
const context: Array<Place> = [];
2516+
const originalContext = effect.function.loweredFunc.func.context;
2517+
for (let i = 0; i < originalContext.length; i++) {
2518+
const substitution = substitutions.get(
2519+
originalContext[i].identifier.id,
2520+
);
2521+
if (substitution == null || substitution.length !== 1) {
2522+
return null;
2523+
}
2524+
context.push(substitution[0]);
2525+
}
2526+
effects.push({
2527+
kind: 'CreateFunction',
2528+
into: applyInto[0],
2529+
function: {
2530+
...effect.function,
2531+
loweredFunc: {
2532+
...effect.function.loweredFunc,
2533+
func: {
2534+
...effect.function.loweredFunc.func,
2535+
context,
2536+
},
2537+
},
2538+
},
2539+
captures,
25042540
});
2541+
break;
25052542
}
25062543
default: {
25072544
assertExhaustive(

compiler/packages/babel-plugin-react-compiler/src/Inference/InferMutationAliasingRanges.ts

Lines changed: 110 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ export function inferMutationAliasingRanges(
141141
} else if (effect.kind === 'CreateFunction') {
142142
state.create(effect.into, {
143143
kind: 'Function',
144-
function: effect.function.loweredFunc.func,
144+
effect,
145145
});
146146
} else if (effect.kind === 'CreateFrom') {
147147
state.createFrom(index++, effect.from, effect.into);
@@ -156,7 +156,7 @@ export function inferMutationAliasingRanges(
156156
* invariant here.
157157
*/
158158
if (!state.nodes.has(effect.into.identifier)) {
159-
state.create(effect.into, {kind: 'Object'});
159+
state.create(effect.into, {kind: 'Assign'});
160160
}
161161
state.assign(index++, effect.from, effect.into);
162162
} else if (effect.kind === 'Alias') {
@@ -474,35 +474,112 @@ export function inferMutationAliasingRanges(
474474
}
475475
}
476476

477+
const tracked: Array<Place> = [];
478+
for (const param of [...fn.params, ...fn.context, fn.returns]) {
479+
const place = param.kind === 'Identifier' ? param : param.place;
480+
tracked.push(place);
481+
}
482+
483+
const returned: Set<Node> = new Set();
484+
const queue: Array<Node> = [state.nodes.get(fn.returns.identifier)!];
485+
const seen: Set<Node> = new Set();
486+
while (queue.length !== 0) {
487+
const node = queue.pop()!;
488+
if (seen.has(node)) {
489+
continue;
490+
}
491+
seen.add(node);
492+
for (const id of node.aliases.keys()) {
493+
queue.push(state.nodes.get(id)!);
494+
}
495+
for (const id of node.createdFrom.keys()) {
496+
queue.push(state.nodes.get(id)!);
497+
}
498+
if (node.id.id === fn.returns.identifier.id) {
499+
continue;
500+
}
501+
switch (node.value.kind) {
502+
case 'Assign':
503+
case 'CreateFrom': {
504+
break;
505+
}
506+
case 'Phi':
507+
case 'Object':
508+
case 'Function': {
509+
returned.add(node);
510+
break;
511+
}
512+
default: {
513+
assertExhaustive(
514+
node.value,
515+
`Unexpected node value kind '${(node.value as any).kind}'`,
516+
);
517+
}
518+
}
519+
}
520+
const returnedValues = [...returned];
521+
if (
522+
returnedValues.length === 1 &&
523+
returnedValues[0].value.kind === 'Object' &&
524+
tracked.some(place => place.identifier.id === returnedValues[0].id.id)
525+
) {
526+
const from = tracked.find(
527+
place => place.identifier.id === returnedValues[0].id.id,
528+
)!;
529+
functionEffects.push({
530+
kind: 'Assign',
531+
from,
532+
into: fn.returns,
533+
});
534+
} else if (
535+
returnedValues.length === 1 &&
536+
returnedValues[0].value.kind === 'Function'
537+
) {
538+
const outerContext = new Set(fn.context.map(p => p.identifier.id));
539+
const effect = returnedValues[0].value.effect;
540+
functionEffects.push({
541+
kind: 'CreateFunction',
542+
function: {
543+
...effect.function,
544+
loweredFunc: {
545+
func: {
546+
...effect.function.loweredFunc.func,
547+
context: effect.function.loweredFunc.func.context.filter(p =>
548+
outerContext.has(p.identifier.id),
549+
),
550+
},
551+
},
552+
},
553+
captures: effect.captures.filter(p => outerContext.has(p.identifier.id)),
554+
into: fn.returns,
555+
});
556+
} else {
557+
const returns = fn.returns.identifier;
558+
functionEffects.push({
559+
kind: 'Create',
560+
into: fn.returns,
561+
value: isPrimitiveType(returns)
562+
? ValueKind.Primitive
563+
: isJsxType(returns.type)
564+
? ValueKind.Frozen
565+
: ValueKind.Mutable,
566+
reason: ValueReason.KnownReturnSignature,
567+
});
568+
}
569+
477570
/**
478571
* Part 3
479572
* Finish populating the externally visible effects. Above we bubble-up the side effects
480573
* (MutateFrozen/MutableGlobal/Impure/Render) as well as mutations of context variables.
481574
* Here we populate an effect to create the return value as well as populating alias/capture
482575
* effects for how data flows between the params, context vars, and return.
483576
*/
484-
const returns = fn.returns.identifier;
485-
functionEffects.push({
486-
kind: 'Create',
487-
into: fn.returns,
488-
value: isPrimitiveType(returns)
489-
? ValueKind.Primitive
490-
: isJsxType(returns.type)
491-
? ValueKind.Frozen
492-
: ValueKind.Mutable,
493-
reason: ValueReason.KnownReturnSignature,
494-
});
495577
/**
496578
* Determine precise data-flow effects by simulating transitive mutations of the params/
497579
* captures and seeing what other params/context variables are affected. Anything that
498580
* would be transitively mutated needs a capture relationship.
499581
*/
500-
const tracked: Array<Place> = [];
501582
const ignoredErrors = new CompilerError();
502-
for (const param of [...fn.params, ...fn.context, fn.returns]) {
503-
const place = param.kind === 'Identifier' ? param : param.place;
504-
tracked.push(place);
505-
}
506583
for (const into of tracked) {
507584
const mutationIndex = index++;
508585
state.mutate(
@@ -588,9 +665,14 @@ type Node = {
588665
lastMutated: number;
589666
mutationReason: MutationReason | null;
590667
value:
668+
| {kind: 'Assign'}
669+
| {kind: 'CreateFrom'}
591670
| {kind: 'Object'}
592671
| {kind: 'Phi'}
593-
| {kind: 'Function'; function: HIRFunction};
672+
| {
673+
kind: 'Function';
674+
effect: Extract<AliasingEffect, {kind: 'CreateFunction'}>;
675+
};
594676
};
595677
class AliasingState {
596678
nodes: Map<Identifier, Node> = new Map();
@@ -612,7 +694,7 @@ class AliasingState {
612694
}
613695

614696
createFrom(index: number, from: Place, into: Place): void {
615-
this.create(into, {kind: 'Object'});
697+
this.create(into, {kind: 'CreateFrom'});
616698
const fromNode = this.nodes.get(from.identifier);
617699
const toNode = this.nodes.get(into.identifier);
618700
if (fromNode == null || toNode == null) {
@@ -674,7 +756,10 @@ class AliasingState {
674756
continue;
675757
}
676758
if (node.value.kind === 'Function') {
677-
appendFunctionErrors(errors, node.value.function);
759+
appendFunctionErrors(
760+
errors,
761+
node.value.effect.function.loweredFunc.func,
762+
);
678763
}
679764
for (const [alias, when] of node.createdFrom) {
680765
if (when >= index) {
@@ -738,7 +823,10 @@ class AliasingState {
738823
node.transitive == null &&
739824
node.local == null
740825
) {
741-
appendFunctionErrors(errors, node.value.function);
826+
appendFunctionErrors(
827+
errors,
828+
node.value.effect.function.loweredFunc.func,
829+
);
742830
}
743831
if (transitive) {
744832
if (node.transitive == null || node.transitive.kind < kind) {

compiler/packages/babel-plugin-react-compiler/src/__tests__/fixtures/compiler/repro-returned-inner-fn-reassigns-context.expect.md

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ import { makeArray, Stringify, useIdentity } from "shared-runtime";
5555
*/
5656
function Foo(t0) {
5757
"use memo";
58-
const $ = _c(3);
58+
const $ = _c(5);
5959
const { b } = t0;
6060

6161
const fnFactory = () => () => {
@@ -66,18 +66,26 @@ function Foo(t0) {
6666
useIdentity();
6767

6868
const fn = fnFactory();
69-
const arr = makeArray(b);
70-
fn(arr);
7169
let t1;
72-
if ($[0] !== arr || $[1] !== myVar) {
73-
t1 = <Stringify cb={myVar} value={arr} shouldInvokeFns={true} />;
74-
$[0] = arr;
75-
$[1] = myVar;
76-
$[2] = t1;
70+
if ($[0] !== b) {
71+
t1 = makeArray(b);
72+
$[0] = b;
73+
$[1] = t1;
74+
} else {
75+
t1 = $[1];
76+
}
77+
const arr = t1;
78+
fn(arr);
79+
let t2;
80+
if ($[2] !== arr || $[3] !== myVar) {
81+
t2 = <Stringify cb={myVar} value={arr} shouldInvokeFns={true} />;
82+
$[2] = arr;
83+
$[3] = myVar;
84+
$[4] = t2;
7785
} else {
78-
t1 = $[2];
86+
t2 = $[4];
7987
}
80-
return t1;
88+
return t2;
8189
}
8290
function _temp2() {
8391
return console.log("b");

0 commit comments

Comments
 (0)