Skip to content

Commit

Permalink
StagedPriorityQueue is independent of Interval
Browse files Browse the repository at this point in the history
  • Loading branch information
ehigham committed Sep 13, 2023
1 parent ed76ebb commit 2d18dd9
Show file tree
Hide file tree
Showing 5 changed files with 356 additions and 219 deletions.
28 changes: 24 additions & 4 deletions hail/src/main/scala/is/hail/expr/ir/EmitClassBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,21 @@ import scala.collection.mutable
import scala.language.existentials

class EmitModuleBuilder(val ctx: ExecuteContext, val modb: ModuleBuilder) {

def getOrEmitNewClass[C: TypeInfo](name: String, sourceFile: Option[String] = None)
(body: EmitClassBuilder[C] => Unit)
: EmitClassBuilder[C] =
modb
.classes
.find(kb => kb.className == name && kb.sourceFile == sourceFile)
.map(kb => new EmitClassBuilder[C](this, kb.asInstanceOf[ClassBuilder[C]]))
.getOrElse {
val kb = newEmitClass[C](name, sourceFile)
body(kb)
kb
}


def newEmitClass[C](name: String, sourceFile: Option[String] = None)(implicit cti: TypeInfo[C]): EmitClassBuilder[C] =
new EmitClassBuilder(this, modb.newClass(name, sourceFile))

Expand Down Expand Up @@ -810,16 +825,21 @@ class EmitClassBuilder[C](

private[this] val methodMemo: mutable.Map[Any, EmitMethodBuilder[C]] = mutable.Map()

def getOrGenEmitMethod(
baseName: String, key: Any, argsInfo: IndexedSeq[ParamType], returnInfo: ParamType
)(body: EmitMethodBuilder[C] => Unit): EmitMethodBuilder[C] = {
def getOrGenEmitMethod(baseName: String,
key: Any,
argsInfo: IndexedSeq[ParamType],
returnInfo: ParamType
)(body: EmitMethodBuilder[C] => Unit)
: EmitMethodBuilder[C] =
methodMemo.getOrElse(key, {
val mb = genEmitMethod(baseName, argsInfo, returnInfo)
methodMemo(key) = mb
body(mb)
mb
})
}

def getEmitMethod(key: Any): EmitMethodBuilder[C] =
methodMemo.getOrElse(key, throw new NoSuchMethodError(s"No such method '$key' in '$className'"))

def genEmitMethod(baseName: String, argsInfo: IndexedSeq[ParamType], returnInfo: ParamType): EmitMethodBuilder[C] =
newEmitMethod(genName("m", baseName), argsInfo, returnInfo)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,17 @@ class StagedArrayBuilder(eltType: PType, kb: EmitClassBuilder[_], region: Value[
def elementOffset(cb: EmitCodeBuilder, idx: Value[Int]): Value[Long] =
cb.memoize(eltArray.elementOffset(data, capacity, idx))


def loadElement(cb: EmitCodeBuilder, idx: Value[Int]): EmitCode = {
val m = eltArray.isElementMissing(data, idx)
EmitCode(Code._empty, m, eltType.loadCheapSCode(cb, eltArray.loadElement(data, capacity, idx)))
}

def swap(cb: EmitCodeBuilder, p: Value[Int], q: Value[Int]): Unit = {
val tmp = loadElement(cb, p).memoize(cb, "tmp")
overwrite(cb, loadElement(cb, q).memoize(cb, ""), p)
overwrite(cb, tmp, q)
}

private def resize(cb: EmitCodeBuilder): Unit = {
val newDataOffset = kb.genFieldThisRef[Long]("new_data_offset")
cb.ifx(size.ceq(capacity),
Expand Down
168 changes: 113 additions & 55 deletions hail/src/main/scala/is/hail/expr/ir/streams/EmitStream.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import is.hail.methods.{BitPackedVector, BitPackedVectorBuilder, LocalLDPrune, L
import is.hail.types.physical.stypes.concrete.{SBinaryPointer, SStackStruct, SUnreachable}
import is.hail.types.physical.stypes.interfaces._
import is.hail.types.physical.stypes.primitives.{SFloat64Value, SInt32Value}
import is.hail.types.physical.stypes.{EmitType, SSettable}
import is.hail.types.physical.stypes.{EmitType, SSettable, SValue}
import is.hail.types.physical.{PCanonicalArray, PCanonicalBinary, PCanonicalStruct, PType}
import is.hail.types.virtual._
import is.hail.types.{RIterable, TypeWithRequiredness, VirtualTypeWithReq}
Expand Down Expand Up @@ -1451,79 +1451,137 @@ object EmitStream {

case x@StreamLeftIntervalJoin(left, right, lKeyNames, rIntrvlName, lEltName, rEltName, body) =>

// Min-heap used requires right elements in the form TTuple(TInterval, TStruct)
val rTupled = {
val rEltRef =
Ref(genUID(), TIterable.elementType(right.typ))

StreamMap(right, rEltRef.name, MakeTuple(FastSeq(
0 -> GetField(rEltRef, rIntrvlName),
1 -> rEltRef
)))
}

produce(left, cb).flatMap(cb) { case lStream: SStreamValue =>
produce(rTupled, cb).map(cb) { case rStream: SStreamValue =>
produce(right, cb).map(cb) { case rStream: SStreamValue =>

// map over the keyStream
val lProd = lStream.getProducer(mb)
val rProd = rStream.getProducer(mb)
val minHeap = new StagedIntervalMinHeap(mb, rProd.element.st.asInstanceOf[SBaseStruct])

val leftStructField = mb.newPField(lProd.element.st)
val rElemSTy = rProd.element.st.asInstanceOf[SBaseStruct]

val intervalResultField = mb.newPField(minHeap.resultArraySType)
val intrvlParamTy: ParamType =
SCodeParamType(rElemSTy.fieldTypes(rElemSTy.fieldIdx(rIntrvlName)))

val eltRegion = mb.genFieldThisRef[Region]("interval_join_region")
val joinResult = EmitCode.fromI(mb) { cb =>
emit(
body,
cb,
region = eltRegion,
env = env.bind(
lEltName -> EmitValue.present(leftStructField),
rEltName -> EmitValue.present(intervalResultField)
val k: EmitClassBuilder[Unit] =
mb.genEmitClass[Unit]("RightEndpointComparator")

val loadInterval: EmitMethodBuilder[Unit] =
k.getOrGenEmitMethod("loadInterval", "loadInterval", Array.empty, intrvlParamTy) { mb =>
mb.emitSCode { cb =>
mb.getEmitParam(cb, 0)
.get(cb)
.asBaseStruct
.loadField(cb, rIntrvlName)
.get(cb)
}
}

val minHeap =
StagedPriorityQueue(mb.ecb.emodb, rElemSTy,
new EmitFunctionBuilder[Unit](
k.getOrGenEmitMethod("compare", "compare", FastSeq(intrvlParamTy, intrvlParamTy), IntInfo) { mb =>
mb.emitWithBuilder[Int] { cb =>
val l = cb.invokeSCode(loadInterval, mb.getEmitParam(cb, 0)).asInterval
val r = cb.invokeSCode(loadInterval, mb.getEmitParam(cb, 1)).asInterval
IntervalFunctions.intervalEndpointCompare(cb,
l.loadEnd(cb).get(cb), l.includesEnd,
r.loadEnd(cb).get(cb), r.includesEnd
)
}
}
)
)
}

new StreamProducer {
override def method: EmitMethodBuilder[_] = mb
val leftStructField = mb.newPField(lProd.element.st)
val intervalResultField = mb.newPField(PCanonicalArray(rProd.element.st.storageType(), required = true).sType)
val eltRegion = mb.genFieldThisRef[Region]("interval_join_region")

override val length: Option[EmitCodeBuilder => Code[Int]] = lProd.length
SStreamValue {
new StreamProducer {
override def method: EmitMethodBuilder[_] =
mb

override def initialize(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit = {
lProd.initialize(cb, outerRegion)
rProd.initialize(cb, outerRegion)
}
override val length: Option[EmitCodeBuilder => Code[Int]] =
lProd.length

override val elementRegion: Settable[Region] = eltRegion
override def initialize(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit = {
lProd.initialize(cb, outerRegion)
rProd.initialize(cb, outerRegion)
minHeap.initialize(cb)
}

override val requiresMemoryManagementPerElement: Boolean = lProd.requiresMemoryManagementPerElement || rProd.requiresMemoryManagementPerElement
override val elementRegion: Settable[Region] =
eltRegion

override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb =>
cb.goto(lProd.LproduceElement)
cb.define(lProd.LproduceElementDone)
val row = lProd.element.toI(cb).get(cb).asBaseStruct
val key = row.subset(lKeyNames: _*)
minHeap.dropLessThan(cb, key)
// now pull from the interval stream and insert into minheap while it is not exhausted and until
// the most-recent interval left endpoint is greater than the current key

cb.assign(leftStructField, row)
cb.assign(intervalResultField, minHeap.getAllContainedIntervalsAsArray(cb, eltRegion))
cb.goto(LproduceElementDone)
override val requiresMemoryManagementPerElement: Boolean =
lProd.requiresMemoryManagementPerElement || rProd.requiresMemoryManagementPerElement

cb.define(lProd.LendOfStream)
cb.goto(LendOfStream)
}
override val LproduceElement: CodeLabel =
mb.defineAndImplementLabel { cb =>
cb.goto(lProd.LproduceElement)
cb.define(lProd.LproduceElementDone)
val row = lProd.element.toI(cb).get(cb).asBaseStruct
val key = row.subset(lKeyNames: _*)

override val element: EmitCode = joinResult
cb.whileLoop(
minHeap.nonEmpty(cb) && {
val interval = cb.invokeSCode(loadInterval, minHeap.peek(cb)).asInterval
IntervalFunctions.intervalContains(cb, interval, key).get(cb).asBoolean.value
},
minHeap.poll(cb)
)

override def close(cb: EmitCodeBuilder): Unit = {
minHeap.close(cb)
rProd.close(cb)
lProd.close(cb)
minHeap.realloc(cb)

// now pull from the interval stream and insert into minheap while
// the most-recent interval left endpoint is greater than the current key

cb.goto(rProd.LproduceElement)
cb.define(rProd.LproduceElementDone)

val rElem: SValue = rProd.element.toI(cb).get(cb)
val rInterval = cb.invokeSCode(loadInterval, rElem).asInterval

cb.ifx(
IntervalFunctions.pointGTIntervalEndpoint(cb, key, rInterval.loadEnd(cb).get(cb), rInterval.includesEnd),
cb.goto(rProd.LproduceElement)
)

val LallIntervalsFound = CodeLabel()
// need lookahead
cb.ifx(
IntervalFunctions.pointLTIntervalEndpoint(cb, key, rInterval.loadStart(cb).get(cb), rInterval.includesStart),
cb.goto(LallIntervalsFound)
)

// we've found the first interval that contains key
// add interval to minheap
minHeap.add(cb, rElem)
cb.goto(rProd.LproduceElement)

cb.define(LallIntervalsFound)
cb.assign(leftStructField, row)
cb.assign(intervalResultField, minHeap.toArray(cb, eltRegion))
cb.goto(LproduceElementDone)

cb.define(lProd.LendOfStream)
cb.goto(LendOfStream)
}

override val element: EmitCode =
EmitCode.fromI(mb) { cb =>
emit(body, cb, region = eltRegion, env = env.bind(
lEltName -> EmitValue.present(leftStructField),
rEltName -> EmitValue.present(intervalResultField)
))
}

override def close(cb: EmitCodeBuilder): Unit = {
minHeap.close(cb)
rProd.close(cb)
lProd.close(cb)
}
}
}
}
Expand Down
Loading

0 comments on commit 2d18dd9

Please sign in to comment.