Skip to content
This repository has been archived by the owner on Feb 26, 2024. It is now read-only.

A code issue #16

Closed
Wastedzz opened this issue Nov 22, 2021 · 6 comments
Closed

A code issue #16

Wastedzz opened this issue Nov 22, 2021 · 6 comments

Comments

@Wastedzz
Copy link

hi, thanks for your open-source code. I met a bug when I run the beam search code in utilis/beamsearch.py line96:
self.mask = self.mask.gather(1, perm_mask)
where 'perm_mask' should be a LongTensor type, but here its type is FloatTensor. It makes severl beam search-based functions fail to use.

@chaitjo
Copy link
Owner

chaitjo commented Dec 5, 2021

Hi @Wastedzz @maoxiaowei97, thank you for your interest. I believe you may be using an incorrect version of PyTorch -- the code was tested with a now ancient version 0.4, but PyTorch has undergone several changes since then. For reproducing exactly, you may have to downgrade your PyTorch version.

@chaitjo
Copy link
Owner

chaitjo commented Dec 5, 2021

Here are some related issues and discussions for reference:

Maybe one simple thing to try first to get the code to run would be to update backpointers via integer division:

# Update backpointers
prev_k = bestScoresId // self.num_nodes

@Wastedzz
Copy link
Author

Wastedzz commented Dec 6, 2021

Here are some related issues and discussions for reference:

* [Error in beamsearch.py #11](https://github.com/chaitjo/graph-convnet-tsp/issues/11)

* [GNN Encoder learning-tsp#1 (comment)](https://github.com/chaitjo/learning-tsp/issues/1#issuecomment-814688655)

Maybe one simple thing to try first to get the code to run would be to update backpointers via integer division:

# Update backpointers
prev_k = bestScoresId // self.num_nodes

you are right, thanks for your reply!

@Wastedzz Wastedzz closed this as completed Dec 6, 2021
@maoxiaowei97
Copy link

Hi @Wastedzz @maoxiaowei97, thank you for your interest. I believe you may be using an incorrect version of PyTorch -- the code was tested with a now ancient version 0.4, but PyTorch has undergone several changes since then. For reproducing exactly, you may have to downgrade your PyTorch version.

You are right. I've figured it out, and it worked. Thank you so much!

@chaitjo
Copy link
Owner

chaitjo commented Dec 7, 2021

Great, happy to help, no worries.

@kumarr99
Copy link

kumarr99 commented Apr 5, 2023

perm_mask = perm_mask.type(torch.int64)

just add in this before gather() line to change the datatype to int64

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants