Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error at merge_prepare.py #145

Open
wants to merge 27 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 36 additions & 52 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,98 +1,82 @@
<img src='docs/imgs/alien.gif' align="right" width=325>
<img src='docs/imgs/C60.png' align="right">
<br><br><br>

# MeshCNN in PyTorch

# MedMeshCNN

### SIGGRAPH 2019 [[Paper]](https://bit.ly/meshcnn) [[Project Page]](https://ranahanocka.github.io/MeshCNN/)<br>
MedMeshCNN is an expansion of [MeshCNN](https://ranahanocka.github.io/MeshCNN/) proposed by [Rana Hanocka](https://www.cs.tau.ac.il/~hanocka/) et al.

MeshCNN is a general-purpose deep neural network for 3D triangular meshes, which can be used for tasks such as 3D shape classification or segmentation. This framework includes convolution, pooling and unpooling layers which are applied directly on the mesh edges.
[MeshCNN](https://ranahanocka.github.io/MeshCNN/) is a general-purpose deep neural network for 3D triangular meshes, which can be used for tasks such as 3D shape classification or segmentation. This framework includes convolution, pooling and unpooling layers which are applied directly on the mesh edges.

<img src="docs/imgs/meshcnn_overview.png" align="center" width="750px"> <br>
MedMeshCNN enables the use of [MeshCNN](https://ranahanocka.github.io/MeshCNN/) for medical surface meshes through an improved memory efficiency that allows to
to keep patient-specific properties and fine-grained patterns during segmentation. Furthermore, a weighted loss function improves the performance of MedMeshCNN on imbalanced datasets that are often caused by pathological appearances.

MedMeshCNN may also be used beyond the medical domain for all applications that include imbalanced datasets and require fine-grained segmentation results.

Advances of MedMeshCNN include:
* Processing meshes with 170.000 edges (NVIDIA GeForce GTX 1080 TiGPU with 12GB RAM)
* IoU metrics
* Weighted loss function to enable better performances on imbalanced class distributions

Please check out the corresponding [PartSegmentationToolbox](https://github.com/LSnyd/PartSegmentationToolbox) to find further information on how to create a segmentation ground truth and helper scripts to scale segmentation results to different mesh resultions.

The code was written by [Rana Hanocka](https://www.cs.tau.ac.il/~hanocka/) and [Amir Hertz](http://pxcm.org/) with support from [Noa Fish](http://www.cs.tau.ac.il/~noafish/).

# Getting Started


### Installation
- Clone this repo:
```bash
git clone https://github.com/ranahanocka/MeshCNN.git
cd MeshCNN
git clone https://github.com/LSnyd/MedMeshCNN.git
cd MedMeshCNN
```
- Install dependencies: [PyTorch](https://pytorch.org/) version 1.2. <i> Optional </i>: [tensorboardX](https://github.com/lanpa/tensorboardX) for training plots.
- Install dependencies: [PyTorch](https://pytorch.org/) version 1.4. <i> Optional </i>: [tensorboardX](https://github.com/lanpa/tensorboardX) for training plots.
- Via new conda environment `conda env create -f environment.yml` (creates an environment called meshcnn)

### 3D Shape Classification on SHREC


### 3D Shape Segmentation on Humans
Download the dataset
```bash
bash ./scripts/shrec/get_data.sh
bash ./scripts/human_seg/get_data.sh
```

Run training (if using conda env first activate env e.g. ```source activate meshcnn```)
```bash
bash ./scripts/shrec/train.sh
bash ./scripts/human_seg/train.sh
```

To view the training loss plots, in another terminal run ```tensorboard --logdir runs``` and click [http://localhost:6006](http://localhost:6006).

Run test and export the intermediate pooled meshes:
```bash
bash ./scripts/shrec/test.sh
bash /scripts/human_seg/test.sh
```

Visualize the network-learned edge collapses:
```bash
bash ./scripts/shrec/view.sh
bash ./scripts/human_seg/view.sh
```

An example of collapses for a mesh:
Some segmentation result examples:

<img src="/docs/imgs/T252.png" width="450px"/>
<img src="/docs/imgs/shrec__10_0.png" height="150px"/> <img src="/docs/imgs/shrec__14_0.png" height="150px"/> <img src="/docs/imgs/shrec__2_0.png" height="150px"/>

Note, you can also get pre-trained weights using bash ```./scripts/shrec/get_pretrained.sh```.
### Hyperparameters

In order to use the pre-trained weights, run ```train.sh``` which will compute and save the mean / standard deviation of the training data.
To alter the values of the hyperparameters, change the bash scripts above accordingly.
This also includes the weight vector for the weighted loss function, which requires one weight per class.


### 3D Shape Segmentation on Humans
The same as above, to download the dataset / run train / get pretrained / run test / view
```bash
bash ./scripts/human_seg/get_data.sh
bash ./scripts/human_seg/train.sh
bash ./scripts/human_seg/get_pretrained.sh
bash ./scripts/human_seg/test.sh
bash ./scripts/human_seg/view.sh
```

Some segmentation result examples:
### More Info

<img src="/docs/imgs/shrec__10_0.png" height="150px"/> <img src="/docs/imgs/shrec__14_0.png" height="150px"/> <img src="/docs/imgs/shrec__2_0.png" height="150px"/>

### Additional Datasets
The same scripts also exist for COSEG segmentation in ```scripts/coseg_seg``` and cubes classification in ```scripts/cubes```.
Check out the corresponding [PartSegmentationToolbox](https://github.com/LSnyd/PartSegmentationToolbox) and my [medium article](https://medium.com/@lisa_81193/how-to-perform-a-3d-segmentation-in-blender-2-82-d87300305f3f) to find further information on how to create a segmentation ground truth as illustrated below. You can also find helper scripts that scale segmentation results to different mesh resolutions.

# More Info
Check out the [MeshCNN wiki](https://github.com/ranahanocka/MeshCNN/wiki) for more details. Specifically, see info on [segmentation](https://github.com/ranahanocka/MeshCNN/wiki/Segmentation) and [data processing](https://github.com/ranahanocka/MeshCNN/wiki/Data-Processing).
<img src='docs/imgs/C060_seg_fine.png' align="center" width="350px">

# Citation
If you find this code useful, please consider citing our paper
```
@article{hanocka2019meshcnn,
title={MeshCNN: A Network with an Edge},
author={Hanocka, Rana and Hertz, Amir and Fish, Noa and Giryes, Raja and Fleishman, Shachar and Cohen-Or, Daniel},
journal={ACM Transactions on Graphics (TOG)},
volume={38},
number={4},
pages = {90:1--90:12},
year={2019},
publisher={ACM}
}
```
Also check out the [MeshCNN wiki](https://github.com/ranahanocka/MeshCNN/wiki) for more details. Specifically, see info on [segmentation](https://github.com/ranahanocka/MeshCNN/wiki/Segmentation) and [data processing](https://github.com/ranahanocka/MeshCNN/wiki/Data-Processing).


# Questions / Issues
If you have questions or issues running this code, please open an issue so we can know to fix it.

# Acknowledgments
This code design was adopted from [pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix).
If you have questions or issues running this code, please open an issue.
Binary file added docs/imgs/C060_seg_fine.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/imgs/C60.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed docs/imgs/T18.png
Binary file not shown.
Binary file removed docs/imgs/T252.png
Binary file not shown.
Binary file removed docs/imgs/T76.png
Binary file not shown.
Binary file removed docs/imgs/alien.gif
Binary file not shown.
Binary file removed docs/imgs/coseg_alien.png
Binary file not shown.
Binary file removed docs/imgs/coseg_chair.png
Binary file not shown.
Binary file removed docs/imgs/coseg_vase.png
Binary file not shown.
Binary file removed docs/imgs/cubes.png
Binary file not shown.
Binary file removed docs/imgs/cubes2.png
Binary file not shown.
Binary file removed docs/imgs/input_edge_features.png
Binary file not shown.
Binary file removed docs/imgs/mesh_conv.png
Binary file not shown.
Binary file removed docs/imgs/mesh_pool_unpool.png
Binary file not shown.
10 changes: 5 additions & 5 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ channels:
- pytorch
- defaults
dependencies:
- python=3.6.8
- python=3.8.5
- cython=0.27.3
- pytorch=1.2.0
- numpy=1.15.0
- matplotlib=3.0.3
- pip
- pytorch=1.4.0
- numpy=1.19.1
- matplotlib=3.3.1
- pip
- pip:
- git+https://github.com/lanpa/tensorboardX.git
- pytest==5.1.1
53 changes: 41 additions & 12 deletions models/layers/mesh_union.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,33 @@
import torch
from torch.nn import ConstantPad2d
import time
from util.util import myindexrowselect

from options.base_options import BaseOptions

class MeshUnion:
def __init__(self, n, device=torch.device('cpu')):
def __init__(self, n, device=torch.device('cpu')):
gpu_ids = BaseOptions().get_device()
self.device = torch.device('cuda:{}'.format(gpu_ids[0])) if len(gpu_ids)>0 else torch.device('cpu')

self.__size = n
self.rebuild_features = self.rebuild_features_average
self.groups = torch.eye(n, device=device)
self.values = torch.ones(n, dtype= torch.float)
self.groups = torch.sparse_coo_tensor(indices= torch.stack((torch.arange(n), torch.arange(n)),dim=0), values= self.values,

size=(self.__size, self.__size), device=self.device)


def union(self, source, target):
self.groups[target, :] += self.groups[source, :]
index = torch.tensor([source], dtype=torch.long)
row = myindexrowselect(self.groups, index, self.device).to(self.device)
row._indices()[0] = torch.tensor(target)
row = torch.sparse_coo_tensor(indices=row._indices(), values= row._values(),
size=(self.__size, self.__size), device=self.device)
self.groups = self.groups.add(row)
self.groups = self.groups.coalesce()
del index, row


def remove_group(self, index):
return
Expand All @@ -18,27 +36,38 @@ def get_group(self, edge_key):
return self.groups[edge_key, :]

def get_occurrences(self):
return torch.sum(self.groups, 0)
return torch.sparse.sum(self.groups, 0).values()


def get_groups(self, tensor_mask):
self.groups = torch.clamp(self.groups, 0, 1)
return self.groups[tensor_mask, :]
## Max comp
mask_index = torch.squeeze((tensor_mask == True).nonzero()).to(self.device)
return myindexrowselect(self.groups, mask_index, self.device)


def rebuild_features_average(self, features, mask, target_edges):
self.prepare_groups(features, mask)
fe = torch.matmul(features.squeeze(-1), self.groups)
occurrences = torch.sum(self.groups, 0).expand(fe.shape)

self.groups = self.groups.to(self.device)
fe = torch.matmul(self.groups.transpose(0,1),features.squeeze(-1).transpose(1,0)).transpose(0,1)
occurrences = torch.sparse.sum(self.groups, 0).to_dense()
fe = fe / occurrences
padding_b = target_edges - fe.shape[1]
if padding_b > 0:
padding_b = ConstantPad2d((0, padding_b, 0, 0), 0)
fe = padding_b(fe)
return fe


def prepare_groups(self, features, mask):
tensor_mask = torch.from_numpy(mask)
self.groups = torch.clamp(self.groups[tensor_mask, :], 0, 1).transpose_(1, 0)
mask_index = torch.squeeze((torch.from_numpy(mask) == True).nonzero())

self.groups = myindexrowselect(self.groups, mask_index, self.device).transpose(1,0)
padding_a = features.shape[1] - self.groups.shape[0]

if padding_a > 0:
padding_a = ConstantPad2d((0, 0, 0, padding_a), 0)
self.groups = padding_a(self.groups)
self.groups = torch.sparse_coo_tensor(
indices=self.groups._indices(), values=self.groups._values(), dtype=torch.float32,
size=(features.shape[1], self.groups.shape[1]))


57 changes: 46 additions & 11 deletions models/layers/mesh_unpool.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import torch
import torch.nn as nn


from options.base_options import BaseOptions

class MeshUnpool(nn.Module):
def __init__(self, unroll_target):
super(MeshUnpool, self).__init__()
self.unroll_target = unroll_target
gpu_ids = BaseOptions().get_device()
self.device = torch.device('cuda:{}'.format(gpu_ids[0])) if len(gpu_ids)>0 else torch.device('cpu')

def __call__(self, features, meshes):
return self.forward(features, meshes)
Expand All @@ -16,8 +17,11 @@ def pad_groups(self, group, unroll_start):
padding_rows = unroll_start - start
padding_cols = self.unroll_target - end
if padding_rows != 0 or padding_cols !=0:
padding = nn.ConstantPad2d((0, padding_cols, 0, padding_rows), 0)
group = padding(group)
size1 = group.shape[0] + padding_rows
size2 = group.shape[1] + padding_cols
group = torch.sparse_coo_tensor(
indices=group._indices(), values=group._values(), dtype=torch.float32,
size=(size1, size2))
return group

def pad_occurrences(self, occurrences):
Expand All @@ -29,13 +33,44 @@ def pad_occurrences(self, occurrences):

def forward(self, features, meshes):
batch_size, nf, edges = features.shape
groups = [self.pad_groups(mesh.get_groups(), edges) for mesh in meshes]
unroll_mat = torch.cat(groups, dim=0).view(batch_size, edges, -1)
groups = [self.pad_groups(mesh.get_groups(), edges).to(self.device) for mesh in meshes]
unroll_mat = torch.stack(groups)
occurrences = [self.pad_occurrences(mesh.get_occurrences()) for mesh in meshes]
occurrences = torch.cat(occurrences, dim=0).view(batch_size, 1, -1)
occurrences = occurrences.expand(unroll_mat.shape)
unroll_mat = unroll_mat / occurrences
unroll_mat = unroll_mat.to(features.device)
occurrences = torch.unsqueeze(torch.stack(occurrences, dim=0), dim=1)

#Sparse division only possible for scalars
#Iterate over dense batches

imin = 0
length = 500
result = []
while imin <= unroll_mat.size()[2]:
try:
sliceUnroll_mat = unroll_mat.narrow_copy(2, imin, length).to_dense().to(self.device)
sliceOcc = occurrences.narrow_copy(2, imin, length).to(self.device)
sliceResult = (sliceUnroll_mat / sliceOcc).to_sparse()
imin = imin + 500

result.append(sliceResult)

except Exception:
length = unroll_mat.size()[2] - imin

unroll_mat = torch.cat(result, -1).to(features.device)

for mesh in meshes:
mesh.unroll_gemm()
return torch.matmul(features, unroll_mat)

#Fix Matmul, due to missing strides of sparse representation
result = []
unroll_mat = unroll_mat.transpose(1,2)
features = features.transpose(1,2)

#iterate over batches
for batch in range(batch_size):
mat = torch.matmul(unroll_mat[batch], features[batch])
mat = torch.unsqueeze(mat, dim=0)
result.append(mat)
return torch.cat(result, dim=0).transpose(1,2)


11 changes: 9 additions & 2 deletions models/mesh_classifier.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from . import networks
from os.path import join
from util.util import seg_accuracy, print_network
from util.util import seg_accuracy, print_network, mean_iou_calc


class ClassifierModel:
Expand Down Expand Up @@ -113,7 +113,8 @@ def test(self):
label_class = self.labels
self.export_segmentation(pred_class.cpu())
correct = self.get_accuracy(pred_class, label_class)
return correct, len(label_class)
mean_iou, iou =self.get_iou(pred_class, label_class)
return correct, len(label_class), mean_iou, iou

def get_accuracy(self, pred, labels):
"""computes accuracy for classification / segmentation """
Expand All @@ -123,6 +124,12 @@ def get_accuracy(self, pred, labels):
correct = seg_accuracy(pred, self.soft_label, self.mesh)
return correct

def get_iou(self, pred, labels):
"""computes IoU for segmentation """
if self.opt.dataset_mode == 'segmentation':
mean_iou, iou = mean_iou_calc(pred, self.labels, self.nclasses)
return mean_iou, iou

def export_segmentation(self, pred_seg):
if self.opt.dataset_mode == 'segmentation':
for meshi, mesh in enumerate(self.mesh):
Expand Down
3 changes: 2 additions & 1 deletion models/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ def define_loss(opt):
if opt.dataset_mode == 'classification':
loss = torch.nn.CrossEntropyLoss()
elif opt.dataset_mode == 'segmentation':
loss = torch.nn.CrossEntropyLoss(ignore_index=-1)
weights = torch.FloatTensor(opt.weighted_loss)
loss = torch.nn.CrossEntropyLoss(weights, ignore_index=-1)
return loss

##############################################################################
Expand Down
Loading