Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#37 from wjm202/add_blip2_stage1
Browse files Browse the repository at this point in the history
Add blip2 stage1
  • Loading branch information
jerrywgz authored Aug 2, 2023
2 parents d48c387 + e7419ca commit de634ce
Show file tree
Hide file tree
Showing 14 changed files with 1,373 additions and 380 deletions.
8 changes: 2 additions & 6 deletions paddlevlp/examples/blip2/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from dataclasses import dataclass, field
import os
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../..'))
from dataclasses import dataclass, field
import paddle
import requests
from paddlenlp.trainer import PdArgumentParser
Expand All @@ -23,14 +24,9 @@
from paddlevlp.utils.log import logger
import os
import yaml
import argparse
import paddle
import argparse
import os
import random

import numpy as np
# import torch
import paddle

@dataclass
Expand Down
13 changes: 1 addition & 12 deletions paddlevlp/examples/blip2/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../..'))
import numpy as np
import requests
from dataclasses import dataclass, field
Expand All @@ -33,18 +34,6 @@

import matplotlib.pyplot as plt


def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)



class DeployConfig:
def __init__(self, path):
with codecs.open(path, 'r', 'utf-8') as file:
Expand Down
24 changes: 9 additions & 15 deletions paddlevlp/examples/blip2/run_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,9 @@

import sys
import os
sys.path.insert(
0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../..'))
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../..'))
import paddle.distributed as dist
from paddle.distributed import fleet
import os
from dataclasses import dataclass, field
import numpy as np
import random
Expand All @@ -28,7 +26,6 @@
from paddlenlp.trainer import (PdArgumentParser, TrainingArguments,
get_last_checkpoint)
from paddlenlp.transformers import AutoConfig, OPTConfig, T5Config
import paddlevlp
from paddlevlp.datasets import load_dataset
from paddlevlp.models.blip2.configuration import (
Blip2Config, Blip2QFormerConfig, Blip2VisionConfig)
Expand All @@ -37,11 +34,8 @@
from paddlevlp.trainer.blip2_trainer import BLIP2Trainer as Trainer
from paddlevlp.utils.log import logger
from paddlenlp.transformers import AutoTokenizer
from paddlevlp.models.blip2.eva_vit import interpolate_pos_embed
from paddlevlp.processors.blip_processing import BlipImageProcessor, BlipTextProcessor
from paddlevlp.examples.blip2.utils import BlipCollator
from paddlevlp.examples.blip2.utils import load_pretrained_model, LLM_LIST

from paddlevlp.processors.blip_processing import BlipImageProcessor,BlipTextProcessor
from paddlevlp.examples.blip2.utils import BlipCollator,load_model,LLM_LIST

@dataclass
class DataArguments:
Expand Down Expand Up @@ -86,12 +80,12 @@ class PreTrainingArguments(TrainingArguments):
Arguments pertaining to what training options we are going to use during pretraining.
"""

pretrained_model_path: str = field(
model_path: str = field(
default="https://bj.bcebos.com/v1/paddlenlp/models/community/Salesforce/blip2-opt-2.7b/blip2_pretrained.pdparams",
metadata={
"help":
"The path to pre-trained model that we will use for pretraining."
}, )
"help": "The path to model that we will use for eval."
},
)
weight_decay: float = field(
default=0.05, metadata={"help": "Weight decay if we apply some."})
learning_rate: float = field(
Expand Down Expand Up @@ -243,8 +237,8 @@ def main():
logger.info("training_args.use_hybrid_parallel:{}".format(
training_args.use_hybrid_parallel))
# create trainer
load_pretrained_model(model, LLM_LIST[model_args.llm_name])
load_pretrained_model(model, training_args.pretrained_model_path)
load_model(training_args,model, ckpt_dir=model_args.model_path,load_language_model=False)
load_model(training_args,model.language_model, ckpt_dir=LLM_LIST[model_args.text_model_name_or_path],load_language_model=True)
trainer = Trainer(
model=model,
args=training_args,
Expand Down
202 changes: 180 additions & 22 deletions paddlevlp/examples/blip2/run_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@
# limitations under the License.
import sys
import os
import paddle.distributed as dist
from paddle.distributed import fleet
from dataclasses import dataclass, field
import numpy as np
import random
import paddle
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
sys.path.insert(
0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../..'))
from dataclasses import dataclass, field
Expand All @@ -24,8 +31,10 @@
from paddlevlp.models.blip2.modeling import Blip2ForConditionalGeneration
from paddlevlp.processors.blip_processing import Blip2Processor
from paddlevlp.utils.log import logger
from paddlevlp.examples.blip2.utils import load_pretrained_model

from paddlevlp.models.blip2.configuration import Blip2VisionConfig, Blip2QFormerConfig, Blip2Config
from paddlenlp.transformers import T5Config, OPTConfig, AutoConfig
from paddlevlp.examples.blip2.utils import load_model,LLM_LIST
from paddlenlp.trainer import (PdArgumentParser, TrainingArguments)

@dataclass
class DataArguments:
Expand All @@ -36,15 +45,16 @@ class DataArguments:
the command line.
"""

input_image: str = field(metadata={
input_image: str = field(
default="http://images.cocodataset.org/val2017/000000039769.jpg",
metadata={
"help": "The name of input image."
}) # "http://images.cocodataset.org/val2017/000000039769.jpg"
prompt: str = field(
default=None,
default="describe the image",
metadata={"help": "The prompt of the image to be generated."
}) # "Question: how many cats are there? Answer:"


@dataclass
class ModelArguments:
"""
Expand All @@ -53,22 +63,114 @@ class ModelArguments:

model_name_or_path: str = field(
default="Salesforce/blip2-opt-2.7b",
metadata={"help": "Path to pretrained model or model identifier"}, )
pretrained_model_path: str = field(
default=None,
metadata={"help": "Path to pretrained model or model identifier"},
)
model_path: str = field(
default="https://bj.bcebos.com/v1/paddlenlp/models/community/Salesforce/blip2-opt-2.7b/blip2_pretrained.pdparams",
metadata={
"help":
"The path to pre-trained model that we will use for inference."
}, )
"help": "The path to model that we will use for eval."
},
)

text_model_name_or_path: str = field(
default="facebook/opt-2.7b",
metadata={"help": "The type of text model to use (OPT, T5)."},
)
image_size : int = field(
default=224, metadata={"help": " Image size for training. (default:224)"}
)

@dataclass
class PreTrainingArguments(TrainingArguments):
"""
Arguments pertaining to what training options we are going to use during pretraining.
"""
weight_decay: float = field(
default=0.05, metadata={"help": "Weight decay if we apply some."}
)
learning_rate: float = field(
default=0.0001, metadata={"help": "The initial learning rate."}
)
num_train_epochs: float = field(
default=10.0, metadata={"help": "Total number of training epochs to perform."}
)
warmup_start_lr: float = field(
default=1e-6, metadata={"help": "Initial learning rate of warm up."}
)
eta_min: float = field(
default=1e-5, metadata={"help": "The minimum value of learning rate."}
)
warmup_steps: int = field(
default=2000, metadata={"help": "Number of warmup steps."}
)
lr_scheduler_name: str = field(
default="CosineDecayWithWarmup", metadata={"help": "The scheduler name to use."}
)
per_device_train_batch_size: int = field(
default=128, metadata={"help":"Batch size per GPU core/CPU for training. (default: 8)"}
)
per_device_eval_batch_size : int = field(
default=128, metadata={"help": " Batch size per GPU core/CPU for evaluation. (default:8)"}
)
warmup_start_lr : float = field(
default=1e-6, metadata={"help": " The initial learning rate of blip2."}
)
output_dir : str = field(
default=".", metadata={"help": "The output path"}
)
do_eval : bool = field(default=False, metadata={"help": "Whether to evaluation."})
do_train : bool = field(default=True, metadata={"help": "Whether to train."})

logging_steps : int = field(default=50, metadata={"help": "Logging interval"})
evaluation_strategy : str = field(default="no", metadata={"help": "Evaluation strategy (epoch/steps/no)"})

fp16_opt_level : str = field(default="O1", metadata={"help": "Mixed Precision Type"})
fp16 : bool = field(default=True, metadata={"help": "Whether to use mixed Precision"})
gradient_checkpointing : bool = field(default=False, metadata={"help": "Forward recompute for saving graphics memory"})
tensor_parallel_degree : int = field(default=1, metadata={"help": "Set the number of tensor model parallel"})
sharding_parallel_degree : int = field(default=1, metadata={"help": "Set the number of sharding, enable sharding parallel"})
pipeline_parallel_degree : int = field(default=1, metadata={"help": "Enable pipeline parallel"})

def get_text_config(text_model_name_or_path):
if "t5" in text_model_name_or_path:
text_config = T5Config.from_pretrained(text_model_name_or_path)
elif "opt" in text_model_name_or_path:
text_config = OPTConfig.from_pretrained(text_model_name_or_path)
else:
text_config = AutoConfig.from_pretrained(text_model_name_or_path)
return text_config


def create_model(config):
vision_config = Blip2VisionConfig.from_pretrained(config.model_name_or_path)
qformer_config = Blip2QFormerConfig.from_pretrained(config.model_name_or_path)
text_config = get_text_config(config.text_model_name_or_path)
# add tensor_parallel_degree
vision_config.image_size= config.image_size
blip2_config = Blip2Config.from_vision_qformer_text_configs(
vision_config, qformer_config, text_config
)

model = Blip2ForConditionalGeneration(blip2_config)
paddle.device.cuda.empty_cache()
return model


def main():
parser = PdArgumentParser((ModelArguments, DataArguments))
model_args, data_args = parser.parse_args_into_dataclasses()
parser = PdArgumentParser((ModelArguments, DataArguments, PreTrainingArguments))
model_args, data_args,training_args = parser.parse_args_into_dataclasses()
url = (data_args.input_image
) # "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
training_args.print_config(model_args, "Model")
training_args.print_config(data_args, "Data")
training_args.prompt=data_args.prompt
setdistenv(training_args)


model_args.data_world_rank = training_args.data_world_rank
model_args.data_world_size = training_args.data_world_size
paddle.set_device(training_args.device)
prompt = data_args.prompt
processor = Blip2Processor.from_pretrained(
model_args.model_name_or_path) # "Salesforce/blip2-opt-2.7b"
Expand All @@ -78,21 +180,77 @@ def main():
return_tensors="pd",
return_attention_mask=True,
mode="test", )
model = Blip2ForConditionalGeneration.from_pretrained(
model_args.model_name_or_path)

# load checkpoint
if model_args.pretrained_model_path:
weight = paddle.load(model_args.pretrained_model_path)
model.set_state_dict(weight)

model = create_model(model_args)
load_model(training_args,model, ckpt_dir=model_args.model_path,load_language_model=False)
load_model(training_args,model.language_model, ckpt_dir=LLM_LIST[model_args.text_model_name_or_path])
model.eval()
model.to("gpu") # doctest: +IGNORE_RESULT
generated_ids, scores = model.generate(**inputs)
generated_text = processor.batch_decode(
generated_ids, skip_special_tokens=True)[0].strip()
logger.info("Generate text: {}".format(generated_text))
return model

def setdistenv(args):
if args.tensor_parallel_degree * args.sharding_parallel_degree * args.pipeline_parallel_degree!=1:
args.use_hybrid_parallel=True
args.dp_degree = dist.get_world_size() \
// (args.tensor_parallel_degree \
* args.sharding_parallel_degree * \
args.pipeline_parallel_degree)
strategy = fleet.DistributedStrategy()
if args.tensor_parallel_degree>1:
strategy.tensor_parallel = True
args.data_parallel_degree=args.dp_degree
logger.info("args.dp_degree:{}".format(args.dp_degree))
logger.info("args.sharding_parallel_degree):{}".format(args.sharding_parallel_degree))
if args.sharding_parallel_degree>1:
args.sharding="stage1"
strategy.hybrid_configs = {
"dp_degree": args.dp_degree,
"mp_degree": args.tensor_parallel_degree,
"sharding_degree": args.sharding_parallel_degree,
"pp_degree": args.pipeline_parallel_degree,
}
BATCH_SIZE=128
MICRO_BATCH_SIZE=32
strategy.pipeline_configs = {
"accumulate_steps": BATCH_SIZE // MICRO_BATCH_SIZE,
"micro_batch_size": MICRO_BATCH_SIZE
}
strategy.find_unused_parameters = True

# set control in tensor parallel
strategy.tensor_parallel_configs = {"tensor_init_seed": args.seed}

fleet.init(is_collective=True, strategy=strategy)

args.rank = dist.get_rank()
# obtain rank message of hybrid parallel
hcg = fleet.get_hybrid_communicate_group()
args.mp_rank = hcg.get_model_parallel_rank()
args.dp_rank = hcg.get_data_parallel_rank()
args.sharding_rank = hcg.get_sharding_parallel_rank()

args.data_world_rank = args.dp_rank * args.sharding_parallel_degree + args.sharding_rank
args.data_world_size = dist.get_world_size() // abs(args.tensor_parallel_degree * args.pipeline_parallel_degree)

# seed control in hybrid parallel
set_hyrbid_parallel_seed(args.seed, args.data_world_rank, args.mp_rank)

def set_hyrbid_parallel_seed(basic_seed, data_world_rank, mp_rank, pp_rank=0):
device_id = paddle.device.get_device()
assert 'gpu' in device_id

random.seed(basic_seed + data_world_rank)
np.random.seed(basic_seed + data_world_rank)
paddle.seed(basic_seed + data_world_rank)
#TODO add manual_seed
# local_seed/ global_seed is used to control dropout in ModelParallel
local_seed = 1024 + basic_seed + mp_rank * 100 + data_world_rank
global_seed = 2048 + basic_seed + data_world_rank
tracker = get_rng_state_tracker()
tracker.add("global_seed", global_seed)
tracker.add("local_seed", local_seed)

if __name__ == "__main__":
main()
Loading

0 comments on commit de634ce

Please sign in to comment.