diff --git a/GPT_SoVITS/export_torch_script.py b/GPT_SoVITS/export_torch_script.py index ce8821bd..f7bef133 100644 --- a/GPT_SoVITS/export_torch_script.py +++ b/GPT_SoVITS/export_torch_script.py @@ -330,11 +330,12 @@ def decode_next_token( for i in range(self.num_blocks): x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i]) return x, k_cache, v_cache - + class VitsModel(nn.Module): def __init__(self, vits_path): super().__init__() - dict_s2 = torch.load(vits_path,map_location="cpu") + # dict_s2 = torch.load(vits_path,map_location="cpu") + dict_s2 = torch.load(vits_path) self.hps = dict_s2["config"] if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322: self.hps["model"]["version"] = "v1" @@ -527,7 +528,7 @@ def build_phone_level_feature(res:Tensor, word2ph:IntTensor): phone_level_feature = torch.cat(phone_level_feature, dim=0) # [sum(word2ph), 1024] return phone_level_feature - + class MyBertModel(torch.nn.Module): def __init__(self, bert_model): super(MyBertModel, self).__init__() @@ -535,7 +536,8 @@ def __init__(self, bert_model): def forward(self, input_ids:torch.Tensor, attention_mask:torch.Tensor, token_type_ids:torch.Tensor, word2ph:IntTensor): outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) - res = torch.cat(outputs["hidden_states"][-3:-2], -1)[0][1:-1] + # res = torch.cat(outputs["hidden_states"][-3:-2], -1)[0][1:-1] + res = torch.cat(outputs[1][-3:-2], -1)[0][1:-1] return build_phone_level_feature(res, word2ph) class SSLModel(torch.nn.Module): @@ -560,13 +562,20 @@ def resample(self,ref_audio:torch.Tensor,src_sr:int,dst_sr:int)->torch.Tensor: audio = resamplex(ref_audio,src_sr,dst_sr).float() return audio -def export_bert(ref_bert_inputs): +def export_bert(output_path): tokenizer = AutoTokenizer.from_pretrained(bert_path) - ref_bert_inputs = tokenizer("声音,是有温度的.夜晚的声音,会发光", return_tensors="pt") - ref_bert_inputs['word2ph'] = torch.Tensor([2,2,1,2,2,2,2,2,1,2,2,2,2,2,1,2,2,2]).int() + text = "叹息声一声接着一声传出,木兰对着房门织布.听不见织布机织布的声音,只听见木兰在叹息.问木兰在想什么?问木兰在惦记什么?木兰答道,我也没有在想什么,也没有在惦记什么." + ref_bert_inputs = tokenizer(text, return_tensors="pt") + word2ph = [] + for c in text: + if c in [',','。',':','?',",",".","?"]: + word2ph.append(1) + else: + word2ph.append(2) + ref_bert_inputs['word2ph'] = torch.Tensor(word2ph).int() - bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True) + bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True,torchscript=True) my_bert_model = MyBertModel(bert_model) ref_bert_inputs = { @@ -576,13 +585,17 @@ def export_bert(ref_bert_inputs): 'word2ph': ref_bert_inputs['word2ph'] } + torch._dynamo.mark_dynamic(ref_bert_inputs['input_ids'], 1) + torch._dynamo.mark_dynamic(ref_bert_inputs['attention_mask'], 1) + torch._dynamo.mark_dynamic(ref_bert_inputs['token_type_ids'], 1) + torch._dynamo.mark_dynamic(ref_bert_inputs['word2ph'], 0) + my_bert_model = torch.jit.trace(my_bert_model,example_kwarg_inputs=ref_bert_inputs) - my_bert_model.save("onnx/bert_model.pt") + output_path = os.path.join(output_path, "bert_model.pt") + my_bert_model.save(output_path) print('#### exported bert ####') - -def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path): - # export_bert(ref_bert_inputs) +def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_bert_and_ssl=False, device='cpu'): if not os.path.exists(output_path): os.makedirs(output_path) print(f"目录已创建: {output_path}") @@ -591,45 +604,57 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path): ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float() ssl = SSLModel() - s = ExportSSLModel(torch.jit.trace(ssl,example_inputs=(ref_audio))) - ssl_path = os.path.join(output_path, "ssl_model.pt") - torch.jit.script(s).save(ssl_path) - print('#### exported ssl ####') + if export_bert_and_ssl: + s = ExportSSLModel(torch.jit.trace(ssl,example_inputs=(ref_audio))) + ssl_path = os.path.join(output_path, "ssl_model.pt") + torch.jit.script(s).save(ssl_path) + print('#### exported ssl ####') + export_bert(output_path) + else: + s = ExportSSLModel(ssl) + + print(f"device: {device}") + ref_seq_id,ref_bert_T,ref_norm_text = get_phones_and_bert(ref_text,"all_zh",'v2') - ref_seq = torch.LongTensor([ref_seq_id]) + ref_seq = torch.LongTensor([ref_seq_id]).to(device) ref_bert = ref_bert_T.T.to(ref_seq.device) text_seq_id,text_bert_T,norm_text = get_phones_and_bert("这是一条测试语音,说什么无所谓,只是给它一个例子","all_zh",'v2') - text_seq = torch.LongTensor([text_seq_id]) + text_seq = torch.LongTensor([text_seq_id]).to(device) text_bert = text_bert_T.T.to(text_seq.device) - ssl_content = ssl(ref_audio) + ssl_content = ssl(ref_audio).to(device) # vits_path = "SoVITS_weights_v2/xw_e8_s216.pth" - vits = VitsModel(vits_path) + vits = VitsModel(vits_path).to(device) vits.eval() # gpt_path = "GPT_weights_v2/xw-e15.ckpt" - dict_s1 = torch.load(gpt_path, map_location="cpu") - raw_t2s = get_raw_t2s_model(dict_s1) + # dict_s1 = torch.load(gpt_path, map_location=device) + dict_s1 = torch.load(gpt_path) + raw_t2s = get_raw_t2s_model(dict_s1).to(device) print('#### get_raw_t2s_model ####') print(raw_t2s.config) t2s_m = T2SModel(raw_t2s) t2s_m.eval() - t2s = torch.jit.script(t2s_m) + t2s = torch.jit.script(t2s_m).to(device) print('#### script t2s_m ####') print("vits.hps.data.sampling_rate:",vits.hps.data.sampling_rate) - gpt_sovits = GPT_SoVITS(t2s,vits) + gpt_sovits = GPT_SoVITS(t2s,vits).to(device) gpt_sovits.eval() - ref_audio_sr = s.resample(ref_audio,16000,32000) - ref_audio_sr = s.resample(ref_audio,16000,32000) - print('ref_audio_sr:',ref_audio_sr.shape) - - ref_audio_sr = s.resample(ref_audio,16000,32000) - print('ref_audio_sr:',ref_audio_sr.shape) - gpt_sovits_export = torch.jit.trace( + ref_audio_sr = s.resample(ref_audio,16000,32000).to(device) + + torch._dynamo.mark_dynamic(ssl_content, 2) + torch._dynamo.mark_dynamic(ref_audio_sr, 1) + torch._dynamo.mark_dynamic(ref_seq, 1) + torch._dynamo.mark_dynamic(text_seq, 1) + torch._dynamo.mark_dynamic(ref_bert, 0) + torch._dynamo.mark_dynamic(text_bert, 0) + + with torch.no_grad(): + gpt_sovits_export = torch.jit.trace( gpt_sovits, example_inputs=( ssl_content, @@ -639,9 +664,9 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path): ref_bert, text_bert)) - gpt_sovits_path = os.path.join(output_path, "gpt_sovits_model.pt") - gpt_sovits_export.save(gpt_sovits_path) - print('#### exported gpt_sovits ####') + gpt_sovits_path = os.path.join(output_path, "gpt_sovits_model.pt") + gpt_sovits_export.save(gpt_sovits_path) + print('#### exported gpt_sovits ####') @torch.jit.script def parse_audio(ref_audio): @@ -674,6 +699,8 @@ def test(): parser.add_argument('--sovits_model', required=True, help="Path to the SoVITS model file") parser.add_argument('--ref_audio', required=True, help="Path to the reference audio file") parser.add_argument('--ref_text', required=True, help="Path to the reference text file") + parser.add_argument('--output_path', required=True, help="Path to the output directory") + args = parser.parse_args() gpt_path = args.gpt_model @@ -682,42 +709,63 @@ def test(): ref_text = args.ref_text tokenizer = AutoTokenizer.from_pretrained(bert_path) - bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True) - bert = MyBertModel(bert_model) - # bert = torch.jit.load("onnx/bert_model.pt",map_location='cuda') - - # gpt_path = "GPT_weights_v2/xw-e15.ckpt" - dict_s1 = torch.load(gpt_path, map_location="cpu") - raw_t2s = get_raw_t2s_model(dict_s1) - t2s = T2SModel(raw_t2s) - t2s.eval() + # bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True,torchscript=True) + # bert = MyBertModel(bert_model) + my_bert = torch.jit.load("onnx/bert_model.pt",map_location='cuda') + + # dict_s1 = torch.load(gpt_path, map_location="cuda") + # raw_t2s = get_raw_t2s_model(dict_s1) + # t2s = T2SModel(raw_t2s) + # t2s.eval() # t2s = torch.jit.load("onnx/xw/t2s_model.pt",map_location='cuda') # vits_path = "SoVITS_weights_v2/xw_e8_s216.pth" - vits = VitsModel(vits_path) - vits.eval() - - ssl = ExportSSLModel(SSLModel()) - ssl.eval() + # vits = VitsModel(vits_path) + # vits.eval() - gpt_sovits = GPT_SoVITS(t2s,vits) + # ssl = ExportSSLModel(SSLModel()).to('cuda') + # ssl.eval() + ssl = torch.jit.load("onnx/by/ssl_model.pt",map_location='cuda') - # vits = torch.jit.load("onnx/xw/vits_model.pt",map_location='cuda') - # ssl = torch.jit.load("onnx/xw/ssl_model.pt",map_location='cuda') + # gpt_sovits = GPT_SoVITS(t2s,vits) + gpt_sovits = torch.jit.load("onnx/by/gpt_sovits_model.pt",map_location='cuda') ref_seq_id,ref_bert_T,ref_norm_text = get_phones_and_bert(ref_text,"all_zh",'v2') ref_seq = torch.LongTensor([ref_seq_id]) ref_bert = ref_bert_T.T.to(ref_seq.device) - text_seq_id,text_bert_T,norm_text = get_phones_and_bert("问木兰在想什么?问木兰在惦记什么?木兰答道,我也没有在想什么,也没有在惦记什么。","all_zh",'v2') + # text_seq_id,text_bert_T,norm_text = get_phones_and_bert("昨天晚上看见征兵文书,知道君主在大规模征兵,那么多卷征兵文册,每一卷上都有父亲的名字.","all_zh",'v2') + text = "昨天晚上看见征兵文书,知道君主在大规模征兵,那么多卷征兵文册,每一卷上都有父亲的名字." + + text_seq_id,text_bert_T,norm_text = get_phones_and_bert(text,"all_zh",'v2') + + test_bert = tokenizer(text, return_tensors="pt") + word2ph = [] + for c in text: + if c in [',','。',':','?',"?",",","."]: + word2ph.append(1) + else: + word2ph.append(2) + test_bert['word2ph'] = torch.Tensor(word2ph).int() + + test_bert = my_bert( + test_bert['input_ids'].to('cuda'), + test_bert['attention_mask'].to('cuda'), + test_bert['token_type_ids'].to('cuda'), + test_bert['word2ph'].to('cuda') + ) + text_seq = torch.LongTensor([text_seq_id]) - print('text_seq:',text_seq_id) text_bert = text_bert_T.T.to(text_seq.device) - # text_bert = torch.zeros(text_bert.shape, dtype=text_bert.dtype).to(text_bert.device) + + print('text_bert:',text_bert.shape,text_bert) + print('test_bert:',test_bert.shape,test_bert) + print(torch.allclose(text_bert.to('cuda'),test_bert)) + print('text_seq:',text_seq.shape) - print('text_bert:',text_bert.shape) + print('text_bert:',text_bert.shape,text_bert.type()) #[1,N] - ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float() + ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float().to('cuda') print('ref_audio:',ref_audio.shape) ref_audio_sr = ssl.resample(ref_audio,16000,32000) @@ -725,13 +773,22 @@ def test(): ssl_content = ssl(ref_audio) print('start gpt_sovits:') + print('ssl_content:',ssl_content.shape) + print('ref_audio_sr:',ref_audio_sr.shape) + print('ref_seq:',ref_seq.shape) + ref_seq=ref_seq.to('cuda') + print('text_seq:',text_seq.shape) + text_seq=text_seq.to('cuda') + print('ref_bert:',ref_bert.shape) + ref_bert=ref_bert.to('cuda') + print('text_bert:',text_bert.shape) + text_bert=text_bert.to('cuda') + with torch.no_grad(): - audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, text_bert) + audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, test_bert) print('start write wav') soundfile.write("out.wav", audio.detach().cpu().numpy(), 32000) - # audio = vits(text_seq, pred_semantic1, ref_audio) - # soundfile.write("out.wav", audio, 32000) import text import json @@ -753,12 +810,23 @@ def main(): parser.add_argument('--ref_audio', required=True, help="Path to the reference audio file") parser.add_argument('--ref_text', required=True, help="Path to the reference text file") parser.add_argument('--output_path', required=True, help="Path to the output directory") + parser.add_argument('--export_common_model', action='store_true', help="Export Bert and SSL model") + parser.add_argument('--device', help="Device to use") args = parser.parse_args() - export(gpt_path=args.gpt_model, vits_path=args.sovits_model, ref_audio_path=args.ref_audio, ref_text=args.ref_text, output_path=args.output_path) + export( + gpt_path=args.gpt_model, + vits_path=args.sovits_model, + ref_audio_path=args.ref_audio, + ref_text=args.ref_text, + output_path=args.output_path, + device=args.device, + export_bert_and_ssl=args.export_common_model, + ) import inference_webui if __name__ == "__main__": inference_webui.is_half=False inference_webui.dtype=torch.float32 - main() \ No newline at end of file + main() + # test()