From 6ccbf66cc3784c876e8da36fa80abb2e95f7f12d Mon Sep 17 00:00:00 2001 From: Jonas Haag Date: Fri, 20 Nov 2020 23:22:21 +0100 Subject: [PATCH] Model summary: add 1 decimal place (#4745) Show 1999 parameters as 1.9 K and 1000 parameters as 1.0 K, rather than both as 1 K. Co-authored-by: chaton Co-authored-by: Jirka Borovec Co-authored-by: Sean Naren (cherry picked from commit 8dfbf6371bd552bed3ee572c6db5f378e28c9cd5) --- docs/source/debugging.rst | 2 +- pytorch_lightning/core/memory.py | 17 ++++++++++------- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/docs/source/debugging.rst b/docs/source/debugging.rst index 845f86a52b231..fea230d67d016 100644 --- a/docs/source/debugging.rst +++ b/docs/source/debugging.rst @@ -102,7 +102,7 @@ You can also display the intermediate input- and output sizes of all your layers -------------------------------------------------------------- 0 | net | Sequential | 132 K | [10, 256] | [10, 512] 1 | net.0 | Linear | 131 K | [10, 256] | [10, 512] - 2 | net.1 | BatchNorm1d | 1 K | [10, 512] | [10, 512] + 2 | net.1 | BatchNorm1d | 1.0 K | [10, 512] | [10, 512] when you call ``.fit()`` on the Trainer. This can help you find bugs in the composition of your layers. diff --git a/pytorch_lightning/core/memory.py b/pytorch_lightning/core/memory.py index fe7d0c37ae6e0..58a556e0e6727 100644 --- a/pytorch_lightning/core/memory.py +++ b/pytorch_lightning/core/memory.py @@ -160,7 +160,7 @@ class ModelSummary(object): -------------------------------------------------------------- 0 | net | Sequential | 132 K | [10, 256] | [10, 512] 1 | net.0 | Linear | 131 K | [10, 256] | [10, 512] - 2 | net.1 | BatchNorm1d | 1 K | [10, 512] | [10, 512] + 2 | net.1 | BatchNorm1d | 1.0 K | [10, 512] | [10, 512] """ MODE_TOP = "top" @@ -363,13 +363,13 @@ def get_human_readable_count(number: int) -> str: >>> get_human_readable_count(123) '123 ' >>> get_human_readable_count(1234) # (one thousand) - '1 K' + '1.2 K' >>> get_human_readable_count(2e6) # (two million) - '2 M' + '2.0 M' >>> get_human_readable_count(3e9) # (three billion) - '3 B' - >>> get_human_readable_count(4e12) # (four trillion) - '4 T' + '3.0 B' + >>> get_human_readable_count(4e14) # (four hundred trillion) + '400 T' >>> get_human_readable_count(5e15) # (more than trillion) '5,000 T' @@ -388,4 +388,7 @@ def get_human_readable_count(number: int) -> str: shift = -3 * (num_groups - 1) number = number * (10 ** shift) index = num_groups - 1 - return f"{int(number):,d} {labels[index]}" + if index < 1 or number >= 100: + return f"{int(number):,d} {labels[index]}" + else: + return f"{number:,.1f} {labels[index]}"