1717
1818package org .apache .spark .internal
1919
20+ import scala .concurrent .duration ._
2021import scala .jdk .CollectionConverters ._
2122
2223import org .apache .logging .log4j .{CloseableThreadContext , Level , LogManager }
@@ -27,6 +28,7 @@ import org.apache.logging.log4j.core.filter.AbstractFilter
2728import org .slf4j .{Logger , LoggerFactory }
2829
2930import org .apache .spark .internal .Logging .SparkShellLoggingFilter
31+ import org .apache .spark .internal .LogKeys
3032import org .apache .spark .util .SparkClassUtils
3133
3234/**
@@ -531,3 +533,158 @@ private[spark] object Logging {
531533 override def isStopped : Boolean = status == LifeCycle .State .STOPPED
532534 }
533535}
536+
537+ /**
538+ * A thread-safe token bucket-based throttler implementation with nanosecond accuracy.
539+ *
540+ * Each instance must be shared across all scopes it should throttle.
541+ * For global throttling that means either by extending this class in an `object` or
542+ * by creating the instance as a field of an `object`.
543+ *
544+ * @param bucketSize This corresponds to the largest possible burst without throttling,
545+ * in number of executions.
546+ * @param tokenRecoveryInterval Time between two tokens being added back to the bucket.
547+ * This is reciprocal of the long-term average unthrottled rate.
548+ *
549+ * Example: With a bucket size of 100 and a recovery interval of 1s, we could log up to 100 events
550+ * in under a second without throttling, but at that point the bucket is exhausted and we only
551+ * regain the ability to log more events at 1 event per second. If we log less than 1 event/s
552+ * the bucket will slowly refill until it's back at 100.
553+ * Either way, we can always log at least 1 event/s.
554+ */
555+ class LogThrottler (
556+ val bucketSize : Int = 100 ,
557+ val tokenRecoveryInterval : FiniteDuration = 1 .second,
558+ val timeSource : NanoTimeTimeSource = SystemNanoTimeSource ) extends Logging {
559+
560+ private var remainingTokens = bucketSize
561+ private var nextRecovery : DeadlineWithTimeSource =
562+ DeadlineWithTimeSource .now(timeSource) + tokenRecoveryInterval
563+ private var numSkipped : Long = 0
564+
565+ /**
566+ * Run `thunk` as long as there are tokens remaining in the bucket,
567+ * otherwise skip and remember number of skips.
568+ *
569+ * The argument to `thunk` is how many previous invocations have been skipped since the last time
570+ * an invocation actually ran.
571+ *
572+ * Note: This method is `synchronized`, so it is concurrency safe.
573+ * However, that also means no heavy-lifting should be done as part of this
574+ * if the throttler is shared between concurrent threads.
575+ * This also means that the synchronized block of the `thunk` that *does* execute will still
576+ * hold up concurrent `thunk`s that will actually get rejected once they hold the lock.
577+ * This is fine at low concurrency/low recovery rates. But if we need this to be more efficient at
578+ * some point, we will need to decouple the check from the `thunk` execution.
579+ */
580+ def throttled (thunk : Long => Unit ): Unit = this .synchronized {
581+ tryRecoverTokens()
582+ if (remainingTokens > 0 ) {
583+ thunk(numSkipped)
584+ numSkipped = 0
585+ remainingTokens -= 1
586+ } else {
587+ numSkipped += 1L
588+ }
589+ }
590+
591+ /**
592+ * Same as [[throttled ]] but turns the number of skipped invocations into a logging message
593+ * that can be appended to item being logged in `thunk`.
594+ */
595+ def throttledWithSkippedLogMessage (thunk : MessageWithContext => Unit ): Unit = {
596+ this .throttled { numSkipped =>
597+ val skippedStr = if (numSkipped != 0L ) {
598+ log " [ ${MDC (LogKeys .NUM_SKIPPED , numSkipped)} similar messages were skipped.] "
599+ } else {
600+ log " "
601+ }
602+ thunk(skippedStr)
603+ }
604+ }
605+
606+ /**
607+ * Try to recover tokens, if the rate allows.
608+ *
609+ * Only call from within a `this.synchronized` block!
610+ */
611+ private [spark] def tryRecoverTokens (): Unit = {
612+ try {
613+ // Doing it one-by-one is a bit inefficient for long periods, but it's easy to avoid jumps
614+ // and rounding errors this way. The inefficiency shouldn't matter as long as the bucketSize
615+ // isn't huge.
616+ while (remainingTokens < bucketSize && nextRecovery.isOverdue()) {
617+ remainingTokens += 1
618+ nextRecovery += tokenRecoveryInterval
619+ }
620+
621+ val currentTime = DeadlineWithTimeSource .now(timeSource)
622+ if (remainingTokens == bucketSize &&
623+ (currentTime - nextRecovery) > tokenRecoveryInterval) {
624+ // Reset the recovery time, so we don't accumulate infinite recovery while nothing is
625+ // going on.
626+ nextRecovery = currentTime + tokenRecoveryInterval
627+ }
628+ } catch {
629+ case _ : IllegalArgumentException =>
630+ // Adding FiniteDuration throws IllegalArgumentException instead of wrapping on overflow.
631+ // Given that this happens every ~300 years, we can afford some non-linearity here,
632+ // rather than taking the effort to properly work around that.
633+ nextRecovery = DeadlineWithTimeSource (Duration (- Long .MaxValue , NANOSECONDS ), timeSource)
634+ }
635+ }
636+
637+ /**
638+ * Resets throttler state to initial state.
639+ * Visible for testing.
640+ */
641+ def reset (): Unit = this .synchronized {
642+ remainingTokens = bucketSize
643+ nextRecovery = DeadlineWithTimeSource .now(timeSource) + tokenRecoveryInterval
644+ numSkipped = 0
645+ }
646+ }
647+
648+ /**
649+ * This is essentially the same as Scala's [[Deadline ]],
650+ * just with a custom source of nanoTime so it can actually be tested properly.
651+ */
652+ case class DeadlineWithTimeSource (
653+ time : FiniteDuration ,
654+ timeSource : NanoTimeTimeSource = SystemNanoTimeSource ) {
655+ // Only implemented the methods LogThrottler actually needs for now.
656+
657+ /**
658+ * Return a deadline advanced (i.e., moved into the future) by the given duration.
659+ */
660+ def + (other : FiniteDuration ): DeadlineWithTimeSource = copy(time = time + other)
661+
662+ /**
663+ * Calculate time difference between this and the other deadline, where the result is directed
664+ * (i.e., may be negative).
665+ */
666+ def - (other : DeadlineWithTimeSource ): FiniteDuration = time - other.time
667+
668+ /**
669+ * Determine whether the deadline lies in the past at the point where this method is called.
670+ */
671+ def isOverdue (): Boolean = (time.toNanos - timeSource.nanoTime()) <= 0
672+ }
673+
674+ object DeadlineWithTimeSource {
675+ /**
676+ * Construct a deadline due exactly at the point where this method is called. Useful for then
677+ * advancing it to obtain a future deadline, or for sampling the current time exactly once and
678+ * then comparing it to multiple deadlines (using subtraction).
679+ */
680+ def now (timeSource : NanoTimeTimeSource = SystemNanoTimeSource ): DeadlineWithTimeSource =
681+ DeadlineWithTimeSource (Duration (timeSource.nanoTime(), NANOSECONDS ), timeSource)
682+ }
683+
684+ /** Generalisation of [[System.nanoTime() ]]. */
685+ private [spark] trait NanoTimeTimeSource {
686+ def nanoTime (): Long
687+ }
688+ private [spark] object SystemNanoTimeSource extends NanoTimeTimeSource {
689+ override def nanoTime (): Long = System .nanoTime()
690+ }
0 commit comments