Skip to content

Commit

Permalink
feat: relaxed reset/reuse in the code generator
Browse files Browse the repository at this point in the history
closes #4089
  • Loading branch information
leodemoura committed May 7, 2024
1 parent 883a3e7 commit a18e517
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 22 deletions.
88 changes: 66 additions & 22 deletions src/Lean/Compiler/IR/ResetReuse.lean
Original file line number Diff line number Diff line change
Expand Up @@ -29,43 +29,50 @@ Here are the main differences:
does not occur in a function body. See example at `livevars.lean`.
-/

private def mayReuse (c₁ c₂ : CtorInfo) : Bool :=
private def mayReuse (c₁ c₂ : CtorInfo) (relaxedReuse : Bool) : Bool :=
c₁.size == c₂.size && c₁.usize == c₂.usize && c₁.ssize == c₂.ssize &&
/- The following condition is a heuristic.
We don't want to reuse cells from different types even when they are compatible
If `relaxedReuse := false`, then we don't want to reuse cells from
different constructors even when they are compatible
because it produces counterintuitive behavior. -/
c₁.name.getPrefix == c₂.name.getPrefix
(relaxedReuse || c₁.name.getPrefix == c₂.name.getPrefix)

/--
Replace `ctor` applications with `reuse` applications if compatible.
`w` contains the "memory cell" being reused.
-/
private partial def S (w : VarId) (c : CtorInfo) : FnBody → FnBody
private partial def S (w : VarId) (c : CtorInfo) (relaxedReuse : Bool) (b : FnBody) : FnBody :=
go b
where
go : FnBody → FnBody
| .vdecl x t v@(.ctor c' ys) b =>
if mayReuse c c' then
if mayReuse c c' relaxedReuse then
let updtCidx := c.cidx != c'.cidx
.vdecl x t (.reuse w c' updtCidx ys) b
else
.vdecl x t v (S w c b)
.vdecl x t v (go b)
| .jdecl j ys v b =>
let v' := S w c v
let v' := go v
if v == v' then
.jdecl j ys v (S w c b)
.jdecl j ys v (go b)
else
.jdecl j ys v' b
| .case tid x xType alts =>
.case tid x xType <| alts.map fun alt => alt.modifyBody (S w c)
.case tid x xType <| alts.map fun alt => alt.modifyBody go
| b =>
if b.isTerminal then
b
else let
(instr, b) := b.split
instr.setBody (S w c b)
else
let (instr, b) := b.split
instr.setBody (go b)

structure Context where
lctx : LocalContext := {}
/--
Contains all variables in `cases` statements in the current path.
Contains all variables in `cases` statements in the current path
and variables that are already in `reset` statements when we
invoke `R`.
We use this information to prevent double-reset in code such as
```
case x_i : obj of
Expand All @@ -74,8 +81,18 @@ structure Context where
Prod.mk →
...
```
A variable can already be in a `reset` statement when we
invoke `R` because we execute it with and without `relaxedReuse`.
-/
casesVars : PHashSet VarId := {}
alreadyFound : PHashSet VarId := {}
/--
If `relaxedReuse := true`, then allow memory cells from different
constructors to be reused. For example, we can reuse a `PSigma.mk`
to allocate a `Prod.mk`. To avoid counterintuitive behavior,
we first try `relaxedReuse := false`, and then `relaxedReuse := true`.
-/
relaxedReuse : Bool := false

/-- We use `Context` to track join points in scope. -/
abbrev M := ReaderT Context (StateT Index Id)
Expand All @@ -90,7 +107,7 @@ to replace a `ctor` withe `reuse` in `b`.
-/
private def tryS (x : VarId) (c : CtorInfo) (b : FnBody) : M FnBody := do
let w ← mkFresh
let b' := S w c b
let b' := S w c (← read).relaxedReuse b
if b == b' then
return b
else
Expand All @@ -102,8 +119,8 @@ private def Dfinalize (x : VarId) (c : CtorInfo) : FnBody × Bool → M FnBody

private def argsContainsVar (ys : Array Arg) (x : VarId) : Bool :=
ys.any fun arg => match arg with
| Arg.var y => x == y
| _ => false
| .var y => x == y
| _ => false

private def isCtorUsing (b : FnBody) (x : VarId) : Bool :=
match b with
Expand Down Expand Up @@ -161,8 +178,8 @@ private def D (x : VarId) (c : CtorInfo) (b : FnBody) : M FnBody :=
partial def R (e : FnBody) : M FnBody := do
match e with
| .case tid x xType alts =>
let alreadyFound := (← read).casesVars.contains x
withReader (fun ctx => { ctx with casesVars := ctx.casesVars.insert x }) do
let alreadyFound := (← read).alreadyFound.contains x
withReader (fun ctx => { ctx with alreadyFound := ctx.alreadyFound.insert x }) do
let alts ← alts.mapM fun alt => do
let alt ← alt.mmodifyBody R
match alt with
Expand All @@ -187,16 +204,43 @@ partial def R (e : FnBody) : M FnBody := do
let b ← R b
return instr.setBody b

end ResetReuse
abbrev N := StateT (PHashSet VarId) Id

partial def collectResets (e : FnBody) : N Unit := do
match e with
| .case _ _ _ alts => alts.forM fun alt => collectResets alt.body
| .jdecl _ _ v b => collectResets v; collectResets b
| .vdecl _ _ (.reset _ x) b => modify fun s => s.insert x; collectResets b
| e => unless e.isTerminal do
let (_, b) := e.split
collectResets b

end ResetReuse
open ResetReuse

def Decl.insertResetReuse (d : Decl) : Decl :=

def Decl.insertResetReuseCore (d : Decl) (relaxedReuse : Bool) : Decl :=
match d with
| .fdecl (body := b) .. =>
let nextIndex := d.maxIndex + 1
let bNew := (R b {}).run' nextIndex
-- First time we execute `insertResetReuseCore`, `relaxedReuse := false`.
let alreadyFound : PHashSet VarId := if relaxedReuse then (collectResets b *> get).run' {} else {}
let bNew := R b { relaxedReuse, alreadyFound } |>.run' nextIndex
d.updateBody! bNew
| other => other

def Decl.insertResetReuse (d : Decl) : Decl :=
/-
We execute the reset/reuse algorithm twice. The first time, we only reuse memory cells
between identical constructor memory cells. That is, we do not reuse a `PSigma.mk` memory cell
when allocating a `Prod.mk` memory cell, even though they have the same layout. Recall
that the reset/reuse placement algorithm is a heuristic, and the first pass prevents reuses
that are unlikely to be useful at runtime. Then, we run the procedure again,
relaxing this restriction. If there are still opportunities for reuse, we will take advantage of them.
The second pass addresses issue #4089.
-/
d.insertResetReuseCore (relaxedReuse := false)
|>.insertResetReuseCore (relaxedReuse := true)

end Lean.IR
12 changes: 12 additions & 0 deletions tests/lean/4089.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
set_option trace.compiler.ir.reset_reuse true

def f : Nat × Nat → Nat × Nat
| (a, b) => (b, a)

def Sigma.toProd : (_ : α) × β → α × β
| ⟨a, b⟩ => (a, b)

def foo : List (Nat × Nat) → List Nat
| [] => []
| x :: xs => match x with
| (a, _) => a :: foo xs
38 changes: 38 additions & 0 deletions tests/lean/4089.lean.expected.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@

[reset_reuse]
def f (x_1 : obj) : obj :=
case x_1 : obj of
Prod.mk →
let x_2 : obj := proj[0] x_1;
let x_3 : obj := proj[1] x_1;
let x_5 : obj := reset[2] x_1;
let x_4 : obj := reuse x_5 in ctor_0[Prod.mk] x_3 x_2;
ret x_4
[reset_reuse]
def Sigma.toProd._rarg (x_1 : obj) : obj :=
case x_1 : obj of
Sigma.mk →
let x_2 : obj := proj[0] x_1;
let x_3 : obj := proj[1] x_1;
let x_5 : obj := reset[2] x_1;
let x_4 : obj := reuse x_5 in ctor_0[Prod.mk] x_2 x_3;
ret x_4
def Sigma.toProd (x_1 : ◾) (x_2 : ◾) : obj :=
let x_3 : obj := pap Sigma.toProd._rarg;
ret x_3
[reset_reuse]
def foo (x_1 : obj) : obj :=
case x_1 : obj of
List.nil →
let x_2 : obj := ctor_0[List.nil];
ret x_2
List.cons →
let x_3 : obj := proj[0] x_1;
case x_3 : obj of
Prod.mk →
let x_4 : obj := proj[1] x_1;
let x_9 : obj := reset[2] x_1;
let x_5 : obj := proj[0] x_3;
let x_6 : obj := foo x_4;
let x_7 : obj := reuse x_9 in ctor_1[List.cons] x_5 x_6;
ret x_7

0 comments on commit a18e517

Please sign in to comment.