Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@ import org.apache.spark.SparkContext
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.linalg.{BLAS, DenseVector, SparseVector, Vector, Vectors, VectorUDT}
import org.apache.spark.ml.linalg.{BLAS, DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable,
SchemaUtils}
import org.apache.spark.ml.util._
import org.apache.spark.sql.{Column, DataFrame, Dataset}
import org.apache.spark.sql.functions.{avg, col, udf}
import org.apache.spark.sql.types.DoubleType
Expand Down Expand Up @@ -107,15 +106,19 @@ class ClusteringEvaluator @Since("2.3.0") (@Since("2.3.0") override val uid: Str

@Since("2.3.0")
override def evaluate(dataset: Dataset[_]): Double = {
SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT)
SchemaUtils.validateVectorCompatibleColumn(dataset.schema, $(featuresCol))
SchemaUtils.checkNumericType(dataset.schema, $(predictionCol))

val vectorCol = DatasetUtils.columnToVector(dataset, $(featuresCol))
val df = dataset.select(col($(predictionCol)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure this is the right way. Probably we can face the same issue everywhere we are using DatasetUtils.columnToVector. Probably it is better to fix the problem there.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mgaido91 Thanks for your reviewing!
I have considered this, however there exists a problem:
if we want to append metadata into the transformed column (like using method .as(alias: String, metadata: Metadata)) in DatasetUtils.columnToVector, how can we get the name of transformed column?
The only way to do this I know is:

val metadata = ...
val vectorCol = ..
val vectorName = dataset.select(vectorCol) .schema.head.name
vectorCol.as(vectorName, metadata)

Copy link
Contributor

@mgaido91 mgaido91 Jun 15, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have the new column we are returning, so we can get its name with .expr.sql

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mgaido91 I think it maybe nice to first add a name getter for column

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can propose that

vectorCol.as($(featuresCol), dataset.schema($(featuresCol)).metadata))

($(metricName), $(distanceMeasure)) match {
case ("silhouette", "squaredEuclidean") =>
SquaredEuclideanSilhouette.computeSilhouetteScore(
dataset, $(predictionCol), $(featuresCol))
df, $(predictionCol), $(featuresCol))
case ("silhouette", "cosine") =>
CosineSilhouette.computeSilhouetteScore(dataset, $(predictionCol), $(featuresCol))
CosineSilhouette.computeSilhouetteScore(df, $(predictionCol), $(featuresCol))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Dataset
Expand All @@ -33,10 +33,17 @@ class ClusteringEvaluatorSuite
import testImplicits._

@transient var irisDataset: Dataset[_] = _
@transient var newIrisDataset: Dataset[_] = _
@transient var newIrisDatasetD: Dataset[_] = _
@transient var newIrisDatasetF: Dataset[_] = _

override def beforeAll(): Unit = {
super.beforeAll()
irisDataset = spark.read.format("libsvm").load("../data/mllib/iris_libsvm.txt")
val datasets = MLTestingUtils.generateArrayFeatureDataset(irisDataset)
newIrisDataset = datasets._1
newIrisDatasetD = datasets._2
newIrisDatasetF = datasets._3
}

test("params") {
Expand Down Expand Up @@ -66,6 +73,9 @@ class ClusteringEvaluatorSuite
.setPredictionCol("label")

assert(evaluator.evaluate(irisDataset) ~== 0.6564679231 relTol 1e-5)
assert(evaluator.evaluate(newIrisDataset) ~== 0.6564679231 relTol 1e-5)
assert(evaluator.evaluate(newIrisDatasetD) ~== 0.6564679231 relTol 1e-5)
assert(evaluator.evaluate(newIrisDatasetF) ~== 0.6564679231 relTol 1e-5)
}

/*
Expand All @@ -85,6 +95,9 @@ class ClusteringEvaluatorSuite
.setDistanceMeasure("cosine")

assert(evaluator.evaluate(irisDataset) ~== 0.7222369298 relTol 1e-5)
assert(evaluator.evaluate(newIrisDataset) ~== 0.7222369298 relTol 1e-5)
assert(evaluator.evaluate(newIrisDatasetD) ~== 0.7222369298 relTol 1e-5)
assert(evaluator.evaluate(newIrisDatasetF) ~== 0.7222369298 relTol 1e-5)
}

test("number of clusters must be greater than one") {
Expand Down