Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation Of Visualization Framework #14

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions geomfum/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,10 @@ class HierarchicalMeshRegistry(WhichRegistry):
MAP = {}


class MeshPlotterRegistry(WhichRegistry):
MAP = {}


def _create_register_funcs(module):
"""Create ``register`` functions for each class registry in this module.

Expand Down
34 changes: 34 additions & 0 deletions geomfum/plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import abc
from geomfum._registry import MeshPlotterRegistry, WhichRegistryMixins

"""
In this file we define the plotting logic used in geomfum
Since each plotting library has its own method of plotting, we define general functions that works with any library implemented

"""

class ShapePlotter(abc.ABC):
"""
Primitive Calss to plot meshes, pointclouds or specific useful informations (scalar functions, landmarks, etc..)
"""
def plot(self, mesh):

"""
Function to plot a mesh

"""
raise NotImplementedError("This method should be overridden by subclasses")

def plot_function(self, mesh, function):
"""
Function to plot a mesh
"""
raise NotImplementedError("This method should be overridden by subclasses")
def show(self):
"""
Function to display the plot
"""
raise NotImplementedError("This method should be overridden by subclasses")

class MeshPlotter(WhichRegistryMixins, ShapePlotter):
_Registry = MeshPlotterRegistry
22 changes: 22 additions & 0 deletions geomfum/shape/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import numpy as np
# TODO: decide where its best to put this functions since they are related to the import of plotly and pyvista
import plotly.graph_objects as go
import pyvista as pv



def to_go_mesh3d(mesh):
"""
Convert a TriangleMesh object to a plotly Mesh3d object.
"""
x, y, z = mesh.vertices[:, 0], mesh.vertices[:, 1], mesh.vertices[:, 2]
f1, f2, f3 = mesh.faces[:, 0], mesh.faces[:, 1], mesh.faces[:, 2]

return go.Mesh3d(x=x, y=y, z=z, i=f1, j=f2, k=f3)


def to_pv_polydata(mesh):
"""
Convert a TriangleMesh object to a PyVista PolyData object.
"""
return pv.PolyData.from_regular_faces(points=mesh.vertices, faces=mesh.faces)
10 changes: 9 additions & 1 deletion geomfum/wrap/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
register_hierarchical_mesh,
register_laplacian_finder,
register_wave_kernel_signature,
register_mesh_plotter,
)
from geomfum._utils import has_package

Expand Down Expand Up @@ -52,7 +53,14 @@
"pyfm", "PyFmFaceOrientationOperator", requires="pyFM", as_default=True
)


register_hierarchical_mesh(
"pyrmt", "PyrmtHierarchicalMesh", requires="PyRMT", as_default=True
)

register_mesh_plotter(
"plotly", "PlotlyMeshPlotter", requires="plotly", as_default=True
gviga marked this conversation as resolved.
Show resolved Hide resolved
)
register_mesh_plotter(
"pyvista", "PyvistaMeshPlotter", requires="pyvista", as_default=False
)

38 changes: 38 additions & 0 deletions geomfum/wrap/plotly.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@

import plotly.graph_objects as go
from geomfum.plot import ShapePlotter
from geomfum.shape.convert import to_go_mesh3d


class PlotlyMeshPlotter(ShapePlotter):

def __init__(self, colormap='viridis'):
self.colormap = colormap
self.fig = None

def plot(self, mesh):

mesh3d= to_go_mesh3d(mesh)
#update adding plotter properties
mesh3d.update(colorscale=self.colormap)
self.fig = go.Figure(data=[mesh3d])

return self.fig

def plot_function(self, mesh, function):


vertices=mesh.vertices
faces=mesh.faces
x, y, z = vertices[:,0], vertices[:,1], vertices[:,2]
f1,f2,f3= faces[:,0], faces[:,1], faces[:,2]
gviga marked this conversation as resolved.
Show resolved Hide resolved
self.fig = go.Figure(data=[go.Mesh3d(x=x,y=y,z=z, i=f1, j=f2, k=f3,
intensity = function,
colorscale = self.colormap,
)])

return self.fig

def show(self):
self.fig.show()

22 changes: 22 additions & 0 deletions geomfum/wrap/pyvista.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import pyvista as pv
from geomfum.plot import ShapePlotter
import numpy as np
from geomfum.shape.convert import to_pv_polydata



class PyvistaMeshPlotter(ShapePlotter):

def __init__(self, colormap='viridis'):
self.colormap = colormap
self.fig=None
def plot(self, mesh):

# here i call the plotter fig to be consistent with the plotly plotter
mesh_polydata = to_pv_polydata(mesh)
self.fig = pv.Plotter()
self.fig.add_mesh(mesh_polydata, cmap=self.colormap, show_edges=False)
return self.fig

def show(self):
self.fig.show()
Loading
Loading