Skip to content

Commit

Permalink
fix: Resolve dimension mismatches in sampling implementations
Browse files Browse the repository at this point in the history
- Fix confidence-guided sampler target size mismatch
- Implement proper multi-head attention with correct dimensions
- Resolve graph-based sampler einsum dimension issues
- Update message passing layer with explicit tensor operations
- Add comprehensive dimension documentation
  • Loading branch information
devin-ai-integration[bot] committed Nov 14, 2024
1 parent 3fd7d26 commit 8807feb
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 147 deletions.
259 changes: 148 additions & 111 deletions models/sampling/attention_based_sampler.py
Original file line number Diff line number Diff line change
@@ -1,150 +1,83 @@
"""
Attention-Based Sampling implementation for ProteinFlex.
Implements structure-aware attention routing for protein generation.
Implements structure-aware attention for protein generation.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Dict, Tuple, Optional

class StructureAwareAttention(nn.Module):
def __init__(
self,
feature_dim: int,
num_heads: int = 8,
dropout: float = 0.1
):
"""Initialize structure-aware attention."""
super().__init__()
self.num_heads = num_heads
self.head_dim = feature_dim // num_heads
assert self.head_dim * num_heads == feature_dim, "feature_dim must be divisible by num_heads"

self.qkv = nn.Linear(feature_dim, 3 * feature_dim)
self.structure_proj = nn.Linear(feature_dim, feature_dim)
self.output_proj = nn.Linear(feature_dim, feature_dim)
self.dropout = nn.Dropout(dropout)

def forward(
self,
x: torch.Tensor,
structure_bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Forward pass with optional structure bias."""
B, L, D = x.shape
qkv = self.qkv(x).reshape(B, L, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]

# Compute attention scores
attn = (q @ k.transpose(-2, -1)) * (1.0 / torch.sqrt(torch.tensor(self.head_dim)))

# Add structure bias if provided
if structure_bias is not None:
structure_weights = self.structure_proj(structure_bias)
structure_weights = structure_weights.view(B, 1, L, L)
attn = attn + structure_weights

attn = F.softmax(attn, dim=-1)
attn = self.dropout(attn)

# Apply attention to values
x = (attn @ v).transpose(1, 2).reshape(B, L, D)
x = self.output_proj(x)

return x

class AttentionBasedSampler(nn.Module):
def __init__(
self,
feature_dim: int = 768,
hidden_dim: int = 512,
num_layers: int = 6,
num_heads: int = 8,
num_layers: int = 6,
dropout: float = 0.1
):
"""
Initialize Attention-Based Sampler.
Args:
feature_dim: Dimension of protein features
hidden_dim: Hidden dimension for feed-forward
num_layers: Number of transformer layers
hidden_dim: Hidden dimension
num_heads: Number of attention heads
num_layers: Number of transformer layers
dropout: Dropout rate
"""
super().__init__()
self.feature_dim = feature_dim
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.head_dim = hidden_dim // num_heads

# Structure encoder
self.structure_encoder = nn.Sequential(
nn.Linear(feature_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, feature_dim)
)
# Input projection
self.input_proj = nn.Linear(feature_dim, hidden_dim)

# Transformer layers
self.layers = nn.ModuleList([
nn.ModuleDict({
'attention': StructureAwareAttention(feature_dim, num_heads, dropout),
'norm1': nn.LayerNorm(feature_dim),
'ff': nn.Sequential(
nn.Linear(feature_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, feature_dim)
),
'norm2': nn.LayerNorm(feature_dim)
}) for _ in range(num_layers)
TransformerLayer(hidden_dim, num_heads, dropout)
for _ in range(num_layers)
])

# Output projection
self.output_proj = nn.Linear(feature_dim, feature_dim)
self.output_proj = nn.Linear(hidden_dim, feature_dim)

def forward(
self,
x: torch.Tensor,
structure_info: Optional[torch.Tensor] = None
structure_bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Forward pass for training.
Forward pass.
Args:
x: Input protein features [batch_size, seq_len, feature_dim]
structure_info: Optional structure information
x: Input features [batch_size, seq_len, feature_dim]
structure_bias: Optional structure information [batch_size, seq_len, seq_len]
Returns:
Processed features
Updated features [batch_size, seq_len, feature_dim]
"""
# Process structure information if provided
structure_bias = None
if structure_info is not None:
structure_bias = self.structure_encoder(structure_info)
# Project input
h = self.input_proj(x)

# Apply transformer layers
for layer in self.layers:
# Attention with structure bias
attn_out = layer['attention'](
layer['norm1'](x),
structure_bias
)
x = x + attn_out

# Feed-forward
ff_out = layer['ff'](layer['norm2'](x))
x = x + ff_out
h = layer(h, structure_bias)

return self.output_proj(x)
# Project output
return self.output_proj(h)

def sample(
self,
batch_size: int,
seq_len: int,
device: torch.device,
structure_info: Optional[torch.Tensor] = None,
temperature: float = 1.0
temperature: float = 1.0,
structure_bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Generate protein features using attention-based sampling.
Expand All @@ -153,45 +86,149 @@ def sample(
batch_size: Number of samples to generate
seq_len: Sequence length
device: Device to generate on
structure_info: Optional structure information
temperature: Sampling temperature
structure_bias: Optional structure guidance [batch_size, seq_len, seq_len]
Returns:
Generated protein features
Generated features [batch_size, seq_len, feature_dim]
"""
# Initialize from random
x = torch.randn(batch_size, seq_len, self.feature_dim, device=device)
# Initialize random features
x = torch.randn(
batch_size, seq_len, self.feature_dim,
device=device
) * temperature

# Apply temperature scaling
x = x * temperature

# Generate features with structure guidance
return self.forward(x, structure_info)
# Refine through attention
return self.forward(x, structure_bias)

def compute_loss(
self,
pred_features: torch.Tensor,
target_features: torch.Tensor,
structure_info: Optional[torch.Tensor] = None
structure_bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Compute training loss.
Args:
pred_features: Predicted protein features
target_features: Target protein features
structure_info: Optional structure information
pred_features: Predicted features [batch_size, seq_len, feature_dim]
target_features: Target features [batch_size, seq_len, feature_dim]
structure_bias: Optional structure information [batch_size, seq_len, seq_len]
Returns:
Loss value
"""
# Feature reconstruction loss
recon_loss = F.mse_loss(pred_features, target_features)
feature_loss = F.mse_loss(pred_features, target_features)

# Structure-aware loss if bias provided
if structure_bias is not None:
pred_dist = torch.cdist(pred_features, pred_features)
target_dist = torch.cdist(target_features, target_features)
structure_loss = F.mse_loss(pred_dist, target_dist)
return feature_loss + structure_loss

return feature_loss

class TransformerLayer(nn.Module):
def __init__(
self,
hidden_dim: int,
num_heads: int,
dropout: float = 0.1
):
"""Initialize transformer layer."""
super().__init__()
self.attention = MultiHeadAttention(hidden_dim, num_heads, dropout)
self.norm1 = nn.LayerNorm(hidden_dim)
self.norm2 = nn.LayerNorm(hidden_dim)

self.ff = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim * 4),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim * 4, hidden_dim)
)

self.dropout = nn.Dropout(dropout)

def forward(
self,
x: torch.Tensor,
structure_bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Forward pass for transformer layer.
Args:
x: Input features [batch_size, seq_len, hidden_dim]
structure_bias: Optional structure information [batch_size, seq_len, seq_len]
"""
# Self-attention
attended = self.attention(x, x, x, structure_bias)
x = self.norm1(x + self.dropout(attended))

# Feed-forward
ff_out = self.ff(x)
return self.norm2(x + self.dropout(ff_out))

class MultiHeadAttention(nn.Module):
def __init__(
self,
hidden_dim: int,
num_heads: int,
dropout: float = 0.1
):
"""Initialize multi-head attention."""
super().__init__()
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.head_dim = hidden_dim // num_heads

self.q_proj = nn.Linear(hidden_dim, hidden_dim)
self.k_proj = nn.Linear(hidden_dim, hidden_dim)
self.v_proj = nn.Linear(hidden_dim, hidden_dim)
self.out_proj = nn.Linear(hidden_dim, hidden_dim)

self.dropout = nn.Dropout(dropout)

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
structure_bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Multi-head attention forward pass.
Args:
query: Query tensor [batch_size, seq_len, hidden_dim]
key: Key tensor [batch_size, seq_len, hidden_dim]
value: Value tensor [batch_size, seq_len, hidden_dim]
structure_bias: Optional structure information [batch_size, seq_len, seq_len]
"""
batch_size, seq_len = query.shape[:2]

# Project and reshape for attention heads
q = self.q_proj(query).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(key).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(value).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

# Compute attention scores
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)

# Add structure bias if provided
if structure_bias is not None:
scores = scores + structure_bias.unsqueeze(1)

# Apply attention
attn = F.softmax(scores, dim=-1)
attn = self.dropout(attn)

# Structure-aware loss if structure info provided
if structure_info is not None:
structure_pred = self.structure_encoder(pred_features)
structure_loss = F.mse_loss(structure_pred, structure_info)
return recon_loss + structure_loss
# Get output
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_dim)

return recon_loss
return self.out_proj(out)
16 changes: 11 additions & 5 deletions models/sampling/confidence_guided_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,9 @@ def compute_loss(
Compute training loss.
Args:
x: Input protein features
noise: Target noise
pred_noise: Predicted noise
x: Input protein features [batch_size, seq_len, feature_dim]
noise: Target noise [batch_size, seq_len, feature_dim]
pred_noise: Predicted noise [batch_size, seq_len, feature_dim]
Returns:
Combined loss value
Expand All @@ -175,8 +175,14 @@ def compute_loss(
noise_loss = F.mse_loss(pred_noise, noise)

# Confidence loss to encourage accurate confidence estimation
confidence = self.confidence_net(x)
confidence_target = torch.exp(-F.mse_loss(pred_noise, noise, reduction='none').mean(-1))
confidence = self.confidence_net(x) # [batch_size, seq_len, 1]
confidence = confidence.squeeze(-1) # [batch_size, seq_len]

# Compute per-residue noise prediction accuracy
noise_error = F.mse_loss(pred_noise, noise, reduction='none') # [batch_size, seq_len, feature_dim]
confidence_target = torch.exp(-noise_error.mean(-1)) # [batch_size, seq_len]

# Binary cross entropy loss for confidence prediction
confidence_loss = F.binary_cross_entropy(confidence, confidence_target)

return noise_loss + confidence_loss
Loading

0 comments on commit 8807feb

Please sign in to comment.