Skip to content

Commit fcf5372

Browse files
committed
fix broken test
1 parent 6edd128 commit fcf5372

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

mllib/src/test/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularizationSuite.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class DifferentiableRegularizationSuite extends SparkFunSuite {
3131
val regFun = new L2Regularization(regParam, shouldApply, None)
3232
val (loss, grad) = regFun.calculate(coefficients)
3333
assert(loss === 0.5 * regParam * BLAS.dot(coefficients, coefficients))
34-
assert(grad === coefficients.toArray.map(_ * regParam))
34+
assert(grad === Vectors.dense(coefficients.toArray.map(_ * regParam)))
3535

3636
// check with features standard
3737
val featuresStd = Array(0.1, 1.1, 0.5)
@@ -40,9 +40,9 @@ class DifferentiableRegularizationSuite extends SparkFunSuite {
4040
val expectedLossStd = 0.5 * regParam * (0 until numFeatures).map { j =>
4141
coefficients(j) * coefficients(j) / (featuresStd(j) * featuresStd(j))
4242
}.sum
43-
val expectedGradientStd = (0 until numFeatures).map { j =>
43+
val expectedGradientStd = Vectors.dense((0 until numFeatures).map { j =>
4444
regParam * coefficients(j) / (featuresStd(j) * featuresStd(j))
45-
}.toArray
45+
}.toArray)
4646
assert(lossStd === expectedLossStd)
4747
assert(gradStd === expectedGradientStd)
4848

@@ -51,7 +51,7 @@ class DifferentiableRegularizationSuite extends SparkFunSuite {
5151
val regFunApply = new L2Regularization(regParam, shouldApply2, None)
5252
val (lossApply, gradApply) = regFunApply.calculate(coefficients)
5353
assert(lossApply === 0.5 * regParam * coefficients(1) * coefficients(1))
54-
assert(gradApply === Array(0.0, coefficients(1) * regParam, 0.0))
54+
assert(gradApply === Vectors.dense(0.0, coefficients(1) * regParam, 0.0))
5555

5656
// check with zero features standard
5757
val featuresStdZero = Array(0.1, 0.0, 0.5)

0 commit comments

Comments
 (0)