diff --git a/Sources/TensorFlow/Layers/Dense.swift b/Sources/TensorFlow/Layers/Dense.swift index 1583c19f4..3a86c13cf 100644 --- a/Sources/TensorFlow/Layers/Dense.swift +++ b/Sources/TensorFlow/Layers/Dense.swift @@ -62,7 +62,7 @@ public struct Dense: Layer { } // TODO(TF-433): Remove custom derivative after `try_apply` differentiation is supported. - @derivative(of: init) + @derivative(of: init, wrt: weight) @usableFromInline static func vjpInit( weight: Tensor,