Skip to content

Commit

Permalink
add automatic weight download (PaddlePaddle#146)
Browse files Browse the repository at this point in the history
* add automatic weight download
  • Loading branch information
lijianshe02 authored Jan 18, 2021
1 parent edd6211 commit 43ffc94
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 10 deletions.
11 changes: 5 additions & 6 deletions applications/tools/wav2lip.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,16 @@
parser.add_argument('--checkpoint_path',
type=str,
help='Name of saved checkpoint to load weights from',
required=True)

parser.add_argument('--face',
type=str,
help='Filepath of video/image that contains faces to use',
required=True)
default=None)
parser.add_argument(
'--audio',
type=str,
help='Filepath of video/audio file to use as raw audio source',
required=True)
parser.add_argument('--face',
type=str,
help='Filepath of video/image that contains faces to use',
required=True)
parser.add_argument('--outfile',
type=str,
help='Video path to save result. See default for an e.g.',
Expand Down
8 changes: 7 additions & 1 deletion ppgan/apps/wav2lip_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
from tqdm import tqdm
from glob import glob
import paddle
from paddle.utils.download import get_weights_path_from_url
from ppgan.faceutils import face_detection
from ppgan.utils import audio
from ppgan.models.generators.wav2lip import Wav2Lip
from .base_predictor import BasePredictor

WAV2LIP_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/wav2lip_hq.pdparams'
mel_step_size = 16


Expand Down Expand Up @@ -216,7 +218,11 @@ def run(self):
gen = self.datagen(full_frames.copy(), mel_chunks)

model = Wav2Lip()
weights = paddle.load(self.args.checkpoint_path)
if self.args.checkpoint_path is None:
model_weights_path = get_weights_path_from_url(WAV2LIP_WEIGHT_URL)
weights = paddle.load(model_weights_path)
else:
weights = paddle.load(self.args.checkpoint_path)
model.load_dict(weights)
model.eval()
print("Model loaded")
Expand Down
6 changes: 4 additions & 2 deletions ppgan/models/wav2lip_hq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import paddle
import paddle.nn.functional as F
from paddle.utils.download import get_weights_path_from_url
from .base_model import BaseModel

from .builder import MODELS
Expand All @@ -25,7 +26,7 @@
from ..solver import build_optimizer
from ..modules.init import init_weights

lipsync_weight_path = '/workspace/PaddleGAN/lipsync_expert.pdparams'
SYNCNET_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/syncnet.pdparams'


@MODELS.register()
Expand Down Expand Up @@ -65,7 +66,8 @@ def __init__(self,
distribution='uniform')
if self.is_train:
self.nets['netDS'] = build_discriminator(discriminator_sync)
params = paddle.load(lipsync_weight_path)
weights_path = get_weights_path_from_url(SYNCNET_WEIGHT_URL)
params = paddle.load(weights_path)
self.nets['netDS'].load_dict(params)

self.nets['netDH'] = build_discriminator(discriminator_hq)
Expand Down
5 changes: 4 additions & 1 deletion ppgan/models/wav2lip_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import paddle
from paddle.utils.download import get_weights_path_from_url
from .base_model import BaseModel

from .builder import MODELS
Expand All @@ -22,6 +23,7 @@
from ..solver import build_optimizer
from ..modules.init import init_weights

SYNCNET_WEIGHT_URL = 'https://paddlegan.bj.bcebos.com/models/syncnet.pdparams'
syncnet_T = 5
syncnet_mel_step_size = 16

Expand Down Expand Up @@ -74,7 +76,8 @@ def __init__(self,
init_weights(self.nets['netG'], distribution='uniform')
if self.is_train:
self.nets['netD'] = build_discriminator(discriminator)
params = paddle.load(lipsync_weight_path)
weights_path = get_weights_path_from_url(SYNCNET_WEIGHT_URL)
params = paddle.load(weights_path)
self.nets['netD'].load_dict(params)

if self.is_train:
Expand Down

0 comments on commit 43ffc94

Please sign in to comment.