Skip to content

Commit

Permalink
small fix in the backward rule of norm (#131)
Browse files Browse the repository at this point in the history
* small fix in backward rule of `norm`

* format `norm_pullback`

* use `hypot`
  • Loading branch information
tangwei94 authored Jun 14, 2024
1 parent b5096da commit b026cf2
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion ext/TensorKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,9 @@ end
function ChainRulesCore.rrule(::typeof(norm), a::AbstractTensorMap, p::Real=2)
p == 2 || error("currently only implemented for p = 2")
n = norm(a, p)
norm_pullback(Δn) = NoTangent(), a * (Δn' + Δn) / (n * 2), NoTangent()
function norm_pullback(Δn)
return NoTangent(), a * (Δn' + Δn) / 2 / hypot(n, eps(one(n))), NoTangent()
end
return n, norm_pullback
end

Expand Down

0 comments on commit b026cf2

Please sign in to comment.