diff --git a/Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala b/Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala index 5368d490..a92eef92 100644 --- a/Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala +++ b/Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala @@ -140,49 +140,46 @@ trait Tensors extends OpenCL { } lazy val enqueue: Do[PendingBuffer] = { - val compiledKernel = kernelCache.get( - closure, - new Callable[CompiledKernel] { - def call(): CompiledKernel = { - - val alphConversionContext = new AlphaConversionContext - val convertedTree = closure.tree.alphaConversion(alphConversionContext) - - val sourceCode = { - val globalContext = new GlobalContext - val functionContext = Factory[OpenCLKernelBuilder].newInstance(globalContext) - - val exportContext = new ExportContext - val kernelBody = convertedTree.export(functionContext, exportContext).asInstanceOf[functionContext.Term] - - val kernelParameters = upvalues(closure.tree).map { upvalue: Parameter => - exportContext.get(alphConversionContext.get(upvalue)).asInstanceOf[functionContext.Term] - } - fastraw""" + val compiledKernel = kernelCache.getIfPresent(closure) match { + case null => + val alphConversionContext = new AlphaConversionContext + val convertedTree = closure.tree.alphaConversion(alphConversionContext) + val loader = new Callable[CompiledKernel] { + def call(): CompiledKernel = { + val sourceCode = { + val globalContext = new GlobalContext + val functionContext = Factory[OpenCLKernelBuilder].newInstance(globalContext) + + val exportContext = new ExportContext + val kernelBody = convertedTree.export(functionContext, exportContext).asInstanceOf[functionContext.Term] + + val kernelParameters = upvalues(closure.tree).map { upvalue: Parameter => + exportContext.get(alphConversionContext.get(upvalue)).asInstanceOf[functionContext.Term] + } + fastraw""" $globalContext ${functionContext.generateKernelSourceCode("jit_kernel", shape.length, kernelParameters, Seq(kernelBody))} """ - } + } - val program = createProgramWithSource(sourceCode) - program.build() + val program = createProgramWithSource(sourceCode) + program.build() - val compiledKernel = new CompiledKernel { + val compiledKernel = new CompiledKernel { - def monadicClose: UnitContinuation[Unit] = program.monadicClose + def monadicClose: UnitContinuation[Unit] = program.monadicClose - def run(upvalues: List[Parameter]): Do[PendingBuffer] = { - // TODO: Manage life cycle of upvalues more delicately - // e.g. a buffer should be release as soon as possible if it is a dependency of another buffer, - // e.g. however, it can be hold longer time if it is dependencies of many other buffers. + def run(upvalues: List[Parameter]): Do[PendingBuffer] = { + // TODO: Manage life cycle of upvalues more delicately + // e.g. a buffer should be release as soon as possible if it is a dependency of another buffer, + // e.g. however, it can be hold longer time if it is dependencies of many other buffers. - upvalues - .traverse[ParallelDo, PendingBuffer] { tree => - Parallel(tree.asInstanceOf[Parameter].id.asInstanceOf[Tensor].enqueue) - } - .unwrap - .intransitiveFlatMap { - arguments: List[PendingBuffer] => + upvalues + .traverse[ParallelDo, PendingBuffer] { tree => + Parallel(tree.id.asInstanceOf[Tensor].enqueue) + } + .unwrap + .intransitiveFlatMap { arguments: List[PendingBuffer] => Do.monadicCloseable(program.createFirstKernel()).intransitiveFlatMap { kernel: Kernel => allocateBuffer[Float](shape.product).flatMap { outputBuffer => for ((arugment, i) <- arguments.view.zipWithIndex) { @@ -197,15 +194,17 @@ trait Tensors extends OpenCL { } } } - } + } + } } + compiledKernel } - kernelCache.put(float.factory.newInstance(convertedTree.asInstanceOf[float.Tree]), compiledKernel) - compiledKernel } - } - ) + kernelCache.get(float.factory.newInstance(convertedTree.asInstanceOf[float.Tree]), loader) + case compiledKernel => + compiledKernel + } compiledKernel.run(upvalues(closure.tree)).shared } diff --git a/Tensors/src/test/scala/com/thoughtworks/compute/TensorsSpec.scala b/Tensors/src/test/scala/com/thoughtworks/compute/TensorsSpec.scala index a97d0677..2a1a72b3 100644 --- a/Tensors/src/test/scala/com/thoughtworks/compute/TensorsSpec.scala +++ b/Tensors/src/test/scala/com/thoughtworks/compute/TensorsSpec.scala @@ -41,6 +41,9 @@ class TensorsSpec extends AsyncFreeSpec with Matchers { floatBuffer.position() should be(0) floatBuffer.limit() should be(shape.product) floatBuffer.capacity() should be(shape.product) + tensors.kernelCache.getIfPresent(zeros.closure) should not be null + val zeros2 = tensors.Tensor.fill(element, shape) + tensors.kernelCache.getIfPresent(zeros2.closure) should not be null } }.run.toScalaFuture diff --git a/Trees/src/main/scala/com/thoughtworks/compute/Trees.scala b/Trees/src/main/scala/com/thoughtworks/compute/Trees.scala index 0372bc56..f8025daa 100644 --- a/Trees/src/main/scala/com/thoughtworks/compute/Trees.scala +++ b/Trees/src/main/scala/com/thoughtworks/compute/Trees.scala @@ -72,7 +72,41 @@ trait Trees extends Expressions { } } + private def isSameChild(left: Any, right: Any, map: StructuralComparisonContext): Boolean = { + left match { + case left: TreeApi => + right match { + case right: TreeApi => + left.isSameStructure(right, map) + case _ => + false + } + case left: Array[_] => + right match { + case right: Array[_] => + val leftLength = left.length + val rightLength = right.length + @tailrec def arrayLoop(start: Int): Boolean = { + if (start < leftLength) { + if (isSameChild(left(start), right(start), map)) { + arrayLoop(start + 1) + } else { + false + } + } else { + true + } + } + leftLength == rightLength && arrayLoop(0) + case _ => + false + } + case _ => + left == right + } + } def isSameStructure(that: TreeApi, map: StructuralComparisonContext): Boolean = { + map.get(this) match { case null => this.getClass == that.getClass && { @@ -80,28 +114,18 @@ trait Trees extends Expressions { map.put(this, that) val productArity: Int = this.productArity @tailrec - def sameFields(start: Int = 0): Boolean = { - if (start < productArity) { - productElement(start) match { - case left: TreeApi => - that.productElement(start) match { - case right: TreeApi => - if (left.isSameStructure(right, map)) { - sameFields(start = start + 1) - } else { - false - } - case _ => - false - } - case _ => - false + def fieldLoop(from: Int): Boolean = { + if (from < productArity) { + if (isSameChild(this.productElement(from), that.productElement(from), map)) { + fieldLoop(from + 1) + } else { + false } } else { true } } - sameFields() + fieldLoop(0) } case existing => existing eq that @@ -113,15 +137,6 @@ trait Trees extends Expressions { val id: Any - def isSameStructure(that: TreeApi, map: StructuralComparisonContext): Boolean = { - map.get(this) match { - case null => - map.put(this, that) - true - case existing => - existing eq that - } - } } final class HashCodeContext extends IdentityHashMap[TreeApi, Int] { @@ -246,6 +261,15 @@ object Trees { final case class FloatParameter(id: Any) extends TreeApi with Parameter { thisParameter => type TermIn[C <: Category] = C#FloatTerm + def isSameStructure(that: TreeApi, map: StructuralComparisonContext): Boolean = { + map.get(this) match { + case null => + map.put(this, that) + true + case existing => + existing eq that + } + } def structuralHashCode(context: HashCodeContext): Int = { context.asScala.getOrElseUpdate(this, { @@ -482,6 +506,23 @@ object Trees { type Element = elementType.TermIn[C] } + def isSameStructure(that: TreeApi, map: StructuralComparisonContext): Boolean = { + map.get(this) match { + case null => + that match { + case ArrayParameter(thatId, thatElemenetType, thatPadding, thatShape) + if elementType == thatElemenetType && padding == padding && java.util.Arrays.equals(shape, + thatShape) => + map.put(this, that) + true + case _ => + false + } + case existing => + existing eq that + } + } + def structuralHashCode(context: HashCodeContext): Int = { context.asScala.getOrElseUpdate( this, { diff --git a/Trees/src/test/scala/com/thoughtworks/compute/TreesSpec.scala b/Trees/src/test/scala/com/thoughtworks/compute/TreesSpec.scala index b94f775b..6fb7cc55 100644 --- a/Trees/src/test/scala/com/thoughtworks/compute/TreesSpec.scala +++ b/Trees/src/test/scala/com/thoughtworks/compute/TreesSpec.scala @@ -12,18 +12,41 @@ final class TreesSpec extends FreeSpec with Matchers { "hashCode" in { val trees: FloatArrayTrees = Factory[Trees.FloatArrayTrees with Trees.StructuralTrees].newInstance() - - trees.float.literal(42.0f).## should be(trees.float.literal(42.0f).##) - trees.float.literal(42.0f).## shouldNot be(trees.float.literal(41.0f).##) - trees.float.parameter("my_id").## should be(trees.float.parameter("my_id").##) - trees.float.parameter("my_id_1").## should be(trees.float.parameter("my_id_2").##) - - trees.array.parameter("my_id", trees.float, 42.0f, Array(12, 34)).## should be( - trees.array.parameter("my_id2", trees.float, 42.0f, Array(12, 34)).##) - - trees.array.parameter("my_id", trees.float, 0.1f, Array(12, 34)).## shouldNot be( - trees.array.parameter("my_id2", trees.float, 99.9f, Array(56, 78)).##) - + def reflexive(term: => trees.Term) = { + val t0 = term + val t1 = term + t0 should be(t0) + t0.## should be(t0.##) + t1 should be(t1) + t1.## should be(t1.##) + t0 should be(t1) + t0.## should be(t1.##) + + sameStructuralDifferentParameterName(t0, t0.alphaConversion) + } + + def sameStructuralDifferentParameterName(term1: trees.Term, term2: trees.Term) = { + term1 should be(term2) + term1.## should be(term2.##) + } + + def differentStructural(term1: trees.Term, term2: trees.Term) = { + term1 shouldNot be(term2) + term1.## shouldNot be(term2.##) + } + + reflexive(trees.float.parameter("my_id")) + reflexive(trees.float.literal(42.0f)) + reflexive(trees.array.parameter("my_id", trees.float, 42.0f, Array(12, 34))) + + sameStructuralDifferentParameterName(trees.float.parameter("my_id_1"), trees.float.parameter("my_id_2")) + sameStructuralDifferentParameterName(trees.array.parameter("my_id_3", trees.float, 42.0f, Array(12, 34)), + trees.array.parameter("my_id_4", trees.float, 42.0f, Array(12, 34))) + + differentStructural( + trees.array.parameter("my_id", trees.float, 0.1f, Array(12, 34)), + trees.array.parameter("my_id2", trees.float, 99.9f, Array(56, 78)) + ) } }