Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix unwrap model #9480

Merged
merged 5 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 31 additions & 12 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3060,13 +3060,13 @@ jobs:
AFTER_SCRIPT: |
rm -rf /home/TestData/nlp/megatron_ir/working_dir

L2_Megatron_GPT_PEFT_Lora_PP2:
L2_Megatron_GPT_PEFT_Lora_PP2_O2:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
with:
RUNNER: self-hosted-azure
SCRIPT: |
rm -rf examples/nlp/language_modeling/gpt_peft_lora_results_pp2
rm -rf /home/TestData/nlp/lora_tuning_pp2

python examples/nlp/language_modeling/tuning/megatron_gpt_finetuning.py \
trainer.devices=2 \
Expand All @@ -3075,11 +3075,12 @@ jobs:
trainer.max_steps=3 \
trainer.val_check_interval=3 \
++trainer.limit_val_batches=2 \
trainer.precision=16 \
exp_manager.exp_dir=examples/nlp/language_modeling/gpt_peft_lora_results_pp2 \
trainer.precision=bf16 \
exp_manager.exp_dir=/home/TestData/nlp/lora_tuning_pp2 \
model.pipeline_model_parallel_size=2 \
model.tensor_model_parallel_size=1 \
model.restore_from_path=/home/TestData/nlp/megatron_gpt/PP2/gpt_pp2_tp1.nemo \
model.restore_from_path=/home/TestData/nlp/megatron_gpt/mcore_45M/megatron_llama.nemo \
model.megatron_amp_O2=True \
model.peft.peft_scheme=lora \
model.answer_only_loss=True \
model.micro_batch_size=1 \
Expand All @@ -3090,10 +3091,28 @@ jobs:
model.data.validation_ds.num_workers=0 \
model.data.validation_ds.file_names=[/home/TestData/nlp/megatron_sft/quarel.jsonl] \
model.data.validation_ds.names=[quarel]

python examples/nlp/language_modeling/tuning/megatron_gpt_generate.py \
model.restore_from_path=/home/TestData/nlp/megatron_gpt/mcore_45M/megatron_llama.nemo \
model.peft.restore_from_path=/home/TestData/nlp/lora_tuning_pp2/megatron_gpt_peft_lora_tuning/checkpoints/megatron_gpt_peft_lora_tuning.nemo \
model.pipeline_model_parallel_size=2 \
model.tensor_model_parallel_size=1 \
trainer.devices=2 \
model.megatron_amp_O2=True \
model.data.test_ds.file_names=[/home/TestData/nlp/megatron_sft/quarel_4.jsonl] \
model.data.test_ds.names=['quarel4'] \
model.global_batch_size=2 \
model.micro_batch_size=1 \
model.data.test_ds.tokens_to_generate=10 \
model.data.test_ds.write_predictions_to_file=True \
model.data.test_ds.output_file_path_prefix='/home/TestData/nlp/lora_tuning_pp2/out' \
inference.greedy=True \
inference.repetition_penalty=1.0 \
inference.outfile_path='/home/TestData/nlp/lora_tuning_pp2/out.jsonl'
AFTER_SCRIPT: |
rm -rf examples/nlp/language_modeling/gpt_peft_lora_results_pp2
rm -rf /home/TestData/nlp/lora_tuning_pp2

L2_Megatron_GPT_PEFT_Lora_TP2:
L2_Megatron_GPT_PEFT_Lora_TP2_O1:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
with:
Expand All @@ -3108,11 +3127,11 @@ jobs:
trainer.max_steps=3 \
trainer.val_check_interval=3 \
++trainer.limit_val_batches=2 \
trainer.precision=16 \
trainer.precision=bf16 \
exp_manager.exp_dir=/home/TestData/nlp/lora_tuning_tp2 \
model.pipeline_model_parallel_size=1 \
model.tensor_model_parallel_size=2 \
model.restore_from_path=/home/TestData/nlp/megatron_gpt/TP2/megatron_gpt_tp2.nemo \
model.restore_from_path=/home/TestData/nlp/megatron_gpt/mcore_45M/megatron_llama.nemo \
model.peft.peft_scheme='lora' \
model.answer_only_loss=True \
model.micro_batch_size=1 \
Expand All @@ -3125,7 +3144,7 @@ jobs:
model.data.validation_ds.names=[quarel]

python examples/nlp/language_modeling/tuning/megatron_gpt_generate.py \
model.restore_from_path=/home/TestData/nlp/megatron_gpt/TP2/megatron_gpt_tp2.nemo \
model.restore_from_path=/home/TestData/nlp/megatron_gpt/mcore_45M/megatron_llama.nemo \
model.peft.restore_from_path=/home/TestData/nlp/lora_tuning_tp2/megatron_gpt_peft_lora_tuning/checkpoints/megatron_gpt_peft_lora_tuning.nemo \
model.tensor_model_parallel_size=2 \
trainer.devices=2 \
Expand Down Expand Up @@ -4234,8 +4253,8 @@ jobs:
- L2_Megatron_GPT_Finetuning_PP2
- L2_Megatron_GPT_Finetuning_StarCoder_PP1
- L2_Megatron_GPT_Embedding
- L2_Megatron_GPT_PEFT_Lora_PP2
- L2_Megatron_GPT_PEFT_Lora_TP2
- L2_Megatron_GPT_PEFT_Lora_PP2_O2
- L2_Megatron_GPT_PEFT_Lora_TP2_O1
- L2_Megatron_GPT_Eval
- L2_Megatron_GPT_Eval_PP2
- L2_Megatron_GPT_SFT_Eval_inference_seq_len_greaterThan_training_seq_len
Expand Down
14 changes: 7 additions & 7 deletions nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,11 @@ def _get_all_keys(
"""
Returns all the keys in the model
"""
k = [n for n, p in self._unwrap_model().named_parameters()]
k = [n for n, p in self._unwrap_model().named_parameters(prefix="model")]
b = [
n
for n, p in self._unwrap_model().named_buffers()
if n.replace("model.module.", "model.", 1) in self._unwrap_model().state_dict().keys()
for n, p in self._unwrap_model().named_buffers(prefix="model")
if n.replace("model.module.", "model.", 1) in self._unwrap_model().state_dict(prefix="model.").keys()
]
# we include buffers because ptuning representations are cached in a buffer and saved to state_dict for inference time use.
return set(k + b)
Expand Down Expand Up @@ -292,13 +292,13 @@ def setup_optimizer_param_groups(self):
self.freeze(training=True) # Freeze the entire model
if not self.ptuning_only_and_non_first_stage:
opt_params = []
for _, module in self._unwrap_model().named_modules():
for _, module in self._unwrap_model().named_modules(prefix="model"):
if isinstance(module, AdapterModuleMixin) and module.is_adapter_available():
module.set_enabled_adapters(enabled=True)
module.unfreeze_enabled_adapters() # selectively unfreeze the adapter modules.
opt_params += [p for p in module.parameters() if p.requires_grad]

for name, param in self._unwrap_model().named_parameters():
for name, param in self._unwrap_model().named_parameters(prefix="model"):
if name in self.tunable_base_param_keys:
param.requires_grad = True
opt_params += [param]
Expand Down Expand Up @@ -397,11 +397,11 @@ def get_peft_state_dict(self):
"""
Gets the keys associated with the adapters only.
"""
state_dict = self._unwrap_model().state_dict()
state_dict = self._unwrap_model().state_dict(prefix="model.")
peft_state_dict = {}
for k in self.adapter_keys.union(self.tunable_base_param_keys):
# state_dict keys needs to be in non-O2 format and will be corrected in PEFTSaveRestoreConnector if O2=True
new_k = k.replace("module.", "", 1)
new_k = k.replace("model.module.", "model.", 1)
peft_state_dict[new_k] = state_dict[new_k]
return peft_state_dict

Expand Down
Loading