Skip to content

Commit 60b0f69

Browse files
Merge pull request #1885 from plotly/imshow_px
Move imshow closer to px pattern
2 parents 16ae853 + c5842bc commit 60b0f69

File tree

3 files changed

+71
-13
lines changed

3 files changed

+71
-13
lines changed

Diff for: packages/python/plotly/plotly/express/_doc.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@
283283
],
284284
title=["str", "The figure title."],
285285
template=[
286-
"str or Plotly.py template object",
286+
"or dict or plotly.graph_objects.layout.Template instance",
287287
"The figure template name or definition.",
288288
],
289289
width=["int (default `None`)", "The figure width in pixels."],

Diff for: packages/python/plotly/plotly/express/_imshow.py

+66-9
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import plotly.graph_objs as go
2+
from _plotly_utils.basevalidators import ColorscaleValidator
3+
from ._core import apply_default_cascade
24
import numpy as np # is it fine to depend on np here?
35

46
_float_types = []
@@ -54,7 +56,19 @@ def _infer_zmax_from_type(img):
5456
return 2 ** 32
5557

5658

57-
def imshow(img, zmin=None, zmax=None, origin=None, colorscale=None):
59+
def imshow(
60+
img,
61+
zmin=None,
62+
zmax=None,
63+
origin=None,
64+
color_continuous_scale=None,
65+
color_continuous_midpoint=None,
66+
range_color=None,
67+
title=None,
68+
template=None,
69+
width=None,
70+
height=None,
71+
):
5872
"""
5973
Display an image, i.e. data on a 2D regular raster.
6074
@@ -74,16 +88,38 @@ def imshow(img, zmin=None, zmax=None, origin=None, colorscale=None):
7488
zmin and zmax correspond to the min and max values of the datatype for integer
7589
datatypes (ie [0-255] for uint8 images, [0, 65535] for uint16 images, etc.). For
7690
a multichannel image of floats, the max of the image is computed and zmax is the
77-
smallest power of 256 (1, 255, 65535) greater than this max value,
91+
smallest power of 256 (1, 255, 65535) greater than this max value,
7892
with a 5% tolerance. For a single-channel image, the max of the image is used.
7993
8094
origin : str, 'upper' or 'lower' (default 'upper')
8195
position of the [0, 0] pixel of the image array, in the upper left or lower left
8296
corner. The convention 'upper' is typically used for matrices and images.
8397
84-
colorscale : str
85-
colormap used to map scalar data to colors (for a 2D image). This parameter is not used for
86-
RGB or RGBA images.
98+
color_continuous_scale : str or list of str
99+
colormap used to map scalar data to colors (for a 2D image). This parameter is
100+
not used for RGB or RGBA images. If a string is provided, it should be the name
101+
of a known color scale, and if a list is provided, it should be a list of CSS-
102+
compatible colors.
103+
104+
color_continuous_midpoint : number
105+
If set, computes the bounds of the continuous color scale to have the desired
106+
midpoint.
107+
108+
range_color : list of two numbers
109+
If provided, overrides auto-scaling on the continuous color scale, including
110+
overriding `color_continuous_midpoint`.
111+
112+
title : str
113+
The figure title.
114+
115+
template : str or dict or plotly.graph_objects.layout.Template instance
116+
The figure template name or definition.
117+
118+
width : number
119+
The figure width in pixels.
120+
121+
height: number
122+
The figure height in pixels, defaults to 600.
87123
88124
Returns
89125
-------
@@ -101,21 +137,33 @@ def imshow(img, zmin=None, zmax=None, origin=None, colorscale=None):
101137
In order to update and customize the returned figure, use
102138
`go.Figure.update_traces` or `go.Figure.update_layout`.
103139
"""
140+
args = locals()
141+
apply_default_cascade(args)
142+
104143
img = np.asanyarray(img)
105144
# Cast bools to uint8 (also one byte)
106145
if img.dtype == np.bool:
107146
img = 255 * img.astype(np.uint8)
108147

109148
# For 2d data, use Heatmap trace
110149
if img.ndim == 2:
111-
if colorscale is None:
112-
colorscale = "gray"
113-
trace = go.Heatmap(z=img, zmin=zmin, zmax=zmax, colorscale=colorscale)
150+
trace = go.Heatmap(z=img, zmin=zmin, zmax=zmax, coloraxis="coloraxis1")
114151
autorange = True if origin == "lower" else "reversed"
115152
layout = dict(
116153
xaxis=dict(scaleanchor="y", constrain="domain"),
117154
yaxis=dict(autorange=autorange, constrain="domain"),
118155
)
156+
colorscale_validator = ColorscaleValidator("colorscale", "imshow")
157+
range_color = range_color or [None, None]
158+
layout["coloraxis1"] = dict(
159+
colorscale=colorscale_validator.validate_coerce(
160+
args["color_continuous_scale"]
161+
),
162+
cmid=color_continuous_midpoint,
163+
cmin=range_color[0],
164+
cmax=range_color[1],
165+
)
166+
119167
# For 2D+RGB data, use Image trace
120168
elif img.ndim == 3 and img.shape[-1] in [3, 4]:
121169
if zmax is None and img.dtype is not np.uint8:
@@ -127,8 +175,17 @@ def imshow(img, zmin=None, zmax=None, origin=None, colorscale=None):
127175
layout["yaxis"] = dict(autorange=True)
128176
else:
129177
raise ValueError(
130-
"px.imshow only accepts 2D grayscale, RGB or RGBA images. "
178+
"px.imshow only accepts 2D single-channel, RGB or RGBA images. "
131179
"An image of shape %s was provided" % str(img.shape)
132180
)
181+
182+
layout_patch = dict()
183+
for v in ["title", "height", "width"]:
184+
if args[v]:
185+
layout_patch[v] = args[v]
186+
if "title" not in layout_patch and args["template"].layout.margin.t is None:
187+
layout_patch["margin"] = {"t": 60}
133188
fig = go.Figure(data=trace, layout=layout)
189+
fig.update_layout(layout_patch)
190+
fig.update_layout(template=args["template"], overwrite=True)
134191
return fig

Diff for: packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,10 @@ def test_origin():
5050

5151
def test_colorscale():
5252
fig = px.imshow(img_gray)
53-
assert fig.data[0].colorscale[0] == (0.0, "rgb(0, 0, 0)")
54-
fig = px.imshow(img_gray, colorscale="Viridis")
55-
assert fig.data[0].colorscale[0] == (0.0, "#440154")
53+
plasma_first_color = px.colors.sequential.Plasma[0]
54+
assert fig.layout.coloraxis1.colorscale[0] == (0.0, plasma_first_color)
55+
fig = px.imshow(img_gray, color_continuous_scale="Viridis")
56+
assert fig.layout.coloraxis1.colorscale[0] == (0.0, "#440154")
5657

5758

5859
def test_wrong_dimensions():

0 commit comments

Comments
 (0)