From 8ae336f6811f4d4435666b0c616693ba37cf2c71 Mon Sep 17 00:00:00 2001 From: Edmund Higham Date: Mon, 22 Jan 2024 13:12:07 -0500 Subject: [PATCH] [compiler] Emit `Let` Bindings Iteratively (#14163) Previously `Emit(?:Stream)?$` would emit let bindings recursively, regardless of if that binding was used. If a stream is not used, `Emit(?:Stream)?$` would define its missing labels, making emission recursive. This can lead to stack overflows for large numbers of let-bindings (and does so for the benchmark benchmark `matrix-multi-write-nothing`). By not emitting unused streams, we can make let-binding emission iterative. --- .../src/main/scala/is/hail/expr/ir/Emit.scala | 88 ++++++++++++------- .../is/hail/expr/ir/EmitCodeBuilder.scala | 31 ++++--- hail/src/main/scala/is/hail/expr/ir/Env.scala | 2 +- .../is/hail/expr/ir/streams/EmitStream.scala | 20 ++--- 4 files changed, 81 insertions(+), 60 deletions(-) diff --git a/hail/src/main/scala/is/hail/expr/ir/Emit.scala b/hail/src/main/scala/is/hail/expr/ir/Emit.scala index 6397686596c..2061e685536 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Emit.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Emit.scala @@ -88,6 +88,7 @@ case class EmitEnv(bindings: Env[EmitValue], inputValues: IndexedSeq[EmitValue]) } (paramTypes, params, recreateFromMB) } + } object Emit { @@ -675,11 +676,7 @@ abstract class EstimableEmitter[C] { def estimatedSize: Int } -class Emit[C]( - val ctx: EmitContext, - val cb: EmitClassBuilder[C], -) { - emitSelf => +class Emit[C](val ctx: EmitContext, val cb: EmitClassBuilder[C]) { val methods: mutable.Map[(String, Seq[Type], Seq[SType], SType), EmitMethodBuilder[C]] = mutable.Map() @@ -801,6 +798,7 @@ class Emit[C]( def emitI( ir: IR, + cb: EmitCodeBuilder = cb, region: Value[Region] = region, env: EmitEnv = env, container: Option[AggContainer] = container, @@ -840,19 +838,17 @@ class Emit[C]( emitI(cond).consume(cb, {}, m => cb.if_(m.asBoolean.value, emitVoid(cnsq), emitVoid(altr))) - case Let(bindings, body) => - def go(env: EmitEnv): IndexedSeq[(String, IR)] => Unit = { - case (name, value) +: rest => - val xVal = - if (value.typ.isInstanceOf[TStream]) emitStream(value, region, env = env) - else emit(value, env = env) - - cb.withScopedMaybeStreamValue(xVal, s"let_$name")(ev => go(env.bind(name, ev))(rest)) - case Seq() => - emitVoid(body, env = env) - } - - go(env)(bindings) + case let: Let => + emitLet( + emitI = (ir, cb, env) => + if (ir.typ.isInstanceOf[TStream]) emitStream(ir, region, env = env).toI(cb) + else emitI(ir, cb = cb, env = env), + emitBody = (ir, cb, env) => emitVoid(ir, cb, env = env), + )( + let, + cb, + env, + ) case StreamFor(a, valueName, body) => emitStream(a, region).toI(cb).consume( @@ -1448,7 +1444,7 @@ class Emit[C]( sorter.sort( cb, region, - makeDependentSortingFunction(cb, sct, lessThan, env, emitSelf, Array(left, right)), + makeDependentSortingFunction(cb, sct, lessThan, env, this, Array(left, right)), ) sorter.toRegion(cb, x.typ) } @@ -3559,22 +3555,18 @@ class Emit[C]( val result: EmitCode = (ir: @unchecked) match { - case Let(bindings, body) => + case let: Let => EmitCode.fromI(mb) { cb => - def go(env: EmitEnv): IndexedSeq[(String, IR)] => IEmitCode = { - case (name, value) +: rest => - val xVal = - if (value.typ.isInstanceOf[TStream]) emitStream(value, region, env = env) - else emit(value, env = env) - - cb.withScopedMaybeStreamValue(xVal, s"let_$name") { ev => - go(env.bind(name, ev))(rest) - } - case Seq() => - emitI(body, cb, env = env) - } - - go(env)(bindings) + emitLet( + emitI = (ir, cb, env) => + if (ir.typ.isInstanceOf[TStream]) emitStream(ir, region, env = env).toI(cb) + else emitI(ir, cb = cb, env = env), + emitBody = (ir, cb, env) => emitI(ir, cb, env = env), + )( + let, + cb, + env, + ) } case Ref(name, t) => @@ -3701,6 +3693,34 @@ class Emit[C]( (cb: EmitCodeBuilder, region: Value[Region], l: Value[_], r: Value[_]) => cb.memoize(cb.invokeCode[Boolean](sort, cb.this_, region, l, r)) } + + def emitLet[A]( + emitI: (IR, EmitCodeBuilder, EmitEnv) => IEmitCode, + emitBody: (IR, EmitCodeBuilder, EmitEnv) => A, + )( + let: Let, + cb: EmitCodeBuilder, + env: EmitEnv, + ): A = { + val uses: mutable.Set[String] = + ctx.usesAndDefs.uses.get(let) match { + case Some(refs) => refs.map(_.t.name) + case None => mutable.Set.empty + } + + emitBody( + let.body, + cb, + let.bindings.foldLeft(env) { case (newEnv, (name, ir)) => + if (!uses.contains(name)) newEnv + else { + val value = emitI(ir, cb, newEnv) + val memo = cb.memoizeMaybeStreamValue(value, s"let_$name") + newEnv.bind(name, memo) + } + }, + ) + } } object NDArrayEmitter { diff --git a/hail/src/main/scala/is/hail/expr/ir/EmitCodeBuilder.scala b/hail/src/main/scala/is/hail/expr/ir/EmitCodeBuilder.scala index 0d9a1f726ab..b386f374de5 100644 --- a/hail/src/main/scala/is/hail/expr/ir/EmitCodeBuilder.scala +++ b/hail/src/main/scala/is/hail/expr/ir/EmitCodeBuilder.scala @@ -160,24 +160,27 @@ class EmitCodeBuilder(val emb: EmitMethodBuilder[_], var code: Code[Unit]) exten } def withScopedMaybeStreamValue[T](ec: EmitCode, name: String)(f: EmitValue => T): T = { - if (ec.st.isRealizable) { - f(memoizeField(ec, name)) - } else { - assert(ec.st.isInstanceOf[SStream]) - val ev = if (ec.required) - EmitValue(None, ec.toI(this).get(this, "")) + val ev = memoizeMaybeStreamValue(ec.toI(this), name) + val res = f(ev) + ec.pv match { + case ss: SStreamValue => + ss.defineUnusedLabels(emb) + case _ => + } + res + } + + def memoizeMaybeStreamValue(iec: IEmitCode, name: String): EmitValue = + if (iec.st.isRealizable) memoizeField(iec, name) + else { + assert(iec.st.isInstanceOf[SStream]) + if (iec.required) EmitValue(None, iec.get(this, "")) else { val m = emb.genFieldThisRef[Boolean](name + "_missing") - ec.toI(this).consume(this, assign(m, true), _ => assign(m, false)) - EmitValue(Some(m), ec.pv) - } - val res = f(ev) - ec.pv match { - case ss: SStreamValue => ss.defineUnusedLabels(emb) + iec.consume(this, assign(m, true), _ => assign(m, false)) + EmitValue(Some(m), iec.value) } - res } - } def memoizeField(v: IEmitCode, name: String): EmitValue = { require(v.st.isRealizable) diff --git a/hail/src/main/scala/is/hail/expr/ir/Env.scala b/hail/src/main/scala/is/hail/expr/ir/Env.scala index bd2a40384cc..8a6783ec9c1 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Env.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Env.scala @@ -150,7 +150,7 @@ class Env[V] private (val m: Map[Env.K, V]) { def apply(name: String): V = m(name) def lookup(name: String): V = - m.get(name).getOrElse(throw new RuntimeException(s"Cannot find $name in $m")) + m.getOrElse(name, throw new RuntimeException(s"Cannot find $name in $m")) def lookupOption(name: String): Option[V] = m.get(name) diff --git a/hail/src/main/scala/is/hail/expr/ir/streams/EmitStream.scala b/hail/src/main/scala/is/hail/expr/ir/streams/EmitStream.scala index e64f790bda5..710584f2503 100644 --- a/hail/src/main/scala/is/hail/expr/ir/streams/EmitStream.scala +++ b/hail/src/main/scala/is/hail/expr/ir/streams/EmitStream.scala @@ -364,17 +364,15 @@ object EmitStream { SStreamValue(producer) } - case Let(bindings, body) => - def go(env: EmitEnv): IndexedSeq[(String, IR)] => IEmitCode = { - case (name, value) +: rest => - cb.withScopedMaybeStreamValue( - EmitCode.fromI(cb.emb)(cb => emit(value, cb, env = env)), - s"let_$name", - )(ev => go(env.bind(name, ev))(rest)) - case Seq() => - produce(body, cb, env = env) - } - go(env)(bindings) + case let: Let => + emitter.emitLet( + emitI = (ir, cb, env) => emit(ir, cb, env = env), + emitBody = (ir, cb, env) => produce(ir, cb, env = env), + )( + let, + cb, + env, + ) case In(n, _) => // this, Code[Region], ...