From 265afd664aabf67311d1559ed59a20c21dc433a4 Mon Sep 17 00:00:00 2001 From: KIRAN Date: Thu, 11 Mar 2021 16:15:43 +0530 Subject: [PATCH] added support for exporting of t5 to onnx with past_key_values --- src/transformers/models/t5/modeling_t5.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index c12a8f4a8991..63f9feed13d6 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -423,6 +423,8 @@ def forward( # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) batch_size, seq_length = hidden_states.shape[:2] + int_seq_length = int(seq_length) + real_seq_length = seq_length if past_key_value is not None: @@ -491,7 +493,7 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): # if key and values are already calculated # we want only the last query position bias if past_key_value is not None: - position_bias = position_bias[:, :, -seq_length:, :] + position_bias = position_bias[:, :, -int_seq_length:, :] if mask is not None: position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)