-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-8519][SPARK-11560][SPARK-11559] [ML] [MLlib] Optimize KMeans implementation #10306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Test build #47724 has finished for PR 10306 at commit
|
|
Test build #47809 has finished for PR 10306 at commit
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Disable this test case because it was blocked by SPARK-12363.
|
Test build #47912 has finished for PR 10306 at commit
|
|
@yanboliang Regarding your local performance test:
Could you re-run the test? I will take a look at your implementation. |
|
@mengxr Thanks for the prompt. I will check my environment and re-run the test. |
|
@mengxr I found the misconfiguration of my test environment and updated it, thanks! I also updated the test cases based on your advice. Now println(com.github.fommil.netlib.BLAS.getInstance().getClass.getName)
val n = 3000
val count = 10
val random = new Random()
val a = Vectors.dense(Array.fill(n)(random.nextDouble()))
val aa = Array.fill(n)(a)
val b = Vectors.dense(Array.fill(n)(random.nextDouble()))
val bb = Array.fill(n)(b)
val a1 = new DenseMatrix(n, n, aa.flatMap(_.toArray), true)
val b1 = new DenseMatrix(n, n, bb.flatMap(_.toArray), false)
val c1 = Matrices.zeros(n, n).asInstanceOf[DenseMatrix]
var total1 = 0.0
// Trial runs
for (i <- 0 until 10) {
gemm(2.0, a1, b1, 2.0, c1)
}
for (i <- 0 until count) {
val start = System.nanoTime()
gemm(2.0, a1, b1, 2.0, c1)
total1 += (System.nanoTime() - start)/1e9
}
total1 = total1 / count
println("gemm elapsed time: = %.3f".format(total1) + " seconds.")
// Trial runs
for (m <- 0 until 10) {
for (i <- 0 until n; j <- 0 until n) {
dot(bb(j), aa(i))
}
}
var total2 = 0.0
for (m <- 0 until count) {
val start = System.nanoTime()
for (i <- 0 until n; j <- 0 until n) {
// axpy(1.0, bb(j), aa(i))
dot(bb(j), aa(i))
}
total2 += (System.nanoTime() - start)/1e9
}
total2 = total2 / count
println("dot elapsed time: = %.3f".format(total2) + " seconds.")The output is: |
Note: I have a new implementation for this issue at #10806 , let's move the discussion there and review that code.
runsrelated code completely, it will have no effect after this change.Update:
Further more, I track the calling stack and found that the cost to construct the
pointMatrix,centerMatrixanddistanceMatrixis expensive. I try to compute and cachepointMatrixandcenterMatrixin advance, but it still can not get benefits. Looking forward others' comments.gemmis slower thanaxpyin themllib.linalgpackage. Consider the following code which is the abstract ofKMeansdistance computation scenarios of new and old version:I got the performance result on my Mac:
It means we can not get benefits from BLAS Level 3 matrix-matrix multiplications to compute pairwise distance. I also found others' complains which is similar with this issue (OpenMathLib/OpenBLAS#528).
If I replace
axpy(1.0, bb(j), aa(j))withdot(bb(j), aa(j)), I got:Please correct me if I have some misunderstanding.