Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move imshow closer to px pattern #1885

Merged
merged 3 commits into from
Nov 8, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/python/plotly/plotly/express/_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@
],
title=["str", "The figure title."],
template=[
"str or Plotly.py template object",
"or dict or plotly.graph_objects.layout.Template instance",
"The figure template name or definition.",
],
width=["int (default `None`)", "The figure width in pixels."],
Expand Down
75 changes: 66 additions & 9 deletions packages/python/plotly/plotly/express/_imshow.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import plotly.graph_objs as go
from _plotly_utils.basevalidators import ColorscaleValidator
from ._core import apply_default_cascade
import numpy as np # is it fine to depend on np here?

_float_types = []
Expand Down Expand Up @@ -54,7 +56,19 @@ def _infer_zmax_from_type(img):
return 2 ** 32


def imshow(img, zmin=None, zmax=None, origin=None, colorscale=None):
def imshow(
img,
zmin=None,
zmax=None,
origin=None,
color_continuous_scale=None,
color_continuous_midpoint=None,
range_color=None,
title=None,
template=None,
width=None,
height=None,
):
"""
Display an image, i.e. data on a 2D regular raster.

Expand All @@ -74,16 +88,38 @@ def imshow(img, zmin=None, zmax=None, origin=None, colorscale=None):
zmin and zmax correspond to the min and max values of the datatype for integer
datatypes (ie [0-255] for uint8 images, [0, 65535] for uint16 images, etc.). For
a multichannel image of floats, the max of the image is computed and zmax is the
smallest power of 256 (1, 255, 65535) greater than this max value,
smallest power of 256 (1, 255, 65535) greater than this max value,
with a 5% tolerance. For a single-channel image, the max of the image is used.

origin : str, 'upper' or 'lower' (default 'upper')
position of the [0, 0] pixel of the image array, in the upper left or lower left
corner. The convention 'upper' is typically used for matrices and images.

colorscale : str
colormap used to map scalar data to colors (for a 2D image). This parameter is not used for
RGB or RGBA images.
color_continuous_scale : str or list of str
emmanuelle marked this conversation as resolved.
Show resolved Hide resolved
colormap used to map scalar data to colors (for a 2D image). This parameter is
not used for RGB or RGBA images. If a string is provided, it should be the name
of a known color scale, and if a list is provided, it should be a list of CSS-
compatible colors.

color_continuous_midpoint : number
If set, computes the bounds of the continuous color scale to have the desired
midpoint.

range_color : list of two numbers
If provided, overrides auto-scaling on the continuous color scale, including
overriding `color_continuous_midpoint`.

title : str
The figure title.

template : str or dict or plotly.graph_objects.layout.Template instance
The figure template name or definition.

width : number
The figure width in pixels.

height: number
The figure height in pixels, defaults to 600.

Returns
-------
Expand All @@ -101,21 +137,33 @@ def imshow(img, zmin=None, zmax=None, origin=None, colorscale=None):
In order to update and customize the returned figure, use
`go.Figure.update_traces` or `go.Figure.update_layout`.
"""
args = locals()
apply_default_cascade(args)

img = np.asanyarray(img)
# Cast bools to uint8 (also one byte)
if img.dtype == np.bool:
img = 255 * img.astype(np.uint8)

# For 2d data, use Heatmap trace
if img.ndim == 2:
if colorscale is None:
colorscale = "gray"
trace = go.Heatmap(z=img, zmin=zmin, zmax=zmax, colorscale=colorscale)
trace = go.Heatmap(z=img, zmin=zmin, zmax=zmax, coloraxis="coloraxis1")
autorange = True if origin == "lower" else "reversed"
layout = dict(
xaxis=dict(scaleanchor="y", constrain="domain"),
yaxis=dict(autorange=autorange, constrain="domain"),
)
colorscale_validator = ColorscaleValidator("colorscale", "imshow")
range_color = range_color or [None, None]
layout["coloraxis1"] = dict(
colorscale=colorscale_validator.validate_coerce(
args["color_continuous_scale"]
),
cmid=color_continuous_midpoint,
cmin=range_color[0],
cmax=range_color[1],
)

# For 2D+RGB data, use Image trace
elif img.ndim == 3 and img.shape[-1] in [3, 4]:
if zmax is None and img.dtype is not np.uint8:
Expand All @@ -127,8 +175,17 @@ def imshow(img, zmin=None, zmax=None, origin=None, colorscale=None):
layout["yaxis"] = dict(autorange=True)
else:
raise ValueError(
"px.imshow only accepts 2D grayscale, RGB or RGBA images. "
"px.imshow only accepts 2D single-channel, RGB or RGBA images. "
"An image of shape %s was provided" % str(img.shape)
)

layout_patch = dict()
for v in ["title", "height", "width"]:
if args[v]:
layout_patch[v] = args[v]
if "title" not in layout_patch and args["template"].layout.margin.t is None:
layout_patch["margin"] = {"t": 60}
fig = go.Figure(data=trace, layout=layout)
fig.update_layout(layout_patch)
fig.update_layout(template=args["template"], overwrite=True)
return fig
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,10 @@ def test_origin():

def test_colorscale():
fig = px.imshow(img_gray)
assert fig.data[0].colorscale[0] == (0.0, "rgb(0, 0, 0)")
fig = px.imshow(img_gray, colorscale="Viridis")
assert fig.data[0].colorscale[0] == (0.0, "#440154")
plasma_first_color = px.colors.sequential.Plasma[0]
assert fig.layout.coloraxis1.colorscale[0] == (0.0, plasma_first_color)
fig = px.imshow(img_gray, color_continuous_scale="Viridis")
assert fig.layout.coloraxis1.colorscale[0] == (0.0, "#440154")


def test_wrong_dimensions():
Expand Down