diff --git a/wenet/bin/export_onnx_gpu.py b/wenet/bin/export_onnx_gpu.py index ab6c1dbe0..e8f3fbf76 100644 --- a/wenet/bin/export_onnx_gpu.py +++ b/wenet/bin/export_onnx_gpu.py @@ -15,14 +15,13 @@ from __future__ import print_function import argparse +import logging import os import sys import torch -import yaml -import logging - import torch.nn.functional as F +import yaml from wenet.transformer.ctc import CTC from wenet.transformer.decoder import TransformerDecoder from wenet.transformer.encoder import BaseEncoder @@ -169,15 +168,19 @@ def forward(self, chunk_xs, chunk_lens, offset, att_cache, cnn_cache, r_att_cache = [] r_cnn_cache = [] for i, layer in enumerate(self.encoder.encoders): - xs, _, new_att_cache, new_cnn_cache = layer( + i_kv_cache = att_cache[i] + size = att_cache.size(-1) // 2 + kv_cache = (i_kv_cache[:, :, :, :size], i_kv_cache[:, :, :, size:]) + xs, _, new_kv_cache, new_cnn_cache = layer( xs, masks, pos_emb, - att_cache=att_cache[i], + att_cache=kv_cache, cnn_cache=cnn_cache[i], ) # shape(new_att_cache) is (B, head, attention_key_size, d_k * 2), # shape(new_cnn_cache) is (B, hidden-dim, cache_t2) + new_att_cache = torch.cat(new_kv_cache, dim=-1) r_att_cache.append( new_att_cache[:, :, next_cache_start:, :].unsqueeze(1)) if not self.transformer: @@ -1241,8 +1244,8 @@ def export_rescoring_decoder(model, configs, args, logger, decoder_onnx_path, if args.fp16: try: import onnxmltools - from onnxmltools.utils.float16_converter import ( - convert_float_to_float16, ) + from onnxmltools.utils.float16_converter import \ + convert_float_to_float16 except ImportError: print("Please install onnxmltools!") sys.exit(1)