Skip to content

Commit

Permalink
refactor laplacian matrices
Browse files Browse the repository at this point in the history
Summary:
Refactor of all functions to compute laplacian matrices in one file.
Support for:
* Standard Laplacian
* Cotangent Laplacian
* Norm Laplacian

Reviewed By: nikhilaravi

Differential Revision: D29297466

fbshipit-source-id: b96b88915ce8ef0c2f5693ec9b179fd27b70abf9
  • Loading branch information
gkioxari authored and facebook-github-bot committed Jun 24, 2021
1 parent da9974b commit 07a5a68
Show file tree
Hide file tree
Showing 8 changed files with 298 additions and 198 deletions.
74 changes: 3 additions & 71 deletions pytorch3d/loss/mesh_laplacian_smoothing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@


import torch
from pytorch3d.ops import cot_laplacian


def mesh_laplacian_smoothing(meshes, method: str = "uniform"):
Expand Down Expand Up @@ -94,6 +95,7 @@ def mesh_laplacian_smoothing(meshes, method: str = "uniform"):

N = len(meshes)
verts_packed = meshes.verts_packed() # (sum(V_n), 3)
faces_packed = meshes.faces_packed() # (sum(F_n), 3)
num_verts_per_mesh = meshes.num_verts_per_mesh() # (N,)
verts_packed_idx = meshes.verts_packed_to_mesh_idx() # (sum(V_n),)
weights = num_verts_per_mesh.gather(0, verts_packed_idx) # (sum(V_n),)
Expand All @@ -106,7 +108,7 @@ def mesh_laplacian_smoothing(meshes, method: str = "uniform"):
if method == "uniform":
L = meshes.laplacian_packed()
elif method in ["cot", "cotcurv"]:
L, inv_areas = laplacian_cot(meshes)
L, inv_areas = cot_laplacian(verts_packed, faces_packed)
if method == "cot":
norm_w = torch.sparse.sum(L, dim=1).to_dense().view(-1, 1)
idx = norm_w > 0
Expand All @@ -127,73 +129,3 @@ def mesh_laplacian_smoothing(meshes, method: str = "uniform"):

loss = loss * weights
return loss.sum() / N


def laplacian_cot(meshes):
"""
Returns the Laplacian matrix with cotangent weights and the inverse of the
face areas.
Args:
meshes: Meshes object with a batch of meshes.
Returns:
2-element tuple containing
- **L**: FloatTensor of shape (V,V) for the Laplacian matrix (V = sum(V_n))
Here, L[i, j] = cot a_ij + cot b_ij iff (i, j) is an edge in meshes.
See the description above for more clarity.
- **inv_areas**: FloatTensor of shape (V,) containing the inverse of sum of
face areas containing each vertex
"""
verts_packed = meshes.verts_packed() # (sum(V_n), 3)
faces_packed = meshes.faces_packed() # (sum(F_n), 3)
# V = sum(V_n), F = sum(F_n)
V, F = verts_packed.shape[0], faces_packed.shape[0]

face_verts = verts_packed[faces_packed]
v0, v1, v2 = face_verts[:, 0], face_verts[:, 1], face_verts[:, 2]

# Side lengths of each triangle, of shape (sum(F_n),)
# A is the side opposite v1, B is opposite v2, and C is opposite v3
A = (v1 - v2).norm(dim=1)
B = (v0 - v2).norm(dim=1)
C = (v0 - v1).norm(dim=1)

# Area of each triangle (with Heron's formula); shape is (sum(F_n),)
s = 0.5 * (A + B + C)
# note that the area can be negative (close to 0) causing nans after sqrt()
# we clip it to a small positive value
area = (s * (s - A) * (s - B) * (s - C)).clamp_(min=1e-12).sqrt()

# Compute cotangents of angles, of shape (sum(F_n), 3)
A2, B2, C2 = A * A, B * B, C * C
cota = (B2 + C2 - A2) / area
cotb = (A2 + C2 - B2) / area
cotc = (A2 + B2 - C2) / area
cot = torch.stack([cota, cotb, cotc], dim=1)
cot /= 4.0

# Construct a sparse matrix by basically doing:
# L[v1, v2] = cota
# L[v2, v0] = cotb
# L[v0, v1] = cotc
ii = faces_packed[:, [1, 2, 0]]
jj = faces_packed[:, [2, 0, 1]]
idx = torch.stack([ii, jj], dim=0).view(2, F * 3)
L = torch.sparse.FloatTensor(idx, cot.view(-1), (V, V))

# Make it symmetric; this means we are also setting
# L[v2, v1] = cota
# L[v0, v2] = cotb
# L[v1, v0] = cotc
L += L.t()

# For each vertex, compute the sum of areas for triangles containing it.
idx = faces_packed.view(-1)
inv_areas = torch.zeros(V, dtype=torch.float32, device=meshes.device)
val = torch.stack([area] * 3, dim=1).view(-1)
inv_areas.scatter_add_(0, idx, val)
idx = inv_areas > 0
inv_areas[idx] = 1.0 / inv_areas[idx]
inv_areas = inv_areas.view(-1, 1)

return L, inv_areas
1 change: 1 addition & 0 deletions pytorch3d/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .graph_conv import GraphConv
from .interp_face_attrs import interpolate_face_attributes
from .knn import knn_gather, knn_points
from .laplacian_matrices import laplacian, cot_laplacian, norm_laplacian
from .mesh_face_areas_normals import mesh_face_areas_normals
from .mesh_filtering import taubin_smoothing
from .packed_to_padded import packed_to_padded, padded_to_packed
Expand Down
170 changes: 170 additions & 0 deletions pytorch3d/ops/laplacian_matrices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import torch

# ------------------------ Laplacian Matrices ------------------------ #
# This file contains implementations of differentiable laplacian matrices.
# These include
# 1) Standard Laplacian matrix
# 2) Cotangent Laplacian matrix
# 3) Norm Laplacian matrix
# -------------------------------------------------------------------- #


def laplacian(verts: torch.Tensor, edges: torch.Tensor) -> torch.Tensor:
"""
Computes the laplacian matrix.
The definition of the laplacian is
L[i, j] = -1 , if i == j
L[i, j] = 1 / deg(i) , if (i, j) is an edge
L[i, j] = 0 , otherwise
where deg(i) is the degree of the i-th vertex in the graph.
Args:
verts: tensor of shape (V, 3) containing the vertices of the graph
edges: tensor of shape (E, 2) containing the vertex indices of each edge
Returns:
L: Sparse FloatTensor of shape (V, V)
"""
V = verts.shape[0]

e0, e1 = edges.unbind(1)

idx01 = torch.stack([e0, e1], dim=1) # (E, 2)
idx10 = torch.stack([e1, e0], dim=1) # (E, 2)
idx = torch.cat([idx01, idx10], dim=0).t() # (2, 2*E)

# First, we construct the adjacency matrix,
# i.e. A[i, j] = 1 if (i,j) is an edge, or
# A[e0, e1] = 1 & A[e1, e0] = 1
ones = torch.ones(idx.shape[1], dtype=torch.float32, device=verts.device)
A = torch.sparse.FloatTensor(idx, ones, (V, V))

# the sum of i-th row of A gives the degree of the i-th vertex
deg = torch.sparse.sum(A, dim=1).to_dense()

# We construct the Laplacian matrix by adding the non diagonal values
# i.e. L[i, j] = 1 ./ deg(i) if (i, j) is an edge
deg0 = deg[e0]
deg0 = torch.where(deg0 > 0.0, 1.0 / deg0, deg0)
deg1 = deg[e1]
deg1 = torch.where(deg1 > 0.0, 1.0 / deg1, deg1)
val = torch.cat([deg0, deg1])
L = torch.sparse.FloatTensor(idx, val, (V, V))

# Then we add the diagonal values L[i, i] = -1.
idx = torch.arange(V, device=verts.device)
idx = torch.stack([idx, idx], dim=0)
ones = torch.ones(idx.shape[1], dtype=torch.float32, device=verts.device)
L -= torch.sparse.FloatTensor(idx, ones, (V, V))

return L


def cot_laplacian(
verts: torch.Tensor, faces: torch.Tensor, eps: float = 1e-12
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Returns the Laplacian matrix with cotangent weights and the inverse of the
face areas.
Args:
verts: tensor of shape (V, 3) containing the vertices of the graph
faces: tensor of shape (F, 3) containing the vertex indices of each face
Returns:
2-element tuple containing
- **L**: Sparse FloatTensor of shape (V,V) for the Laplacian matrix.
Here, L[i, j] = cot a_ij + cot b_ij iff (i, j) is an edge in meshes.
See the description above for more clarity.
- **inv_areas**: FloatTensor of shape (V,) containing the inverse of sum of
face areas containing each vertex
"""
V, F = verts.shape[0], faces.shape[0]

face_verts = verts[faces]
v0, v1, v2 = face_verts[:, 0], face_verts[:, 1], face_verts[:, 2]

# Side lengths of each triangle, of shape (sum(F_n),)
# A is the side opposite v1, B is opposite v2, and C is opposite v3
A = (v1 - v2).norm(dim=1)
B = (v0 - v2).norm(dim=1)
C = (v0 - v1).norm(dim=1)

# Area of each triangle (with Heron's formula); shape is (sum(F_n),)
s = 0.5 * (A + B + C)
# note that the area can be negative (close to 0) causing nans after sqrt()
# we clip it to a small positive value
area = (s * (s - A) * (s - B) * (s - C)).clamp_(min=eps).sqrt()

# Compute cotangents of angles, of shape (sum(F_n), 3)
A2, B2, C2 = A * A, B * B, C * C
cota = (B2 + C2 - A2) / area
cotb = (A2 + C2 - B2) / area
cotc = (A2 + B2 - C2) / area
cot = torch.stack([cota, cotb, cotc], dim=1)
cot /= 4.0

# Construct a sparse matrix by basically doing:
# L[v1, v2] = cota
# L[v2, v0] = cotb
# L[v0, v1] = cotc
ii = faces[:, [1, 2, 0]]
jj = faces[:, [2, 0, 1]]
idx = torch.stack([ii, jj], dim=0).view(2, F * 3)
L = torch.sparse.FloatTensor(idx, cot.view(-1), (V, V))

# Make it symmetric; this means we are also setting
# L[v2, v1] = cota
# L[v0, v2] = cotb
# L[v1, v0] = cotc
L += L.t()

# For each vertex, compute the sum of areas for triangles containing it.
idx = faces.view(-1)
inv_areas = torch.zeros(V, dtype=torch.float32, device=verts.device)
val = torch.stack([area] * 3, dim=1).view(-1)
inv_areas.scatter_add_(0, idx, val)
idx = inv_areas > 0
inv_areas[idx] = 1.0 / inv_areas[idx]
inv_areas = inv_areas.view(-1, 1)

return L, inv_areas


def norm_laplacian(
verts: torch.Tensor, edges: torch.Tensor, eps: float = 1e-12
) -> torch.Tensor:
"""
Norm laplacian computes a variant of the laplacian matrix which weights each
affinity with the normalized distance of the neighboring nodes.
More concretely,
L[i, j] = 1. / wij where wij = ||vi - vj|| if (vi, vj) are neighboring nodes
Args:
verts: tensor of shape (V, 3) containing the vertices of the graph
edges: tensor of shape (E, 2) containing the vertex indices of each edge
Returns:
L: Sparse FloatTensor of shape (V, V)
"""
edge_verts = verts[edges] # (E, 2, 3)
v0, v1 = edge_verts[:, 0], edge_verts[:, 1]

# Side lengths of each edge, of shape (E,)
w01 = 1.0 / ((v0 - v1).norm(dim=1) + eps)

# Construct a sparse matrix by basically doing:
# L[v0, v1] = w01
# L[v1, v0] = w01
e01 = edges.t() # (2, E)

V = verts.shape[0]
L = torch.sparse.FloatTensor(e01, w01, (V, V))
L = L + L.t()

return L
30 changes: 1 addition & 29 deletions pytorch3d/ops/mesh_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import torch
from pytorch3d.ops import norm_laplacian
from pytorch3d.structures import Meshes, utils as struct_utils


Expand All @@ -19,35 +20,6 @@
# ----------------------- Taubin Smoothing ----------------------- #


def norm_laplacian(verts: torch.Tensor, edges: torch.Tensor, eps: float = 1e-12):
"""
Norm laplacian computes a variant of the laplacian matrix which weights each
affinity with the normalized distance of the neighboring nodes.
More concretely,
L[i, j] = 1. / wij where wij = ||vi - vj|| if (vi, vj) are neighboring nodes
Args:
verts: tensor of shape (V, 3) containing the vertices of the graph
edges: tensor of shape (E, 2) containing the vertex indices of each edge
"""
edge_verts = verts[edges] # (E, 2, 3)
v0, v1 = edge_verts[:, 0], edge_verts[:, 1]

# Side lengths of each edge, of shape (E,)
w01 = 1.0 / ((v0 - v1).norm(dim=1) + eps)

# Construct a sparse matrix by basically doing:
# L[v0, v1] = w01
# L[v1, v0] = w01
e01 = edges.t() # (2, E)

V = verts.shape[0]
L = torch.sparse.FloatTensor(e01, w01, (V, V))
L = L + L.t()

return L


def taubin_smoothing(
meshes: Meshes, lambd: float = 0.53, mu: float = -0.53, num_iter: int = 10
) -> Meshes:
Expand Down
37 changes: 4 additions & 33 deletions pytorch3d/structures/meshes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1142,6 +1142,8 @@ def _compute_laplacian_packed(self, refresh: bool = False):
Sparse FloatTensor of shape (V, V) where V = sum(V_n)
"""
from ..ops import laplacian

if not (refresh or self._laplacian_packed is None):
return

Expand All @@ -1153,39 +1155,8 @@ def _compute_laplacian_packed(self, refresh: bool = False):

verts_packed = self.verts_packed() # (sum(V_n), 3)
edges_packed = self.edges_packed() # (sum(E_n), 3)
V = verts_packed.shape[0] # sum(V_n)

e0, e1 = edges_packed.unbind(1)

idx01 = torch.stack([e0, e1], dim=1) # (sum(E_n), 2)
idx10 = torch.stack([e1, e0], dim=1) # (sum(E_n), 2)
idx = torch.cat([idx01, idx10], dim=0).t() # (2, 2*sum(E_n))

# First, we construct the adjacency matrix,
# i.e. A[i, j] = 1 if (i,j) is an edge, or
# A[e0, e1] = 1 & A[e1, e0] = 1
ones = torch.ones(idx.shape[1], dtype=torch.float32, device=self.device)
A = torch.sparse.FloatTensor(idx, ones, (V, V))

# the sum of i-th row of A gives the degree of the i-th vertex
deg = torch.sparse.sum(A, dim=1).to_dense()

# We construct the Laplacian matrix by adding the non diagonal values
# i.e. L[i, j] = 1 ./ deg(i) if (i, j) is an edge
deg0 = deg[e0]
deg0 = torch.where(deg0 > 0.0, 1.0 / deg0, deg0)
deg1 = deg[e1]
deg1 = torch.where(deg1 > 0.0, 1.0 / deg1, deg1)
val = torch.cat([deg0, deg1])
L = torch.sparse.FloatTensor(idx, val, (V, V))

# Then we add the diagonal values L[i, i] = -1.
idx = torch.arange(V, device=self.device)
idx = torch.stack([idx, idx], dim=0)
ones = torch.ones(idx.shape[1], dtype=torch.float32, device=self.device)
L -= torch.sparse.FloatTensor(idx, ones, (V, V))

self._laplacian_packed = L

self._laplacian_packed = laplacian(verts_packed, edges_packed)

def clone(self):
"""
Expand Down
Loading

0 comments on commit 07a5a68

Please sign in to comment.