Skip to content

Commit

Permalink
[compiler] Emit Let Bindings Iteratively (hail-is#14163)
Browse files Browse the repository at this point in the history
Previously `Emit(?:Stream)?$` would emit let bindings recursively,
regardless of if that binding was used.
If a stream is not used, `Emit(?:Stream)?$` would define its missing
labels, making emission recursive.
This can lead to stack overflows for large numbers of let-bindings (and
does so for the benchmark benchmark `matrix-multi-write-nothing`).

By not emitting unused streams, we can make let-binding emission
iterative.
  • Loading branch information
ehigham authored Jan 22, 2024
1 parent 2858259 commit 8ae336f
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 60 deletions.
88 changes: 54 additions & 34 deletions hail/src/main/scala/is/hail/expr/ir/Emit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ case class EmitEnv(bindings: Env[EmitValue], inputValues: IndexedSeq[EmitValue])
}
(paramTypes, params, recreateFromMB)
}

}

object Emit {
Expand Down Expand Up @@ -675,11 +676,7 @@ abstract class EstimableEmitter[C] {
def estimatedSize: Int
}

class Emit[C](
val ctx: EmitContext,
val cb: EmitClassBuilder[C],
) {
emitSelf =>
class Emit[C](val ctx: EmitContext, val cb: EmitClassBuilder[C]) {

val methods: mutable.Map[(String, Seq[Type], Seq[SType], SType), EmitMethodBuilder[C]] =
mutable.Map()
Expand Down Expand Up @@ -801,6 +798,7 @@ class Emit[C](

def emitI(
ir: IR,
cb: EmitCodeBuilder = cb,
region: Value[Region] = region,
env: EmitEnv = env,
container: Option[AggContainer] = container,
Expand Down Expand Up @@ -840,19 +838,17 @@ class Emit[C](

emitI(cond).consume(cb, {}, m => cb.if_(m.asBoolean.value, emitVoid(cnsq), emitVoid(altr)))

case Let(bindings, body) =>
def go(env: EmitEnv): IndexedSeq[(String, IR)] => Unit = {
case (name, value) +: rest =>
val xVal =
if (value.typ.isInstanceOf[TStream]) emitStream(value, region, env = env)
else emit(value, env = env)

cb.withScopedMaybeStreamValue(xVal, s"let_$name")(ev => go(env.bind(name, ev))(rest))
case Seq() =>
emitVoid(body, env = env)
}

go(env)(bindings)
case let: Let =>
emitLet(
emitI = (ir, cb, env) =>
if (ir.typ.isInstanceOf[TStream]) emitStream(ir, region, env = env).toI(cb)
else emitI(ir, cb = cb, env = env),
emitBody = (ir, cb, env) => emitVoid(ir, cb, env = env),
)(
let,
cb,
env,
)

case StreamFor(a, valueName, body) =>
emitStream(a, region).toI(cb).consume(
Expand Down Expand Up @@ -1448,7 +1444,7 @@ class Emit[C](
sorter.sort(
cb,
region,
makeDependentSortingFunction(cb, sct, lessThan, env, emitSelf, Array(left, right)),
makeDependentSortingFunction(cb, sct, lessThan, env, this, Array(left, right)),
)
sorter.toRegion(cb, x.typ)
}
Expand Down Expand Up @@ -3559,22 +3555,18 @@ class Emit[C](

val result: EmitCode = (ir: @unchecked) match {

case Let(bindings, body) =>
case let: Let =>
EmitCode.fromI(mb) { cb =>
def go(env: EmitEnv): IndexedSeq[(String, IR)] => IEmitCode = {
case (name, value) +: rest =>
val xVal =
if (value.typ.isInstanceOf[TStream]) emitStream(value, region, env = env)
else emit(value, env = env)

cb.withScopedMaybeStreamValue(xVal, s"let_$name") { ev =>
go(env.bind(name, ev))(rest)
}
case Seq() =>
emitI(body, cb, env = env)
}

go(env)(bindings)
emitLet(
emitI = (ir, cb, env) =>
if (ir.typ.isInstanceOf[TStream]) emitStream(ir, region, env = env).toI(cb)
else emitI(ir, cb = cb, env = env),
emitBody = (ir, cb, env) => emitI(ir, cb, env = env),
)(
let,
cb,
env,
)
}

case Ref(name, t) =>
Expand Down Expand Up @@ -3701,6 +3693,34 @@ class Emit[C](
(cb: EmitCodeBuilder, region: Value[Region], l: Value[_], r: Value[_]) =>
cb.memoize(cb.invokeCode[Boolean](sort, cb.this_, region, l, r))
}

def emitLet[A](
emitI: (IR, EmitCodeBuilder, EmitEnv) => IEmitCode,
emitBody: (IR, EmitCodeBuilder, EmitEnv) => A,
)(
let: Let,
cb: EmitCodeBuilder,
env: EmitEnv,
): A = {
val uses: mutable.Set[String] =
ctx.usesAndDefs.uses.get(let) match {
case Some(refs) => refs.map(_.t.name)
case None => mutable.Set.empty
}

emitBody(
let.body,
cb,
let.bindings.foldLeft(env) { case (newEnv, (name, ir)) =>
if (!uses.contains(name)) newEnv
else {
val value = emitI(ir, cb, newEnv)
val memo = cb.memoizeMaybeStreamValue(value, s"let_$name")
newEnv.bind(name, memo)
}
},
)
}
}

object NDArrayEmitter {
Expand Down
31 changes: 17 additions & 14 deletions hail/src/main/scala/is/hail/expr/ir/EmitCodeBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -160,24 +160,27 @@ class EmitCodeBuilder(val emb: EmitMethodBuilder[_], var code: Code[Unit]) exten
}

def withScopedMaybeStreamValue[T](ec: EmitCode, name: String)(f: EmitValue => T): T = {
if (ec.st.isRealizable) {
f(memoizeField(ec, name))
} else {
assert(ec.st.isInstanceOf[SStream])
val ev = if (ec.required)
EmitValue(None, ec.toI(this).get(this, ""))
val ev = memoizeMaybeStreamValue(ec.toI(this), name)
val res = f(ev)
ec.pv match {
case ss: SStreamValue =>
ss.defineUnusedLabels(emb)
case _ =>
}
res
}

def memoizeMaybeStreamValue(iec: IEmitCode, name: String): EmitValue =
if (iec.st.isRealizable) memoizeField(iec, name)
else {
assert(iec.st.isInstanceOf[SStream])
if (iec.required) EmitValue(None, iec.get(this, ""))
else {
val m = emb.genFieldThisRef[Boolean](name + "_missing")
ec.toI(this).consume(this, assign(m, true), _ => assign(m, false))
EmitValue(Some(m), ec.pv)
}
val res = f(ev)
ec.pv match {
case ss: SStreamValue => ss.defineUnusedLabels(emb)
iec.consume(this, assign(m, true), _ => assign(m, false))
EmitValue(Some(m), iec.value)
}
res
}
}

def memoizeField(v: IEmitCode, name: String): EmitValue = {
require(v.st.isRealizable)
Expand Down
2 changes: 1 addition & 1 deletion hail/src/main/scala/is/hail/expr/ir/Env.scala
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ class Env[V] private (val m: Map[Env.K, V]) {
def apply(name: String): V = m(name)

def lookup(name: String): V =
m.get(name).getOrElse(throw new RuntimeException(s"Cannot find $name in $m"))
m.getOrElse(name, throw new RuntimeException(s"Cannot find $name in $m"))

def lookupOption(name: String): Option[V] = m.get(name)

Expand Down
20 changes: 9 additions & 11 deletions hail/src/main/scala/is/hail/expr/ir/streams/EmitStream.scala
Original file line number Diff line number Diff line change
Expand Up @@ -364,17 +364,15 @@ object EmitStream {
SStreamValue(producer)
}

case Let(bindings, body) =>
def go(env: EmitEnv): IndexedSeq[(String, IR)] => IEmitCode = {
case (name, value) +: rest =>
cb.withScopedMaybeStreamValue(
EmitCode.fromI(cb.emb)(cb => emit(value, cb, env = env)),
s"let_$name",
)(ev => go(env.bind(name, ev))(rest))
case Seq() =>
produce(body, cb, env = env)
}
go(env)(bindings)
case let: Let =>
emitter.emitLet(
emitI = (ir, cb, env) => emit(ir, cb, env = env),
emitBody = (ir, cb, env) => produce(ir, cb, env = env),
)(
let,
cb,
env,
)

case In(n, _) =>
// this, Code[Region], ...
Expand Down

0 comments on commit 8ae336f

Please sign in to comment.