Skip to content

Commit

Permalink
restore path for p-tuning (NVIDIA#6273)
Browse files Browse the repository at this point in the history
* bug fix to seed prompt encoder with nemo model

Signed-off-by: arendu <adithya.r@gmail.com>

* update for continued training

Signed-off-by: arendu <adithya.r@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: arendu <adithya.r@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: hsiehjackson <c2hsieh@ucsd.edu>
  • Loading branch information
2 people authored and hsiehjackson committed Jun 2, 2023
1 parent f13f5b1 commit e8725da
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ exp_manager:
monitor: val_loss
save_top_k: 2
mode: min
save_nemo_on_train_end: False # Should be false, correct prompt learning model file is saved at model.nemo_path set below,
save_nemo_on_train_end: True
filename: 'megatron_gpt_prompt_tune--{val_loss:.3f}-{step}'
model_parallel_size: ${model.tensor_model_parallel_size}
save_best_model: True
Expand All @@ -44,6 +44,7 @@ exp_manager:
min_delta: 0.001
patience: 10
verbose: True
strict: False # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training.


model:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,8 @@ def setup(self, stage=None):

if self.first_stage_of_pipeline():
if self.virtual_prompt_style == VirtualPromptStyle.P_TUNING:
self.init_prompt_encoder()
if self.prompt_encoder is None:
self.init_prompt_encoder()
self.freeze_existing_word_embeddings()

self.setup_training_data()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,9 @@ def on_validation_epoch_start(self) -> None:
return super().on_validation_epoch_start()

def validation_epoch_end(self, outputs):
if len(outputs) == 0:
return

if parallel_state.is_pipeline_last_stage():
# only the last pipeline parallel stages return loss
averaged_loss = torch.stack([i['loss'] for i in outputs]).mean()
Expand Down

0 comments on commit e8725da

Please sign in to comment.