svd = mat.computeSVD(5, true, 1.0E-9d);
+ RowMatrix U = svd.U(); // The U factor is a RowMatrix.
+ Vector s = svd.s(); // The singular values are stored in a local dense vector.
+ Matrix V = svd.V(); // The V factor is a local dense matrix.
// $example off$
Vector[] collectPartitions = (Vector[]) U.rows().collect();
System.out.println("U factor is:");
diff --git a/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java b/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java
index da3a5dfe8628..6b8e6554f1bb 100644
--- a/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java
+++ b/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java
@@ -28,8 +28,6 @@
import java.sql.Timestamp;
import java.util.*;
-import scala.Tuple2;
-
/**
* Counts words in UTF8 encoded, '\n' delimited text received from the network.
*
@@ -76,8 +74,6 @@ public Iterator call(LineWithTimestamp lineWithTimestamp) throws Exceptio
for (String word : lineWithTimestamp.getLine().split(" ")) {
eventList.add(new Event(word, lineWithTimestamp.getTimestamp()));
}
- System.out.println(
- "Number of events from " + lineWithTimestamp.getLine() + " = " + eventList.size());
return eventList.iterator();
}
};
@@ -100,7 +96,7 @@ public Iterator call(LineWithTimestamp lineWithTimestamp) throws Exceptio
// If timed out, then remove session and send final update
if (state.hasTimedOut()) {
SessionUpdate finalUpdate = new SessionUpdate(
- sessionId, state.get().getDurationMs(), state.get().getNumEvents(), true);
+ sessionId, state.get().calculateDuration(), state.get().getNumEvents(), true);
state.remove();
return finalUpdate;
@@ -133,7 +129,7 @@ public Iterator call(LineWithTimestamp lineWithTimestamp) throws Exceptio
// Set timeout such that the session will be expired if no data received for 10 seconds
state.setTimeoutDuration("10 seconds");
return new SessionUpdate(
- sessionId, state.get().getDurationMs(), state.get().getNumEvents(), false);
+ sessionId, state.get().calculateDuration(), state.get().getNumEvents(), false);
}
}
};
@@ -215,7 +211,8 @@ public void setStartTimestampMs(long startTimestampMs) {
public long getEndTimestampMs() { return endTimestampMs; }
public void setEndTimestampMs(long endTimestampMs) { this.endTimestampMs = endTimestampMs; }
- public long getDurationMs() { return endTimestampMs - startTimestampMs; }
+ public long calculateDuration() { return endTimestampMs - startTimestampMs; }
+
@Override public String toString() {
return "SessionInfo(numEvents = " + numEvents +
", timestamps = " + startTimestampMs + " to " + endTimestampMs + ")";
diff --git a/examples/src/main/python/ml/fpgrowth_example.py b/examples/src/main/python/ml/fpgrowth_example.py
new file mode 100644
index 000000000000..c92c3c27abb2
--- /dev/null
+++ b/examples/src/main/python/ml/fpgrowth_example.py
@@ -0,0 +1,56 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# $example on$
+from pyspark.ml.fpm import FPGrowth
+# $example off$
+from pyspark.sql import SparkSession
+
+"""
+An example demonstrating FPGrowth.
+Run with:
+ bin/spark-submit examples/src/main/python/ml/fpgrowth_example.py
+"""
+
+if __name__ == "__main__":
+ spark = SparkSession\
+ .builder\
+ .appName("FPGrowthExample")\
+ .getOrCreate()
+
+ # $example on$
+ df = spark.createDataFrame([
+ (0, [1, 2, 5]),
+ (1, [1, 2, 3, 5]),
+ (2, [1, 2])
+ ], ["id", "items"])
+
+ fpGrowth = FPGrowth(itemsCol="items", minSupport=0.5, minConfidence=0.6)
+ model = fpGrowth.fit(df)
+
+ # Display frequent itemsets.
+ model.freqItemsets.show()
+
+ # Display generated association rules.
+ model.associationRules.show()
+
+ # transform examines the input items against all the association rules and summarize the
+ # consequents as prediction
+ model.transform(df).show()
+ # $example off$
+
+ spark.stop()
diff --git a/examples/src/main/python/mllib/pca_rowmatrix_example.py b/examples/src/main/python/mllib/pca_rowmatrix_example.py
new file mode 100644
index 000000000000..49b9b1bbe08e
--- /dev/null
+++ b/examples/src/main/python/mllib/pca_rowmatrix_example.py
@@ -0,0 +1,46 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark import SparkContext
+# $example on$
+from pyspark.mllib.linalg import Vectors
+from pyspark.mllib.linalg.distributed import RowMatrix
+# $example off$
+
+if __name__ == "__main__":
+ sc = SparkContext(appName="PythonPCAOnRowMatrixExample")
+
+ # $example on$
+ rows = sc.parallelize([
+ Vectors.sparse(5, {1: 1.0, 3: 7.0}),
+ Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0),
+ Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)
+ ])
+
+ mat = RowMatrix(rows)
+ # Compute the top 4 principal components.
+ # Principal components are stored in a local dense matrix.
+ pc = mat.computePrincipalComponents(4)
+
+ # Project the rows to the linear space spanned by the top 4 principal components.
+ projected = mat.multiply(pc)
+ # $example off$
+ collected = projected.rows.collect()
+ print("Projected Row Matrix of principal component:")
+ for vector in collected:
+ print(vector)
+ sc.stop()
diff --git a/examples/src/main/python/mllib/svd_example.py b/examples/src/main/python/mllib/svd_example.py
new file mode 100644
index 000000000000..5b220fdb3fd6
--- /dev/null
+++ b/examples/src/main/python/mllib/svd_example.py
@@ -0,0 +1,48 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark import SparkContext
+# $example on$
+from pyspark.mllib.linalg import Vectors
+from pyspark.mllib.linalg.distributed import RowMatrix
+# $example off$
+
+if __name__ == "__main__":
+ sc = SparkContext(appName="PythonSVDExample")
+
+ # $example on$
+ rows = sc.parallelize([
+ Vectors.sparse(5, {1: 1.0, 3: 7.0}),
+ Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0),
+ Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)
+ ])
+
+ mat = RowMatrix(rows)
+
+ # Compute the top 5 singular values and corresponding singular vectors.
+ svd = mat.computeSVD(5, computeU=True)
+ U = svd.U # The U factor is a RowMatrix.
+ s = svd.s # The singular values are stored in a local dense vector.
+ V = svd.V # The V factor is a local dense matrix.
+ # $example off$
+ collected = U.rows.collect()
+ print("U factor is:")
+ for vector in collected:
+ print(vector)
+ print("Singular values are: %s" % s)
+ print("V factor is:\n%s" % V)
+ sc.stop()
diff --git a/examples/src/main/r/ml/fpm.R b/examples/src/main/r/ml/fpm.R
new file mode 100644
index 000000000000..89c4564457d9
--- /dev/null
+++ b/examples/src/main/r/ml/fpm.R
@@ -0,0 +1,50 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# To run this example use
+# ./bin/spark-submit examples/src/main/r/ml/fpm.R
+
+# Load SparkR library into your R session
+library(SparkR)
+
+# Initialize SparkSession
+sparkR.session(appName = "SparkR-ML-fpm-example")
+
+# $example on$
+# Load training data
+
+df <- selectExpr(createDataFrame(data.frame(rawItems = c(
+ "1,2,5", "1,2,3,5", "1,2"
+))), "split(rawItems, ',') AS items")
+
+fpm <- spark.fpGrowth(df, itemsCol="items", minSupport=0.5, minConfidence=0.6)
+
+# Extracting frequent itemsets
+
+spark.freqItemsets(fpm)
+
+# Extracting association rules
+
+spark.associationRules(fpm)
+
+# Predict uses association rules to and combines possible consequents
+
+predict(fpm, df)
+
+# $example off$
+
+sparkR.session.stop()
diff --git a/examples/src/main/r/streaming/structured_network_wordcount.R b/examples/src/main/r/streaming/structured_network_wordcount.R
new file mode 100644
index 000000000000..cda18ebc072e
--- /dev/null
+++ b/examples/src/main/r/streaming/structured_network_wordcount.R
@@ -0,0 +1,57 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Counts words in UTF8 encoded, '\n' delimited text received from the network.
+
+# To run this on your local machine, you need to first run a Netcat server
+# $ nc -lk 9999
+# and then run the example
+# ./bin/spark-submit examples/src/main/r/streaming/structured_network_wordcount.R localhost 9999
+
+# Load SparkR library into your R session
+library(SparkR)
+
+# Initialize SparkSession
+sparkR.session(appName = "SparkR-Streaming-structured-network-wordcount-example")
+
+args <- commandArgs(trailing = TRUE)
+
+if (length(args) != 2) {
+ print("Usage: structured_network_wordcount.R ")
+ print(" and describe the TCP server that Structured Streaming")
+ print("would connect to receive data.")
+ q("no")
+}
+
+hostname <- args[[1]]
+port <- as.integer(args[[2]])
+
+# Create DataFrame representing the stream of input lines from connection to localhost:9999
+lines <- read.stream("socket", host = hostname, port = port)
+
+# Split the lines into words
+words <- selectExpr(lines, "explode(split(value, ' ')) as word")
+
+# Generate running word count
+wordCounts <- count(groupBy(words, "word"))
+
+# Start running the query that prints the running counts to the console
+query <- write.stream(wordCounts, "console", outputMode = "complete")
+
+awaitTermination(query)
+
+sparkR.session.stop()
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/FPGrowthExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/FPGrowthExample.scala
new file mode 100644
index 000000000000..59110d70de55
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/FPGrowthExample.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml
+
+// scalastyle:off println
+
+// $example on$
+import org.apache.spark.ml.fpm.FPGrowth
+// $example off$
+import org.apache.spark.sql.SparkSession
+
+/**
+ * An example demonstrating FP-Growth.
+ * Run with
+ * {{{
+ * bin/run-example ml.FPGrowthExample
+ * }}}
+ */
+object FPGrowthExample {
+
+ def main(args: Array[String]): Unit = {
+ val spark = SparkSession
+ .builder
+ .appName(s"${this.getClass.getSimpleName}")
+ .getOrCreate()
+ import spark.implicits._
+
+ // $example on$
+ val dataset = spark.createDataset(Seq(
+ "1 2 5",
+ "1 2 3 5",
+ "1 2")
+ ).map(t => t.split(" ")).toDF("items")
+
+ val fpgrowth = new FPGrowth().setItemsCol("items").setMinSupport(0.5).setMinConfidence(0.6)
+ val model = fpgrowth.fit(dataset)
+
+ // Display frequent itemsets.
+ model.freqItemsets.show()
+
+ // Display generated association rules.
+ model.associationRules.show()
+
+ // transform examines the input items against all the association rules and summarize the
+ // consequents as prediction
+ model.transform(dataset).show()
+ // $example off$
+
+ spark.stop()
+ }
+}
+// scalastyle:on println
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnRowMatrixExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnRowMatrixExample.scala
index a137ba2a2f9d..da43a8d9c7e8 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnRowMatrixExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnRowMatrixExample.scala
@@ -39,9 +39,9 @@ object PCAOnRowMatrixExample {
Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0),
Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0))
- val dataRDD = sc.parallelize(data, 2)
+ val rows = sc.parallelize(data)
- val mat: RowMatrix = new RowMatrix(dataRDD)
+ val mat: RowMatrix = new RowMatrix(rows)
// Compute the top 4 principal components.
// Principal components are stored in a local dense matrix.
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SVDExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SVDExample.scala
index b286a3f7b909..769ae2a3a88b 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/SVDExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SVDExample.scala
@@ -28,6 +28,9 @@ import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.linalg.distributed.RowMatrix
// $example off$
+/**
+ * Example for SingularValueDecomposition.
+ */
object SVDExample {
def main(args: Array[String]): Unit = {
@@ -41,15 +44,15 @@ object SVDExample {
Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0),
Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0))
- val dataRDD = sc.parallelize(data, 2)
+ val rows = sc.parallelize(data)
- val mat: RowMatrix = new RowMatrix(dataRDD)
+ val mat: RowMatrix = new RowMatrix(rows)
// Compute the top 5 singular values and corresponding singular vectors.
val svd: SingularValueDecomposition[RowMatrix, Matrix] = mat.computeSVD(5, computeU = true)
val U: RowMatrix = svd.U // The U factor is a RowMatrix.
- val s: Vector = svd.s // The singular values are stored in a local dense vector.
- val V: Matrix = svd.V // The V factor is a local dense matrix.
+ val s: Vector = svd.s // The singular values are stored in a local dense vector.
+ val V: Matrix = svd.V // The V factor is a local dense matrix.
// $example off$
val collect = U.rows.collect()
println("U factor is:")
diff --git a/external/docker-integration-tests/pom.xml b/external/docker-integration-tests/pom.xml
index 8948df2da89e..04afe28fb788 100644
--- a/external/docker-integration-tests/pom.xml
+++ b/external/docker-integration-tests/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.2.1-SNAPSHOT
../../pom.xml
diff --git a/external/flume-assembly/pom.xml b/external/flume-assembly/pom.xml
index f8ef8a991316..47e03419d3df 100644
--- a/external/flume-assembly/pom.xml
+++ b/external/flume-assembly/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.2.1-SNAPSHOT
../../pom.xml
diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml
index 6d547c46d6a2..f961a8f54d9a 100644
--- a/external/flume-sink/pom.xml
+++ b/external/flume-sink/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.2.1-SNAPSHOT
../../pom.xml
diff --git a/external/flume/pom.xml b/external/flume/pom.xml
index 46901d64eda9..d8bc7dcf7524 100644
--- a/external/flume/pom.xml
+++ b/external/flume/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.2.1-SNAPSHOT
../../pom.xml
diff --git a/external/kafka-0-10-assembly/pom.xml b/external/kafka-0-10-assembly/pom.xml
index 295142cbfdff..6d46430d6e96 100644
--- a/external/kafka-0-10-assembly/pom.xml
+++ b/external/kafka-0-10-assembly/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.2.1-SNAPSHOT
../../pom.xml
diff --git a/external/kafka-0-10-sql/pom.xml b/external/kafka-0-10-sql/pom.xml
index 6cf448e65e8b..5d979ddf2f74 100644
--- a/external/kafka-0-10-sql/pom.xml
+++ b/external/kafka-0-10-sql/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.2.1-SNAPSHOT
../../pom.xml
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala
index 6d76904fb0e5..7c4f38e02fb2 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala
@@ -28,6 +28,7 @@ import org.apache.kafka.common.TopicPartition
import org.apache.spark.{SparkEnv, SparkException, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.kafka010.KafkaSource._
+import org.apache.spark.util.UninterruptibleThread
/**
@@ -62,11 +63,20 @@ private[kafka010] case class CachedKafkaConsumer private(
case class AvailableOffsetRange(earliest: Long, latest: Long)
+ private def runUninterruptiblyIfPossible[T](body: => T): T = Thread.currentThread match {
+ case ut: UninterruptibleThread =>
+ ut.runUninterruptibly(body)
+ case _ =>
+ logWarning("CachedKafkaConsumer is not running in UninterruptibleThread. " +
+ "It may hang when CachedKafkaConsumer's methods are interrupted because of KAFKA-1894")
+ body
+ }
+
/**
* Return the available offset range of the current partition. It's a pair of the earliest offset
* and the latest offset.
*/
- def getAvailableOffsetRange(): AvailableOffsetRange = {
+ def getAvailableOffsetRange(): AvailableOffsetRange = runUninterruptiblyIfPossible {
consumer.seekToBeginning(Set(topicPartition).asJava)
val earliestOffset = consumer.position(topicPartition)
consumer.seekToEnd(Set(topicPartition).asJava)
@@ -92,7 +102,8 @@ private[kafka010] case class CachedKafkaConsumer private(
offset: Long,
untilOffset: Long,
pollTimeoutMs: Long,
- failOnDataLoss: Boolean): ConsumerRecord[Array[Byte], Array[Byte]] = {
+ failOnDataLoss: Boolean):
+ ConsumerRecord[Array[Byte], Array[Byte]] = runUninterruptiblyIfPossible {
require(offset < untilOffset,
s"offset must always be less than untilOffset [offset: $offset, untilOffset: $untilOffset]")
logDebug(s"Get $groupId $topicPartition nextOffset $nextOffsetInFetchedData requested $offset")
@@ -276,7 +287,7 @@ private[kafka010] case class CachedKafkaConsumer private(
reportDataLoss0(failOnDataLoss, finalMessage, cause)
}
- private def close(): Unit = consumer.close()
+ def close(): Unit = consumer.close()
private def seek(offset: Long): Unit = {
logDebug(s"Seeking to $groupId $topicPartition $offset")
@@ -371,7 +382,7 @@ private[kafka010] object CachedKafkaConsumer extends Logging {
// If this is reattempt at running the task, then invalidate cache and start with
// a new consumer
- if (TaskContext.get != null && TaskContext.get.attemptNumber > 1) {
+ if (TaskContext.get != null && TaskContext.get.attemptNumber >= 1) {
removeKafkaConsumer(topic, partition, kafkaParams)
val consumer = new CachedKafkaConsumer(topicPartition, kafkaParams)
consumer.inuse = true
@@ -387,6 +398,14 @@ private[kafka010] object CachedKafkaConsumer extends Logging {
}
}
+ /** Create an [[CachedKafkaConsumer]] but don't put it into cache. */
+ def createUncached(
+ topic: String,
+ partition: Int,
+ kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer = {
+ new CachedKafkaConsumer(new TopicPartition(topic, partition), kafkaParams)
+ }
+
private def reportDataLoss0(
failOnDataLoss: Boolean,
finalMessage: String,
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala
index 2696d6f089d2..3e65949a6fd1 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala
@@ -95,8 +95,10 @@ private[kafka010] class KafkaOffsetReader(
* Closes the connection to Kafka, and cleans up state.
*/
def close(): Unit = {
- consumer.close()
- kafkaReaderThread.shutdownNow()
+ runUninterruptibly {
+ consumer.close()
+ }
+ kafkaReaderThread.shutdown()
}
/**
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala
index f180bbad6e36..97bd28316932 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.kafka010
import java.{util => ju}
+import java.util.UUID
import org.apache.kafka.common.TopicPartition
@@ -33,9 +34,9 @@ import org.apache.spark.unsafe.types.UTF8String
private[kafka010] class KafkaRelation(
override val sqlContext: SQLContext,
- kafkaReader: KafkaOffsetReader,
- executorKafkaParams: ju.Map[String, Object],
+ strategy: ConsumerStrategy,
sourceOptions: Map[String, String],
+ specifiedKafkaParams: Map[String, String],
failOnDataLoss: Boolean,
startingOffsets: KafkaOffsetRangeLimit,
endingOffsets: KafkaOffsetRangeLimit)
@@ -53,9 +54,27 @@ private[kafka010] class KafkaRelation(
override def schema: StructType = KafkaOffsetReader.kafkaSchema
override def buildScan(): RDD[Row] = {
+ // Each running query should use its own group id. Otherwise, the query may be only assigned
+ // partial data since Kafka will assign partitions to multiple consumers having the same group
+ // id. Hence, we should generate a unique id for each query.
+ val uniqueGroupId = s"spark-kafka-relation-${UUID.randomUUID}"
+
+ val kafkaOffsetReader = new KafkaOffsetReader(
+ strategy,
+ KafkaSourceProvider.kafkaParamsForDriver(specifiedKafkaParams),
+ sourceOptions,
+ driverGroupIdPrefix = s"$uniqueGroupId-driver")
+
// Leverage the KafkaReader to obtain the relevant partition offsets
- val fromPartitionOffsets = getPartitionOffsets(startingOffsets)
- val untilPartitionOffsets = getPartitionOffsets(endingOffsets)
+ val (fromPartitionOffsets, untilPartitionOffsets) = {
+ try {
+ (getPartitionOffsets(kafkaOffsetReader, startingOffsets),
+ getPartitionOffsets(kafkaOffsetReader, endingOffsets))
+ } finally {
+ kafkaOffsetReader.close()
+ }
+ }
+
// Obtain topicPartitions in both from and until partition offset, ignoring
// topic partitions that were added and/or deleted between the two above calls.
if (fromPartitionOffsets.keySet != untilPartitionOffsets.keySet) {
@@ -82,6 +101,8 @@ private[kafka010] class KafkaRelation(
offsetRanges.sortBy(_.topicPartition.toString).mkString(", "))
// Create an RDD that reads from Kafka and get the (key, value) pair as byte arrays.
+ val executorKafkaParams =
+ KafkaSourceProvider.kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId)
val rdd = new KafkaSourceRDD(
sqlContext.sparkContext, executorKafkaParams, offsetRanges,
pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer = false).map { cr =>
@@ -98,6 +119,7 @@ private[kafka010] class KafkaRelation(
}
private def getPartitionOffsets(
+ kafkaReader: KafkaOffsetReader,
kafkaOffsets: KafkaOffsetRangeLimit): Map[TopicPartition, Long] = {
def validateTopicPartitions(partitions: Set[TopicPartition],
partitionOffsets: Map[TopicPartition, Long]): Map[TopicPartition, Long] = {
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala
index ab1ce347cbe3..3cb4d8cad12c 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala
@@ -111,10 +111,6 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
sqlContext: SQLContext,
parameters: Map[String, String]): BaseRelation = {
validateBatchOptions(parameters)
- // Each running query should use its own group id. Otherwise, the query may be only assigned
- // partial data since Kafka will assign partitions to multiple consumers having the same group
- // id. Hence, we should generate a unique id for each query.
- val uniqueGroupId = s"spark-kafka-relation-${UUID.randomUUID}"
val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
val specifiedKafkaParams =
parameters
@@ -131,20 +127,14 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
assert(endingRelationOffsets != EarliestOffsetRangeLimit)
- val kafkaOffsetReader = new KafkaOffsetReader(
- strategy(caseInsensitiveParams),
- kafkaParamsForDriver(specifiedKafkaParams),
- parameters,
- driverGroupIdPrefix = s"$uniqueGroupId-driver")
-
new KafkaRelation(
sqlContext,
- kafkaOffsetReader,
- kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId),
- parameters,
- failOnDataLoss(caseInsensitiveParams),
- startingRelationOffsets,
- endingRelationOffsets)
+ strategy(caseInsensitiveParams),
+ sourceOptions = parameters,
+ specifiedKafkaParams = specifiedKafkaParams,
+ failOnDataLoss = failOnDataLoss(caseInsensitiveParams),
+ startingOffsets = startingRelationOffsets,
+ endingOffsets = endingRelationOffsets)
}
override def createSink(
@@ -213,46 +203,6 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName)
}
- private def kafkaParamsForDriver(specifiedKafkaParams: Map[String, String]) =
- ConfigUpdater("source", specifiedKafkaParams)
- .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName)
- .set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName)
-
- // Set to "earliest" to avoid exceptions. However, KafkaSource will fetch the initial
- // offsets by itself instead of counting on KafkaConsumer.
- .set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest")
-
- // So that consumers in the driver does not commit offsets unnecessarily
- .set(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false")
-
- // So that the driver does not pull too much data
- .set(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, new java.lang.Integer(1))
-
- // If buffer config is not set, set it to reasonable value to work around
- // buffer issues (see KAFKA-3135)
- .setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer)
- .build()
-
- private def kafkaParamsForExecutors(
- specifiedKafkaParams: Map[String, String], uniqueGroupId: String) =
- ConfigUpdater("executor", specifiedKafkaParams)
- .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName)
- .set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName)
-
- // Make sure executors do only what the driver tells them.
- .set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "none")
-
- // So that consumers in executors do not mess with any existing group id
- .set(ConsumerConfig.GROUP_ID_CONFIG, s"$uniqueGroupId-executor")
-
- // So that consumers in executors does not commit offsets unnecessarily
- .set(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false")
-
- // If buffer config is not set, set it to reasonable value to work around
- // buffer issues (see KAFKA-3135)
- .setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer)
- .build()
-
private def strategy(caseInsensitiveParams: Map[String, String]) =
caseInsensitiveParams.find(x => STRATEGY_OPTION_KEYS.contains(x._1)).get match {
case ("assign", value) =>
@@ -414,30 +364,9 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
logWarning("maxOffsetsPerTrigger option ignored in batch queries")
}
}
-
- /** Class to conveniently update Kafka config params, while logging the changes */
- private case class ConfigUpdater(module: String, kafkaParams: Map[String, String]) {
- private val map = new ju.HashMap[String, Object](kafkaParams.asJava)
-
- def set(key: String, value: Object): this.type = {
- map.put(key, value)
- logInfo(s"$module: Set $key to $value, earlier value: ${kafkaParams.getOrElse(key, "")}")
- this
- }
-
- def setIfUnset(key: String, value: Object): ConfigUpdater = {
- if (!map.containsKey(key)) {
- map.put(key, value)
- logInfo(s"$module: Set $key to $value")
- }
- this
- }
-
- def build(): ju.Map[String, Object] = map
- }
}
-private[kafka010] object KafkaSourceProvider {
+private[kafka010] object KafkaSourceProvider extends Logging {
private val STRATEGY_OPTION_KEYS = Set("subscribe", "subscribepattern", "assign")
private[kafka010] val STARTING_OFFSETS_OPTION_KEY = "startingoffsets"
private[kafka010] val ENDING_OFFSETS_OPTION_KEY = "endingoffsets"
@@ -459,4 +388,66 @@ private[kafka010] object KafkaSourceProvider {
case None => defaultOffsets
}
}
+
+ def kafkaParamsForDriver(specifiedKafkaParams: Map[String, String]): ju.Map[String, Object] =
+ ConfigUpdater("source", specifiedKafkaParams)
+ .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName)
+ .set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName)
+
+ // Set to "earliest" to avoid exceptions. However, KafkaSource will fetch the initial
+ // offsets by itself instead of counting on KafkaConsumer.
+ .set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest")
+
+ // So that consumers in the driver does not commit offsets unnecessarily
+ .set(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false")
+
+ // So that the driver does not pull too much data
+ .set(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, new java.lang.Integer(1))
+
+ // If buffer config is not set, set it to reasonable value to work around
+ // buffer issues (see KAFKA-3135)
+ .setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer)
+ .build()
+
+ def kafkaParamsForExecutors(
+ specifiedKafkaParams: Map[String, String],
+ uniqueGroupId: String): ju.Map[String, Object] =
+ ConfigUpdater("executor", specifiedKafkaParams)
+ .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName)
+ .set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName)
+
+ // Make sure executors do only what the driver tells them.
+ .set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "none")
+
+ // So that consumers in executors do not mess with any existing group id
+ .set(ConsumerConfig.GROUP_ID_CONFIG, s"$uniqueGroupId-executor")
+
+ // So that consumers in executors does not commit offsets unnecessarily
+ .set(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false")
+
+ // If buffer config is not set, set it to reasonable value to work around
+ // buffer issues (see KAFKA-3135)
+ .setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer)
+ .build()
+
+ /** Class to conveniently update Kafka config params, while logging the changes */
+ private case class ConfigUpdater(module: String, kafkaParams: Map[String, String]) {
+ private val map = new ju.HashMap[String, Object](kafkaParams.asJava)
+
+ def set(key: String, value: Object): this.type = {
+ map.put(key, value)
+ logDebug(s"$module: Set $key to $value, earlier value: ${kafkaParams.getOrElse(key, "")}")
+ this
+ }
+
+ def setIfUnset(key: String, value: Object): ConfigUpdater = {
+ if (!map.containsKey(key)) {
+ map.put(key, value)
+ logDebug(s"$module: Set $key to $value")
+ }
+ this
+ }
+
+ def build(): ju.Map[String, Object] = map
+ }
}
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala
index 6fb3473eb75f..9d9e2aaba807 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala
@@ -125,16 +125,15 @@ private[kafka010] class KafkaSourceRDD(
context: TaskContext): Iterator[ConsumerRecord[Array[Byte], Array[Byte]]] = {
val sourcePartition = thePart.asInstanceOf[KafkaSourceRDDPartition]
val topic = sourcePartition.offsetRange.topic
- if (!reuseKafkaConsumer) {
- // if we can't reuse CachedKafkaConsumers, let's reset the groupId to something unique
- // to each task (i.e., append the task's unique partition id), because we will have
- // multiple tasks (e.g., in the case of union) reading from the same topic partitions
- val old = executorKafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String]
- val id = TaskContext.getPartitionId()
- executorKafkaParams.put(ConsumerConfig.GROUP_ID_CONFIG, old + "-" + id)
- }
val kafkaPartition = sourcePartition.offsetRange.partition
- val consumer = CachedKafkaConsumer.getOrCreate(topic, kafkaPartition, executorKafkaParams)
+ val consumer =
+ if (!reuseKafkaConsumer) {
+ // If we can't reuse CachedKafkaConsumers, creating a new CachedKafkaConsumer. As here we
+ // uses `assign`, we don't need to worry about the "group.id" conflicts.
+ CachedKafkaConsumer.createUncached(topic, kafkaPartition, executorKafkaParams)
+ } else {
+ CachedKafkaConsumer.getOrCreate(topic, kafkaPartition, executorKafkaParams)
+ }
val range = resolveRange(consumer, sourcePartition.offsetRange)
assert(
range.fromOffset <= range.untilOffset,
@@ -170,7 +169,7 @@ private[kafka010] class KafkaSourceRDD(
override protected def close(): Unit = {
if (!reuseKafkaConsumer) {
// Don't forget to close non-reuse KafkaConsumers. You may take down your cluster!
- CachedKafkaConsumer.removeKafkaConsumer(topic, kafkaPartition, executorKafkaParams)
+ consumer.close()
} else {
// Indicate that we're no longer using this consumer
CachedKafkaConsumer.releaseKafkaConsumer(topic, kafkaPartition, executorKafkaParams)
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala
index a637d52c933a..61936e32fd83 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala
@@ -47,7 +47,7 @@ private[kafka010] object KafkaWriter extends Logging {
queryExecution: QueryExecution,
kafkaParameters: ju.Map[String, Object],
topic: Option[String] = None): Unit = {
- val schema = queryExecution.logical.output
+ val schema = queryExecution.analyzed.output
schema.find(_.name == TOPIC_ATTRIBUTE_NAME).getOrElse(
if (topic == None) {
throw new AnalysisException(s"topic option required when no " +
@@ -84,7 +84,7 @@ private[kafka010] object KafkaWriter extends Logging {
queryExecution: QueryExecution,
kafkaParameters: ju.Map[String, Object],
topic: Option[String] = None): Unit = {
- val schema = queryExecution.logical.output
+ val schema = queryExecution.analyzed.output
validateQuery(queryExecution, kafkaParameters, topic)
SQLExecution.withNewExecutionId(sparkSession, queryExecution) {
queryExecution.toRdd.foreachPartition { iter =>
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala
index 4bd052d249ec..2ab336c7ac47 100644
--- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala
@@ -28,6 +28,7 @@ import org.apache.spark.SparkException
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, SpecificInternalRow, UnsafeProjection}
import org.apache.spark.sql.execution.streaming.MemoryStream
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.streaming._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{BinaryType, DataType}
@@ -108,6 +109,21 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext {
s"save mode overwrite not allowed for kafka"))
}
+ test("SPARK-20496: batch - enforce analyzed plans") {
+ val inputEvents =
+ spark.range(1, 1000)
+ .select(to_json(struct("*")) as 'value)
+
+ val topic = newTopic()
+ testUtils.createTopic(topic)
+ // used to throw UnresolvedException
+ inputEvents.write
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("topic", topic)
+ .save()
+ }
+
test("streaming - write to kafka with topic field") {
val input = MemoryStream[String]
val topic = newTopic()
diff --git a/external/kafka-0-10/pom.xml b/external/kafka-0-10/pom.xml
index 88499240cd56..e4336ecb07da 100644
--- a/external/kafka-0-10/pom.xml
+++ b/external/kafka-0-10/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.2.1-SNAPSHOT
../../pom.xml
diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala
index 4c6e2ce87e29..62cdf5b1134e 100644
--- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala
+++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala
@@ -199,7 +199,7 @@ private[spark] class KafkaRDD[K, V](
val consumer = if (useConsumerCache) {
CachedKafkaConsumer.init(cacheInitialCapacity, cacheMaxCapacity, cacheLoadFactor)
- if (context.attemptNumber > 1) {
+ if (context.attemptNumber >= 1) {
// just in case the prior attempt failures were cache related
CachedKafkaConsumer.remove(groupId, part.topic, part.partition)
}
diff --git a/external/kafka-0-8-assembly/pom.xml b/external/kafka-0-8-assembly/pom.xml
index 3fedd9eda195..2489d29ebe16 100644
--- a/external/kafka-0-8-assembly/pom.xml
+++ b/external/kafka-0-8-assembly/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.2.1-SNAPSHOT
../../pom.xml
diff --git a/external/kafka-0-8/pom.xml b/external/kafka-0-8/pom.xml
index 8368a1f12218..98f81aee376a 100644
--- a/external/kafka-0-8/pom.xml
+++ b/external/kafka-0-8/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.2.1-SNAPSHOT
../../pom.xml
diff --git a/external/kinesis-asl-assembly/pom.xml b/external/kinesis-asl-assembly/pom.xml
index 90bb0e4987c8..88515f853edb 100644
--- a/external/kinesis-asl-assembly/pom.xml
+++ b/external/kinesis-asl-assembly/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.2.1-SNAPSHOT
../../pom.xml
diff --git a/external/kinesis-asl/pom.xml b/external/kinesis-asl/pom.xml
index daa79e79163b..28797e3fe432 100644
--- a/external/kinesis-asl/pom.xml
+++ b/external/kinesis-asl/pom.xml
@@ -20,7 +20,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.2.1-SNAPSHOT
../../pom.xml
diff --git a/external/spark-ganglia-lgpl/pom.xml b/external/spark-ganglia-lgpl/pom.xml
index 7da27817ebaf..701455f22609 100644
--- a/external/spark-ganglia-lgpl/pom.xml
+++ b/external/spark-ganglia-lgpl/pom.xml
@@ -20,7 +20,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.2.1-SNAPSHOT
../../pom.xml
diff --git a/graphx/pom.xml b/graphx/pom.xml
index 8df33660ea9d..1ed38a794f44 100644
--- a/graphx/pom.xml
+++ b/graphx/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.2.1-SNAPSHOT
../pom.xml
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
index 646462b4a835..755c6febc48e 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
@@ -19,7 +19,10 @@ package org.apache.spark.graphx
import scala.reflect.ClassTag
+import org.apache.spark.graphx.util.PeriodicGraphCheckpointer
import org.apache.spark.internal.Logging
+import org.apache.spark.rdd.RDD
+import org.apache.spark.rdd.util.PeriodicRDDCheckpointer
/**
* Implements a Pregel-like bulk-synchronous message-passing API.
@@ -122,27 +125,39 @@ object Pregel extends Logging {
require(maxIterations > 0, s"Maximum number of iterations must be greater than 0," +
s" but got ${maxIterations}")
- var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg)).cache()
+ val checkpointInterval = graph.vertices.sparkContext.getConf
+ .getInt("spark.graphx.pregel.checkpointInterval", -1)
+ var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg))
+ val graphCheckpointer = new PeriodicGraphCheckpointer[VD, ED](
+ checkpointInterval, graph.vertices.sparkContext)
+ graphCheckpointer.update(g)
+
// compute the messages
var messages = GraphXUtils.mapReduceTriplets(g, sendMsg, mergeMsg)
+ val messageCheckpointer = new PeriodicRDDCheckpointer[(VertexId, A)](
+ checkpointInterval, graph.vertices.sparkContext)
+ messageCheckpointer.update(messages.asInstanceOf[RDD[(VertexId, A)]])
var activeMessages = messages.count()
+
// Loop
var prevG: Graph[VD, ED] = null
var i = 0
while (activeMessages > 0 && i < maxIterations) {
// Receive the messages and update the vertices.
prevG = g
- g = g.joinVertices(messages)(vprog).cache()
+ g = g.joinVertices(messages)(vprog)
+ graphCheckpointer.update(g)
val oldMessages = messages
// Send new messages, skipping edges where neither side received a message. We must cache
// messages so it can be materialized on the next line, allowing us to uncache the previous
// iteration.
messages = GraphXUtils.mapReduceTriplets(
- g, sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache()
+ g, sendMsg, mergeMsg, Some((oldMessages, activeDirection)))
// The call to count() materializes `messages` and the vertices of `g`. This hides oldMessages
// (depended on by the vertices of g) and the vertices of prevG (depended on by oldMessages
// and the vertices of g).
+ messageCheckpointer.update(messages.asInstanceOf[RDD[(VertexId, A)]])
activeMessages = messages.count()
logInfo("Pregel finished iteration " + i)
@@ -154,7 +169,9 @@ object Pregel extends Logging {
// count the iteration
i += 1
}
- messages.unpersist(blocking = false)
+ messageCheckpointer.unpersistDataSet()
+ graphCheckpointer.deleteAllCheckpoints()
+ messageCheckpointer.deleteAllCheckpoints()
g
} // end of apply
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala
index 13b2b5771918..fd7b7f7c1c48 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala
@@ -226,18 +226,18 @@ object PageRank extends Logging {
// Propagates the message along outbound edges
// and adding start nodes back in with activation resetProb
val rankUpdates = rankGraph.aggregateMessages[BV[Double]](
- ctx => ctx.sendToDst(ctx.srcAttr :* ctx.attr),
- (a : BV[Double], b : BV[Double]) => a :+ b, TripletFields.Src)
+ ctx => ctx.sendToDst(ctx.srcAttr *:* ctx.attr),
+ (a : BV[Double], b : BV[Double]) => a +:+ b, TripletFields.Src)
rankGraph = rankGraph.outerJoinVertices(rankUpdates) {
(vid, oldRank, msgSumOpt) =>
- val popActivations: BV[Double] = msgSumOpt.getOrElse(zero) :* (1.0 - resetProb)
+ val popActivations: BV[Double] = msgSumOpt.getOrElse(zero) *:* (1.0 - resetProb)
val resetActivations = if (sourcesInitMapBC.value contains vid) {
- sourcesInitMapBC.value(vid) :* resetProb
+ sourcesInitMapBC.value(vid) *:* resetProb
} else {
zero
}
- popActivations :+ resetActivations
+ popActivations +:+ resetActivations
}.cache()
rankGraph.edges.foreachPartition(x => {}) // also materializes rankGraph.vertices
@@ -250,9 +250,9 @@ object PageRank extends Logging {
}
// SPARK-18847 If the graph has sinks (vertices with no outgoing edges) correct the sum of ranks
- val rankSums = rankGraph.vertices.values.fold(zero)(_ :+ _)
+ val rankSums = rankGraph.vertices.values.fold(zero)(_ +:+ _)
rankGraph.mapVertices { (vid, attr) =>
- Vectors.fromBreeze(attr :/ rankSums)
+ Vectors.fromBreeze(attr /:/ rankSums)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointer.scala
similarity index 91%
rename from mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala
rename to graphx/src/main/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointer.scala
index 80074897567e..fda501aa757d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointer.scala
@@ -15,11 +15,12 @@
* limitations under the License.
*/
-package org.apache.spark.mllib.impl
+package org.apache.spark.graphx.util
import org.apache.spark.SparkContext
import org.apache.spark.graphx.Graph
import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.PeriodicCheckpointer
/**
@@ -74,9 +75,8 @@ import org.apache.spark.storage.StorageLevel
* @tparam VD Vertex descriptor type
* @tparam ED Edge descriptor type
*
- * TODO: Move this out of MLlib?
*/
-private[mllib] class PeriodicGraphCheckpointer[VD, ED](
+private[spark] class PeriodicGraphCheckpointer[VD, ED](
checkpointInterval: Int,
sc: SparkContext)
extends PeriodicCheckpointer[Graph[VD, ED]](checkpointInterval, sc) {
@@ -87,10 +87,13 @@ private[mllib] class PeriodicGraphCheckpointer[VD, ED](
override protected def persist(data: Graph[VD, ED]): Unit = {
if (data.vertices.getStorageLevel == StorageLevel.NONE) {
- data.vertices.persist()
+ /* We need to use cache because persist does not honor the default storage level requested
+ * when constructing the graph. Only cache does that.
+ */
+ data.vertices.cache()
}
if (data.edges.getStorageLevel == StorageLevel.NONE) {
- data.edges.persist()
+ data.edges.cache()
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointerSuite.scala
similarity index 70%
rename from mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
rename to graphx/src/test/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointerSuite.scala
index a13e7f63a929..e0c65e6940f6 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointerSuite.scala
@@ -15,77 +15,81 @@
* limitations under the License.
*/
-package org.apache.spark.mllib.impl
+package org.apache.spark.graphx.util
import org.apache.hadoop.fs.Path
import org.apache.spark.{SparkContext, SparkFunSuite}
-import org.apache.spark.graphx.{Edge, Graph}
-import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.graphx.{Edge, Graph, LocalSparkContext}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
-class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext {
+class PeriodicGraphCheckpointerSuite extends SparkFunSuite with LocalSparkContext {
import PeriodicGraphCheckpointerSuite._
test("Persisting") {
var graphsToCheck = Seq.empty[GraphToCheck]
- val graph1 = createGraph(sc)
- val checkpointer =
- new PeriodicGraphCheckpointer[Double, Double](10, graph1.vertices.sparkContext)
- checkpointer.update(graph1)
- graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1)
- checkPersistence(graphsToCheck, 1)
-
- var iteration = 2
- while (iteration < 9) {
- val graph = createGraph(sc)
- checkpointer.update(graph)
- graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration)
- checkPersistence(graphsToCheck, iteration)
- iteration += 1
+ withSpark { sc =>
+ val graph1 = createGraph(sc)
+ val checkpointer =
+ new PeriodicGraphCheckpointer[Double, Double](10, graph1.vertices.sparkContext)
+ checkpointer.update(graph1)
+ graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1)
+ checkPersistence(graphsToCheck, 1)
+
+ var iteration = 2
+ while (iteration < 9) {
+ val graph = createGraph(sc)
+ checkpointer.update(graph)
+ graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration)
+ checkPersistence(graphsToCheck, iteration)
+ iteration += 1
+ }
}
}
test("Checkpointing") {
- val tempDir = Utils.createTempDir()
- val path = tempDir.toURI.toString
- val checkpointInterval = 2
- var graphsToCheck = Seq.empty[GraphToCheck]
- sc.setCheckpointDir(path)
- val graph1 = createGraph(sc)
- val checkpointer = new PeriodicGraphCheckpointer[Double, Double](
- checkpointInterval, graph1.vertices.sparkContext)
- checkpointer.update(graph1)
- graph1.edges.count()
- graph1.vertices.count()
- graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1)
- checkCheckpoint(graphsToCheck, 1, checkpointInterval)
-
- var iteration = 2
- while (iteration < 9) {
- val graph = createGraph(sc)
- checkpointer.update(graph)
- graph.vertices.count()
- graph.edges.count()
- graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration)
- checkCheckpoint(graphsToCheck, iteration, checkpointInterval)
- iteration += 1
- }
+ withSpark { sc =>
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+ val checkpointInterval = 2
+ var graphsToCheck = Seq.empty[GraphToCheck]
+ sc.setCheckpointDir(path)
+ val graph1 = createGraph(sc)
+ val checkpointer = new PeriodicGraphCheckpointer[Double, Double](
+ checkpointInterval, graph1.vertices.sparkContext)
+ checkpointer.update(graph1)
+ graph1.edges.count()
+ graph1.vertices.count()
+ graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1)
+ checkCheckpoint(graphsToCheck, 1, checkpointInterval)
+
+ var iteration = 2
+ while (iteration < 9) {
+ val graph = createGraph(sc)
+ checkpointer.update(graph)
+ graph.vertices.count()
+ graph.edges.count()
+ graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration)
+ checkCheckpoint(graphsToCheck, iteration, checkpointInterval)
+ iteration += 1
+ }
- checkpointer.deleteAllCheckpoints()
- graphsToCheck.foreach { graph =>
- confirmCheckpointRemoved(graph.graph)
- }
+ checkpointer.deleteAllCheckpoints()
+ graphsToCheck.foreach { graph =>
+ confirmCheckpointRemoved(graph.graph)
+ }
- Utils.deleteRecursively(tempDir)
+ Utils.deleteRecursively(tempDir)
+ }
}
}
private object PeriodicGraphCheckpointerSuite {
+ private val defaultStorageLevel = StorageLevel.MEMORY_ONLY_SER
case class GraphToCheck(graph: Graph[Double, Double], gIndex: Int)
@@ -96,7 +100,8 @@ private object PeriodicGraphCheckpointerSuite {
Edge[Double](3, 4, 0))
def createGraph(sc: SparkContext): Graph[Double, Double] = {
- Graph.fromEdges[Double, Double](sc.parallelize(edges), 0)
+ Graph.fromEdges[Double, Double](
+ sc.parallelize(edges), 0, defaultStorageLevel, defaultStorageLevel)
}
def checkPersistence(graphs: Seq[GraphToCheck], iteration: Int): Unit = {
@@ -116,8 +121,8 @@ private object PeriodicGraphCheckpointerSuite {
assert(graph.vertices.getStorageLevel == StorageLevel.NONE)
assert(graph.edges.getStorageLevel == StorageLevel.NONE)
} else {
- assert(graph.vertices.getStorageLevel != StorageLevel.NONE)
- assert(graph.edges.getStorageLevel != StorageLevel.NONE)
+ assert(graph.vertices.getStorageLevel == defaultStorageLevel)
+ assert(graph.edges.getStorageLevel == defaultStorageLevel)
}
} catch {
case _: AssertionError =>
diff --git a/launcher/pom.xml b/launcher/pom.xml
index 025cd84f20f0..a4bb50ce7dda 100644
--- a/launcher/pom.xml
+++ b/launcher/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.2.1-SNAPSHOT
../pom.xml
diff --git a/mllib-local/pom.xml b/mllib-local/pom.xml
index 663f7fb0b010..16cce0a49653 100644
--- a/mllib-local/pom.xml
+++ b/mllib-local/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.2.1-SNAPSHOT
../pom.xml
diff --git a/mllib/pom.xml b/mllib/pom.xml
index 82f840b0fc26..fec1be909946 100644
--- a/mllib/pom.xml
+++ b/mllib/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.2.1-SNAPSHOT
../pom.xml
diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/LossFunction.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/LossFunction.scala
index 32d78e9b226e..3aea568cd652 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/ann/LossFunction.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/ann/LossFunction.scala
@@ -56,7 +56,7 @@ private[ann] class SigmoidLayerModelWithSquaredError
extends FunctionalLayerModel(new FunctionalLayer(new SigmoidFunction)) with LossFunction {
override def loss(output: BDM[Double], target: BDM[Double], delta: BDM[Double]): Double = {
ApplyInPlace(output, target, delta, (o: Double, t: Double) => o - t)
- val error = Bsum(delta :* delta) / 2 / output.cols
+ val error = Bsum(delta *:* delta) / 2 / output.cols
ApplyInPlace(delta, output, delta, (x: Double, o: Double) => x * (o - o * o))
error
}
@@ -119,6 +119,6 @@ private[ann] class SoftmaxLayerModelWithCrossEntropyLoss extends LayerModel with
override def loss(output: BDM[Double], target: BDM[Double], delta: BDM[Double]): Double = {
ApplyInPlace(output, target, delta, (o: Double, t: Double) => o - t)
- -Bsum( target :* brzlog(output)) / output.cols
+ -Bsum( target *:* brzlog(output)) / output.cols
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
index f76b14eeeb54..7507c7539d4e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
@@ -458,9 +458,7 @@ private class LinearSVCAggregator(
*/
def add(instance: Instance): this.type = {
instance match { case Instance(label, weight, features) =>
- require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0")
- require(numFeatures == features.size, s"Dimensions mismatch when adding new instance." +
- s" Expecting $numFeatures but got ${features.size}.")
+
if (weight == 0.0) return this
val localFeaturesStd = bcFeaturesStd.value
val localCoefficients = coefficientsArray
@@ -512,6 +510,7 @@ private class LinearSVCAggregator(
* @return This LinearSVCAggregator object.
*/
def merge(other: LinearSVCAggregator): this.type = {
+
if (other.weightSum != 0.0) {
weightSum += other.weightSum
lossSum += other.lossSum
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 965ce3d6f275..42dc7fbebe4c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -22,7 +22,7 @@ import java.util.Locale
import scala.collection.mutable
import breeze.linalg.{DenseVector => BDV}
-import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN}
+import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, LBFGSB => BreezeLBFGSB, OWLQN => BreezeOWLQN}
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkException
@@ -178,11 +178,90 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
}
}
+ /**
+ * The lower bounds on coefficients if fitting under bound constrained optimization.
+ * The bound matrix must be compatible with the shape (1, number of features) for binomial
+ * regression, or (number of classes, number of features) for multinomial regression.
+ * Otherwise, it throws exception.
+ * Default is none.
+ *
+ * @group expertParam
+ */
+ @Since("2.2.0")
+ val lowerBoundsOnCoefficients: Param[Matrix] = new Param(this, "lowerBoundsOnCoefficients",
+ "The lower bounds on coefficients if fitting under bound constrained optimization.")
+
+ /** @group expertGetParam */
+ @Since("2.2.0")
+ def getLowerBoundsOnCoefficients: Matrix = $(lowerBoundsOnCoefficients)
+
+ /**
+ * The upper bounds on coefficients if fitting under bound constrained optimization.
+ * The bound matrix must be compatible with the shape (1, number of features) for binomial
+ * regression, or (number of classes, number of features) for multinomial regression.
+ * Otherwise, it throws exception.
+ * Default is none.
+ *
+ * @group expertParam
+ */
+ @Since("2.2.0")
+ val upperBoundsOnCoefficients: Param[Matrix] = new Param(this, "upperBoundsOnCoefficients",
+ "The upper bounds on coefficients if fitting under bound constrained optimization.")
+
+ /** @group expertGetParam */
+ @Since("2.2.0")
+ def getUpperBoundsOnCoefficients: Matrix = $(upperBoundsOnCoefficients)
+
+ /**
+ * The lower bounds on intercepts if fitting under bound constrained optimization.
+ * The bounds vector size must be equal with 1 for binomial regression, or the number
+ * of classes for multinomial regression. Otherwise, it throws exception.
+ * Default is none.
+ *
+ * @group expertParam
+ */
+ @Since("2.2.0")
+ val lowerBoundsOnIntercepts: Param[Vector] = new Param(this, "lowerBoundsOnIntercepts",
+ "The lower bounds on intercepts if fitting under bound constrained optimization.")
+
+ /** @group expertGetParam */
+ @Since("2.2.0")
+ def getLowerBoundsOnIntercepts: Vector = $(lowerBoundsOnIntercepts)
+
+ /**
+ * The upper bounds on intercepts if fitting under bound constrained optimization.
+ * The bound vector size must be equal with 1 for binomial regression, or the number
+ * of classes for multinomial regression. Otherwise, it throws exception.
+ * Default is none.
+ *
+ * @group expertParam
+ */
+ @Since("2.2.0")
+ val upperBoundsOnIntercepts: Param[Vector] = new Param(this, "upperBoundsOnIntercepts",
+ "The upper bounds on intercepts if fitting under bound constrained optimization.")
+
+ /** @group expertGetParam */
+ @Since("2.2.0")
+ def getUpperBoundsOnIntercepts: Vector = $(upperBoundsOnIntercepts)
+
+ protected def usingBoundConstrainedOptimization: Boolean = {
+ isSet(lowerBoundsOnCoefficients) || isSet(upperBoundsOnCoefficients) ||
+ isSet(lowerBoundsOnIntercepts) || isSet(upperBoundsOnIntercepts)
+ }
+
override protected def validateAndTransformSchema(
schema: StructType,
fitting: Boolean,
featuresDataType: DataType): StructType = {
checkThresholdConsistency()
+ if (usingBoundConstrainedOptimization) {
+ require($(elasticNetParam) == 0.0, "Fitting under bound constrained optimization only " +
+ s"supports L2 regularization, but got elasticNetParam = $getElasticNetParam.")
+ }
+ if (!$(fitIntercept)) {
+ require(!isSet(lowerBoundsOnIntercepts) && !isSet(upperBoundsOnIntercepts),
+ "Please don't set bounds on intercepts if fitting without intercept.")
+ }
super.validateAndTransformSchema(schema, fitting, featuresDataType)
}
}
@@ -217,6 +296,9 @@ class LogisticRegression @Since("1.2.0") (
* For alpha in (0,1), the penalty is a combination of L1 and L2.
* Default is 0.0 which is an L2 penalty.
*
+ * Note: Fitting under bound constrained optimization only supports L2 regularization,
+ * so throws exception if this param is non-zero value.
+ *
* @group setParam
*/
@Since("1.4.0")
@@ -312,6 +394,83 @@ class LogisticRegression @Since("1.2.0") (
def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value)
setDefault(aggregationDepth -> 2)
+ /**
+ * Set the lower bounds on coefficients if fitting under bound constrained optimization.
+ *
+ * @group expertSetParam
+ */
+ @Since("2.2.0")
+ def setLowerBoundsOnCoefficients(value: Matrix): this.type = set(lowerBoundsOnCoefficients, value)
+
+ /**
+ * Set the upper bounds on coefficients if fitting under bound constrained optimization.
+ *
+ * @group expertSetParam
+ */
+ @Since("2.2.0")
+ def setUpperBoundsOnCoefficients(value: Matrix): this.type = set(upperBoundsOnCoefficients, value)
+
+ /**
+ * Set the lower bounds on intercepts if fitting under bound constrained optimization.
+ *
+ * @group expertSetParam
+ */
+ @Since("2.2.0")
+ def setLowerBoundsOnIntercepts(value: Vector): this.type = set(lowerBoundsOnIntercepts, value)
+
+ /**
+ * Set the upper bounds on intercepts if fitting under bound constrained optimization.
+ *
+ * @group expertSetParam
+ */
+ @Since("2.2.0")
+ def setUpperBoundsOnIntercepts(value: Vector): this.type = set(upperBoundsOnIntercepts, value)
+
+ private def assertBoundConstrainedOptimizationParamsValid(
+ numCoefficientSets: Int,
+ numFeatures: Int): Unit = {
+ if (isSet(lowerBoundsOnCoefficients)) {
+ require($(lowerBoundsOnCoefficients).numRows == numCoefficientSets &&
+ $(lowerBoundsOnCoefficients).numCols == numFeatures,
+ "The shape of LowerBoundsOnCoefficients must be compatible with (1, number of features) " +
+ "for binomial regression, or (number of classes, number of features) for multinomial " +
+ "regression, but found: " +
+ s"(${getLowerBoundsOnCoefficients.numRows}, ${getLowerBoundsOnCoefficients.numCols}).")
+ }
+ if (isSet(upperBoundsOnCoefficients)) {
+ require($(upperBoundsOnCoefficients).numRows == numCoefficientSets &&
+ $(upperBoundsOnCoefficients).numCols == numFeatures,
+ "The shape of upperBoundsOnCoefficients must be compatible with (1, number of features) " +
+ "for binomial regression, or (number of classes, number of features) for multinomial " +
+ "regression, but found: " +
+ s"(${getUpperBoundsOnCoefficients.numRows}, ${getUpperBoundsOnCoefficients.numCols}).")
+ }
+ if (isSet(lowerBoundsOnIntercepts)) {
+ require($(lowerBoundsOnIntercepts).size == numCoefficientSets, "The size of " +
+ "lowerBoundsOnIntercepts must be equal with 1 for binomial regression, or the number of " +
+ s"classes for multinomial regression, but found: ${getLowerBoundsOnIntercepts.size}.")
+ }
+ if (isSet(upperBoundsOnIntercepts)) {
+ require($(upperBoundsOnIntercepts).size == numCoefficientSets, "The size of " +
+ "upperBoundsOnIntercepts must be equal with 1 for binomial regression, or the number of " +
+ s"classes for multinomial regression, but found: ${getUpperBoundsOnIntercepts.size}.")
+ }
+ if (isSet(lowerBoundsOnCoefficients) && isSet(upperBoundsOnCoefficients)) {
+ require($(lowerBoundsOnCoefficients).toArray.zip($(upperBoundsOnCoefficients).toArray)
+ .forall(x => x._1 <= x._2), "LowerBoundsOnCoefficients should always be " +
+ "less than or equal to upperBoundsOnCoefficients, but found: " +
+ s"lowerBoundsOnCoefficients = $getLowerBoundsOnCoefficients, " +
+ s"upperBoundsOnCoefficients = $getUpperBoundsOnCoefficients.")
+ }
+ if (isSet(lowerBoundsOnIntercepts) && isSet(upperBoundsOnIntercepts)) {
+ require($(lowerBoundsOnIntercepts).toArray.zip($(upperBoundsOnIntercepts).toArray)
+ .forall(x => x._1 <= x._2), "LowerBoundsOnIntercepts should always be " +
+ "less than or equal to upperBoundsOnIntercepts, but found: " +
+ s"lowerBoundsOnIntercepts = $getLowerBoundsOnIntercepts, " +
+ s"upperBoundsOnIntercepts = $getUpperBoundsOnIntercepts.")
+ }
+ }
+
private var optInitialModel: Option[LogisticRegressionModel] = None
private[spark] def setInitialModel(model: LogisticRegressionModel): this.type = {
@@ -378,6 +537,11 @@ class LogisticRegression @Since("1.2.0") (
}
val numCoefficientSets = if (isMultinomial) numClasses else 1
+ // Check params interaction is valid if fitting under bound constrained optimization.
+ if (usingBoundConstrainedOptimization) {
+ assertBoundConstrainedOptimizationParamsValid(numCoefficientSets, numFeatures)
+ }
+
if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
".train() called with non-matching numClasses and thresholds.length." +
@@ -397,7 +561,7 @@ class LogisticRegression @Since("1.2.0") (
val isConstantLabel = histogram.count(_ != 0.0) == 1
- if ($(fitIntercept) && isConstantLabel) {
+ if ($(fitIntercept) && isConstantLabel && !usingBoundConstrainedOptimization) {
logWarning(s"All labels are the same value and fitIntercept=true, so the coefficients " +
s"will be zeros. Training is not needed.")
val constantLabelIndex = Vectors.dense(histogram).argmax
@@ -434,8 +598,53 @@ class LogisticRegression @Since("1.2.0") (
$(standardization), bcFeaturesStd, regParamL2, multinomial = isMultinomial,
$(aggregationDepth))
+ val numCoeffsPlusIntercepts = numFeaturesPlusIntercept * numCoefficientSets
+
+ val (lowerBounds, upperBounds): (Array[Double], Array[Double]) = {
+ if (usingBoundConstrainedOptimization) {
+ val lowerBounds = Array.fill[Double](numCoeffsPlusIntercepts)(Double.NegativeInfinity)
+ val upperBounds = Array.fill[Double](numCoeffsPlusIntercepts)(Double.PositiveInfinity)
+ val isSetLowerBoundsOnCoefficients = isSet(lowerBoundsOnCoefficients)
+ val isSetUpperBoundsOnCoefficients = isSet(upperBoundsOnCoefficients)
+ val isSetLowerBoundsOnIntercepts = isSet(lowerBoundsOnIntercepts)
+ val isSetUpperBoundsOnIntercepts = isSet(upperBoundsOnIntercepts)
+
+ var i = 0
+ while (i < numCoeffsPlusIntercepts) {
+ val coefficientSetIndex = i % numCoefficientSets
+ val featureIndex = i / numCoefficientSets
+ if (featureIndex < numFeatures) {
+ if (isSetLowerBoundsOnCoefficients) {
+ lowerBounds(i) = $(lowerBoundsOnCoefficients)(
+ coefficientSetIndex, featureIndex) * featuresStd(featureIndex)
+ }
+ if (isSetUpperBoundsOnCoefficients) {
+ upperBounds(i) = $(upperBoundsOnCoefficients)(
+ coefficientSetIndex, featureIndex) * featuresStd(featureIndex)
+ }
+ } else {
+ if (isSetLowerBoundsOnIntercepts) {
+ lowerBounds(i) = $(lowerBoundsOnIntercepts)(coefficientSetIndex)
+ }
+ if (isSetUpperBoundsOnIntercepts) {
+ upperBounds(i) = $(upperBoundsOnIntercepts)(coefficientSetIndex)
+ }
+ }
+ i += 1
+ }
+ (lowerBounds, upperBounds)
+ } else {
+ (null, null)
+ }
+ }
+
val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) {
- new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
+ if (lowerBounds != null && upperBounds != null) {
+ new BreezeLBFGSB(
+ BDV[Double](lowerBounds), BDV[Double](upperBounds), $(maxIter), 10, $(tol))
+ } else {
+ new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
+ }
} else {
val standardizationParam = $(standardization)
def regParamL1Fun = (index: Int) => {
@@ -546,6 +755,26 @@ class LogisticRegression @Since("1.2.0") (
math.log(histogram(1) / histogram(0)))
}
+ if (usingBoundConstrainedOptimization) {
+ // Make sure all initial values locate in the corresponding bound.
+ var i = 0
+ while (i < numCoeffsPlusIntercepts) {
+ val coefficientSetIndex = i % numCoefficientSets
+ val featureIndex = i / numCoefficientSets
+ if (initialCoefWithInterceptMatrix(coefficientSetIndex, featureIndex) < lowerBounds(i))
+ {
+ initialCoefWithInterceptMatrix.update(
+ coefficientSetIndex, featureIndex, lowerBounds(i))
+ } else if (
+ initialCoefWithInterceptMatrix(coefficientSetIndex, featureIndex) > upperBounds(i))
+ {
+ initialCoefWithInterceptMatrix.update(
+ coefficientSetIndex, featureIndex, upperBounds(i))
+ }
+ i += 1
+ }
+ }
+
val states = optimizer.iterations(new CachedDiffFunction(costFun),
new BDV[Double](initialCoefWithInterceptMatrix.toArray))
@@ -599,7 +828,7 @@ class LogisticRegression @Since("1.2.0") (
if (isIntercept) interceptVec.toArray(classIndex) = value
}
- if ($(regParam) == 0.0 && isMultinomial) {
+ if ($(regParam) == 0.0 && isMultinomial && !usingBoundConstrainedOptimization) {
/*
When no regularization is applied, the multinomial coefficients lack identifiability
because we do not use a pivot class. We can add any constant value to the coefficients
@@ -609,13 +838,18 @@ class LogisticRegression @Since("1.2.0") (
Friedman, et al. "Regularization Paths for Generalized Linear Models via
Coordinate Descent," https://core.ac.uk/download/files/153/6287975.pdf
*/
- val denseValues = denseCoefficientMatrix.values
- val coefficientMean = denseValues.sum / denseValues.length
- denseCoefficientMatrix.update(_ - coefficientMean)
+ val centers = Array.fill(numFeatures)(0.0)
+ denseCoefficientMatrix.foreachActive { case (i, j, v) =>
+ centers(j) += v
+ }
+ centers.transform(_ / numCoefficientSets)
+ denseCoefficientMatrix.foreachActive { case (i, j, v) =>
+ denseCoefficientMatrix.update(i, j, v - centers(j))
+ }
}
// center the intercepts when using multinomial algorithm
- if ($(fitIntercept) && isMultinomial) {
+ if ($(fitIntercept) && isMultinomial && !usingBoundConstrainedOptimization) {
val interceptArray = interceptVec.toArray
val interceptMean = interceptArray.sum / interceptArray.length
(0 until interceptVec.size).foreach { i => interceptArray(i) -= interceptMean }
@@ -1566,9 +1800,6 @@ private class LogisticAggregator(
*/
def add(instance: Instance): this.type = {
instance match { case Instance(label, weight, features) =>
- require(numFeatures == features.size, s"Dimensions mismatch when adding new instance." +
- s" Expecting $numFeatures but got ${features.size}.")
- require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0")
if (weight == 0.0) return this
@@ -1591,8 +1822,6 @@ private class LogisticAggregator(
* @return This LogisticAggregator object.
*/
def merge(other: LogisticAggregator): this.type = {
- require(numFeatures == other.numFeatures, s"Dimensions mismatch when merging with another " +
- s"LogisticAggregator. Expecting $numFeatures but got ${other.numFeatures}.")
if (other.weightSum != 0.0) {
weightSum += other.weightSum
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
index a9c1a7ba0bc8..5259ee419445 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
@@ -472,7 +472,7 @@ class GaussianMixture @Since("2.0.0") (
*/
val cov = {
val ss = new DenseVector(new Array[Double](numFeatures)).asBreeze
- slice.foreach(xi => ss += (xi.asBreeze - mean.asBreeze) :^ 2.0)
+ slice.foreach(xi => ss += (xi.asBreeze - mean.asBreeze) ^:^ 2.0)
val diagVec = Vectors.fromBreeze(ss)
BLAS.scal(1.0 / numSamples, diagVec)
val covVec = new DenseVector(Array.fill[Double](
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
index 2f50dc7c85f3..e3026c8efa82 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
@@ -36,7 +36,6 @@ import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedL
EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel,
LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel,
OnlineLDAOptimizer => OldOnlineLDAOptimizer}
-import org.apache.spark.mllib.impl.PeriodicCheckpointer
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.mllib.linalg.MatrixImplicits._
import org.apache.spark.mllib.linalg.VectorImplicits._
@@ -45,9 +44,9 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions.{col, monotonically_increasing_id, udf}
import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.PeriodicCheckpointer
import org.apache.spark.util.VersionUtils
-
private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasMaxIter
with HasSeed with HasCheckpointInterval {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
index d1f3b2af1e48..bb8f2a3aa5f7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
@@ -116,7 +116,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
Bucketizer.binarySearchForBuckets($(splits), feature, keepInvalid)
}
- val newCol = bucketizer(filteredDataset($(inputCol)))
+ val newCol = bucketizer(filteredDataset($(inputCol)).cast(DoubleType))
val newField = prepOutputField(filteredDataset.schema)
filteredDataset.withColumn($(outputCol), newCol, newField.metadata)
}
@@ -130,7 +130,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
- SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
+ SchemaUtils.checkNumericType(schema, $(inputCol))
SchemaUtils.appendColumn(schema, prepOutputField(schema))
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
index d604c1ac001a..8f00daa59f1a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
@@ -17,7 +17,6 @@
package org.apache.spark.ml.fpm
-import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag
import org.apache.hadoop.fs.Path
@@ -54,7 +53,7 @@ private[fpm] trait FPGrowthParams extends Params with HasPredictionCol {
/**
* Minimal support level of the frequent pattern. [0.0, 1.0]. Any pattern that appears
- * more than (minSupport * size-of-the-dataset) times will be output
+ * more than (minSupport * size-of-the-dataset) times will be output in the frequent itemsets.
* Default: 0.3
* @group param
*/
@@ -82,8 +81,8 @@ private[fpm] trait FPGrowthParams extends Params with HasPredictionCol {
def getNumPartitions: Int = $(numPartitions)
/**
- * Minimal confidence for generating Association Rule.
- * Note that minConfidence has no effect during fitting.
+ * Minimal confidence for generating Association Rule. minConfidence will not affect the mining
+ * for frequent itemsets, but will affect the association rules generation.
* Default: 0.8
* @group param
*/
@@ -118,7 +117,7 @@ private[fpm] trait FPGrowthParams extends Params with HasPredictionCol {
* Recommendation. PFP distributes computation in such a way that each worker executes an
* independent group of mining tasks. The FP-Growth algorithm is described in
* Han et al., Mining frequent patterns without
- * candidate generation. Note null values in the feature column are ignored during fit().
+ * candidate generation. Note null values in the itemsCol column are ignored during fit().
*
* @see
* Association rule learning (Wikipedia)
@@ -167,7 +166,6 @@ class FPGrowth @Since("2.2.0") (
}
val parentModel = mllibFP.run(items)
val rows = parentModel.freqItemsets.map(f => Row(f.items, f.freq))
-
val schema = StructType(Seq(
StructField("items", dataset.schema($(itemsCol)).dataType, nullable = false),
StructField("freq", LongType, nullable = false)))
@@ -196,7 +194,7 @@ object FPGrowth extends DefaultParamsReadable[FPGrowth] {
* :: Experimental ::
* Model fitted by FPGrowth.
*
- * @param freqItemsets frequent items in the format of DataFrame("items"[Seq], "freq"[Long])
+ * @param freqItemsets frequent itemsets in the format of DataFrame("items"[Array], "freq"[Long])
*/
@Since("2.2.0")
@Experimental
@@ -244,10 +242,13 @@ class FPGrowthModel private[ml] (
/**
* The transform method first generates the association rules according to the frequent itemsets.
- * Then for each association rule, it will examine the input items against antecedents and
- * summarize the consequents as prediction. The prediction column has the same data type as the
- * input column(Array[T]) and will not contain existing items in the input column. The null
- * values in the feature columns are treated as empty sets.
+ * Then for each transaction in itemsCol, the transform method will compare its items against the
+ * antecedents of each association rule. If the record contains all the antecedents of a
+ * specific association rule, the rule will be considered as applicable and its consequents
+ * will be added to the prediction result. The transform method will summarize the consequents
+ * from all the applicable rules as prediction. The prediction column has the same data type as
+ * the input column(Array[T]) and will not contain existing items in the input column. The null
+ * values in the itemsCol columns are treated as empty sets.
* WARNING: internally it collects association rules to the driver and uses broadcast for
* efficiency. This may bring pressure to driver memory for large set of association rules.
*/
@@ -335,13 +336,13 @@ private[fpm] object AssociationRules {
/**
* Computes the association rules with confidence above minConfidence.
- * @param dataset DataFrame("items", "freq") containing frequent itemset obtained from
- * algorithms like [[FPGrowth]].
+ * @param dataset DataFrame("items"[Array], "freq"[Long]) containing frequent itemsets obtained
+ * from algorithms like [[FPGrowth]].
* @param itemsCol column name for frequent itemsets
- * @param freqCol column name for frequent itemsets count
- * @param minConfidence minimum confidence for the result association rules
- * @return a DataFrame("antecedent", "consequent", "confidence") containing the association
- * rules.
+ * @param freqCol column name for appearance count of the frequent itemsets
+ * @param minConfidence minimum confidence for generating the association rules
+ * @return a DataFrame("antecedent"[Array], "consequent"[Array], "confidence"[Double])
+ * containing the association rules.
*/
def getAssociationRulesFromFP[T: ClassTag](
dataset: Dataset[_],
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index d6093a01c671..bff0d9bbb46f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -894,10 +894,10 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine
private[regression] object Probit extends Link("probit") {
- override def link(mu: Double): Double = dist.Gaussian(0.0, 1.0).icdf(mu)
+ override def link(mu: Double): Double = dist.Gaussian(0.0, 1.0).inverseCdf(mu)
override def deriv(mu: Double): Double = {
- 1.0 / dist.Gaussian(0.0, 1.0).pdf(dist.Gaussian(0.0, 1.0).icdf(mu))
+ 1.0 / dist.Gaussian(0.0, 1.0).pdf(dist.Gaussian(0.0, 1.0).inverseCdf(mu))
}
override def unlink(eta: Double): Double = dist.Gaussian(0.0, 1.0).cdf(eta)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index f7e3c8fa5b6e..eaad54985229 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -971,9 +971,6 @@ private class LeastSquaresAggregator(
*/
def add(instance: Instance): this.type = {
instance match { case Instance(label, weight, features) =>
- require(dim == features.size, s"Dimensions mismatch when adding new sample." +
- s" Expecting $dim but got ${features.size}.")
- require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0")
if (weight == 0.0) return this
@@ -1005,8 +1002,6 @@ private class LeastSquaresAggregator(
* @return This LeastSquaresAggregator object.
*/
def merge(other: LeastSquaresAggregator): this.type = {
- require(dim == other.dim, s"Dimensions mismatch when merging with another " +
- s"LeastSquaresAggregator. Expecting $dim but got ${other.dim}.")
if (other.weightSum != 0) {
totalCnt += other.totalCnt
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
index 4c525c0714ec..ce2bd7b430f4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
@@ -21,12 +21,12 @@ import org.apache.spark.internal.Logging
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
-import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy => OldBoostingStrategy}
import org.apache.spark.mllib.tree.impurity.{Variance => OldVariance}
import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
import org.apache.spark.rdd.RDD
+import org.apache.spark.rdd.util.PeriodicRDDCheckpointer
import org.apache.spark.storage.StorageLevel
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala
index 051ec2404fb6..4d952ac88c9b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala
@@ -271,7 +271,7 @@ class GaussianMixture private (
private def initCovariance(x: IndexedSeq[BV[Double]]): BreezeMatrix[Double] = {
val mu = vectorMean(x)
val ss = BDV.zeros[Double](x(0).length)
- x.foreach(xi => ss += (xi - mu) :^ 2.0)
+ x.foreach(xi => ss += (xi - mu) ^:^ 2.0)
diag(ss / x.length.toDouble)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
index 7fd722a33292..663f63c25a94 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
@@ -314,7 +314,7 @@ class LocalLDAModel private[spark] (
docBound += count * LDAUtils.logSumExp(Elogthetad + localElogbeta(idx, ::).t)
}
// E[log p(theta | alpha) - log q(theta | gamma)]
- docBound += sum((brzAlpha - gammad) :* Elogthetad)
+ docBound += sum((brzAlpha - gammad) *:* Elogthetad)
docBound += sum(lgamma(gammad) - lgamma(brzAlpha))
docBound += lgamma(sum(brzAlpha)) - lgamma(sum(gammad))
@@ -324,7 +324,7 @@ class LocalLDAModel private[spark] (
// Bound component for prob(topic-term distributions):
// E[log p(beta | eta) - log q(beta | lambda)]
val sumEta = eta * vocabSize
- val topicsPart = sum((eta - lambda) :* Elogbeta) +
+ val topicsPart = sum((eta - lambda) *:* Elogbeta) +
sum(lgamma(lambda) - lgamma(eta)) +
sum(lgamma(sumEta) - lgamma(sum(lambda(::, breeze.linalg.*))))
@@ -721,7 +721,7 @@ class DistributedLDAModel private[clustering] (
val N_wj = edgeContext.attr
val smoothed_N_wk: TopicCounts = edgeContext.dstAttr + (eta - 1.0)
val smoothed_N_kj: TopicCounts = edgeContext.srcAttr + (alpha - 1.0)
- val phi_wk: TopicCounts = smoothed_N_wk :/ smoothed_N_k
+ val phi_wk: TopicCounts = smoothed_N_wk /:/ smoothed_N_k
val theta_kj: TopicCounts = normalize(smoothed_N_kj, 1.0)
val tokenLogLikelihood = N_wj * math.log(phi_wk.dot(theta_kj))
edgeContext.sendToDst(tokenLogLikelihood)
@@ -748,7 +748,7 @@ class DistributedLDAModel private[clustering] (
if (isTermVertex(vertex)) {
val N_wk = vertex._2
val smoothed_N_wk: TopicCounts = N_wk + (eta - 1.0)
- val phi_wk: TopicCounts = smoothed_N_wk :/ smoothed_N_k
+ val phi_wk: TopicCounts = smoothed_N_wk /:/ smoothed_N_k
sumPrior + (eta - 1.0) * sum(phi_wk.map(math.log))
} else {
val N_kj = vertex._2
@@ -788,20 +788,14 @@ class DistributedLDAModel private[clustering] (
@Since("1.5.0")
def topTopicsPerDocument(k: Int): RDD[(Long, Array[Int], Array[Double])] = {
graph.vertices.filter(LDA.isDocumentVertex).map { case (docID, topicCounts) =>
- // TODO: Remove work-around for the breeze bug.
- // https://github.com/scalanlp/breeze/issues/561
- val topIndices = if (k == topicCounts.length) {
- Seq.range(0, k)
- } else {
- argtopk(topicCounts, k)
- }
+ val topIndices = argtopk(topicCounts, k)
val sumCounts = sum(topicCounts)
val weights = if (sumCounts != 0) {
- topicCounts(topIndices) / sumCounts
+ topicCounts(topIndices).toArray.map(_ / sumCounts)
} else {
- topicCounts(topIndices)
+ topicCounts(topIndices).toArray
}
- (docID.toLong, topIndices.toArray, weights.toArray)
+ (docID.toLong, topIndices.toArray, weights)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
index 48bae4276c48..d633893e55f5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
@@ -25,7 +25,7 @@ import breeze.stats.distributions.{Gamma, RandBasis}
import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.graphx._
-import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer
+import org.apache.spark.graphx.util.PeriodicGraphCheckpointer
import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vector, Vectors}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
@@ -482,7 +482,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
stats.map(_._2).flatMap(list => list).collect().map(_.toDenseMatrix): _*)
stats.unpersist()
expElogbetaBc.destroy(false)
- val batchResult = statsSum :* expElogbeta.t
+ val batchResult = statsSum *:* expElogbeta.t
// Note that this is an optimization to avoid batch.count
updateLambda(batchResult, (miniBatchFraction * corpusSize).ceil.toInt)
@@ -522,7 +522,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer {
val dalpha = -(gradf - b) / q
- if (all((weight * dalpha + alpha) :> 0D)) {
+ if (all((weight * dalpha + alpha) >:> 0D)) {
alpha :+= weight * dalpha
this.alpha = Vectors.dense(alpha.toArray)
}
@@ -584,7 +584,7 @@ private[clustering] object OnlineLDAOptimizer {
val expElogthetad: BDV[Double] = exp(LDAUtils.dirichletExpectation(gammad)) // K
val expElogbetad = expElogbeta(ids, ::).toDenseMatrix // ids * K
- val phiNorm: BDV[Double] = expElogbetad * expElogthetad :+ 1e-100 // ids
+ val phiNorm: BDV[Double] = expElogbetad * expElogthetad +:+ 1e-100 // ids
var meanGammaChange = 1D
val ctsVector = new BDV[Double](cts) // ids
@@ -592,14 +592,14 @@ private[clustering] object OnlineLDAOptimizer {
while (meanGammaChange > 1e-3) {
val lastgamma = gammad.copy
// K K * ids ids
- gammad := (expElogthetad :* (expElogbetad.t * (ctsVector :/ phiNorm))) :+ alpha
+ gammad := (expElogthetad *:* (expElogbetad.t * (ctsVector /:/ phiNorm))) +:+ alpha
expElogthetad := exp(LDAUtils.dirichletExpectation(gammad))
// TODO: Keep more values in log space, and only exponentiate when needed.
- phiNorm := expElogbetad * expElogthetad :+ 1e-100
+ phiNorm := expElogbetad * expElogthetad +:+ 1e-100
meanGammaChange = sum(abs(gammad - lastgamma)) / k
}
- val sstatsd = expElogthetad.asDenseMatrix.t * (ctsVector :/ phiNorm).asDenseMatrix
+ val sstatsd = expElogthetad.asDenseMatrix.t * (ctsVector /:/ phiNorm).asDenseMatrix
(gammad, sstatsd, ids)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala
index 1f6e1a077f92..c4bbe51a46c3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala
@@ -29,7 +29,7 @@ private[clustering] object LDAUtils {
*/
private[clustering] def logSumExp(x: BDV[Double]): Double = {
val a = max(x)
- a + log(sum(exp(x :- a)))
+ a + log(sum(exp(x -:- a)))
}
/**
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index c858b9bbfc25..bf6bfe30bfe2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -26,7 +26,7 @@ import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.classification.LogisticRegressionSuite._
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
-import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, SparseMatrix, Vector, Vectors}
+import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, Matrix, SparseMatrix, Vector, Vectors}
import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
@@ -150,6 +150,54 @@ class LogisticRegressionSuite
assert(!model.hasSummary)
}
+ test("logistic regression: illegal params") {
+ val lowerBoundsOnCoefficients = Matrices.dense(1, 4, Array(1.0, 0.0, 1.0, 0.0))
+ val upperBoundsOnCoefficients1 = Matrices.dense(1, 4, Array(0.0, 1.0, 1.0, 0.0))
+ val upperBoundsOnCoefficients2 = Matrices.dense(1, 3, Array(1.0, 0.0, 1.0))
+ val lowerBoundsOnIntercepts = Vectors.dense(1.0)
+
+ // Work well when only set bound in one side.
+ new LogisticRegression()
+ .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients)
+ .fit(binaryDataset)
+
+ withClue("bound constrained optimization only supports L2 regularization") {
+ intercept[IllegalArgumentException] {
+ new LogisticRegression()
+ .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients)
+ .setElasticNetParam(1.0)
+ .fit(binaryDataset)
+ }
+ }
+
+ withClue("lowerBoundsOnCoefficients should less than or equal to upperBoundsOnCoefficients") {
+ intercept[IllegalArgumentException] {
+ new LogisticRegression()
+ .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients)
+ .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients1)
+ .fit(binaryDataset)
+ }
+ }
+
+ withClue("the coefficients bound matrix mismatched with shape (1, number of features)") {
+ intercept[IllegalArgumentException] {
+ new LogisticRegression()
+ .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients)
+ .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients2)
+ .fit(binaryDataset)
+ }
+ }
+
+ withClue("bounds on intercepts should not be set if fitting without intercept") {
+ intercept[IllegalArgumentException] {
+ new LogisticRegression()
+ .setLowerBoundsOnIntercepts(lowerBoundsOnIntercepts)
+ .setFitIntercept(false)
+ .fit(binaryDataset)
+ }
+ }
+ }
+
test("empty probabilityCol") {
val lr = new LogisticRegression().setProbabilityCol("")
val model = lr.fit(smallBinaryDataset)
@@ -610,6 +658,107 @@ class LogisticRegressionSuite
assert(model2.coefficients ~= coefficientsR relTol 1E-3)
}
+ test("binary logistic regression with intercept without regularization with bound") {
+ // Bound constrained optimization with bound on one side.
+ val upperBoundsOnCoefficients = Matrices.dense(1, 4, Array(1.0, 0.0, 1.0, 0.0))
+ val upperBoundsOnIntercepts = Vectors.dense(1.0)
+
+ val trainer1 = new LogisticRegression()
+ .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients)
+ .setUpperBoundsOnIntercepts(upperBoundsOnIntercepts)
+ .setFitIntercept(true)
+ .setStandardization(true)
+ .setWeightCol("weight")
+ val trainer2 = new LogisticRegression()
+ .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients)
+ .setUpperBoundsOnIntercepts(upperBoundsOnIntercepts)
+ .setFitIntercept(true)
+ .setStandardization(false)
+ .setWeightCol("weight")
+
+ val model1 = trainer1.fit(binaryDataset)
+ val model2 = trainer2.fit(binaryDataset)
+
+ // The solution is generated by https://github.com/yanboliang/bound-optimization.
+ val coefficientsExpected1 = Vectors.dense(0.06079437, 0.0, -0.26351059, -0.59102199)
+ val interceptExpected1 = 1.0
+
+ assert(model1.intercept ~== interceptExpected1 relTol 1E-3)
+ assert(model1.coefficients ~= coefficientsExpected1 relTol 1E-3)
+
+ // Without regularization, with or without standardization will converge to the same solution.
+ assert(model2.intercept ~== interceptExpected1 relTol 1E-3)
+ assert(model2.coefficients ~= coefficientsExpected1 relTol 1E-3)
+
+ // Bound constrained optimization with bound on both side.
+ val lowerBoundsOnCoefficients = Matrices.dense(1, 4, Array(0.0, -1.0, 0.0, -1.0))
+ val lowerBoundsOnIntercepts = Vectors.dense(0.0)
+
+ val trainer3 = new LogisticRegression()
+ .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients)
+ .setUpperBoundsOnIntercepts(upperBoundsOnIntercepts)
+ .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients)
+ .setLowerBoundsOnIntercepts(lowerBoundsOnIntercepts)
+ .setFitIntercept(true)
+ .setStandardization(true)
+ .setWeightCol("weight")
+ val trainer4 = new LogisticRegression()
+ .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients)
+ .setUpperBoundsOnIntercepts(upperBoundsOnIntercepts)
+ .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients)
+ .setLowerBoundsOnIntercepts(lowerBoundsOnIntercepts)
+ .setFitIntercept(true)
+ .setStandardization(false)
+ .setWeightCol("weight")
+
+ val model3 = trainer3.fit(binaryDataset)
+ val model4 = trainer4.fit(binaryDataset)
+
+ // The solution is generated by https://github.com/yanboliang/bound-optimization.
+ val coefficientsExpected3 = Vectors.dense(0.0, 0.0, 0.0, -0.71708632)
+ val interceptExpected3 = 0.58776113
+
+ assert(model3.intercept ~== interceptExpected3 relTol 1E-3)
+ assert(model3.coefficients ~= coefficientsExpected3 relTol 1E-3)
+
+ // Without regularization, with or without standardization will converge to the same solution.
+ assert(model4.intercept ~== interceptExpected3 relTol 1E-3)
+ assert(model4.coefficients ~= coefficientsExpected3 relTol 1E-3)
+
+ // Bound constrained optimization with infinite bound on both side.
+ val trainer5 = new LogisticRegression()
+ .setUpperBoundsOnCoefficients(Matrices.dense(1, 4, Array.fill(4)(Double.PositiveInfinity)))
+ .setUpperBoundsOnIntercepts(Vectors.dense(Double.PositiveInfinity))
+ .setLowerBoundsOnCoefficients(Matrices.dense(1, 4, Array.fill(4)(Double.NegativeInfinity)))
+ .setLowerBoundsOnIntercepts(Vectors.dense(Double.NegativeInfinity))
+ .setFitIntercept(true)
+ .setStandardization(true)
+ .setWeightCol("weight")
+ val trainer6 = new LogisticRegression()
+ .setUpperBoundsOnCoefficients(Matrices.dense(1, 4, Array.fill(4)(Double.PositiveInfinity)))
+ .setUpperBoundsOnIntercepts(Vectors.dense(Double.PositiveInfinity))
+ .setLowerBoundsOnCoefficients(Matrices.dense(1, 4, Array.fill(4)(Double.NegativeInfinity)))
+ .setLowerBoundsOnIntercepts(Vectors.dense(Double.NegativeInfinity))
+ .setFitIntercept(true)
+ .setStandardization(false)
+ .setWeightCol("weight")
+
+ val model5 = trainer5.fit(binaryDataset)
+ val model6 = trainer6.fit(binaryDataset)
+
+ // The solution is generated by https://github.com/yanboliang/bound-optimization.
+ // It should be same as unbound constrained optimization with LBFGS.
+ val coefficientsExpected5 = Vectors.dense(-0.5734389, 0.8911736, -0.3878645, -0.8060570)
+ val interceptExpected5 = 2.7355261
+
+ assert(model5.intercept ~== interceptExpected5 relTol 1E-3)
+ assert(model5.coefficients ~= coefficientsExpected5 relTol 1E-3)
+
+ // Without regularization, with or without standardization will converge to the same solution.
+ assert(model6.intercept ~== interceptExpected5 relTol 1E-3)
+ assert(model6.coefficients ~= coefficientsExpected5 relTol 1E-3)
+ }
+
test("binary logistic regression without intercept without regularization") {
val trainer1 = (new LogisticRegression).setFitIntercept(false).setStandardization(true)
.setWeightCol("weight")
@@ -650,6 +799,34 @@ class LogisticRegressionSuite
assert(model2.coefficients ~= coefficientsR relTol 1E-2)
}
+ test("binary logistic regression without intercept without regularization with bound") {
+ val upperBoundsOnCoefficients = Matrices.dense(1, 4, Array(1.0, 0.0, 1.0, 0.0)).toSparse
+
+ val trainer1 = new LogisticRegression()
+ .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients)
+ .setFitIntercept(false)
+ .setStandardization(true)
+ .setWeightCol("weight")
+ val trainer2 = new LogisticRegression()
+ .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients)
+ .setFitIntercept(false)
+ .setStandardization(false)
+ .setWeightCol("weight")
+
+ val model1 = trainer1.fit(binaryDataset)
+ val model2 = trainer2.fit(binaryDataset)
+
+ // The solution is generated by https://github.com/yanboliang/bound-optimization.
+ val coefficientsExpected = Vectors.dense(0.20847553, 0.0, -0.24240289, -0.55568071)
+
+ assert(model1.intercept ~== 0.0 relTol 1E-3)
+ assert(model1.coefficients ~= coefficientsExpected relTol 1E-3)
+
+ // Without regularization, with or without standardization will converge to the same solution.
+ assert(model2.intercept ~== 0.0 relTol 1E-3)
+ assert(model2.coefficients ~= coefficientsExpected relTol 1E-3)
+ }
+
test("binary logistic regression with intercept with L1 regularization") {
val trainer1 = (new LogisticRegression).setFitIntercept(true)
.setElasticNetParam(1.0).setRegParam(0.12).setStandardization(true).setWeightCol("weight")
@@ -815,6 +992,40 @@ class LogisticRegressionSuite
assert(model2.coefficients ~= coefficientsR relTol 1E-3)
}
+ test("binary logistic regression with intercept with L2 regularization with bound") {
+ val upperBoundsOnCoefficients = Matrices.dense(1, 4, Array(1.0, 0.0, 1.0, 0.0))
+ val upperBoundsOnIntercepts = Vectors.dense(1.0)
+
+ val trainer1 = new LogisticRegression()
+ .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients)
+ .setUpperBoundsOnIntercepts(upperBoundsOnIntercepts)
+ .setRegParam(1.37)
+ .setFitIntercept(true)
+ .setStandardization(true)
+ .setWeightCol("weight")
+ val trainer2 = new LogisticRegression()
+ .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients)
+ .setUpperBoundsOnIntercepts(upperBoundsOnIntercepts)
+ .setRegParam(1.37)
+ .setFitIntercept(true)
+ .setStandardization(false)
+ .setWeightCol("weight")
+
+ val model1 = trainer1.fit(binaryDataset)
+ val model2 = trainer2.fit(binaryDataset)
+
+ // The solution is generated by https://github.com/yanboliang/bound-optimization.
+ val coefficientsExpectedWithStd = Vectors.dense(-0.06985003, 0.0, -0.04794278, -0.10168595)
+ val interceptExpectedWithStd = 0.45750141
+ val coefficientsExpected = Vectors.dense(-0.0494524, 0.0, -0.11360797, -0.06313577)
+ val interceptExpected = 0.53722967
+
+ assert(model1.intercept ~== interceptExpectedWithStd relTol 1E-3)
+ assert(model1.coefficients ~= coefficientsExpectedWithStd relTol 1E-3)
+ assert(model2.intercept ~== interceptExpected relTol 1E-3)
+ assert(model2.coefficients ~= coefficientsExpected relTol 1E-3)
+ }
+
test("binary logistic regression without intercept with L2 regularization") {
val trainer1 = (new LogisticRegression).setFitIntercept(false)
.setElasticNetParam(0.0).setRegParam(1.37).setStandardization(true).setWeightCol("weight")
@@ -864,6 +1075,35 @@ class LogisticRegressionSuite
assert(model2.coefficients ~= coefficientsR relTol 1E-2)
}
+ test("binary logistic regression without intercept with L2 regularization with bound") {
+ val upperBoundsOnCoefficients = Matrices.dense(1, 4, Array(1.0, 0.0, 1.0, 0.0))
+
+ val trainer1 = new LogisticRegression()
+ .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients)
+ .setRegParam(1.37)
+ .setFitIntercept(false)
+ .setStandardization(true)
+ .setWeightCol("weight")
+ val trainer2 = new LogisticRegression()
+ .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients)
+ .setRegParam(1.37)
+ .setFitIntercept(false)
+ .setStandardization(false)
+ .setWeightCol("weight")
+
+ val model1 = trainer1.fit(binaryDataset)
+ val model2 = trainer2.fit(binaryDataset)
+
+ // The solution is generated by https://github.com/yanboliang/bound-optimization.
+ val coefficientsExpectedWithStd = Vectors.dense(-0.00796538, 0.0, -0.0394228, -0.0873314)
+ val coefficientsExpected = Vectors.dense(0.01105972, 0.0, -0.08574949, -0.05079558)
+
+ assert(model1.intercept ~== 0.0 relTol 1E-3)
+ assert(model1.coefficients ~= coefficientsExpectedWithStd relTol 1E-3)
+ assert(model2.intercept ~== 0.0 relTol 1E-3)
+ assert(model2.coefficients ~= coefficientsExpected relTol 1E-3)
+ }
+
test("binary logistic regression with intercept with ElasticNet regularization") {
val trainer1 = (new LogisticRegression).setFitIntercept(true).setMaxIter(200)
.setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true).setWeightCol("weight")
@@ -1084,7 +1324,6 @@ class LogisticRegressionSuite
}
test("multinomial logistic regression with intercept without regularization") {
-
val trainer1 = (new LogisticRegression).setFitIntercept(true)
.setElasticNetParam(0.0).setRegParam(0.0).setStandardization(true).setWeightCol("weight")
val trainer2 = (new LogisticRegression).setFitIntercept(true)
@@ -1139,6 +1378,9 @@ class LogisticRegressionSuite
0.10095851, -0.85897154, 0.08392798, 0.07904499), isTransposed = true)
val interceptsR = Vectors.dense(-2.10320093, 0.3394473, 1.76375361)
+ model1.coefficientMatrix.colIter.foreach(v => assert(v.toArray.sum ~== 0.0 absTol eps))
+ model2.coefficientMatrix.colIter.foreach(v => assert(v.toArray.sum ~== 0.0 absTol eps))
+
assert(model1.coefficientMatrix ~== coefficientsR relTol 0.05)
assert(model1.coefficientMatrix.toArray.sum ~== 0.0 absTol eps)
assert(model1.interceptVector ~== interceptsR relTol 0.05)
@@ -1149,6 +1391,110 @@ class LogisticRegressionSuite
assert(model2.interceptVector.toArray.sum ~== 0.0 absTol eps)
}
+ test("multinomial logistic regression with intercept without regularization with bound") {
+ // Bound constrained optimization with bound on one side.
+ val lowerBoundsOnCoefficients = Matrices.dense(3, 4, Array.fill(12)(1.0))
+ val lowerBoundsOnIntercepts = Vectors.dense(Array.fill(3)(1.0))
+
+ val trainer1 = new LogisticRegression()
+ .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients)
+ .setLowerBoundsOnIntercepts(lowerBoundsOnIntercepts)
+ .setFitIntercept(true)
+ .setStandardization(true)
+ .setWeightCol("weight")
+ val trainer2 = new LogisticRegression()
+ .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients)
+ .setLowerBoundsOnIntercepts(lowerBoundsOnIntercepts)
+ .setFitIntercept(true)
+ .setStandardization(false)
+ .setWeightCol("weight")
+
+ val model1 = trainer1.fit(multinomialDataset)
+ val model2 = trainer2.fit(multinomialDataset)
+
+ // The solution is generated by https://github.com/yanboliang/bound-optimization.
+ val coefficientsExpected1 = new DenseMatrix(3, 4, Array(
+ 2.52076464, 2.73596057, 1.87984904, 2.73264492,
+ 1.93302281, 3.71363303, 1.50681746, 1.93398782,
+ 2.37839917, 1.93601818, 1.81924758, 2.45191255), isTransposed = true)
+ val interceptsExpected1 = Vectors.dense(1.00010477, 3.44237083, 4.86740286)
+
+ checkCoefficientsEquivalent(model1.coefficientMatrix, coefficientsExpected1)
+ assert(model1.interceptVector ~== interceptsExpected1 relTol 0.01)
+ checkCoefficientsEquivalent(model2.coefficientMatrix, coefficientsExpected1)
+ assert(model2.interceptVector ~== interceptsExpected1 relTol 0.01)
+
+ // Bound constrained optimization with bound on both side.
+ val upperBoundsOnCoefficients = Matrices.dense(3, 4, Array.fill(12)(2.0))
+ val upperBoundsOnIntercepts = Vectors.dense(Array.fill(3)(2.0))
+
+ val trainer3 = new LogisticRegression()
+ .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients)
+ .setLowerBoundsOnIntercepts(lowerBoundsOnIntercepts)
+ .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients)
+ .setUpperBoundsOnIntercepts(upperBoundsOnIntercepts)
+ .setFitIntercept(true)
+ .setStandardization(true)
+ .setWeightCol("weight")
+ val trainer4 = new LogisticRegression()
+ .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients)
+ .setLowerBoundsOnIntercepts(lowerBoundsOnIntercepts)
+ .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients)
+ .setUpperBoundsOnIntercepts(upperBoundsOnIntercepts)
+ .setFitIntercept(true)
+ .setStandardization(false)
+ .setWeightCol("weight")
+
+ val model3 = trainer3.fit(multinomialDataset)
+ val model4 = trainer4.fit(multinomialDataset)
+
+ // The solution is generated by https://github.com/yanboliang/bound-optimization.
+ val coefficientsExpected3 = new DenseMatrix(3, 4, Array(
+ 1.61967097, 1.16027835, 1.45131448, 1.97390431,
+ 1.30529317, 2.0, 1.12985473, 1.26652854,
+ 1.61647195, 1.0, 1.40642959, 1.72985589), isTransposed = true)
+ val interceptsExpected3 = Vectors.dense(1.0, 2.0, 2.0)
+
+ checkCoefficientsEquivalent(model3.coefficientMatrix, coefficientsExpected3)
+ assert(model3.interceptVector ~== interceptsExpected3 relTol 0.01)
+ checkCoefficientsEquivalent(model4.coefficientMatrix, coefficientsExpected3)
+ assert(model4.interceptVector ~== interceptsExpected3 relTol 0.01)
+
+ // Bound constrained optimization with infinite bound on both side.
+ val trainer5 = new LogisticRegression()
+ .setLowerBoundsOnCoefficients(Matrices.dense(3, 4, Array.fill(12)(Double.NegativeInfinity)))
+ .setLowerBoundsOnIntercepts(Vectors.dense(Array.fill(3)(Double.NegativeInfinity)))
+ .setUpperBoundsOnCoefficients(Matrices.dense(3, 4, Array.fill(12)(Double.PositiveInfinity)))
+ .setUpperBoundsOnIntercepts(Vectors.dense(Array.fill(3)(Double.PositiveInfinity)))
+ .setFitIntercept(true)
+ .setStandardization(true)
+ .setWeightCol("weight")
+ val trainer6 = new LogisticRegression()
+ .setLowerBoundsOnCoefficients(Matrices.dense(3, 4, Array.fill(12)(Double.NegativeInfinity)))
+ .setLowerBoundsOnIntercepts(Vectors.dense(Array.fill(3)(Double.NegativeInfinity)))
+ .setUpperBoundsOnCoefficients(Matrices.dense(3, 4, Array.fill(12)(Double.PositiveInfinity)))
+ .setUpperBoundsOnIntercepts(Vectors.dense(Array.fill(3)(Double.PositiveInfinity)))
+ .setFitIntercept(true)
+ .setStandardization(false)
+ .setWeightCol("weight")
+
+ val model5 = trainer5.fit(multinomialDataset)
+ val model6 = trainer6.fit(multinomialDataset)
+
+ // The solution is generated by https://github.com/yanboliang/bound-optimization.
+ // It should be same as unbound constrained optimization with LBFGS.
+ val coefficientsExpected5 = new DenseMatrix(3, 4, Array(
+ 0.24337896, -0.05916156, 0.14446790, 0.35976165,
+ -0.3443375, 0.9181331, -0.2283959, -0.4388066,
+ 0.10095851, -0.85897154, 0.08392798, 0.07904499), isTransposed = true)
+ val interceptsExpected5 = Vectors.dense(-2.10320093, 0.3394473, 1.76375361)
+
+ checkCoefficientsEquivalent(model5.coefficientMatrix, coefficientsExpected5)
+ assert(model5.interceptVector ~== interceptsExpected5 relTol 0.01)
+ checkCoefficientsEquivalent(model6.coefficientMatrix, coefficientsExpected5)
+ assert(model6.interceptVector ~== interceptsExpected5 relTol 0.01)
+ }
+
test("multinomial logistic regression without intercept without regularization") {
val trainer1 = (new LogisticRegression).setFitIntercept(false)
@@ -1204,6 +1550,9 @@ class LogisticRegressionSuite
-0.3180040, 0.9679074, -0.2252219, -0.4319914,
0.2452411, -0.6046524, 0.1050710, 0.1180180), isTransposed = true)
+ model1.coefficientMatrix.colIter.foreach(v => assert(v.toArray.sum ~== 0.0 absTol eps))
+ model2.coefficientMatrix.colIter.foreach(v => assert(v.toArray.sum ~== 0.0 absTol eps))
+
assert(model1.coefficientMatrix ~== coefficientsR relTol 0.05)
assert(model1.coefficientMatrix.toArray.sum ~== 0.0 absTol eps)
assert(model1.interceptVector.toArray === Array.fill(3)(0.0))
@@ -1214,6 +1563,35 @@ class LogisticRegressionSuite
assert(model2.interceptVector.toArray.sum ~== 0.0 absTol eps)
}
+ test("multinomial logistic regression without intercept without regularization with bound") {
+ val lowerBoundsOnCoefficients = Matrices.dense(3, 4, Array.fill(12)(1.0))
+
+ val trainer1 = new LogisticRegression()
+ .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients)
+ .setFitIntercept(false)
+ .setStandardization(true)
+ .setWeightCol("weight")
+ val trainer2 = new LogisticRegression()
+ .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients)
+ .setFitIntercept(false)
+ .setStandardization(false)
+ .setWeightCol("weight")
+
+ val model1 = trainer1.fit(multinomialDataset)
+ val model2 = trainer2.fit(multinomialDataset)
+
+ // The solution is generated by https://github.com/yanboliang/bound-optimization.
+ val coefficientsExpected = new DenseMatrix(3, 4, Array(
+ 1.62410051, 1.38219391, 1.34486618, 1.74641729,
+ 1.23058989, 2.71787825, 1.0, 1.00007073,
+ 1.79478632, 1.14360459, 1.33011603, 1.55093897), isTransposed = true)
+
+ checkCoefficientsEquivalent(model1.coefficientMatrix, coefficientsExpected)
+ assert(model1.interceptVector.toArray === Array.fill(3)(0.0))
+ checkCoefficientsEquivalent(model2.coefficientMatrix, coefficientsExpected)
+ assert(model2.interceptVector.toArray === Array.fill(3)(0.0))
+ }
+
test("multinomial logistic regression with intercept with L1 regularization") {
// use tighter constraints because OWL-QN solver takes longer to converge
@@ -1512,6 +1890,46 @@ class LogisticRegressionSuite
assert(model2.interceptVector.toArray.sum ~== 0.0 absTol eps)
}
+ test("multinomial logistic regression with intercept with L2 regularization with bound") {
+ val lowerBoundsOnCoefficients = Matrices.dense(3, 4, Array.fill(12)(1.0))
+ val lowerBoundsOnIntercepts = Vectors.dense(Array.fill(3)(1.0))
+
+ val trainer1 = new LogisticRegression()
+ .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients)
+ .setLowerBoundsOnIntercepts(lowerBoundsOnIntercepts)
+ .setRegParam(0.1)
+ .setFitIntercept(true)
+ .setStandardization(true)
+ .setWeightCol("weight")
+ val trainer2 = new LogisticRegression()
+ .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients)
+ .setLowerBoundsOnIntercepts(lowerBoundsOnIntercepts)
+ .setRegParam(0.1)
+ .setFitIntercept(true)
+ .setStandardization(false)
+ .setWeightCol("weight")
+
+ val model1 = trainer1.fit(multinomialDataset)
+ val model2 = trainer2.fit(multinomialDataset)
+
+ // The solution is generated by https://github.com/yanboliang/bound-optimization.
+ val coefficientsExpectedWithStd = new DenseMatrix(3, 4, Array(
+ 1.0, 1.0, 1.0, 1.01647497,
+ 1.0, 1.44105616, 1.0, 1.0,
+ 1.0, 1.0, 1.0, 1.0), isTransposed = true)
+ val interceptsExpectedWithStd = Vectors.dense(2.52055893, 1.0, 2.560682)
+ val coefficientsExpected = new DenseMatrix(3, 4, Array(
+ 1.0, 1.0, 1.03189386, 1.0,
+ 1.0, 1.0, 1.0, 1.0,
+ 1.0, 1.0, 1.0, 1.0), isTransposed = true)
+ val interceptsExpected = Vectors.dense(1.06418835, 1.0, 1.20494701)
+
+ assert(model1.coefficientMatrix ~== coefficientsExpectedWithStd relTol 0.01)
+ assert(model1.interceptVector ~== interceptsExpectedWithStd relTol 0.01)
+ assert(model2.coefficientMatrix ~== coefficientsExpected relTol 0.01)
+ assert(model2.interceptVector ~== interceptsExpected relTol 0.01)
+ }
+
test("multinomial logistic regression without intercept with L2 regularization") {
val trainer1 = (new LogisticRegression).setFitIntercept(false)
.setElasticNetParam(0.0).setRegParam(0.1).setStandardization(true).setWeightCol("weight")
@@ -1609,6 +2027,41 @@ class LogisticRegressionSuite
assert(model2.interceptVector.toArray.sum ~== 0.0 absTol eps)
}
+ test("multinomial logistic regression without intercept with L2 regularization with bound") {
+ val lowerBoundsOnCoefficients = Matrices.dense(3, 4, Array.fill(12)(1.0))
+
+ val trainer1 = new LogisticRegression()
+ .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients)
+ .setRegParam(0.1)
+ .setFitIntercept(false)
+ .setStandardization(true)
+ .setWeightCol("weight")
+ val trainer2 = new LogisticRegression()
+ .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients)
+ .setRegParam(0.1)
+ .setFitIntercept(false)
+ .setStandardization(false)
+ .setWeightCol("weight")
+
+ val model1 = trainer1.fit(multinomialDataset)
+ val model2 = trainer2.fit(multinomialDataset)
+
+ // The solution is generated by https://github.com/yanboliang/bound-optimization.
+ val coefficientsExpectedWithStd = new DenseMatrix(3, 4, Array(
+ 1.01324653, 1.0, 1.0, 1.0415767,
+ 1.0, 1.0, 1.0, 1.0,
+ 1.02244888, 1.0, 1.0, 1.0), isTransposed = true)
+ val coefficientsExpected = new DenseMatrix(3, 4, Array(
+ 1.0, 1.0, 1.03932259, 1.0,
+ 1.0, 1.0, 1.0, 1.0,
+ 1.0, 1.0, 1.03274649, 1.0), isTransposed = true)
+
+ assert(model1.coefficientMatrix ~== coefficientsExpectedWithStd absTol 0.01)
+ assert(model1.interceptVector.toArray === Array.fill(3)(0.0))
+ assert(model2.coefficientMatrix ~== coefficientsExpected absTol 0.01)
+ assert(model2.interceptVector.toArray === Array.fill(3)(0.0))
+ }
+
test("multinomial logistic regression with intercept with elasticnet regularization") {
val trainer1 = (new LogisticRegression).setFitIntercept(true).setWeightCol("weight")
.setElasticNetParam(0.5).setRegParam(0.1).setStandardization(true)
@@ -2267,4 +2720,19 @@ object LogisticRegressionSuite {
val testData = (0 until nPoints).map(i => LabeledPoint(y(i), x(i)))
testData
}
+
+ /**
+ * When no regularization is applied, the multinomial coefficients lack identifiability
+ * because we do not use a pivot class. We can add any constant value to the coefficients
+ * and get the same likelihood. If fitting under bound constrained optimization, we don't
+ * choose the mean centered coefficients like what we do for unbound problems, since they
+ * may out of the bounds. We use this function to check whether two coefficients are equivalent.
+ */
+ def checkCoefficientsEquivalent(coefficients1: Matrix, coefficients2: Matrix): Unit = {
+ coefficients1.colIter.zip(coefficients2.colIter).foreach { case (col1: Vector, col2: Vector) =>
+ (col1.asBreeze - col2.asBreeze).toArray.toSeq.sliding(2).foreach {
+ case Seq(v1, v2) => assert(v1 ~= v2 absTol 1E-3)
+ }
+ }
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
index b56f8e19ca53..3a2be236f125 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
@@ -168,7 +168,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
assert(m1.pi ~== m2.pi relTol 0.01)
assert(m1.theta ~== m2.theta relTol 0.01)
}
- val testParams = Seq(
+ val testParams = Seq[(String, Dataset[_])](
("bernoulli", bernoulliDataset),
("multinomial", dataset)
)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
index aac29137d791..420fb17ddce8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
@@ -26,6 +26,8 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -162,6 +164,29 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
.setSplits(Array(0.1, 0.8, 0.9))
testDefaultReadWrite(t)
}
+
+ test("Bucket numeric features") {
+ val splits = Array(-3.0, 0.0, 3.0)
+ val data = Array(-2.0, -1.0, 0.0, 1.0, 2.0)
+ val expectedBuckets = Array(0.0, 0.0, 1.0, 1.0, 1.0)
+ val dataFrame: DataFrame = data.zip(expectedBuckets).toSeq.toDF("feature", "expected")
+
+ val bucketizer: Bucketizer = new Bucketizer()
+ .setInputCol("feature")
+ .setOutputCol("result")
+ .setSplits(splits)
+
+ val types = Seq(ShortType, IntegerType, LongType, FloatType, DoubleType,
+ ByteType, DecimalType(10, 0))
+ for (mType <- types) {
+ val df = dataFrame.withColumn("feature", col("feature").cast(mType))
+ bucketizer.transform(df).select("result", "expected").collect().foreach {
+ case Row(x: Double, y: Double) =>
+ assert(x === y, "The result is not correct after bucketing in type " +
+ mType.toString + ". " + s"Expected $y but found $x.")
+ }
+ }
+ }
}
private object BucketizerSuite extends SparkFunSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
index 6806cb03bc42..87f8b9034dde 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
@@ -122,6 +122,8 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
.setMinConfidence(0.5678)
assert(fpGrowth.getMinSupport === 0.4567)
assert(model.getMinConfidence === 0.5678)
+ // numPartitions should not have default value.
+ assert(fpGrowth.isDefined(fpGrowth.numPartitions) === false)
MLTestingUtils.checkCopyAndUids(fpGrowth, model)
ParamsSuite.checkParams(fpGrowth)
ParamsSuite.checkParams(model)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
index 572959200f47..3d6a9f8d84ca 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
@@ -191,8 +191,8 @@ class LBFGSSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers
// With smaller convergenceTol, it takes more steps.
assert(lossLBFGS3.length > lossLBFGS2.length)
- // Based on observation, lossLBFGS2 runs 5 iterations, no theoretically guaranteed.
- assert(lossLBFGS3.length == 6)
+ // Based on observation, lossLBFGS3 runs 7 iterations, no theoretically guaranteed.
+ assert(lossLBFGS3.length == 7)
assert((lossLBFGS3(4) - lossLBFGS3(5)) / lossLBFGS3(4) < convergenceTol)
}
diff --git a/pom.xml b/pom.xml
index 14370d92a908..ccd8546a269c 100644
--- a/pom.xml
+++ b/pom.xml
@@ -26,7 +26,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.2.1-SNAPSHOT
pom
Spark Project Parent POM
http://spark.apache.org/
@@ -58,10 +58,6 @@
https://issues.apache.org/jira/browse/SPARK
-
- ${maven.version}
-
-
Dev Mailing List
@@ -136,13 +132,12 @@
10.12.1.1
1.8.2
1.6.0
- 9.2.16.v20160414
+ 9.3.11.v20160721
3.1.0
0.8.0
2.4.0
2.0.8
3.1.2
-
1.7.7
hadoop2
0.9.3
@@ -659,7 +654,7 @@
org.scalanlp
breeze_${scala.binary.version}
- 0.12
+ 0.13.1
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 77dae289f775..e52baf51aed1 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -318,8 +318,8 @@ object SparkBuild extends PomBuild {
enable(MimaBuild.mimaSettings(sparkHome, x))(x)
}
- /* Generate and pick the spark build info from extra-resources and override a dependency */
- enable(Core.settings ++ CoreDependencyOverrides.settings)(core)
+ /* Generate and pick the spark build info from extra-resources */
+ enable(Core.settings)(core)
/* Unsafe settings */
enable(Unsafe.settings)(unsafe)
@@ -443,16 +443,6 @@ object DockerIntegrationTests {
)
}
-/**
- * Overrides to work around sbt's dependency resolution being different from Maven's in Unidoc.
- *
- * Note that, this is a hack that should be removed in the future. See SPARK-20343
- */
-object CoreDependencyOverrides {
- lazy val settings = Seq(
- dependencyOverrides += "org.apache.avro" % "avro" % "1.7.7")
-}
-
/**
* Overrides to work around sbt's dependency resolution being different from Maven's.
*/
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 2961cda553d6..3be07325f416 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -240,6 +240,32 @@ def signal_handler(signal, frame):
if isinstance(threading.current_thread(), threading._MainThread):
signal.signal(signal.SIGINT, signal_handler)
+ def __repr__(self):
+ return "".format(
+ master=self.master,
+ appName=self.appName,
+ )
+
+ def _repr_html_(self):
+ return """
+
+
SparkContext
+
+
Spark UI
+
+
+ - Version
+ v{sc.version}
+ - Master
+ {sc.master}
+ - AppName
+ {sc.appName}
+
+
+ """.format(
+ sc=self
+ )
+
def _initialize_context(self, jconf):
"""
Initialize SparkContext in function to allow subclass specific initialization
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index b4fc357e42d7..a9756ea4af99 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -185,36 +185,33 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
>>> from pyspark.sql import Row
>>> from pyspark.ml.linalg import Vectors
>>> bdf = sc.parallelize([
- ... Row(label=1.0, weight=2.0, features=Vectors.dense(1.0)),
- ... Row(label=0.0, weight=2.0, features=Vectors.sparse(1, [], []))]).toDF()
- >>> blor = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight")
+ ... Row(label=1.0, weight=1.0, features=Vectors.dense(0.0, 5.0)),
+ ... Row(label=0.0, weight=2.0, features=Vectors.dense(1.0, 2.0)),
+ ... Row(label=1.0, weight=3.0, features=Vectors.dense(2.0, 1.0)),
+ ... Row(label=0.0, weight=4.0, features=Vectors.dense(3.0, 3.0))]).toDF()
+ >>> blor = LogisticRegression(regParam=0.01, weightCol="weight")
>>> blorModel = blor.fit(bdf)
>>> blorModel.coefficients
- DenseVector([5.5...])
+ DenseVector([-1.080..., -0.646...])
>>> blorModel.intercept
- -2.68...
- >>> mdf = sc.parallelize([
- ... Row(label=1.0, weight=2.0, features=Vectors.dense(1.0)),
- ... Row(label=0.0, weight=2.0, features=Vectors.sparse(1, [], [])),
- ... Row(label=2.0, weight=2.0, features=Vectors.dense(3.0))]).toDF()
- >>> mlor = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight",
- ... family="multinomial")
+ 3.112...
+ >>> data_path = "data/mllib/sample_multiclass_classification_data.txt"
+ >>> mdf = spark.read.format("libsvm").load(data_path)
+ >>> mlor = LogisticRegression(regParam=0.1, elasticNetParam=1.0, family="multinomial")
>>> mlorModel = mlor.fit(mdf)
- >>> print(mlorModel.coefficientMatrix)
- DenseMatrix([[-2.3...],
- [ 0.2...],
- [ 2.1... ]])
+ >>> mlorModel.coefficientMatrix
+ SparseMatrix(3, 4, [0, 1, 2, 3], [3, 2, 1], [1.87..., -2.75..., -0.50...], 1)
>>> mlorModel.interceptVector
- DenseVector([2.0..., 0.8..., -2.8...])
- >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF()
+ DenseVector([0.04..., -0.42..., 0.37...])
+ >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, 1.0))]).toDF()
>>> result = blorModel.transform(test0).head()
>>> result.prediction
- 0.0
+ 1.0
>>> result.probability
- DenseVector([0.99..., 0.00...])
+ DenseVector([0.02..., 0.97...])
>>> result.rawPrediction
- DenseVector([8.22..., -8.22...])
- >>> test1 = sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]).toDF()
+ DenseVector([-3.54..., 3.54...])
+ >>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF()
>>> blorModel.transform(test1).head().prediction
1.0
>>> blor.setParams("vector")
@@ -224,8 +221,8 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
>>> lr_path = temp_path + "/lr"
>>> blor.save(lr_path)
>>> lr2 = LogisticRegression.load(lr_path)
- >>> lr2.getMaxIter()
- 5
+ >>> lr2.getRegParam()
+ 0.01
>>> model_path = temp_path + "/lr_model"
>>> blorModel.save(model_path)
>>> model2 = LogisticRegressionModel.load(model_path)
@@ -1482,31 +1479,33 @@ class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable):
>>> from pyspark.sql import Row
>>> from pyspark.ml.linalg import Vectors
- >>> df = sc.parallelize([
- ... Row(label=0.0, features=Vectors.dense(1.0, 0.8)),
- ... Row(label=1.0, features=Vectors.sparse(2, [], [])),
- ... Row(label=2.0, features=Vectors.dense(0.5, 0.5))]).toDF()
- >>> lr = LogisticRegression(maxIter=5, regParam=0.01)
+ >>> data_path = "data/mllib/sample_multiclass_classification_data.txt"
+ >>> df = spark.read.format("libsvm").load(data_path)
+ >>> lr = LogisticRegression(regParam=0.01)
>>> ovr = OneVsRest(classifier=lr)
>>> model = ovr.fit(df)
- >>> [x.coefficients for x in model.models]
- [DenseVector([3.3925, 1.8785]), DenseVector([-4.3016, -6.3163]), DenseVector([-4.5855, 6.1785])]
+ >>> model.models[0].coefficients
+ DenseVector([0.5..., -1.0..., 3.4..., 4.2...])
+ >>> model.models[1].coefficients
+ DenseVector([-2.1..., 3.1..., -2.6..., -2.3...])
+ >>> model.models[2].coefficients
+ DenseVector([0.3..., -3.4..., 1.0..., -1.1...])
>>> [x.intercept for x in model.models]
- [-3.64747..., 2.55078..., -1.10165...]
- >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, 0.0))]).toDF()
+ [-2.7..., -2.5..., -1.3...]
+ >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, 0.0, 1.0, 1.0))]).toDF()
>>> model.transform(test0).head().prediction
- 1.0
- >>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF()
- >>> model.transform(test1).head().prediction
0.0
- >>> test2 = sc.parallelize([Row(features=Vectors.dense(0.5, 0.4))]).toDF()
- >>> model.transform(test2).head().prediction
+ >>> test1 = sc.parallelize([Row(features=Vectors.sparse(4, [0], [1.0]))]).toDF()
+ >>> model.transform(test1).head().prediction
2.0
+ >>> test2 = sc.parallelize([Row(features=Vectors.dense(0.5, 0.4, 0.3, 0.2))]).toDF()
+ >>> model.transform(test2).head().prediction
+ 0.0
>>> model_path = temp_path + "/ovr_model"
>>> model.save(model_path)
>>> model2 = OneVsRestModel.load(model_path)
>>> model2.transform(test0).head().prediction
- 1.0
+ 0.0
.. versionadded:: 2.0.0
"""
diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py
index 8bc899a0788b..bcfb36880eb0 100644
--- a/python/pyspark/ml/recommendation.py
+++ b/python/pyspark/ml/recommendation.py
@@ -82,6 +82,14 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
Row(user=1, item=0, prediction=2.6258413791656494)
>>> predictions[2]
Row(user=2, item=0, prediction=-1.5018409490585327)
+ >>> user_recs = model.recommendForAllUsers(3)
+ >>> user_recs.where(user_recs.user == 0)\
+ .select("recommendations.item", "recommendations.rating").collect()
+ [Row(item=[0, 1, 2], rating=[3.910..., 1.992..., -0.138...])]
+ >>> item_recs = model.recommendForAllItems(3)
+ >>> item_recs.where(item_recs.item == 2)\
+ .select("recommendations.user", "recommendations.rating").collect()
+ [Row(user=[2, 1, 0], rating=[4.901..., 3.981..., -0.138...])]
>>> als_path = temp_path + "/als"
>>> als.save(als_path)
>>> als2 = ALS.load(als_path)
@@ -384,6 +392,28 @@ def itemFactors(self):
"""
return self._call_java("itemFactors")
+ @since("2.2.0")
+ def recommendForAllUsers(self, numItems):
+ """
+ Returns top `numItems` items recommended for each user, for all users.
+
+ :param numItems: max number of recommendations for each user
+ :return: a DataFrame of (userCol, recommendations), where recommendations are
+ stored as an array of (itemCol, rating) Rows.
+ """
+ return self._call_java("recommendForAllUsers", numItems)
+
+ @since("2.2.0")
+ def recommendForAllItems(self, numUsers):
+ """
+ Returns top `numUsers` users recommended for each item, for all items.
+
+ :param numUsers: max number of recommendations for each item
+ :return: a DataFrame of (itemCol, recommendations), where recommendations are
+ stored as an array of (userCol, rating) Rows.
+ """
+ return self._call_java("recommendForAllItems", numUsers)
+
if __name__ == "__main__":
import doctest
diff --git a/python/pyspark/mllib/linalg/distributed.py b/python/pyspark/mllib/linalg/distributed.py
index 600655c912ca..4cb802514be5 100644
--- a/python/pyspark/mllib/linalg/distributed.py
+++ b/python/pyspark/mllib/linalg/distributed.py
@@ -28,14 +28,13 @@
from pyspark import RDD, since
from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper
-from pyspark.mllib.linalg import _convert_to_vector, Matrix, QRDecomposition
+from pyspark.mllib.linalg import _convert_to_vector, DenseMatrix, Matrix, QRDecomposition
from pyspark.mllib.stat import MultivariateStatisticalSummary
from pyspark.storagelevel import StorageLevel
-__all__ = ['DistributedMatrix', 'RowMatrix', 'IndexedRow',
- 'IndexedRowMatrix', 'MatrixEntry', 'CoordinateMatrix',
- 'BlockMatrix']
+__all__ = ['BlockMatrix', 'CoordinateMatrix', 'DistributedMatrix', 'IndexedRow',
+ 'IndexedRowMatrix', 'MatrixEntry', 'RowMatrix', 'SingularValueDecomposition']
class DistributedMatrix(object):
@@ -301,6 +300,136 @@ def tallSkinnyQR(self, computeQ=False):
R = decomp.call("R")
return QRDecomposition(Q, R)
+ @since('2.2.0')
+ def computeSVD(self, k, computeU=False, rCond=1e-9):
+ """
+ Computes the singular value decomposition of the RowMatrix.
+
+ The given row matrix A of dimension (m X n) is decomposed into
+ U * s * V'T where
+
+ * U: (m X k) (left singular vectors) is a RowMatrix whose
+ columns are the eigenvectors of (A X A')
+ * s: DenseVector consisting of square root of the eigenvalues
+ (singular values) in descending order.
+ * v: (n X k) (right singular vectors) is a Matrix whose columns
+ are the eigenvectors of (A' X A)
+
+ For more specific details on implementation, please refer
+ the Scala documentation.
+
+ :param k: Number of leading singular values to keep (`0 < k <= n`).
+ It might return less than k if there are numerically zero singular values
+ or there are not enough Ritz values converged before the maximum number of
+ Arnoldi update iterations is reached (in case that matrix A is ill-conditioned).
+ :param computeU: Whether or not to compute U. If set to be
+ True, then U is computed by A * V * s^-1
+ :param rCond: Reciprocal condition number. All singular values
+ smaller than rCond * s[0] are treated as zero
+ where s[0] is the largest singular value.
+ :returns: :py:class:`SingularValueDecomposition`
+
+ >>> rows = sc.parallelize([[3, 1, 1], [-1, 3, 1]])
+ >>> rm = RowMatrix(rows)
+
+ >>> svd_model = rm.computeSVD(2, True)
+ >>> svd_model.U.rows.collect()
+ [DenseVector([-0.7071, 0.7071]), DenseVector([-0.7071, -0.7071])]
+ >>> svd_model.s
+ DenseVector([3.4641, 3.1623])
+ >>> svd_model.V
+ DenseMatrix(3, 2, [-0.4082, -0.8165, -0.4082, 0.8944, -0.4472, 0.0], 0)
+ """
+ j_model = self._java_matrix_wrapper.call(
+ "computeSVD", int(k), bool(computeU), float(rCond))
+ return SingularValueDecomposition(j_model)
+
+ @since('2.2.0')
+ def computePrincipalComponents(self, k):
+ """
+ Computes the k principal components of the given row matrix
+
+ .. note:: This cannot be computed on matrices with more than 65535 columns.
+
+ :param k: Number of principal components to keep.
+ :returns: :py:class:`pyspark.mllib.linalg.DenseMatrix`
+
+ >>> rows = sc.parallelize([[1, 2, 3], [2, 4, 5], [3, 6, 1]])
+ >>> rm = RowMatrix(rows)
+
+ >>> # Returns the two principal components of rm
+ >>> pca = rm.computePrincipalComponents(2)
+ >>> pca
+ DenseMatrix(3, 2, [-0.349, -0.6981, 0.6252, -0.2796, -0.5592, -0.7805], 0)
+
+ >>> # Transform into new dimensions with the greatest variance.
+ >>> rm.multiply(pca).rows.collect() # doctest: +NORMALIZE_WHITESPACE
+ [DenseVector([0.1305, -3.7394]), DenseVector([-0.3642, -6.6983]), \
+ DenseVector([-4.6102, -4.9745])]
+ """
+ return self._java_matrix_wrapper.call("computePrincipalComponents", k)
+
+ @since('2.2.0')
+ def multiply(self, matrix):
+ """
+ Multiply this matrix by a local dense matrix on the right.
+
+ :param matrix: a local dense matrix whose number of rows must match the number of columns
+ of this matrix
+ :returns: :py:class:`RowMatrix`
+
+ >>> rm = RowMatrix(sc.parallelize([[0, 1], [2, 3]]))
+ >>> rm.multiply(DenseMatrix(2, 2, [0, 2, 1, 3])).rows.collect()
+ [DenseVector([2.0, 3.0]), DenseVector([6.0, 11.0])]
+ """
+ if not isinstance(matrix, DenseMatrix):
+ raise ValueError("Only multiplication with DenseMatrix "
+ "is supported.")
+ j_model = self._java_matrix_wrapper.call("multiply", matrix)
+ return RowMatrix(j_model)
+
+
+class SingularValueDecomposition(JavaModelWrapper):
+ """
+ Represents singular value decomposition (SVD) factors.
+
+ .. versionadded:: 2.2.0
+ """
+
+ @property
+ @since('2.2.0')
+ def U(self):
+ """
+ Returns a distributed matrix whose columns are the left
+ singular vectors of the SingularValueDecomposition if computeU was set to be True.
+ """
+ u = self.call("U")
+ if u is not None:
+ mat_name = u.getClass().getSimpleName()
+ if mat_name == "RowMatrix":
+ return RowMatrix(u)
+ elif mat_name == "IndexedRowMatrix":
+ return IndexedRowMatrix(u)
+ else:
+ raise TypeError("Expected RowMatrix/IndexedRowMatrix got %s" % mat_name)
+
+ @property
+ @since('2.2.0')
+ def s(self):
+ """
+ Returns a DenseVector with singular values in descending order.
+ """
+ return self.call("s")
+
+ @property
+ @since('2.2.0')
+ def V(self):
+ """
+ Returns a DenseMatrix whose columns are the right singular
+ vectors of the SingularValueDecomposition.
+ """
+ return self.call("V")
+
class IndexedRow(object):
"""
@@ -528,6 +657,68 @@ def toBlockMatrix(self, rowsPerBlock=1024, colsPerBlock=1024):
colsPerBlock)
return BlockMatrix(java_block_matrix, rowsPerBlock, colsPerBlock)
+ @since('2.2.0')
+ def computeSVD(self, k, computeU=False, rCond=1e-9):
+ """
+ Computes the singular value decomposition of the IndexedRowMatrix.
+
+ The given row matrix A of dimension (m X n) is decomposed into
+ U * s * V'T where
+
+ * U: (m X k) (left singular vectors) is a IndexedRowMatrix
+ whose columns are the eigenvectors of (A X A')
+ * s: DenseVector consisting of square root of the eigenvalues
+ (singular values) in descending order.
+ * v: (n X k) (right singular vectors) is a Matrix whose columns
+ are the eigenvectors of (A' X A)
+
+ For more specific details on implementation, please refer
+ the scala documentation.
+
+ :param k: Number of leading singular values to keep (`0 < k <= n`).
+ It might return less than k if there are numerically zero singular values
+ or there are not enough Ritz values converged before the maximum number of
+ Arnoldi update iterations is reached (in case that matrix A is ill-conditioned).
+ :param computeU: Whether or not to compute U. If set to be
+ True, then U is computed by A * V * s^-1
+ :param rCond: Reciprocal condition number. All singular values
+ smaller than rCond * s[0] are treated as zero
+ where s[0] is the largest singular value.
+ :returns: SingularValueDecomposition object
+
+ >>> rows = [(0, (3, 1, 1)), (1, (-1, 3, 1))]
+ >>> irm = IndexedRowMatrix(sc.parallelize(rows))
+ >>> svd_model = irm.computeSVD(2, True)
+ >>> svd_model.U.rows.collect() # doctest: +NORMALIZE_WHITESPACE
+ [IndexedRow(0, [-0.707106781187,0.707106781187]),\
+ IndexedRow(1, [-0.707106781187,-0.707106781187])]
+ >>> svd_model.s
+ DenseVector([3.4641, 3.1623])
+ >>> svd_model.V
+ DenseMatrix(3, 2, [-0.4082, -0.8165, -0.4082, 0.8944, -0.4472, 0.0], 0)
+ """
+ j_model = self._java_matrix_wrapper.call(
+ "computeSVD", int(k), bool(computeU), float(rCond))
+ return SingularValueDecomposition(j_model)
+
+ @since('2.2.0')
+ def multiply(self, matrix):
+ """
+ Multiply this matrix by a local dense matrix on the right.
+
+ :param matrix: a local dense matrix whose number of rows must match the number of columns
+ of this matrix
+ :returns: :py:class:`IndexedRowMatrix`
+
+ >>> mat = IndexedRowMatrix(sc.parallelize([(0, (0, 1)), (1, (2, 3))]))
+ >>> mat.multiply(DenseMatrix(2, 2, [0, 2, 1, 3])).rows.collect()
+ [IndexedRow(0, [2.0,3.0]), IndexedRow(1, [6.0,11.0])]
+ """
+ if not isinstance(matrix, DenseMatrix):
+ raise ValueError("Only multiplication with DenseMatrix "
+ "is supported.")
+ return IndexedRowMatrix(self._java_matrix_wrapper.call("multiply", matrix))
+
class MatrixEntry(object):
"""
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 523b3f111331..1037bab7f108 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -23,6 +23,7 @@
import sys
import tempfile
import array as pyarray
+from math import sqrt
from time import time, sleep
from shutil import rmtree
@@ -54,6 +55,7 @@
from pyspark.mllib.clustering import StreamingKMeans, StreamingKMeansModel
from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\
DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT
+from pyspark.mllib.linalg.distributed import RowMatrix
from pyspark.mllib.classification import StreamingLogisticRegressionWithSGD
from pyspark.mllib.recommendation import Rating
from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD
@@ -1699,6 +1701,67 @@ def test_binary_term_freqs(self):
": expected " + str(expected[i]) + ", got " + str(output[i]))
+class DimensionalityReductionTests(MLlibTestCase):
+
+ denseData = [
+ Vectors.dense([0.0, 1.0, 2.0]),
+ Vectors.dense([3.0, 4.0, 5.0]),
+ Vectors.dense([6.0, 7.0, 8.0]),
+ Vectors.dense([9.0, 0.0, 1.0])
+ ]
+ sparseData = [
+ Vectors.sparse(3, [(1, 1.0), (2, 2.0)]),
+ Vectors.sparse(3, [(0, 3.0), (1, 4.0), (2, 5.0)]),
+ Vectors.sparse(3, [(0, 6.0), (1, 7.0), (2, 8.0)]),
+ Vectors.sparse(3, [(0, 9.0), (2, 1.0)])
+ ]
+
+ def assertEqualUpToSign(self, vecA, vecB):
+ eq1 = vecA - vecB
+ eq2 = vecA + vecB
+ self.assertTrue(sum(abs(eq1)) < 1e-6 or sum(abs(eq2)) < 1e-6)
+
+ def test_svd(self):
+ denseMat = RowMatrix(self.sc.parallelize(self.denseData))
+ sparseMat = RowMatrix(self.sc.parallelize(self.sparseData))
+ m = 4
+ n = 3
+ for mat in [denseMat, sparseMat]:
+ for k in range(1, 4):
+ rm = mat.computeSVD(k, computeU=True)
+ self.assertEqual(rm.s.size, k)
+ self.assertEqual(rm.U.numRows(), m)
+ self.assertEqual(rm.U.numCols(), k)
+ self.assertEqual(rm.V.numRows, n)
+ self.assertEqual(rm.V.numCols, k)
+
+ # Test that U returned is None if computeU is set to False.
+ self.assertEqual(mat.computeSVD(1).U, None)
+
+ # Test that low rank matrices cannot have number of singular values
+ # greater than a limit.
+ rm = RowMatrix(self.sc.parallelize(tile([1, 2, 3], (3, 1))))
+ self.assertEqual(rm.computeSVD(3, False, 1e-6).s.size, 1)
+
+ def test_pca(self):
+ expected_pcs = array([
+ [0.0, 1.0, 0.0],
+ [sqrt(2.0) / 2.0, 0.0, sqrt(2.0) / 2.0],
+ [sqrt(2.0) / 2.0, 0.0, -sqrt(2.0) / 2.0]
+ ])
+ n = 3
+ denseMat = RowMatrix(self.sc.parallelize(self.denseData))
+ sparseMat = RowMatrix(self.sc.parallelize(self.sparseData))
+ for mat in [denseMat, sparseMat]:
+ for k in range(1, 4):
+ pcs = mat.computePrincipalComponents(k)
+ self.assertEqual(pcs.numRows, n)
+ self.assertEqual(pcs.numCols, k)
+
+ # We can just test the updated principal component for equality.
+ self.assertEqualUpToSign(pcs.toArray()[:, k - 1], expected_pcs[:, k - 1])
+
+
if __name__ == "__main__":
from pyspark.mllib.tests import *
if not _have_scipy:
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 774caf53f3a4..d62ba9623b44 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -371,6 +371,35 @@ def withWatermark(self, eventTime, delayThreshold):
jdf = self._jdf.withWatermark(eventTime, delayThreshold)
return DataFrame(jdf, self.sql_ctx)
+ @since(2.2)
+ def hint(self, name, *parameters):
+ """Specifies some hint on the current DataFrame.
+
+ :param name: A name of the hint.
+ :param parameters: Optional parameters.
+ :return: :class:`DataFrame`
+
+ >>> df.join(df2.hint("broadcast"), "name").show()
+ +----+---+------+
+ |name|age|height|
+ +----+---+------+
+ | Bob| 5| 85|
+ +----+---+------+
+ """
+ if len(parameters) == 1 and isinstance(parameters[0], list):
+ parameters = parameters[0]
+
+ if not isinstance(name, str):
+ raise TypeError("name should be provided as str, got {0}".format(type(name)))
+
+ for p in parameters:
+ if not isinstance(p, str):
+ raise TypeError(
+ "all parameters should be str, got {0} of type {1}".format(p, type(p)))
+
+ jdf = self._jdf.hint(name, self._jseq(parameters))
+ return DataFrame(jdf, self.sql_ctx)
+
@since(1.3)
def count(self):
"""Returns the number of rows in this :class:`DataFrame`.
@@ -1238,7 +1267,7 @@ def fillna(self, value, subset=None):
Value to replace null values with.
If the value is a dict, then `subset` is ignored and `value` must be a mapping
from column name (string) to replacement value. The replacement value must be
- an int, long, float, or string.
+ an int, long, float, boolean, or string.
:param subset: optional list of column names to consider.
Columns specified in subset that do not have matching data type are ignored.
For example, if `value` is a string, and subset contains a non-string column,
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 9f4772eec9f2..c1bf2bd76fb7 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -221,6 +221,17 @@ def __init__(self, sparkContext, jsparkSession=None):
or SparkSession._instantiatedSession._sc._jsc is None:
SparkSession._instantiatedSession = self
+ def _repr_html_(self):
+ return """
+
+
SparkSession - {catalogImplementation}
+ {sc_HTML}
+
+ """.format(
+ catalogImplementation=self.conf.get("spark.sql.catalogImplementation"),
+ sc_HTML=self.sparkContext._repr_html_()
+ )
+
@since(2.0)
def newSession(self):
"""
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 2b2444304e04..2aa2d23c6f0d 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -1711,6 +1711,10 @@ def test_fillna(self):
self.assertEqual(row.age, None)
self.assertEqual(row.height, None)
+ # fillna with dictionary for boolean types
+ row = self.spark.createDataFrame([Row(a=None), Row(a=True)]).fillna({"a": True}).first()
+ self.assertEqual(row.a, True)
+
def test_bitwise_operations(self):
from pyspark.sql import functions
row = Row(a=170, b=75)
@@ -1902,6 +1906,22 @@ def test_functions_broadcast(self):
# planner should not crash without a join
broadcast(df1)._jdf.queryExecution().executedPlan()
+ def test_generic_hints(self):
+ from pyspark.sql import DataFrame
+
+ df1 = self.spark.range(10e10).toDF("id")
+ df2 = self.spark.range(10e10).toDF("id")
+
+ self.assertIsInstance(df1.hint("broadcast"), DataFrame)
+ self.assertIsInstance(df1.hint("broadcast", []), DataFrame)
+
+ # Dummy rules
+ self.assertIsInstance(df1.hint("broadcast", "foo", "bar"), DataFrame)
+ self.assertIsInstance(df1.hint("broadcast", ["foo", "bar"]), DataFrame)
+
+ plan = df1.join(df2.hint("broadcast"), "id")._jdf.queryExecution().executedPlan()
+ self.assertEqual(1, plan.toString().count("BroadcastHashJoin"))
+
def test_toDF_with_schema_string(self):
data = [Row(key=i, value=str(i)) for i in range(100)]
rdd = self.sc.parallelize(data, 5)
diff --git a/python/pyspark/version.py b/python/pyspark/version.py
index 41bf8c269b79..c0bb1968b4b9 100644
--- a/python/pyspark/version.py
+++ b/python/pyspark/version.py
@@ -16,4 +16,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__version__ = "2.2.0.dev0"
+__version__ = "2.2.1.dev0"
diff --git a/repl/pom.xml b/repl/pom.xml
index a256ae3b8418..f3c49dfb0060 100644
--- a/repl/pom.xml
+++ b/repl/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.2.1-SNAPSHOT
../pom.xml
diff --git a/resource-managers/mesos/pom.xml b/resource-managers/mesos/pom.xml
index 03846d9f5a3b..547836050a61 100644
--- a/resource-managers/mesos/pom.xml
+++ b/resource-managers/mesos/pom.xml
@@ -20,7 +20,7 @@
org.apache.spark
spark-parent_2.11
- 2.2.0-SNAPSHOT
+ 2.2.1-SNAPSHOT
../../pom.xml
diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala
index cd98110ddcc0..127fadabcce5 100644
--- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala
+++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala
@@ -101,7 +101,7 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver")