Skip to content

Commit

Permalink
Add debug prints for tensor shapes in UnifiedPredictor
Browse files Browse the repository at this point in the history
  • Loading branch information
devin-ai-integration[bot] committed Nov 14, 2024
1 parent 2e7620f commit ddc0850
Showing 1 changed file with 21 additions and 5 deletions.
26 changes: 21 additions & 5 deletions models/analysis/multimodal_integrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,27 +162,43 @@ def forward(
structure_features: torch.Tensor,
function_results: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
# Add debug prints for input shapes
print(f"Sequence features shape: {sequence_features.shape}")
print(f"Structure features shape: {structure_features.shape}")
print(f"Function results GO terms shape: {function_results['go_terms'].shape}")

# Transform function features to match dimensions
function_features = self.function_encoder(function_results['go_terms'])
print(f"Transformed function features shape: {function_features.shape}")

# Ensure all features have the same dimensions before combining
batch_size = sequence_features.size(0)
seq_len = sequence_features.size(1)
feature_dim = sequence_features.size(2) # Should be 768

# Reshape features if needed
sequence_features = sequence_features.view(batch_size, seq_len, feature_dim)
structure_features = structure_features.view(batch_size, seq_len, feature_dim)
function_features = function_features.view(batch_size, seq_len, feature_dim)

# Reshape function features if needed
function_features = function_features.view(batch_size, seq_len, -1)
print(f"Reshaped sequence features: {sequence_features.shape}")
print(f"Reshaped structure features: {structure_features.shape}")
print(f"Reshaped function features: {function_features.shape}")

# Combine all features
combined_features = torch.cat([
sequence_features,
structure_features,
function_features
], dim=-1)
], dim=-1) # Concatenate along feature dimension

print(f"Combined features shape: {combined_features.shape}")

# Generate unified representation
# Process through integration network
unified_features = self.integration_network(combined_features)
print(f"Unified features shape: {unified_features.shape}")

# Estimate prediction confidence
# Estimate confidence
confidence = self.confidence_estimator(unified_features)

return {
Expand Down

0 comments on commit ddc0850

Please sign in to comment.