Skip to content

Use x and y parameters for Image trace in imshow (for RGB or binary_string=True) #2761

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Nov 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 54 additions & 11 deletions packages/python/plotly/plotly/express/_imshow.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,23 +204,19 @@ def imshow(
args = locals()
apply_default_cascade(args)
labels = labels.copy()
img_is_xarray = False
# ----- Define x and y, set labels if img is an xarray -------------------
if xarray_imported and isinstance(img, xarray.DataArray):
if binary_string:
raise ValueError(
"It is not possible to use binary image strings for xarrays."
"Please pass your data as a numpy array instead using"
"`img.values`"
)
img_is_xarray = True
y_label, x_label = img.dims[0], img.dims[1]
# np.datetime64 is not handled correctly by go.Heatmap
for ax in [x_label, y_label]:
if np.issubdtype(img.coords[ax].dtype, np.datetime64):
img.coords[ax] = img.coords[ax].astype(str)
if x is None:
x = img.coords[x_label]
x = img.coords[x_label].values
if y is None:
y = img.coords[y_label]
y = img.coords[y_label].values
if aspect is None:
aspect = "auto"
if labels.get("x", None) is None:
Expand Down Expand Up @@ -330,6 +326,42 @@ def imshow(
_vectorize_zvalue(zmin, mode="min"),
_vectorize_zvalue(zmax, mode="max"),
)
x0, y0, dx, dy = (None,) * 4
error_msg_xarray = (
"Non-numerical coordinates were passed with xarray `img`, but "
"the Image trace cannot handle it. Please use `binary_string=False` "
"for 2D data or pass instead the numpy array `img.values` to `px.imshow`."
)
if x is not None:
x = np.asanyarray(x)
if np.issubdtype(x.dtype, np.number):
x0 = x[0]
dx = x[1] - x[0]
else:
error_msg = (
error_msg_xarray
if img_is_xarray
else (
"Only numerical values are accepted for the `x` parameter "
"when an Image trace is used."
)
)
raise ValueError(error_msg)
if y is not None:
y = np.asanyarray(y)
if np.issubdtype(y.dtype, np.number):
y0 = y[0]
dy = y[1] - y[0]
else:
error_msg = (
error_msg_xarray
if img_is_xarray
else (
"Only numerical values are accepted for the `y` parameter "
"when an Image trace is used."
)
)
raise ValueError(error_msg)
if binary_string:
if zmin is None and zmax is None: # no rescaling, faster
img_rescaled = img
Expand All @@ -355,13 +387,24 @@ def imshow(
compression=binary_compression_level,
ext=binary_format,
)
trace = go.Image(source=img_str)
trace = go.Image(source=img_str, x0=x0, y0=y0, dx=dx, dy=dy)
else:
colormodel = "rgb" if img.shape[-1] == 3 else "rgba256"
trace = go.Image(z=img, zmin=zmin, zmax=zmax, colormodel=colormodel)
trace = go.Image(
z=img,
zmin=zmin,
zmax=zmax,
colormodel=colormodel,
x0=x0,
y0=y0,
dx=dx,
dy=dy,
)
layout = {}
if origin == "lower":
if origin == "lower" or (dy is not None and dy < 0):
layout["yaxis"] = dict(autorange=True)
if dx is not None and dx < 0:
layout["xaxis"] = dict(autorange="reversed")
else:
raise ValueError(
"px.imshow only accepts 2D single-channel, RGB or RGBA images. "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from PIL import Image
from io import BytesIO
import base64
import datetime
from plotly.express.imshow_utils import rescale_intensity

img_rgb = np.array([[[255, 0, 0], [0, 255, 0], [0, 0, 255]]], dtype=np.uint8)
Expand Down Expand Up @@ -204,6 +205,37 @@ def test_imshow_labels_and_ranges():
with pytest.raises(ValueError):
fig = px.imshow([[1, 2], [3, 4], [5, 6]], x=["a"])

img = np.ones((2, 2), dtype=np.uint8)
fig = px.imshow(img, x=["a", "b"])
assert fig.data[0].x == ("a", "b")

with pytest.raises(ValueError):
img = np.ones((2, 2, 3), dtype=np.uint8)
fig = px.imshow(img, x=["a", "b"])

img = np.ones((2, 2), dtype=np.uint8)
base = datetime.datetime(2000, 1, 1)
fig = px.imshow(img, x=[base, base + datetime.timedelta(hours=1)])
assert fig.data[0].x == (
datetime.datetime(2000, 1, 1, 0, 0),
datetime.datetime(2000, 1, 1, 1, 0),
)

with pytest.raises(ValueError):
img = np.ones((2, 2, 3), dtype=np.uint8)
base = datetime.datetime(2000, 1, 1)
fig = px.imshow(img, x=[base, base + datetime.timedelta(hours=1)])


def test_imshow_ranges_image_trace():
fig = px.imshow(img_rgb, x=[1, 11, 21])
assert fig.data[0].dx == 10
assert fig.data[0].x0 == 1
fig = px.imshow(img_rgb, x=[21, 11, 1])
assert fig.data[0].dx == -10
assert fig.data[0].x0 == 21
assert fig.layout.xaxis.autorange == "reversed"


def test_imshow_dataframe():
df = px.data.medals_wide(indexed=False)
Expand Down