Skip to content

Commit

Permalink
Removing L2-norm in contrastive loss (L2-norm already present in cosi…
Browse files Browse the repository at this point in the history
…ne-similarity computation)

Signed-off-by: Lucas Robinet <robinet.lucas@iuct-oncopole.fr>
  • Loading branch information
Lucas Robinet authored and Lucas-rbnt committed May 24, 2023
1 parent 960249f commit 34e84a3
Showing 1 changed file with 1 addition and 4 deletions.
5 changes: 1 addition & 4 deletions monai/losses/contrastive.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
temperature_tensor = torch.as_tensor(self.temperature).to(input.device)
batch_size = input.shape[0]

norm_i = F.normalize(input, dim=1)
norm_j = F.normalize(target, dim=1)

negatives_mask = ~torch.eye(batch_size * 2, batch_size * 2, dtype=torch.bool)
negatives_mask = torch.clone(negatives_mask.type(torch.float)).to(input.device)

repr = torch.cat([norm_i, norm_j], dim=0)
repr = torch.cat([input, target], dim=0)
sim_matrix = F.cosine_similarity(repr.unsqueeze(1), repr.unsqueeze(0), dim=2)
sim_ij = torch.diag(sim_matrix, batch_size)
sim_ji = torch.diag(sim_matrix, -batch_size)
Expand Down

0 comments on commit 34e84a3

Please sign in to comment.