Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

text_generation_utils memory reduction if no logprob needed #6773

Merged
merged 31 commits into from
Jun 6, 2023
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
6d3aa7c
repro for gpt eval mp mem issue
yzhang123 May 26, 2023
1b077de
add print statements for memory allocation
yzhang123 May 30, 2023
30cd1e2
adjusted hot fix that prevents softmax on the entire output embedding…
yzhang123 May 30, 2023
94ff7a7
Merge remote-tracking branch 'origin/main' into gpt_predict_mp_mem_issue
yzhang123 May 30, 2023
fcb8379
Merge remote-tracking branch 'origin/main' into gpt_predict_mp_mem_issue
yzhang123 May 31, 2023
2b1e2b5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 31, 2023
2b31551
using compute_logprob to configure inference
yzhang123 Jun 1, 2023
a7f3c66
enable compute logprob for peft
yzhang123 Jun 1, 2023
abe634c
remove print statements
yzhang123 Jun 1, 2023
4529e52
Merge branch 'gpt_predict_mp_mem_issue' of github.com:yzhang123/NeMo …
yzhang123 Jun 1, 2023
269b59d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 1, 2023
11dccff
Merge remote-tracking branch 'origin/main' into gpt_predict_mp_mem_issue
yzhang123 Jun 1, 2023
bb6937e
Merge branch 'gpt_predict_mp_mem_issue' of github.com:yzhang123/NeMo …
yzhang123 Jun 1, 2023
342beb0
fix ci
yzhang123 Jun 1, 2023
0d46e07
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 1, 2023
d4c5703
added docstrings
yzhang123 Jun 1, 2023
6689ade
Merge branch 'gpt_predict_mp_mem_issue' of github.com:yzhang123/NeMo …
yzhang123 Jun 1, 2023
5957cc3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 1, 2023
a87891f
Merge branch 'main' into gpt_predict_mp_mem_issue
yzhang123 Jun 1, 2023
4c93fc4
add missing config
yzhang123 Jun 1, 2023
7292468
Merge branch 'gpt_predict_mp_mem_issue' of github.com:yzhang123/NeMo …
yzhang123 Jun 1, 2023
c43106c
Merge branch 'main' into gpt_predict_mp_mem_issue
yzhang123 Jun 2, 2023
34fbabc
remove truncate prompt length feature
yzhang123 Jun 2, 2023
ea9c983
resolve merge conflict
yzhang123 Jun 2, 2023
213218c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 2, 2023
c50bce7
Merge branch 'main' into gpt_predict_mp_mem_issue
yzhang123 Jun 2, 2023
d1287bc
Merge branch 'main' of github.com:NVIDIA/NeMo into gpt_predict_mp_mem…
yzhang123 Jun 3, 2023
aacb2d3
tensor before all gather needs to be contiguous
yzhang123 Jun 3, 2023
b9c285a
Merge branch 'main' into gpt_predict_mp_mem_issue
yzhang123 Jun 3, 2023
9d914c8
Merge branch 'main' into gpt_predict_mp_mem_issue
ekmb Jun 5, 2023
5c9f87b
Merge branch 'main' into gpt_predict_mp_mem_issue
MaximumEntropy Jun 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ inference:
repetition_penalty: 1.2 # The parameter for repetition penalty. 1.0 means no penalty.
min_tokens_to_generate: 0 # The minimum length of the sequence to be generated.
compute_logprob: False # a flag used to compute logprob of all the input text, a very special case of running inference, default False

truncate_prompt_length: -1 # if not -1 truncate prompt to this length

trainer:
devices: 1
Expand Down
5 changes: 4 additions & 1 deletion examples/nlp/language_modeling/megatron_gpt_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,11 @@ def main(cfg) -> None:
except AttributeError:
pass

assert cfg.inference.truncate_prompt_length == -1 or cfg.inference.truncate_prompt_length >= 0
yzhang123 marked this conversation as resolved.
Show resolved Hide resolved
length_params: LengthParam = {
"max_length": cfg.inference.tokens_to_generate,
"min_length": cfg.inference.min_tokens_to_generate,
"truncate_prompt_length": cfg.inference.truncate_prompt_length,
}

sampling_params: SamplingParam = {
Expand All @@ -265,9 +267,10 @@ def main(cfg) -> None:

# Second method of running text generation, call trainer.predict
ds = RequestDataSet(OmegaConf.to_container(cfg.prompts))
request_dl = DataLoader(dataset=ds, batch_size=2)
request_dl = DataLoader(dataset=ds, batch_size=1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this change necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good overall. why we need truncate_prompt_length argument?

If someone want's to truncate their prompt, like with p-tuning truncating the context. but this is not a deal breaker, so your call if you want to remove it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok i can revert the bs

config = OmegaConf.to_container(cfg.inference)
model.set_inference_config(config)

response = trainer.predict(model, request_dl)

print("***************************")
Expand Down
22 changes: 12 additions & 10 deletions examples/nlp/language_modeling/tuning/megatron_gpt_peft_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def main(cfg) -> None:

if os.path.isdir(cfg.model.restore_from_path):
save_restore_connector.model_extracted_dir = cfg.model.restore_from_path
model = NLPModel.restore_from(
model = MegatronGPTSFTModel.restore_from(
restore_path=cfg.model.restore_from_path,
trainer=trainer,
override_config_path=peft_model_cfg,
Expand All @@ -180,15 +180,17 @@ def main(cfg) -> None:
for batch in response:
batch_sentences = [s for s in batch['sentences']]
batch_tokens = [s for s in batch['tokens']]
batch_logprob = [s.tolist() for s in batch['logprob']]
for s, t, l in zip(batch_sentences, batch_tokens, batch_logprob):
if cfg.inference.get("verbose", False):
d = {
'sentence': s,
'tokens_with_logprobs': ', '.join([f"{_t} {_l:.4f}" for _t, _l in zip(t, l)]),
}
f.write(json.dumps(d, sort_keys=True, indent=2) + '\n')
else:
if cfg.inference.compute_logprob:
batch_logprob = [s.tolist() for s in batch['logprob']]
for s, t, l in zip(batch_sentences, batch_tokens, batch_logprob):
if cfg.inference.get("verbose", False):
d = {
'sentence': s,
'tokens_with_logprobs': ', '.join([f"{_t} {_l:.4f}" for _t, _l in zip(t, l)]),
}
f.write(json.dumps(d, sort_keys=True, indent=2) + '\n')
else:
for s in batch_sentences:
d = {'sentence': s}
f.write(json.dumps(d) + '\n')
print("predictions saved to {}".format(cfg.inference.outfile_path))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1109,7 +1109,6 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int]
inference_config = inference_config.copy()
compute_logprob = inference_config['compute_logprob']
if compute_logprob:
del inference_config['compute_logprob']
inference_config['inputs'] = batch
inference_config['tokens_to_generate'] = 1
inference_config['all_probs'] = True
Expand All @@ -1119,7 +1118,6 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int]
compute_prob_response = get_computeprob_response(self.tokenizer, response, batch)
return compute_prob_response
else:
del inference_config['compute_logprob']
inference_config['inputs'] = batch
return generate(self, **inference_config)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def init_model(self, cfg: DictConfig, trainer: Trainer):
self.length_params: LengthParam = {
"max_length": self.cfg.inference.get('tokens_to_generate', 30),
"min_length": self.cfg.inference.get('min_tokens_to_generate', 0),
"truncate_prompt_length": self.cfg.inference.get('truncate_prompt_length', -1),
}

self.sampling_params: SamplingParam = {
Expand Down Expand Up @@ -742,6 +743,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int]
length_params: LengthParam = {
"max_length": inference_config["tokens_to_generate"],
"min_length": inference_config["min_tokens_to_generate"],
"truncate_prompt_length": inference_config.get("truncate_prompt_length", -1),
}

sampling_params: SamplingParam = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
LengthParam,
SamplingParam,
generate,
get_computeprob_response,
megatron_gpt_generate,
)
from nemo.utils import AppState, logging
Expand Down Expand Up @@ -539,7 +540,6 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int]
inference_config = inference_config.copy()
compute_logprob = inference_config['compute_logprob']
if compute_logprob:
del inference_config['compute_logprob']
inference_config['inputs'] = batch
inference_config['tokens_to_generate'] = 1
inference_config['all_probs'] = True
Expand All @@ -549,7 +549,6 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int]
compute_prob_response = get_computeprob_response(self.tokenizer, response, batch)
return compute_prob_response
else:
del inference_config['compute_logprob']
inference_config['inputs'] = (batch['contexts'].cuda(), batch['context_lengths'].cuda())
return generate(self, **inference_config)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,6 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int]
inference_config = inference_config.copy()
compute_logprob = inference_config['compute_logprob']
if compute_logprob:
del inference_config['compute_logprob']
inference_config['inputs'] = batch
inference_config['tokens_to_generate'] = 1
inference_config['all_probs'] = True
Expand All @@ -474,7 +473,6 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int]
compute_prob_response = get_computeprob_response(self.tokenizer, response, batch)
return compute_prob_response
else:
del inference_config['compute_logprob']
inference_config['inputs'] = batch
return generate(self, **inference_config, strategy=self.inference_strategy)

Expand Down
14 changes: 10 additions & 4 deletions nemo/collections/nlp/modules/common/text_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def __init__(self, model):

def forward_step(self, batch, tensor_shape):
fwd_bwd_function = get_forward_backward_func()

output_tensor = fwd_bwd_function(
forward_step_func=self.model.get_forward_output_only_func(),
data_iterator=iter([batch,]),
Expand All @@ -67,13 +66,14 @@ def forward_step(self, batch, tensor_shape):

return output_tensor

def tokenize_batch(self, sentences, max_len, add_BOS):
def tokenize_batch(self, sentences, max_len, add_BOS, truncate_prompt_length=-1):
"""
convert the sentences into lists of tokens, pad them to the same length, add bos tokens if it is needed
Args:
sentences (List[str]): list of input sentences in str format.
max_len (int): max number of tokens to generate.
add_BOS (bool): whether to add the BOS token at the beginning
truncate_prompt_length (int): if not -1 truncates sentences to this length
Returns:
Tuple[torch.Tensor], the tokenized and padded torch tensor and the token context length tensor.
"""
Expand All @@ -82,6 +82,11 @@ def tokenize_batch(self, sentences, max_len, add_BOS):
context_tokens = [[tokenizer.bos_id] + tokenizer.text_to_ids(s) for s in sentences]
else:
context_tokens = [tokenizer.text_to_ids(s) for s in sentences]
if truncate_prompt_length != -1:
res = []
for s in context_tokens:
res.append(s[:truncate_prompt_length])
context_tokens = res
context_tokens, context_lengths = pad_batch(context_tokens, tokenizer.eos_id, max_len)
context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
context_length_tensor = torch.cuda.LongTensor(context_lengths)
Expand Down Expand Up @@ -181,8 +186,9 @@ def __init__(self, model):

def clip_max_len(self, maxlen: int) -> int:
""" clip the max len based on the LM model max sequence length"""
if maxlen > self.model.cfg.encoder_seq_length + 1:
maxlen = self.model.cfg.encoder_seq_length + 1
if self.model.cfg.get("position_embedding_type", "learned_absolute") == "learned_absolute":
if maxlen > self.model.cfg.encoder_seq_length + 1:
maxlen = self.model.cfg.encoder_seq_length + 1
return maxlen

def init_batch(self, context_tokens: torch.Tensor, context_length: int):
Expand Down
Loading