diff --git a/orangewidget/tests/test_matplotlib_export.py b/orangewidget/tests/test_matplotlib_export.py index dc7b27e4b..3a0f073e7 100644 --- a/orangewidget/tests/test_matplotlib_export.py +++ b/orangewidget/tests/test_matplotlib_export.py @@ -1,7 +1,10 @@ +import numpy as np import pyqtgraph as pg from orangewidget.tests.base import GuiTest -from orangewidget.utils.matplotlib_export import scatterplot_code +from orangewidget.utils.matplotlib_export import ( + scatterplot_code, numpy_repr, compress_if_all_same, numpy_repr_int +) def add_intro(a): @@ -15,8 +18,24 @@ class TestScatterPlot(GuiTest): def test_scatterplot_simple(self): plotWidget = pg.PlotWidget(background="w") scatterplot = pg.ScatterPlotItem() - scatterplot.setData(x=[1, 2, 3], y=[3, 2, 1]) + scatterplot.setData( + x=np.array([1., 2, 3]), + y=np.array([3., 2, 1]), + size=np.array([1., 1, 1]) + ) plotWidget.addItem(scatterplot) code = scatterplot_code(scatterplot) self.assertIn("plt.scatter", code) exec(add_intro(code), {}) + + def test_utils(self): + a = np.array([1.5, 2.5]) + self.assertIn("1.5, 2.5", numpy_repr(a)) + a = np.array([1, 1]) + v = compress_if_all_same(a) + self.assertEqual(v, 1) + self.assertEqual(repr(v), "1") + self.assertIs(type(v), int) + a = np.array([1, 2], dtype=int) + v = numpy_repr_int(a) + self.assertIn("1, 2", v) diff --git a/orangewidget/utils/matplotlib_export.py b/orangewidget/utils/matplotlib_export.py index f82a07feb..1bccf470a 100644 --- a/orangewidget/utils/matplotlib_export.py +++ b/orangewidget/utils/matplotlib_export.py @@ -14,7 +14,7 @@ def numpy_repr(a): # avoid numpy repr as it changes between versions # TODO handle numpy repr differences if isinstance(a, np.ndarray): - return "array(" + repr(list(a)) + ")" + return "array(" + repr(a.tolist()) + ")" try: np.set_printoptions(threshold=10**10) return repr(a) @@ -25,12 +25,20 @@ def numpy_repr(a): def numpy_repr_int(a): # avoid numpy repr as it changes between versions # TODO handle numpy repr differences - return "array(" + repr(list(a)) + ", dtype='int')" + if isinstance(a, np.ndarray): + a = a.tolist() + else: + a = list(a) + return "array(" + repr(a) + ", dtype='int')" def compress_if_all_same(l): s = set(l) - return s.pop() if len(s) == 1 else l + if len(s) == 1: + v = s.pop() + return v.item() if isinstance(v, np.generic) else v + else: + return l def is_sequence_not_string(a): @@ -188,6 +196,7 @@ def scene_code(scene): code = [] code.append("import matplotlib.pyplot as plt") + code.append("import numpy as np") code.append("from numpy import array") code.append("")