1
1
import plotly .graph_objs as go
2
+ from _plotly_utils .basevalidators import ColorscaleValidator
3
+ from ._core import apply_default_cascade
2
4
import numpy as np # is it fine to depend on np here?
3
5
4
6
_float_types = []
@@ -54,7 +56,19 @@ def _infer_zmax_from_type(img):
54
56
return 2 ** 32
55
57
56
58
57
- def imshow (img , zmin = None , zmax = None , origin = None , colorscale = None ):
59
+ def imshow (
60
+ img ,
61
+ zmin = None ,
62
+ zmax = None ,
63
+ origin = None ,
64
+ color_continuous_scale = None ,
65
+ color_continuous_midpoint = None ,
66
+ range_color = None ,
67
+ title = None ,
68
+ template = None ,
69
+ width = None ,
70
+ height = None ,
71
+ ):
58
72
"""
59
73
Display an image, i.e. data on a 2D regular raster.
60
74
@@ -74,16 +88,38 @@ def imshow(img, zmin=None, zmax=None, origin=None, colorscale=None):
74
88
zmin and zmax correspond to the min and max values of the datatype for integer
75
89
datatypes (ie [0-255] for uint8 images, [0, 65535] for uint16 images, etc.). For
76
90
a multichannel image of floats, the max of the image is computed and zmax is the
77
- smallest power of 256 (1, 255, 65535) greater than this max value,
91
+ smallest power of 256 (1, 255, 65535) greater than this max value,
78
92
with a 5% tolerance. For a single-channel image, the max of the image is used.
79
93
80
94
origin : str, 'upper' or 'lower' (default 'upper')
81
95
position of the [0, 0] pixel of the image array, in the upper left or lower left
82
96
corner. The convention 'upper' is typically used for matrices and images.
83
97
84
- colorscale : str
85
- colormap used to map scalar data to colors (for a 2D image). This parameter is not used for
86
- RGB or RGBA images.
98
+ color_continuous_scale : str or list of str
99
+ colormap used to map scalar data to colors (for a 2D image). This parameter is
100
+ not used for RGB or RGBA images. If a string is provided, it should be the name
101
+ of a known color scale, and if a list is provided, it should be a list of CSS-
102
+ compatible colors.
103
+
104
+ color_continuous_midpoint : number
105
+ If set, computes the bounds of the continuous color scale to have the desired
106
+ midpoint.
107
+
108
+ range_color : list of two numbers
109
+ If provided, overrides auto-scaling on the continuous color scale, including
110
+ overriding `color_continuous_midpoint`.
111
+
112
+ title : str
113
+ The figure title.
114
+
115
+ template : str or dict or plotly.graph_objects.layout.Template instance
116
+ The figure template name or definition.
117
+
118
+ width : number
119
+ The figure width in pixels.
120
+
121
+ height: number
122
+ The figure height in pixels, defaults to 600.
87
123
88
124
Returns
89
125
-------
@@ -101,21 +137,33 @@ def imshow(img, zmin=None, zmax=None, origin=None, colorscale=None):
101
137
In order to update and customize the returned figure, use
102
138
`go.Figure.update_traces` or `go.Figure.update_layout`.
103
139
"""
140
+ args = locals ()
141
+ apply_default_cascade (args )
142
+
104
143
img = np .asanyarray (img )
105
144
# Cast bools to uint8 (also one byte)
106
145
if img .dtype == np .bool :
107
146
img = 255 * img .astype (np .uint8 )
108
147
109
148
# For 2d data, use Heatmap trace
110
149
if img .ndim == 2 :
111
- if colorscale is None :
112
- colorscale = "gray"
113
- trace = go .Heatmap (z = img , zmin = zmin , zmax = zmax , colorscale = colorscale )
150
+ trace = go .Heatmap (z = img , zmin = zmin , zmax = zmax , coloraxis = "coloraxis1" )
114
151
autorange = True if origin == "lower" else "reversed"
115
152
layout = dict (
116
153
xaxis = dict (scaleanchor = "y" , constrain = "domain" ),
117
154
yaxis = dict (autorange = autorange , constrain = "domain" ),
118
155
)
156
+ colorscale_validator = ColorscaleValidator ("colorscale" , "imshow" )
157
+ range_color = range_color or [None , None ]
158
+ layout ["coloraxis1" ] = dict (
159
+ colorscale = colorscale_validator .validate_coerce (
160
+ args ["color_continuous_scale" ]
161
+ ),
162
+ cmid = color_continuous_midpoint ,
163
+ cmin = range_color [0 ],
164
+ cmax = range_color [1 ],
165
+ )
166
+
119
167
# For 2D+RGB data, use Image trace
120
168
elif img .ndim == 3 and img .shape [- 1 ] in [3 , 4 ]:
121
169
if zmax is None and img .dtype is not np .uint8 :
@@ -127,8 +175,17 @@ def imshow(img, zmin=None, zmax=None, origin=None, colorscale=None):
127
175
layout ["yaxis" ] = dict (autorange = True )
128
176
else :
129
177
raise ValueError (
130
- "px.imshow only accepts 2D grayscale , RGB or RGBA images. "
178
+ "px.imshow only accepts 2D single-channel , RGB or RGBA images. "
131
179
"An image of shape %s was provided" % str (img .shape )
132
180
)
181
+
182
+ layout_patch = dict ()
183
+ for v in ["title" , "height" , "width" ]:
184
+ if args [v ]:
185
+ layout_patch [v ] = args [v ]
186
+ if "title" not in layout_patch and args ["template" ].layout .margin .t is None :
187
+ layout_patch ["margin" ] = {"t" : 60 }
133
188
fig = go .Figure (data = trace , layout = layout )
189
+ fig .update_layout (layout_patch )
190
+ fig .update_layout (template = args ["template" ], overwrite = True )
134
191
return fig
0 commit comments