Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#29 from wjm202/change_llm
Browse files Browse the repository at this point in the history
modify llm load method
  • Loading branch information
lyuwenyu authored Jul 28, 2023
2 parents 76550bf + 2e66373 commit 0a9cb60
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 5 deletions.
3 changes: 2 additions & 1 deletion paddlevlp/examples/blip2/run_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from paddlevlp.models.blip2.eva_vit import interpolate_pos_embed
from paddlevlp.processors.blip_processing import BlipImageProcessor, BlipTextProcessor
from paddlevlp.examples.blip2.utils import BlipCollator
from paddlevlp.examples.blip2.utils import load_pretrained_model
from paddlevlp.examples.blip2.utils import load_pretrained_model, LLM_LIST


@dataclass
Expand Down Expand Up @@ -243,6 +243,7 @@ def main():
logger.info("training_args.use_hybrid_parallel:{}".format(
training_args.use_hybrid_parallel))
# create trainer
load_pretrained_model(model, LLM_LIST[model_args.llm_name])
load_pretrained_model(model, training_args.pretrained_model_path)
trainer = Trainer(
model=model,
Expand Down
6 changes: 5 additions & 1 deletion paddlevlp/examples/blip2/run_pretrain_stage2.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from paddlevlp.models.blip2.eva_vit import interpolate_pos_embed
from paddlevlp.processors.blip_processing import BlipImageProcessor, BlipTextProcessor
from paddlevlp.examples.blip2.utils import BlipCollator
from paddlevlp.examples.blip2.utils import load_pretrained_model
from paddlevlp.examples.blip2.utils import load_pretrained_model, LLM_LIST


@dataclass
Expand Down Expand Up @@ -79,6 +79,9 @@ class ModelArguments:
image_size: int = field(
default=224,
metadata={"help": " Image size for training. (default:224)"})
llm_name: str = field(
default="opt-2.7b",
metadata={"help": "llm name which you ned to load in LLM_LIST"})


@dataclass
Expand Down Expand Up @@ -244,6 +247,7 @@ def main():
logger.info("training_args.use_hybrid_parallel:{}".format(
training_args.use_hybrid_parallel))
# create trainer
load_pretrained_model(model.language_model, LLM_LIST[model_args.llm_name])
load_pretrained_model(model, training_args.pretrained_model_path)
trainer = Trainer(
model=model,
Expand Down
151 changes: 148 additions & 3 deletions paddlevlp/examples/blip2/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,48 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import os.path as osp
import tarfile
import shutil
import zipfile
import requests
from pycocoevalcap.eval import COCOEvalCap
from pycocotools.coco import COCO
from paddlevlp.utils.downloader import get_weights_path_from_url
from paddlevlp.utils.downloader import is_url
from paddlevlp.utils.downloader import is_url, _decompress, _md5check, WEIGHTS_HOME, DOWNLOAD_RETRY_LIMIT
from paddlevlp.models.blip2.eva_vit import interpolate_pos_embed
import paddle
from paddlevlp.utils.downloader import is_url
from paddlevlp.utils.log import logger
from tqdm.auto import tqdm
LLM_LIST = {
"opt-2.7b":
"https://bj.bcebos.com/paddlenlp/models/community/facebook/opt-2.7b/model_state.pdparams",
"t5-small":
"https://bj.bcebos.com/paddlenlp/models/transformers/t5/t5-small/model_state.pdparams",
"t5-base":
"https://bj.bcebos.com/paddlenlp/models/transformers/t5/t5-base/model_state.pdparams",
"t5-large":
"https://bj.bcebos.com/paddlenlp/models/transformers/t5/t5-large/model_state.pdparams",
"t5-3b":
"https://bj.bcebos.com/paddlenlp/models/transformers/t5/t5-3b/model_state.pdparams",
"t5-11b":
"https://bj.bcebos.com/paddlenlp/models/transformers/t5/t5-11b/model_state.pdparams",
"t5-v1_1-base":
"https://bj.bcebos.com/paddlenlp/models/transformers/t5/t5-v1_1-base/model_state.pdparams",
"t5-v1_1-large":
"https://bj.bcebos.com/paddlenlp/models/transformers/t5/t5-v1_1-large/model_state.pdparams",
}


class BlipCollator:
Expand Down Expand Up @@ -82,6 +120,113 @@ def load_pretrained_model(model, pretrained_model_path):
else:
assert os.path.exists(
pretrained_model_path), f"{pretrained_model_path} not exist"
state_dict = paddle.load(path)
state_dict = paddle.load(path, return_numpy=True)
interpolate_pos_embed(model, state_dict)
model.set_state_dict(state_dict)


def _map_path(url, root_dir):
# parse path after download under root_dir
fname = osp.join(url.split("/")[-2], url.split("/")[-1])
fpath = fname
return osp.join(root_dir, fpath)


def get_weights_path_from_url(url, md5sum=None):
"""Get weights path from WEIGHT_HOME, if not exists,
download it from url.
Args:
url (str): download url
md5sum (str): md5 sum of download package
Returns:
str: a local path to save downloaded weights.
Examples:
.. code-block:: python
from paddle.utils.download import get_weights_path_from_url
resnet18_pretrained_weight_url = 'https://paddle-hapi.bj.bcebos.com/models/resnet18.pdparams'
local_weight_path = get_weights_path_from_url(resnet18_pretrained_weight_url)
"""
path = get_path_from_url(url, WEIGHTS_HOME, md5sum)
return path


def get_path_from_url(url, root_dir, md5sum=None, check_exist=True):
"""Download from given url to root_dir.
if file or directory specified by url is exists under
root_dir, return the path directly, otherwise download
from url and decompress it, return the path.
Args:
url (str): download url
root_dir (str): root dir for downloading, it should be
WEIGHTS_HOME or DATASET_HOME
md5sum (str): md5 sum of download package
Returns:
str: a local path to save downloaded models & weights & datasets.
"""

assert is_url(url), "downloading from {} not a url".format(url)
# parse path after download to decompress under root_dir
fullpath = _map_path(url, root_dir)

if osp.exists(fullpath) and check_exist and _md5check(fullpath, md5sum):
logger.info("Found {}".format(fullpath))
else:
fullpath = _download(url, root_dir, md5sum)

if tarfile.is_tarfile(fullpath) or zipfile.is_zipfile(fullpath):
fullpath = _decompress(fullpath)

# model tokenizer config, [file-lock]
return fullpath


def _download(url, path, md5sum=None):
"""
Download from url, save to path.
url (str): download url
path (str): download to given path
"""
path = osp.join(path, url.split("/")[-2])
os.makedirs(path, exist_ok=True)
fname = osp.join(url.split("/")[-1])
fullname = osp.join(path, fname)
retry_cnt = 0

while not (osp.exists(fullname) and _md5check(fullname, md5sum)):
if retry_cnt < DOWNLOAD_RETRY_LIMIT:
retry_cnt += 1
else:
raise RuntimeError("Download from {} failed. "
"Retry limit reached".format(url))

logger.info("Downloading {} from {}".format(fname, url))

req = requests.get(url, stream=True)
if req.status_code != 200:
raise RuntimeError("Downloading from {} failed with code "
"{}!".format(url, req.status_code))

# For protecting download interupted, download to
# tmp_fullname firstly, move tmp_fullname to fullname
# after download finished
tmp_fullname = fullname + "_tmp"
total_size = req.headers.get("content-length")
with open(tmp_fullname, "wb") as f:
if total_size:
with tqdm(
total=int(total_size),
unit="B",
unit_scale=True,
unit_divisor=1024) as pbar:
for chunk in req.iter_content(chunk_size=1024):
f.write(chunk)
pbar.update(len(chunk))
else:
for chunk in req.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
shutil.move(tmp_fullname, fullname)

return fullname

0 comments on commit 0a9cb60

Please sign in to comment.