Skip to content

Commit

Permalink
Model summary: add 1 decimal place
Browse files Browse the repository at this point in the history
Show 1999 parameters as 1.9 K and 1000 parameters as 1.0 K, rather than both as 1 K.
  • Loading branch information
jonashaag authored Nov 18, 2020
1 parent e7134a9 commit ba068bb
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions pytorch_lightning/core/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,15 +363,15 @@ def get_human_readable_count(number: int) -> str:
>>> get_human_readable_count(123)
'123 '
>>> get_human_readable_count(1234) # (one thousand)
'1 K'
'1.2 K'
>>> get_human_readable_count(2e6) # (two million)
'2 M'
'2.0 M'
>>> get_human_readable_count(3e9) # (three billion)
'3 B'
'3.0 B'
>>> get_human_readable_count(4e12) # (four trillion)
'4 T'
'4.0 T'
>>> get_human_readable_count(5e15) # (more than trillion)
'5,000 T'
'5,000.0 T'
Args:
number: a positive integer number
Expand All @@ -382,10 +382,12 @@ def get_human_readable_count(number: int) -> str:
"""
assert number >= 0
labels = PARAMETER_NUM_UNITS
if number < 1e3:
return f"{number:d} {labels[0]}"
num_digits = int(np.floor(np.log10(number)) + 1 if number > 0 else 1)
num_groups = int(np.ceil(num_digits / 3))
num_groups = min(num_groups, len(labels)) # don't abbreviate beyond trillions
shift = -3 * (num_groups - 1)
number = number * (10 ** shift)
index = num_groups - 1
return f"{int(number):,d} {labels[index]}"
return f"{number:,.1f} {labels[index]}"

0 comments on commit ba068bb

Please sign in to comment.