diff --git a/monai/losses/contrastive.py b/monai/losses/contrastive.py index 6213091bf6..a74f303ec6 100644 --- a/monai/losses/contrastive.py +++ b/monai/losses/contrastive.py @@ -68,13 +68,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: temperature_tensor = torch.as_tensor(self.temperature).to(input.device) batch_size = input.shape[0] - norm_i = F.normalize(input, dim=1) - norm_j = F.normalize(target, dim=1) - negatives_mask = ~torch.eye(batch_size * 2, batch_size * 2, dtype=torch.bool) negatives_mask = torch.clone(negatives_mask.type(torch.float)).to(input.device) - repr = torch.cat([norm_i, norm_j], dim=0) + repr = torch.cat([input, target], dim=0) sim_matrix = F.cosine_similarity(repr.unsqueeze(1), repr.unsqueeze(0), dim=2) sim_ij = torch.diag(sim_matrix, batch_size) sim_ji = torch.diag(sim_matrix, -batch_size)