Skip to content

Commit

Permalink
Update classifier.py
Browse files Browse the repository at this point in the history
  • Loading branch information
fffffgggg54 committed Jul 31, 2024
1 parent 2a0254f commit e2152dc
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion timm/layers/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ def __init__(
self.num_classes = 1000 if num_classes is None else num_classes
feature_dims = self.model.feature_info.channels()
self.decoders = nn.ModuleList([CrossAttention(embed_dim, kv_dim = dim, **kwargs) for dim in feature_dims])
self.queries = nn.ParameterList([nn.Parameter(torch.randn(32, embed_dim)) for _ in feature_dims])
self.queries = nn.ParameterList([nn.Parameter(torch.randn(1, embed_dim)) for _ in feature_dims])
self.norms = nn.ModuleList([create_norm_layer('layernorm2d', dim) for dim in feature_dims])
self.encoder = nn.Sequential(*[TransformerBlock(embed_dim, **kwargs) for _ in range(depth)])
self.head_norm = nn.LayerNorm(embed_dim)
Expand Down

0 comments on commit e2152dc

Please sign in to comment.