Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NDimensionalAffineTransform/build.sbt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
libraryDependencies += "org.scalatest" %% "scalatest" % "3.0.4" % Test
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package com.thoughtworks.compute

import scala.annotation.tailrec

/**
* @author 杨博 (Yang Bo)
*/
object NDimensionalAffineTransform {

def preConcatenate(matrix01: Array[Double], matrix12: Array[Double], length0: Int): Array[Double] = {
val length1 = matrix01.length / (length0 + 1)
val length2 = matrix12.length / (length1 + 1)
val matrix02 = Array.ofDim[Double]((length0 + 1) * length2)
concatenate(matrix01, matrix12, matrix02, length0, length1, length2)
matrix02
}

def concatenate(matrix01: Array[Double],
matrix12: Array[Double],
matrix02: Array[Double],
length0: Int,
length1: Int,
length2: Int): Unit = {

@tailrec
def loop2(index2: Int): Unit = {
if (index2 < length2) {
@tailrec
def loop0(index0: Int): Unit = {
if (index0 < length0) {
@tailrec
def loop1(index1: Int, accumulator: Double): Double = {
if (index1 < length1) {
loop1(index1 + 1,
accumulator +
matrix12(index2 * (length1 + 1) + index1) *
matrix01(index1 * (length0 + 1) + index0))
} else {
accumulator
}
}
matrix02(index2 * (length0 + 1) + index0) = loop1(0, 0.0)

loop0(index0 + 1)
}
}
loop0(0)
@tailrec
def loop1(index1: Int, accumulator: Double): Double = {
if (index1 < length1) {
loop1(index1 + 1,
accumulator +
matrix12(index2 * (length1 + 1) + index1) *
matrix01(index1 * (length0 + 1) + length0))
} else {
accumulator
}
}
matrix02(index2 * (length0 + 1) + length0) = loop1(0, matrix12(index2 * (length1 + 1) + length1))
loop2(index2 + 1)
}
}
loop2(0)

}

def transform(matrix: Array[Double], source: Array[Double]): Array[Double] = {
val sourceLength = source.length
val destination = Array.ofDim[Double](matrix.length / (sourceLength + 1))
transform(matrix, source, destination)
destination
}

private def transform(matrix: Array[Double], source: Array[Double], destination: Array[Double]): Unit = {
val sourceLength = source.length
val destinationLength = destination.length
if (matrix.length != (sourceLength + 1) * destinationLength) {
throw new IllegalArgumentException
}
@tailrec
def rowLoop(y: Int): Unit = {
if (y < destinationLength) {
@tailrec
def columnLoop(x: Int, accumulator: Double): Double = {
if (x < sourceLength) {
columnLoop(x + 1, accumulator + matrix(y * (sourceLength + 1) + x) * source(x))
} else {
accumulator
}
}
destination(y) = columnLoop(0, matrix(y * (sourceLength + 1) + sourceLength))

rowLoop(y + 1)
}
}
rowLoop(0)
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package com.thoughtworks.compute

import java.awt.geom.AffineTransform

import org.scalatest._

/**
* @author 杨博 (Yang Bo)
*/
final class NDimensionalAffineTransformSpec extends FreeSpec with Matchers {

def arrayToAffineTransform(matrix: Array[Double]): AffineTransform = {
matrix match {
case Array(m00, m01, m02, m10, m11, m12) =>
new AffineTransform(m00, m10, m01, m11, m02, m12)
case _ =>
throw new IllegalArgumentException
}
}

private def checkConcatenate2D(matrix0: Array[Double], matrix1: Array[Double]) = {
val at = arrayToAffineTransform(matrix0)
at.preConcatenate(arrayToAffineTransform(matrix1))

val expected = arrayToAffineTransform(NDimensionalAffineTransform.preConcatenate(matrix0, matrix1, 2))
at should be(expected)

}

"concatenate 2D" in {
checkConcatenate2D(
Array(
1.0, 0.0, 3.5, //
0.0, 1.0, 4.2
),
Array(
3.0, 0.0, 0.0, //
0.0, 2.0, 0.0
)
)

checkConcatenate2D(
Array.fill(6)(math.random()),
Array.fill(6)(math.random())
)
}

}
4 changes: 3 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ lazy val Expressions = project

lazy val Trees = project.dependsOn(Expressions)

lazy val OpenCLKernelBuilder = project.dependsOn(Expressions, Trees % Test)
lazy val NDimensionalAffineTransform = project

lazy val OpenCLKernelBuilder = project.dependsOn(NDimensionalAffineTransform, Expressions, Trees % Test)

lazy val Tensors = project.dependsOn(OpenCLKernelBuilder, OpenCL, Trees)

Expand Down