Skip to content

Commit 3adc2d1

Browse files
committed
Rename TreeApi to Tree
1 parent b264a96 commit 3adc2d1

File tree

2 files changed

+71
-62
lines changed

2 files changed

+71
-62
lines changed

Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@ trait Tensors extends OpenCL {
3232
Factory[FloatArrayTrees with StructuralTrees].newInstance()
3333
import trees._
3434

35-
private def upvalues(tree: TreeApi): List[Parameter] = {
36-
val traversed: java.util.Set[TreeApi] = Collections.newSetFromMap(new IdentityHashMap)
35+
private def upvalues(tree: Tree): List[Parameter] = {
36+
val traversed: java.util.Set[Tree] = Collections.newSetFromMap(new IdentityHashMap)
3737
val builder = List.newBuilder[Parameter]
38-
def buildParameterList(tree: TreeApi): Unit = {
38+
def buildParameterList(tree: Tree): Unit = {
3939
tree match {
4040
case tree: Parameter =>
4141
builder += tree
@@ -44,7 +44,7 @@ trait Tensors extends OpenCL {
4444
@tailrec def loop(i: Int): Unit = {
4545
if (i < productArity) {
4646
tree.productElement(i) match {
47-
case child: TreeApi @unchecked =>
47+
case child: Tree @unchecked =>
4848
val isNew = traversed.add(tree)
4949
if (isNew) {
5050
buildParameterList(child)
@@ -398,7 +398,11 @@ trait Tensors extends OpenCL {
398398
compiledKernel
399399
}
400400
}
401-
kernelCache.get(float.factory.newInstance(convertedTree.asInstanceOf[float.Tree]), loader)
401+
kernelCache.get(float.factory.newInstance(
402+
convertedTree.asInstanceOf[
403+
Tree { type TermIn[C <: trees.Category] = C#FloatTerm }
404+
]),
405+
loader)
402406
case compiledKernel =>
403407
compiledKernel
404408
}

0 commit comments

Comments
 (0)