From eb875a2a2f2acb36fcfd9054746a81b655485117 Mon Sep 17 00:00:00 2001 From: Mddct Date: Fri, 8 Nov 2024 11:22:54 +0800 Subject: [PATCH] fix --- wenet/bin/export_onnx_gpu.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/wenet/bin/export_onnx_gpu.py b/wenet/bin/export_onnx_gpu.py index 568ada358..19a7b8655 100644 --- a/wenet/bin/export_onnx_gpu.py +++ b/wenet/bin/export_onnx_gpu.py @@ -171,7 +171,7 @@ def forward(self, chunk_xs, chunk_lens, offset, att_cache, cnn_cache, i_kv_cache = att_cache[i:i + 1] size = att_cache.size(-1) // 2 kv_cache = (i_kv_cache[:, :, :, :size], i_kv_cache[:, :, :, size:]) - xs, _, new_att_cache, new_cnn_cache = layer( + xs, _, new_kv_cache, new_cnn_cache = layer( xs, masks, pos_emb, @@ -180,6 +180,7 @@ def forward(self, chunk_xs, chunk_lens, offset, att_cache, cnn_cache, ) # shape(new_att_cache) is (B, head, attention_key_size, d_k * 2), # shape(new_cnn_cache) is (B, hidden-dim, cache_t2) + new_att_cache = torch.cat(new_kv_cache, dim=-1) r_att_cache.append( new_att_cache[:, :, next_cache_start:, :].unsqueeze(1)) if not self.transformer: