Skip to content

Commit

Permalink
Parler tts compatibility due to last updates (#2278)
Browse files Browse the repository at this point in the history
Support compatibility with last updates in Parler TTS without support of
static cache.
  • Loading branch information
aleksandr-mokrov committed Aug 9, 2024
1 parent 51c21ea commit 9bde900
Showing 1 changed file with 10 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,8 @@
"source": [
"from collections import namedtuple\n",
"\n",
"import torch.nn as nn\n",
"\n",
"EncoderOutput = namedtuple(\"EncoderOutput\", \"last_hidden_state\")\n",
"DecoderOutput = namedtuple(\"DecoderOutput\", (\"last_hidden_state\", \"past_key_values\", \"hidden_states\", \"attentions\", \"cross_attentions\"))\n",
"\n",
Expand All @@ -369,10 +371,14 @@
"\n",
"\n",
"class DecoderWrapper(torch.nn.Module):\n",
" def __init__(self, decoder_stage_1_ir_path, decoder_stage_2_ir_path):\n",
" def __init__(self, decoder_stage_1_ir_path, decoder_stage_2_ir_path, config):\n",
" super().__init__()\n",
" self.decoder_stage_1 = core.compile_model(decoder_stage_1_ir_path, device.value)\n",
" self.decoder_stage_2 = core.compile_model(decoder_stage_2_ir_path, device.value)\n",
" self.config = config\n",
" self.embed_tokens = None\n",
" embed_dim = config.vocab_size + 1 # + 1 for pad token id\n",
" self.embed_tokens = nn.ModuleList([nn.Embedding(embed_dim, config.hidden_size) for _ in range(config.num_codebooks)])\n",
"\n",
" def __call__(self, input_ids=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_values=None, prompt_hidden_states=None, **kwargs):\n",
" inputs = {}\n",
Expand Down Expand Up @@ -419,7 +425,9 @@
"outputs": [],
"source": [
"model.text_encoder = TextEncoderModelWrapper(TEXT_ENCODER_OV_PATH, model.text_encoder.config)\n",
"model.decoder.model.decoder = DecoderWrapper(DECODER_STAGE_1_OV_PATH, DECODER_STAGE_2_OV_PATH)"
"model.decoder.model.decoder = DecoderWrapper(DECODER_STAGE_1_OV_PATH, DECODER_STAGE_2_OV_PATH, model.decoder.model.decoder.config)\n",
"model._supports_cache_class = False\n",
"model._supports_static_cache = False"
]
},
{
Expand Down

0 comments on commit 9bde900

Please sign in to comment.