diff --git a/torchinfo/enums.py b/torchinfo/enums.py index 6f403bb..5ee40d4 100644 --- a/torchinfo/enums.py +++ b/torchinfo/enums.py @@ -47,6 +47,7 @@ class Units(str, Enum): __slots__ = () AUTO = "auto" + KILOBYTES = "K" MEGABYTES = "M" GIGABYTES = "G" TERABYTES = "T" diff --git a/torchinfo/formatting.py b/torchinfo/formatting.py index 3b48405..3fc7fc2 100644 --- a/torchinfo/formatting.py +++ b/torchinfo/formatting.py @@ -21,6 +21,7 @@ Units.TERABYTES: 1e12, Units.GIGABYTES: 1e9, Units.MEGABYTES: 1e6, + Units.KILOBYTES: 1e3, Units.NONE: 1, } diff --git a/torchinfo/model_statistics.py b/torchinfo/model_statistics.py index 854bea8..40b58c2 100644 --- a/torchinfo/model_statistics.py +++ b/torchinfo/model_statistics.py @@ -75,18 +75,18 @@ def __repr__(self) -> str: macs = ModelStatistics.format_output_num( self.total_mult_adds, self.formatting.macs_units ) - input_size = self.to_megabytes(self.total_input) - output_bytes = self.to_megabytes(self.total_output_bytes) - param_bytes = self.to_megabytes(self.total_param_bytes) - total_bytes = self.to_megabytes( + input_size = self.to_readable(self.total_input) + output_bytes = self.to_readable(self.total_output_bytes) + param_bytes = self.to_readable(self.total_param_bytes) + total_bytes = self.to_readable( self.total_input + self.total_output_bytes + self.total_param_bytes ) summary_str += ( f"Total mult-adds{macs}\n{divider}\n" - f"Input size (MB): {input_size:0.2f}\n" - f"Forward/backward pass size (MB): {output_bytes:0.2f}\n" - f"Params size (MB): {param_bytes:0.2f}\n" - f"Estimated Total Size (MB): {total_bytes:0.2f}\n" + f"Input size ({input_size[0]}B): {input_size[1]:0.2f}\n" + f"Forward/backward pass size ({output_bytes[0]}B): {output_bytes[1]:0.2f}\n" + f"Params size ({param_bytes[0]}B): {param_bytes[1]:0.2f}\n" + f"Estimated Total Size ({total_bytes[0]}B): {total_bytes[1]:0.2f}\n" ) summary_str += divider return summary_str @@ -109,7 +109,9 @@ def to_readable(num: int, units: Units = Units.AUTO) -> tuple[Units, float]: return Units.TERABYTES, num / 1e12 if num >= 1e9: return Units.GIGABYTES, num / 1e9 - return Units.MEGABYTES, num / 1e6 + if num >= 1e6: + return Units.MEGABYTES, num / 1e6 + return Units.KILOBYTES, num / 1e3 return units, num / CONVERSION_FACTORS[units] @staticmethod