Skip to content
This repository has been archived by the owner on Dec 19, 2023. It is now read-only.

Retrieve node embeddings from checkpoint #35

Open
matteomedioli opened this issue Aug 17, 2021 · 1 comment
Open

Retrieve node embeddings from checkpoint #35

matteomedioli opened this issue Aug 17, 2021 · 1 comment

Comments

@matteomedioli
Copy link

Hi, thanks for your amazing work. I'm trying to retrieve node embedding from GAT and CONV checkpoints. What the right way to do that?
I'm trying something like this:

def get_embedding(args, unique_entities):
    model_conv = SpKBGATConvOnly(entity_embeddings, relation_embeddings, args.entity_out_dim, args.entity_out_dim,
                                 args.drop_GAT, args.drop_conv, args.alpha, args.alpha_conv,
                                 args.nheads_GAT, args.out_channels)
    model_conv.load_state_dict(torch.load(
        '{0}conv/trained_{1}.pth'.format(args.output_folder, args.epochs_conv - 1)), strict=False)

    model_conv.cuda()
    model_conv.eval()
    with torch.no_grad():
        preds = model_conv(Corpus_, Corpus_.train_adj_matrix, unique_entities)
        print(preds.size())

But I0m not sure about preds = model_conv(Corpus_, Corpus_.train_adj_matrix, unique_entities).
Thanks in advance

@matteomedioli
Copy link
Author

I think to find a possible solution (but not sure about it since I need all entities embedding but I'm using train indices):

def get_embeddings():
    fl = args.data + "/2hop.pickle"
    with open(fl, 'rb') as handle:
        node_neighbors_2hop = pickle.load(handle)

    current_batch_2hop_indices = Corpus_.get_batch_nhop_neighbors_all(args, Corpus_.unique_entities_train, node_neighbors_2hop)
    if CUDA:
        current_batch_2hop_indices = Variable(
            torch.LongTensor(current_batch_2hop_indices)).cuda()
        train_indices = Variable(
                    torch.LongTensor(Corpus_.train_indices)).cuda()

    model_gat = SpKBGATModified(entity_embeddings, relation_embeddings, args.entity_out_dim, args.entity_out_dim,
                                args.drop_GAT, args.alpha, args.nheads_GAT)
    model_gat.load_state_dict(torch.load(
        '{}/trained_{}.pth'.format(args.output_folder, args.epochs_gat - 1)), strict=False)
    model_gat.cuda()
    model_gat.eval()
    with torch.no_grad():
        entity_embed, relation_embed = model_gat(
            Corpus_, Corpus_.train_adj_matrix, train_indices, current_batch_2hop_indices)
    return entity_embed, relation_embed

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant