diff --git a/plotly/tools.py b/plotly/tools.py index 0f71d950e85..e8fe75443d3 100644 --- a/plotly/tools.py +++ b/plotly/tools.py @@ -1464,11 +1464,10 @@ def _find_intermediate_color(lowcolor, highcolor, intermed): diff_1 = float(highcolor[1] - lowcolor[1]) diff_2 = float(highcolor[2] - lowcolor[2]) - new_tuple = (lowcolor[0] + intermed*diff_0, - lowcolor[1] + intermed*diff_1, - lowcolor[2] + intermed*diff_2) - - return new_tuple + inter_colors = np.array([lowcolor[0] + intermed * diff_0, + lowcolor[1] + intermed * diff_1, + lowcolor[2] + intermed * diff_2]) + return inter_colors @staticmethod def _unconvert_from_RGB_255(colors): @@ -1491,7 +1490,7 @@ def _unconvert_from_RGB_255(colors): return un_rgb_colors @staticmethod - def _map_z2color(zval, colormap, vmin, vmax): + def _map_z2color(zvals, colormap, vmin, vmax): """ Returns the color corresponding zval's place between vmin and vmax @@ -1508,21 +1507,14 @@ def _map_z2color(zval, colormap, vmin, vmax): "of vmax.") # find distance t of zval from vmin to vmax where the distance # is normalized to be between 0 and 1 - t = (zval - vmin)/float((vmax - vmin)) - t_color = FigureFactory._find_intermediate_color(colormap[0], - colormap[1], - t) - t_color = (t_color[0]*255.0, t_color[1]*255.0, t_color[2]*255.0) - labelled_color = 'rgb{}'.format(t_color) - - return labelled_color - - @staticmethod - def _tri_indices(simplices): - """ - Returns a triplet of lists containing simplex coordinates - """ - return ([triplet[c] for triplet in simplices] for c in range(3)) + t = (zvals - vmin) / float((vmax - vmin)) + t_colors = FigureFactory._find_intermediate_color(colormap[0], + colormap[1], + t) + t_colors = t_colors * 255. + labelled_colors = ['rgb(%s, %s, %s)' % (i, j, k) + for i, j, k in t_colors.T] + return labelled_colors @staticmethod def _trisurf(x, y, z, simplices, colormap=None, dist_func=None, @@ -1539,11 +1531,11 @@ def _trisurf(x, y, z, simplices, colormap=None, dist_func=None, points3D = np.vstack((x, y, z)).T # vertices of the surface triangles - tri_vertices = list(map(lambda index: points3D[index], simplices)) + tri_vertices = points3D[simplices] if not dist_func: # mean values of z-coordinates of triangle vertices - mean_dists = [np.mean(tri[:, 2]) for tri in tri_vertices] + mean_dists = tri_vertices[:, :, 2].mean(-1) else: # apply user inputted function to calculate # custom coloring for triangle vertices @@ -1559,38 +1551,43 @@ def _trisurf(x, y, z, simplices, colormap=None, dist_func=None, min_mean_dists = np.min(mean_dists) max_mean_dists = np.max(mean_dists) - facecolor = ([FigureFactory._map_z2color(zz, colormap, min_mean_dists, - max_mean_dists) for zz in mean_dists]) - ii, jj, kk = FigureFactory._tri_indices(simplices) + facecolor = FigureFactory._map_z2color(mean_dists, colormap, + min_mean_dists, max_mean_dists) + ii, jj, kk = zip(*simplices) triangles = graph_objs.Mesh3d(x=x, y=y, z=z, facecolor=facecolor, i=ii, j=jj, k=kk, name='') - if plot_edges is None: # the triangle sides are not plotted + if plot_edges is not True: # the triangle sides are not plotted return graph_objs.Data([triangles]) # define the lists x_edge, y_edge and z_edge, of x, y, resp z # coordinates of edge end points for each triangle # None separates data corresponding to two consecutive triangles - lists_coord = ([[[T[k % 3][c] for k in range(4)]+[None] - for T in tri_vertices] for c in range(3)]) - if x_edge is None: - x_edge = [] - for array in lists_coord[0]: - for item in array: - x_edge.append(item) - - if y_edge is None: - y_edge = [] - for array in lists_coord[1]: - for item in array: - y_edge.append(item) - - if z_edge is None: - z_edge = [] - for array in lists_coord[2]: - for item in array: - z_edge.append(item) + is_none = [ii is None for ii in [x_edge, y_edge, z_edge]] + if any(is_none): + if not all(is_none): + raise ValueError('If any (x_edge, y_edge, z_edge) is None,' + ' all must be None') + else: + x_edge = [] + y_edge = [] + z_edge = [] + + # Pull indices we care about, then add a None column to separate tris + ixs_triangles = [0, 1, 2, 0] + pull_edges = tri_vertices[:, ixs_triangles, :] + x_edge_pull = np.hstack([pull_edges[:, :, 0], + np.tile(None, [pull_edges.shape[0], 1])]) + y_edge_pull = np.hstack([pull_edges[:, :, 1], + np.tile(None, [pull_edges.shape[0], 1])]) + z_edge_pull = np.hstack([pull_edges[:, :, 2], + np.tile(None, [pull_edges.shape[0], 1])]) + + # Now unravel the edges into a 1-d vector for plotting + x_edge = np.hstack([x_edge, x_edge_pull.reshape([1, -1])[0]]) + y_edge = np.hstack([y_edge, y_edge_pull.reshape([1, -1])[0]]) + z_edge = np.hstack([z_edge, z_edge_pull.reshape([1, -1])[0]]) # define the lines for plotting lines = graph_objs.Scatter3d( @@ -5621,4 +5618,3 @@ def make_table_annotations(self): font=dict(color=font_color), showarrow=False)) return annotations -