Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

text_generation_utils memory reduction if no logprob needed #6773

Merged
merged 31 commits into from
Jun 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
6d3aa7c
repro for gpt eval mp mem issue
yzhang123 May 26, 2023
1b077de
add print statements for memory allocation
yzhang123 May 30, 2023
30cd1e2
adjusted hot fix that prevents softmax on the entire output embedding…
yzhang123 May 30, 2023
94ff7a7
Merge remote-tracking branch 'origin/main' into gpt_predict_mp_mem_issue
yzhang123 May 30, 2023
fcb8379
Merge remote-tracking branch 'origin/main' into gpt_predict_mp_mem_issue
yzhang123 May 31, 2023
2b1e2b5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 31, 2023
2b31551
using compute_logprob to configure inference
yzhang123 Jun 1, 2023
a7f3c66
enable compute logprob for peft
yzhang123 Jun 1, 2023
abe634c
remove print statements
yzhang123 Jun 1, 2023
4529e52
Merge branch 'gpt_predict_mp_mem_issue' of github.com:yzhang123/NeMo …
yzhang123 Jun 1, 2023
269b59d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 1, 2023
11dccff
Merge remote-tracking branch 'origin/main' into gpt_predict_mp_mem_issue
yzhang123 Jun 1, 2023
bb6937e
Merge branch 'gpt_predict_mp_mem_issue' of github.com:yzhang123/NeMo …
yzhang123 Jun 1, 2023
342beb0
fix ci
yzhang123 Jun 1, 2023
0d46e07
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 1, 2023
d4c5703
added docstrings
yzhang123 Jun 1, 2023
6689ade
Merge branch 'gpt_predict_mp_mem_issue' of github.com:yzhang123/NeMo …
yzhang123 Jun 1, 2023
5957cc3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 1, 2023
a87891f
Merge branch 'main' into gpt_predict_mp_mem_issue
yzhang123 Jun 1, 2023
4c93fc4
add missing config
yzhang123 Jun 1, 2023
7292468
Merge branch 'gpt_predict_mp_mem_issue' of github.com:yzhang123/NeMo …
yzhang123 Jun 1, 2023
c43106c
Merge branch 'main' into gpt_predict_mp_mem_issue
yzhang123 Jun 2, 2023
34fbabc
remove truncate prompt length feature
yzhang123 Jun 2, 2023
ea9c983
resolve merge conflict
yzhang123 Jun 2, 2023
213218c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 2, 2023
c50bce7
Merge branch 'main' into gpt_predict_mp_mem_issue
yzhang123 Jun 2, 2023
d1287bc
Merge branch 'main' of github.com:NVIDIA/NeMo into gpt_predict_mp_mem…
yzhang123 Jun 3, 2023
aacb2d3
tensor before all gather needs to be contiguous
yzhang123 Jun 3, 2023
b9c285a
Merge branch 'main' into gpt_predict_mp_mem_issue
yzhang123 Jun 3, 2023
9d914c8
Merge branch 'main' into gpt_predict_mp_mem_issue
ekmb Jun 5, 2023
5c9f87b
Merge branch 'main' into gpt_predict_mp_mem_issue
MaximumEntropy Jun 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions examples/nlp/language_modeling/tuning/megatron_gpt_peft_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def main(cfg) -> None:

if os.path.isdir(cfg.model.restore_from_path):
save_restore_connector.model_extracted_dir = cfg.model.restore_from_path
model = NLPModel.restore_from(
model = MegatronGPTSFTModel.restore_from(
restore_path=cfg.model.restore_from_path,
trainer=trainer,
override_config_path=peft_model_cfg,
Expand All @@ -180,15 +180,17 @@ def main(cfg) -> None:
for batch in response:
batch_sentences = [s for s in batch['sentences']]
batch_tokens = [s for s in batch['tokens']]
batch_logprob = [s.tolist() for s in batch['logprob']]
for s, t, l in zip(batch_sentences, batch_tokens, batch_logprob):
if cfg.inference.get("verbose", False):
d = {
'sentence': s,
'tokens_with_logprobs': ', '.join([f"{_t} {_l:.4f}" for _t, _l in zip(t, l)]),
}
f.write(json.dumps(d, sort_keys=True, indent=2) + '\n')
else:
if cfg.inference.compute_logprob:
batch_logprob = [s.tolist() for s in batch['logprob']]
for s, t, l in zip(batch_sentences, batch_tokens, batch_logprob):
if cfg.inference.get("verbose", False):
d = {
'sentence': s,
'tokens_with_logprobs': ', '.join([f"{_t} {_l:.4f}" for _t, _l in zip(t, l)]),
}
f.write(json.dumps(d, sort_keys=True, indent=2) + '\n')
else:
for s in batch_sentences:
d = {'sentence': s}
f.write(json.dumps(d) + '\n')
print("predictions saved to {}".format(cfg.inference.outfile_path))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1111,7 +1111,6 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int]
inference_config = inference_config.copy()
compute_logprob = inference_config['compute_logprob']
if compute_logprob:
del inference_config['compute_logprob']
inference_config['inputs'] = batch
inference_config['tokens_to_generate'] = 1
inference_config['all_probs'] = True
Expand All @@ -1121,7 +1120,6 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int]
compute_prob_response = get_computeprob_response(self.tokenizer, response, batch)
return compute_prob_response
else:
del inference_config['compute_logprob']
inference_config['inputs'] = batch
return generate(self, **inference_config)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
LengthParam,
SamplingParam,
generate,
get_computeprob_response,
megatron_gpt_generate,
)
from nemo.utils import AppState, logging
Expand Down Expand Up @@ -539,7 +540,6 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int]
inference_config = inference_config.copy()
compute_logprob = inference_config['compute_logprob']
if compute_logprob:
del inference_config['compute_logprob']
inference_config['inputs'] = batch
inference_config['tokens_to_generate'] = 1
inference_config['all_probs'] = True
Expand All @@ -549,8 +549,6 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int]
compute_prob_response = get_computeprob_response(self.tokenizer, response, batch)
return compute_prob_response
else:
del inference_config['compute_logprob']

# for megatron_gpt_eval.py
if isinstance(batch, list):
inference_config['inputs'] = batch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,6 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int]
inference_config = inference_config.copy()
compute_logprob = inference_config['compute_logprob']
if compute_logprob:
del inference_config['compute_logprob']
inference_config['inputs'] = batch
inference_config['tokens_to_generate'] = 1
inference_config['all_probs'] = True
Expand All @@ -474,7 +473,6 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int]
compute_prob_response = get_computeprob_response(self.tokenizer, response, batch)
return compute_prob_response
else:
del inference_config['compute_logprob']
inference_config['inputs'] = batch
return generate(self, **inference_config, strategy=self.inference_strategy)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def __init__(self, model):

def forward_step(self, batch, tensor_shape):
fwd_bwd_function = get_forward_backward_func()

output_tensor = fwd_bwd_function(
forward_step_func=self.model.get_forward_output_only_func(),
data_iterator=iter([batch,]),
Expand Down
110 changes: 70 additions & 40 deletions nemo/collections/nlp/modules/common/text_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def megatron_gpt_generate(model, inputs, tokenizer, length_params, sampling_para
inputs=inputs,
tokens_to_generate=length_params['max_length'],
all_probs=sampling_params['all_probs'],
compute_logprob=sampling_params['compute_logprob'],
temperature=sampling_params['temperature'],
add_BOS=sampling_params['add_BOS'],
top_k=sampling_params['top_k'],
Expand All @@ -116,6 +117,7 @@ def megatron_gpt_generate(model, inputs, tokenizer, length_params, sampling_para
inputs=inputs,
tokens_to_generate=length_params['max_length'],
all_probs=sampling_params['all_probs'],
compute_logprob=sampling_params['compute_logprob'],
temperature=sampling_params['temperature'],
add_BOS=sampling_params['add_BOS'],
top_k=sampling_params['top_k'],
Expand Down Expand Up @@ -269,6 +271,7 @@ def send_generate_info(
context_length_tensor,
tokens_to_generate,
all_probs,
compute_logprob,
temperature,
top_k,
top_p,
Expand All @@ -288,6 +291,7 @@ def send_generate_info(
context_tokens_tensor.size(1), # seq_len
tokens_to_generate,
all_probs,
compute_logprob, # whether to compute log probabilities matrix
temperature,
top_k,
top_p,
Expand Down Expand Up @@ -317,18 +321,19 @@ def receive_generate_info():
"""
model_parallel_group = parallel_state.get_model_parallel_group()
src = get_model_parallel_src_rank()
input_info_tensor = torch.empty(10, dtype=torch.float32, device=torch.cuda.current_device())
input_info_tensor = torch.empty(11, dtype=torch.float32, device=torch.cuda.current_device())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add a comment here? why change to 11?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added compute_logprob as new entry to input_info_tensor, hence need to increase by one. added comment what compute_logprob does

torch.distributed.broadcast(input_info_tensor, src, model_parallel_group)
batch_size = int(input_info_tensor[0].item())
seq_len = int(input_info_tensor[1].item())
tokens_to_generate = int(input_info_tensor[2].item())
all_probs = bool(input_info_tensor[3].item())
temperature = float(input_info_tensor[4].item())
top_k = int(input_info_tensor[5].item())
top_p = float(input_info_tensor[6].item())
greedy = bool(input_info_tensor[7].item())
repetition_penalty = float(input_info_tensor[8].item())
min_tokens_to_generate = int(input_info_tensor[9].item())
compute_logprob = bool(input_info_tensor[4].item()) # whether to compute log probabilities matrix
temperature = float(input_info_tensor[5].item())
top_k = int(input_info_tensor[6].item())
top_p = float(input_info_tensor[7].item())
greedy = bool(input_info_tensor[8].item())
repetition_penalty = float(input_info_tensor[9].item())
min_tokens_to_generate = int(input_info_tensor[10].item())

context_length_tensor = torch.empty(batch_size, dtype=torch.int64, device=torch.cuda.current_device())
context_tokens_tensor = torch.empty(batch_size, seq_len, dtype=torch.int64, device=torch.cuda.current_device())
Expand All @@ -349,6 +354,7 @@ def receive_generate_info():
context_tokens_tensor,
tokens_to_generate,
all_probs,
compute_logprob,
temperature,
top_k,
top_p,
Expand All @@ -370,6 +376,7 @@ def synced_generate(
top_k=0,
top_p=0.0,
greedy=False,
compute_logprob=False,
repetition_penalty=1.2,
min_tokens_to_generate=0,
end_strings=[],
Expand All @@ -394,6 +401,7 @@ def synced_generate(
context_length_tensor,
tokens_to_generate,
all_probs,
compute_logprob=compute_logprob,
temperature=temperature,
end_strings=end_strings,
extra={
Expand All @@ -411,7 +419,8 @@ def synced_generate(
if parallel_state.is_pipeline_last_stage():
src = parallel_state.get_pipeline_model_parallel_last_rank()
group = parallel_state.get_embedding_group()
torch.distributed.broadcast(output_logits, src, group)
if compute_logprob:
torch.distributed.broadcast(output_logits, src, group)
if all_probs:
src = parallel_state.get_pipeline_model_parallel_last_rank()
group = parallel_state.get_embedding_group()
Expand All @@ -422,15 +431,18 @@ def synced_generate(
src = parallel_state.get_pipeline_model_parallel_last_rank()
group = parallel_state.get_embedding_group()

precision = model._trainer.precision
if precision in [16, "16"]:
dtype = torch.float16
elif precision == "bf16":
dtype = torch.bfloat16
else:
dtype = torch.float32
output_logits = torch.empty(tokens.size(0), context_length - 1, dtype=dtype, device=torch.device("cuda"))
torch.distributed.broadcast(output_logits, src, group)
if compute_logprob:
precision = model._trainer.precision
if precision in [16, "16"]:
dtype = torch.float16
elif precision == "bf16":
dtype = torch.bfloat16
else:
dtype = torch.float32
output_logits = torch.empty(
tokens.size(0), context_length - 1, dtype=dtype, device=torch.device("cuda")
)
torch.distributed.broadcast(output_logits, src, group)

if all_probs:
src = parallel_state.get_pipeline_model_parallel_last_rank()
Expand All @@ -457,6 +469,7 @@ def generate(
top_k=0,
top_p=0.0,
greedy=False,
compute_logprob=False,
repetition_penalty=1.0,
min_tokens_to_generate=0,
end_strings=['<|endoftext|>'],
Expand Down Expand Up @@ -504,6 +517,7 @@ def generate(
context_length_tensor,
tokens_to_generate,
all_probs,
compute_logprob,
temperature,
top_k,
top_p,
Expand All @@ -518,6 +532,7 @@ def generate(
context_tokens_tensor,
tokens_to_generate,
all_probs,
compute_logprob,
temperature,
top_k,
top_p,
Expand All @@ -535,6 +550,7 @@ def generate(
tokens_to_generate,
all_probs,
temperature,
compute_logprob=compute_logprob,
top_k=top_k,
top_p=top_p,
greedy=greedy,
Expand Down Expand Up @@ -619,6 +635,7 @@ def sample_sequence_batch(
context_lengths,
tokens_to_generate,
all_probs=False,
compute_logprob=False,
type_ids=None,
temperature=None,
end_strings=['<|endoftext|>'],
Expand Down Expand Up @@ -673,11 +690,18 @@ def sample_sequence_batch(
output = inference_strategy.forward_step(batch, tensor_shape)

if parallel_state.is_pipeline_last_stage():
output = output[0]['logits']

output = tensor_parallel.gather_from_tensor_model_parallel_region(output)
assert output is not None
logits = output[:, -1].view(batch_size, -1).contiguous()
if compute_logprob:
output = output[0]['logits']
output = tensor_parallel.gather_from_tensor_model_parallel_region(output)
assert output is not None
logits = output[:, -1].view(batch_size, -1).contiguous()

else:
logits = output[0]['logits'][:, -1].contiguous()
logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits)
assert logits is not None
logits = logits.view(batch_size, -1)

# make sure it will generate at least min_length
min_length = extra.get('min_tokens_to_generate', 0)
Expand All @@ -689,6 +713,7 @@ def sample_sequence_batch(
logits[:, tokenizer.vocab_size :] = -float('Inf')

# started indicates whether the current token step passes the context_length, so we make sure not to overwrite the context tokens

started = context_lengths <= context_length
if extra.get('greedy', False):
prev = torch.argmax(logits, dim=-1).view(-1)
Expand Down Expand Up @@ -716,23 +741,25 @@ def sample_sequence_batch(
# Insert either new predicted or next prompt token
tokens[:, context_length] = new_tokens

if output_logits is None:
output = F.log_softmax(output[:, :context_length, :], 2)
indices = torch.unsqueeze(tokens[:, 1 : context_length + 1], 2)
output_logits = torch.gather(output, 2, indices).squeeze(2)
all_generated_indices = indices[:, :, 0]
if all_probs:
full_logits = output
else:
output = F.log_softmax(output, 2)
indices = torch.unsqueeze(new_tokens, 1).unsqueeze(2)
new_output_logits = torch.gather(output, 2, indices).squeeze(2)
if compute_logprob:
if output_logits is None:
output = F.log_softmax(output[:, :context_length, :], 2)

# TODO(rprenger) we're copying output_logits every time. Should pre-allocate
output_logits = torch.cat([output_logits, new_output_logits], 1)
all_generated_indices = torch.cat([all_generated_indices, indices[:, :, 0]], 1)
if all_probs:
full_logits = torch.cat([full_logits, output], 1)
indices = torch.unsqueeze(tokens[:, 1 : context_length + 1], 2)
output_logits = torch.gather(output, 2, indices).squeeze(2)
all_generated_indices = indices[:, :, 0]
if all_probs:
full_logits = output
else:
output = F.log_softmax(output, 2)
indices = torch.unsqueeze(new_tokens, 1).unsqueeze(2)
new_output_logits = torch.gather(output, 2, indices).squeeze(2)

# TODO(rprenger) we're copying output_logits every time. Should pre-allocate
output_logits = torch.cat([output_logits, new_output_logits], 1)
all_generated_indices = torch.cat([all_generated_indices, indices[:, :, 0]], 1)
if all_probs:
full_logits = torch.cat([full_logits, output], 1)

src = parallel_state.get_pipeline_model_parallel_last_rank()
group = parallel_state.get_embedding_group()
Expand All @@ -752,10 +779,13 @@ def sample_sequence_batch(
src = parallel_state.get_pipeline_model_parallel_last_rank()
group = parallel_state.get_pipeline_model_parallel_group()
torch.distributed.broadcast(done, src, group)
if all_probs:
yield tokens, lengths, output_logits, full_logits
if compute_logprob:
if all_probs:
yield tokens, lengths, output_logits, full_logits
else:
yield tokens, lengths, output_logits, None
else:
yield tokens, lengths, output_logits, None
yield tokens, lengths, None, None

else:
if parallel_state.is_pipeline_first_stage():
Expand Down