Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support forward with multiple arguments #251

Open
joaolcguerreiro opened this issue May 9, 2023 · 4 comments
Open

Support forward with multiple arguments #251

joaolcguerreiro opened this issue May 9, 2023 · 4 comments

Comments

@joaolcguerreiro
Copy link

Imagine I have a module like this:

class Model(nn.Module):
    def __init__(self, generator, discriminator):
        super(Model, self).__init__()

        # Define Generator
        self.generator = generator

        # Define Discriminator
        self.discriminator = discriminator

    def forward(self, lr, hr):
        gen = self.generator(lr)

        return gen, self.discriminator(gen), self.discriminator(hr)

If I want to call summary(model, input_size=..., depth=1) what should the input_size look like? Is it supported?

I believe the summary function could handle a input_size in a list meaning the forward will receive as many arguments as element in the list passed.

@snimu
Copy link
Contributor

snimu commented May 9, 2023

It is not supported via input_size, but you can easily circumvent that by using the input_data-argument. For example:

from torchview import summary

generator, discriminator = ...
model = Model(generator, discriminator)
lr = torch.randn(1, 2, 3, 4, 5)  # whatever lr is
hr = torch.randn(2, 5, 2, 5)  # whatever hr is

summary(model, input_data=(lr, hr), depth=1)

Because I don't know what the generator or discriminator is, or what lr and hr are, I cannot be more specific, and I don't know for sure if this will work, but in principle, you can just generate pseudo-data and give that to summary. By packaging multiple inputs into a single tuple or list, you can handle models like yours.

If you try that and it still fails, then you can write again. If so, I would need more detail to look into it more closely.

@neverstoplearn
Copy link

@snimu ,I meet the same error,my code like this:
`def test_batch(self, img, label):

    self.model.eval()

    with torch.no_grad():
        label_input, label_length, label_target = self.converter.test_encode(label)
        if self.use_gpu:
            img = img.cuda()
            #print(img.shape)
            label_input = label_input.cuda()

        if self.need_text:
            pred = self.model((img, label_input))
            from torchinfo import summary
            print(img.shape,label_input.shape)
            lr = torch.randn(288,1,32,100)
            hr = torch.randn(288,1)
            summary(self.model,input_data=(lr,hr),depth=1)
        else:
            pred = self.model((img,))

        pred, prob = self.postprocess(pred, self.postprocess_cfg)
        self.metric.measure(pred, prob, label)
        self.backup_metric.measure(pred, prob, label)

`
but I got this error:
torch.Size([288, 1, 32, 100]) torch.Size([288, 1])
Traceback (most recent call last):
File "/home/zhengxin/anaconda3/envs/torch182/lib/python3.7/site-packages/torchinfo/torchinfo.py", line 288, in forward_pass
_ = model.to(device)(*x, **kwargs)
File "/home/zhengxin/anaconda3/envs/torch182/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
TypeError: forward() takes 2 positional arguments but 3 were given

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "tools/test.py", line 46, in
main()
File "tools/test.py", line 42, in main
runner()
File "tools/../vedastr/runners/test_runner.py", line 49, in call
self.test_batch(img, label)
File "tools/../vedastr/runners/test_runner.py", line 32, in test_batch
summary(self.model,input_data=(lr,hr),depth=1)
File "/home/zhengxin/anaconda3/envs/torch182/lib/python3.7/site-packages/torchinfo/torchinfo.py", line 219, in summary
model, x, batch_dim, cache_forward_pass, device, model_mode, **kwargs
File "/home/zhengxin/anaconda3/envs/torch182/lib/python3.7/site-packages/torchinfo/torchinfo.py", line 300, in forward_pass
) from e
RuntimeError: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: []
can you help me? thanks.

@snimu
Copy link
Contributor

snimu commented May 12, 2023

It looks like your model's forward-method only takes one input, but you have given it two.

Here is the call that effectively happens inside summary, given your arguments:

# Setup:
model = Model()

# The call:
model(lr, hr)

You can see this from the following part of the error message: TypeError: forward() takes 2 positional arguments but 3 were given. The three arguments that were given are self, lr, and hr (self is automatically given). You have not provided code for your model, but it seems clear to me that your model's forward-pass only takes a single argument besides self.

@snimu
Copy link
Contributor

snimu commented May 29, 2023

@joaolcguerreiro Is your issue resolved?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants