Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AssertionError when working with classes which are not in ImageNet #16

Open
saqibns opened this issue Jan 25, 2020 · 1 comment
Open

Comments

@saqibns
Copy link

saqibns commented Jan 25, 2020

The function one_hot_from_names throws an AssertionError when a class name - which is not in the original ImageNet classes and for which possible synsets do not exist either - is used.

This happens because the batch_size is not updated when calling one_hot_from_int in utils.py after converting words to their respective indices.

The following lines should be able to reproduce this:

import torch
from pytorch_pretrained_biggan import BigGAN, one_hot_from_names
model = BigGAN.from_pretrained('biggan-deep-256')
class_vector = one_hot_from_names(['cake'], batch_size=1)

This would throw the following error:

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-4-4dd4cd1296e1> in <module>()
      1 
----> 2 class_vector = one_hot_from_names(['cake'], batch_size=1)

/home/saqib/Projects/Poem2GIF/repos/pytorch-pretrained-BigGAN/pytorch_pretrained_biggan/utils.py in one_hot_from_names(class_name_or_list, batch_size)
    211                 classes.append(IMAGENET[possible_synsets[0].offset()])
    212 
--> 213     return one_hot_from_int(classes, batch_size=batch_size)
    214 
    215 

/home/saqib/Projects/Poem2GIF/repos/pytorch-pretrained-BigGAN/pytorch_pretrained_biggan/utils.py in one_hot_from_int(int_or_list, batch_size)
    164         int_or_list = [int_or_list[0]] * batch_size
    165 
--> 166     assert batch_size == len(int_or_list)
    167 
    168     array = np.zeros((batch_size, NUM_CLASSES), dtype=np.float32)

AssertionError: 

@arslanan
Copy link

Hi, I have not found 'cake' in the Imagenet dataset.
I tried with "bagel" (931) and ice cream (928) here are the results
output_0
output_1

You can find all the Imagenet classes here : https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants