Skip to content

Commit

Permalink
Model summary: add 1 decimal place (#4745)
Browse files Browse the repository at this point in the history
Show 1999 parameters as 1.9 K and 1000 parameters as 1.0 K, rather than both as 1 K.

Co-authored-by: chaton <thomas@grid.ai>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
(cherry picked from commit 8dfbf63)
  • Loading branch information
jonashaag authored and Borda committed Nov 23, 2020
1 parent 4cb51f2 commit 6ccbf66
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
2 changes: 1 addition & 1 deletion docs/source/debugging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ You can also display the intermediate input- and output sizes of all your layers
--------------------------------------------------------------
0 | net | Sequential | 132 K | [10, 256] | [10, 512]
1 | net.0 | Linear | 131 K | [10, 256] | [10, 512]
2 | net.1 | BatchNorm1d | 1 K | [10, 512] | [10, 512]
2 | net.1 | BatchNorm1d | 1.0 K | [10, 512] | [10, 512]
when you call ``.fit()`` on the Trainer. This can help you find bugs in the composition of your layers.

Expand Down
17 changes: 10 additions & 7 deletions pytorch_lightning/core/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ class ModelSummary(object):
--------------------------------------------------------------
0 | net | Sequential | 132 K | [10, 256] | [10, 512]
1 | net.0 | Linear | 131 K | [10, 256] | [10, 512]
2 | net.1 | BatchNorm1d | 1 K | [10, 512] | [10, 512]
2 | net.1 | BatchNorm1d | 1.0 K | [10, 512] | [10, 512]
"""

MODE_TOP = "top"
Expand Down Expand Up @@ -363,13 +363,13 @@ def get_human_readable_count(number: int) -> str:
>>> get_human_readable_count(123)
'123 '
>>> get_human_readable_count(1234) # (one thousand)
'1 K'
'1.2 K'
>>> get_human_readable_count(2e6) # (two million)
'2 M'
'2.0 M'
>>> get_human_readable_count(3e9) # (three billion)
'3 B'
>>> get_human_readable_count(4e12) # (four trillion)
'4 T'
'3.0 B'
>>> get_human_readable_count(4e14) # (four hundred trillion)
'400 T'
>>> get_human_readable_count(5e15) # (more than trillion)
'5,000 T'
Expand All @@ -388,4 +388,7 @@ def get_human_readable_count(number: int) -> str:
shift = -3 * (num_groups - 1)
number = number * (10 ** shift)
index = num_groups - 1
return f"{int(number):,d} {labels[index]}"
if index < 1 or number >= 100:
return f"{int(number):,d} {labels[index]}"
else:
return f"{number:,.1f} {labels[index]}"

0 comments on commit 6ccbf66

Please sign in to comment.