Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 39 additions & 40 deletions Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -140,49 +140,46 @@ trait Tensors extends OpenCL {
}

lazy val enqueue: Do[PendingBuffer] = {
val compiledKernel = kernelCache.get(
closure,
new Callable[CompiledKernel] {
def call(): CompiledKernel = {

val alphConversionContext = new AlphaConversionContext
val convertedTree = closure.tree.alphaConversion(alphConversionContext)

val sourceCode = {
val globalContext = new GlobalContext
val functionContext = Factory[OpenCLKernelBuilder].newInstance(globalContext)

val exportContext = new ExportContext
val kernelBody = convertedTree.export(functionContext, exportContext).asInstanceOf[functionContext.Term]

val kernelParameters = upvalues(closure.tree).map { upvalue: Parameter =>
exportContext.get(alphConversionContext.get(upvalue)).asInstanceOf[functionContext.Term]
}
fastraw"""
val compiledKernel = kernelCache.getIfPresent(closure) match {
case null =>
val alphConversionContext = new AlphaConversionContext
val convertedTree = closure.tree.alphaConversion(alphConversionContext)
val loader = new Callable[CompiledKernel] {
def call(): CompiledKernel = {
val sourceCode = {
val globalContext = new GlobalContext
val functionContext = Factory[OpenCLKernelBuilder].newInstance(globalContext)

val exportContext = new ExportContext
val kernelBody = convertedTree.export(functionContext, exportContext).asInstanceOf[functionContext.Term]

val kernelParameters = upvalues(closure.tree).map { upvalue: Parameter =>
exportContext.get(alphConversionContext.get(upvalue)).asInstanceOf[functionContext.Term]
}
fastraw"""
$globalContext
${functionContext.generateKernelSourceCode("jit_kernel", shape.length, kernelParameters, Seq(kernelBody))}
"""
}
}

val program = createProgramWithSource(sourceCode)
program.build()
val program = createProgramWithSource(sourceCode)
program.build()

val compiledKernel = new CompiledKernel {
val compiledKernel = new CompiledKernel {

def monadicClose: UnitContinuation[Unit] = program.monadicClose
def monadicClose: UnitContinuation[Unit] = program.monadicClose

def run(upvalues: List[Parameter]): Do[PendingBuffer] = {
// TODO: Manage life cycle of upvalues more delicately
// e.g. a buffer should be release as soon as possible if it is a dependency of another buffer,
// e.g. however, it can be hold longer time if it is dependencies of many other buffers.
def run(upvalues: List[Parameter]): Do[PendingBuffer] = {
// TODO: Manage life cycle of upvalues more delicately
// e.g. a buffer should be release as soon as possible if it is a dependency of another buffer,
// e.g. however, it can be hold longer time if it is dependencies of many other buffers.

upvalues
.traverse[ParallelDo, PendingBuffer] { tree =>
Parallel(tree.asInstanceOf[Parameter].id.asInstanceOf[Tensor].enqueue)
}
.unwrap
.intransitiveFlatMap {
arguments: List[PendingBuffer] =>
upvalues
.traverse[ParallelDo, PendingBuffer] { tree =>
Parallel(tree.id.asInstanceOf[Tensor].enqueue)
}
.unwrap
.intransitiveFlatMap { arguments: List[PendingBuffer] =>
Do.monadicCloseable(program.createFirstKernel()).intransitiveFlatMap { kernel: Kernel =>
allocateBuffer[Float](shape.product).flatMap { outputBuffer =>
for ((arugment, i) <- arguments.view.zipWithIndex) {
Expand All @@ -197,15 +194,17 @@ trait Tensors extends OpenCL {
}
}
}
}
}

}
}
compiledKernel
}
kernelCache.put(float.factory.newInstance(convertedTree.asInstanceOf[float.Tree]), compiledKernel)
compiledKernel
}
}
)
kernelCache.get(float.factory.newInstance(convertedTree.asInstanceOf[float.Tree]), loader)
case compiledKernel =>
compiledKernel
}

compiledKernel.run(upvalues(closure.tree)).shared
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ class TensorsSpec extends AsyncFreeSpec with Matchers {
floatBuffer.position() should be(0)
floatBuffer.limit() should be(shape.product)
floatBuffer.capacity() should be(shape.product)
tensors.kernelCache.getIfPresent(zeros.closure) should not be null
val zeros2 = tensors.Tensor.fill(element, shape)
tensors.kernelCache.getIfPresent(zeros2.closure) should not be null
}
}.run.toScalaFuture

Expand Down
93 changes: 67 additions & 26 deletions Trees/src/main/scala/com/thoughtworks/compute/Trees.scala
Original file line number Diff line number Diff line change
Expand Up @@ -72,36 +72,60 @@ trait Trees extends Expressions {
}
}

private def isSameChild(left: Any, right: Any, map: StructuralComparisonContext): Boolean = {
left match {
case left: TreeApi =>
right match {
case right: TreeApi =>
left.isSameStructure(right, map)
case _ =>
false
}
case left: Array[_] =>
right match {
case right: Array[_] =>
val leftLength = left.length
val rightLength = right.length
@tailrec def arrayLoop(start: Int): Boolean = {
if (start < leftLength) {
if (isSameChild(left(start), right(start), map)) {
arrayLoop(start + 1)
} else {
false
}
} else {
true
}
}
leftLength == rightLength && arrayLoop(0)
case _ =>
false
}
case _ =>
left == right
}
}
def isSameStructure(that: TreeApi, map: StructuralComparisonContext): Boolean = {

map.get(this) match {
case null =>
this.getClass == that.getClass && {
assert(this.productArity == that.productArity)
map.put(this, that)
val productArity: Int = this.productArity
@tailrec
def sameFields(start: Int = 0): Boolean = {
if (start < productArity) {
productElement(start) match {
case left: TreeApi =>
that.productElement(start) match {
case right: TreeApi =>
if (left.isSameStructure(right, map)) {
sameFields(start = start + 1)
} else {
false
}
case _ =>
false
}
case _ =>
false
def fieldLoop(from: Int): Boolean = {
if (from < productArity) {
if (isSameChild(this.productElement(from), that.productElement(from), map)) {
fieldLoop(from + 1)
} else {
false
}
} else {
true
}
}
sameFields()
fieldLoop(0)
}
case existing =>
existing eq that
Expand All @@ -113,15 +137,6 @@ trait Trees extends Expressions {

val id: Any

def isSameStructure(that: TreeApi, map: StructuralComparisonContext): Boolean = {
map.get(this) match {
case null =>
map.put(this, that)
true
case existing =>
existing eq that
}
}
}

final class HashCodeContext extends IdentityHashMap[TreeApi, Int] {
Expand Down Expand Up @@ -246,6 +261,15 @@ object Trees {

final case class FloatParameter(id: Any) extends TreeApi with Parameter { thisParameter =>
type TermIn[C <: Category] = C#FloatTerm
def isSameStructure(that: TreeApi, map: StructuralComparisonContext): Boolean = {
map.get(this) match {
case null =>
map.put(this, that)
true
case existing =>
existing eq that
}
}

def structuralHashCode(context: HashCodeContext): Int = {
context.asScala.getOrElseUpdate(this, {
Expand Down Expand Up @@ -482,6 +506,23 @@ object Trees {
type Element = elementType.TermIn[C]
}

def isSameStructure(that: TreeApi, map: StructuralComparisonContext): Boolean = {
map.get(this) match {
case null =>
that match {
case ArrayParameter(thatId, thatElemenetType, thatPadding, thatShape)
if elementType == thatElemenetType && padding == padding && java.util.Arrays.equals(shape,
thatShape) =>
map.put(this, that)
true
case _ =>
false
}
case existing =>
existing eq that
}
}

def structuralHashCode(context: HashCodeContext): Int = {
context.asScala.getOrElseUpdate(
this, {
Expand Down
47 changes: 35 additions & 12 deletions Trees/src/test/scala/com/thoughtworks/compute/TreesSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,41 @@ final class TreesSpec extends FreeSpec with Matchers {
"hashCode" in {

val trees: FloatArrayTrees = Factory[Trees.FloatArrayTrees with Trees.StructuralTrees].newInstance()

trees.float.literal(42.0f).## should be(trees.float.literal(42.0f).##)
trees.float.literal(42.0f).## shouldNot be(trees.float.literal(41.0f).##)
trees.float.parameter("my_id").## should be(trees.float.parameter("my_id").##)
trees.float.parameter("my_id_1").## should be(trees.float.parameter("my_id_2").##)

trees.array.parameter("my_id", trees.float, 42.0f, Array(12, 34)).## should be(
trees.array.parameter("my_id2", trees.float, 42.0f, Array(12, 34)).##)

trees.array.parameter("my_id", trees.float, 0.1f, Array(12, 34)).## shouldNot be(
trees.array.parameter("my_id2", trees.float, 99.9f, Array(56, 78)).##)

def reflexive(term: => trees.Term) = {
val t0 = term
val t1 = term
t0 should be(t0)
t0.## should be(t0.##)
t1 should be(t1)
t1.## should be(t1.##)
t0 should be(t1)
t0.## should be(t1.##)

sameStructuralDifferentParameterName(t0, t0.alphaConversion)
}

def sameStructuralDifferentParameterName(term1: trees.Term, term2: trees.Term) = {
term1 should be(term2)
term1.## should be(term2.##)
}

def differentStructural(term1: trees.Term, term2: trees.Term) = {
term1 shouldNot be(term2)
term1.## shouldNot be(term2.##)
}

reflexive(trees.float.parameter("my_id"))
reflexive(trees.float.literal(42.0f))
reflexive(trees.array.parameter("my_id", trees.float, 42.0f, Array(12, 34)))

sameStructuralDifferentParameterName(trees.float.parameter("my_id_1"), trees.float.parameter("my_id_2"))
sameStructuralDifferentParameterName(trees.array.parameter("my_id_3", trees.float, 42.0f, Array(12, 34)),
trees.array.parameter("my_id_4", trees.float, 42.0f, Array(12, 34)))

differentStructural(
trees.array.parameter("my_id", trees.float, 0.1f, Array(12, 34)),
trees.array.parameter("my_id2", trees.float, 99.9f, Array(56, 78))
)
}

}