Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issues with FineTuning Checkpoint #2

Open
Advaid-Deepak opened this issue Apr 18, 2024 · 2 comments
Open

Issues with FineTuning Checkpoint #2

Advaid-Deepak opened this issue Apr 18, 2024 · 2 comments

Comments

@Advaid-Deepak
Copy link

We were trying to finetune a Matformer checkpoint ( MatFormer-OLMo-180M Link )

We used the following command to call the training script

python train.py ../configs/pile-tiny.yaml \
    --matformer_factor=8 \
    --matformer_factor=8 \
    --model.d_model=512 \
    --model.n_heads=16 \
    --model.n_layers=8 \
    --model.max_sequence_length=2048 \
    --device_train_microbatch_size=8 \
    --global_train_batch_size=128 \
    --max_duration=75000  \
    --optimizer.learning_rate=1.0e-3 \
    --console_log_interval=10 \
    --load_path=:"/raid/ganesh/namitha/Skill_localization_experiment/ckpt_paths/MatFormer-OLMo-180M" \
    --run_name="matformer-olmo-180M-finetune"

where the folder mentioned in load_path is obtained by download from the link mentioned in the README for MatFormer-OLMo-180M .

However running this gives us the following error

[2024-04-18 09:09:04] CRITICAL [root, rank=0] Uncaught ValueError: Must flatten tensors on the same device but got both cuda:0 and meta
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /raid/ganesh/namitha/Skill_localization_experiment/MatFormer-OLMo/scripts/train.py:229 in <modul │
│                                                                                                  │
│   226 │   │   raise OlmoCliError(f"Usage: {sys.argv[0]} [CONFIG_PATH] [OPTIONS]")                │
│   227 │   print([clean_opt(s) for s in args_list])                                               │
│   228 │   cfg = TrainConfig.load(yaml_path, [clean_opt(s) for s in args_list])                   │
│ ❱ 229 │   main(cfg)                                                                              │
│   230                                                                                            │
│                                                                                                  │
│ /raid/ganesh/namitha/Skill_localization_experiment/MatFormer-OLMo/scripts/train.py:108 in main   │
│                                                                                                  │
│   105 │   log.info(f"Number of non-embedding parameters: {olmo_model.num_params(include_embeddin │
│   106 │   torch.distributed.init_process_group(backend='nccl',rank=0, world_size=1)              │
│   107 │   # Wrap the model in FSDP.                                                              │
│ ❱ 108 │   fsdp_model = FSDP(                                                                     │
│   109 │   │   olmo_model,                                                                        │
│   110 │   │   sharding_strategy=cfg.fsdp.sharding_strategy,                                      │
│   111 │   │   mixed_precision=MixedPrecision(  # equivalent to MosaicML's "PURE"                 │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│    474 │   │   │   │   # process groups.                                                         │
│    475 │   │   │   │   root_kwargs["process_group"] = (self.process_group, self._inter_node_pg)  │
│    476 │   │   │                                                                                 │
│ ❱  477 │   │   │   _auto_wrap(                                                                   │
│    478 │   │   │   │   module,                                                                   │
│    479 │   │   │   │   auto_wrap_policy,                                                         │
│    480 │   │   │   │   self._ignored_modules,                                                    │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│    98 │   │   )                                                                                  │
│    99 │   │   recursive_wrap_kwargs["auto_wrap_policy"] = policy                                 │
│   100 │   │   _warn_on_overridden_mixed_precision(overridden_module_classes)                     │
│ ❱ 101 │   _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs)  # type: ignore[arg-type]      │
│   102                                                                                            │
│   103                                                                                            │
│   104 def _check_nested_wrapping(root_module: nn.Module):                                        │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│   540 │   │   for name, child in module.named_children():                                        │
│   541 │   │   │   if child in ignored_modules:                                                   │
│   542 │   │   │   │   continue                                                                   │
│ ❱ 543 │   │   │   wrapped_child, num_wrapped_params = _recursive_wrap(                           │
│   544 │   │   │   │   module=child,                                                              │
│   545 │   │   │   │   auto_wrap_policy=auto_wrap_policy,                                         │
│   546 │   │   │   │   wrapper_cls=wrapper_cls,                                                   │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│   540 │   │   for name, child in module.named_children():                                        │
│   541 │   │   │   if child in ignored_modules:                                                   │
│   542 │   │   │   │   continue                                                                   │
│ ❱ 543 │   │   │   wrapped_child, num_wrapped_params = _recursive_wrap(                           │
│   544 │   │   │   │   module=child,                                                              │
│   545 │   │   │   │   auto_wrap_policy=auto_wrap_policy,                                         │
│   546 │   │   │   │   wrapper_cls=wrapper_cls,                                                   │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│   540 │   │   for name, child in module.named_children():                                        │
│   541 │   │   │   if child in ignored_modules:                                                   │
│   542 │   │   │   │   continue                                                                   │
│ ❱ 543 │   │   │   wrapped_child, num_wrapped_params = _recursive_wrap(                           │
│   544 │   │   │   │   module=child,                                                              │
│   545 │   │   │   │   auto_wrap_policy=auto_wrap_policy,                                         │
│   546 │   │   │   │   wrapper_cls=wrapper_cls,                                                   │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│   558 │   │   │   module=module, recurse=False, nonwrapped_numel=remainder                       │
│   559 │   │   ):                                                                                 │
│   560 │   │   │   # Leaf node or final wrapping of the remainder both happen here.               │
│ ❱ 561 │   │   │   return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel                  │
│   562 │   │   else:                                                                              │
│   563 │   │   │   return module, total_wrapped_numel                                             │
│   564 │   return module, 0                                                                       │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│   487 │   │   overrides = {**kwargs, **module._wrap_overrides}  # type: ignore[arg-type]         │
│   488 │   │   return wrapper_cls(module, **overrides)                                            │
│   489 │                                                                                          │
│ ❱ 490 │   return wrapper_cls(module, **kwargs)                                                   │
│   491                                                                                            │
│   492                                                                                            │
│   493 def _recursive_wrap(                                                                       │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│    500 │   │   _init_buffer_state(self, module)                                                  │
│    501 │   │   # extension needs to be set before `_init_param_handle_from_module()`             │
│    502 │   │   _init_extension(self, device_mesh)                                                │
│ ❱  503 │   │   _init_param_handle_from_module(                                                   │
│    504 │   │   │   self,                                                                         │
│    505 │   │   │   module,                                                                       │
│    506 │   │   │   device_id,                                                                    │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│    587 │   │   │   _sync_module_params_and_buffers(                                              │
│    588 │   │   │   │   fully_sharded_module, managed_params, state._inter_node_pg                │
│    589 │   │   │   )                                                                             │
│ ❱  590 │   _init_param_handle_from_params(state, managed_params, fully_sharded_module)           │
│    591 │   return state                                                                          │
│    592                                                                                           │
│    593                                                                                           │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│    599 ):                                                                                        │
│    600 │   if len(params) == 0:                                                                  │
│    601 │   │   return                                                                            │
│ ❱  602 │   handle = FlatParamHandle(                                                             │
│    603 │   │   params,                                                                           │
│    604 │   │   fully_sharded_module,                                                             │
│    605 │   │   state.compute_device,                                                             │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│    570 │   │   │   else 0                                                                        │
│    571 │   │   )                                                                                 │
│    572 │   │   self._fsdp_extension = fsdp_extension                                             │
│ ❱  573 │   │   self._init_flat_param_and_metadata(                                               │
│    574 │   │   │   params, fully_sharded_module, self._aligned_numel, use_orig_params  # type: i │
│    575 │   │   )                                                                                 │
│    576 │   │   self._use_unsharded_views(as_params=False)                                        │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│    620 │   │   │   dtype,                                                                        │
│    621 │   │   │   flat_param_requires_grad,                                                     │
│    622 │   │   │   device,                                                                       │
│ ❱  623 │   │   ) = self._validate_tensors_to_flatten(params)                                     │
│    624 │   │   params_set = set(params)                                                          │
│    625 │   │   # For alignment padding, only `numels` gets strictly non-`None`                   │
│    626 │   │   # elements, and all other lists get `None` elements for padding.                  │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│    773 │   │   │   │   │   "`use_orig_params=False`"                                             │
│    774 │   │   │   │   )                                                                         │
│    775 │   │   │   if device is not None and tensor.device != device:                            │
│ ❱  776 │   │   │   │   raise ValueError(                                                         │
│    777 │   │   │   │   │   "Must flatten tensors on the same device but got both "               │
│    778 │   │   │   │   │   f"{device} and {tensor.device}"                                       │
│    779 │   │   │   │   )                                                                         │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ValueError: Must flatten tensors on the same device but got both cuda:0 and meta

We are unable to resolve this issue

We tried adding the following line to torch/distributed/fsdp/_init_utils.py

tensor.to("cuda:0")  

But this operation gives another error as follows

│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│    753 │   │   device: Optional[torch.device] = None                                             │
│    754 │   │   # For `use_orig_params=True`, permit non-uniform `requires_grad`                  │
│    755 │   │   for tensor in tensors:                                                            │
│ ❱  756 │   │   │   tensor.to("cuda:0")                                                           │
│    757 │   │   │   if isinstance(tensor, FlatParameter):                                         │
│    758 │   │   │   │   raise ValueError("Cannot flatten a `FlatParameter`")                      │
│    759 │   │   │   if dtype is None and not tensor.is_floating_point():                          │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
NotImplementedError: Cannot copy out of meta tensor; no data!

We have made other changes to pile-tiny.yaml , scripts/train.py and scripts/util.py to make it compatible for training
I am attaching a zip of those files here :
changes.zip

Apart from this we were facing another issue

[2024-04-18 09:30:56] CRITICAL [root, rank=0] Uncaught AttributeError: 'LayerNorm' object has no attribute 'reset_parameters'
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /raid/ganesh/namitha/Skill_localization_experiment/MatFormer-OLMo/scripts/train.py:229 in <modul │
│                                                                                                  │
│   226 │   │   raise OlmoCliError(f"Usage: {sys.argv[0]} [CONFIG_PATH] [OPTIONS]")                │
│   227 │   print([clean_opt(s) for s in args_list])                                               │
│   228 │   cfg = TrainConfig.load(yaml_path, [clean_opt(s) for s in args_list])                   │
│ ❱ 229 │   main(cfg)                                                                              │
│   230                                                                                            │
│                                                                                                  │
│ /raid/ganesh/namitha/Skill_localization_experiment/MatFormer-OLMo/scripts/train.py:108 in main   │
│                                                                                                  │
│   105 │   log.info(f"Number of non-embedding parameters: {olmo_model.num_params(include_embeddin │
│   106 │   torch.distributed.init_process_group(backend='nccl',rank=0, world_size=1)              │
│   107 │   # Wrap the model in FSDP.                                                              │
│ ❱ 108 │   fsdp_model = FSDP(                                                                     │
│   109 │   │   olmo_model,                                                                        │
│   110 │   │   sharding_strategy=cfg.fsdp.sharding_strategy,                                      │
│   111 │   │   mixed_precision=MixedPrecision(  # equivalent to MosaicML's "PURE"                 │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│    474 │   │   │   │   # process groups.                                                         │
│    475 │   │   │   │   root_kwargs["process_group"] = (self.process_group, self._inter_node_pg)  │
│    476 │   │   │                                                                                 │
│ ❱  477 │   │   │   _auto_wrap(                                                                   │
│    478 │   │   │   │   module,                                                                   │
│    479 │   │   │   │   auto_wrap_policy,                                                         │
│    480 │   │   │   │   self._ignored_modules,                                                    │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│    98 │   │   )                                                                                  │
│    99 │   │   recursive_wrap_kwargs["auto_wrap_policy"] = policy                                 │
│   100 │   │   _warn_on_overridden_mixed_precision(overridden_module_classes)                     │
│ ❱ 101 │   _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs)  # type: ignore[arg-type]      │
│   102                                                                                            │
│   103                                                                                            │
│   104 def _check_nested_wrapping(root_module: nn.Module):                                        │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│   540 │   │   for name, child in module.named_children():                                        │
│   541 │   │   │   if child in ignored_modules:                                                   │
│   542 │   │   │   │   continue                                                                   │
│ ❱ 543 │   │   │   wrapped_child, num_wrapped_params = _recursive_wrap(                           │
│   544 │   │   │   │   module=child,                                                              │
│   545 │   │   │   │   auto_wrap_policy=auto_wrap_policy,                                         │
│   546 │   │   │   │   wrapper_cls=wrapper_cls,                                                   │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│   540 │   │   for name, child in module.named_children():                                        │
│   541 │   │   │   if child in ignored_modules:                                                   │
│   542 │   │   │   │   continue                                                                   │
│ ❱ 543 │   │   │   wrapped_child, num_wrapped_params = _recursive_wrap(                           │
│   544 │   │   │   │   module=child,                                                              │
│   545 │   │   │   │   auto_wrap_policy=auto_wrap_policy,                                         │
│   546 │   │   │   │   wrapper_cls=wrapper_cls,                                                   │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│   540 │   │   for name, child in module.named_children():                                        │
│   541 │   │   │   if child in ignored_modules:                                                   │
│   542 │   │   │   │   continue                                                                   │
│ ❱ 543 │   │   │   wrapped_child, num_wrapped_params = _recursive_wrap(                           │
│   544 │   │   │   │   module=child,                                                              │
│   545 │   │   │   │   auto_wrap_policy=auto_wrap_policy,                                         │
│   546 │   │   │   │   wrapper_cls=wrapper_cls,                                                   │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│   558 │   │   │   module=module, recurse=False, nonwrapped_numel=remainder                       │
│   559 │   │   ):                                                                                 │
│   560 │   │   │   # Leaf node or final wrapping of the remainder both happen here.               │
│ ❱ 561 │   │   │   return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel                  │
│   562 │   │   else:                                                                              │
│   563 │   │   │   return module, total_wrapped_numel                                             │
│   564 │   return module, 0                                                                       │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│   487 │   │   overrides = {**kwargs, **module._wrap_overrides}  # type: ignore[arg-type]         │
│   488 │   │   return wrapper_cls(module, **overrides)                                            │
│   489 │                                                                                          │
│ ❱ 490 │   return wrapper_cls(module, **kwargs)                                                   │
│   491                                                                                            │
│   492                                                                                            │
│   493 def _recursive_wrap(                                                                       │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│    500 │   │   _init_buffer_state(self, module)                                                  │
│    501 │   │   # extension needs to be set before `_init_param_handle_from_module()`             │
│    502 │   │   _init_extension(self, device_mesh)                                                │
│ ❱  503 │   │   _init_param_handle_from_module(                                                   │
│    504 │   │   │   self,                                                                         │
│    505 │   │   │   module,                                                                       │
│    506 │   │   │   device_id,                                                                    │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│    549 │   │   │   fully_sharded_module, param_init_fn, state._ignored_modules                   │
│    550 │   │   )                                                                                 │
│    551 │   elif is_meta_module:                                                                  │
│ ❱  552 │   │   _materialize_meta_module(                                                         │
│    553 │   │   │   fully_sharded_module, device_id, state._ignored_modules                       │
│    554 │   │   )                                                                                 │
│    555 │   elif is_torchdistX_deferred_init:                                                     │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│    881 │   │   │   f"device with error {str(e)}. Please ensure that your module of"              │
│    882 │   │   │   f"type {type(module)} implements a `reset_parameters()` method."              │
│    883 │   │   )                                                                                 │
│ ❱  884 │   │   raise e                                                                           │
│    885                                                                                           │
│    886                                                                                           │
│    887 def _get_modules_to_materialize(                                                          │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/distributed/fsd │
│                                                                                                  │
│    874 │   │   │   │   has_module_states = len(list(module_state_iter)) > 0                      │
│    875 │   │   │   │   if has_module_states:                                                     │
│    876 │   │   │   │   │   module.to_empty(device=materialization_device, recurse=False)         │
│ ❱  877 │   │   │   │   │   module.reset_parameters()  # type: ignore[operator]                   │
│    878 │   except BaseException as e:                                                            │
│    879 │   │   warnings.warn(                                                                    │
│    880 │   │   │   "Unable to call `reset_parameters()` for module on meta "                     │
│                                                                                                  │
│ /raid/ganesh/namitha/miniconda3/envs/icl_as_ft/lib/python3.9/site-packages/torch/nn/modules/modu │
│                                                                                                  │
│   1685 │   │   │   modules = self.__dict__['_modules']                                           │
│   1686 │   │   │   if name in modules:                                                           │
│   1687 │   │   │   │   return modules[name]                                                      │
│ ❱ 1688 │   │   raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") │
│   1689 │                                                                                         │
│   1690 │   def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None:             │
│   1691 │   │   def remove_from(*dicts_or_sets):                                                  │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
AttributeError: 'LayerNorm' object has no attribute 'reset_parameters'

However we circumvented this issue by commenting out the raise error (within torch/distributed/fsdp/_init_utils.py ) as follows

except BaseException as e:
        warnings.warn(
            "Unable to call `reset_parameters()` for module on meta "
            f"device with error {str(e)}. Please ensure that your module of"
            f"type {type(module)} implements a `reset_parameters()` method."
        )
        #raise e

I have attached the entire file within changes.zip , just in case

@prateekiiest
Copy link

Hi @adityakusupati , This is Prateek Chanda from GRI. @Advaid-Deepak and me were experimenting with Matformer OLmo for trying out a few ideas externally and were facing some issues with finetuning with a matformer checkpoint shown above.

Would really appreciate if you could kindly point out any steps which we possibly missed.

Thanks 😄

@adityakusupati
Copy link
Member

Hi Prateek and Advaid,

Thanks for your interest. I am unsure as to what is happening here as well. MatFormer-OLMo models are not that competitive either to do any experiments (barring scaling laws) and get meaningful results.

The only good MatFormer models publicly released at the MatViT models in scenic which are actually SOTA as regular ViT models and a drop in replacement.

As of now I am unable to look at this closely and can only do so after May 2nd week. The script and readme is what I used to restart my trained runs when something failed for ckpt, so that will imply fine-tuning should work similarly.

Sorry for not being of much help here.
Aditya

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants