Skip to content

Commit

Permalink
Fix syntax error and implement global average pooling for confidence …
Browse files Browse the repository at this point in the history
…estimation
  • Loading branch information
devin-ai-integration[bot] committed Nov 14, 2024
1 parent 2025439 commit d97446d
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions models/analysis/multimodal_integrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def __init__(self, hidden_size: int):
nn.Linear(1536, 768) # Final output dimension
)

# Global average pooling followed by confidence estimation
self.confidence_estimator = nn.Sequential(
nn.Linear(768, 384),
nn.ReLU(),
Expand Down Expand Up @@ -210,8 +211,13 @@ def forward(
unified_features = self.integration_network(combined_features)
print(f"Unified features shape: {unified_features.shape}")

# Estimate confidence
confidence = self.confidence_estimator(unified_features)
# Global average pooling for confidence estimation
pooled_features = torch.mean(unified_features, dim=1) # Average across sequence length
print(f"Pooled features shape: {pooled_features.shape}")

# Estimate confidence (single value per protein)
confidence = self.confidence_estimator(pooled_features)
print(f"Confidence shape: {confidence.shape}")

return {
'unified_features': unified_features,
Expand Down

0 comments on commit d97446d

Please sign in to comment.