Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error about implementation of cp2p dataset #2

Open
SeanSiyang opened this issue Jan 12, 2023 · 0 comments
Open

Error about implementation of cp2p dataset #2

SeanSiyang opened this issue Jan 12, 2023 · 0 comments

Comments

@SeanSiyang
Copy link

To use the cp2p dataset in DPFM, I wrote two py files:

  • cp2p_dataset.py
  • train_cp2p.py

In the cp2p_dataset.py, I defined the class "Cp2pDataset" that inherits "Dataset" class, which is come from:

from torch.utils.data import Dataset

The specific implementation code of py is as follows:

import os
from pathlib import Path
import numpy as np
import potpourri3d as pp3d
import torch
from torch.utils.data import Dataset
import diffusion_net as dfn
from tqdm import tqdm
from itertools import permutations
from utils import farthest_point_sample, square_distance

class Cp2pDataset(Dataset):
    def __init__(self, root_dir, name="cp2p", k_eig=128, n_fmap=30, use_cache=True, op_cache_dir=None):

        self.k_eig = k_eig
        self.n_fmap = n_fmap
        self.root_dir = root_dir
        self.cache_dir = root_dir
        self.op_cache_dir = op_cache_dir

        if use_cache:
            train_cache = os.path.join(self.cache_dir, "train.pt")
            load_cache = train_cache
            print("using dataset cache path: " + str(load_cache))
            if os.path.exists(load_cache):
                print("  --> loading dataset from cache")
                (
                    self.verts_list,
                    self.faces_list,
                    self.frames_list,
                    self.massvec_list,
                    self.L_list,
                    self.evals_list,
                    self.evecs_list,
                    self.gradX_list,
                    self.gradY_list,
                    self.hks_list,
                    self.vts_list,
                    self.names_list,
                    self.sample_list
                ) = torch.load(load_cache)
                self.combinations = list(self.corres_dict.keys())
                return
            print("  --> dataset not in cache, repopulating")

        # Load the meshes and labels
        # define files and order
        train = True
        if train:
            path = "./data/cp2p/splits/train.txt"
            with open(path, 'r') as f:
                mesh_lists = f.read().strip().split()
        else:
            path = "./data/cp2p/splits/test.txt"
            with open(path, 'r') as f:
                mesh_lists = f.read().strip().split()
        self.used_shapes = sorted(x[:-4] for x in mesh_lists)

        corres_path = Path(root_dir) / "maps"
        all_combs = [x.stem for x in corres_path.iterdir()]
        self.corres_dict = {}
        for x, y in map(lambda x: (x[:x.rfind("_")], x[x.rfind("_") + 1:]), all_combs):
            if x in self.used_shapes and y in self.used_shapes:
                map_ = torch.from_numpy(np.loadtxt(corres_path / f"{x}_{y}.map", dtype=np.int32)).long() - 1
                self.corres_dict[(self.used_shapes.index(y), self.used_shapes.index(x))] = map_

        # set combinations
        self.combinations = list(self.corres_dict.keys())
        mesh_dirpath = Path(root_dir) / "shapes"

        # Get all the files
        self.verts_list = []
        self.faces_list = []
        self.sample_list = []

        # Load the actual files
        for shape_name in self.used_shapes:
            print("loading mesh " + str(shape_name))
            verts, faces = pp3d.read_mesh(str(mesh_dirpath / f"{shape_name}.off"))

            # to torch
            verts = torch.tensor(np.ascontiguousarray(verts)).float()
            faces = torch.tensor(np.ascontiguousarray(faces))
            self.verts_list.append(verts)
            self.faces_list.append(faces)
            idx0 = farthest_point_sample(verts.t(), ratio=0.9)
            dists, idx1 = square_distance(verts.unsqueeze(0), verts[idx0].unsqueeze(0)).sort(dim=-1)
            dists, idx1 = dists[:, :, :130].clone(), idx1[:, :, :130].clone()
            self.sample_list.append((idx0, idx1, dists))

        # Precompute operators
        (
            self.frames_list,
            self.massvec_list,
            self.L_list,
            self.evals_list,
            self.evecs_list,
            self.gradX_list,
            self.gradY_list,
        ) = dfn.geometry.get_all_operators(
            self.verts_list,
            self.faces_list,
            k_eig=self.k_eig,
            op_cache_dir=self.op_cache_dir,
        )

        # save to cache
        if use_cache:
            dfn.utils.ensure_dir_exists(self.cache_dir)
            torch.save(
                (
                    self.verts_list,
                    self.faces_list,
                    self.frames_list,
                    self.massvec_list,
                    self.L_list,
                    self.evals_list,
                    self.evecs_list,
                    self.gradX_list,
                    self.gradY_list,
                    self.used_shapes,
                    self.corres_dict,
                    self.sample_list,
                ),
                load_cache,
            )

    def __len__(self):
        return len(self.combinations)

    def __getitem__(self, item):
        idx1, idx2 = self.combinations[item]

        shape1 = {
            "xyz": self.verts_list[idx1],
            "faces": self.faces_list[idx1],
            "frames": self.frames_list[idx1],
            "mass": self.massvec_list[idx1],
            "L": self.L_list[idx1],
            "evals": self.evals_list[idx1],
            "evecs": self.evecs_list[idx1],
            "gradX": self.gradX_list[idx1],
            "gradY": self.gradY_list[idx1],
            "name": self.used_shapes[idx1],
            "sample_idx": self.sample_list[idx1],
        }

        shape2 = {
            "xyz": self.verts_list[idx2],
            "faces": self.faces_list[idx2],
            "frames": self.frames_list[idx2],
            "mass": self.massvec_list[idx2],
            "L": self.L_list[idx2],
            "evals": self.evals_list[idx2],
            "evecs": self.evecs_list[idx2],
            "gradX": self.gradX_list[idx2],
            "gradY": self.gradY_list[idx2],
            "name": self.used_shapes[idx2],
            "sample_idx": self.sample_list[idx2],
        }
        # Compute fmap
        map21 = self.corres_dict[(idx1, idx2)]

        evec_1, evec_2, mass2 = shape1["evecs"][:, :self.n_fmap], shape2["evecs"][:, :self.n_fmap], shape2["mass"]
        trans_evec2 = evec_2.t() @ torch.diag(mass2)

        P = torch.zeros(evec_2.size(0), evec_1.size(0))
        P[range(evec_2.size(0)), map21.flatten()] = 1
        C_gt = trans_evec2 @ P @ evec_1

        # compute region labels
        gt_partiality_mask12 = torch.zeros(shape1["xyz"].size(0)).long().detach()
        gt_partiality_mask12[map21[map21 != -1]] = 1
        gt_partiality_mask21 = torch.zeros(shape2["xyz"].size(0)).long().detach()
        gt_partiality_mask21[map21 != -1] = 1

        return {"shape1": shape1, "shape2": shape2, "C_gt": C_gt,
                "map21": map21, "gt_partiality_mask12": gt_partiality_mask12, "gt_partiality_mask21": gt_partiality_mask21}


def shape_to_device(dict_shape, device):
    names_to_device = ["xyz", "faces", "mass", "evals", "evecs", "gradX", "gradY"]
    for k, v in dict_shape.items():
        if "shape" in k:
            for name in names_to_device:
                v[name] = v[name].to(device)
            dict_shape[k] = v
        else:
            dict_shape[k] = v.to(device)

    return dict_shape

However, I encountered the following error:

Traceback (most recent call last):
  File "train_cp2p.py", line 90, in <module>
    train_net(cfg)
  File "train_cp2p.py", line 55, in train_net
    for i, data in enumerate(train_loader):
  File "/home/anaconda3/envs/fm2023/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 517, in __next__
    data = self._next_data()
  File "/home/anaconda3/envs/fm2023/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 557, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/home/anaconda3/envs/fm2023/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 46, in fetch
    data = self.dataset[possibly_batched_index]
  File "/home/FM_Code/dpfm/Cp2p_dataset.py", line 174, in __getitem__
    P[range(evec_2.size(0)), map21.flatten()] = 1
IndexError: shape mismatch: indexing tensors could not be broadcast together with shapes [1777], [5830]

Fortunately, loding shapes and get_all_operators operations are working fine.
From the error above, I think there is a problem with the implementation of the "getitem" function

I don't know why the error occurred and hope I can get an answer from you~
Have a nice day :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant