Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[compiler] Emit Let Bindings Iteratively #14163

Merged
merged 11 commits into from
Jan 22, 2024
88 changes: 54 additions & 34 deletions hail/src/main/scala/is/hail/expr/ir/Emit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ case class EmitEnv(bindings: Env[EmitValue], inputValues: IndexedSeq[EmitValue])
}
(paramTypes, params, recreateFromMB)
}

}

object Emit {
Expand Down Expand Up @@ -674,11 +675,7 @@ abstract class EstimableEmitter[C] {
def estimatedSize: Int
}

class Emit[C](
val ctx: EmitContext,
val cb: EmitClassBuilder[C],
) {
emitSelf =>
class Emit[C](val ctx: EmitContext, val cb: EmitClassBuilder[C]) {

val methods: mutable.Map[(String, Seq[Type], Seq[SType], SType), EmitMethodBuilder[C]] =
mutable.Map()
Expand Down Expand Up @@ -800,6 +797,7 @@ class Emit[C](

def emitI(
ir: IR,
cb: EmitCodeBuilder = cb,
region: Value[Region] = region,
env: EmitEnv = env,
container: Option[AggContainer] = container,
Expand Down Expand Up @@ -839,19 +837,17 @@ class Emit[C](

emitI(cond).consume(cb, {}, m => cb.if_(m.asBoolean.value, emitVoid(cnsq), emitVoid(altr)))

case Let(bindings, body) =>
def go(env: EmitEnv): IndexedSeq[(String, IR)] => Unit = {
case (name, value) +: rest =>
val xVal =
if (value.typ.isInstanceOf[TStream]) emitStream(value, region, env = env)
else emit(value, env = env)

cb.withScopedMaybeStreamValue(xVal, s"let_$name")(ev => go(env.bind(name, ev))(rest))
case Seq() =>
emitVoid(body, env = env)
}

go(env)(bindings)
case let: Let =>
emitLet(
emitI = (ir, cb, env) =>
if (ir.typ.isInstanceOf[TStream]) emitStream(ir, region, env = env).toI(cb)
else emitI(ir, cb = cb, env = env),
emitBody = (ir, cb, env) => emitVoid(ir, cb, env = env),
)(
let,
cb,
env,
)

case StreamFor(a, valueName, body) =>
emitStream(a, region).toI(cb).consume(
Expand Down Expand Up @@ -1447,7 +1443,7 @@ class Emit[C](
sorter.sort(
cb,
region,
makeDependentSortingFunction(cb, sct, lessThan, env, emitSelf, Array(left, right)),
makeDependentSortingFunction(cb, sct, lessThan, env, this, Array(left, right)),
)
sorter.toRegion(cb, x.typ)
}
Expand Down Expand Up @@ -3558,22 +3554,18 @@ class Emit[C](

val result: EmitCode = (ir: @unchecked) match {

case Let(bindings, body) =>
case let: Let =>
EmitCode.fromI(mb) { cb =>
def go(env: EmitEnv): IndexedSeq[(String, IR)] => IEmitCode = {
case (name, value) +: rest =>
val xVal =
if (value.typ.isInstanceOf[TStream]) emitStream(value, region, env = env)
else emit(value, env = env)

cb.withScopedMaybeStreamValue(xVal, s"let_$name") { ev =>
go(env.bind(name, ev))(rest)
}
case Seq() =>
emitI(body, cb, env = env)
}

go(env)(bindings)
emitLet(
emitI = (ir, cb, env) =>
if (ir.typ.isInstanceOf[TStream]) emitStream(ir, region, env = env).toI(cb)
else emitI(ir, cb = cb, env = env),
emitBody = (ir, cb, env) => emitI(ir, cb, env = env),
)(
let,
cb,
env,
)
}

case Ref(name, t) =>
Expand Down Expand Up @@ -3700,6 +3692,34 @@ class Emit[C](
(cb: EmitCodeBuilder, region: Value[Region], l: Value[_], r: Value[_]) =>
cb.memoize(cb.invokeCode[Boolean](sort, cb.this_, region, l, r))
}

def emitLet[A](
emitI: (IR, EmitCodeBuilder, EmitEnv) => IEmitCode,
emitBody: (IR, EmitCodeBuilder, EmitEnv) => A,
)(
let: Let,
cb: EmitCodeBuilder,
env: EmitEnv,
): A = {
val uses: mutable.Set[String] =
ctx.usesAndDefs.uses.get(let) match {
case Some(refs) => refs.map(_.t.name)
case None => mutable.Set.empty
}

emitBody(
let.body,
cb,
let.bindings.foldLeft(env) { case (newEnv, (name, ir)) =>
if (!uses.contains(name)) newEnv
else {
val value = emitI(ir, cb, newEnv)
val memo = cb.memoizeMaybeStreamValue(value, s"let_$name")
newEnv.bind(name, memo)
}
},
)
}
}

object NDArrayEmitter {
Expand Down
31 changes: 17 additions & 14 deletions hail/src/main/scala/is/hail/expr/ir/EmitCodeBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -160,24 +160,27 @@ class EmitCodeBuilder(val emb: EmitMethodBuilder[_], var code: Code[Unit]) exten
}

def withScopedMaybeStreamValue[T](ec: EmitCode, name: String)(f: EmitValue => T): T = {
if (ec.st.isRealizable) {
f(memoizeField(ec, name))
} else {
assert(ec.st.isInstanceOf[SStream])
val ev = if (ec.required)
EmitValue(None, ec.toI(this).get(this, ""))
val ev = memoizeMaybeStreamValue(ec.toI(this), name)
val res = f(ev)
ec.pv match {
case ss: SStreamValue =>
ss.defineUnusedLabels(emb)
case _ =>
}
res
}

def memoizeMaybeStreamValue(iec: IEmitCode, name: String): EmitValue =
if (iec.st.isRealizable) memoizeField(iec, name)
else {
assert(iec.st.isInstanceOf[SStream])
if (iec.required) EmitValue(None, iec.get(this, ""))
else {
val m = emb.genFieldThisRef[Boolean](name + "_missing")
ec.toI(this).consume(this, assign(m, true), _ => assign(m, false))
EmitValue(Some(m), ec.pv)
}
val res = f(ev)
ec.pv match {
case ss: SStreamValue => ss.defineUnusedLabels(emb)
iec.consume(this, assign(m, true), _ => assign(m, false))
EmitValue(Some(m), iec.value)
}
res
}
}

def memoizeField(v: IEmitCode, name: String): EmitValue = {
require(v.st.isRealizable)
Expand Down
2 changes: 1 addition & 1 deletion hail/src/main/scala/is/hail/expr/ir/Env.scala
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ class Env[V] private (val m: Map[Env.K, V]) {
def apply(name: String): V = m(name)

def lookup(name: String): V =
m.get(name).getOrElse(throw new RuntimeException(s"Cannot find $name in $m"))
m.getOrElse(name, throw new RuntimeException(s"Cannot find $name in $m"))

def lookupOption(name: String): Option[V] = m.get(name)

Expand Down
20 changes: 9 additions & 11 deletions hail/src/main/scala/is/hail/expr/ir/streams/EmitStream.scala
Original file line number Diff line number Diff line change
Expand Up @@ -365,17 +365,15 @@ object EmitStream {
SStreamValue(producer)
}

case Let(bindings, body) =>
def go(env: EmitEnv): IndexedSeq[(String, IR)] => IEmitCode = {
case (name, value) +: rest =>
cb.withScopedMaybeStreamValue(
EmitCode.fromI(cb.emb)(cb => emit(value, cb, env = env)),
s"let_$name",
)(ev => go(env.bind(name, ev))(rest))
case Seq() =>
produce(body, cb, env = env)
}
go(env)(bindings)
case let: Let =>
emitter.emitLet(
emitI = (ir, cb, env) => emit(ir, cb, env = env),
emitBody = (ir, cb, env) => produce(ir, cb, env = env),
)(
let,
cb,
env,
)

case In(n, _) =>
// this, Code[Region], ...
Expand Down