From 0011fe263c9e2182500eae5ad91776ee4d494ab4 Mon Sep 17 00:00:00 2001 From: Yang Bo Date: Fri, 9 Feb 2018 13:03:52 +0800 Subject: [PATCH] Implement Tensor.toString --- .../com/thoughtworks/compute/Tensors.scala | 36 +++++++++++++++++++ .../thoughtworks/compute/TensorsSpec.scala | 14 ++++++++ 2 files changed, 50 insertions(+) diff --git a/Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala b/Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala index 99bc4d98..d2a5a433 100644 --- a/Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala +++ b/Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala @@ -84,6 +84,42 @@ trait Tensors extends OpenCL { } sealed trait Tensor { thisTensor => + + override def toString: String = { + enqueue + .intransitiveFlatMap { pendingBuffer => + pendingBuffer.toHostBuffer.intransitiveMap { floatBuffer => + val floatArray = Array.ofDim[Float](floatBuffer.capacity()) + floatBuffer.asReadOnlyBuffer().get(floatArray) + floatArray + } + } + .run + .map { floatArray => + def toFastring(shape: Seq[Int], floatArray: Seq[Float]): Fastring = { + shape match { + case headSize +: tailShape => + val length = floatArray.length + if (tailShape.isEmpty) { + if (headSize == length) { + fast"[${floatArray.mkFastring(",")}]" + } else { + throw new IllegalArgumentException + } + } else { + val groupSize = length / headSize + def groups = for (i <- (0 until headSize).view) yield { + toFastring(tailShape, floatArray.view(i * groupSize, (i + 1) * groupSize)) + } + fast"[${groups.mkFastring(",")}]" + } + } + } + + toFastring(shape.view, floatArray).toString + }.blockingAwait + } + def broadcast(newShape: Array[Int]): Tensor = { val newLength = newShape.length val length = shape.length diff --git a/Tensors/src/test/scala/com/thoughtworks/compute/TensorsSpec.scala b/Tensors/src/test/scala/com/thoughtworks/compute/TensorsSpec.scala index 365ebda1..64877785 100644 --- a/Tensors/src/test/scala/com/thoughtworks/compute/TensorsSpec.scala +++ b/Tensors/src/test/scala/com/thoughtworks/compute/TensorsSpec.scala @@ -54,6 +54,20 @@ class TensorsSpec extends AsyncFreeSpec with Matchers { val element = 42.0f val padding = 99.0f val translated = tensors.Tensor.fill(element, shape, padding = padding).translate(Array(1, 2, -3)) + translated.toString should be( + "[" + + "[" + + "[99.0,99.0,99.0,99.0,99.0]," + + "[99.0,99.0,99.0,99.0,99.0]," + + "[99.0,99.0,99.0,99.0,99.0]" + + "]," + + "[" + + "[99.0,99.0,99.0,99.0,99.0]," + + "[99.0,99.0,99.0,99.0,99.0]," + + "[42.0,42.0,99.0,99.0,99.0]" + + "]" + + "]") + for { pendingBuffer <- translated.enqueue floatBuffer <- pendingBuffer.toHostBuffer