diff --git a/stdlib/public/Differentiation/ArrayDifferentiation.swift b/stdlib/public/Differentiation/ArrayDifferentiation.swift index a4b1149a20f56..177e9a3572d6e 100644 --- a/stdlib/public/Differentiation/ArrayDifferentiation.swift +++ b/stdlib/public/Differentiation/ArrayDifferentiation.swift @@ -42,6 +42,14 @@ where Element: Differentiable { return (base, { $0 }) } + @usableFromInline + @derivative(of: base) + func _jvpBase() -> ( + value: [Element], differential: (Array.TangentVector) -> TangentVector + ) { + return (base, { $0 }) + } + /// Creates a differentiable view of the given array. public init(_ base: [Element]) { self._base = base } @@ -53,6 +61,14 @@ where Element: Differentiable { return (Array.DifferentiableView(base), { $0 }) } + @usableFromInline + @derivative(of: init(_:)) + static func _jvpInit(_ base: [Element]) -> ( + value: Array.DifferentiableView, differential: (TangentVector) -> TangentVector + ) { + return (Array.DifferentiableView(base), { $0 }) + } + public typealias TangentVector = Array.DifferentiableView @@ -191,6 +207,17 @@ extension Array where Element: Differentiable { return (self[index], pullback) } + @usableFromInline + @derivative(of: subscript) + func _jvpSubscript(index: Int) -> ( + value: Element, differential: (TangentVector) -> Element.TangentVector + ) { + func differential(_ v: TangentVector) -> Element.TangentVector { + return v[index] + } + return (self[index], differential) + } + @usableFromInline @derivative(of: +) static func _vjpConcatenate(_ lhs: Self, _ rhs: Self) -> ( @@ -210,8 +237,26 @@ extension Array where Element: Differentiable { } return (lhs + rhs, pullback) } + + @usableFromInline + @derivative(of: +) + static func _jvpConcatenate(_ lhs: Self, _ rhs: Self) -> ( + value: Self, + differential: (TangentVector, TangentVector) -> TangentVector + ) { + func differential(_ l: TangentVector, _ r: TangentVector) -> TangentVector { + precondition( + l.base.count == lhs.count && r.base.count == rhs.count, """ + Tangent vectors with invalid count; expected to equal the \ + operand counts \(lhs.count) and \(rhs.count) + """) + return .init(l.base + r.base) + } + return (lhs + rhs, differential) + } } + extension Array where Element: Differentiable { @usableFromInline @derivative(of: append) @@ -277,6 +322,17 @@ extension Array where Element: Differentiable { } ) } + + @usableFromInline + @derivative(of: init(repeating:count:)) + static func _jvpInit(repeating repeatedValue: Element, count: Int) -> ( + value: Self, differential: (Element.TangentVector) -> TangentVector + ) { + ( + value: Self(repeating: repeatedValue, count: count), + differential: { v in TangentVector(.init(repeating: v, count: count)) } + ) + } } //===----------------------------------------------------------------------===// @@ -312,6 +368,27 @@ extension Array where Element: Differentiable { } return (value: values, pullback: pullback) } + + @inlinable + @derivative(of: differentiableMap) + internal func _jvpDifferentiableMap( + _ body: @differentiable (Element) -> Result + ) -> ( + value: [Result], + differential: (Array.TangentVector) -> Array.TangentVector + ) { + var values: [Result] = [] + var differentials: [(Element.TangentVector) -> Result.TangentVector] = [] + for x in self { + let (y, df) = valueWithDifferential(at: x, in: body) + values.append(y) + differentials.append(df) + } + func differential(_ tans: Array.TangentVector) -> Array.TangentVector { + .init(zip(tans.base, differentials).map { tan, df in df(tan) }) + } + return (value: values, differential: differential) + } } extension Array where Element: Differentiable { @@ -361,4 +438,33 @@ extension Array where Element: Differentiable { } ) } + + @inlinable + @derivative(of: differentiableReduce, wrt: (self, initialResult)) + func _jvpDifferentiableReduce( + _ initialResult: Result, + _ nextPartialResult: @differentiable (Result, Element) -> Result + ) -> (value: Result, + differential: (Array.TangentVector, Result.TangentVector) + -> Result.TangentVector) { + var differentials: + [(Result.TangentVector, Element.TangentVector) -> Result.TangentVector] + = [] + let count = self.count + differentials.reserveCapacity(count) + var result = initialResult + for element in self { + let (y, df) = + valueWithDifferential(at: result, element, in: nextPartialResult) + result = y + differentials.append(df) + } + return (value: result, differential: { dSelf, dInitial in + var dResult = dInitial + for (dElement, df) in zip(dSelf.base, differentials) { + dResult = df(dResult, dElement) + } + return dResult + }) + } } diff --git a/test/AutoDiff/validation-test/forward_mode.swift b/test/AutoDiff/validation-test/forward_mode.swift index 22b049158c432..e43740c1244a5 100644 --- a/test/AutoDiff/validation-test/forward_mode.swift +++ b/test/AutoDiff/validation-test/forward_mode.swift @@ -1319,4 +1319,112 @@ ForwardModeTests.test("ForceUnwrapping") { expectEqual(5, forceUnwrap(Float(2))) } +//===----------------------------------------------------------------------===// +// Array methods from ArrayDifferentiation.swift +//===----------------------------------------------------------------------===// + +typealias FloatArrayTan = Array.TangentVector + +ForwardModeTests.test("Array.+") { + func sumFirstThreeConcatenating(_ a: [Float], _ b: [Float]) -> Float { + let c = a + b + return c[0] + c[1] + c[2] + } + + expectEqual(3, differential(at: [0, 0], [0, 0], in: sumFirstThreeConcatenating)(.init([1, 1]), .init([1, 1]))) + expectEqual(0, differential(at: [0, 0], [0, 0], in: sumFirstThreeConcatenating)(.init([0, 0]), .init([0, 1]))) + expectEqual(1, differential(at: [0, 0], [0, 0], in: sumFirstThreeConcatenating)(.init([0, 1]), .init([0, 1]))) + expectEqual(1, differential(at: [0, 0], [0, 0], in: sumFirstThreeConcatenating)(.init([1, 0]), .init([0, 1]))) + expectEqual(1, differential(at: [0, 0], [0, 0], in: sumFirstThreeConcatenating)(.init([0, 0]), .init([1, 1]))) + expectEqual(2, differential(at: [0, 0], [0, 0], in: sumFirstThreeConcatenating)(.init([1, 1]), .init([0, 1]))) + + expectEqual( + 3, + differential(at: [0, 0, 0, 0], [0, 0], in: sumFirstThreeConcatenating)(.init([1, 1, 1, 1]), .init([1, 1]))) + expectEqual( + 3, + differential(at: [0, 0, 0, 0], [0, 0], in: sumFirstThreeConcatenating)(.init([1, 1, 1, 0]), .init([0, 0]))) + + expectEqual( + 3, + differential(at: [], [0, 0, 0, 0], in: sumFirstThreeConcatenating)(.init([]), .init([1, 1, 1, 1]))) + expectEqual( + 0, + differential(at: [], [0, 0, 0, 0], in: sumFirstThreeConcatenating)(.init([]), .init([0, 0, 0, 1]))) +} + +ForwardModeTests.test("Array.init(repeating:count:)") { + @differentiable + func repeating(_ x: Float) -> [Float] { + Array(repeating: x, count: 10) + } + expectEqual(Float(10), derivative(at: .zero) { x in + repeating(x).differentiableReduce(0, {$0 + $1}) + }) + expectEqual(Float(20), differential(at: .zero, in: { x in + repeating(x).differentiableReduce(0, {$0 + $1}) + })(2)) +} + +ForwardModeTests.test("Array.DifferentiableView.init") { + @differentiable + func constructView(_ x: [Float]) -> Array.DifferentiableView { + return Array.DifferentiableView(x) + } + + let forward = differential(at: [5, 6, 7, 8], in: constructView) + expectEqual( + FloatArrayTan([1, 2, 3, 4]), + forward(FloatArrayTan([1, 2, 3, 4]))) +} + +ForwardModeTests.test("Array.DifferentiableView.base") { + @differentiable + func accessBase(_ x: Array.DifferentiableView) -> [Float] { + return x.base + } + + let forward = differential( + at: Array.DifferentiableView([5, 6, 7, 8]), + in: accessBase) + expectEqual( + FloatArrayTan([1, 2, 3, 4]), + forward(FloatArrayTan([1, 2, 3, 4]))) +} + +ForwardModeTests.test("Array.differentiableMap") { + let x: [Float] = [1, 2, 3] + let tan = Array.TangentVector([1, 1, 1]) + + func multiplyMap(_ a: [Float]) -> [Float] { + return a.differentiableMap({ x in 3 * x }) + } + expectEqual([3, 3, 3], differential(at: x, in: multiplyMap)(tan)) + + func squareMap(_ a: [Float]) -> [Float] { + return a.differentiableMap({ x in x * x }) + } + expectEqual([2, 4, 6], differential(at: x, in: squareMap)(tan)) +} + +ForwardModeTests.test("Array.differentiableReduce") { + let x: [Float] = [1, 2, 3] + let tan = Array.TangentVector([1, 1, 1]) + + func sumReduce(_ a: [Float]) -> Float { + return a.differentiableReduce(0, { $0 + $1 }) + } + expectEqual(1 + 1 + 1, differential(at: x, in: sumReduce)(tan)) + + func productReduce(_ a: [Float]) -> Float { + return a.differentiableReduce(1, { $0 * $1 }) + } + expectEqual(x[1] * x[2] + x[0] * x[2] + x[0] * x[1], differential(at: x, in: productReduce)(tan)) + + func sumOfSquaresReduce(_ a: [Float]) -> Float { + return a.differentiableReduce(0, { $0 + $1 * $1 }) + } + expectEqual(2 * x[0] + 2 * x[1] + 2 * x[2], differential(at: x, in: sumOfSquaresReduce)(tan)) +} + runAllTests()