File tree Expand file tree Collapse file tree 2 files changed +69
-60
lines changed
Tensors/src/main/scala/com/thoughtworks/compute
Trees/src/main/scala/com/thoughtworks/compute Expand file tree Collapse file tree 2 files changed +69
-60
lines changed Original file line number Diff line number Diff line change @@ -32,10 +32,10 @@ trait Tensors extends OpenCL {
32
32
Factory [FloatArrayTrees with StructuralTrees ].newInstance()
33
33
import trees ._
34
34
35
- private def upvalues (tree : TreeApi ): List [Parameter ] = {
36
- val traversed : java.util.Set [TreeApi ] = Collections .newSetFromMap(new IdentityHashMap )
35
+ private def upvalues (tree : Tree ): List [Parameter ] = {
36
+ val traversed : java.util.Set [Tree ] = Collections .newSetFromMap(new IdentityHashMap )
37
37
val builder = List .newBuilder[Parameter ]
38
- def buildParameterList (tree : TreeApi ): Unit = {
38
+ def buildParameterList (tree : Tree ): Unit = {
39
39
tree match {
40
40
case tree : Parameter =>
41
41
builder += tree
@@ -44,7 +44,7 @@ trait Tensors extends OpenCL {
44
44
@ tailrec def loop (i : Int ): Unit = {
45
45
if (i < productArity) {
46
46
tree.productElement(i) match {
47
- case child : TreeApi @ unchecked =>
47
+ case child : Tree @ unchecked =>
48
48
val isNew = traversed.add(tree)
49
49
if (isNew) {
50
50
buildParameterList(child)
@@ -398,7 +398,11 @@ trait Tensors extends OpenCL {
398
398
compiledKernel
399
399
}
400
400
}
401
- kernelCache.get(float.factory.newInstance(convertedTree.asInstanceOf [FloatTerm # Tree ]), loader)
401
+ kernelCache.get(float.factory.newInstance(
402
+ convertedTree.asInstanceOf [
403
+ Tree { type TermIn [C <: trees.Category ] = C # FloatTerm }
404
+ ]),
405
+ loader)
402
406
case compiledKernel =>
403
407
compiledKernel
404
408
}
You can’t perform that action at this time.
0 commit comments