diff --git a/CHANGELOG.md b/CHANGELOG.md index 5edecfad4c..a247597da4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed the MNIST download giving HTTP 404 with torchvision>=0.9.1 ([#674](https://github.com/PyTorchLightning/lightning-bolts/pull/674)) +- Removed momentum updating from val step and add separate val queue ([#631](https://github.com/PyTorchLightning/lightning-bolts/pull/631)) + + ## [0.3.4] - 2021-06-17 ### Changed diff --git a/pl_bolts/models/self_supervised/moco/moco2_module.py b/pl_bolts/models/self_supervised/moco/moco2_module.py index fe3b64d3b2..7a50fad853 100644 --- a/pl_bolts/models/self_supervised/moco/moco2_module.py +++ b/pl_bolts/models/self_supervised/moco/moco2_module.py @@ -121,6 +121,12 @@ def __init__( self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) + # create the validation queue + self.register_buffer("val_queue", torch.randn(emb_dim, num_negatives)) + self.val_queue = nn.functional.normalize(self.val_queue, dim=0) + + self.register_buffer("val_queue_ptr", torch.zeros(1, dtype=torch.long)) + def init_encoders(self, base_encoder): """ Override to add your own encoders @@ -142,21 +148,21 @@ def _momentum_update_key_encoder(self): param_k.data = param_k.data * em + param_q.data * (1. - em) @torch.no_grad() - def _dequeue_and_enqueue(self, keys): + def _dequeue_and_enqueue(self, keys, queue_ptr, queue): # gather keys before updating queue if self.trainer.use_ddp or self.trainer.use_ddp2: keys = concat_all_gather(keys) batch_size = keys.shape[0] - ptr = int(self.queue_ptr) + ptr = int(queue_ptr) assert self.hparams.num_negatives % batch_size == 0 # for simplicity # replace the keys at ptr (dequeue and enqueue) - self.queue[:, ptr:ptr + batch_size] = keys.T + queue[:, ptr:ptr + batch_size] = keys.T ptr = (ptr + batch_size) % self.hparams.num_negatives # move pointer - self.queue_ptr[0] = ptr + queue_ptr[0] = ptr @torch.no_grad() def _batch_shuffle_ddp(self, x): # pragma: no cover @@ -205,11 +211,12 @@ def _batch_unshuffle_ddp(self, x, idx_unshuffle): # pragma: no cover return x_gather[idx_this] - def forward(self, img_q, img_k): + def forward(self, img_q, img_k, queue): """ Input: im_q: a batch of query images im_k: a batch of key images + queue: a queue from which to pick negative samples Output: logits, targets """ @@ -220,7 +227,6 @@ def forward(self, img_q, img_k): # compute key features with torch.no_grad(): # no gradient to keys - self._momentum_update_key_encoder() # update the key encoder # shuffle for making use of BN if self.trainer.use_ddp or self.trainer.use_ddp2: @@ -238,7 +244,7 @@ def forward(self, img_q, img_k): # positive logits: Nx1 l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) # negative logits: NxK - l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) + l_neg = torch.einsum('nc,ck->nk', [q, queue.clone().detach()]) # logits: Nx(1+K) logits = torch.cat([l_pos, l_neg], dim=1) @@ -250,10 +256,7 @@ def forward(self, img_q, img_k): labels = torch.zeros(logits.shape[0], dtype=torch.long) labels = labels.type_as(logits) - # dequeue and enqueue - self._dequeue_and_enqueue(k) - - return logits, labels + return logits, labels, k def training_step(self, batch, batch_idx): # in STL10 we pass in both lab+unl for online ft @@ -264,7 +267,10 @@ def training_step(self, batch, batch_idx): (img_1, img_2), _ = batch - output, target = self(img_q=img_1, img_k=img_2) + self._momentum_update_key_encoder() # update the key encoder + output, target, keys = self(img_q=img_1, img_k=img_2, queue=self.queue) + self._dequeue_and_enqueue(keys, queue=self.queue, queue_ptr=self.queue_ptr) # dequeue and enqueue + loss = F.cross_entropy(output.float(), target.long()) acc1, acc5 = precision_at_k(output, target, top_k=(1, 5)) @@ -282,7 +288,9 @@ def validation_step(self, batch, batch_idx): (img_1, img_2), labels = batch - output, target = self(img_q=img_1, img_k=img_2) + output, target, keys = self(img_q=img_1, img_k=img_2, queue=self.val_queue) + self._dequeue_and_enqueue(keys, queue=self.val_queue, queue_ptr=self.val_queue_ptr) # dequeue and enqueue + loss = F.cross_entropy(output, target.long()) acc1, acc5 = precision_at_k(output, target, top_k=(1, 5))