Skip to content

Commit 3297eec

Browse files
committed
Add utilities for NDimensionalAffineTransform
1 parent 9d8c2b1 commit 3297eec

File tree

4 files changed

+151
-1
lines changed

4 files changed

+151
-1
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
libraryDependencies += "org.scalatest" %% "scalatest" % "3.0.4" % Test
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
package com.thoughtworks.compute
2+
3+
import scala.annotation.tailrec
4+
5+
/**
6+
* @author 杨博 (Yang Bo)
7+
*/
8+
object NDimensionalAffineTransform {
9+
10+
def preConcatenate(matrix01: Array[Double], matrix12: Array[Double], length0: Int): Array[Double] = {
11+
val length1 = matrix01.length / (length0 + 1)
12+
val length2 = matrix12.length / (length1 + 1)
13+
val matrix02 = Array.ofDim[Double]((length0 + 1) * length2)
14+
concatenate(matrix01, matrix12, matrix02, length0, length1, length2)
15+
matrix02
16+
}
17+
18+
def concatenate(matrix01: Array[Double],
19+
matrix12: Array[Double],
20+
matrix02: Array[Double],
21+
length0: Int,
22+
length1: Int,
23+
length2: Int): Unit = {
24+
25+
@tailrec
26+
def loop2(index2: Int): Unit = {
27+
if (index2 < length2) {
28+
@tailrec
29+
def loop0(index0: Int): Unit = {
30+
if (index0 < length0) {
31+
@tailrec
32+
def loop1(index1: Int, accumulator: Double): Double = {
33+
if (index1 < length1) {
34+
loop1(index1 + 1,
35+
accumulator +
36+
matrix12(index2 * (length1 + 1) + index1) *
37+
matrix01(index1 * (length0 + 1) + index0))
38+
} else {
39+
accumulator
40+
}
41+
}
42+
matrix02(index2 * (length0 + 1) + index0) = loop1(0, 0.0)
43+
44+
loop0(index0 + 1)
45+
}
46+
}
47+
loop0(0)
48+
@tailrec
49+
def loop1(index1: Int, accumulator: Double): Double = {
50+
if (index1 < length1) {
51+
loop1(index1 + 1,
52+
accumulator +
53+
matrix12(index2 * (length1 + 1) + index1) *
54+
matrix01(index1 * (length0 + 1) + length0))
55+
} else {
56+
accumulator
57+
}
58+
}
59+
matrix02(index2 * (length0 + 1) + length0) = loop1(0, matrix12(index2 * (length1 + 1) + length1))
60+
loop2(index2 + 1)
61+
}
62+
}
63+
loop2(0)
64+
65+
}
66+
67+
def transform(matrix: Array[Double], source: Array[Double]): Array[Double] = {
68+
val sourceLength = source.length
69+
val destination = Array.ofDim[Double](matrix.length / (sourceLength + 1))
70+
transform(matrix, source, destination)
71+
destination
72+
}
73+
74+
private def transform(matrix: Array[Double], source: Array[Double], destination: Array[Double]): Unit = {
75+
val sourceLength = source.length
76+
val destinationLength = destination.length
77+
if (matrix.length != (sourceLength + 1) * destinationLength) {
78+
throw new IllegalArgumentException
79+
}
80+
@tailrec
81+
def rowLoop(y: Int): Unit = {
82+
if (y < destinationLength) {
83+
@tailrec
84+
def columnLoop(x: Int, accumulator: Double): Double = {
85+
if (x < sourceLength) {
86+
columnLoop(x + 1, accumulator + matrix(y * (sourceLength + 1) + x) * source(x))
87+
} else {
88+
accumulator
89+
}
90+
}
91+
destination(y) = columnLoop(0, matrix(y * (sourceLength + 1) + sourceLength))
92+
93+
rowLoop(y + 1)
94+
}
95+
}
96+
rowLoop(0)
97+
}
98+
99+
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package com.thoughtworks.compute
2+
3+
import java.awt.geom.AffineTransform
4+
5+
import org.scalatest._
6+
7+
/**
8+
* @author 杨博 (Yang Bo)
9+
*/
10+
final class NDimensionalAffineTransformSpec extends FreeSpec with Matchers {
11+
12+
def arrayToAffineTransform(matrix: Array[Double]): AffineTransform = {
13+
matrix match {
14+
case Array(m00, m01, m02, m10, m11, m12) =>
15+
new AffineTransform(m00, m10, m01, m11, m02, m12)
16+
case _ =>
17+
throw new IllegalArgumentException
18+
}
19+
}
20+
21+
private def checkConcatenate2D(matrix0: Array[Double], matrix1: Array[Double]) = {
22+
val at = arrayToAffineTransform(matrix0)
23+
at.preConcatenate(arrayToAffineTransform(matrix1))
24+
25+
val expected = arrayToAffineTransform(NDimensionalAffineTransform.preConcatenate(matrix0, matrix1, 2))
26+
at should be(expected)
27+
28+
}
29+
30+
"concatenate 2D" in {
31+
checkConcatenate2D(
32+
Array(
33+
1.0, 0.0, 3.5, //
34+
0.0, 1.0, 4.2
35+
),
36+
Array(
37+
3.0, 0.0, 0.0, //
38+
0.0, 2.0, 0.0
39+
)
40+
)
41+
42+
checkConcatenate2D(
43+
Array.fill(6)(math.random()),
44+
Array.fill(6)(math.random())
45+
)
46+
}
47+
48+
}

build.sbt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ lazy val Expressions = project
1010

1111
lazy val Trees = project.dependsOn(Expressions)
1212

13-
lazy val OpenCLKernelBuilder = project.dependsOn(Expressions, Trees % Test)
13+
lazy val NDimensionalAffineTransform = project
14+
15+
lazy val OpenCLKernelBuilder = project.dependsOn(NDimensionalAffineTransform, Expressions, Trees % Test)
1416

1517
lazy val Tensors = project.dependsOn(OpenCLKernelBuilder, OpenCL, Trees)
1618

0 commit comments

Comments
 (0)