Skip to content

Commit

Permalink
优化 export_torch_script.py (RVC-Boss#1739)
Browse files Browse the repository at this point in the history
  • Loading branch information
L-jasmine authored Nov 7, 2024
1 parent 6d82050 commit a70e1ad
Showing 1 changed file with 129 additions and 61 deletions.
190 changes: 129 additions & 61 deletions GPT_SoVITS/export_torch_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,11 +330,12 @@ def decode_next_token(
for i in range(self.num_blocks):
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i])
return x, k_cache, v_cache

class VitsModel(nn.Module):
def __init__(self, vits_path):
super().__init__()
dict_s2 = torch.load(vits_path,map_location="cpu")
# dict_s2 = torch.load(vits_path,map_location="cpu")
dict_s2 = torch.load(vits_path)
self.hps = dict_s2["config"]
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
self.hps["model"]["version"] = "v1"
Expand Down Expand Up @@ -527,15 +528,16 @@ def build_phone_level_feature(res:Tensor, word2ph:IntTensor):
phone_level_feature = torch.cat(phone_level_feature, dim=0)
# [sum(word2ph), 1024]
return phone_level_feature

class MyBertModel(torch.nn.Module):
def __init__(self, bert_model):
super(MyBertModel, self).__init__()
self.bert = bert_model

def forward(self, input_ids:torch.Tensor, attention_mask:torch.Tensor, token_type_ids:torch.Tensor, word2ph:IntTensor):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
res = torch.cat(outputs["hidden_states"][-3:-2], -1)[0][1:-1]
# res = torch.cat(outputs["hidden_states"][-3:-2], -1)[0][1:-1]
res = torch.cat(outputs[1][-3:-2], -1)[0][1:-1]
return build_phone_level_feature(res, word2ph)

class SSLModel(torch.nn.Module):
Expand All @@ -560,13 +562,20 @@ def resample(self,ref_audio:torch.Tensor,src_sr:int,dst_sr:int)->torch.Tensor:
audio = resamplex(ref_audio,src_sr,dst_sr).float()
return audio

def export_bert(ref_bert_inputs):
def export_bert(output_path):
tokenizer = AutoTokenizer.from_pretrained(bert_path)

ref_bert_inputs = tokenizer("声音,是有温度的.夜晚的声音,会发光", return_tensors="pt")
ref_bert_inputs['word2ph'] = torch.Tensor([2,2,1,2,2,2,2,2,1,2,2,2,2,2,1,2,2,2]).int()
text = "叹息声一声接着一声传出,木兰对着房门织布.听不见织布机织布的声音,只听见木兰在叹息.问木兰在想什么?问木兰在惦记什么?木兰答道,我也没有在想什么,也没有在惦记什么."
ref_bert_inputs = tokenizer(text, return_tensors="pt")
word2ph = []
for c in text:
if c in [',','。',':','?',",",".","?"]:
word2ph.append(1)
else:
word2ph.append(2)
ref_bert_inputs['word2ph'] = torch.Tensor(word2ph).int()

bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True)
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True,torchscript=True)
my_bert_model = MyBertModel(bert_model)

ref_bert_inputs = {
Expand All @@ -576,13 +585,17 @@ def export_bert(ref_bert_inputs):
'word2ph': ref_bert_inputs['word2ph']
}

torch._dynamo.mark_dynamic(ref_bert_inputs['input_ids'], 1)
torch._dynamo.mark_dynamic(ref_bert_inputs['attention_mask'], 1)
torch._dynamo.mark_dynamic(ref_bert_inputs['token_type_ids'], 1)
torch._dynamo.mark_dynamic(ref_bert_inputs['word2ph'], 0)

my_bert_model = torch.jit.trace(my_bert_model,example_kwarg_inputs=ref_bert_inputs)
my_bert_model.save("onnx/bert_model.pt")
output_path = os.path.join(output_path, "bert_model.pt")
my_bert_model.save(output_path)
print('#### exported bert ####')

def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path):
# export_bert(ref_bert_inputs)

def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_bert_and_ssl=False, device='cpu'):
if not os.path.exists(output_path):
os.makedirs(output_path)
print(f"目录已创建: {output_path}")
Expand All @@ -591,45 +604,57 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path):

ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float()
ssl = SSLModel()
s = ExportSSLModel(torch.jit.trace(ssl,example_inputs=(ref_audio)))
ssl_path = os.path.join(output_path, "ssl_model.pt")
torch.jit.script(s).save(ssl_path)
print('#### exported ssl ####')
if export_bert_and_ssl:
s = ExportSSLModel(torch.jit.trace(ssl,example_inputs=(ref_audio)))
ssl_path = os.path.join(output_path, "ssl_model.pt")
torch.jit.script(s).save(ssl_path)
print('#### exported ssl ####')
export_bert(output_path)
else:
s = ExportSSLModel(ssl)

print(f"device: {device}")


ref_seq_id,ref_bert_T,ref_norm_text = get_phones_and_bert(ref_text,"all_zh",'v2')
ref_seq = torch.LongTensor([ref_seq_id])
ref_seq = torch.LongTensor([ref_seq_id]).to(device)
ref_bert = ref_bert_T.T.to(ref_seq.device)
text_seq_id,text_bert_T,norm_text = get_phones_and_bert("这是一条测试语音,说什么无所谓,只是给它一个例子","all_zh",'v2')
text_seq = torch.LongTensor([text_seq_id])
text_seq = torch.LongTensor([text_seq_id]).to(device)
text_bert = text_bert_T.T.to(text_seq.device)

ssl_content = ssl(ref_audio)
ssl_content = ssl(ref_audio).to(device)

# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
vits = VitsModel(vits_path)
vits = VitsModel(vits_path).to(device)
vits.eval()

# gpt_path = "GPT_weights_v2/xw-e15.ckpt"
dict_s1 = torch.load(gpt_path, map_location="cpu")
raw_t2s = get_raw_t2s_model(dict_s1)
# dict_s1 = torch.load(gpt_path, map_location=device)
dict_s1 = torch.load(gpt_path)
raw_t2s = get_raw_t2s_model(dict_s1).to(device)
print('#### get_raw_t2s_model ####')
print(raw_t2s.config)
t2s_m = T2SModel(raw_t2s)
t2s_m.eval()
t2s = torch.jit.script(t2s_m)
t2s = torch.jit.script(t2s_m).to(device)
print('#### script t2s_m ####')

print("vits.hps.data.sampling_rate:",vits.hps.data.sampling_rate)
gpt_sovits = GPT_SoVITS(t2s,vits)
gpt_sovits = GPT_SoVITS(t2s,vits).to(device)
gpt_sovits.eval()
ref_audio_sr = s.resample(ref_audio,16000,32000)
ref_audio_sr = s.resample(ref_audio,16000,32000)
print('ref_audio_sr:',ref_audio_sr.shape)

ref_audio_sr = s.resample(ref_audio,16000,32000)
print('ref_audio_sr:',ref_audio_sr.shape)

gpt_sovits_export = torch.jit.trace(
ref_audio_sr = s.resample(ref_audio,16000,32000).to(device)

torch._dynamo.mark_dynamic(ssl_content, 2)
torch._dynamo.mark_dynamic(ref_audio_sr, 1)
torch._dynamo.mark_dynamic(ref_seq, 1)
torch._dynamo.mark_dynamic(text_seq, 1)
torch._dynamo.mark_dynamic(ref_bert, 0)
torch._dynamo.mark_dynamic(text_bert, 0)

with torch.no_grad():
gpt_sovits_export = torch.jit.trace(
gpt_sovits,
example_inputs=(
ssl_content,
Expand All @@ -639,9 +664,9 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path):
ref_bert,
text_bert))

gpt_sovits_path = os.path.join(output_path, "gpt_sovits_model.pt")
gpt_sovits_export.save(gpt_sovits_path)
print('#### exported gpt_sovits ####')
gpt_sovits_path = os.path.join(output_path, "gpt_sovits_model.pt")
gpt_sovits_export.save(gpt_sovits_path)
print('#### exported gpt_sovits ####')

@torch.jit.script
def parse_audio(ref_audio):
Expand Down Expand Up @@ -674,6 +699,8 @@ def test():
parser.add_argument('--sovits_model', required=True, help="Path to the SoVITS model file")
parser.add_argument('--ref_audio', required=True, help="Path to the reference audio file")
parser.add_argument('--ref_text', required=True, help="Path to the reference text file")
parser.add_argument('--output_path', required=True, help="Path to the output directory")


args = parser.parse_args()
gpt_path = args.gpt_model
Expand All @@ -682,56 +709,86 @@ def test():
ref_text = args.ref_text

tokenizer = AutoTokenizer.from_pretrained(bert_path)
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True)
bert = MyBertModel(bert_model)
# bert = torch.jit.load("onnx/bert_model.pt",map_location='cuda')

# gpt_path = "GPT_weights_v2/xw-e15.ckpt"
dict_s1 = torch.load(gpt_path, map_location="cpu")
raw_t2s = get_raw_t2s_model(dict_s1)
t2s = T2SModel(raw_t2s)
t2s.eval()
# bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True,torchscript=True)
# bert = MyBertModel(bert_model)
my_bert = torch.jit.load("onnx/bert_model.pt",map_location='cuda')

# dict_s1 = torch.load(gpt_path, map_location="cuda")
# raw_t2s = get_raw_t2s_model(dict_s1)
# t2s = T2SModel(raw_t2s)
# t2s.eval()
# t2s = torch.jit.load("onnx/xw/t2s_model.pt",map_location='cuda')

# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
vits = VitsModel(vits_path)
vits.eval()

ssl = ExportSSLModel(SSLModel())
ssl.eval()
# vits = VitsModel(vits_path)
# vits.eval()

gpt_sovits = GPT_SoVITS(t2s,vits)
# ssl = ExportSSLModel(SSLModel()).to('cuda')
# ssl.eval()
ssl = torch.jit.load("onnx/by/ssl_model.pt",map_location='cuda')

# vits = torch.jit.load("onnx/xw/vits_model.pt",map_location='cuda')
# ssl = torch.jit.load("onnx/xw/ssl_model.pt",map_location='cuda')
# gpt_sovits = GPT_SoVITS(t2s,vits)
gpt_sovits = torch.jit.load("onnx/by/gpt_sovits_model.pt",map_location='cuda')

ref_seq_id,ref_bert_T,ref_norm_text = get_phones_and_bert(ref_text,"all_zh",'v2')
ref_seq = torch.LongTensor([ref_seq_id])
ref_bert = ref_bert_T.T.to(ref_seq.device)
text_seq_id,text_bert_T,norm_text = get_phones_and_bert("问木兰在想什么?问木兰在惦记什么?木兰答道,我也没有在想什么,也没有在惦记什么。","all_zh",'v2')
# text_seq_id,text_bert_T,norm_text = get_phones_and_bert("昨天晚上看见征兵文书,知道君主在大规模征兵,那么多卷征兵文册,每一卷上都有父亲的名字.","all_zh",'v2')
text = "昨天晚上看见征兵文书,知道君主在大规模征兵,那么多卷征兵文册,每一卷上都有父亲的名字."

text_seq_id,text_bert_T,norm_text = get_phones_and_bert(text,"all_zh",'v2')

test_bert = tokenizer(text, return_tensors="pt")
word2ph = []
for c in text:
if c in [',','。',':','?',"?",",","."]:
word2ph.append(1)
else:
word2ph.append(2)
test_bert['word2ph'] = torch.Tensor(word2ph).int()

test_bert = my_bert(
test_bert['input_ids'].to('cuda'),
test_bert['attention_mask'].to('cuda'),
test_bert['token_type_ids'].to('cuda'),
test_bert['word2ph'].to('cuda')
)

text_seq = torch.LongTensor([text_seq_id])
print('text_seq:',text_seq_id)
text_bert = text_bert_T.T.to(text_seq.device)
# text_bert = torch.zeros(text_bert.shape, dtype=text_bert.dtype).to(text_bert.device)

print('text_bert:',text_bert.shape,text_bert)
print('test_bert:',test_bert.shape,test_bert)
print(torch.allclose(text_bert.to('cuda'),test_bert))

print('text_seq:',text_seq.shape)
print('text_bert:',text_bert.shape)
print('text_bert:',text_bert.shape,text_bert.type())

#[1,N]
ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float()
ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float().to('cuda')
print('ref_audio:',ref_audio.shape)

ref_audio_sr = ssl.resample(ref_audio,16000,32000)
print('start ssl')
ssl_content = ssl(ref_audio)

print('start gpt_sovits:')
print('ssl_content:',ssl_content.shape)
print('ref_audio_sr:',ref_audio_sr.shape)
print('ref_seq:',ref_seq.shape)
ref_seq=ref_seq.to('cuda')
print('text_seq:',text_seq.shape)
text_seq=text_seq.to('cuda')
print('ref_bert:',ref_bert.shape)
ref_bert=ref_bert.to('cuda')
print('text_bert:',text_bert.shape)
text_bert=text_bert.to('cuda')

with torch.no_grad():
audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, text_bert)
audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, test_bert)
print('start write wav')
soundfile.write("out.wav", audio.detach().cpu().numpy(), 32000)

# audio = vits(text_seq, pred_semantic1, ref_audio)
# soundfile.write("out.wav", audio, 32000)

import text
import json
Expand All @@ -753,12 +810,23 @@ def main():
parser.add_argument('--ref_audio', required=True, help="Path to the reference audio file")
parser.add_argument('--ref_text', required=True, help="Path to the reference text file")
parser.add_argument('--output_path', required=True, help="Path to the output directory")
parser.add_argument('--export_common_model', action='store_true', help="Export Bert and SSL model")
parser.add_argument('--device', help="Device to use")

args = parser.parse_args()
export(gpt_path=args.gpt_model, vits_path=args.sovits_model, ref_audio_path=args.ref_audio, ref_text=args.ref_text, output_path=args.output_path)
export(
gpt_path=args.gpt_model,
vits_path=args.sovits_model,
ref_audio_path=args.ref_audio,
ref_text=args.ref_text,
output_path=args.output_path,
device=args.device,
export_bert_and_ssl=args.export_common_model,
)

import inference_webui
if __name__ == "__main__":
inference_webui.is_half=False
inference_webui.dtype=torch.float32
main()
main()
# test()

0 comments on commit a70e1ad

Please sign in to comment.