Skip to content

Commit

Permalink
Use HF generate for inference (#261)
Browse files Browse the repository at this point in the history
  • Loading branch information
dskhudia authored Mar 29, 2023
1 parent 49b87ed commit b713261
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 88 deletions.
4 changes: 1 addition & 3 deletions examples/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
try:
import torch

from examples.llm.inference.inference import (MosaicGPTInference,
get_mosaicgpt_inference_model)
from examples.llm.inference.inference import get_mosaicgpt_inference_model
from examples.llm.src.model_registry import COMPOSER_MODEL_REGISTRY
from examples.llm.src.models.hf import (ComposerHFCausalLM,
ComposerHFPrefixLM, ComposerHFT5)
Expand Down Expand Up @@ -42,6 +41,5 @@
'GPTBlock',
'MosaicGPT',
'ComposerMosaicGPT',
'MosaicGPTInference',
'get_mosaicgpt_inference_model',
]
4 changes: 1 addition & 3 deletions examples/llm/inference/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
# Copyright 2022 MosaicML Examples authors
# SPDX-License-Identifier: Apache-2.0

from examples.llm.inference.inference import (MosaicGPTInference,
get_mosaicgpt_inference_model)
from examples.llm.inference.inference import get_mosaicgpt_inference_model

__all__ = [
'MosaicGPTInference',
'get_mosaicgpt_inference_model',
]
141 changes: 59 additions & 82 deletions examples/llm/inference/inference.py
Original file line number Diff line number Diff line change
@@ -1,89 +1,16 @@
# Copyright 2022 MosaicML Examples authors
# SPDX-License-Identifier: Apache-2.0
import sys
import warnings
from typing import List

import torch
from composer.core import get_precision_context
from composer.utils import get_device
from omegaconf import OmegaConf as om

from examples.llm.src import COMPOSER_MODEL_REGISTRY


def sample_top_p(probs, p):
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token


class MosaicGPTInference:

def __init__(self, cfg, model, tokenizer):
self.cfg = cfg
self.model = model
self.tokenizer = tokenizer

def generate(
self,
prompts: List[str],
max_gen_len: int,
temperature: float = 0.8,
top_p: float = 0.95,
) -> List[str]:
if isinstance(prompts, str):
prompts = [prompts]

bsz = len(prompts)

prompt_tokens = [self.tokenizer.encode(x) for x in prompts]

min_prompt_size = min([len(t) for t in prompt_tokens])
max_prompt_size = max([len(t) for t in prompt_tokens])

total_len = min(self.cfg.max_seq_len, max_gen_len + max_prompt_size)

tokens = torch.full((bsz, total_len), -100).cuda().long()
for k, t in enumerate(prompt_tokens):
tokens[k, :len(t)] = torch.tensor(t).long()
input_text_mask = tokens != -100
start_pos = min_prompt_size
for cur_pos in range(start_pos, total_len):
with torch.no_grad():
with get_precision_context(self.cfg.get('precision',
'amp_bf16')):
logits = self.model.forward(
{'input_ids': tokens[:, :cur_pos]})
logits = logits[:, -1, :]
logits = logits.float()
if temperature > 0:
probs = torch.softmax(logits / temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
else:
next_token = torch.argmax(logits, dim=-1)
next_token = next_token.reshape(-1)
# only replace token if prompt has already been generated
next_token = torch.where(input_text_mask[:, cur_pos],
tokens[:, cur_pos], next_token)
tokens[:, cur_pos] = next_token

decoded = []
for i, t in enumerate(tokens.tolist()):
# cut to max gen len
t = t[:len(prompt_tokens[i]) + max_gen_len]
# cut to eos tok if any
try:
t = t[:t.index(self.tokenizer.eos_token_id)]
except ValueError:
pass
decoded.append(self.tokenizer.decode(t))
return decoded


def build_composer_model(model_cfg, tokenizer_cfg):
warnings.filterwarnings(
action='ignore',
Expand All @@ -95,7 +22,7 @@ def build_composer_model(model_cfg, tokenizer_cfg):
f'Not sure how to build model with name={model_cfg.name}')


def get_mosaicgpt_inference_model(checkpoint_yaml_path: str, tokenizer):
def get_mosaicgpt_inference_model(checkpoint_yaml_path: str):
with open(checkpoint_yaml_path) as f:
cfg = om.load(f)
# set init_device to cpu for checkpoint loading
Expand All @@ -109,11 +36,61 @@ def get_mosaicgpt_inference_model(checkpoint_yaml_path: str, tokenizer):

checkpoint = torch.load(ckpt_load_path, map_location='cpu')

model.load_state_dict(checkpoint['state']['model'], strict=True)

model.cuda()
if 'state' in checkpoint.keys():
# it's a full training checkpoint
model.load_state_dict(checkpoint['state']['model'], strict=True)
else:
# it's a weights-only checkpoint
model.load_state_dict(checkpoint, strict=True)

if model.tokenizer.pad_token_id is None:
warnings.warn(
'pad_token_id is not set for the tokenizer. Using eos_token_id as pad_token_id.'
)
model.tokenizer.pad_token_id = model.tokenizer.eos_token_id
model.tokenizer.padding_side = 'left'

return model


if __name__ == '__main__':
if len(sys.argv) < 2:
print('please provide a configuration yaml')
sys.exit(-1)
yaml_path = sys.argv[1]
with open(yaml_path) as f:
cfg = om.load(f)
model = get_mosaicgpt_inference_model(yaml_path)
model.eval()

generator = MosaicGPTInference(cfg, model, tokenizer)

return generator
generate_kwargs = {
'max_new_tokens': 100,
'use_cache': True,
'do_sample': True,
'top_p': 0.95,
'eos_token_id': model.tokenizer.eos_token_id,
}
prompts = [
'My name is',
'This is an explanation of deep learning to a five year old. Deep learning is',
]
device = get_device(None)
device.module_to_device(model)
encoded_inp = model.tokenizer(prompts, return_tensors='pt', padding=True)
for key, value in encoded_inp.items():
encoded_inp[key] = device.tensor_to_device(value)

with torch.no_grad():
with get_precision_context(
cfg.get(
'precision', # type: ignore
'amp_bf16')):
generation = model.model.generate(
input_ids=encoded_inp['input_ids'],
attention_mask=encoded_inp['attention_mask'],
**generate_kwargs,
)

decoded_out = model.tokenizer.batch_decode(generation,
skip_special_tokens=True)
print('\n###\n'.join(decoded_out))

0 comments on commit b713261

Please sign in to comment.