Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

summary not working with a custom pre-processing layer #282

Open
jaroslawjanas opened this issue Nov 5, 2023 · 2 comments
Open

summary not working with a custom pre-processing layer #282

jaroslawjanas opened this issue Nov 5, 2023 · 2 comments

Comments

@jaroslawjanas
Copy link

Describe the bug
I have a custom TextVectorization layer, it doesn't use any nn. layers, it's just a dictionary of words that is used to fill in a torch.zeros vector. I want it to be baked in so I put it as the first layer in my model.

Unfortunately, it doesn't work with torchinfo.summary(mode, input_shape(["test"] * batch_size).
Which is bothersome.

The model.forward(["this is a test"]) works just fine so I am somewhat confident that it's an issue with torchinfo not being able to handle my custom layer. It worked fine without it (with random int tokens as input data).

Code and Setup
TextVectorization

class TextVectorization(nn.Module):
    def __init__(self, max_vocabulary, max_tokens):
        super(TextVectorization, self).__init__()
        self.max_tokens = max_tokens
        self.max_vocabulary = max_vocabulary
        self.word_dictionary = dict()
        self.dictionary_size = 0

    def adapt(self, dataset):
        word_frequencies = defaultdict(int)

        for text in dataset:
            for word in text[0].split():
                word_frequencies[word] += 1

        # Sort the dictionary by word frequencies in descending order
        sorted_word_frequencies = dict(sorted(word_frequencies.items(),
                                              key=lambda item: item[1],
                                              reverse=True)
        )

        # Take the top N most frequent words
        most_frequent = list(sorted_word_frequencies.items())[:self.max_vocabulary]
        self.dictionary_size = len(most_frequent)

        # Note starting at 2 since 0 (padding) and 1 (missing) are reserved
        for word_value, (word, frequency) in enumerate(most_frequent, 2):
            self.word_dictionary[word] = word_value

    def vocabulary_size(self):
        return self.dictionary_size

    def dictionary(self):
        return self.word_dictionary

    def forward(self, batch_x):
        batch_text_vectors = torch.zeros((len(batch_x), self.max_tokens), dtype=torch.int32)

        for i, text in enumerate(batch_x):

            # Split the text and tokenize it
            words = text.split()[:self.max_tokens]

            for pos, word in enumerate(words):
                batch_text_vectors[i, pos] = self.word_dictionary.get(word, 1)

        return batch_text_vectors

Model

class TransformerModel(nn.Module):
    def __init__(self, max_tokens, vocab_size, embed_dim, num_heads, ff_dim, vectorize_layer):
        super(TransformerModel, self).__init__()
        self.vectorize_layer = vectorize_layer
        self.embedding_layer = TokenAndPositionEmbedding(
            max_tokens,
            vocab_size,
            embed_dim
        )
        self.transformer_block = TransformerBlock(embed_dim, num_heads, ff_dim)
        self.global_avg_pooling = nn.AdaptiveAvgPool1d(1)
        self.dropout = nn.Dropout(0.1)
        self.fc1 = nn.Linear(embed_dim, 20)
        self.fc2 = nn.Linear(20, 3)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.vectorize_layer(x)
        x = self.embedding_layer(x)
        x = self.transformer_block(x)
        x = self.global_avg_pooling(x.permute(0, 2, 1)).squeeze(2)
        x = self.dropout(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.softmax(x)
        return x

Summary

summary_samples = ["This is a test"] * batch_size
summary(model, input_data=summary_samples)

Runtime Error

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[/usr/local/lib/python3.10/dist-packages/torchinfo/torchinfo.py](https://localhost:8080/#) in forward_pass(model, x, batch_dim, cache_forward_pass, device, mode, **kwargs)
    294             if isinstance(x, (list, tuple)):
--> 295                 _ = model(*x, **kwargs)
    296             elif isinstance(x, dict):

4 frames
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1517         else:
-> 1518             return self._call_impl(*args, **kwargs)
   1519 

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1567 
-> 1568             result = forward_call(*args, **kwargs)
   1569             if _global_forward_hooks or self._forward_hooks:

TypeError: TransformerModel.forward() takes 2 positional arguments but 257 were given

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
[<ipython-input-107-41c92f8997e3>](https://localhost:8080/#) in <cell line: 5>()
      3 summary_samples = ["This is a test"] * batch_size
      4 # print(np.shape(summary_samples))
----> 5 summary(model, input_data=summary_samples)
      6 
      7 

[/usr/local/lib/python3.10/dist-packages/torchinfo/torchinfo.py](https://localhost:8080/#) in summary(model, input_size, input_data, batch_dim, cache_forward_pass, col_names, col_width, depth, device, dtypes, mode, row_settings, verbose, **kwargs)
    221         input_data, input_size, batch_dim, device, dtypes
    222     )
--> 223     summary_list = forward_pass(
    224         model, x, batch_dim, cache_forward_pass, device, model_mode, **kwargs
    225     )

[/usr/local/lib/python3.10/dist-packages/torchinfo/torchinfo.py](https://localhost:8080/#) in forward_pass(model, x, batch_dim, cache_forward_pass, device, mode, **kwargs)
    302     except Exception as e:
    303         executed_layers = [layer for layer in summary_list if layer.executed]
--> 304         raise RuntimeError(
    305             "Failed to run torchinfo. See above stack traces for more details. "
    306             f"Executed layers up to: {executed_layers}"

RuntimeError: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: []

Screenshots
If applicable, add screenshots to help explain your problem.

Desktop (please complete the following information):

  • Colab
  • PyTorch = 2.1.0+cu118
  • torchinfo = 1.8.0
@allispaul
Copy link

This is related to #254 and probably also #280. The code expects "tensor-like" input, not strings. Even if this isn't fixed, the error should definitely be caught earlier and stated more clearly. As it stands, process_input doesn't know what to do with this kind of input, and there are related issues coming from traverse_input_data.

I would love to work on this. Does anyone have opinions on what should be done: new functionality to handle text input, or a better error message?

@TylerYep
Copy link
Owner

TylerYep commented Feb 20, 2024

Either solution sounds good to me. The better error message sounds like a good place to start, and then handling text input would be a good followup. PRs are definitely welcome!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants