Skip to content

Commit

Permalink
Remove momentum updating from val step and add separate val queue (#631)
Browse files Browse the repository at this point in the history
* Add remove momentum updating from val step and add separate val queue

* Remove momentum updating from val step and add separate val queue

* Fix val queue init

* Update changelog

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: jirka <jirka.borovec@seznam.cz>
Co-authored-by: Akihiro Nitta <nitta@akihironitta.com>
  • Loading branch information
4 people authored Jul 4, 2021
1 parent 0045e64 commit 270867c
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 13 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed the MNIST download giving HTTP 404 with torchvision>=0.9.1 ([#674](https://github.com/PyTorchLightning/lightning-bolts/pull/674))


- Removed momentum updating from val step and add separate val queue ([#631](https://github.com/PyTorchLightning/lightning-bolts/pull/631))


## [0.3.4] - 2021-06-17

### Changed
Expand Down
34 changes: 21 additions & 13 deletions pl_bolts/models/self_supervised/moco/moco2_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,12 @@ def __init__(

self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

# create the validation queue
self.register_buffer("val_queue", torch.randn(emb_dim, num_negatives))
self.val_queue = nn.functional.normalize(self.val_queue, dim=0)

self.register_buffer("val_queue_ptr", torch.zeros(1, dtype=torch.long))

def init_encoders(self, base_encoder):
"""
Override to add your own encoders
Expand All @@ -142,21 +148,21 @@ def _momentum_update_key_encoder(self):
param_k.data = param_k.data * em + param_q.data * (1. - em)

@torch.no_grad()
def _dequeue_and_enqueue(self, keys):
def _dequeue_and_enqueue(self, keys, queue_ptr, queue):
# gather keys before updating queue
if self.trainer.use_ddp or self.trainer.use_ddp2:
keys = concat_all_gather(keys)

batch_size = keys.shape[0]

ptr = int(self.queue_ptr)
ptr = int(queue_ptr)
assert self.hparams.num_negatives % batch_size == 0 # for simplicity

# replace the keys at ptr (dequeue and enqueue)
self.queue[:, ptr:ptr + batch_size] = keys.T
queue[:, ptr:ptr + batch_size] = keys.T
ptr = (ptr + batch_size) % self.hparams.num_negatives # move pointer

self.queue_ptr[0] = ptr
queue_ptr[0] = ptr

@torch.no_grad()
def _batch_shuffle_ddp(self, x): # pragma: no cover
Expand Down Expand Up @@ -205,11 +211,12 @@ def _batch_unshuffle_ddp(self, x, idx_unshuffle): # pragma: no cover

return x_gather[idx_this]

def forward(self, img_q, img_k):
def forward(self, img_q, img_k, queue):
"""
Input:
im_q: a batch of query images
im_k: a batch of key images
queue: a queue from which to pick negative samples
Output:
logits, targets
"""
Expand All @@ -220,7 +227,6 @@ def forward(self, img_q, img_k):

# compute key features
with torch.no_grad(): # no gradient to keys
self._momentum_update_key_encoder() # update the key encoder

# shuffle for making use of BN
if self.trainer.use_ddp or self.trainer.use_ddp2:
Expand All @@ -238,7 +244,7 @@ def forward(self, img_q, img_k):
# positive logits: Nx1
l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
# negative logits: NxK
l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
l_neg = torch.einsum('nc,ck->nk', [q, queue.clone().detach()])

# logits: Nx(1+K)
logits = torch.cat([l_pos, l_neg], dim=1)
Expand All @@ -250,10 +256,7 @@ def forward(self, img_q, img_k):
labels = torch.zeros(logits.shape[0], dtype=torch.long)
labels = labels.type_as(logits)

# dequeue and enqueue
self._dequeue_and_enqueue(k)

return logits, labels
return logits, labels, k

def training_step(self, batch, batch_idx):
# in STL10 we pass in both lab+unl for online ft
Expand All @@ -264,7 +267,10 @@ def training_step(self, batch, batch_idx):

(img_1, img_2), _ = batch

output, target = self(img_q=img_1, img_k=img_2)
self._momentum_update_key_encoder() # update the key encoder
output, target, keys = self(img_q=img_1, img_k=img_2, queue=self.queue)
self._dequeue_and_enqueue(keys, queue=self.queue, queue_ptr=self.queue_ptr) # dequeue and enqueue

loss = F.cross_entropy(output.float(), target.long())

acc1, acc5 = precision_at_k(output, target, top_k=(1, 5))
Expand All @@ -282,7 +288,9 @@ def validation_step(self, batch, batch_idx):

(img_1, img_2), labels = batch

output, target = self(img_q=img_1, img_k=img_2)
output, target, keys = self(img_q=img_1, img_k=img_2, queue=self.val_queue)
self._dequeue_and_enqueue(keys, queue=self.val_queue, queue_ptr=self.val_queue_ptr) # dequeue and enqueue

loss = F.cross_entropy(output, target.long())

acc1, acc5 = precision_at_k(output, target, top_k=(1, 5))
Expand Down

0 comments on commit 270867c

Please sign in to comment.