From b026cf2c1d470c6df1788a8f742c20acca67db83 Mon Sep 17 00:00:00 2001 From: tangwei94 <34451674+tangwei94@users.noreply.github.com> Date: Fri, 14 Jun 2024 22:36:20 +0200 Subject: [PATCH] small fix in the backward rule of `norm` (#131) * small fix in backward rule of `norm` * format `norm_pullback` * use `hypot` --- ext/TensorKitChainRulesCoreExt.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ext/TensorKitChainRulesCoreExt.jl b/ext/TensorKitChainRulesCoreExt.jl index 514d4749..6ff4b791 100644 --- a/ext/TensorKitChainRulesCoreExt.jl +++ b/ext/TensorKitChainRulesCoreExt.jl @@ -172,7 +172,9 @@ end function ChainRulesCore.rrule(::typeof(norm), a::AbstractTensorMap, p::Real=2) p == 2 || error("currently only implemented for p = 2") n = norm(a, p) - norm_pullback(Δn) = NoTangent(), a * (Δn' + Δn) / (n * 2), NoTangent() + function norm_pullback(Δn) + return NoTangent(), a * (Δn' + Δn) / 2 / hypot(n, eps(one(n))), NoTangent() + end return n, norm_pullback end