Skip to content

Commit

Permalink
Revert "Use intermediate embedding features"
Browse files Browse the repository at this point in the history
This reverts commit b3aff76.

Doesn't seem to help (to the contrary), so removing.
  • Loading branch information
klieret committed Jul 31, 2023
1 parent b3aff76 commit 199b9e8
Showing 1 changed file with 4 additions and 12 deletions.
16 changes: 4 additions & 12 deletions src/gnn_tracking/models/graph_construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,10 @@ def forward(self, data: Data) -> dict[str, T]:
np.sqrt(self.hparams.beta) * layer(relu(x))
+ np.sqrt(1 - self.hparams.beta) * x
)
x1 = x
x = self._decoder(relu(x))
x *= self._latent_normalization
assert x.shape[1] == self.hparams.out_dim
return {"H": x, "H1": x1}
return {"H": x}


class GraphConstructionResIN(nn.Module, HyperparametersMixin):
Expand Down Expand Up @@ -208,7 +207,6 @@ def __init__(
ml_freeze: bool = True,
ec_freeze: bool = True,
embedding_slice: tuple[int | None, int | None] = (None, None),
use_intermediate_embedding_features: bool = False,
):
"""Builds graph from embedding space. If you want to start from a checkpoint,
use `MLGraphConstruction.from_chkpt`.
Expand Down Expand Up @@ -302,16 +300,10 @@ def forward(self, data: Data) -> Data:
y: T = ( # type: ignore
data.particle_id[edge_index[0]] == data.particle_id[edge_index[1]]
)
node_features = []
if self._ml and self.hparams.use_embedding_features:
node_features.append(mo["H"])
if self._ml and self.hparams.use_intermediate_embedding_features:
node_features.append(mo["H1"])
node_features.append(data.x)
if len(node_features) > 1:
x = torch.cat(node_features, dim=1)
if not self._ml or not self.hparams.use_embedding_features:
x = data.x
else:
x = node_features[0]
x = torch.cat((mo["H"], data.x), dim=1)
# print(edge_index.shape, )
if self.hparams.ratio_of_false and self.training:
num_true = y.sum()
Expand Down

0 comments on commit 199b9e8

Please sign in to comment.