Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Output value of Resampler (IP-Adapter-Plus-SDXL) #434

Open
232525 opened this issue Oct 21, 2024 · 1 comment
Open

Output value of Resampler (IP-Adapter-Plus-SDXL) #434

232525 opened this issue Oct 21, 2024 · 1 comment

Comments

@232525
Copy link

232525 commented Oct 21, 2024

Hello, I try to print the value of Resampler, and found that the output values are strange (too large: -12296, 1219).
And I tried to train IP-Adapter-Plus-SDXL in fp16, which was easy to numerical overflow and caused NaN loss, have you adopted some tricks for stability training?

    def forward(self, x):
        if self.pos_emb is not None:
            n, device = x.shape[1], x.device
            pos_emb = self.pos_emb(torch.arange(n, device=device))
            x = x + pos_emb

        latents = self.latents.repeat(x.size(0), 1, 1)

        x = self.proj_in(x)

        if self.to_latents_from_mean_pooled_seq:
            meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
            meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
            latents = torch.cat((meanpooled_latents, latents), dim=-2)

        for attn, ff in self.layers:
            latents = attn(x, latents) + latents
            latents = ff(latents) + latents
            print("image (attn):", latents.min(), latents.max(), latents.mean())
        print("image (bfr proj_out):", latents.min(), latents.max(), latents.mean())

        latents = self.proj_out(latents)
        print("image (bfr norm):", latents.min(), latents.max(), latents.mean())
        return self.norm_out(latents)

image (attn): tensor(-55.5000, device='cuda:0', dtype=torch.float16) tensor(41.8750, device='cuda:0', dtype=torch.float16) tensor(-0.0103, device='cuda:0', dtype=torch.float16)
image (attn): tensor(-82.1875, device='cuda:0', dtype=torch.float16) tensor(64.3750, device='cuda:0', dtype=torch.float16) tensor(-0.0302, device='cuda:0', dtype=torch.float16)
image (attn): tensor(-271.7500, device='cuda:0', dtype=torch.float16) tensor(116.8750, device='cuda:0', dtype=torch.float16) tensor(-0.0097, device='cuda:0', dtype=torch.float16)
image (attn): tensor(-252.5000, device='cuda:0', dtype=torch.float16) tensor(184.6250, device='cuda:0', dtype=torch.float16) tensor(0.0262, device='cuda:0', dtype=torch.float16)
image (bfr proj_out): tensor(-252.5000, device='cuda:0', dtype=torch.float16) tensor(184.6250, device='cuda:0', dtype=torch.float16) tensor(0.0262, device='cuda:0', dtype=torch.float16)
image (bfr norm): tensor(-12296., device='cuda:0', dtype=torch.float16) tensor(1219., device='cuda:0', dtype=torch.float16) tensor(0.5435, device='cuda:0', dtype=torch.float16)

@ValentiaSulli
Copy link

Hello,how did you solve this problem?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants