Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix errors caused by different dimensionality in get_dist function #7490

Closed
wants to merge 10 commits into from
344 changes: 172 additions & 172 deletions Jenkinsfile

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions examples/nlp/glue_benchmark/glue_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@

@hydra_runner(config_name="glue_benchmark_config")
def main(cfg: DictConfig) -> None:
# PTL 2.0 has find_unused_parameters as False by default, so its required to set it to True
# when there are unused parameters like here
if cfg.trainer.strategy == 'ddp':
cfg.trainer.strategy = "ddp_find_unused_parameters_true"
logging.info(f'Config: {OmegaConf.to_yaml(cfg)}')
trainer = pl.Trainer(**cfg.trainer)
exp_manager_cfg = cfg.get("exp_manager", None)
Expand Down
278 changes: 127 additions & 151 deletions nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py

Large diffs are not rendered by default.

65 changes: 57 additions & 8 deletions nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,8 @@ def _load_state_dict_from_disk(self, model_weights, map_location=None):
peft_state_dict = torch.load(model_weights_path, map_location)['state_dict']
else:
peft_state_dict = {}
base_model_state_dict.update(peft_state_dict) # add the PEFT state_dict into the base model's state_dict
if base_model_state_dict:
base_model_state_dict.update(peft_state_dict) # add the PEFT state_dict into the base model's state_dict
return base_model_state_dict

def restore_from(
Expand All @@ -765,13 +766,61 @@ def restore_from(
return loaded_params
conf, instance, state_dict = loaded_params

if (
self.peft_model_nemo_path is None and self.peft_model_ckpt_dir is None
): # we have this check only for training PEFT from scratch
peft_state_dict = instance.get_peft_state_dict()
state_dict.update(peft_state_dict)
state_dict = self.modify_state_dict(conf, state_dict)
self.load_instance_with_state_dict(instance, state_dict, strict)
# if we're using dist checkpointing then state_dict will be None
if state_dict is None:
# dist checkpointing needs torch.distributed to load the checkpoint
if parallel_state.is_unitialized():

def dummy():
return

if trainer.strategy.launcher is not None:
trainer.strategy.launcher.launch(dummy, trainer=trainer)
trainer.strategy.setup_environment()

with tempfile.TemporaryDirectory() as tmpdir:
# Check if self.model_extracted_dir is set, and is a valid path
if self.model_extracted_dir is not None and os.path.isdir(self.model_extracted_dir):
# Log that NeMo will use the provided `model_extracted_dir`
logging.info(
f"Restoration will occur within pre-extracted directory : " f"`{self.model_extracted_dir}`."
)

# Override `tmpdir` above with the pre-extracted `model_extracted_dir`
tmpdir = self.model_extracted_dir

else:
# Extract the nemo file into the temporary directory
self._unpack_nemo_file(
path2file=restore_path, out_folder=tmpdir, extract_config_only=return_config is True
)
checkpoint = {}
sharded_state_dict = instance.sharded_state_dict()
peft_state_dict = instance.get_peft_state_dict()
for k in peft_state_dict.keys():
sharded_state_dict.pop(k)
checkpoint['state_dict'] = sharded_state_dict
# remove model weights extension
tmp_model_weights_ckpt = os.path.join(tmpdir, self.model_weights_ckpt)
tmp_model_weights_dir = os.path.splitext(tmp_model_weights_ckpt)[0]
assert os.path.isdir(tmp_model_weights_dir), f'Expected {tmp_model_weights_dir} to be a directory.'
checkpoint = dist_checkpointing.load(
sharded_state_dict=checkpoint, checkpoint_dir=tmp_model_weights_dir
)
checkpoint['state_dict'].update(peft_state_dict)
instance.on_load_checkpoint(checkpoint)
if hasattr(instance, 'setup_transformer_engine_tp_groups'):
instance.setup_transformer_engine_tp_groups()

else:
if (
self.peft_model_nemo_path is None and self.peft_model_ckpt_dir is None
): # we have this check only for training PEFT from scratch
peft_state_dict = instance.get_peft_state_dict()
state_dict.update(peft_state_dict)
state_dict = self.modify_state_dict(conf, state_dict)
self.load_instance_with_state_dict(instance, state_dict, strict)

logging.info(f'Model {instance.__class__.__name__} was successfully restored from {restore_path}.')
return instance

Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/tts/modules/aligner.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def get_dist(self, keys, queries, mask=None):

self._apply_mask(dist, mask, float("inf"))

return dist
return dist.squeeze(1)

@staticmethod
def get_euclidean_dist(queries_enc, keys_enc):
Expand Down
2 changes: 1 addition & 1 deletion nemo/package_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
MAJOR = 1
MINOR = 21
PATCH = 0
PRE_RELEASE = 'rc0'
PRE_RELEASE = ''

# Use the following formatting: (major, minor, patch, pre-release)
VERSION = (MAJOR, MINOR, PATCH, PRE_RELEASE)
Expand Down
12 changes: 10 additions & 2 deletions nemo/utils/exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ class ExpManagerConfig:
ema: Optional[EMAParams] = EMAParams()
# Wall clock time limit
max_time_per_run: Optional[str] = None
# time to sleep non 0 ranks during initialization
seconds_to_sleep: float = 5


class TimingCallback(Callback):
Expand Down Expand Up @@ -301,6 +303,7 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo
Set this to True if you are using DDP with many GPUs and do not want many log files in your exp dir.
- max_time (str): The maximum wall clock time *per run*. This is intended to be used on clusters where you want
a checkpoint to be saved after this specified time and be able to resume from that checkpoint. Defaults to None.
- seconds_to_sleep (float): seconds to sleep non rank 0 processes for. Used to give enough time for rank 0 to initialize

returns:
log_dir (Path): The final logging directory where logging files are saved. Usually the concatenation of
Expand Down Expand Up @@ -501,6 +504,11 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo
# Add lightning file logging to global_rank zero
add_filehandlers_to_pl_logger(log_dir / 'lightning_logs.txt', log_dir / 'nemo_error_log.txt')

elif trainer.num_nodes * trainer.num_devices > 1:
# sleep other ranks so rank 0 can finish
# doing the initialization such as moving files
time.sleep(cfg.seconds_to_sleep)

return log_dir


Expand Down Expand Up @@ -578,8 +586,8 @@ def check_resume(
end_dist_checkpoints = [d for d in dist_checkpoints if d.match("*end")]
last_dist_checkpoints = [d for d in dist_checkpoints if d.match("*last")]

end_checkpoints = end_dist_checkpoints if end_dist_checkpoints else list(checkpoint_dir.glob("*end.ckpt"))
last_checkpoints = last_dist_checkpoints if last_dist_checkpoints else list(checkpoint_dir.glob("*last.ckpt"))
end_checkpoints = end_dist_checkpoints if end_dist_checkpoints else list(checkpoint_dir.rglob("*end.ckpt"))
last_checkpoints = last_dist_checkpoints if last_dist_checkpoints else list(checkpoint_dir.rglob("*last.ckpt"))

if not checkpoint_dir.exists() or (not len(end_checkpoints) > 0 and not len(last_checkpoints) > 0):
if resume_ignore_no_checkpoint:
Expand Down
Loading