Skip to content

vectorizing some trisurf functions for performance improvement #472

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 43 additions & 47 deletions plotly/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1464,11 +1464,10 @@ def _find_intermediate_color(lowcolor, highcolor, intermed):
diff_1 = float(highcolor[1] - lowcolor[1])
diff_2 = float(highcolor[2] - lowcolor[2])

new_tuple = (lowcolor[0] + intermed*diff_0,
lowcolor[1] + intermed*diff_1,
lowcolor[2] + intermed*diff_2)

return new_tuple
inter_colors = np.array([lowcolor[0] + intermed * diff_0,
lowcolor[1] + intermed * diff_1,
lowcolor[2] + intermed * diff_2])
return inter_colors

@staticmethod
def _unconvert_from_RGB_255(colors):
Expand All @@ -1491,7 +1490,7 @@ def _unconvert_from_RGB_255(colors):
return un_rgb_colors

@staticmethod
def _map_z2color(zval, colormap, vmin, vmax):
def _map_z2color(zvals, colormap, vmin, vmax):
"""
Returns the color corresponding zval's place between vmin and vmax

Expand All @@ -1508,21 +1507,14 @@ def _map_z2color(zval, colormap, vmin, vmax):
"of vmax.")
# find distance t of zval from vmin to vmax where the distance
# is normalized to be between 0 and 1
t = (zval - vmin)/float((vmax - vmin))
t_color = FigureFactory._find_intermediate_color(colormap[0],
colormap[1],
t)
t_color = (t_color[0]*255.0, t_color[1]*255.0, t_color[2]*255.0)
labelled_color = 'rgb{}'.format(t_color)

return labelled_color

@staticmethod
def _tri_indices(simplices):
"""
Returns a triplet of lists containing simplex coordinates
"""
return ([triplet[c] for triplet in simplices] for c in range(3))
t = (zvals - vmin) / float((vmax - vmin))
t_colors = FigureFactory._find_intermediate_color(colormap[0],
colormap[1],
t)
t_colors = t_colors * 255.
labelled_colors = ['rgb(%s, %s, %s)' % (i, j, k)
for i, j, k in t_colors.T]
return labelled_colors

@staticmethod
def _trisurf(x, y, z, simplices, colormap=None, dist_func=None,
Expand All @@ -1539,11 +1531,11 @@ def _trisurf(x, y, z, simplices, colormap=None, dist_func=None,
points3D = np.vstack((x, y, z)).T

# vertices of the surface triangles
tri_vertices = list(map(lambda index: points3D[index], simplices))
tri_vertices = points3D[simplices]

if not dist_func:
# mean values of z-coordinates of triangle vertices
mean_dists = [np.mean(tri[:, 2]) for tri in tri_vertices]
mean_dists = tri_vertices[:, :, 2].mean(-1)
else:
# apply user inputted function to calculate
# custom coloring for triangle vertices
Expand All @@ -1559,38 +1551,43 @@ def _trisurf(x, y, z, simplices, colormap=None, dist_func=None,

min_mean_dists = np.min(mean_dists)
max_mean_dists = np.max(mean_dists)
facecolor = ([FigureFactory._map_z2color(zz, colormap, min_mean_dists,
max_mean_dists) for zz in mean_dists])
ii, jj, kk = FigureFactory._tri_indices(simplices)
facecolor = FigureFactory._map_z2color(mean_dists, colormap,
min_mean_dists, max_mean_dists)
ii, jj, kk = zip(*simplices)

triangles = graph_objs.Mesh3d(x=x, y=y, z=z, facecolor=facecolor,
i=ii, j=jj, k=kk, name='')

if plot_edges is None: # the triangle sides are not plotted
if plot_edges is not True: # the triangle sides are not plotted
return graph_objs.Data([triangles])

# define the lists x_edge, y_edge and z_edge, of x, y, resp z
# coordinates of edge end points for each triangle
# None separates data corresponding to two consecutive triangles
lists_coord = ([[[T[k % 3][c] for k in range(4)]+[None]
for T in tri_vertices] for c in range(3)])
if x_edge is None:
x_edge = []
for array in lists_coord[0]:
for item in array:
x_edge.append(item)

if y_edge is None:
y_edge = []
for array in lists_coord[1]:
for item in array:
y_edge.append(item)

if z_edge is None:
z_edge = []
for array in lists_coord[2]:
for item in array:
z_edge.append(item)
is_none = [ii is None for ii in [x_edge, y_edge, z_edge]]
if any(is_none):
if not all(is_none):
raise ValueError('If any (x_edge, y_edge, z_edge) is None,'
' all must be None')
else:
x_edge = []
y_edge = []
z_edge = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool logic.


# Pull indices we care about, then add a None column to separate tris
ixs_triangles = [0, 1, 2, 0]
pull_edges = tri_vertices[:, ixs_triangles, :]
x_edge_pull = np.hstack([pull_edges[:, :, 0],
np.tile(None, [pull_edges.shape[0], 1])])
y_edge_pull = np.hstack([pull_edges[:, :, 1],
np.tile(None, [pull_edges.shape[0], 1])])
z_edge_pull = np.hstack([pull_edges[:, :, 2],
np.tile(None, [pull_edges.shape[0], 1])])

# Now unravel the edges into a 1-d vector for plotting
x_edge = np.hstack([x_edge, x_edge_pull.reshape([1, -1])[0]])
y_edge = np.hstack([y_edge, y_edge_pull.reshape([1, -1])[0]])
z_edge = np.hstack([z_edge, z_edge_pull.reshape([1, -1])[0]])

# define the lines for plotting
lines = graph_objs.Scatter3d(
Expand Down Expand Up @@ -5621,4 +5618,3 @@ def make_table_annotations(self):
font=dict(color=font_color),
showarrow=False))
return annotations