@@ -376,3 +376,79 @@ public class AMSGrad<Model: Differentiable & KeyPathIterable>: Optimizer
376376 model. move ( along: - stepSize * firstMoments ./ denominator)
377377 }
378378}
379+
380+ /// RAdam optimizer.
381+ ///
382+ /// Rectified Adam, a variant of Adam that introduces a term to rectify the adaptive learning rate
383+ /// variance.
384+ ///
385+ /// Reference: ["On the Variance of the Adaptive Learning Rate and Beyond"]
386+ /// https://arxiv.org/pdf/1908.03265.pdf
387+ public class RAdam < Model: Differentiable > : Optimizer
388+ where Model. TangentVector: VectorProtocol & PointwiseMultiplicative &
389+ ElementaryFunctions & KeyPathIterable ,
390+ Model. TangentVector. VectorSpaceScalar == Float {
391+ public typealias Model = Model
392+ /// The learning rate.
393+ public var learningRate : Float
394+ /// A coefficient used to calculate the first and second moments of the gradients.
395+ public var beta1 : Float
396+ /// A coefficient used to calculate the first and second moments of the gradients.
397+ public var beta2 : Float
398+ /// A small scalar added to the denominator to improve numerical stability.
399+ public var epsilon : Float
400+ /// The learning rate decay.
401+ public var decay : Float
402+ /// The current step.
403+ public var step : Int = 0
404+ /// The first moments of the weights.
405+ public var firstMoments : Model . TangentVector = . zero
406+ /// The second moments of the weights.
407+ public var secondMoments : Model . TangentVector = . zero
408+
409+ public init (
410+ for model: __shared Model,
411+ learningRate: Float = 1e-3 ,
412+ beta1: Float = 0.9 ,
413+ beta2: Float = 0.999 ,
414+ epsilon: Float = 1e-8 ,
415+ decay: Float = 0
416+ ) {
417+ precondition ( learningRate >= 0 , " Learning rate must be non-negative " )
418+ precondition ( 0 <= beta1 && beta1 <= 1 , " Beta parameter must be between 0 and 1 " )
419+ precondition ( 0 <= beta2 && beta2 <= 1 , " Beta parameter must be between 0 and 1 " )
420+ precondition ( decay >= 0 , " Learning rate decay must be non-negative " )
421+
422+ self . learningRate = learningRate
423+ self . beta1 = beta1
424+ self . beta2 = beta2
425+ self . epsilon = epsilon
426+ self . decay = decay
427+ }
428+
429+ public func update( _ model: inout Model , along direction: Model . TangentVector ) {
430+ step += 1
431+ let step = Float ( self . step)
432+ let beta1Power = pow ( beta1, step)
433+ let beta2Power = pow ( beta2, step)
434+ secondMoments = beta2 * secondMoments + direction .* direction * ( 1 - beta2)
435+ firstMoments = beta1 * firstMoments + direction * ( 1 - beta1)
436+ // Compute maximum length SMA, bias-corrected moving average and approximate length.
437+ let N_sma_inf = 2 / ( 1 - beta2) - 1
438+ let N_sma_t = N_sma_inf - 2 * step * beta2Power / ( 1 - beta2Power)
439+
440+ if N_sma_t > 5 {
441+ // Compute bias-corrected second moments, rectification and adapted momentum.
442+ let secondMoments_h = Model . TangentVector. sqrt ( secondMoments) + epsilon
443+ let stepSize = sqrt (
444+ ( N_sma_t - 4 ) * ( N_sma_t - 2 ) * N_sma_inf / (
445+ ( N_sma_inf - 4 ) * ( N_sma_inf - 2 ) * ( N_sma_t)
446+ ) )
447+ model. move ( along: - stepSize * sqrt( 1 - beta2Power) * firstMoments ./ secondMoments_h)
448+ } else {
449+ // Update with un-adapted momentum.
450+ let stepSize = self . learningRate * step / ( 1 - beta1Power)
451+ model. move ( along: - stepSize * firstMoments)
452+ }
453+ }
454+ }
0 commit comments