diff --git a/docs/DifferentiableProgramming.md b/docs/DifferentiableProgramming.md index 91b396adada17..0e1196f138cca 100644 --- a/docs/DifferentiableProgramming.md +++ b/docs/DifferentiableProgramming.md @@ -1192,7 +1192,14 @@ extension Optional: Differentiable where Wrapped: Differentiable { @noDerivative public var zeroTangentVectorInitializer: () -> TangentVector { - { TangentVector(.zero) } + switch self { + case nil: + return { TangentVector(nil) } + case let x?: + return { [zeroTanInit = x.zeroTangentVectorInitializer] in + TangentVector(zeroTanInit()) + } + } } } ``` diff --git a/stdlib/public/Differentiation/ArrayDifferentiation.swift b/stdlib/public/Differentiation/ArrayDifferentiation.swift index 177e9a3572d6e..046af6f186095 100644 --- a/stdlib/public/Differentiation/ArrayDifferentiation.swift +++ b/stdlib/public/Differentiation/ArrayDifferentiation.swift @@ -82,6 +82,12 @@ where Element: Differentiable { base[i].move(along: direction.base[i]) } } + + /// A closure that produces a `TangentVector` of zeros with the same + /// `count` as `self`. + public var zeroTangentVectorInitializer: () -> TangentVector { + return base.zeroTangentVectorInitializer + } } extension Array.DifferentiableView: Equatable diff --git a/stdlib/public/Differentiation/CMakeLists.txt b/stdlib/public/Differentiation/CMakeLists.txt index 0fa1ed82d6c8e..6e0257aa3f73e 100644 --- a/stdlib/public/Differentiation/CMakeLists.txt +++ b/stdlib/public/Differentiation/CMakeLists.txt @@ -16,6 +16,7 @@ add_swift_target_library(swift_Differentiation ${SWIFT_STDLIB_LIBRARY_BUILD_TYPE DifferentiationUtilities.swift AnyDifferentiable.swift ArrayDifferentiation.swift + OptionalDifferentiation.swift GYB_SOURCES FloatingPointDifferentiation.swift.gyb diff --git a/stdlib/public/Differentiation/OptionalDifferentiation.swift b/stdlib/public/Differentiation/OptionalDifferentiation.swift new file mode 100644 index 0000000000000..c8aa82f01e980 --- /dev/null +++ b/stdlib/public/Differentiation/OptionalDifferentiation.swift @@ -0,0 +1,83 @@ +//===--- OptionalDifferentiation.swift ------------------------*- swift -*-===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2020 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// + +import Swift + +extension Optional: Differentiable where Wrapped: Differentiable { + public struct TangentVector: Differentiable, AdditiveArithmetic { + public typealias TangentVector = Self + + public var value: Wrapped.TangentVector? + + public init(_ value: Wrapped.TangentVector?) { + self.value = value + } + + public static var zero: Self { + return Self(.zero) + } + + public static func + (lhs: Self, rhs: Self) -> Self { + switch (lhs.value, rhs.value) { + case (nil, nil): return Self(nil) + case let (x?, nil): return Self(x) + case let (nil, y?): return Self(y) + case let (x?, y?): return Self(x + y) + } + } + + public static func - (lhs: Self, rhs: Self) -> Self { + switch (lhs.value, rhs.value) { + case (nil, nil): return Self(nil) + case let (x?, nil): return Self(x) + case let (nil, y?): return Self(.zero - y) + case let (x?, y?): return Self(x - y) + } + } + + public mutating func move(along direction: TangentVector) { + if let value = direction.value { + self.value?.move(along: value) + } + } + + @noDerivative + public var zeroTangentVectorInitializer: () -> TangentVector { + switch value { + case nil: + return { Self(nil) } + case let x?: + return { [zeroTanInit = x.zeroTangentVectorInitializer] in + Self(zeroTanInit()) + } + } + } + } + + public mutating func move(along direction: TangentVector) { + if let value = direction.value { + self?.move(along: value) + } + } + + @noDerivative + public var zeroTangentVectorInitializer: () -> TangentVector { + switch self { + case nil: + return { TangentVector(nil) } + case let x?: + return { [zeroTanInit = x.zeroTangentVectorInitializer] in + TangentVector(zeroTanInit()) + } + } + } +} diff --git a/test/AutoDiff/stdlib/optional.swift b/test/AutoDiff/stdlib/optional.swift new file mode 100644 index 0000000000000..5039f8098294e --- /dev/null +++ b/test/AutoDiff/stdlib/optional.swift @@ -0,0 +1,103 @@ +// RUN: %target-run-simple-swift +// REQUIRES: executable_test + +import _Differentiation +import StdlibUnittest + +var OptionalDifferentiationTests = TestSuite("OptionalDifferentiation") + +OptionalDifferentiationTests.test("Optional operations") { + // Differentiable.move(along:) + do { + var some: Float? = 2 + some.move(along: .init(3)) + expectEqual(5, some) + + var none: Float? = nil + none.move(along: .init(3)) + expectEqual(nil, none) + } + + // Differentiable.zeroTangentVectorInitializer + do { + let some: [Float]? = [1, 2, 3] + expectEqual(.init([0, 0, 0]), some.zeroTangentVectorInitializer()) + + let none: [Float]? = nil + expectEqual(.init(nil), none.zeroTangentVectorInitializer()) + } +} + +OptionalDifferentiationTests.test("Optional.TangentVector operations") { + // Differentiable.move(along:) + do { + var some: Optional.TangentVector = .init(2) + some.move(along: .init(3)) + expectEqual(5, some.value) + + var none: Optional.TangentVector = .init(nil) + none.move(along: .init(3)) + expectEqual(nil, none.value) + + var nestedSome: Optional>.TangentVector = .init(.init(2)) + nestedSome.move(along: .init(.init(3))) + expectEqual(.init(5), nestedSome.value) + + var nestedNone: Optional>.TangentVector = .init(.init(nil)) + nestedNone.move(along: .init(.init(3))) + expectEqual(.init(nil), nestedNone.value) + } + + // Differentiable.zeroTangentVectorInitializer + do { + let some: [Float]? = [1, 2, 3] + expectEqual(.init([0, 0, 0]), some.zeroTangentVectorInitializer()) + + let none: [Float]? = nil + expectEqual(.init(nil), none.zeroTangentVectorInitializer()) + + let nestedSome: [Float]?? = [1, 2, 3] + expectEqual(.init(.init([0, 0, 0])), nestedSome.zeroTangentVectorInitializer()) + + let nestedNone: [Float]?? = nil + expectEqual(.init(nil), nestedNone.zeroTangentVectorInitializer()) + } + + // AdditiveArithmetic.zero + expectEqual(.init(Float.zero), Float?.TangentVector.zero) + expectEqual(.init([Float].TangentVector.zero), [Float]?.TangentVector.zero) + + expectEqual(.init(.init(Float.zero)), Float??.TangentVector.zero) + expectEqual(.init(.init([Float].TangentVector.zero)), [Float]??.TangentVector.zero) + + // AdditiveArithmetic.+, AdditiveArithmetic.- + do { + let some: Optional.TangentVector = .init(2) + let none: Optional.TangentVector = .init(nil) + + expectEqual(.init(4), some + some) + expectEqual(.init(2), some + none) + expectEqual(.init(2), none + some) + expectEqual(.init(nil), none + none) + + expectEqual(.init(0), some - some) + expectEqual(.init(2), some - none) + expectEqual(.init(-2), none - some) + expectEqual(.init(nil), none - none) + + let nestedSome: Optional>.TangentVector = .init(.init(2)) + let nestedNone: Optional>.TangentVector = .init(.init(nil)) + + expectEqual(.init(.init(4)), nestedSome + nestedSome) + expectEqual(.init(.init(2)), nestedSome + nestedNone) + expectEqual(.init(.init(2)), nestedNone + nestedSome) + expectEqual(.init(.init(nil)), nestedNone + nestedNone) + + expectEqual(.init(.init(0)), nestedSome - nestedSome) + expectEqual(.init(.init(2)), nestedSome - nestedNone) + expectEqual(.init(.init(-2)), nestedNone - nestedSome) + expectEqual(.init(.init(nil)), nestedNone - nestedNone) + } +} + +runAllTests()