From 67a0e1342c18fd716214afbde00a4f79c92f814d Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Tue, 31 Jan 2023 11:46:06 +0100 Subject: [PATCH 01/37] resolve conflicts --- src/training/main.py | 41 +++++++++++++++++++++++++++++++++-------- src/training/params.py | 6 ++++++ 2 files changed, 39 insertions(+), 8 deletions(-) diff --git a/src/training/main.py b/src/training/main.py index 94496999f..091c99e07 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -12,6 +12,8 @@ import torch from torch import optim from torch.cuda.amp import GradScaler +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, CPUOffload, CPUOffload, MixedPrecision +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy try: import wandb @@ -29,6 +31,8 @@ hvd = None from open_clip import create_model_and_transforms, trace_model, get_tokenizer, create_loss +from open_clip.transformer import VisionTransformer, TextTransformer +from open_clip.model import CLIP from training.data import get_data from training.distributed import is_master, init_distributed_device, broadcast_object from training.logger import setup_logging @@ -292,14 +296,35 @@ def main(args): if args.distributed and not args.horovod: if args.use_bn_sync: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) - ddp_args = {} - if args.ddp_static_graph: - # this doesn't exist in older PyTorch, arg only added if enabled - ddp_args['static_graph'] = True - model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], **ddp_args) - - if args.distill: - dist_model = torch.nn.parallel.DistributedDataParallel(dist_model, device_ids=[device], **ddp_args) + if args.distributed_engine == 'ddp': + ddp_args = {} + if args.ddp_static_graph: + # this doesn't exist in older PyTorch, arg only added if enabled + ddp_args['static_graph'] = True + elif args.distributed_engine == 'fsdp': + mp = MixedPrecision( + # param_dtype=torch.bfloat16, + reduce_dtype=torch.bfloat16, + # buffer_dtype=torch.bfloat16, + ) + wrapper_kwargs = dict( + mixed_precision=mp, + limit_all_gathers=True, + auto_wrap_policy=partial( + transformer_auto_wrap_policy, + transformer_layer_cls={ + VisionTransformer, + TextTransformer, + CLIP, + }, + ), + ) + model = FSDP(model, device_id=device, **wrapper_kwargs) + print(f"After FSTP parameter num: {sum(p.numel() for p in model.parameters())}") + print(f"After FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB") + else: + print("--distrubted_engine should be either 'ddp or 'fsdp'") + sys.exit(1) # create optimizer and scaler optimizer = None diff --git a/src/training/params.py b/src/training/params.py index 3ea5a8f3b..29083481e 100644 --- a/src/training/params.py +++ b/src/training/params.py @@ -358,6 +358,12 @@ def parse_args(args): action='store_true', help="Enable static graph optimization for DDP in PyTorch >= 1.11.", ) + parser.add_argument( + "--distributed_engine", + type=str, + default="ddp", + choices=["ddp", "fsdp"], + ) parser.add_argument( "--no-set-device-rank", default=False, From 173cba4163daebcafa3fcb32db6d0d5311e8db50 Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Fri, 6 Jan 2023 17:31:33 +0100 Subject: [PATCH 02/37] show before fsdp memory usage --- src/training/main.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/training/main.py b/src/training/main.py index 091c99e07..8fcda9535 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -302,6 +302,8 @@ def main(args): # this doesn't exist in older PyTorch, arg only added if enabled ddp_args['static_graph'] = True elif args.distributed_engine == 'fsdp': + print(f"Before FSTP parameter num: {sum(p.numel() for p in model.parameters())}") + print(f"Before FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB") mp = MixedPrecision( # param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, @@ -310,13 +312,13 @@ def main(args): wrapper_kwargs = dict( mixed_precision=mp, limit_all_gathers=True, - auto_wrap_policy=partial( - transformer_auto_wrap_policy, - transformer_layer_cls={ - VisionTransformer, - TextTransformer, - CLIP, - }, + auto_wrap_policy=partial( + transformer_auto_wrap_policy, + transformer_layer_cls={ + VisionTransformer, + TextTransformer, + CLIP, + }, ), ) model = FSDP(model, device_id=device, **wrapper_kwargs) From a45acae631a1c8f3013e2a0da778a82144cb5764 Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Fri, 6 Jan 2023 17:40:32 +0100 Subject: [PATCH 03/37] add ddp again --- src/training/main.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/training/main.py b/src/training/main.py index 8fcda9535..0ac46873d 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -301,6 +301,7 @@ def main(args): if args.ddp_static_graph: # this doesn't exist in older PyTorch, arg only added if enabled ddp_args['static_graph'] = True + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], **ddp_args) elif args.distributed_engine == 'fsdp': print(f"Before FSTP parameter num: {sum(p.numel() for p in model.parameters())}") print(f"Before FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB") @@ -310,16 +311,17 @@ def main(args): # buffer_dtype=torch.bfloat16, ) wrapper_kwargs = dict( - mixed_precision=mp, - limit_all_gathers=True, - auto_wrap_policy=partial( - transformer_auto_wrap_policy, - transformer_layer_cls={ - VisionTransformer, - TextTransformer, - CLIP, - }, - ), + #mixed_precision=mp, + #limit_all_gathers=True, + + #auto_wrap_policy=partial( + # transformer_auto_wrap_policy, + # transformer_layer_cls={ + # VisionTransformer, + # TextTransformer, + # CLIP, + # }, + #), ) model = FSDP(model, device_id=device, **wrapper_kwargs) print(f"After FSTP parameter num: {sum(p.numel() for p in model.parameters())}") From fa80396d629b914a0dc6c0b043f44d84da6fef5c Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Tue, 31 Jan 2023 11:46:38 +0100 Subject: [PATCH 04/37] resolve conflicts --- src/training/main.py | 47 ++++++++++++++++++++++++++++++-------------- 1 file changed, 32 insertions(+), 15 deletions(-) diff --git a/src/training/main.py b/src/training/main.py index 0ac46873d..4d3009b60 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -31,7 +31,7 @@ hvd = None from open_clip import create_model_and_transforms, trace_model, get_tokenizer, create_loss -from open_clip.transformer import VisionTransformer, TextTransformer +from open_clip.transformer import VisionTransformer, TextTransformer, ResidualAttentionBlock from open_clip.model import CLIP from training.data import get_data from training.distributed import is_master, init_distributed_device, broadcast_object @@ -85,7 +85,6 @@ def main(args): # fully initialize distributed device environment device = init_distributed_device(args) - # get the name of the experiments if args.name is None: # sanitize model name for filesystem / uri use, easier if we don't use / in name as a rule? @@ -280,7 +279,8 @@ def main(args): freeze_layer_norm=args.lock_text_freeze_layer_norm) if args.grad_checkpointing: - model.set_grad_checkpointing() + if args.distributed_engine != 'fsdp': + model.set_grad_checkpointing() if is_master(args): logging.info("Model:") @@ -303,29 +303,46 @@ def main(args): ddp_args['static_graph'] = True model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], **ddp_args) elif args.distributed_engine == 'fsdp': + print(f"Before FSTP parameter num: {sum(p.numel() for p in model.parameters())}") print(f"Before FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB") mp = MixedPrecision( - # param_dtype=torch.bfloat16, + #param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, - # buffer_dtype=torch.bfloat16, + #buffer_dtype=torch.bfloat16, ) wrapper_kwargs = dict( - #mixed_precision=mp, - #limit_all_gathers=True, + mixed_precision=mp, + limit_all_gathers=True, - #auto_wrap_policy=partial( - # transformer_auto_wrap_policy, - # transformer_layer_cls={ - # VisionTransformer, - # TextTransformer, - # CLIP, - # }, - #), + auto_wrap_policy=partial( + transformer_auto_wrap_policy, + transformer_layer_cls={ + VisionTransformer, + TextTransformer, + CLIP, + }, + ), ) model = FSDP(model, device_id=device, **wrapper_kwargs) print(f"After FSTP parameter num: {sum(p.numel() for p in model.parameters())}") print(f"After FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB") + if args.grad_checkpointing: + #https://pytorch.org/blog/efficient-large-scale-training-with-pytorch/ + from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + checkpoint_wrapper, + CheckpointImpl, + apply_activation_checkpointing, + ) + non_reentrant_wrapper = partial( + checkpoint_wrapper, + offload_to_cpu=False, + checkpoint_impl=CheckpointImpl.NO_REENTRANT, + ) + check_fn = lambda submodule: isinstance(submodule, ResidualAttentionBlock) + apply_activation_checkpointing( + model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn + ) else: print("--distrubted_engine should be either 'ddp or 'fsdp'") sys.exit(1) From 9f967b7dc4aac2452cc7d2ee40d460b88ba6ddde Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Tue, 31 Jan 2023 11:47:34 +0100 Subject: [PATCH 05/37] resolve conflicts --- src/training/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/training/train.py b/src/training/train.py index a48a34593..69eca63a4 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -185,8 +185,8 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist accum_images, accum_texts, accum_features = [], [], {} # Note: we clamp to 4.6052 = ln(100), as in the original paper. - with torch.no_grad(): - unwrap_model(model).logit_scale.clamp_(0, math.log(100)) + #with torch.no_grad(): + # unwrap_model(model).logit_scale.clamp_(0, math.log(100)) batch_time_m.update(time.time() - end) end = time.time() From 1832e13131674fd7961c1a225d356fa1ee4ebc66 Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Tue, 31 Jan 2023 11:48:33 +0100 Subject: [PATCH 06/37] resolve conflicts --- src/open_clip/model.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/open_clip/model.py b/src/open_clip/model.py index fe3aa31c9..5e3a49644 100644 --- a/src/open_clip/model.py +++ b/src/open_clip/model.py @@ -285,14 +285,12 @@ def encode_text(self, text, normalize: bool = False): return F.normalize(x, dim=-1) if normalize else x - def forward( - self, - image: Optional[torch.Tensor] = None, - text: Optional[torch.Tensor] = None, - ): + def forward(self, image, text, clamp_logit_scale_to=None): image_features = self.encode_image(image, normalize=True) if image is not None else None text_features = self.encode_text(text, normalize=True) if text is not None else None - + if clamp_logit_scale_to is not None: + with torch.no_grad(): + self.logit_scale.data.clamp_(0, clamp_logit_scale_to) if self.output_dict: out_dict = { "image_features": image_features, From 9d5369e1a7ced51d412a2c87a277b275a8b85546 Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Sat, 7 Jan 2023 04:01:59 +0100 Subject: [PATCH 07/37] minor --- src/training/main.py | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/src/training/main.py b/src/training/main.py index 4d3009b60..506aff201 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -294,6 +294,7 @@ def main(args): f.write(f"{name}: {val}\n") if args.distributed and not args.horovod: + if args.use_bn_sync: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) if args.distributed_engine == 'ddp': @@ -303,7 +304,10 @@ def main(args): ddp_args['static_graph'] = True model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], **ddp_args) elif args.distributed_engine == 'fsdp': - + from torch.distributed.fsdp.wrap import ( + enable_wrap, + wrap, + ) print(f"Before FSTP parameter num: {sum(p.numel() for p in model.parameters())}") print(f"Before FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB") mp = MixedPrecision( @@ -315,15 +319,23 @@ def main(args): mixed_precision=mp, limit_all_gathers=True, - auto_wrap_policy=partial( - transformer_auto_wrap_policy, - transformer_layer_cls={ - VisionTransformer, - TextTransformer, - CLIP, - }, - ), + #auto_wrap_policy=partial( + # transformer_auto_wrap_policy, + # transformer_layer_cls={ + # VisionTransformer, + # TextTransformer, + # CLIP, + # }, + #), ) + + # avoid "RuntimeError: The tensor has a non-zero number of elements, but its data is not allocated yet. Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data() to actually allocate memory." + #model.transformer = FSDP(model.transformer, device_id=device) + #model.token_embedding = FSDP(model.token_embedding, device_id=device) + #model.tp = FSDP(model.tp, device_id=device) + #model.visual = FSDP(model.visual, device_id=device) + #model.text_projection = FSDP(model.text_projection) ??? + #model.ln_final = FSDP(model.ln_final, device_id=device) model = FSDP(model, device_id=device, **wrapper_kwargs) print(f"After FSTP parameter num: {sum(p.numel() for p in model.parameters())}") print(f"After FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB") From 08016d0d0807375c408b7ee76dcb90f1ab5bf7bc Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Sat, 7 Jan 2023 04:02:29 +0100 Subject: [PATCH 08/37] fix logit scale and eval issues on FSDP --- src/training/train.py | 9 ++++++--- src/training/zero_shot.py | 32 ++++++++++++++++++++++++++++++-- 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/src/training/train.py b/src/training/train.py index 69eca63a4..5552b31c3 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -185,8 +185,11 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist accum_images, accum_texts, accum_features = [], [], {} # Note: we clamp to 4.6052 = ln(100), as in the original paper. - #with torch.no_grad(): - # unwrap_model(model).logit_scale.clamp_(0, math.log(100)) + if args.distributed_engine == 'fsdp': + model(image=None, text=None, clamp_logit_scale_to=math.log(100)) + else: + with torch.no_grad(): + unwrap_model(model).logit_scale.clamp_(0, math.log(100)) batch_time_m.update(time.time() - end) end = time.time() @@ -250,7 +253,7 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist def evaluate(model, data, epoch, args, tb_writer=None, tokenizer=None): metrics = {} - if not is_master(args): + if not is_master(args) and args.distributed_engine != 'fsdp': return metrics device = torch.device(args.device) model.eval() diff --git a/src/training/zero_shot.py b/src/training/zero_shot.py index 06ce7ac09..8d957febe 100644 --- a/src/training/zero_shot.py +++ b/src/training/zero_shot.py @@ -6,6 +6,28 @@ from open_clip import get_input_dtype, get_tokenizer, build_zero_shot_classifier, \ IMAGENET_CLASSNAMES, OPENAI_IMAGENET_TEMPLATES from .precision import get_autocast +from .imagenet_zeroshot_data import imagenet_classnames, openai_imagenet_template + + +def zero_shot_classifier(model, classnames, templates, args): + tokenizer = get_tokenizer(args.model) + with torch.no_grad(): + zeroshot_weights = [] + for classname in tqdm(classnames): + texts = [template(classname) for template in templates] # format with class + texts = tokenizer(texts).to(args.device) # tokenize + if args.distributed and not args.horovod: + if args.distributed_engine == 'fsdp': + _, class_embeddings, _ = model(image=None, text=texts) + else: + class_embeddings = model.module.encode_text(texts) + else: + class_embeddings = model.encode_text(texts) + class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) + class_embedding /= class_embedding.norm() + zeroshot_weights.append(class_embedding) + zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(args.device) + return zeroshot_weights def accuracy(output, target, topk=(1,)): @@ -26,8 +48,14 @@ def run(model, classifier, dataloader, args): with autocast(): # predict - output = model(image=images) - image_features = output['image_features'] if isinstance(output, dict) else output[0] + if args.distributed and not args.horovod: + if args.distributed_engine == 'fsdp': + image_features, _, _ = model(image=images, text=None) + else: + image_features = model.module.encode_image(images) + else: + image_features = model.encode_image(images) + image_features = F.normalize(image_features, dim=-1) logits = 100. * image_features @ classifier # measure accuracy From 8820831c1c7409e54be1b40b5b26357b71f74210 Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Sat, 7 Jan 2023 13:50:43 +0100 Subject: [PATCH 09/37] support cpu offload --- src/training/main.py | 27 ++++++++++++++------------- src/training/params.py | 12 +++++++++++- 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/src/training/main.py b/src/training/main.py index 506aff201..663296fe9 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -14,7 +14,7 @@ from torch.cuda.amp import GradScaler from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, CPUOffload, CPUOffload, MixedPrecision from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy - +from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload try: import wandb except ImportError: @@ -227,7 +227,7 @@ def main(args): args.model, args.pretrained, precision=args.precision, - device=device, + device='cpu' if args.fsdp_init_on_cpu else device, jit=args.torchscript, force_quick_gelu=args.force_quick_gelu, force_custom_text=args.force_custom_text, @@ -318,15 +318,16 @@ def main(args): wrapper_kwargs = dict( mixed_precision=mp, limit_all_gathers=True, - - #auto_wrap_policy=partial( - # transformer_auto_wrap_policy, - # transformer_layer_cls={ - # VisionTransformer, - # TextTransformer, - # CLIP, - # }, - #), + cpu_offload=CPUOffload(offload_params=args.fsdp_cpu_offload), + auto_wrap_policy=partial( + transformer_auto_wrap_policy, + transformer_layer_cls={ + VisionTransformer, + TextTransformer, + CLIP, + }, + ), + device_id=None if args.fsdp_init_on_cpu else device, ) # avoid "RuntimeError: The tensor has a non-zero number of elements, but its data is not allocated yet. Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data() to actually allocate memory." @@ -336,7 +337,7 @@ def main(args): #model.visual = FSDP(model.visual, device_id=device) #model.text_projection = FSDP(model.text_projection) ??? #model.ln_final = FSDP(model.ln_final, device_id=device) - model = FSDP(model, device_id=device, **wrapper_kwargs) + model = FSDP(model, **wrapper_kwargs) print(f"After FSTP parameter num: {sum(p.numel() for p in model.parameters())}") print(f"After FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB") if args.grad_checkpointing: @@ -348,7 +349,7 @@ def main(args): ) non_reentrant_wrapper = partial( checkpoint_wrapper, - offload_to_cpu=False, + offload_to_cpu=args.fsdp_cpu_offload, checkpoint_impl=CheckpointImpl.NO_REENTRANT, ) check_fn = lambda submodule: isinstance(submodule, ResidualAttentionBlock) diff --git a/src/training/params.py b/src/training/params.py index 29083481e..126367251 100644 --- a/src/training/params.py +++ b/src/training/params.py @@ -359,11 +359,21 @@ def parse_args(args): help="Enable static graph optimization for DDP in PyTorch >= 1.11.", ) parser.add_argument( - "--distributed_engine", + "--distributed-engine", type=str, default="ddp", choices=["ddp", "fsdp"], ) + parser.add_argument( + "--fsdp-init-on-cpu", + default=False, + action="store_true", + ) + parser.add_argument( + "--fsdp-cpu-offload", + default=False, + action="store_true", + ) parser.add_argument( "--no-set-device-rank", default=False, From 188bc9c6e44cc664c337e435d2c737a31e0616ff Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Sat, 7 Jan 2023 16:43:34 +0100 Subject: [PATCH 10/37] wrap residual blocks with FSDP --- src/training/main.py | 4 ++-- src/training/params.py | 5 +++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/training/main.py b/src/training/main.py index 663296fe9..fcc0e23e2 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -317,14 +317,14 @@ def main(args): ) wrapper_kwargs = dict( mixed_precision=mp, - limit_all_gathers=True, + limit_all_gathers=args.fsdp_limit_allgathers, cpu_offload=CPUOffload(offload_params=args.fsdp_cpu_offload), auto_wrap_policy=partial( transformer_auto_wrap_policy, transformer_layer_cls={ VisionTransformer, TextTransformer, - CLIP, + ResidualAttentionBlock, }, ), device_id=None if args.fsdp_init_on_cpu else device, diff --git a/src/training/params.py b/src/training/params.py index 126367251..6a709677e 100644 --- a/src/training/params.py +++ b/src/training/params.py @@ -374,6 +374,11 @@ def parse_args(args): default=False, action="store_true", ) + parser.add_argument( + "--fsdp-limit-allgathers", + default=False, + action="store_true", + ) parser.add_argument( "--no-set-device-rank", default=False, From 2782ab141c1ff82113287e1fbf0f93caa315a2f2 Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Sun, 8 Jan 2023 15:40:09 +0100 Subject: [PATCH 11/37] add forward trick to CustomCLIP --- src/training/main.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/training/main.py b/src/training/main.py index fcc0e23e2..c7e349cf2 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -309,6 +309,9 @@ def main(args): wrap, ) print(f"Before FSTP parameter num: {sum(p.numel() for p in model.parameters())}") + print(f"Before FSTP VISUAL parameter num: {sum(p.numel() for p in model.visual.parameters())}") + #print(f"Before FSTP TEXT parameter num: {sum(p.numel() for p in model.transformer.parameters())}") + print(f"Before FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB") mp = MixedPrecision( #param_dtype=torch.bfloat16, @@ -327,7 +330,7 @@ def main(args): ResidualAttentionBlock, }, ), - device_id=None if args.fsdp_init_on_cpu else device, + device_id=device, ) # avoid "RuntimeError: The tensor has a non-zero number of elements, but its data is not allocated yet. Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data() to actually allocate memory." From afd8ef35126e448a351e456bc92a3660d748bb21 Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Tue, 31 Jan 2023 12:20:07 +0100 Subject: [PATCH 12/37] test_training_clip_with_jit test error --- src/open_clip/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/open_clip/model.py b/src/open_clip/model.py index 5e3a49644..42e4aa934 100644 --- a/src/open_clip/model.py +++ b/src/open_clip/model.py @@ -285,10 +285,10 @@ def encode_text(self, text, normalize: bool = False): return F.normalize(x, dim=-1) if normalize else x - def forward(self, image, text, clamp_logit_scale_to=None): + def forward(self, image, text, clamp_logit_scale_to:float=0): image_features = self.encode_image(image, normalize=True) if image is not None else None text_features = self.encode_text(text, normalize=True) if text is not None else None - if clamp_logit_scale_to is not None: + if clamp_logit_scale_to: with torch.no_grad(): self.logit_scale.data.clamp_(0, clamp_logit_scale_to) if self.output_dict: From 6627268d60d6b7967edea109488cedb79bd4004d Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Tue, 31 Jan 2023 17:30:52 +0100 Subject: [PATCH 13/37] select layers to wrap in FSDP and grad checkpointing --- src/training/main.py | 24 +++++++++++++++--------- src/training/params.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 9 deletions(-) diff --git a/src/training/main.py b/src/training/main.py index c7e349cf2..aaf713220 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -31,7 +31,6 @@ hvd = None from open_clip import create_model_and_transforms, trace_model, get_tokenizer, create_loss -from open_clip.transformer import VisionTransformer, TextTransformer, ResidualAttentionBlock from open_clip.model import CLIP from training.data import get_data from training.distributed import is_master, init_distributed_device, broadcast_object @@ -294,7 +293,6 @@ def main(args): f.write(f"{name}: {val}\n") if args.distributed and not args.horovod: - if args.use_bn_sync: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) if args.distributed_engine == 'ddp': @@ -318,21 +316,23 @@ def main(args): reduce_dtype=torch.bfloat16, #buffer_dtype=torch.bfloat16, ) + layers = set() + for module in model.modules(): + name = module.__class__.__name__ + for layer in args.fsdp_layers_to_wrap: + if re.match(layer, name): + layers.add(module.__class__) + print("Wrapped layers", layers) wrapper_kwargs = dict( mixed_precision=mp, limit_all_gathers=args.fsdp_limit_allgathers, cpu_offload=CPUOffload(offload_params=args.fsdp_cpu_offload), auto_wrap_policy=partial( transformer_auto_wrap_policy, - transformer_layer_cls={ - VisionTransformer, - TextTransformer, - ResidualAttentionBlock, - }, + transformer_layer_cls=layers, ), device_id=device, ) - # avoid "RuntimeError: The tensor has a non-zero number of elements, but its data is not allocated yet. Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data() to actually allocate memory." #model.transformer = FSDP(model.transformer, device_id=device) #model.token_embedding = FSDP(model.token_embedding, device_id=device) @@ -345,6 +345,12 @@ def main(args): print(f"After FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB") if args.grad_checkpointing: #https://pytorch.org/blog/efficient-large-scale-training-with-pytorch/ + layers_grad_checkpoint = set() + for module in model.modules(): + name = module.__class__.__name__ + for layer in args.fsdp_layers_to_grad_checkpoint: + if re.match(layer, name): + layers_grad_checkpoint.add(module.__class__) from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( checkpoint_wrapper, CheckpointImpl, @@ -355,7 +361,7 @@ def main(args): offload_to_cpu=args.fsdp_cpu_offload, checkpoint_impl=CheckpointImpl.NO_REENTRANT, ) - check_fn = lambda submodule: isinstance(submodule, ResidualAttentionBlock) + check_fn = lambda submodule: (any(isinstance(submodule, layer) for layer in layers_grad_checkpoint)) apply_activation_checkpointing( model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn ) diff --git a/src/training/params.py b/src/training/params.py index 6a709677e..d4ca1790f 100644 --- a/src/training/params.py +++ b/src/training/params.py @@ -23,6 +23,8 @@ def __call__(self, parser, namespace, values, option_string=None): setattr(namespace, self.dest, kw) + + def parse_args(args): parser = argparse.ArgumentParser() parser.add_argument( @@ -379,6 +381,34 @@ def parse_args(args): default=False, action="store_true", ) + parser.add_argument( + "--fsdp-layers-to-wrap", + default=( + # Match all sort of blocks + '.*Block.*', + 'Bottleneck', + # CLIP + 'VisionTransformer', + 'Transformer', + # CLIP ModifiedResNet + 'ModifiedResNet', + # HF Text + 'HFTextEncoder', + # TIMM visual + 'TimmModel', + ), + type=str, + nargs='+' + ) + parser.add_argument( + "--fsdp-layers-to-grad-checkpoint", + default=( + '.*Block.*', + 'Bottleneck', + ), + type=str, + nargs='+' + ) parser.add_argument( "--no-set-device-rank", default=False, From fd42631b5848fcae88c1051b6feb85ae56610126 Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Sat, 4 Feb 2023 05:50:03 +0100 Subject: [PATCH 14/37] support unlocking --- src/training/main.py | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/src/training/main.py b/src/training/main.py index aaf713220..678b13963 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -267,18 +267,18 @@ def main(args): if args.trace: model = trace_model(model, batch_size=args.batch_size, device=device) - if args.lock_image: - # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 - model.lock_image_tower( - unlocked_groups=args.lock_image_unlocked_groups, - freeze_bn_stats=args.lock_image_freeze_bn_stats) - if args.lock_text: - model.lock_text_tower( - unlocked_layers=args.lock_text_unlocked_layers, - freeze_layer_norm=args.lock_text_freeze_layer_norm) - - if args.grad_checkpointing: - if args.distributed_engine != 'fsdp': + if args.distributed_engine != 'fsdp': + if args.lock_image: + # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 + model.lock_image_tower( + unlocked_groups=args.lock_image_unlocked_groups, + freeze_bn_stats=args.lock_image_freeze_bn_stats) + if args.lock_text: + model.lock_text_tower( + unlocked_layers=args.lock_text_unlocked_layers, + freeze_layer_norm=args.lock_text_freeze_layer_norm) + + if args.grad_checkpointing: model.set_grad_checkpointing() if is_master(args): @@ -343,6 +343,15 @@ def main(args): model = FSDP(model, **wrapper_kwargs) print(f"After FSTP parameter num: {sum(p.numel() for p in model.parameters())}") print(f"After FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB") + if args.lock_image: + # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 + model.lock_image_tower( + unlocked_groups=args.lock_image_unlocked_groups, + freeze_bn_stats=args.lock_image_freeze_bn_stats) + if args.lock_text: + model.lock_text_tower( + unlocked_layers=args.lock_text_unlocked_layers, + freeze_layer_norm=args.lock_text_freeze_layer_norm) if args.grad_checkpointing: #https://pytorch.org/blog/efficient-large-scale-training-with-pytorch/ layers_grad_checkpoint = set() @@ -365,6 +374,8 @@ def main(args): apply_activation_checkpointing( model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn ) + + else: print("--distrubted_engine should be either 'ddp or 'fsdp'") sys.exit(1) From 4f65c85ece4e5d93569e8126e8122e38c9039f7d Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Sat, 18 Feb 2023 02:06:53 +0100 Subject: [PATCH 15/37] fix hang after epoch finish --- src/training/main.py | 56 ++++++++++++++++++++------------------------ 1 file changed, 25 insertions(+), 31 deletions(-) diff --git a/src/training/main.py b/src/training/main.py index 678b13963..babc09ec6 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -41,6 +41,7 @@ from training.file_utils import pt_load, check_exists, start_sync_process, remote_sync + LATEST_CHECKPOINT_NAME = "epoch_latest.pt" @@ -70,7 +71,6 @@ def get_latest_checkpoint(path: str, remote : bool): return checkpoints[-1] return None - def main(args): args = parse_args(args) @@ -323,6 +323,7 @@ def main(args): if re.match(layer, name): layers.add(module.__class__) print("Wrapped layers", layers) + wrapper_kwargs = dict( mixed_precision=mp, limit_all_gathers=args.fsdp_limit_allgathers, @@ -340,9 +341,6 @@ def main(args): #model.visual = FSDP(model.visual, device_id=device) #model.text_projection = FSDP(model.text_projection) ??? #model.ln_final = FSDP(model.ln_final, device_id=device) - model = FSDP(model, **wrapper_kwargs) - print(f"After FSTP parameter num: {sum(p.numel() for p in model.parameters())}") - print(f"After FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB") if args.lock_image: # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 model.lock_image_tower( @@ -352,6 +350,9 @@ def main(args): model.lock_text_tower( unlocked_layers=args.lock_text_unlocked_layers, freeze_layer_norm=args.lock_text_freeze_layer_norm) + model = FSDP(model, **wrapper_kwargs) + print(f"After FSTP parameter num: {sum(p.numel() for p in model.parameters())}") + print(f"After FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB") if args.grad_checkpointing: #https://pytorch.org/blog/efficient-large-scale-training-with-pytorch/ layers_grad_checkpoint = set() @@ -374,8 +375,6 @@ def main(args): apply_activation_checkpointing( model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn ) - - else: print("--distrubted_engine should be either 'ddp or 'fsdp'") sys.exit(1) @@ -409,7 +408,6 @@ def main(args): hvd.broadcast_optimizer_state(optimizer, root_rank=0) scaler = GradScaler() if args.precision == "amp" else None - # optionally resume from a checkpoint start_epoch = 0 if args.resume is not None: @@ -430,7 +428,6 @@ def main(args): # loading a bare (model only) checkpoint for fine-tune or evaluation model.load_state_dict(checkpoint) logging.info(f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})") - # initialize datasets tokenizer = get_tokenizer(args.model) data = get_data( @@ -460,9 +457,8 @@ def main(args): logging.error( f'Unknown scheduler, {args.lr_scheduler}. Available options are: cosine, const, const-cooldown.') exit(1) - # determine if this worker should save logs and checkpoints. only do so if it is rank == 0 - args.save_logs = args.logs and args.logs.lower() != 'none' and is_master(args) + args.save_logs = args.logs and args.logs.lower() != 'none' and (is_master(args) or args.distributed_engine == 'fsdp') writer = None if args.save_logs and args.tensorboard: assert tensorboard is not None, "Please install tensorboard." @@ -507,11 +503,9 @@ def main(args): return loss = create_loss(args) - for epoch in range(start_epoch, args.epochs): if is_master(args): logging.info(f'Start epoch {epoch}') - train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist_model, args, tb_writer=writer) completed_epoch = epoch + 1 @@ -528,25 +522,25 @@ def main(args): } if scaler is not None: checkpoint_dict["scaler"] = scaler.state_dict() - - if completed_epoch == args.epochs or ( - args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0 - ): - torch.save( - checkpoint_dict, - os.path.join(args.checkpoint_path, f"epoch_{completed_epoch}.pt"), - ) - if args.delete_previous_checkpoint: - previous_checkpoint = os.path.join(args.checkpoint_path, f"epoch_{completed_epoch - 1}.pt") - if os.path.exists(previous_checkpoint): - os.remove(previous_checkpoint) - - if args.save_most_recent: - # try not to corrupt the latest checkpoint if save fails - tmp_save_path = os.path.join(args.checkpoint_path, "tmp.pt") - latest_save_path = os.path.join(args.checkpoint_path, LATEST_CHECKPOINT_NAME) - torch.save(checkpoint_dict, tmp_save_path) - os.replace(tmp_save_path, latest_save_path) + if is_master(args): + if completed_epoch == args.epochs or ( + args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0 + ): + torch.save( + checkpoint_dict, + os.path.join(args.checkpoint_path, f"epoch_{completed_epoch}.pt"), + ) + if args.delete_previous_checkpoint: + previous_checkpoint = os.path.join(args.checkpoint_path, f"epoch_{completed_epoch - 1}.pt") + if os.path.exists(previous_checkpoint): + os.remove(previous_checkpoint) + + if args.save_most_recent: + # try not to corrupt the latest checkpoint if save fails + tmp_save_path = os.path.join(args.checkpoint_path, "tmp.pt") + latest_save_path = os.path.join(args.checkpoint_path, LATEST_CHECKPOINT_NAME) + torch.save(checkpoint_dict, tmp_save_path) + os.replace(tmp_save_path, latest_save_path) if args.wandb and is_master(args): wandb.finish() From 3bada34c81c7e18b28c5ae2c88d9616021945857 Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Sun, 19 Feb 2023 11:03:58 +0100 Subject: [PATCH 16/37] use `use_orig_params=True` (thanks to @nkflash) to use original parameter names to avoid erroneous parameter decay, and decay params by constructing a set of parameter names to decay before FSDP wrapping (thanks to @rwightman) --- src/training/main.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/src/training/main.py b/src/training/main.py index babc09ec6..d0a4d94f9 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -261,6 +261,9 @@ def main(args): linear_replacement_cls = getattr(bnb.nn.triton_based_modules, args.use_bnb_linear) replace_linear(model, linear_replacement_cls) model = model.to(device) + # Prepare parameters to decay + exclude = lambda n, p: p.ndim < 2 or "bn" in n or "ln" in n or "bias" in n or 'logit_scale' in n + parameters_to_decay = set(n for n, p in model.named_parameters() if not exclude(n,p)) random_seed(args.seed, args.rank) @@ -332,6 +335,7 @@ def main(args): transformer_auto_wrap_policy, transformer_layer_cls=layers, ), + use_orig_params=True, device_id=device, ) # avoid "RuntimeError: The tensor has a non-zero number of elements, but its data is not allocated yet. Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data() to actually allocate memory." @@ -385,14 +389,17 @@ def main(args): if args.train_data or args.dataset_type == "synthetic": assert not args.trace, 'Cannot train with traced model' - - exclude = lambda n, p: p.ndim < 2 or "bn" in n or "ln" in n or "bias" in n or 'logit_scale' in n - include = lambda n, p: not exclude(n, p) - named_parameters = list(model.named_parameters()) - gain_or_bias_params = [p for n, p in named_parameters if exclude(n, p) and p.requires_grad] - rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad] - + if args.distributed_engine == "fsdp": + def _param_name_without_fsdp_prefix(n): + n = n.replace("_fsdp_wrapped_module.", "") + n = n.replace("._checkpoint_wrapped_module", "") + return n + gain_or_bias_params = [p for n, p in named_parameters if _param_name_without_fsdp_prefix(n) not in parameters_to_decay and p.requires_grad] + rest_params = [p for n, p in named_parameters if _param_name_without_fsdp_prefix(n) in parameters_to_decay and p.requires_grad] + else: + gain_or_bias_params = [p for n, p in named_parameters if n not in parameters_to_decay and p.requires_grad] + rest_params = [p for n, p in named_parameters if n in parameters_to_decay and p.requires_grad] optimizer = optim.AdamW( [ {"params": gain_or_bias_params, "weight_decay": 0.}, From f495986440d7fef127cc54bfa019a282592ff1a7 Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Tue, 7 Mar 2023 10:28:37 +0100 Subject: [PATCH 17/37] fix distill --- src/training/main.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/training/main.py b/src/training/main.py index d0a4d94f9..4b3855aa5 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -304,6 +304,8 @@ def main(args): # this doesn't exist in older PyTorch, arg only added if enabled ddp_args['static_graph'] = True model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], **ddp_args) + if args.distill: + dist_model = torch.nn.parallel.DistributedDataParallel(dist_model, device_ids=[device], **ddp_args) elif args.distributed_engine == 'fsdp': from torch.distributed.fsdp.wrap import ( enable_wrap, From 397b8fca61698c987848511fe5e5420dbfb5e926 Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Mon, 13 Mar 2023 12:27:17 +0100 Subject: [PATCH 18/37] fix FSDP optim state save/load so that we save the full optim state dict and we shard the optim state dict after loading --- src/training/main.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/training/main.py b/src/training/main.py index 4b3855aa5..64bb613b5 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -12,9 +12,10 @@ import torch from torch import optim from torch.cuda.amp import GradScaler -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, CPUOffload, CPUOffload, MixedPrecision +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, CPUOffload, CPUOffload, MixedPrecision from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload + try: import wandb except ImportError: @@ -419,6 +420,7 @@ def _param_name_without_fsdp_prefix(n): scaler = GradScaler() if args.precision == "amp" else None # optionally resume from a checkpoint start_epoch = 0 + if args.resume is not None: checkpoint = pt_load(args.resume, map_location='cpu') if 'epoch' in checkpoint: @@ -429,7 +431,11 @@ def _param_name_without_fsdp_prefix(n): sd = {k[len('module.'):]: v for k, v in sd.items()} model.load_state_dict(sd) if optimizer is not None: - optimizer.load_state_dict(checkpoint["optimizer"]) + if args.distributed_engine == 'fsdp': + sharded_state_dict = FSDP.optim_state_dict_to_load(checkpoint["optimizer"], model, optimizer) + optimizer.load_state_dict(sharded_state_dict) + else: + optimizer.load_state_dict(checkpoint["optimizer"]) if scaler is not None and 'scaler' in checkpoint: scaler.load_state_dict(checkpoint['scaler']) logging.info(f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})") @@ -437,6 +443,7 @@ def _param_name_without_fsdp_prefix(n): # loading a bare (model only) checkpoint for fine-tune or evaluation model.load_state_dict(checkpoint) logging.info(f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})") + # initialize datasets tokenizer = get_tokenizer(args.model) data = get_data( @@ -523,11 +530,13 @@ def _param_name_without_fsdp_prefix(n): # Saving checkpoints. if args.save_logs: + checkpoint_dict = { "epoch": completed_epoch, "name": args.name, "state_dict": original_model.state_dict(), - "optimizer": optimizer.state_dict(), + "optimizer": FSDP.optim_state_dict(model, optimizer) if args.distributed_engine == 'fsdp' else optimizer.state_dict() + } if scaler is not None: checkpoint_dict["scaler"] = scaler.state_dict() From f2c72f8091d729eee3e2e6352809b09fed248b45 Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Tue, 14 Mar 2023 05:44:41 +0100 Subject: [PATCH 19/37] offload to cpu when saving checkpoint to avoid OOM --- src/training/main.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/training/main.py b/src/training/main.py index 64bb613b5..e3616cd62 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -12,9 +12,9 @@ import torch from torch import optim from torch.cuda.amp import GradScaler -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, CPUOffload, CPUOffload, MixedPrecision +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, CPUOffload, CPUOffload, MixedPrecision +from torch.distributed.fsdp.api import StateDictType, FullStateDictConfig, FullOptimStateDictConfig from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy -from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload try: import wandb @@ -444,6 +444,13 @@ def _param_name_without_fsdp_prefix(n): model.load_state_dict(checkpoint) logging.info(f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})") + if args.distributed_engine == 'fsdp': + FSDP.set_state_dict_type( + model, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(rank0_only=False, offload_to_cpu=True), + FullOptimStateDictConfig(rank0_only=False, offload_to_cpu=True), + ) # initialize datasets tokenizer = get_tokenizer(args.model) data = get_data( @@ -454,6 +461,7 @@ def _param_name_without_fsdp_prefix(n): ) assert len(data), 'At least one train or eval dataset must be specified.' + # create scheduler if train scheduler = None if 'train' in data and optimizer is not None: From a69c0a762e970d1cc863730c898d880213946198 Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Wed, 17 May 2023 13:19:32 +0200 Subject: [PATCH 20/37] - use the new ModuleWrapPolicy instead of transformer_auto_wrap_policy from pytorch nightly - fix grad checkpointing offloading to be compatible with pytorch nightly - use sync_module_states --- src/training/main.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/training/main.py b/src/training/main.py index e3616cd62..40fdae407 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -14,8 +14,7 @@ from torch.cuda.amp import GradScaler from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, CPUOffload, CPUOffload, MixedPrecision from torch.distributed.fsdp.api import StateDictType, FullStateDictConfig, FullOptimStateDictConfig -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy - +from torch.distributed.fsdp.wrap import ModuleWrapPolicy try: import wandb except ImportError: @@ -334,11 +333,9 @@ def main(args): mixed_precision=mp, limit_all_gathers=args.fsdp_limit_allgathers, cpu_offload=CPUOffload(offload_params=args.fsdp_cpu_offload), - auto_wrap_policy=partial( - transformer_auto_wrap_policy, - transformer_layer_cls=layers, - ), + auto_wrap_policy=ModuleWrapPolicy(layers), use_orig_params=True, + sync_module_states=True, device_id=device, ) # avoid "RuntimeError: The tensor has a non-zero number of elements, but its data is not allocated yet. Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data() to actually allocate memory." @@ -370,17 +367,21 @@ def main(args): layers_grad_checkpoint.add(module.__class__) from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( checkpoint_wrapper, + offload_wrapper, CheckpointImpl, apply_activation_checkpointing, ) non_reentrant_wrapper = partial( checkpoint_wrapper, - offload_to_cpu=args.fsdp_cpu_offload, checkpoint_impl=CheckpointImpl.NO_REENTRANT, ) + if args.fsdp_cpu_offload: + wrapper = lambda module:offload_wrapper(non_reentrant_wrapper(module)) + else: + wrapper = non_reentrant_wrapper check_fn = lambda submodule: (any(isinstance(submodule, layer) for layer in layers_grad_checkpoint)) apply_activation_checkpointing( - model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn + model, checkpoint_wrapper_fn=wrapper, check_fn=check_fn ) else: print("--distrubted_engine should be either 'ddp or 'fsdp'") @@ -403,6 +404,7 @@ def _param_name_without_fsdp_prefix(n): else: gain_or_bias_params = [p for n, p in named_parameters if n not in parameters_to_decay and p.requires_grad] rest_params = [p for n, p in named_parameters if n in parameters_to_decay and p.requires_grad] + optimizer = optim.AdamW( [ {"params": gain_or_bias_params, "weight_decay": 0.}, From 62980cb03547a691d767c1eff294a9bb6f89847a Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Wed, 17 May 2023 13:28:50 +0200 Subject: [PATCH 21/37] use ShardedGradScaler for fsdp, thanks to @nkflash --- src/training/main.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/training/main.py b/src/training/main.py index 40fdae407..7bc21f993 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -12,9 +12,13 @@ import torch from torch import optim from torch.cuda.amp import GradScaler + + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, CPUOffload, CPUOffload, MixedPrecision from torch.distributed.fsdp.api import StateDictType, FullStateDictConfig, FullOptimStateDictConfig from torch.distributed.fsdp.wrap import ModuleWrapPolicy +from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler + try: import wandb except ImportError: @@ -419,10 +423,12 @@ def _param_name_without_fsdp_prefix(n): hvd.broadcast_parameters(model.state_dict(), root_rank=0) hvd.broadcast_optimizer_state(optimizer, root_rank=0) - scaler = GradScaler() if args.precision == "amp" else None + if args.distributed_engine == "fsdp": + scaler = ShardedGradScaler() + else: + scaler = GradScaler() if args.precision == "amp" else None # optionally resume from a checkpoint start_epoch = 0 - if args.resume is not None: checkpoint = pt_load(args.resume, map_location='cpu') if 'epoch' in checkpoint: From 9e47140df7651ce93b2a73ef17621107114a2c1b Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Wed, 17 May 2023 13:39:50 +0200 Subject: [PATCH 22/37] - FSDP printouts: use logging info. - Only import FSDP modules if possible to avoid import error --- src/training/main.py | 42 ++++++++++++++++++++---------------------- 1 file changed, 20 insertions(+), 22 deletions(-) diff --git a/src/training/main.py b/src/training/main.py index 7bc21f993..ef11e56c5 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -14,10 +14,20 @@ from torch.cuda.amp import GradScaler -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, CPUOffload, CPUOffload, MixedPrecision -from torch.distributed.fsdp.api import StateDictType, FullStateDictConfig, FullOptimStateDictConfig -from torch.distributed.fsdp.wrap import ModuleWrapPolicy -from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler +# FSDP +major, minor, *rest = torch.__version__.split(".") +if (int(major), int(minor)) >= (2, 1): + # FSDP is only supported for torch >= 2.1 + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, CPUOffload, MixedPrecision + from torch.distributed.fsdp.api import StateDictType, FullStateDictConfig, FullOptimStateDictConfig + from torch.distributed.fsdp.wrap import ModuleWrapPolicy + from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler + from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + checkpoint_wrapper, + offload_wrapper, + CheckpointImpl, + apply_activation_checkpointing, + ) try: import wandb @@ -311,15 +321,9 @@ def main(args): if args.distill: dist_model = torch.nn.parallel.DistributedDataParallel(dist_model, device_ids=[device], **ddp_args) elif args.distributed_engine == 'fsdp': - from torch.distributed.fsdp.wrap import ( - enable_wrap, - wrap, - ) - print(f"Before FSTP parameter num: {sum(p.numel() for p in model.parameters())}") - print(f"Before FSTP VISUAL parameter num: {sum(p.numel() for p in model.visual.parameters())}") - #print(f"Before FSTP TEXT parameter num: {sum(p.numel() for p in model.transformer.parameters())}") - - print(f"Before FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB") + logging.info(f"Before FSTP parameter num: {sum(p.numel() for p in model.parameters())}") + logging.info(f"Before FSTP VISUAL parameter num: {sum(p.numel() for p in model.visual.parameters())}") + logging.info(f"Before FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB") mp = MixedPrecision( #param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, @@ -331,7 +335,7 @@ def main(args): for layer in args.fsdp_layers_to_wrap: if re.match(layer, name): layers.add(module.__class__) - print("Wrapped layers", layers) + logging.info(f"FSDP Wrapped layers: {layers}") wrapper_kwargs = dict( mixed_precision=mp, @@ -359,8 +363,8 @@ def main(args): unlocked_layers=args.lock_text_unlocked_layers, freeze_layer_norm=args.lock_text_freeze_layer_norm) model = FSDP(model, **wrapper_kwargs) - print(f"After FSTP parameter num: {sum(p.numel() for p in model.parameters())}") - print(f"After FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB") + logging.info(f"After FSTP parameter num: {sum(p.numel() for p in model.parameters())}") + logging.info(f"After FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB") if args.grad_checkpointing: #https://pytorch.org/blog/efficient-large-scale-training-with-pytorch/ layers_grad_checkpoint = set() @@ -369,12 +373,6 @@ def main(args): for layer in args.fsdp_layers_to_grad_checkpoint: if re.match(layer, name): layers_grad_checkpoint.add(module.__class__) - from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - checkpoint_wrapper, - offload_wrapper, - CheckpointImpl, - apply_activation_checkpointing, - ) non_reentrant_wrapper = partial( checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT, From a8d644bf31e437e8904dd83b709a5ad3f51cde3f Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Wed, 17 May 2023 13:58:11 +0200 Subject: [PATCH 23/37] parametrize FSDP mixed precision --- src/training/main.py | 26 ++++++++++++-------------- src/training/params.py | 13 +++++++++++++ 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/src/training/main.py b/src/training/main.py index ef11e56c5..12d901956 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -324,10 +324,17 @@ def main(args): logging.info(f"Before FSTP parameter num: {sum(p.numel() for p in model.parameters())}") logging.info(f"Before FSTP VISUAL parameter num: {sum(p.numel() for p in model.visual.parameters())}") logging.info(f"Before FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB") - mp = MixedPrecision( - #param_dtype=torch.bfloat16, - reduce_dtype=torch.bfloat16, - #buffer_dtype=torch.bfloat16, + type_name_to_class = { + "amp": torch.float16, + "amp_bf16": torch.bfloat16, + "amp_bfloat16": torch.bfloat16, + "fp16": torch.float16, + "fp32": torch.float32, + } + mixed_precision = MixedPrecision( + param_dtype=type_name_to_class[args.precision], + reduce_dtype=type_name_to_class[args.fsdp_reduce_precision], + buffer_dtype=type_name_to_class[args.fsdp_buffer_precision], ) layers = set() for module in model.modules(): @@ -338,7 +345,7 @@ def main(args): logging.info(f"FSDP Wrapped layers: {layers}") wrapper_kwargs = dict( - mixed_precision=mp, + mixed_precision=mixed_precision, limit_all_gathers=args.fsdp_limit_allgathers, cpu_offload=CPUOffload(offload_params=args.fsdp_cpu_offload), auto_wrap_policy=ModuleWrapPolicy(layers), @@ -346,15 +353,7 @@ def main(args): sync_module_states=True, device_id=device, ) - # avoid "RuntimeError: The tensor has a non-zero number of elements, but its data is not allocated yet. Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data() to actually allocate memory." - #model.transformer = FSDP(model.transformer, device_id=device) - #model.token_embedding = FSDP(model.token_embedding, device_id=device) - #model.tp = FSDP(model.tp, device_id=device) - #model.visual = FSDP(model.visual, device_id=device) - #model.text_projection = FSDP(model.text_projection) ??? - #model.ln_final = FSDP(model.ln_final, device_id=device) if args.lock_image: - # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 model.lock_image_tower( unlocked_groups=args.lock_image_unlocked_groups, freeze_bn_stats=args.lock_image_freeze_bn_stats) @@ -366,7 +365,6 @@ def main(args): logging.info(f"After FSTP parameter num: {sum(p.numel() for p in model.parameters())}") logging.info(f"After FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB") if args.grad_checkpointing: - #https://pytorch.org/blog/efficient-large-scale-training-with-pytorch/ layers_grad_checkpoint = set() for module in model.modules(): name = module.__class__.__name__ diff --git a/src/training/params.py b/src/training/params.py index d4ca1790f..fe0e23d4a 100644 --- a/src/training/params.py +++ b/src/training/params.py @@ -409,6 +409,19 @@ def parse_args(args): type=str, nargs='+' ) + parser.add_argument( + "--fsdp-buffer-precision", + choices=["bf16", "fp16", "fp32"], + default="fp32", + help="FSDP floating point precision for buffers" + ) + parser.add_argument( + "--fsdp-reduce-precision", + choices=["bf16", "fp16", "fp32"], + default="fp16", + help="FSDP floating point precision for gradient reduction" + ) + parser.add_argument( "--no-set-device-rank", default=False, From 16013c4335dbaa5c4ebe074c931bb376073dd42f Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Wed, 17 May 2023 14:22:27 +0200 Subject: [PATCH 24/37] use a boolean param args.fsdp to match current args.horovod instead of adding args.distributed_engine --- src/training/main.py | 33 +++++++++++++++------------------ src/training/params.py | 8 ++++---- src/training/train.py | 5 +++-- 3 files changed, 22 insertions(+), 24 deletions(-) diff --git a/src/training/main.py b/src/training/main.py index 12d901956..7d9b77ef4 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -284,7 +284,7 @@ def main(args): if args.trace: model = trace_model(model, batch_size=args.batch_size, device=device) - if args.distributed_engine != 'fsdp': + if not args.fsdp: if args.lock_image: # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 model.lock_image_tower( @@ -312,15 +312,7 @@ def main(args): if args.distributed and not args.horovod: if args.use_bn_sync: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) - if args.distributed_engine == 'ddp': - ddp_args = {} - if args.ddp_static_graph: - # this doesn't exist in older PyTorch, arg only added if enabled - ddp_args['static_graph'] = True - model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], **ddp_args) - if args.distill: - dist_model = torch.nn.parallel.DistributedDataParallel(dist_model, device_ids=[device], **ddp_args) - elif args.distributed_engine == 'fsdp': + if args.fsdp: logging.info(f"Before FSTP parameter num: {sum(p.numel() for p in model.parameters())}") logging.info(f"Before FSTP VISUAL parameter num: {sum(p.numel() for p in model.visual.parameters())}") logging.info(f"Before FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB") @@ -384,8 +376,13 @@ def main(args): model, checkpoint_wrapper_fn=wrapper, check_fn=check_fn ) else: - print("--distrubted_engine should be either 'ddp or 'fsdp'") - sys.exit(1) + ddp_args = {} + if args.ddp_static_graph: + # this doesn't exist in older PyTorch, arg only added if enabled + ddp_args['static_graph'] = True + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], **ddp_args) + if args.distill: + dist_model = torch.nn.parallel.DistributedDataParallel(dist_model, device_ids=[device], **ddp_args) # create optimizer and scaler optimizer = None @@ -394,7 +391,7 @@ def main(args): if args.train_data or args.dataset_type == "synthetic": assert not args.trace, 'Cannot train with traced model' named_parameters = list(model.named_parameters()) - if args.distributed_engine == "fsdp": + if args.fsdp: def _param_name_without_fsdp_prefix(n): n = n.replace("_fsdp_wrapped_module.", "") n = n.replace("._checkpoint_wrapped_module", "") @@ -419,7 +416,7 @@ def _param_name_without_fsdp_prefix(n): hvd.broadcast_parameters(model.state_dict(), root_rank=0) hvd.broadcast_optimizer_state(optimizer, root_rank=0) - if args.distributed_engine == "fsdp": + if args.fsdp: scaler = ShardedGradScaler() else: scaler = GradScaler() if args.precision == "amp" else None @@ -435,7 +432,7 @@ def _param_name_without_fsdp_prefix(n): sd = {k[len('module.'):]: v for k, v in sd.items()} model.load_state_dict(sd) if optimizer is not None: - if args.distributed_engine == 'fsdp': + if args.fsdp: sharded_state_dict = FSDP.optim_state_dict_to_load(checkpoint["optimizer"], model, optimizer) optimizer.load_state_dict(sharded_state_dict) else: @@ -448,7 +445,7 @@ def _param_name_without_fsdp_prefix(n): model.load_state_dict(checkpoint) logging.info(f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})") - if args.distributed_engine == 'fsdp': + if args.fsdp: FSDP.set_state_dict_type( model, StateDictType.FULL_STATE_DICT, @@ -486,7 +483,7 @@ def _param_name_without_fsdp_prefix(n): f'Unknown scheduler, {args.lr_scheduler}. Available options are: cosine, const, const-cooldown.') exit(1) # determine if this worker should save logs and checkpoints. only do so if it is rank == 0 - args.save_logs = args.logs and args.logs.lower() != 'none' and (is_master(args) or args.distributed_engine == 'fsdp') + args.save_logs = args.logs and args.logs.lower() != 'none' and (is_master(args) or args.fsdp) writer = None if args.save_logs and args.tensorboard: assert tensorboard is not None, "Please install tensorboard." @@ -547,7 +544,7 @@ def _param_name_without_fsdp_prefix(n): "epoch": completed_epoch, "name": args.name, "state_dict": original_model.state_dict(), - "optimizer": FSDP.optim_state_dict(model, optimizer) if args.distributed_engine == 'fsdp' else optimizer.state_dict() + "optimizer": FSDP.optim_state_dict(model, optimizer) if args.fsdp else optimizer.state_dict() } if scaler is not None: diff --git a/src/training/params.py b/src/training/params.py index fe0e23d4a..f46b58f32 100644 --- a/src/training/params.py +++ b/src/training/params.py @@ -361,10 +361,10 @@ def parse_args(args): help="Enable static graph optimization for DDP in PyTorch >= 1.11.", ) parser.add_argument( - "--distributed-engine", - type=str, - default="ddp", - choices=["ddp", "fsdp"], + "--fsdp", + default=False, + action="store_true", + help="Use FSDP for distributed training." ) parser.add_argument( "--fsdp-init-on-cpu", diff --git a/src/training/train.py b/src/training/train.py index 5552b31c3..3c5afaeb1 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -9,6 +9,7 @@ import torch.nn.functional as F from torch.nn.parallel.distributed import DistributedDataParallel + try: import wandb except ImportError: @@ -185,7 +186,7 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist accum_images, accum_texts, accum_features = [], [], {} # Note: we clamp to 4.6052 = ln(100), as in the original paper. - if args.distributed_engine == 'fsdp': + if args.fsdp: model(image=None, text=None, clamp_logit_scale_to=math.log(100)) else: with torch.no_grad(): @@ -253,7 +254,7 @@ def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, dist def evaluate(model, data, epoch, args, tb_writer=None, tokenizer=None): metrics = {} - if not is_master(args) and args.distributed_engine != 'fsdp': + if not is_master(args) and not args.fsdp: return metrics device = torch.device(args.device) model.eval() From 7735cace2b640fefcfa4200d47cbde539e7d721d Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Wed, 17 May 2023 14:31:31 +0200 Subject: [PATCH 25/37] replace last args.distributed_engine mention in the code --- src/training/zero_shot.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/training/zero_shot.py b/src/training/zero_shot.py index 8d957febe..eaff6e009 100644 --- a/src/training/zero_shot.py +++ b/src/training/zero_shot.py @@ -17,7 +17,7 @@ def zero_shot_classifier(model, classnames, templates, args): texts = [template(classname) for template in templates] # format with class texts = tokenizer(texts).to(args.device) # tokenize if args.distributed and not args.horovod: - if args.distributed_engine == 'fsdp': + if args.fsdp: _, class_embeddings, _ = model(image=None, text=texts) else: class_embeddings = model.module.encode_text(texts) @@ -49,7 +49,7 @@ def run(model, classifier, dataloader, args): with autocast(): # predict if args.distributed and not args.horovod: - if args.distributed_engine == 'fsdp': + if args.fsdp: image_features, _, _ = model(image=images, text=None) else: image_features = model.module.encode_image(images) From f4165f777c4f291f0aa74e0bf9eba8613c657c2a Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Wed, 17 May 2023 16:47:37 +0200 Subject: [PATCH 26/37] fsdp log on rank zero only --- src/training/main.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/training/main.py b/src/training/main.py index 7d9b77ef4..6e1132cef 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -313,9 +313,9 @@ def main(args): if args.use_bn_sync: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) if args.fsdp: - logging.info(f"Before FSTP parameter num: {sum(p.numel() for p in model.parameters())}") - logging.info(f"Before FSTP VISUAL parameter num: {sum(p.numel() for p in model.visual.parameters())}") - logging.info(f"Before FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB") + if is_master(args): + logging.info(f"Before FSDP number of params: {sum(p.numel() for p in model.parameters())}") + logging.info(f"Before FSDP memory allocated: {torch.cuda.memory_allocated()/1024**3:.3} GB") type_name_to_class = { "amp": torch.float16, "amp_bf16": torch.bfloat16, @@ -354,8 +354,9 @@ def main(args): unlocked_layers=args.lock_text_unlocked_layers, freeze_layer_norm=args.lock_text_freeze_layer_norm) model = FSDP(model, **wrapper_kwargs) - logging.info(f"After FSTP parameter num: {sum(p.numel() for p in model.parameters())}") - logging.info(f"After FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB") + if is_master(args): + logging.info(f"After FSDP number of params: {sum(p.numel() for p in model.parameters())}") + logging.info(f"After FSDP memory allocated: {torch.cuda.memory_allocated()/1024**3:.3} GB") if args.grad_checkpointing: layers_grad_checkpoint = set() for module in model.modules(): From 3aa42f45d667185d98382a8d0f40665c0a0f1cae Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Wed, 17 May 2023 16:53:45 +0200 Subject: [PATCH 27/37] minor --- src/training/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/training/main.py b/src/training/main.py index 6e1132cef..b9ac00f5c 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -315,7 +315,7 @@ def main(args): if args.fsdp: if is_master(args): logging.info(f"Before FSDP number of params: {sum(p.numel() for p in model.parameters())}") - logging.info(f"Before FSDP memory allocated: {torch.cuda.memory_allocated()/1024**3:.3} GB") + logging.info(f"Before FSDP memory allocated: {torch.cuda.memory_allocated()/1024**3:.4} GB") type_name_to_class = { "amp": torch.float16, "amp_bf16": torch.bfloat16, @@ -356,7 +356,7 @@ def main(args): model = FSDP(model, **wrapper_kwargs) if is_master(args): logging.info(f"After FSDP number of params: {sum(p.numel() for p in model.parameters())}") - logging.info(f"After FSDP memory allocated: {torch.cuda.memory_allocated()/1024**3:.3} GB") + logging.info(f"After FSDP memory allocated: {torch.cuda.memory_allocated()/1024**3:.4} GB") if args.grad_checkpointing: layers_grad_checkpoint = set() for module in model.modules(): From 5e167b2fac60d86dfee1abf02e7ce01e97cb892f Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Wed, 17 May 2023 16:56:30 +0200 Subject: [PATCH 28/37] minor --- src/training/main.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/training/main.py b/src/training/main.py index b9ac00f5c..bb745682f 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -334,7 +334,9 @@ def main(args): for layer in args.fsdp_layers_to_wrap: if re.match(layer, name): layers.add(module.__class__) - logging.info(f"FSDP Wrapped layers: {layers}") + + if is_master(args): + logging.info(f"FSDP Wrapped layers: {layers}") wrapper_kwargs = dict( mixed_precision=mixed_precision, @@ -357,6 +359,7 @@ def main(args): if is_master(args): logging.info(f"After FSDP number of params: {sum(p.numel() for p in model.parameters())}") logging.info(f"After FSDP memory allocated: {torch.cuda.memory_allocated()/1024**3:.4} GB") + if args.grad_checkpointing: layers_grad_checkpoint = set() for module in model.modules(): From 5704ada89e02a59c57972352deb662b0b43a5664 Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Thu, 18 May 2023 09:18:45 +0200 Subject: [PATCH 29/37] rank0 only and offload to cpu both true as recommended --- src/training/main.py | 6 +++--- src/training/params.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/training/main.py b/src/training/main.py index bb745682f..753b23984 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -325,7 +325,7 @@ def main(args): } mixed_precision = MixedPrecision( param_dtype=type_name_to_class[args.precision], - reduce_dtype=type_name_to_class[args.fsdp_reduce_precision], + reduce_dtype=type_name_to_class[args.fsdp_gradient_reduction_precision], buffer_dtype=type_name_to_class[args.fsdp_buffer_precision], ) layers = set() @@ -453,8 +453,8 @@ def _param_name_without_fsdp_prefix(n): FSDP.set_state_dict_type( model, StateDictType.FULL_STATE_DICT, - FullStateDictConfig(rank0_only=False, offload_to_cpu=True), - FullOptimStateDictConfig(rank0_only=False, offload_to_cpu=True), + FullStateDictConfig(rank0_only=True, offload_to_cpu=True), + FullOptimStateDictConfig(rank0_only=True, offload_to_cpu=True), ) # initialize datasets tokenizer = get_tokenizer(args.model) diff --git a/src/training/params.py b/src/training/params.py index f46b58f32..2f7be333c 100644 --- a/src/training/params.py +++ b/src/training/params.py @@ -416,7 +416,7 @@ def parse_args(args): help="FSDP floating point precision for buffers" ) parser.add_argument( - "--fsdp-reduce-precision", + "--fsdp-gradient-reduction-precision", choices=["bf16", "fp16", "fp32"], default="fp16", help="FSDP floating point precision for gradient reduction" From ffcf226cb1bea365951d8c9463eed428e475561a Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Thu, 18 May 2023 09:26:20 +0200 Subject: [PATCH 30/37] cli parameters description --- src/training/params.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/training/params.py b/src/training/params.py index 2f7be333c..e03000bc3 100644 --- a/src/training/params.py +++ b/src/training/params.py @@ -370,16 +370,19 @@ def parse_args(args): "--fsdp-init-on-cpu", default=False, action="store_true", + help="Initialize the model on CPUs rather than GPUs, useful for large models", ) parser.add_argument( "--fsdp-cpu-offload", default=False, action="store_true", + help="Use CPU offloading", ) parser.add_argument( "--fsdp-limit-allgathers", default=False, action="store_true", + help="Prevent too many allgathers", ) parser.add_argument( "--fsdp-layers-to-wrap", @@ -398,7 +401,8 @@ def parse_args(args): 'TimmModel', ), type=str, - nargs='+' + nargs='+', + help="Regular expression to match module names to wrap in FSDP, this affects communication and peak memory.", ) parser.add_argument( "--fsdp-layers-to-grad-checkpoint", @@ -407,7 +411,8 @@ def parse_args(args): 'Bottleneck', ), type=str, - nargs='+' + nargs='+', + help="Module names to wrap for gradient checkpointing when FSDP is used", ) parser.add_argument( "--fsdp-buffer-precision", @@ -421,7 +426,6 @@ def parse_args(args): default="fp16", help="FSDP floating point precision for gradient reduction" ) - parser.add_argument( "--no-set-device-rank", default=False, From d3ab217ffe16a65798cc359a54c7b776545f4b53 Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Mon, 22 May 2023 15:11:39 +0200 Subject: [PATCH 31/37] support CoCa models --- src/open_clip/coca_model.py | 34 +++++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index 272b2cc06..8fc7bc08d 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -154,15 +154,20 @@ def encode_text(self, text, normalize: bool = True): text_latent, _ = self._encode_text(text, normalize=normalize) return text_latent - def forward( - self, - image, - text: Optional[torch.Tensor] = None, - image_latent: Optional[torch.Tensor] = None, - image_embs: Optional[torch.Tensor] = None, - ): - if image_latent is None or image_embs is None: - image_latent, image_embs = self._encode_image(image) + def forward(self, image=None, text=None, embed_cls=True, image_latent=None, image_embs=None, clamp_logit_scale_to=0): + + if text is not None: + text_latent, token_embs = self._encode_text(text, embed_cls=embed_cls) + else: + text_latent, token_embs = None, None + + + if image is not None: + if image_latent is None or image_embs is None: + image_latent, image_embs = self._encode_image(image) + else: + image_latent = None + image_embs = None if text is None: return {"image_features": image_latent, "image_embs": image_embs} @@ -170,9 +175,16 @@ def forward( text_latent, token_embs = self._encode_text(text) # TODO: add assertion to avoid bugs? - labels = text[:, -token_embs.shape[1]:] + if text is not None and token_embs is not None: + labels = text[:, -token_embs.shape[1]:] + logits = self.text_decoder(image_embs, token_embs) + else: + labels = None + logits = None - logits = self.text_decoder(image_embs, token_embs) + if clamp_logit_scale_to: + with torch.no_grad(): + self.logit_scale.data.clamp_(0, clamp_logit_scale_to) out_dict = { "image_features": image_latent, "text_features": text_latent, From 86799c25b18135ac74743d0133e926201e0dc1f4 Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Wed, 24 May 2023 08:43:52 +0200 Subject: [PATCH 32/37] fix optimizer resuming in FSDP and remove param/buffer precision --- src/training/main.py | 31 ++++++++++++++++--------------- src/training/params.py | 6 ------ 2 files changed, 16 insertions(+), 21 deletions(-) diff --git a/src/training/main.py b/src/training/main.py index 753b23984..445aa6930 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -87,7 +87,6 @@ def get_latest_checkpoint(path: str, remote : bool): def main(args): args = parse_args(args) - if torch.cuda.is_available(): # This enables tf32 on Ampere GPUs which is only 8% slower than # float16 and almost as accurate as float32 @@ -324,9 +323,7 @@ def main(args): "fp32": torch.float32, } mixed_precision = MixedPrecision( - param_dtype=type_name_to_class[args.precision], reduce_dtype=type_name_to_class[args.fsdp_gradient_reduction_precision], - buffer_dtype=type_name_to_class[args.fsdp_buffer_precision], ) layers = set() for module in model.modules(): @@ -426,6 +423,13 @@ def _param_name_without_fsdp_prefix(n): scaler = GradScaler() if args.precision == "amp" else None # optionally resume from a checkpoint start_epoch = 0 + if args.fsdp: + FSDP.set_state_dict_type( + model, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(rank0_only=False, offload_to_cpu=True), + FullOptimStateDictConfig(rank0_only=False, offload_to_cpu=True), + ) if args.resume is not None: checkpoint = pt_load(args.resume, map_location='cpu') if 'epoch' in checkpoint: @@ -437,7 +441,10 @@ def _param_name_without_fsdp_prefix(n): model.load_state_dict(sd) if optimizer is not None: if args.fsdp: - sharded_state_dict = FSDP.optim_state_dict_to_load(checkpoint["optimizer"], model, optimizer) + optimizer_state_dict = checkpoint["optimizer"] + optimizer_state_dict['state']['logit_scale']['exp_avg'] = optimizer_state_dict['state']['logit_scale']['exp_avg'].view(1) + optimizer_state_dict['state']['logit_scale']['exp_avg_sq'] = optimizer_state_dict['state']['logit_scale']['exp_avg_sq'].view(1) + sharded_state_dict = FSDP.optim_state_dict_to_load(model, optimizer, optimizer_state_dict) optimizer.load_state_dict(sharded_state_dict) else: optimizer.load_state_dict(checkpoint["optimizer"]) @@ -449,13 +456,7 @@ def _param_name_without_fsdp_prefix(n): model.load_state_dict(checkpoint) logging.info(f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})") - if args.fsdp: - FSDP.set_state_dict_type( - model, - StateDictType.FULL_STATE_DICT, - FullStateDictConfig(rank0_only=True, offload_to_cpu=True), - FullOptimStateDictConfig(rank0_only=True, offload_to_cpu=True), - ) + # initialize datasets tokenizer = get_tokenizer(args.model) data = get_data( @@ -543,14 +544,14 @@ def _param_name_without_fsdp_prefix(n): # Saving checkpoints. if args.save_logs: - + checkpoint_dict = { "epoch": completed_epoch, "name": args.name, - "state_dict": original_model.state_dict(), - "optimizer": FSDP.optim_state_dict(model, optimizer) if args.fsdp else optimizer.state_dict() - + "state_dict": model.state_dict(), + "optimizer": FSDP.optim_state_dict(model, optimizer) if args.fsdp else optimizer.state_dict(), } + if scaler is not None: checkpoint_dict["scaler"] = scaler.state_dict() if is_master(args): diff --git a/src/training/params.py b/src/training/params.py index e03000bc3..953ed9a4d 100644 --- a/src/training/params.py +++ b/src/training/params.py @@ -414,12 +414,6 @@ def parse_args(args): nargs='+', help="Module names to wrap for gradient checkpointing when FSDP is used", ) - parser.add_argument( - "--fsdp-buffer-precision", - choices=["bf16", "fp16", "fp32"], - default="fp32", - help="FSDP floating point precision for buffers" - ) parser.add_argument( "--fsdp-gradient-reduction-precision", choices=["bf16", "fp16", "fp32"], From 0859c848008bfa7b822d8bd1a7165c94a3764198 Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Fri, 3 Nov 2023 13:27:35 +0100 Subject: [PATCH 33/37] use original_model instead of model --- src/training/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/training/main.py b/src/training/main.py index 445aa6930..89acf0393 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -548,7 +548,7 @@ def _param_name_without_fsdp_prefix(n): checkpoint_dict = { "epoch": completed_epoch, "name": args.name, - "state_dict": model.state_dict(), + "state_dict": original_model.state_dict(), "optimizer": FSDP.optim_state_dict(model, optimizer) if args.fsdp else optimizer.state_dict(), } From 0a98da27012455e22a24d95408654c61338afad5 Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Fri, 3 Nov 2023 13:29:26 +0100 Subject: [PATCH 34/37] delete old import --- src/training/zero_shot.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/training/zero_shot.py b/src/training/zero_shot.py index eaff6e009..209577d0f 100644 --- a/src/training/zero_shot.py +++ b/src/training/zero_shot.py @@ -6,7 +6,6 @@ from open_clip import get_input_dtype, get_tokenizer, build_zero_shot_classifier, \ IMAGENET_CLASSNAMES, OPENAI_IMAGENET_TEMPLATES from .precision import get_autocast -from .imagenet_zeroshot_data import imagenet_classnames, openai_imagenet_template def zero_shot_classifier(model, classnames, templates, args): From acd5af79cc2b8498383d84d8771c43cf8a5470b8 Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Fri, 3 Nov 2023 14:11:14 +0100 Subject: [PATCH 35/37] remove old zero shot classifier builder --- src/training/zero_shot.py | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/src/training/zero_shot.py b/src/training/zero_shot.py index 209577d0f..f7b8ff493 100644 --- a/src/training/zero_shot.py +++ b/src/training/zero_shot.py @@ -8,27 +8,6 @@ from .precision import get_autocast -def zero_shot_classifier(model, classnames, templates, args): - tokenizer = get_tokenizer(args.model) - with torch.no_grad(): - zeroshot_weights = [] - for classname in tqdm(classnames): - texts = [template(classname) for template in templates] # format with class - texts = tokenizer(texts).to(args.device) # tokenize - if args.distributed and not args.horovod: - if args.fsdp: - _, class_embeddings, _ = model(image=None, text=texts) - else: - class_embeddings = model.module.encode_text(texts) - else: - class_embeddings = model.encode_text(texts) - class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) - class_embedding /= class_embedding.norm() - zeroshot_weights.append(class_embedding) - zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(args.device) - return zeroshot_weights - - def accuracy(output, target, topk=(1,)): pred = output.topk(max(topk), 1, True, True)[1].t() correct = pred.eq(target.view(1, -1).expand_as(pred)) From 67bfcaa6a09cb2fa79bb5a6b2c14cc9dff6db2ba Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Fri, 3 Nov 2023 15:21:49 +0100 Subject: [PATCH 36/37] fix again zero-shot eval --- src/open_clip/model.py | 2 +- src/open_clip/zero_shot_classifier.py | 3 ++- src/training/train.py | 4 ++-- src/training/zero_shot.py | 13 +++---------- 4 files changed, 8 insertions(+), 14 deletions(-) diff --git a/src/open_clip/model.py b/src/open_clip/model.py index 42e4aa934..64084fe07 100644 --- a/src/open_clip/model.py +++ b/src/open_clip/model.py @@ -285,7 +285,7 @@ def encode_text(self, text, normalize: bool = False): return F.normalize(x, dim=-1) if normalize else x - def forward(self, image, text, clamp_logit_scale_to:float=0): + def forward(self, image=None, text=None, clamp_logit_scale_to:float=0): image_features = self.encode_image(image, normalize=True) if image is not None else None text_features = self.encode_text(text, normalize=True) if text is not None else None if clamp_logit_scale_to: diff --git a/src/open_clip/zero_shot_classifier.py b/src/open_clip/zero_shot_classifier.py index 535ec9696..36f99967e 100644 --- a/src/open_clip/zero_shot_classifier.py +++ b/src/open_clip/zero_shot_classifier.py @@ -53,7 +53,8 @@ def _process_batch(batch_classnames): num_batch_classes = len(batch_classnames) texts = [template.format(c) if use_format else template(c) for c in batch_classnames for template in templates] texts = tokenizer(texts).to(device) - class_embeddings = model.encode_text(texts, normalize=True) + output = model(text=texts) + class_embeddings = output['text_features'] if isinstance(output, dict) else output[1] class_embeddings = class_embeddings.reshape(num_batch_classes, num_templates, -1).mean(dim=1) class_embeddings = class_embeddings / class_embeddings.norm(dim=1, keepdim=True) class_embeddings = class_embeddings.T diff --git a/src/training/train.py b/src/training/train.py index 3c5afaeb1..f7ffd54dd 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -338,7 +338,7 @@ def evaluate(model, data, epoch, args, tb_writer=None, tokenizer=None): log_data = {"val/" + name: val for name, val in metrics.items()} - if args.save_logs: + if args.save_logs and is_master(args): if tb_writer is not None: for name, val in log_data.items(): tb_writer.add_scalar(name, val, epoch) @@ -347,7 +347,7 @@ def evaluate(model, data, epoch, args, tb_writer=None, tokenizer=None): f.write(json.dumps(metrics)) f.write("\n") - if args.wandb: + if args.wandb and is_master(args): assert wandb is not None, 'Please install wandb.' if 'train' in data: dataloader = data['train'].dataloader diff --git a/src/training/zero_shot.py b/src/training/zero_shot.py index f7b8ff493..def3a3b20 100644 --- a/src/training/zero_shot.py +++ b/src/training/zero_shot.py @@ -25,15 +25,8 @@ def run(model, classifier, dataloader, args): target = target.to(args.device) with autocast(): - # predict - if args.distributed and not args.horovod: - if args.fsdp: - image_features, _, _ = model(image=images, text=None) - else: - image_features = model.module.encode_image(images) - else: - image_features = model.encode_image(images) - image_features = F.normalize(image_features, dim=-1) + output = model(image=images) + image_features = output['image_features'] if isinstance(output, dict) else output[0] logits = 100. * image_features @ classifier # measure accuracy @@ -54,7 +47,7 @@ def zero_shot_eval(model, data, epoch, args, tokenizer=None): return {} if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs: return {} - if args.distributed and not args.horovod: + if args.distributed and not args.horovod and not args.fsdp: model = model.module logging.info('Starting zero-shot imagenet.') From 4206d5642f16ade9c7f6c903559ea7059c676b12 Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Sat, 4 Nov 2023 18:39:04 +0100 Subject: [PATCH 37/37] support sharded checkpointing for FSDP to handle large models, following dinov2 code --- src/open_clip/model.py | 5 +- src/training/main.py | 199 +++++++++++++++++++++++++++-------------- src/training/params.py | 6 ++ 3 files changed, 143 insertions(+), 67 deletions(-) diff --git a/src/open_clip/model.py b/src/open_clip/model.py index 64084fe07..8262c521c 100644 --- a/src/open_clip/model.py +++ b/src/open_clip/model.py @@ -356,10 +356,13 @@ def forward( self, image: Optional[torch.Tensor] = None, text: Optional[torch.Tensor] = None, + clamp_logit_scale_to: float = 0, ): image_features = self.encode_image(image, normalize=True) if image is not None else None text_features = self.encode_text(text, normalize=True) if text is not None else None - + if clamp_logit_scale_to: + with torch.no_grad(): + self.logit_scale.data.clamp_(0, clamp_logit_scale_to) if self.output_dict: out_dict = { "image_features": image_features, diff --git a/src/training/main.py b/src/training/main.py index 89acf0393..ac5039f65 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -28,7 +28,12 @@ CheckpointImpl, apply_activation_checkpointing, ) - +try: + from fvcore.common.checkpoint import Checkpointer +except ImportError: + has_checkpointer = False +else: + has_checkpointer = True try: import wandb except ImportError: @@ -55,9 +60,9 @@ from training.file_utils import pt_load, check_exists, start_sync_process, remote_sync - -LATEST_CHECKPOINT_NAME = "epoch_latest.pt" - +LATEST_CHECKPOINT = "epoch_latest" +LATEST_CHECKPOINT_NAME = f"{LATEST_CHECKPOINT}.pt" +DISTRIBUTED_CHECKPOINT_NAME_FORMAT = "{name}.rank_{rank}.pt" def random_seed(seed=42, rank=0): torch.manual_seed(seed + rank) @@ -70,7 +75,7 @@ def natural_key(string_): return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] -def get_latest_checkpoint(path: str, remote : bool): +def get_latest_checkpoint(path: str, remote : bool, search_pattern: str = "*.pt"): # as writen, this glob recurses, so can pick up checkpoints across multiple sub-folders if remote: result = subprocess.run(["aws", "s3", "ls", path + "/"], stdout=subprocess.PIPE, stderr=subprocess.PIPE) @@ -79,7 +84,7 @@ def get_latest_checkpoint(path: str, remote : bool): return None checkpoints = [os.path.join(path, x.split(' ')[-1]) for x in result.stdout.decode().split('\n')[:-1]] else: - checkpoints = glob.glob(path + '**/*.pt', recursive=True) + checkpoints = glob.glob(path + f'**/{search_pattern}', recursive=True) if checkpoints: checkpoints = sorted(checkpoints, key=natural_key) return checkpoints[-1] @@ -155,24 +160,31 @@ def main(args): if args.remote_sync_protocol != 's3': print('Error. Sync protocol not supported when using resume latest.') return -1 - if is_master(args): + if is_master(args) or args.fsdp_use_distributed_checkpointer: # Checking for existing checkpoint via master rank only. It is possible for # different rank processes to see different files if a shared file-system is under # stress, however it's very difficult to fully work around such situations. if args.save_most_recent: # if --save-most-recent flag is set, look for latest at a fixed filename - resume_from = os.path.join(checkpoint_path, LATEST_CHECKPOINT_NAME) + resume_from = os.path.join(checkpoint_path, LATEST_CHECKPOINT) + if args.fsdp_use_distributed_checkpointer: + resume_from = DISTRIBUTED_CHECKPOINT_NAME_FORMAT.format(name=resume_from, rank=args.rank) if not os.path.exists(resume_from): # If no latest checkpoint has been saved yet, don't try to resume resume_from = None else: # otherwise, list checkpoint dir contents and pick the newest checkpoint - resume_from = get_latest_checkpoint(checkpoint_path, remote=args.remote_sync is not None) + if args.fsdp_use_distributed_checkpointer: + pattern = DISTRIBUTED_CHECKPOINT_NAME_FORMAT.format(name="*", rank=args.rank) + resume_from = get_latest_checkpoint(checkpoint_path, remote=args.remote_sync is not None, search_pattern=pattern) + # TODO we need to make sure there is consistency between all checkpoints + else: + resume_from = get_latest_checkpoint(checkpoint_path, remote=args.remote_sync is not None, search_pattern="*.pt") if resume_from: logging.info(f'Found latest resume checkpoint at {resume_from}.') else: logging.info(f'No latest resume checkpoint found in {checkpoint_path}.') - if args.distributed: + if args.distributed and not args.fsdp_use_distributed_checkpointer: # sync found checkpoint path to all ranks resume_from = broadcast_object(args, resume_from) args.resume = resume_from @@ -297,7 +309,7 @@ def main(args): if args.grad_checkpointing: model.set_grad_checkpointing() - if is_master(args): + if is_master(args): logging.info("Model:") logging.info(f"{str(model)}") logging.info("Params:") @@ -424,37 +436,48 @@ def _param_name_without_fsdp_prefix(n): # optionally resume from a checkpoint start_epoch = 0 if args.fsdp: - FSDP.set_state_dict_type( - model, - StateDictType.FULL_STATE_DICT, - FullStateDictConfig(rank0_only=False, offload_to_cpu=True), + if args.fsdp_use_distributed_checkpointer: + checkpointer = FSDPCheckpointer(model, args.rank, args.checkpoint_path, optimizer=optimizer, save_to_disk=True) + else: + FSDP.set_state_dict_type( + model, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(rank0_only=False, offload_to_cpu=True), FullOptimStateDictConfig(rank0_only=False, offload_to_cpu=True), - ) + ) + if args.resume is not None: - checkpoint = pt_load(args.resume, map_location='cpu') - if 'epoch' in checkpoint: - # resuming a train checkpoint w/ epoch and optimizer state - start_epoch = checkpoint["epoch"] - sd = checkpoint["state_dict"] - if not args.distributed and next(iter(sd.items()))[0].startswith('module'): - sd = {k[len('module.'):]: v for k, v in sd.items()} - model.load_state_dict(sd) - if optimizer is not None: - if args.fsdp: - optimizer_state_dict = checkpoint["optimizer"] - optimizer_state_dict['state']['logit_scale']['exp_avg'] = optimizer_state_dict['state']['logit_scale']['exp_avg'].view(1) - optimizer_state_dict['state']['logit_scale']['exp_avg_sq'] = optimizer_state_dict['state']['logit_scale']['exp_avg_sq'].view(1) - sharded_state_dict = FSDP.optim_state_dict_to_load(model, optimizer, optimizer_state_dict) - optimizer.load_state_dict(sharded_state_dict) - else: - optimizer.load_state_dict(checkpoint["optimizer"]) - if scaler is not None and 'scaler' in checkpoint: + if args.fsdp_use_distributed_checkpointer: + checkpoint = checkpointer.load(args.resume, checkpointables=[]) + if 'epoch' in checkpoint: + start_epoch = checkpoint["epoch"] + if 'scaler' in checkpoint: scaler.load_state_dict(checkpoint['scaler']) - logging.info(f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})") else: - # loading a bare (model only) checkpoint for fine-tune or evaluation - model.load_state_dict(checkpoint) - logging.info(f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})") + checkpoint = pt_load(args.resume, map_location='cpu') + if 'epoch' in checkpoint: + # resuming a train checkpoint w/ epoch and optimizer state + start_epoch = checkpoint["epoch"] + sd = checkpoint["state_dict"] + if not args.distributed and next(iter(sd.items()))[0].startswith('module'): + sd = {k[len('module.'):]: v for k, v in sd.items()} + model.load_state_dict(sd) + if optimizer is not None: + if args.fsdp: + optimizer_state_dict = checkpoint["optimizer"] + optimizer_state_dict['state']['logit_scale']['exp_avg'] = optimizer_state_dict['state']['logit_scale']['exp_avg'].view(1) + optimizer_state_dict['state']['logit_scale']['exp_avg_sq'] = optimizer_state_dict['state']['logit_scale']['exp_avg_sq'].view(1) + sharded_state_dict = FSDP.optim_state_dict_to_load(model, optimizer, optimizer_state_dict) + optimizer.load_state_dict(sharded_state_dict) + else: + optimizer.load_state_dict(checkpoint["optimizer"]) + if scaler is not None and 'scaler' in checkpoint: + scaler.load_state_dict(checkpoint['scaler']) + else: + # loading a bare (model only) checkpoint for fine-tune or evaluation + model.load_state_dict(checkpoint) + logging.info(f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})") + logging.info(f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})") # initialize datasets @@ -544,35 +567,40 @@ def _param_name_without_fsdp_prefix(n): # Saving checkpoints. if args.save_logs: - - checkpoint_dict = { - "epoch": completed_epoch, - "name": args.name, - "state_dict": original_model.state_dict(), - "optimizer": FSDP.optim_state_dict(model, optimizer) if args.fsdp else optimizer.state_dict(), - } - - if scaler is not None: - checkpoint_dict["scaler"] = scaler.state_dict() - if is_master(args): - if completed_epoch == args.epochs or ( - args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0 - ): - torch.save( - checkpoint_dict, - os.path.join(args.checkpoint_path, f"epoch_{completed_epoch}.pt"), - ) - if args.delete_previous_checkpoint: - previous_checkpoint = os.path.join(args.checkpoint_path, f"epoch_{completed_epoch - 1}.pt") - if os.path.exists(previous_checkpoint): - os.remove(previous_checkpoint) - + if args.fsdp_use_distributed_checkpointer: + if completed_epoch == args.epochs or (args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0): + checkpointer.save(path=f"epoch_{completed_epoch}", name=args.name, epoch=completed_epoch, scaler=scaler.state_dict()) if args.save_most_recent: - # try not to corrupt the latest checkpoint if save fails - tmp_save_path = os.path.join(args.checkpoint_path, "tmp.pt") - latest_save_path = os.path.join(args.checkpoint_path, LATEST_CHECKPOINT_NAME) - torch.save(checkpoint_dict, tmp_save_path) - os.replace(tmp_save_path, latest_save_path) + checkpointer.save(path=LATEST_CHECKPOINT, name=args.name, epoch=completed_epoch, scaler=scaler.state_dict()) + else: + checkpoint_dict = { + "epoch": completed_epoch, + "name": args.name, + "state_dict": original_model.state_dict(), + "optimizer": FSDP.optim_state_dict(model, optimizer) if args.fsdp else optimizer.state_dict(), + } + + if scaler is not None: + checkpoint_dict["scaler"] = scaler.state_dict() + if is_master(args): + if completed_epoch == args.epochs or ( + args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0 + ): + torch.save( + checkpoint_dict, + os.path.join(args.checkpoint_path, f"epoch_{completed_epoch}.pt"), + ) + if args.delete_previous_checkpoint: + previous_checkpoint = os.path.join(args.checkpoint_path, f"epoch_{completed_epoch - 1}.pt") + if os.path.exists(previous_checkpoint): + os.remove(previous_checkpoint) + + if args.save_most_recent: + # try not to corrupt the latest checkpoint if save fails + tmp_save_path = os.path.join(args.checkpoint_path, "tmp.pt") + latest_save_path = os.path.join(args.checkpoint_path, LATEST_CHECKPOINT_NAME) + torch.save(checkpoint_dict, tmp_save_path) + os.replace(tmp_save_path, latest_save_path) if args.wandb and is_master(args): wandb.finish() @@ -609,5 +637,44 @@ def copy_codebase(args): return 1 +if has_checkpointer: + class FSDPCheckpointer(Checkpointer): + + def __init__(self, model, rank, save_dir, **kwargs): + super().__init__(model, save_dir, **kwargs) + self.rank = rank + + def save(self, path: str, **kwargs) -> None: + """ + Dump model and checkpointables to a file. + + Args: + name (str): name of the file. + kwargs (dict): extra arbitrary data to save. + """ + if not self.save_dir or not self.save_to_disk: + return + + data = {} + with FSDP.state_dict_type(self.model, StateDictType.LOCAL_STATE_DICT): + data["model"] = self.model.state_dict() + + # data["model"] = self.model.state_dict() + for key, obj in self.checkpointables.items(): + data[key] = obj.state_dict() + data.update(kwargs) + + basename = DISTRIBUTED_CHECKPOINT_NAME_FORMAT.format(name=path, rank=self.rank) + save_file = os.path.join(self.save_dir, basename) + assert os.path.basename(save_file) == basename, basename + self.logger.info("Saving checkpoint to {}".format(save_file)) + with self.path_manager.open(save_file, "wb") as f: + torch.save(data, f) + + def load(self, *args, **kwargs): + with FSDP.state_dict_type(self.model, StateDictType.LOCAL_STATE_DICT): + return super().load(*args, **kwargs) + + if __name__ == "__main__": main(sys.argv[1:]) diff --git a/src/training/params.py b/src/training/params.py index 953ed9a4d..384fe0907 100644 --- a/src/training/params.py +++ b/src/training/params.py @@ -372,6 +372,12 @@ def parse_args(args): action="store_true", help="Initialize the model on CPUs rather than GPUs, useful for large models", ) + parser.add_argument( + "--fsdp-use-distributed-checkpointer", + default=False, + action="store_true", + help="Use distributed checkpointer for FSDP, useful for large models", + ) parser.add_argument( "--fsdp-cpu-offload", default=False,