Skip to content

Commit

Permalink
CAN Attention: standard management of heads
Browse files Browse the repository at this point in the history
  • Loading branch information
gbg141 committed May 14, 2024
1 parent 2267768 commit 3652f91
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 87 deletions.
6 changes: 3 additions & 3 deletions test/nn/cell/test_can.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_forward(self):
in_channels_1=2,
out_channels=2,
dropout=0.5,
heads=1,
heads=2,
n_layers=2,
att_lift=False,
pooling=True,
Expand All @@ -31,8 +31,8 @@ def test_forward(self):
).to_sparse()

x_0, x_1 = (
torch.tensor(x_0).float().to(device),
torch.tensor(x_1).float().to(device),
x_0.clone().detach().float().to(device),
x_1.clone().detach().float().to(device),
)
adjacency_1 = adjacency_1.float().to(device)
adjacency_2 = adjacency_1.float().to(device)
Expand Down
13 changes: 7 additions & 6 deletions test/nn/cell/test_can_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class TestCANLayer:
def test_forward(self):
"""Test the forward method of CANLayer."""
in_channels = 7
out_channels = 64
out_channels = 66
dropout_values = [0.5, 0.7]
heads_values = [1, 3]
concat_values = [True, False]
Expand Down Expand Up @@ -65,12 +65,12 @@ def test_forward(self):
)
x_out = can_layer.forward(x_1, lower_neighborhood, upper_neighborhood)
if concat:
assert x_out.shape == (n_cells, out_channels * heads)
else:
assert x_out.shape == (n_cells, out_channels)
else:
assert x_out.shape == (n_cells, out_channels // heads)

# Test if there are no non-zero values in the neighborhood
heads = 1
heads = 3
concat_list = [True, False]
skip_connection = True

Expand All @@ -79,6 +79,7 @@ def test_forward(self):
can_layer = CANLayer(
in_channels=in_channels,
out_channels=out_channels,
heads=heads,
concat=concat,
skip_connection=skip_connection,
version=version,
Expand All @@ -89,9 +90,9 @@ def test_forward(self):
torch.zeros_like(upper_neighborhood),
)
if concat:
assert x_out.shape == (n_cells, out_channels * heads)
else:
assert x_out.shape == (n_cells, out_channels)
else:
assert x_out.shape == (n_cells, out_channels // heads)

def test_reset_parameters(self):
"""Test the reset_parameters method of CANLayer."""
Expand Down
4 changes: 2 additions & 2 deletions topomodelx/nn/cell/can.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(
for _ in range(n_layers - 1):
layers.append(
CANLayer(
in_channels=out_channels * heads,
in_channels=out_channels,
out_channels=out_channels,
dropout=dropout,
heads=heads,
Expand All @@ -110,7 +110,7 @@ def __init__(
layers.append(
PoolLayer(
k_pool=k_pool,
in_channels_0=out_channels * heads,
in_channels_0=out_channels,
signal_pool_activation=torch.nn.Sigmoid(),
readout=True,
**kwargs,
Expand Down
49 changes: 23 additions & 26 deletions topomodelx/nn/cell/can_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,9 +439,9 @@ def __init__(
self.dropout = dropout
self.add_self_loops = add_self_loops

self.lin = torch.nn.Linear(in_channels, heads * out_channels, bias=False)
self.att_weight_src = Parameter(torch.Tensor(1, heads, out_channels))
self.att_weight_dst = Parameter(torch.Tensor(1, heads, out_channels))
self.lin = torch.nn.Linear(in_channels, out_channels, bias=False)
self.att_weight_src = Parameter(torch.Tensor(1, heads, out_channels // heads))
self.att_weight_dst = Parameter(torch.Tensor(1, heads, out_channels // heads))

self.reset_parameters()

Expand All @@ -468,7 +468,7 @@ def message(self, x_source):
"""
# Compute the linear transformation on the source features
x_message = self.lin(x_source).view(
-1, self.heads, self.out_channels
-1, self.heads, self.out_channels // self.heads
) # (n_k_cells, H, C)

# compute the source and target messages
Expand Down Expand Up @@ -534,12 +534,13 @@ def forward(self, x_source, neighborhood):
# If there are no non-zero values in the neighborhood, then the neighborhood is empty. -> return zero tensor
if not neighborhood.values().nonzero().size(0) > 0 and self.concat:
return torch.zeros(
(x_source.shape[0], self.out_channels * self.heads),
(x_source.shape[0], self.out_channels),
device=x_source.device,
) # (n_k_cells, H * C)
if not neighborhood.values().nonzero().size(0) > 0 and not self.concat:
return torch.zeros(
(x_source.shape[0], self.out_channels), device=x_source.device
(x_source.shape[0], self.out_channels // self.heads),
device=x_source.device,
) # (n_k_cells, C)

# Add self-loops to the neighborhood matrix if necessary
Expand All @@ -559,9 +560,7 @@ def forward(self, x_source, neighborhood):

# if concat true, concatenate the messages for each head. Otherwise, average the messages for each head.
if self.concat:
return aggregated_message.view(
-1, self.heads * self.out_channels
) # (n_k_cells, H * C)
return aggregated_message.view(-1, self.out_channels) # (n_k_cells, H * C)

return aggregated_message.mean(dim=1) # (n_k_cells, C)

Expand Down Expand Up @@ -613,7 +612,7 @@ def __init__(
heads: int,
concat: bool,
att_activation: torch.nn.Module,
add_self_loops: bool = False,
add_self_loops: bool = True,
aggr_func: Literal["sum", "mean", "add"] = "sum",
initialization: Literal["xavier_uniform", "xavier_normal"] = "xavier_uniform",
share_weights: bool = False,
Expand All @@ -634,17 +633,13 @@ def __init__(

if share_weights:
self.lin_src = self.lin_dst = torch.nn.Linear(
in_channels, heads * out_channels, bias=False
in_channels, out_channels, bias=False
)
else:
self.lin_src = torch.nn.Linear(
in_channels, heads * out_channels, bias=False
)
self.lin_dst = torch.nn.Linear(
in_channels, heads * out_channels, bias=False
)
self.lin_src = torch.nn.Linear(in_channels, out_channels, bias=False)
self.lin_dst = torch.nn.Linear(in_channels, out_channels, bias=False)

self.att_weight = Parameter(torch.Tensor(1, heads, out_channels))
self.att_weight = Parameter(torch.Tensor(1, heads, out_channels // heads))

self.reset_parameters()

Expand All @@ -671,12 +666,12 @@ def message(self, x_source):
"""
# Compute the linear transformation on the source features
x_src_message = self.lin_src(x_source).view(
-1, self.heads, self.out_channels
-1, self.heads, self.out_channels // self.heads
) # (n_k_cells, H, C)

# Compute the linear transformation on the source features
x_dst_message = self.lin_dst(x_source).view(
-1, self.heads, self.out_channels
-1, self.heads, self.out_channels // self.heads
) # (n_k_cells, H, C)

# Get the source and target projections of the neighborhood
Expand Down Expand Up @@ -737,12 +732,13 @@ def forward(self, x_source, neighborhood):
# If there are no non-zero values in the neighborhood, then the neighborhood is empty. -> return zero tensor
if not neighborhood.values().nonzero().size(0) > 0 and self.concat:
return torch.zeros(
(x_source.shape[0], self.out_channels * self.heads),
(x_source.shape[0], self.out_channels),
device=x_source.device,
) # (n_k_cells, H * C)
if not neighborhood.values().nonzero().size(0) > 0 and not self.concat:
return torch.zeros(
(x_source.shape[0], self.out_channels), device=x_source.device
(x_source.shape[0], self.out_channels // self.heads),
device=x_source.device,
) # (n_k_cells, C)

# Add self-loops to the neighborhood matrix if necessary
Expand All @@ -762,9 +758,7 @@ def forward(self, x_source, neighborhood):

# if concat true, concatenate the messages for each head. Otherwise, average the messages for each head.
if self.concat:
return aggregated_message.view(
-1, self.heads * self.out_channels
) # (n_k_cells, H * C)
return aggregated_message.view(-1, self.out_channels) # (n_k_cells, H * C)

return aggregated_message.mean(dim=1) # (n_k_cells, C)

Expand Down Expand Up @@ -836,6 +830,9 @@ def __init__(
assert in_channels > 0, ValueError("Number of input channels must be > 0")
assert out_channels > 0, ValueError("Number of output channels must be > 0")
assert heads > 0, ValueError("Number of heads must be > 0")
assert out_channels % heads == 0, ValueError(
"Number of output channels must be divisible by the number of heads"
)
assert dropout >= 0.0 and dropout <= 1.0, ValueError("Dropout must be in [0,1]")

# assert that shared weight is True only if version is v2
Expand Down Expand Up @@ -893,7 +890,7 @@ def __init__(

# linear transformation
if skip_connection:
out_channels = out_channels * heads if concat else out_channels
out_channels = out_channels if concat else out_channels // heads
self.lin = Linear(in_channels, out_channels, bias=False)
self.eps = 1 + 1e-6

Expand Down
Loading

0 comments on commit 3652f91

Please sign in to comment.