Skip to content

Latest commit

 

History

History
25 lines (16 loc) · 930 Bytes

README.md

File metadata and controls

25 lines (16 loc) · 930 Bytes

Gradient flow check in Pytorch

Check that the gradient flow is proper in the network by recording the average gradients per layer in every training iteration and then plotting them at the end. If the average gradients are zero in the initial layers of the network then probably your network is too deep for the gradient to flow.

Usage

loss = self.criterion(outputs, labels)  
loss.backward()
plot_grad_flow(model.named_parameters()) # version 1
# OR
plot_grad_flow_v2(model.named_parameters()) # version 2

Result

Bad gradient flow:

Bad gradient

Good gradient flow:

Good gradient

Repo based on this pytorch discuss post.