From d4a46211c3efd4452ffdb00d2f793bbf8125a8d5 Mon Sep 17 00:00:00 2001 From: Nicolas Kruchten Date: Fri, 8 Nov 2019 14:45:08 -0500 Subject: [PATCH 1/3] use three px-standard colorscale kwargs --- .../python/plotly/plotly/express/_imshow.py | 42 +++++++++++++++---- 1 file changed, 33 insertions(+), 9 deletions(-) diff --git a/packages/python/plotly/plotly/express/_imshow.py b/packages/python/plotly/plotly/express/_imshow.py index e41efbeaa0..f9f4c556d3 100644 --- a/packages/python/plotly/plotly/express/_imshow.py +++ b/packages/python/plotly/plotly/express/_imshow.py @@ -1,4 +1,5 @@ import plotly.graph_objs as go +from _plotly_utils.basevalidators import ColorscaleValidator import numpy as np # is it fine to depend on np here? _float_types = [] @@ -54,7 +55,15 @@ def _infer_zmax_from_type(img): return 2 ** 32 -def imshow(img, zmin=None, zmax=None, origin=None, colorscale=None): +def imshow( + img, + zmin=None, + zmax=None, + origin=None, + color_continuous_scale=None, + color_continuous_midpoint=None, + range_color=None, +): """ Display an image, i.e. data on a 2D regular raster. @@ -74,16 +83,24 @@ def imshow(img, zmin=None, zmax=None, origin=None, colorscale=None): zmin and zmax correspond to the min and max values of the datatype for integer datatypes (ie [0-255] for uint8 images, [0, 65535] for uint16 images, etc.). For a multichannel image of floats, the max of the image is computed and zmax is the - smallest power of 256 (1, 255, 65535) greater than this max value, + smallest power of 256 (1, 255, 65535) greater than this max value, with a 5% tolerance. For a single-channel image, the max of the image is used. origin : str, 'upper' or 'lower' (default 'upper') position of the [0, 0] pixel of the image array, in the upper left or lower left corner. The convention 'upper' is typically used for matrices and images. - colorscale : str - colormap used to map scalar data to colors (for a 2D image). This parameter is not used for - RGB or RGBA images. + color_continuous_scale : str or list of str + colormap used to map scalar data to colors (for a 2D image). This parameter is + not used for RGB or RGBA images. + + color_continuous_midpoint : number + If set, computes the bounds of the continuous color scale to have the desired + midpoint. + + range_color : list of two numbers + If provided, overrides auto-scaling on the continuous color scale, including + overriding `color_continuous_midpoint`. Returns ------- @@ -108,14 +125,21 @@ def imshow(img, zmin=None, zmax=None, origin=None, colorscale=None): # For 2d data, use Heatmap trace if img.ndim == 2: - if colorscale is None: - colorscale = "gray" - trace = go.Heatmap(z=img, zmin=zmin, zmax=zmax, colorscale=colorscale) + trace = go.Heatmap(z=img, zmin=zmin, zmax=zmax, coloraxis="coloraxis1") autorange = True if origin == "lower" else "reversed" layout = dict( xaxis=dict(scaleanchor="y", constrain="domain"), yaxis=dict(autorange=autorange, constrain="domain"), ) + colorscale_validator = ColorscaleValidator("colorscale", "imshow") + range_color = range_color or [None, None] + layout["coloraxis1"] = dict( + colorscale=colorscale_validator.validate_coerce(color_continuous_scale), + cmid=color_continuous_midpoint, + cmin=range_color[0], + cmax=range_color[1], + ) + # For 2D+RGB data, use Image trace elif img.ndim == 3 and img.shape[-1] in [3, 4]: if zmax is None and img.dtype is not np.uint8: @@ -127,7 +151,7 @@ def imshow(img, zmin=None, zmax=None, origin=None, colorscale=None): layout["yaxis"] = dict(autorange=True) else: raise ValueError( - "px.imshow only accepts 2D grayscale, RGB or RGBA images. " + "px.imshow only accepts 2D single-channel, RGB or RGBA images. " "An image of shape %s was provided" % str(img.shape) ) fig = go.Figure(data=trace, layout=layout) From f367c7227e7fd9b12c04db8fed98144780bcec32 Mon Sep 17 00:00:00 2001 From: Nicolas Kruchten Date: Fri, 8 Nov 2019 14:53:11 -0500 Subject: [PATCH 2/3] imshow defaults cascade --- .../python/plotly/plotly/express/_imshow.py | 21 ++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/packages/python/plotly/plotly/express/_imshow.py b/packages/python/plotly/plotly/express/_imshow.py index f9f4c556d3..316fc69d6d 100644 --- a/packages/python/plotly/plotly/express/_imshow.py +++ b/packages/python/plotly/plotly/express/_imshow.py @@ -1,5 +1,6 @@ import plotly.graph_objs as go from _plotly_utils.basevalidators import ColorscaleValidator +from ._core import apply_default_cascade import numpy as np # is it fine to depend on np here? _float_types = [] @@ -63,6 +64,10 @@ def imshow( color_continuous_scale=None, color_continuous_midpoint=None, range_color=None, + title=None, + template=None, + width=None, + height=None, ): """ Display an image, i.e. data on a 2D regular raster. @@ -118,6 +123,9 @@ def imshow( In order to update and customize the returned figure, use `go.Figure.update_traces` or `go.Figure.update_layout`. """ + args = locals() + apply_default_cascade(args) + img = np.asanyarray(img) # Cast bools to uint8 (also one byte) if img.dtype == np.bool: @@ -134,7 +142,9 @@ def imshow( colorscale_validator = ColorscaleValidator("colorscale", "imshow") range_color = range_color or [None, None] layout["coloraxis1"] = dict( - colorscale=colorscale_validator.validate_coerce(color_continuous_scale), + colorscale=colorscale_validator.validate_coerce( + args["color_continuous_scale"] + ), cmid=color_continuous_midpoint, cmin=range_color[0], cmax=range_color[1], @@ -154,5 +164,14 @@ def imshow( "px.imshow only accepts 2D single-channel, RGB or RGBA images. " "An image of shape %s was provided" % str(img.shape) ) + + layout_patch = dict() + for v in ["title", "height", "width"]: + if args[v]: + layout_patch[v] = args[v] + if "title" not in layout_patch and args["template"].layout.margin.t is None: + layout_patch["margin"] = {"t": 60} fig = go.Figure(data=trace, layout=layout) + fig.update_layout(layout_patch) + fig.update_layout(template=args["template"], overwrite=True) return fig From c5842bc9deaa2d5934fc32348e0ecff975fbae82 Mon Sep 17 00:00:00 2001 From: Nicolas Kruchten Date: Fri, 8 Nov 2019 15:10:22 -0500 Subject: [PATCH 3/3] docstrings and tests --- packages/python/plotly/plotly/express/_doc.py | 2 +- packages/python/plotly/plotly/express/_imshow.py | 16 +++++++++++++++- .../tests/test_core/test_px/test_imshow.py | 7 ++++--- 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/packages/python/plotly/plotly/express/_doc.py b/packages/python/plotly/plotly/express/_doc.py index e01ec87fe7..39a133ef22 100644 --- a/packages/python/plotly/plotly/express/_doc.py +++ b/packages/python/plotly/plotly/express/_doc.py @@ -283,7 +283,7 @@ ], title=["str", "The figure title."], template=[ - "str or Plotly.py template object", + "or dict or plotly.graph_objects.layout.Template instance", "The figure template name or definition.", ], width=["int (default `None`)", "The figure width in pixels."], diff --git a/packages/python/plotly/plotly/express/_imshow.py b/packages/python/plotly/plotly/express/_imshow.py index 316fc69d6d..81c8be3a77 100644 --- a/packages/python/plotly/plotly/express/_imshow.py +++ b/packages/python/plotly/plotly/express/_imshow.py @@ -97,7 +97,9 @@ def imshow( color_continuous_scale : str or list of str colormap used to map scalar data to colors (for a 2D image). This parameter is - not used for RGB or RGBA images. + not used for RGB or RGBA images. If a string is provided, it should be the name + of a known color scale, and if a list is provided, it should be a list of CSS- + compatible colors. color_continuous_midpoint : number If set, computes the bounds of the continuous color scale to have the desired @@ -107,6 +109,18 @@ def imshow( If provided, overrides auto-scaling on the continuous color scale, including overriding `color_continuous_midpoint`. + title : str + The figure title. + + template : str or dict or plotly.graph_objects.layout.Template instance + The figure template name or definition. + + width : number + The figure width in pixels. + + height: number + The figure height in pixels, defaults to 600. + Returns ------- fig : graph_objects.Figure containing the displayed image diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py index 8b6130b998..8bf1657c4e 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py @@ -50,9 +50,10 @@ def test_origin(): def test_colorscale(): fig = px.imshow(img_gray) - assert fig.data[0].colorscale[0] == (0.0, "rgb(0, 0, 0)") - fig = px.imshow(img_gray, colorscale="Viridis") - assert fig.data[0].colorscale[0] == (0.0, "#440154") + plasma_first_color = px.colors.sequential.Plasma[0] + assert fig.layout.coloraxis1.colorscale[0] == (0.0, plasma_first_color) + fig = px.imshow(img_gray, color_continuous_scale="Viridis") + assert fig.layout.coloraxis1.colorscale[0] == (0.0, "#440154") def test_wrong_dimensions():