diff --git a/stdlib/public/Differentiation/ArrayDifferentiation.swift b/stdlib/public/Differentiation/ArrayDifferentiation.swift index a4b1149a20f56..1f8cbf9b09f33 100644 --- a/stdlib/public/Differentiation/ArrayDifferentiation.swift +++ b/stdlib/public/Differentiation/ArrayDifferentiation.swift @@ -312,6 +312,27 @@ extension Array where Element: Differentiable { } return (value: values, pullback: pullback) } + + @inlinable + @derivative(of: differentiableMap) + internal func _jvpDifferentiableMap( + _ body: @differentiable (Element) -> Result + ) -> ( + value: [Result], + differential: (Array.TangentVector) -> Array.TangentVector + ) { + var values: [Result] = [] + var differentials: [(Element.TangentVector) -> Result.TangentVector] = [] + for x in self { + let (y, df) = valueWithDifferential(at: x, in: body) + values.append(y) + differentials.append(df) + } + func differential(_ tans: Array.TangentVector) -> Array.TangentVector { + .init(zip(tans.base, differentials).map { tan, df in df(tan) }) + } + return (value: values, differential: differential) + } } extension Array where Element: Differentiable { diff --git a/test/AutoDiff/validation-test/forward_mode.swift b/test/AutoDiff/validation-test/forward_mode.swift index 22b049158c432..d3c4280c6910e 100644 --- a/test/AutoDiff/validation-test/forward_mode.swift +++ b/test/AutoDiff/validation-test/forward_mode.swift @@ -1319,4 +1319,23 @@ ForwardModeTests.test("ForceUnwrapping") { expectEqual(5, forceUnwrap(Float(2))) } +//===----------------------------------------------------------------------===// +// Array methods from ArrayDifferentiation.swift +//===----------------------------------------------------------------------===// + +ForwardModeTests.test("Array.differentiableMap") { + let x: [Float] = [1, 2, 3] + let tan = Array.TangentVector([1, 1, 1]) + + func multiplyMap(_ a: [Float]) -> [Float] { + return a.differentiableMap({ x in 3 * x }) + } + expectEqual([3, 3, 3], differential(at: x, in: multiplyMap)(tan)) + + func squareMap(_ a: [Float]) -> [Float] { + return a.differentiableMap({ x in x * x }) + } + expectEqual([2, 4, 6], differential(at: x, in: squareMap)(tan)) +} + runAllTests()