Skip to content

Commit

Permalink
check for first or last stage (#6708)
Browse files Browse the repository at this point in the history
* check for first or last stage

Signed-off-by: ericharper <complex451@gmail.com>

* remove redundant check

Signed-off-by: ericharper <complex451@gmail.com>

* fix typo

Signed-off-by: ericharper <complex451@gmail.com>

* add map_location

Signed-off-by: ericharper <complex451@gmail.com>

---------

Signed-off-by: ericharper <complex451@gmail.com>
  • Loading branch information
ericharper authored and web-flow committed May 26, 2023
1 parent b50ae98 commit 900ce90
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 30 deletions.
1 change: 1 addition & 0 deletions examples/nlp/language_modeling/megatron_gpt_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def main(cfg) -> None:
trainer=trainer,
override_config_path=pretrained_cfg,
save_restore_connector=save_restore_connector,
map_location=f'cuda:{trainer.local_rank}', # map_location is needed for converted models
)
elif cfg.checkpoint_dir:
app_state = AppState()
Expand Down
65 changes: 35 additions & 30 deletions nemo/collections/nlp/modules/common/text_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,36 +135,41 @@ def megatron_gpt_generate(model, inputs, tokenizer, length_params, sampling_para


def get_computeprob_response(tokenizer, response, inputs):
compute_prob_response = {}
new_token_ids = []
new_tokens = []
new_texts = []
log_probs = []
full_logprobs = []
offsets = []
for batch_id in range(len(response['tokens'])):
if isinstance(inputs, (list, tuple)):
if isinstance(inputs[0], str):
new_token_id = tokenizer.text_to_ids(inputs[batch_id])
new_text = inputs[batch_id]
token_len = len(new_token_id)
elif isinstance(inputs[0], torch.Tensor):
token_len = int(inputs[1][batch_id].item())
new_token_id = inputs[0][batch_id][:token_len].tolist()
new_text = tokenizer.ids_to_text(new_token_id)
new_token_ids.append(new_token_id)
new_tokens.append(response['tokens'][batch_id][:token_len])
new_texts.append(new_text)
log_probs.append(response['logprob'][batch_id][:token_len])
full_logprobs.append(response['full_logprob'][batch_id][:token_len])
offsets.append(response['offsets'][batch_id][:-1])
compute_prob_response['sentences'] = new_texts
compute_prob_response['tokens'] = new_tokens
compute_prob_response['token_ids'] = new_token_ids
compute_prob_response['logprob'] = log_probs
compute_prob_response['full_logprob'] = full_logprobs
compute_prob_response['offsets'] = offsets
return compute_prob_response
if parallel_state.is_pipeline_first_stage() or parallel_state.is_pipeline_last_stage():
# we only have a response on the first and last pipeline stages
compute_prob_response = {}
new_token_ids = []
new_tokens = []
new_texts = []
log_probs = []
full_logprobs = []
offsets = []
for batch_id in range(len(response['tokens'])):
if isinstance(inputs, (list, tuple)):
if isinstance(inputs[0], str):
new_token_id = tokenizer.text_to_ids(inputs[batch_id])
new_text = inputs[batch_id]
token_len = len(new_token_id)
elif isinstance(inputs[0], torch.Tensor):
token_len = int(inputs[1][batch_id].item())
new_token_id = inputs[0][batch_id][:token_len].tolist()
new_text = tokenizer.ids_to_text(new_token_id)
new_token_ids.append(new_token_id)
new_tokens.append(response['tokens'][batch_id][:token_len])
new_texts.append(new_text)
log_probs.append(response['logprob'][batch_id][:token_len])
full_logprobs.append(response['full_logprob'][batch_id][:token_len])
offsets.append(response['offsets'][batch_id][:-1])
compute_prob_response['sentences'] = new_texts
compute_prob_response['tokens'] = new_tokens
compute_prob_response['token_ids'] = new_token_ids
compute_prob_response['logprob'] = log_probs
compute_prob_response['full_logprob'] = full_logprobs
compute_prob_response['offsets'] = offsets
return compute_prob_response
else:
# intermediate stages
return None


def get_batch(model, tokenizer, context_tokens):
Expand Down

0 comments on commit 900ce90

Please sign in to comment.