From 34b8d58f0b5b12bc0237ff3035f7e1e12f7d3e8c Mon Sep 17 00:00:00 2001 From: Fabian Degen <106864199+degenfabian@users.noreply.github.com> Date: Thu, 24 Oct 2024 10:43:33 +0200 Subject: [PATCH] BUG: Fix index error when values exceed 999T (#581) * fix IndexError when values exceed 999T in _human_format function Signed-off-by: Fabian Degen * add testcase to test _human_format function with numbers larger 999T Signed-off-by: Fabian Degen * fix ruff format Signed-off-by: Fabian Degen --------- Signed-off-by: Fabian Degen Co-authored-by: Fabian Degen --- python/interpret-core/interpret/visual/plot.py | 10 +++++++--- python/interpret-core/tests/visual/test_plot.py | 17 +++++++++++++++++ 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/python/interpret-core/interpret/visual/plot.py b/python/interpret-core/interpret/visual/plot.py index d6bbb6a02..077f483f0 100644 --- a/python/interpret-core/interpret/visual/plot.py +++ b/python/interpret-core/interpret/visual/plot.py @@ -203,15 +203,19 @@ def extend_y_range(y): # Taken from: # https://stackoverflow.com/questions/579310/formatting-long-numbers-as-strings-in-python +# Adapted to handle numbers larger than 1e15 def _human_format(num): num = float(f"{num:.3g}") magnitude = 0 + suffixes = ["", "K", "M", "B", "T"] + + if abs(num) >= 1000 ** (len(suffixes)): # 1000 ^ 5 == 1e15 + return f"{num:.2e}" + while abs(num) >= 1000: magnitude += 1 num /= 1000.0 - return "{}{}".format( - f"{num:f}".rstrip("0").rstrip("."), ["", "K", "M", "B", "T"][magnitude] - ) + return "{}{}".format(f"{num:f}".rstrip("0").rstrip("."), suffixes[magnitude]) # TODO: Clean this up after validation. diff --git a/python/interpret-core/tests/visual/test_plot.py b/python/interpret-core/tests/visual/test_plot.py index b115a2015..dd30902ef 100644 --- a/python/interpret-core/tests/visual/test_plot.py +++ b/python/interpret-core/tests/visual/test_plot.py @@ -2,6 +2,7 @@ # Distributed under the MIT software license from interpret.visual.plot import plot_line +from interpret.visual.plot import plot_density def test_plot_line_bounds_smoke(): @@ -13,3 +14,19 @@ def test_plot_line_bounds_smoke(): } figure = plot_line(data_dict) assert figure.data[0].name == "Lower Bound" + + +def test_plot_density_large_numbers(): + """ + Test that density plots handle large numbers correctly using the new number formatting + """ + data_dict = { + "scores": [1.0, 1.0], + "names": [9e13, 1e14, 1e15], # 1e15 value will trigger new formatting + } + + figure = plot_density(data_dict) + + # The x-axis tick text should show ranges using our new formatting + assert "90T - 100T" in figure.layout.xaxis.ticktext[0] + assert "100T - 1.00e+15" in figure.layout.xaxis.ticktext[1]