diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala index 9827dfab8684a..e27c44d03c317 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala @@ -269,7 +269,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali if (Random.nextDouble() < probability) { Some(vidVvals._1) } else { None } } - if (selectedVertices.count > 1) { + if (selectedVertices.count > 0) { found = true val collectedVertices = selectedVertices.collect() retVal = collectedVertices(Random.nextInt(collectedVertices.size)) diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index 1f5e27d5508b8..9acbd7960e12f 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -428,4 +428,12 @@ class GraphSuite extends SparkFunSuite with LocalSparkContext { } } + test("SPARK-14219: pickRandomVertex") { + withSpark { sc => + val vert = sc.parallelize(List((1L, "a")), 1) + val edges = sc.parallelize(List(Edge[Long](1L, 1L)), 1) + val g0 = Graph(vert, edges) + assert(g0.pickRandomVertex() === 1L) + } + } }