diff --git a/Sources/TensorFlow/CMakeLists.txt b/Sources/TensorFlow/CMakeLists.txt index 4644b48fb..cdbd92d83 100644 --- a/Sources/TensorFlow/CMakeLists.txt +++ b/Sources/TensorFlow/CMakeLists.txt @@ -37,6 +37,8 @@ add_library(TensorFlow SHARED Core/Threading.swift Core/Utilities.swift Core/EuclideanDifferentiable.swift + Core/VectorProtocol.swift + Core/PointwiseMultiplicative.swift Core/ElementaryFunctions.swift Epochs/Algorithms.swift diff --git a/Sources/TensorFlow/Core/EuclideanDifferentiable.swift b/Sources/TensorFlow/Core/EuclideanDifferentiable.swift index 6a662e931..d5f4fadbc 100644 --- a/Sources/TensorFlow/Core/EuclideanDifferentiable.swift +++ b/Sources/TensorFlow/Core/EuclideanDifferentiable.swift @@ -110,4 +110,10 @@ where Element: EuclideanDifferentiable { out = Array.DifferentiableView.TangentVector(self.base.map { $0.differentiableVectorView }) } } +extension RNNCellInput: _EuclideanDifferentiable + where Input: EuclideanDifferentiable, State: EuclideanDifferentiable {} +extension RNNCellOutput: _EuclideanDifferentiable + where Output: EuclideanDifferentiable, State: EuclideanDifferentiable {} +extension Tensor: _EuclideanDifferentiable where Scalar: TensorFlowFloatingPoint {} + #endif diff --git a/Sources/TensorFlow/Core/PointwiseMultiplicative.swift b/Sources/TensorFlow/Core/PointwiseMultiplicative.swift new file mode 100644 index 000000000..ef256f928 --- /dev/null +++ b/Sources/TensorFlow/Core/PointwiseMultiplicative.swift @@ -0,0 +1,137 @@ +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import _Differentiation + +#if TENSORFLOW_USE_STANDARD_TOOLCHAIN +@_spi(Reflection) import Swift + +infix operator .*: MultiplicationPrecedence +infix operator .*=: AssignmentPrecedence + +/// Implementation detail of the reflection default implementation. +/// +/// Contains versions of functions in PointwiseMultiplicative that +/// operate over key paths and modify a child of `Root` in-place. +/// The key paths must all be WritableKeyPath. This is a workaround +/// to simulate having Self constraints. +public protocol _PointwiseMultiplicative { + /// lhs[keyPath: kp] .*= rhs[keyPath: kp] + static func _pointwiseMult(_ lhs: inout Root, _ rhs: Root, _ kp: PartialKeyPath) + /// out[keyPath: kp] = Self.one + static func _setOne(_ out: inout Root, _ kp: PartialKeyPath) + /// out[keyPath: kp] = out[keyPath: kp].reciprocal + static func _setReciprocal(_ out: inout Root, _ kp: PartialKeyPath) +} + +public protocol PointwiseMultiplicative: _PointwiseMultiplicative & AdditiveArithmetic { + /// The one value. + /// + /// One is the identity element for multiplication. For any value, + /// `x .* .one == x` and `.one .* x == x`. + static var one: Self { get } + + /// The multiplicative inverse of self. + /// + /// For any value, `x .* x.reciprocal == .one` and + /// `x.reciprocal .* x == .one`. + var reciprocal: Self { get } + + /// Multiplies two values and produces their product. + /// + /// - Parameters: + /// - lhs: The first value to multiply. + /// - rhs: The second value to multiply. + static func .* (lhs: Self, rhs: Self) -> Self + + /// Multiplies two values and produces their product. + /// + /// - Parameters: + /// - lhs: The first value to multiply. + /// - rhs: The second value to multiply. + static func .*= (lhs: inout Self, rhs: Self) +} + +extension PointwiseMultiplicative { + public static func .*= (lhs: inout Self, rhs: Self) { + lhs = lhs .* rhs + } +} + +extension PointwiseMultiplicative +where Self: ExpressibleByIntegerLiteral { + public static var one: Self { + return 1 + } +} + +extension PointwiseMultiplicative { + public static var one: Self { + var out = self.zero + visitChildren { kp, t in t._setOne(&out, kp) } + return out + } + public var reciprocal: Self { + var out = self + Self.visitChildren { kp, t in t._setReciprocal(&out, kp) } + return out + } + public static func .* (lhs: Self, rhs: Self) -> Self { + var out = lhs + visitChildren { kp, t in + t._pointwiseMult(&out, rhs, kp) + } + return out + } + public static func _pointwiseMult( + _ lhs: inout Root, _ rhs: Root, _ kp: PartialKeyPath + ) { + let kp = kp as! WritableKeyPath + lhs[keyPath: kp] .*= rhs[keyPath: kp] + } + public static func _setOne(_ out: inout Root, _ kp: PartialKeyPath) { + let kp = kp as! WritableKeyPath + out[keyPath: kp] = Self.one + } + public static func _setReciprocal(_ out: inout Root, _ kp: PartialKeyPath) { + let kp = kp as! WritableKeyPath + out[keyPath: kp] = out[keyPath: kp].reciprocal + } +} + +extension PointwiseMultiplicative { + internal static func visitChildren( + _ body: (PartialKeyPath, _PointwiseMultiplicative.Type) -> Void + ) { + if !_forEachFieldWithKeyPath( + of: Self.self, + body: { name, kp in + let valueType = type(of: kp).valueType + guard let valueType = valueType as? _PointwiseMultiplicative.Type else { + fatalError("not PointwiseMultiplicative: \(valueType)") + } + body(kp, valueType) + return true + }) + { + fatalError( + "Unreflectable member of \(Self.self) while implementing PointwiseMultiplicative.") + } + } +} + +extension Array.DifferentiableView: _PointwiseMultiplicative +where Element: Differentiable & PointwiseMultiplicative {} +extension Tensor: _PointwiseMultiplicative where Scalar: Numeric {} +#endif diff --git a/Sources/TensorFlow/Core/VectorProtocol.swift b/Sources/TensorFlow/Core/VectorProtocol.swift new file mode 100644 index 000000000..554535c29 --- /dev/null +++ b/Sources/TensorFlow/Core/VectorProtocol.swift @@ -0,0 +1,128 @@ +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import _Differentiation + +#if TENSORFLOW_USE_STANDARD_TOOLCHAIN +@_spi(Reflection) import Swift + +/// Implementation detail for reflection. +/// +/// This should contain the methods of `VectorProtocol` +/// that do not require Self constraints. +public protocol _VectorProtocol { + typealias VectorSpaceScalar = Float + + /// Adds the specified scalar to `self`. + mutating func add(_ x: VectorSpaceScalar) + + /// Subtracts the specified scalar to `self`. + mutating func subtract(_ x: VectorSpaceScalar) + + /// Scales `self` by the specified scalar. + mutating func scale(by scalar: VectorSpaceScalar) +} + +extension VectorProtocol { + internal static func visitChildren( + _ body: (PartialKeyPath, _VectorProtocol.Type) -> Void + ) { + if !_forEachFieldWithKeyPath( + of: Self.self, + body: { name, kp in + let valueType = type(of: kp).valueType + guard let valueType = valueType as? _VectorProtocol.Type else { + fatalError("not VectorProtocol: \(valueType)") + } + body(kp, valueType) + return true + }) + { + fatalError("Unreflectable member of \(Self.self) while implementing VectorProtocol.") + } + } +} + +extension _VectorProtocol { + static func add(_ v: inout Root, _ kp: PartialKeyPath, _ x: VectorSpaceScalar) { + v[keyPath: (kp as! WritableKeyPath)].add(x) + } + static func subtract(_ v: inout Root, _ kp: PartialKeyPath, _ x: VectorSpaceScalar) + { + v[keyPath: (kp as! WritableKeyPath)].subtract(x) + } + static func scale( + _ v: inout Root, _ kp: PartialKeyPath, by scalar: VectorSpaceScalar + ) { + v[keyPath: (kp as! WritableKeyPath)].scale(by: scalar) + } +} + +/// A type that represents an unranked vector space. Values of this type are +/// elements in this vector space and have either no shape or a static shape. +public protocol VectorProtocol: _VectorProtocol & AdditiveArithmetic { + /// The type of scalars in the vector space. + associatedtype VectorSpaceScalar = Float + + func adding(_ x: VectorSpaceScalar) -> Self + + mutating func add(_ x: VectorSpaceScalar) + + func subtracting(_ x: VectorSpaceScalar) -> Self + + mutating func subtract(_ x: VectorSpaceScalar) + + /// Returns `self` multiplied by the given scalar. + func scaled(by scalar: VectorSpaceScalar) -> Self + + /// Multiplies `self` by the given scalar. + mutating func scale(by scalar: VectorSpaceScalar) +} + +extension VectorProtocol { + public mutating func add(_ x: VectorSpaceScalar) { + self = adding(x) + } + + public mutating func subtract(_ x: VectorSpaceScalar) { + self = subtracting(x) + } + + public mutating func scale(by scalar: VectorSpaceScalar) { + self = scaled(by: scalar) + } +} + +extension VectorProtocol { + public func adding(_ x: VectorSpaceScalar) -> Self { + var out = self + Self.visitChildren { kp, t in t.add(&out, kp, x) } + return out + } + public func subtracting(_ x: VectorSpaceScalar) -> Self { + var out = self + Self.visitChildren { kp, t in t.subtract(&out, kp, x) } + return out + } + public func scaled(by scalar: VectorSpaceScalar) -> Self { + var out = self + Self.visitChildren { kp, t in t.scale(&out, kp, by: scalar) } + return out + } +} + +extension Tensor: _VectorProtocol where Scalar: TensorFlowFloatingPoint {} +extension Array.DifferentiableView: _VectorProtocol +where Element: Differentiable & VectorProtocol {} +#endif