Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix graph feature bug #14

Merged
merged 1 commit into from
Jun 2, 2021
Merged

fix graph feature bug #14

merged 1 commit into from
Jun 2, 2021

Conversation

zhulf0804
Copy link
Contributor

Hi, Shengyu.

Thanks for your nice work and open source code. But there may be a small bug in function get_graph_feature() in models/gcn.py in which i update all_feats = feats.unsqueeze(-1).repeat(1, 1, 1, N) to all_feats = feats.unsqueeze(2).repeat(1,1,N,1).

I implemented the following code to verify that the code all_feats = feats.unsqueeze(2).repeat(1,1,N,1) is right.

import numpy as np
import torch


def setup_seed(seed):
    torch.backends.cudnn.deterministic = True
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)


def square_distance(src, dst, normalised = False):
    """
    Calculate Euclid distance between each two points.
    Args:
        src: source points, [B, N, C]
        dst: target points, [B, M, C]
    Returns:
        dist: per-point square distance, [B, N, M]
    """
    B, N, _ = src.shape
    _, M, _ = dst.shape
    dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
    if(normalised):
        dist += 2
    else:
        dist += torch.sum(src ** 2, dim=-1)[:, :, None]
        dist += torch.sum(dst ** 2, dim=-1)[:, None, :]

    dist = torch.clamp(dist, min=1e-12, max=None)
    return dist


def get_graph_feature(coords, feats, k=10):
    """
    Apply KNN search based on coordinates, then concatenate the features to the centroid features
    Input:
        X:          [B, 3, N]
        feats:      [B, C, N]
    Return:
        feats_cat:  [B, 2C, N, k]
    """
    # apply KNN search to build neighborhood
    B, C, N = feats.size()
    dist = square_distance(coords.transpose(1,2), coords.transpose(1,2))

    idx = dist.topk(k=k+1, dim=-1, largest=False, sorted=True)[1]  #[B, N, K+1], here we ignore the smallest element as it's the query itself
    idx = idx[:,:,1:]  #[B, N, K]
    print(features)
    print(idx)
    idx = idx.unsqueeze(1).repeat(1,C,1,1) #[B, C, N, K]
    all_feats = feats.unsqueeze(2).repeat(1,1,N,1)  #[B, C, N, N]
    # all_feats = feats.unsqueeze(-1).repeat(1, 1, 1, N)
    neighbor_feats = torch.gather(all_feats, dim=-1,index=idx) #[B, C, N, K]
    print(neighbor_feats)

    # concatenate the features with centroid
    feats = feats.unsqueeze(-1).repeat(1,1,1,k)
    feats_cat = torch.cat((feats, neighbor_feats-feats),dim=1)

    return feats_cat


if __name__ == '__main__':
    setup_seed(1234)
    coords = torch.randn(1, 3, 5)
    features = torch.randn(1, 4, 5)
    feats_cat = get_graph_feature(coords, features, k=2)

@ShengyuH
Copy link
Member

hi,

Thanks for the catch. Did you retrain the model afterwards?

It seems I do have the bug and is confirmed with the provided sanity check. I will retrain the model and come back with new results hopefully before CVPR conference.

Best,
Shengyu

@zhulf0804
Copy link
Contributor Author

Hi Shengyu,

Thanks for your reply. I am training the model on the 3DMatch and KITTI dataset with the updated get_graph_feature(). Hope for a better result.

Best,
Lifa

@ShengyuH ShengyuH merged commit 4403532 into prs-eth:main Jun 2, 2021
@zgojcic zgojcic linked an issue Jun 2, 2021 that may be closed by this pull request
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[FIXED] Bug in the self-attention GNN.
2 participants