Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zero-shot evaluation pipeline for mcore RETRO #8941

Merged
merged 146 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
146 commits
Select commit Hold shift + click to select a range
e91a66d
update branch
ericharper Jan 29, 2024
305ad9c
Add dist ckpt support for regular optimizers (#7749)
mikolajblaz Jan 31, 2024
40da002
Pin lhotse=1.19.2 in r1.23.0 (#8303)
pzelasko Feb 1, 2024
d3bad4b
Cache Aware Streaming tutorial notebook (#8296)
erastorgueva-nv Feb 1, 2024
17f09e4
fix path location and branch (#8304)
nithinraok Feb 2, 2024
991dad9
add deallocate pipeline output optimization (#8279)
JimmyZhang12 Feb 2, 2024
e9320ed
Fix memory leak caused by context parallelism hanging references by o…
JimmyZhang12 Feb 2, 2024
8b18cfc
remove assertion (#8302)
dimapihtar Feb 2, 2024
d9f1409
Update PEFT Doc (#8262)
cuichenx Feb 3, 2024
a592517
Attention encoder-decoder models for multiple speech-to-text tasks …
titu1994 Feb 3, 2024
3ef5513
add code for calling mcore_retro in NeMo
huvunvidia Nov 27, 2023
6218b8c
add code for calling mcore_retro in NeMo
huvunvidia Nov 27, 2023
c5907ac
runnable, training curve match retro mcore and nemo
huvunvidia Dec 15, 2023
5f10619
working on retro inference
huvunvidia Jan 10, 2024
ecc061e
working on megatron_retro_eval.py and megatron_retro_inference.yaml
huvunvidia Jan 10, 2024
c1e99d3
refactoring text_generation_utils code and retro inference relevant f…
Jan 20, 2024
2bf5d2c
clean PR
Jan 26, 2024
db6ffe3
resolving quick hacks (reading number of train/valid samples from wor…
Jan 30, 2024
1d1021c
clean repository
Jan 31, 2024
9b3cd36
revert changes to inference/eval code to original in main
Jan 31, 2024
31834d5
clean code
Jan 31, 2024
43c4af9
runable training code, with already implemented eval code
Jan 31, 2024
45ea217
[tutorial] fixed missing RIR scripts file. (#8257)
XuesongYang Jan 29, 2024
186a369
add values to en tts dict (#7879)
mgrafu Jan 30, 2024
7aea8c9
Add Bert HF checkpoint converter (#8088)
yaoyu-33 Jan 31, 2024
11787e7
revert to original eval code files
Jan 31, 2024
a1faebf
revert to original eval code files 2
Jan 31, 2024
78070f5
revert to original eval code files 3
Jan 31, 2024
7f2a889
revert to original eval code files 4
Jan 31, 2024
c0e4ea2
clean code
Feb 1, 2024
5dbee42
clean code
Feb 1, 2024
a9fb106
update my code to support changes from lastest main
Feb 2, 2024
769605c
commit before rebase r1.23.0
Feb 6, 2024
c3c766e
Multimodal r1.23.0 bug fix (#8315)
yaoyu-33 Feb 6, 2024
53bac6e
copy paste files from r1.23.0
Feb 6, 2024
9830475
clean PR
Feb 6, 2024
1434979
Fixes for MoE parameter passing & use of AutoTokenizer/Model for mist…
akoumpa Feb 6, 2024
ec8f413
Keep max_seqlen and cu_seqlens_argmin for later micro-batches when PP…
erhoo82 Feb 6, 2024
50864db
Remove asr webapp (#8347)
titu1994 Feb 6, 2024
498e9e4
remove _target_ at model level in aed config (#8351)
krishnacpuvvada Feb 6, 2024
c4f38cd
revert changes for tts and asr
Feb 7, 2024
2f72846
Add change_vocabulary and save_tokenizers() support to Multitask ASR …
titu1994 Feb 7, 2024
931c53c
Change default (#8371)
titu1994 Feb 8, 2024
7c75022
implement retro's own fwd_bwd_step() and validation_step() to not hav…
Feb 8, 2024
40bb4a2
adding megatron compile_helpers(), in future can be fixed with correc…
Feb 9, 2024
0e13348
bug fix in fast-conformer-aed.yaml and adding jenkins test for speech…
krishnacpuvvada Feb 9, 2024
138a7ab
Enable megatron core loggers for GPT pretraining (#8354)
ashbhandare Feb 9, 2024
4ee9c58
mcore ds fix (#8283)
dimapihtar Feb 9, 2024
de96b6e
addressing Eric's reviews
Feb 9, 2024
0e806b9
adding existing implementation RETRO files
Feb 9, 2024
09d9ce2
adding existing implementation RETRO files
Feb 9, 2024
02ec761
Add Finetuning tutorial with HF Datasets (#8356)
nithinraok Feb 9, 2024
88d7b21
release updates (#8378)
dimapihtar Feb 9, 2024
400c4a1
MCore dataset compatibility for tokenizers (#8390)
vysarge Feb 11, 2024
3112091
Mcore customization doc (#8298)
HuiyingLi Feb 12, 2024
68eba36
wer fix (#8404)
tbartley94 Feb 12, 2024
5b8f18c
updated link to pubmed (#8402)
nithinraok Feb 13, 2024
0f7b49b
Update NFA video download link (#8406)
erastorgueva-nv Feb 13, 2024
f897a77
revert changes (#8410)
cuichenx Feb 13, 2024
371de5b
Fix dreambooth data sampler issue (#8400)
yaoyu-33 Feb 13, 2024
98186c2
Fixed errors in the CTM gen functions (#8416)
tango4j Feb 14, 2024
8689bc0
add ensemble decoding fix (#8427)
nithinraok Feb 15, 2024
770f73b
SDE bugfix log (#8430)
Jorjeous Feb 15, 2024
05122bd
mcore customization doc minor fix (#8421)
HuiyingLi Feb 16, 2024
2e77f20
NeMo-Mistral to HF converter bugfix. (#8353)
akoumpa Feb 16, 2024
9588494
Fixing mcore bert for TP, PP and SP (#8336)
shanmugamr1992 Feb 16, 2024
71ce00c
Add settings to suppress bf16 compile errors in CI on V100 (#8481)
athitten Feb 22, 2024
c98b9c1
MoE parameter passing (#8255)
akoumpa Feb 23, 2024
a836fce
Update k2 version (#8478) (#8492)
artbataev Feb 23, 2024
0dc8a19
Add fp8 support for SD/Update notebook paths (#8489)
Victor49152 Feb 25, 2024
1d80d00
pin to 0.5.0 (#8465)
ericharper Feb 26, 2024
fcf1044
Update NeMo Multimodal Requirements (#8515)
yaoyu-33 Feb 26, 2024
d2283e3
update github raw content link (#8517)
cuichenx Feb 26, 2024
e6b7354
Add dep notice for notebooks (#8522)
ericharper Feb 27, 2024
ae9a2aa
Revert FP8 integration (#8520)
Victor49152 Feb 27, 2024
e772dbf
Update data prep notebook (#8532)
Victor49152 Feb 27, 2024
21984a1
before update branch with latest r1.23.0
Mar 4, 2024
52ee601
Merge remote-tracking branch 'origin/r1.23.0' into huvu/mcore_retro
Mar 4, 2024
b5d8aec
update to run with MLM ae2817b3dde4efb1515061a5311d01d8f85bd99c (runn…
Mar 5, 2024
4199bc7
remove compile_helpers
Mar 5, 2024
0ff5673
reverse changes from main branch to r1.23.0
Mar 5, 2024
74994c2
adding *_legacy files
Mar 5, 2024
061b632
update MLM commit in Jenkinsfile to latest
Mar 6, 2024
41b0178
debugging Jenkinstest: test different mcore import in retro_dataset
Mar 6, 2024
f9c3293
update Jenkinsfile edit megatron_retro_mutransfer_pretrain_legacy.py
Mar 7, 2024
88ef4d4
removing all mcore RETRO to pass the Jenkinstest
Mar 7, 2024
251762b
fixing import legacy problem for tests/collections/nlp/test_indexed_r…
Mar 7, 2024
8a6452d
update Jenkinsfile file to use TE v0.7
Mar 7, 2024
12263c0
update NeMo to work with latest mcore RETRO (solving TE problems)
Mar 20, 2024
188dd43
update TE commit Jenkinsfile to be the same with r1.23.0's Jenkinsfile
Mar 20, 2024
26c44a2
update commit for MLM
Mar 20, 2024
1c4c9a4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 20, 2024
9068890
jenkinstest debugging
Mar 20, 2024
adff1f8
temporary fix RETRO's __init__ for jenkinstest
Mar 21, 2024
fec1852
edit splits_string in jenkinsfile to correct format; put RETRO test i…
Mar 21, 2024
ab4c6c0
edit splits_string in jenkinsfile to correct format; put RETRO test i…
Mar 21, 2024
5cedf92
edit splits_string in jenkinsfile to correct format; put RETRO test i…
Mar 21, 2024
fadd10b
edit splits_string in jenkinsfile to correct format; put RETRO test i…
Mar 21, 2024
08d2d73
add model.data.dataloader_type=cyclic to jenkinsfile
Mar 22, 2024
ff2c904
runnable for inference
Mar 31, 2024
92ade73
update code to work with latest megatron-lm main 81dab6067
Apr 4, 2024
4f547fb
update M-LM commit in Jenkinsfile to latest main M-LM 81dab6067
Apr 4, 2024
96ed408
cleaning inference code
Apr 4, 2024
e8a83ed
fix to by pass CI test bf16 problem (following this PR https://github…
Apr 8, 2024
b33d8a0
isort and black
Apr 8, 2024
2402e66
adjusting model.micro_batch_size to 1
Apr 8, 2024
d050eb9
fix conflicts
Apr 9, 2024
7b13e95
fix BRANCH = 'r1.23.0'
Apr 9, 2024
35dc730
replace tutorials dir from main branch to huvu/mcore_retro
Apr 9, 2024
003b4b3
fix minor merges conflict
Apr 9, 2024
964f366
update Jenkinsfile
Apr 9, 2024
0787a20
runnable with a temporary fix from Jacek (unfound -unfinished problem)
Apr 10, 2024
39b2d76
runnable with a temporary fix from Jacek (unfound -unfinished problem)
Apr 10, 2024
17beaf8
merged from main on 10apr
Apr 10, 2024
19bfae0
modified nlp_overrides.py back to original
Apr 10, 2024
1ea089d
fix checkpoint from Jacek Bieniusiewicz
Apr 10, 2024
ac3e2b9
config Jenkinsfile test
Apr 10, 2024
a522627
set RETRO Jenkins MBS to 1
Apr 11, 2024
72ff280
black fix
Apr 11, 2024
4fbb7b8
isort fix
Apr 11, 2024
62f6d7e
update TE commit
Apr 11, 2024
41de539
update to latest Jenkinsfile with latest container and commits
Apr 11, 2024
32ab269
remove new RETRO jenkinstest
Apr 11, 2024
04844cd
Merge remote-tracking branch 'origin/main' into huvu/mcore_retro
Apr 11, 2024
637448c
merge latest main
Apr 11, 2024
6022d66
put RETRO Jenkinstest to the right place
Apr 11, 2024
01ead1e
merge from latest origin/huvu/nemo_retro
Apr 11, 2024
ce7e9e0
update code for megatron_retro_pretraining_legacy.py
Apr 11, 2024
685bb1e
update Jenkins and _legacy.py
Apr 12, 2024
e5c935d
update new RETRO jenkinstest to run faster
Apr 12, 2024
2849a07
fixing errors from GitHub Advanced Security / CodeQL
Apr 12, 2024
87ecf60
fixing errors from GitHub Advanced Security / CodeQL
Apr 12, 2024
0e5bfce
update manually branch to huvu/mcore_retro
Apr 12, 2024
493d66e
remove DEBUGGING markers
Apr 12, 2024
296dc95
merging and solve conflicts
Apr 16, 2024
9455d6d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 16, 2024
002c49c
copy paste scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt
Apr 16, 2024
554a06d
update codes to fix Github warnings; adding cicd-main.yml action tests
Apr 16, 2024
a767066
cleaning code, addressing Shanmugam's comments
Apr 19, 2024
b69b1d6
saving before pulling from main
Apr 19, 2024
cdf3dc2
pulled from main
Apr 19, 2024
16a2cf3
cleaning code
Apr 19, 2024
896e591
adding deprecations note
Apr 19, 2024
6262e2a
Merge remote-tracking branch 'origin/main' into huvu/mcore_retro_eval
Apr 22, 2024
119c6b9
Merge remote-tracking branch 'origin/main' into huvu/mcore_retro_eval
Apr 23, 2024
07a7a73
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 21 additions & 23 deletions examples/nlp/language_modeling/conf/megatron_retro_inference.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,42 +3,40 @@ inference:
top_k: 0 # The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p: 0.9 # If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.
temperature: 1.0 # sampling temperature
add_BOS: True # add the bos token at the begining of the prompt
add_BOS: False # add the bos token at the begining of the prompt
tokens_to_generate: 30 # The minimum length of the sequence to be generated.
all_probs: False # whether return the log prob for all the tokens in vocab
repetition_penalty: 1.2 # The parameter for repetition penalty. 1.0 means no penalty.
min_tokens_to_generate: 0 # The minimum length of the sequence to be generated.
compute_logprob: False # a flag used to compute logprob of all the input text, a very special case of running inference, default False

end_strings: ["<|endoftext|>"] # generation will stop when one of these tokens is generated
# RETRO-specific arguments
retro_inference:
retro_gpt_retrieved_length: 128
retro_num_neighbors: 2
ft_neighbours: 0
reuse_top: False

trainer:
devices: 1
num_nodes: 1
accelerator: gpu
logger: False # logger provided by exp_manager
precision: 16 # 16, 32, or bf16

inference_batch_size: 2
precision: 32 # 16, 32, or bf16
use_distributed_sampler: False
tensor_model_parallel_size: -1
pipeline_model_parallel_size: -1
pipeline_model_parallel_split_rank: -1 # used for encoder and decoder model (0 for others)
retro_model_file: null # RETRO nemo file path
megatron_amp_O2: False # Enable O2-level automatic mixed precision to save memory

use_predict_method: False # whether to use the predict method
retro_model_file: null # Retro nemo file path
checkpoint_dir: null # checkpoint file dir. This is used to load the PTL checkpoint generated during the Retro training
checkpoint_name: null # PTL checkpoint file name, only used for PTL checkpoint loading
hparams_file: null # model configuration file, only used for PTL checkpoint loading

prompts: # prompts for RETRO model inference
- "hello,"
- "good morning,"
- "good afternoon,"
- "good evening,"

########### Faiss service parameters ########
retrieval_service:
strategy: RetroModelTextGenerationStrategy # choose customized inference strategy
neighbors: 4
frequent_query: False # for the current token generation, frequently update the retrieval context. If false, update it every 64 tokens
pad_tokens: True # pad the tokens at the beginning to make it minimum of 64 tokens for retrieving at least once
store_retrieved: False # whether store the retrieved documents, so it can be checked
combo_service:
service_ip: '0.0.0.0'
service_port: 17181
# RETRO inference
prompt: "sample prompt"
neighbors:
- "neighbor text 1"
- "neighbor text 2"
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# (This inferencing script for native NeMo RETRO will be soon deprecated. For new inferencing script for mcore RETRO, see ./megatron_retro_inference.yaml)

inference:
greedy: False # Whether or not to use sampling ; use greedy decoding otherwise
top_k: 0 # The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p: 0.9 # If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.
temperature: 1.0 # sampling temperature
add_BOS: True # add the bos token at the begining of the prompt
tokens_to_generate: 30 # The minimum length of the sequence to be generated.
all_probs: False # whether return the log prob for all the tokens in vocab
repetition_penalty: 1.2 # The parameter for repetition penalty. 1.0 means no penalty.
min_tokens_to_generate: 0 # The minimum length of the sequence to be generated.
compute_logprob: False # a flag used to compute logprob of all the input text, a very special case of running inference, default False


trainer:
devices: 1
num_nodes: 1
accelerator: gpu
logger: False # logger provided by exp_manager
precision: 16 # 16, 32, or bf16

inference_batch_size: 2
tensor_model_parallel_size: -1
pipeline_model_parallel_size: -1
pipeline_model_parallel_split_rank: -1 # used for encoder and decoder model (0 for others)
retro_model_file: null # RETRO nemo file path

use_predict_method: False # whether to use the predict method

prompts: # prompts for RETRO model inference
- "hello,"
- "good morning,"
- "good afternoon,"
- "good evening,"

########### Faiss service parameters ########
retrieval_service:
strategy: RetroModelTextGenerationStrategy # choose customized inference strategy
neighbors: 4
frequent_query: False # for the current token generation, frequently update the retrieval context. If false, update it every 64 tokens
pad_tokens: True # pad the tokens at the beginning to make it minimum of 64 tokens for retrieving at least once
store_retrieved: False # whether store the retrieved documents, so it can be checked
combo_service:
service_ip: '0.0.0.0'
service_port: 17181
40 changes: 40 additions & 0 deletions examples/nlp/language_modeling/conf/megatron_retro_qatask.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
inference:
greedy: False # Whether or not to use sampling ; use greedy decoding otherwise
top_k: 0 # The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p: 0.9 # If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.
temperature: 1.0 # sampling temperature
add_BOS: False # add the bos token at the begining of the prompt
tokens_to_generate: 30 # The minimum length of the sequence to be generated.
all_probs: False # whether return the log prob for all the tokens in vocab
repetition_penalty: 1.2 # The parameter for repetition penalty. 1.0 means no penalty.
min_tokens_to_generate: 0 # The minimum length of the sequence to be generated.
compute_logprob: False # a flag used to compute logprob of all the input text, a very special case of running inference, default False
end_strings: ["<|endoftext|>"] # generation will stop when one of these tokens is generated
# RETRO-specific arguments
retro_inference:
retro_gpt_retrieved_length: 128
retro_num_neighbors: 2
ft_neighbours: 0
reuse_top: False

trainer:
devices: 1
num_nodes: 1
accelerator: gpu
logger: False # logger provided by exp_manager
precision: 32 # 16, 32, or bf16
use_distributed_sampler: False

tensor_model_parallel_size: -1
pipeline_model_parallel_size: -1
pipeline_model_parallel_split_rank: -1 # used for encoder and decoder model (0 for others)
megatron_amp_O2: False # Enable O2-level automatic mixed precision to save memory

retro_model_file: null # Retro nemo file path
checkpoint_dir: null # checkpoint file dir. This is used to load the PTL checkpoint generated during the Retro training
checkpoint_name: null # PTL checkpoint file name, only used for PTL checkpoint loading
hparams_file: null # model configuration file, only used for PTL checkpoint loading

# qa tasks
qa_file_path: null
pred_file_path: null
183 changes: 87 additions & 96 deletions examples/nlp/language_modeling/megatron_retro_eval.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -12,128 +12,119 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import os

from examples.nlp.language_modeling.megatron_gpt_eval import RequestDataSet
from omegaconf.omegaconf import OmegaConf, open_dict
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
import torch
from omegaconf import OmegaConf
from pytorch_lightning.trainer.trainer import Trainer
from torch.utils.data import DataLoader, Dataset

from nemo.collections.nlp.models.language_modeling.megatron_retrieval_model import MegatronRetrievalModel
from nemo.collections.nlp.modules.common.transformer.text_generation import LengthParam, SamplingParam
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector
from nemo.collections.nlp.models.language_modeling.megatron_retro_model import MegatronRetroModel
from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel
from nemo.collections.nlp.parts.nlp_overrides import CustomProgressBar, NLPDDPStrategy
from nemo.core.config import hydra_runner

try:
from megatron.core import parallel_state

HAVE_MEGATRON_CORE = True

except (ImportError, ModuleNotFoundError):

HAVE_MEGATRON_CORE = False
from nemo.utils.app_state import AppState
from nemo.utils.model_utils import inject_model_parallel_rank

"""
This is the script to run RETRO Model text generation.
This is the script to run Retro text generation.

Usage:
Assume the model has TP=1, PP=1
run greedy inference from a nemo file:
Currently, Mcore-based RETRO only support batch-size of 1.
Example running greedy inference from a distributed checkpoint dir:
python megatron_retro_eval.py \
checkpoint_dir=PATH_TO_CHECKPOINT \
checkpoint_name=CHECKPOINT_NAME \
inference.greedy=True \
inference.add_BOS=False \
trainer.devices=1 \
trainer.num_nodes=1 \
trainer.accelerator=gpu \
trainer.precision=16 \
inference.tokens_to_generate=128 \
inference.greedy=True \
retro_model_file=path_to_retro_nemo_file \
tensor_model_parallel_size=-1 \
pipeline_model_parallel_size=-1 \
retrieval_service.faiss_devices='0' \
retrieval_service.faiss_index=path_to_faiss_index \
retrieval_service.retrieval_index=path_to_retrieval_dataset \
retrieval_service.neighbors=20
"""
prompt="sample prompt" \
inference.retro_inference.retro_num_neighbors=2 \
neighbors=["neighbor text 1", "neighbor text 2"]


@hydra_runner(config_path="conf", config_name="megatron_retro_inference")
def main(cfg) -> None:
trainer = Trainer(strategy=NLPDDPStrategy(), **cfg.trainer)
```
"""

model_path = cfg.retro_model_file
if not torch.cuda.is_available():
raise EnvironmentError("GPU is needed for the inference")

save_restore_connector = NLPSaveRestoreConnector()

if os.path.isdir(model_path):
save_restore_connector.model_extracted_dir = model_path
class RequestDataSet(Dataset):
def __init__(self, sentences, neighbors):
super().__init__()
self.sentences = sentences
self.neighbors = neighbors

model_cfg = MegatronRetrievalModel.restore_from(
model_path, trainer=trainer, return_config=True, save_restore_connector=save_restore_connector,
)
def __len__(self,):
return len(self.sentences)

with open_dict(model_cfg):
model_cfg.precision = trainer.precision
model_cfg.sequence_parallel = False
model_cfg.activations_checkpoint_granularity = None
model_cfg.activations_checkpoint_method = None

if (
cfg.tensor_model_parallel_size < 0
or cfg.pipeline_model_parallel_size < 0
or cfg.get('pipeline_model_parallel_split_rank', -1) < 0
):
with open_dict(cfg):
cfg.tensor_model_parallel_size = model_cfg.get('tensor_model_parallel_size', 1)
cfg.pipeline_model_parallel_size = model_cfg.get('pipeline_model_parallel_size', 1)
cfg.pipeline_model_parallel_split_rank = model_cfg.get('pipeline_model_parallel_split_rank', 0)

model = MegatronRetrievalModel.restore_from(
model_path, trainer=trainer, save_restore_connector=save_restore_connector, override_config_path=model_cfg,
)
def __getitem__(self, idx):
return {'prompts': self.sentences[idx], 'neighbors': self.neighbors[idx]}

length_params: LengthParam = {
"max_length": cfg.inference.tokens_to_generate,
"min_length": cfg.inference.min_tokens_to_generate,
}

sampling_params: SamplingParam = {
"use_greedy": cfg.inference.greedy,
"temperature": cfg.inference.temperature,
"top_k": cfg.inference.top_k,
"top_p": cfg.inference.top_p,
"repetition_penalty": cfg.inference.repetition_penalty,
"add_BOS": cfg.inference.add_BOS,
"all_probs": cfg.inference.all_probs,
"compute_logprob": cfg.inference.compute_logprob,
}
@hydra_runner(config_path="conf", config_name="megatron_retro_inference")
def main(cfg) -> None:

# trainer required for restoring model parallel models
trainer = Trainer(
strategy=NLPDDPStrategy(timeout=datetime.timedelta(seconds=18000)),
**cfg.trainer,
callbacks=[CustomProgressBar()],
)

# check whether the DDP is initialized
if not parallel_state.is_initialized():
if cfg.checkpoint_dir:
app_state = AppState()
if cfg.tensor_model_parallel_size > 1 or cfg.pipeline_model_parallel_size > 1:
app_state.model_parallel_size = cfg.tensor_model_parallel_size * cfg.pipeline_model_parallel_size
app_state.tensor_model_parallel_size = cfg.tensor_model_parallel_size
app_state.pipeline_model_parallel_size = cfg.pipeline_model_parallel_size
(
app_state.tensor_model_parallel_rank,
app_state.pipeline_model_parallel_rank,
app_state.model_parallel_size,
app_state.data_parallel_size,
app_state.pipeline_model_parallel_split_rank,
app_state.virtual_pipeline_model_parallel_rank,
) = fake_initialize_model_parallel(
world_size=app_state.model_parallel_size,
rank=trainer.global_rank,
tensor_model_parallel_size_=cfg.tensor_model_parallel_size,
pipeline_model_parallel_size_=cfg.pipeline_model_parallel_size,
pipeline_model_parallel_split_rank_=cfg.pipeline_model_parallel_split_rank,
)
checkpoint_path = os.path.join(cfg.checkpoint_dir, cfg.checkpoint_name)
# checkpoint_path is a dir in case of distributed checkpointing
if not os.path.isdir(checkpoint_path):
# legacy checkpoint needs model parallel rank injection
checkpoint_path = inject_model_parallel_rank(os.path.join(cfg.checkpoint_dir, cfg.checkpoint_name))
model = MegatronRetroModel.load_from_checkpoint(
checkpoint_path, hparams_file=cfg.hparams_file, trainer=trainer
)
else:
raise ValueError("Requiring distributed checkpoint dir for loading Mcore RETRO.")

def dummy():
return
model.freeze()

if model.trainer.strategy.launcher is not None:
model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer)
model.trainer.strategy.setup_environment()
# Have to turn off activations_checkpoint_method for inference
try:
model.model.language_model.encoder.activations_checkpoint_method = None
except AttributeError:
Fixed Show fixed Hide fixed

Check notice

Code scanning / CodeQL

Empty except Note

'except' clause does nothing but pass and there is no explanatory comment.
pass

prompt = [cfg.prompt]
neighbors = [cfg.neighbors]
ds = RequestDataSet(prompt, neighbors)
bs = 1
request_dl = DataLoader(dataset=ds, batch_size=bs)
config = OmegaConf.to_container(cfg.inference)
retrieval_service = OmegaConf.to_container(cfg.retrieval_service)
model.set_inference_config(config, retrieval_service)

if not cfg.use_predict_method:
# First method of running text generation, call model.generate method
response = model.generate(
inputs=OmegaConf.to_container(cfg.prompts),
length_params=length_params,
sampling_params=sampling_params,
strategy=model.inference_strategy,
)
else:
# Second method of running text generation, call trainer.predict
ds = RequestDataSet(OmegaConf.to_container(cfg.prompts))
request_dl = DataLoader(dataset=ds, batch_size=cfg.inference_batch_size)
response = trainer.predict(model, request_dl)
model.set_inference_config(config)

response = trainer.predict(model, request_dl)
huvunvidia marked this conversation as resolved.
Show resolved Hide resolved

print("***************************")
print(response)
Expand Down
Loading
Loading