Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Support FSDP #358

Draft
wants to merge 37 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
67a0e13
resolve conflicts
mehdidc Jan 31, 2023
173cba4
show before fsdp memory usage
mehdidc Jan 6, 2023
a45acae
add ddp again
mehdidc Jan 6, 2023
fa80396
resolve conflicts
mehdidc Jan 31, 2023
9f967b7
resolve conflicts
mehdidc Jan 31, 2023
1832e13
resolve conflicts
mehdidc Jan 31, 2023
9d5369e
minor
mehdidc Jan 7, 2023
08016d0
fix logit scale and eval issues on FSDP
mehdidc Jan 7, 2023
8820831
support cpu offload
mehdidc Jan 7, 2023
188bc9c
wrap residual blocks with FSDP
mehdidc Jan 7, 2023
2782ab1
add forward trick to CustomCLIP
mehdidc Jan 8, 2023
afd8ef3
test_training_clip_with_jit test error
mehdidc Jan 31, 2023
6627268
select layers to wrap in FSDP and grad checkpointing
mehdidc Jan 31, 2023
fd42631
support unlocking
mehdidc Feb 4, 2023
4f65c85
fix hang after epoch finish
mehdidc Feb 18, 2023
3bada34
use `use_orig_params=True` (thanks to @nkflash) to use original param…
mehdidc Feb 19, 2023
f495986
fix distill
mehdidc Mar 7, 2023
397b8fc
fix FSDP optim state save/load so that we save the full optim state d…
mehdidc Mar 13, 2023
f2c72f8
offload to cpu when saving checkpoint to avoid OOM
mehdidc Mar 14, 2023
a69c0a7
- use the new ModuleWrapPolicy instead of transformer_auto_wrap_polic…
mehdidc May 17, 2023
62980cb
use ShardedGradScaler for fsdp, thanks to @nkflash
mehdidc May 17, 2023
9e47140
- FSDP printouts: use logging info.
mehdidc May 17, 2023
a8d644b
parametrize FSDP mixed precision
mehdidc May 17, 2023
16013c4
use a boolean param args.fsdp to match current args.horovod instead o…
mehdidc May 17, 2023
7735cac
replace last args.distributed_engine mention in the code
mehdidc May 17, 2023
f4165f7
fsdp log on rank zero only
mehdidc May 17, 2023
3aa42f4
minor
mehdidc May 17, 2023
5e167b2
minor
mehdidc May 17, 2023
5704ada
rank0 only and offload to cpu both true as recommended
mehdidc May 18, 2023
ffcf226
cli parameters description
mehdidc May 18, 2023
d3ab217
support CoCa models
mehdidc May 22, 2023
86799c2
fix optimizer resuming in FSDP and remove param/buffer precision
mehdidc May 24, 2023
0859c84
use original_model instead of model
mehdidc Nov 3, 2023
0a98da2
delete old import
mehdidc Nov 3, 2023
acd5af7
remove old zero shot classifier builder
mehdidc Nov 3, 2023
67bfcaa
fix again zero-shot eval
mehdidc Nov 3, 2023
4206d56
support sharded checkpointing for FSDP to handle large models, following
mehdidc Nov 4, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions src/open_clip/coca_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,25 +154,37 @@ def encode_text(self, text, normalize: bool = True):
text_latent, _ = self._encode_text(text, normalize=normalize)
return text_latent

def forward(
self,
image,
text: Optional[torch.Tensor] = None,
image_latent: Optional[torch.Tensor] = None,
image_embs: Optional[torch.Tensor] = None,
):
if image_latent is None or image_embs is None:
image_latent, image_embs = self._encode_image(image)
def forward(self, image=None, text=None, embed_cls=True, image_latent=None, image_embs=None, clamp_logit_scale_to=0):

if text is not None:
text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls)
else:
text_latent, token_embs = None, None


if image is not None:
if image_latent is None or image_embs is None:
image_latent, image_embs = self._encode_image(image)
else:
image_latent = None
image_embs = None

if text is None:
return {"image_features": image_latent, "image_embs": image_embs}

text_latent, token_embs = self._encode_text(text)

# TODO: add assertion to avoid bugs?
labels = text[:, -token_embs.shape[1]:]
if text is not None and token_embs is not None:
labels = text[:, -token_embs.shape[1]:]
logits = self.text_decoder(image_embs, token_embs)
else:
labels = None
logits = None

logits = self.text_decoder(image_embs, token_embs)
if clamp_logit_scale_to:
with torch.no_grad():
self.logit_scale.data.clamp_(0, clamp_logit_scale_to)
out_dict = {
"image_features": image_latent,
"text_features": text_latent,
Expand Down
15 changes: 8 additions & 7 deletions src/open_clip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,14 +285,12 @@ def encode_text(self, text, normalize: bool = False):

return F.normalize(x, dim=-1) if normalize else x

def forward(
self,
image: Optional[torch.Tensor] = None,
text: Optional[torch.Tensor] = None,
):
def forward(self, image=None, text=None, clamp_logit_scale_to:float=0):
image_features = self.encode_image(image, normalize=True) if image is not None else None
text_features = self.encode_text(text, normalize=True) if text is not None else None

if clamp_logit_scale_to:
with torch.no_grad():
self.logit_scale.data.clamp_(0, clamp_logit_scale_to)
if self.output_dict:
out_dict = {
"image_features": image_features,
Expand Down Expand Up @@ -358,10 +356,13 @@ def forward(
self,
image: Optional[torch.Tensor] = None,
text: Optional[torch.Tensor] = None,
clamp_logit_scale_to: float = 0,
):
image_features = self.encode_image(image, normalize=True) if image is not None else None
text_features = self.encode_text(text, normalize=True) if text is not None else None

if clamp_logit_scale_to:
with torch.no_grad():
self.logit_scale.data.clamp_(0, clamp_logit_scale_to)
if self.output_dict:
out_dict = {
"image_features": image_features,
Expand Down
3 changes: 2 additions & 1 deletion src/open_clip/zero_shot_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def _process_batch(batch_classnames):
num_batch_classes = len(batch_classnames)
texts = [template.format(c) if use_format else template(c) for c in batch_classnames for template in templates]
texts = tokenizer(texts).to(device)
class_embeddings = model.encode_text(texts, normalize=True)
output = model(text=texts)
class_embeddings = output['text_features'] if isinstance(output, dict) else output[1]
class_embeddings = class_embeddings.reshape(num_batch_classes, num_templates, -1).mean(dim=1)
class_embeddings = class_embeddings / class_embeddings.norm(dim=1, keepdim=True)
class_embeddings = class_embeddings.T
Expand Down
Loading
Loading