Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable pointer-generator T5 models in BeamSearch #23134

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 48 additions & 18 deletions onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,12 @@ namespace transformers {

Status T5DecoderSubgraph::Validate(const std::vector<const NodeArg*>& subgraph_inputs,
const std::vector<const NodeArg*>& subgraph_outputs) {
bool has_hidden_state = subgraph_inputs[2]->Name() == "encoder_hidden_states" ? true : false;
SetPastInputIndex(has_hidden_state);
bool has_encoder_input_ids = subgraph_inputs[1]->Name() == "encoder_input_ids";
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
bool has_hidden_state = subgraph_inputs[2 + has_encoder_input_ids]->Name() == "encoder_hidden_states";
Dismissed Show dismissed Hide dismissed
SetPastInputIndex(has_hidden_state, has_encoder_input_ids);

ORT_RETURN_IF(first_past_input_index_ != 2 && first_past_input_index_ != 3,
"kFirstPastInputIndex currently only supports 2 or 3");
ORT_RETURN_IF(first_past_input_index_ != 2 && first_past_input_index_ != 3 && first_past_input_index_ != 4,
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
"kFirstPastInputIndex currently only supports 2, 3 or 4");

if (!past_present_share_buffer_) {
ORT_RETURN_IF(has_decoder_masked_attention_, "decoder_masked_attention shall use with past_present_share_buffer");
Expand All @@ -75,13 +76,22 @@ Status T5DecoderSubgraph::Validate(const std::vector<const NodeArg*>& subgraph_i

ORT_RETURN_IF(subgraph_inputs[0]->Name() != "input_ids",
"decoder subgraph input 0 shall be named as input_ids, got: ", subgraph_inputs[0]->Name());
ORT_RETURN_IF(subgraph_inputs[1]->Name() != "encoder_attention_mask",
"decoder subgraph input 1 shall be named as encoder_attention_mask, got: ",
subgraph_inputs[1]->Name());
if (first_past_input_index_ == 3) {
ORT_RETURN_IF(subgraph_inputs[2]->Name() != "encoder_hidden_states",
"decoder subgraph input 2 shall be named as encoder_hidden_states, got: ",
subgraph_inputs[2]->Name());
const int enc_attn_mask_index = 1 + has_encoder_input_ids_;
const int enc_hidden_state_index = enc_attn_mask_index + 1;
if (has_encoder_input_ids_) {
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
ORT_RETURN_IF(subgraph_inputs[1]->Name() != "encoder_input_ids",
"decoder subgraph input 1 shall be named as encoder_input_ids, got: ",
subgraph_inputs[1]->Name());
}
ORT_RETURN_IF(subgraph_inputs[enc_attn_mask_index]->Name() != "encoder_attention_mask",
"decoder subgraph input ", std::to_string(enc_attn_mask_index),
" shall be named as encoder_attention_mask, got: ",
subgraph_inputs[enc_attn_mask_index]->Name());
if (has_hidden_state_) {
ORT_RETURN_IF(subgraph_inputs[enc_hidden_state_index]->Name() != "encoder_hidden_states",
"decoder subgraph input ", std::to_string(enc_hidden_state_index),
" shall be named as encoder_hidden_states, got: ",
subgraph_inputs[enc_hidden_state_index]->Name());
}

// check subgraph outputs
Expand All @@ -108,12 +118,19 @@ Status T5DecoderSubgraph::Validate(const std::vector<const NodeArg*>& subgraph_i

ORT_RETURN_IF(subgraph_inputs[0]->TypeAsProto()->tensor_type().elem_type() != int32_type,
"decoder subgraph input 0 (input_ids) shall have int32 type");
ORT_RETURN_IF(subgraph_inputs[1]->TypeAsProto()->tensor_type().elem_type() != int32_type,
"decoder subgraph input 1 (encoder_attention_mask) shall have int32 type");

auto float_type = subgraph_inputs[2]->TypeAsProto()->tensor_type().elem_type();
ORT_RETURN_IF(float_type != float32_type && float_type != float16_type,
"decoder subgraph input 2 (encoder_hidden_states) shall have float or float16 type");
if (has_encoder_input_ids_) {
ORT_RETURN_IF(subgraph_inputs[1]->TypeAsProto()->tensor_type().elem_type() != int32_type,
"decoder subgraph input 1 (encoder_input_ids) shall have int32 type");
}
ORT_RETURN_IF(subgraph_inputs[enc_attn_mask_index]->TypeAsProto()->tensor_type().elem_type() != int32_type,
"decoder subgraph input ", std::to_string(enc_attn_mask_index),
" (encoder_attention_mask) shall have int32 type");

auto float_type = subgraph_inputs[enc_hidden_state_index]->TypeAsProto()->tensor_type().elem_type();
if (has_hidden_state_) {
ORT_RETURN_IF(float_type != float32_type && float_type != float16_type,
"decoder subgraph input ", std::to_string(enc_hidden_state_index), " (encoder_hidden_states) shall have float or float16 type");
}

for (int i = first_past_input_index_; i < first_past_input_index_ + 4 * num_layers; i++) {
ORT_RETURN_IF(subgraph_inputs[i]->TypeAsProto()->tensor_type().elem_type() != float_type,
Expand Down Expand Up @@ -219,6 +236,19 @@ Status T5DecoderSubgraph::CreateInitialFeeds(
decoder_feeds.reserve(static_cast<size_t>(num_subgraph_inputs) + static_cast<size_t>(num_implicit_inputs));
decoder_feeds.push_back(input_ids);

if (has_encoder_input_ids_) {
// The encoder_input_ids is copied from the first input of encoder.
OrtValue expanded_encoder_input_ids;
ORT_RETURN_IF_ERROR(expand_buffer_int32_func(stream,
encoder_feeds[0],
num_beam,
allocator,
expanded_encoder_input_ids,
false,
0 /*max_sequence_length*/));
decoder_feeds.push_back(expanded_encoder_input_ids);
}

// The encoder_attention_mask is copied from the second input of encoder.
OrtValue expanded_decoder_attention_masks;
ORT_RETURN_IF_ERROR(expand_buffer_int32_func(stream,
Expand All @@ -238,7 +268,7 @@ Status T5DecoderSubgraph::CreateInitialFeeds(
// When first_past_input_index_ == 3, the encoder_hidden_states and past states are copied from the second output
// of encoder.
// When first_past_input_index_ == 2, the past states are copied from the second output of encoder.
for (size_t j = static_cast<size_t>(4) - first_past_input_index_; j < encoder_fetches.size(); j++) {
for (size_t j = static_cast<size_t>(2) - has_hidden_state_; j < encoder_fetches.size(); j++) {
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
if (j == 1) {
ORT_RETURN_IF(has_hidden_state_ == false, "Invalid hidden_states expension: has_hidden_state_ == false");
OrtValue expanded_hidden_states;
Expand Down
10 changes: 4 additions & 6 deletions onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,10 @@ class T5DecoderSubgraph : public Subgraph {
Status Validate(const std::vector<const NodeArg*>& subgraph_inputs,
const std::vector<const NodeArg*>& subgraph_outputs) override;

void SetPastInputIndex(bool has_hidden_state) {
void SetPastInputIndex(bool has_hidden_state, bool has_encoder_input_ids) {
has_hidden_state_ = has_hidden_state;
if (!has_hidden_state_) {
first_past_input_index_ = 2;
} else {
first_past_input_index_ = 3;
}
has_encoder_input_ids_ = has_encoder_input_ids;
first_past_input_index_ = 2 + has_hidden_state_ + has_encoder_input_ids_;
}

int GetFirstPastInputIndex() const {
Expand All @@ -79,6 +76,7 @@ class T5DecoderSubgraph : public Subgraph {
int first_past_input_index_;
int first_present_output_index_;
bool has_hidden_state_;
bool has_encoder_input_ids_;
bool use_sequence_as_input_ids_;
};

Expand Down
22 changes: 22 additions & 0 deletions onnxruntime/test/contrib_ops/beam_search_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,8 @@ TEST(BeamSearchTest, DummyT5) {
#if defined(USE_CUDA) && defined(USE_DML)
SKIP_CUDA_TEST_WITH_DML;
#endif
// dummy_t5.onnx model generated using following command:
// python onnxruntime/test/testdata/dummy_t5_generator.py --output-path dummy_t5.onnx
ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5.onnx"));
tester.ConfigEp(DefaultCpuExecutionProvider());
tester.AddInput("encoder_input_ids", {1, 5}, {14, 6, 13, 9, 7});
Expand All @@ -414,6 +416,8 @@ TEST(BeamSearchTest, DummyT5WithOuterScopeInitializers) {
#if defined(USE_CUDA) && defined(USE_DML)
SKIP_CUDA_TEST_WITH_DML;
#endif
// dummy_t5_with_outer_scope_initializers.onnx model generated using following command:
// python onnxruntime/test/testdata/dummy_t5_generator.py --output-path dummy_t5_with_outer_scope_initializers.onnx --move-initializers
ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5_with_outer_scope_initializers.onnx"));
tester.ConfigEp(DefaultCpuExecutionProvider());
tester.AddInput("encoder_input_ids", {1, 5}, {14, 6, 13, 9, 7});
Expand All @@ -428,6 +432,8 @@ TEST(BeamSearchTest, DummyT5WithSequenceInputIds) {
#if defined(USE_CUDA) && defined(USE_DML)
SKIP_CUDA_TEST_WITH_DML;
#endif
// dummy_t5_with_sequence_input_ids.onnx model generated using following command:
// python onnxruntime/test/testdata/dummy_t5_generator.py --output-path dummy_t5_with_sequence_input_ids.onnx --sequence-as-input
ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5_with_sequence_input_ids.onnx"));
tester.ConfigEp(DefaultCpuExecutionProvider());
tester.AddInput("encoder_input_ids", {1, 5}, {16, 17, 1, 0, 8});
Expand All @@ -438,5 +444,21 @@ TEST(BeamSearchTest, DummyT5WithSequenceInputIds) {
tester.RunWithConfig();
}

TEST(BeamSearchTest, DummyT5PointerGenerator) {
#if defined(USE_CUDA) && defined(USE_DML)
SKIP_CUDA_TEST_WITH_DML;
#endif
// dummy_t5_pointer_generator.onnx model generated using following command:
// python onnxruntime/test/testdata/dummy_t5_generator.py --output-path dummy_t5_pointer_generator.onnx --decoder-needs-input-ids
ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5_pointer_generator.onnx"));
tester.ConfigEp(DefaultCpuExecutionProvider());
tester.AddInput("encoder_input_ids", {1, 5}, {14, 6, 13, 9, 7});
tester.AddOutput("sequences", {1, 3, 10}, {2, 3, 6, 7, 3, 6, 7, 18, 3, 6, 2, 3, 6, 7, 18, 3, 6, 7, 18, 3, 2, 3, 6, 7, 3, 6, 7, 3, 6, 7});
#ifdef USE_CUDA
tester.ConfigEp(DefaultCudaExecutionProvider());
#endif
tester.RunWithConfig();
}

} // namespace test
} // namespace onnxruntime
Loading
Loading