diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala index 0fa3287f36db..9cb2448fea0f 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala @@ -138,17 +138,17 @@ private[spark] class DirectKafkaInputDStream[K, V]( lagPerPartition.map { case (tp, lag) => val maxRateLimitPerPartition = ppc.maxRatePerPartition(tp) - val backpressureRate = Math.round(lag / totalLag.toFloat * rate) + val backpressureRate = lag / totalLag.toDouble * rate tp -> (if (maxRateLimitPerPartition > 0) { Math.min(backpressureRate, maxRateLimitPerPartition)} else backpressureRate) } - case None => offsets.map { case (tp, offset) => tp -> ppc.maxRatePerPartition(tp) } + case None => offsets.map { case (tp, offset) => tp -> ppc.maxRatePerPartition(tp).toDouble } } if (effectiveRateLimitPerPartition.values.sum > 0) { val secsPerBatch = context.graph.batchDuration.milliseconds.toDouble / 1000 Some(effectiveRateLimitPerPartition.map { - case (tp, limit) => tp -> (secsPerBatch * limit).toLong + case (tp, limit) => tp -> Math.max((secsPerBatch * limit).toLong, 1L) }) } else { None diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala index 453b5e5ab20d..8524743ee284 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala @@ -617,6 +617,54 @@ class DirectKafkaStreamSuite ssc.stop() } + test("maxMessagesPerPartition with zero offset and rate equal to one") { + val topic = "backpressure" + val kafkaParams = getKafkaParams() + val batchIntervalMilliseconds = 60000 + val sparkConf = new SparkConf() + // Safe, even with streaming, because we're using the direct API. + // Using 1 core is useful to make the test more predictable. + .setMaster("local[1]") + .setAppName(this.getClass.getSimpleName) + .set("spark.streaming.kafka.maxRatePerPartition", "100") + + // Setup the streaming context + ssc = new StreamingContext(sparkConf, Milliseconds(batchIntervalMilliseconds)) + val estimateRate = 1L + val fromOffsets = Map( + new TopicPartition(topic, 0) -> 0L, + new TopicPartition(topic, 1) -> 0L, + new TopicPartition(topic, 2) -> 0L, + new TopicPartition(topic, 3) -> 0L + ) + val kafkaStream = withClue("Error creating direct stream") { + new DirectKafkaInputDStream[String, String]( + ssc, + preferredHosts, + ConsumerStrategies.Subscribe[String, String](List(topic), kafkaParams.asScala), + new DefaultPerPartitionConfig(sparkConf) + ) { + currentOffsets = fromOffsets + override val rateController = Some(new ConstantRateController(id, null, estimateRate)) + } + } + + val offsets = Map[TopicPartition, Long]( + new TopicPartition(topic, 0) -> 0, + new TopicPartition(topic, 1) -> 100L, + new TopicPartition(topic, 2) -> 200L, + new TopicPartition(topic, 3) -> 300L + ) + val result = kafkaStream.maxMessagesPerPartition(offsets) + val expected = Map( + new TopicPartition(topic, 0) -> 1L, + new TopicPartition(topic, 1) -> 10L, + new TopicPartition(topic, 2) -> 20L, + new TopicPartition(topic, 3) -> 30L + ) + assert(result.contains(expected), s"Number of messages per partition must be at least 1") + } + /** Get the generated offset ranges from the DirectKafkaStream */ private def getOffsetRanges[K, V]( kafkaStream: DStream[ConsumerRecord[K, V]]): Seq[(Time, Array[OffsetRange])] = { diff --git a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala index d52c230eb784..d6dd0744441e 100644 --- a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala +++ b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala @@ -104,17 +104,17 @@ class DirectKafkaInputDStream[ val totalLag = lagPerPartition.values.sum lagPerPartition.map { case (tp, lag) => - val backpressureRate = Math.round(lag / totalLag.toFloat * rate) + val backpressureRate = lag / totalLag.toDouble * rate tp -> (if (maxRateLimitPerPartition > 0) { Math.min(backpressureRate, maxRateLimitPerPartition)} else backpressureRate) } - case None => offsets.map { case (tp, offset) => tp -> maxRateLimitPerPartition } + case None => offsets.map { case (tp, offset) => tp -> maxRateLimitPerPartition.toDouble } } if (effectiveRateLimitPerPartition.values.sum > 0) { val secsPerBatch = context.graph.batchDuration.milliseconds.toDouble / 1000 Some(effectiveRateLimitPerPartition.map { - case (tp, limit) => tp -> (secsPerBatch * limit).toLong + case (tp, limit) => tp -> Math.max((secsPerBatch * limit).toLong, 1L) }) } else { None diff --git a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala index 06ef5bc3f8bd..3fea6cfd910b 100644 --- a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala +++ b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala @@ -456,6 +456,57 @@ class DirectKafkaStreamSuite ssc.stop() } + test("maxMessagesPerPartition with zero offset and rate equal to one") { + val topic = "backpressure" + val kafkaParams = Map( + "metadata.broker.list" -> kafkaTestUtils.brokerAddress, + "auto.offset.reset" -> "smallest" + ) + + val batchIntervalMilliseconds = 60000 + val sparkConf = new SparkConf() + // Safe, even with streaming, because we're using the direct API. + // Using 1 core is useful to make the test more predictable. + .setMaster("local[1]") + .setAppName(this.getClass.getSimpleName) + .set("spark.streaming.kafka.maxRatePerPartition", "100") + + // Setup the streaming context + ssc = new StreamingContext(sparkConf, Milliseconds(batchIntervalMilliseconds)) + val estimatedRate = 1L + val kafkaStream = withClue("Error creating direct stream") { + val messageHandler = (mmd: MessageAndMetadata[String, String]) => (mmd.key, mmd.message) + val fromOffsets = Map( + TopicAndPartition(topic, 0) -> 0L, + TopicAndPartition(topic, 1) -> 0L, + TopicAndPartition(topic, 2) -> 0L, + TopicAndPartition(topic, 3) -> 0L + ) + new DirectKafkaInputDStream[String, String, StringDecoder, StringDecoder, (String, String)]( + ssc, kafkaParams, fromOffsets, messageHandler) { + override protected[streaming] val rateController = + Some(new DirectKafkaRateController(id, null) { + override def getLatestRate() = estimatedRate + }) + } + } + + val offsets = Map( + TopicAndPartition(topic, 0) -> 0L, + TopicAndPartition(topic, 1) -> 100L, + TopicAndPartition(topic, 2) -> 200L, + TopicAndPartition(topic, 3) -> 300L + ) + val result = kafkaStream.maxMessagesPerPartition(offsets) + val expected = Map( + TopicAndPartition(topic, 0) -> 1L, + TopicAndPartition(topic, 1) -> 10L, + TopicAndPartition(topic, 2) -> 20L, + TopicAndPartition(topic, 3) -> 30L + ) + assert(result.contains(expected), s"Number of messages per partition must be at least 1") + } + /** Get the generated offset ranges from the DirectKafkaStream */ private def getOffsetRanges[K, V]( kafkaStream: DStream[(K, V)]): Seq[(Time, Array[OffsetRange])] = {