Skip to content

Commit

Permalink
opt-30b (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
920232796 authored Jun 29, 2022
1 parent 97ffea4 commit c8e8e61
Show file tree
Hide file tree
Showing 10 changed files with 461 additions and 64 deletions.
23 changes: 23 additions & 0 deletions examples/opt/generate_opt_30b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from flagai.model.predictor.predictor import Predictor
from flagai.auto_model.auto_loader import AutoLoader
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

loader = AutoLoader(task_name="lm",
model_name="opt-30b-en")

model = loader.get_model()
tokenizer = loader.get_tokenizer()
model.eval()
model.to(device)

text = "The trophy doesn’t fit in the suitcase because "
predictor = Predictor(model, tokenizer)
out = predictor.predict_generate_randomsample(text,
input_max_length=100,
out_max_length=300,
top_k=30,
top_p=0.9,
repetition_penalty=3.0)

print(f"input is {text} \n out is {out}")
78 changes: 78 additions & 0 deletions examples/opt/opt_30b_en_mutigpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,2"
import torch
import os
import argparse
from flagai import mpu
from flagai.auto_model.auto_loader import AutoLoader
import random
import numpy as np
from flagai.model.predictor.predictor import Predictor

# run script : python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 glm_blank_filling_QA_ch_mutigpu.py
os.environ["ENV_TYPE"] = "deepspeed+mpu"
model_parallel_size = 4
world_size = 4

os.environ["MODEL_PARALLEL_SIZE"] = str(model_parallel_size)
os.environ["WORLD_SIZE"] = str(world_size)

def set_random_seed(seed):
"""Set random seed for reproducability."""
if seed is not None and seed > 0:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
mpu.model_parallel_cuda_manual_seed(seed)

parser = argparse.ArgumentParser()
parser.add_argument('--local_rank',
type=int,
default=0,
help="local_rank")

ds_args = parser.parse_args()
local_rank = ds_args.local_rank

master_addr = os.environ.get('MASTER_ADDR', '127.0.0.1')
master_port = os.environ.get('MASTER_PORT', '17501')

device = torch.device("cuda", local_rank)

def initialize_distributed():
"""Initialize torch.distributed."""
torch.backends.cudnn.enabled = False
# Manually set the device ids.
torch.cuda.set_device(device)
# Call the init process
init_method = 'tcp://'

init_method += master_addr + ':' + master_port
torch.distributed.init_process_group(
backend='nccl', # gloo
world_size=world_size,
rank=local_rank,
init_method=init_method)
mpu.initialize_model_parallel(model_parallel_size)

initialize_distributed()

set_random_seed(123)

loader = AutoLoader("lm", model_name="opt-350m-en")
model = loader.get_model()
model.half()
tokenizer = loader.get_tokenizer()
# model.parallel_output = False
model.eval()
model.to(device)

torch.distributed.barrier(group=mpu.get_model_parallel_group())

text = """I think The Old Man and the Sea is a very good book, what do you think? I think """

predictor = Predictor(model, tokenizer)
out = predictor.predict_generate_randomsample(text)
if mpu.get_model_parallel_rank() == 0:
print(f"pred is {out}")


10 changes: 8 additions & 2 deletions flagai/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,17 @@ def from_pretrain(cls,
print(
"preparing the model weights for model parallel size = {:02d}"
.format(model_parallel_size))
from flagai.mp_tools import change_pytorch_model_mp_from_1_to_n, check_pytorch_model_mp_size
from flagai.auto_model.auto_loader import MODEL_DICT
from flagai.mp_tools import change_pytorch_model_mp_from_1_to_n_new, check_pytorch_model_mp_size
if model_parallel_size > 1 and not check_pytorch_model_mp_size(
download_path, model_parallel_size):
change_pytorch_model_mp_from_1_to_n(
brief_model_name = MODEL_DICT[model_name.lower()][2]
change_pytorch_model_mp_from_1_to_n_new(brief_model_name,
download_path, model_parallel_size)

from flagai import mpu
torch.distributed.barrier(group=mpu.get_model_parallel_group())

if model_parallel_size > 1:
from flagai.mpu import get_model_parallel_rank
model_parallel_rank = get_model_parallel_rank()
Expand Down
2 changes: 1 addition & 1 deletion flagai/model/glm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def forward(self,
labels = kwargs['labels']
if os.getenv("ENV_TYPE") == 'deepspeed+mpu':
loss = vocab_parallel_cross_entropy(
logits.contiguous().float(), labels).mean()
logits_parallel.contiguous().float(), labels).mean()
else:

loss = F.cross_entropy(
Expand Down
76 changes: 64 additions & 12 deletions flagai/model/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,17 @@
from flagai.model.layers.embeddings import VocabParallelEmbedding
from flagai.model.utils import normal_init_method
from flagai.model.base_model import BaseModel
import torch.nn.functional as F

if os.getenv('ENV_TYPE') == 'deepspeed+mpu':
from flagai.mpu import get_model_parallel_world_size
from flagai.mpu import gather_from_model_parallel_region
from flagai.mpu import get_cuda_rng_tracker
from flagai.mpu.utils import divide
if os.getenv('ENV_TYPE') == 'deepspeed+mpu':
from flagai.mpu import copy_to_model_parallel_region
from flagai.mpu.random import checkpoint
from flagai.mpu import copy_to_model_parallel_region, gather_from_model_parallel_region
from flagai.mpu.cross_entropy import vocab_parallel_cross_entropy

elif os.getenv('ENV_TYPE') == 'deepspeed':
from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
else:
Expand Down Expand Up @@ -224,6 +226,7 @@ def __init__(self, config, **kwargs):
do_layer_norm_before=self.config.get("do_layer_norm_before", True),
)
self.config = config_gpt
self.parallel_output = True

self.transformer = GPT2Stack(self.config)
self.lm_head = nn.Linear(self.config.n_embd,
Expand Down Expand Up @@ -266,21 +269,70 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
hidden_states = transformer_outputs
logits = transformer_outputs

if os.getenv("ENV_TYPE") == 'deepspeed+mpu':
logits_parallel = copy_to_model_parallel_region(logits)
else:
logits_parallel = logits

lm_logits = self.lm_head(hidden_states)
# if self.output_predict:
# Parallel logits.
logits_parallel = F.linear(logits_parallel,
self.transformer.wte.weight)

return_data = {"logits": lm_logits}
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_logits = logits_parallel[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1))
return_data["loss"] = loss

return return_data
if os.getenv("ENV_TYPE") == 'deepspeed+mpu':
loss = vocab_parallel_cross_entropy(
shift_logits.contiguous().float(), shift_labels).mean()
else:
loss = F.cross_entropy(
shift_logits.contiguous().float(), shift_labels.long())

if self.parallel_output: # Put in different GPUs
return {
'logits': logits_parallel,
'loss': loss,
'hidden_states': None,
}
else:
return {
"logits":
gather_from_model_parallel_region(logits_parallel),
"loss":
loss,
"hidden_states":
None,
}
else:
if self.parallel_output: # Put in different GPUs
return {
'logits': logits_parallel,
'hidden_states': None,
}
else:
return {
"logits":
gather_from_model_parallel_region(logits_parallel),
"hidden_states":
None,
}

# lm_logits = self.lm_head(hidden_states)
# return_data = {"logits": lm_logits}
# if labels is not None:
# # Shift so that tokens < n predict n
# shift_logits = lm_logits[..., :-1, :].contiguous()
# shift_labels = labels[..., 1:].contiguous()
# loss_fct = nn.CrossEntropyLoss()
# loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
# shift_labels.view(-1))
# return_data["loss"] = loss

# return return_data

def load_weights(self, checkpoint_path):
checkpoint = torch.load(checkpoint_path,
Expand Down
2 changes: 2 additions & 0 deletions flagai/model/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def __init__(self,
self.scale = scale
if os.getenv("ENV_TYPE") == 'deepspeed+mpu':
world_size = get_model_parallel_world_size()
self.split_size = divide(n_state, world_size)

self.hidden_size_per_partition = divide(nx, world_size)
self.hidden_size_per_attention_head = divide(nx, self.n_head)
self.num_attention_heads_per_partition = divide(
Expand Down
70 changes: 35 additions & 35 deletions flagai/model/opt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,41 +265,41 @@ def __init__(self, config, **kwargs):
# self.config = config
self.transformer = OPTStack(self.config)

def forward(
self,
**data,
):
input_ids = data.get("input_ids", None)
# attention_mask = data.get("attention_mask", None)
# position_ids = data.get("position_ids", None)
labels = data.get("labels", None)
use_cache = data.get("use_cache", None)
output_attentions = data.get("output_attentions", None)
output_hidden_states = data.get("output_hidden_states", True)

transformer_outputs = self.transformer(
input_ids,
attention_mask=None,
position_ids=None,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
hidden_states = transformer_outputs

lm_logits = self.lm_head(hidden_states)

return_data = {"logits": lm_logits}
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1))
return_data["loss"] = loss

return return_data
# def forward(
# self,
# **data,
# ):
# input_ids = data.get("input_ids", None)
# # attention_mask = data.get("attention_mask", None)
# # position_ids = data.get("position_ids", None)
# labels = data.get("labels", None)
# use_cache = data.get("use_cache", None)
# output_attentions = data.get("output_attentions", None)
# output_hidden_states = data.get("output_hidden_states", True)
#
# transformer_outputs = self.transformer(
# input_ids,
# attention_mask=None,
# position_ids=None,
# use_cache=use_cache,
# output_attentions=output_attentions,
# output_hidden_states=output_hidden_states,
# )
# hidden_states = transformer_outputs
#
# lm_logits = self.lm_head(hidden_states)
#
# return_data = {"logits": lm_logits}
# if labels is not None:
# # Shift so that tokens < n predict n
# shift_logits = lm_logits[..., :-1, :].contiguous()
# shift_labels = labels[..., 1:].contiguous()
# loss_fct = nn.CrossEntropyLoss()
# loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
# shift_labels.view(-1))
# return_data["loss"] = loss
#
# return return_data


def load_weights(self, checkpoint_path):
Expand Down
16 changes: 2 additions & 14 deletions flagai/model/predictor/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,12 @@
t5_random_sample, gpt_random_sample, \
t5_beamsearch, gpt_beamsearch, bert_random_sample, glm_beamsearch, glm_random_sample
from typing import List, Union, Dict, Tuple, Any
from flagai.model.bert_model import BertModel, BertForMaskLM, BertForSeq2seq, BertForSequenceLabeling, \
BertForSequenceLabelingGP, BertForSequenceLabelingCRF, BertForClsClassifier
from flagai.model.gpt2_model import GPT2Model
from flagai.model.t5_model import T5Model
from flagai.data.tokenizer.bert.bert_tokenizer import BertTokenizer
from flagai.data.tokenizer.t5.t5_pegasus_tokenizer import T5PegasusTokenizer
from flagai.data.tokenizer.glm_large_ch.glm_large_ch_tokenizer import GLMLargeChTokenizer


class Predictor:

def __init__(self,
model: Union[BertModel, GPT2Model, BertForSequenceLabelingGP,
BertForSequenceLabelingCRF, BertForClsClassifier,
BertForClsClassifier, BertForSequenceLabeling,
BertForSeq2seq, T5Model, BertForMaskLM],
tokenizer: Union[GLMLargeChTokenizer, BertTokenizer,
T5PegasusTokenizer]):
model,
tokenizer):
"""
Args:
model: The model loaded by the AutoLoader class.
Expand Down
Loading

0 comments on commit c8e8e61

Please sign in to comment.