Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support scalar inputs to the Scan subgraph #24

Merged
merged 1 commit into from
Nov 26, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions onnxruntime/core/framework/mlvalue_tensor_slicer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ MLValueTensorSlicer<T> MLValueTensorSlicer<T>::Create(T& mlvalue, int64_t slice_
ONNXRUNTIME_ENFORCE(mlvalue.IsAllocated(), "MLValue has not been allocated so can't be sliced.");

auto& tensor_shape{mlvalue.template Get<Tensor>().Shape()};
ONNXRUNTIME_ENFORCE(gsl::narrow_cast<int64_t>(tensor_shape.NumDimensions()) > slice_dimension,
"Insufficient dimensions to slice on ", slice_dimension, ". Shape:", tensor_shape);
ONNXRUNTIME_ENFORCE(gsl::narrow_cast<int64_t>(tensor_shape.NumDimensions()) >= slice_dimension,
"Insufficient dimensions to slice on ", slice_dimension, ". Shape:", tensor_shape);

auto dim0_size = tensor_shape[0];
ONNXRUNTIME_ENFORCE(dim0_offset < dim0_size, "Invalid dim0_offset of ", dim0_offset, ". Dimension 0 is ", dim0_size);
Expand Down
5 changes: 3 additions & 2 deletions onnxruntime/core/providers/cpu/controlflow/scan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,9 @@ static const MLValue& GetSubgraphInputMLValue(const OpKernelContextInternal& con
// Validate that the subgraph input has valid shapes
Status ScanImpl::ValidateSubgraphInput(int start_input, int end_input, bool has_seq_len_dim,
const std::vector<const NodeArg*>& graph_inputs) {
// first dim is batch size. optional sequence dim. dim/s for the data
auto min_dims_required = has_seq_len_dim ? 3 : 2;
// first dim is batch size. optional sequence dim. dim/s for the data.
// if there is no dim for the data treat it as a scalar.
auto min_dims_required = has_seq_len_dim ? 2 : 1;

for (int i = start_input; i < end_input; ++i) {
auto& input_tensor = GetSubgraphInputTensor(context_, i);
Expand Down
68 changes: 46 additions & 22 deletions onnxruntime/test/providers/cpu/controlflow/scan_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ struct RunOptions {
bool include_dim_values_in_subgraph = true;
bool include_types_in_subgraph = true;
bool include_outer_scope_add = false;
bool scalar_loop_state_value = false;
bool add_bad_shape = false;
};

Expand All @@ -37,13 +38,13 @@ class ScanOpTester : public OpTester {
// add outer_scope_0 node. push the value through an extra Identity node as a Constant gets lifted into an
// initializer which results in different treatment by the allocation planner
{
TypeProto float_scalar;
float_scalar.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
auto mutable_dim = float_scalar.mutable_tensor_type()->mutable_shape()->add_dim();
TypeProto float_single_value;
float_single_value.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
auto mutable_dim = float_single_value.mutable_tensor_type()->mutable_shape()->add_dim();
mutable_dim->set_dim_value(1);

{
auto& outer_scope_constant = graph.GetOrCreateNodeArg("outer_scope_constant", &float_scalar);
auto& outer_scope_constant = graph.GetOrCreateNodeArg("outer_scope_constant", &float_single_value);
auto* constant = graph.AddNode("outer_scope_constant", "Constant", "Constant with value kOuterNodeAddValue",
{}, {&outer_scope_constant});

Expand All @@ -54,7 +55,7 @@ class ScanOpTester : public OpTester {

constant->AddAttribute("value", value_tensor);

auto& outer_scope_node_arg = graph.GetOrCreateNodeArg("outer_scope_0", &float_scalar);
auto& outer_scope_node_arg = graph.GetOrCreateNodeArg("outer_scope_0", &float_single_value);
graph.AddNode("outer_scope_id", "Identity", "Identity for outer_scope_0",
{&outer_scope_constant}, {&outer_scope_node_arg});
}
Expand All @@ -66,7 +67,7 @@ class ScanOpTester : public OpTester {
};

static void CreateSubgraph(Graph& graph, RunOptions& options, const std::string& failure_message) {
bool include_shapes = options.include_dim_values_in_subgraph;
bool include_dim_values = options.include_dim_values_in_subgraph;
bool include_types = options.include_types_in_subgraph;

std::vector<NodeArg*> inputs;
Expand Down Expand Up @@ -94,21 +95,27 @@ static void CreateSubgraph(Graph& graph, RunOptions& options, const std::string&
inputs = {};
outputs = {};

TypeProto float_scalar;
TypeProto float_input;
// inputs must have type information and a rank
float_scalar.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
auto mutable_dim = float_scalar.mutable_tensor_type()->mutable_shape()->add_dim();
if (include_shapes)
mutable_dim->set_dim_value(1);
float_input.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
auto mutable_shape = float_input.mutable_tensor_type()->mutable_shape();
if (options.scalar_loop_state_value) {
// no dims
} else {
auto mutable_dim = mutable_shape->add_dim(); // set rank
if (include_dim_values)
mutable_dim->set_dim_value(1);
}

{
auto& output_arg = graph.GetOrCreateNodeArg("constant_1", &float_scalar);
auto& output_arg = graph.GetOrCreateNodeArg("constant_1", &float_input);
outputs.push_back(&output_arg);

auto* constant = graph.AddNode("constant", "Constant", "Constant with value 1", inputs, outputs);

TensorProto value_tensor;
value_tensor.add_dims(1);
if (!options.scalar_loop_state_value)
value_tensor.add_dims(1);
value_tensor.add_float_data(1.f);
value_tensor.set_data_type(onnx::TensorProto_DataType_FLOAT);

Expand All @@ -118,7 +125,7 @@ static void CreateSubgraph(Graph& graph, RunOptions& options, const std::string&
inputs = outputs; // start with output from Constant node
outputs = {};

auto& input_arg = graph.GetOrCreateNodeArg("loop_state_in_1", &float_scalar);
auto& input_arg = graph.GetOrCreateNodeArg("loop_state_in_1", &float_input);
inputs.push_back(&input_arg);

TypeProto loop_state_output_tensor;
Expand All @@ -128,15 +135,17 @@ static void CreateSubgraph(Graph& graph, RunOptions& options, const std::string&
// it has to come from here.
bool type_and_shape_required = options.include_dim_values_in_main_graph == false;

if (include_shapes || type_and_shape_required)
loop_state_output_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
if (include_dim_values || type_and_shape_required) {
mutable_shape = loop_state_output_tensor.mutable_tensor_type()->mutable_shape();
if (!options.scalar_loop_state_value)
mutable_shape->add_dim()->set_dim_value(1);
}

TypeProto* type_proto = include_types || type_and_shape_required ? &loop_state_output_tensor : nullptr;
auto& output_arg = graph.GetOrCreateNodeArg("loop_state_out_1", type_proto);
outputs.push_back(&output_arg);

auto* add = graph.AddNode("add", "Add", "Add 1 to the loop state", inputs, outputs);
(void)add;
graph.AddNode("add", "Add", "Add 1 to the loop state", inputs, outputs);
}

// subgraph with multiple inputs and outputs to test variadic behaviour.
Expand All @@ -152,7 +161,7 @@ static void CreateSubgraph(Graph& graph, RunOptions& options, const std::string&
// inputs must have type information and rank, but dimension can have no value if we're not providing shape info.
concat_input_tensor.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
auto mutable_dim = concat_input_tensor.mutable_tensor_type()->mutable_shape()->add_dim();
if (include_shapes) {
if (include_dim_values) {
mutable_dim->set_dim_value(2);

if (options.add_bad_shape) {
Expand All @@ -168,7 +177,7 @@ static void CreateSubgraph(Graph& graph, RunOptions& options, const std::string&
// one output from concatenate of {4} tensor
TypeProto concat_output_tensor;
concat_output_tensor.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
if (include_shapes)
if (include_dim_values)
concat_output_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(4);

TypeProto* type_proto = include_types ? &concat_output_tensor : nullptr;
Expand Down Expand Up @@ -277,13 +286,18 @@ void RunTest(const std::string test_name, int64_t batch_size, int64_t max_sequen
test.AddInput<int64_t>("sequence_lens", sequence_lens_dims, *sequence_lens);
}

test.AddInput<float>("scan_loop_state_in_0", {batch_size, 1}, loop_state_in_0);
std::vector<int64_t> loop_state_shape{batch_size};
if (!options.scalar_loop_state_value) {
loop_state_shape.push_back(1);
}

test.AddInput<float>("scan_loop_state_in_0", loop_state_shape, loop_state_in_0);

std::vector<int64_t> input_shape{batch_size, max_sequence_len, input_size};
test.AddInput<float>("scan_input_0", input_shape, input_0);
test.AddInput<float>("scan_input_1", input_shape, input_1);

test.AddOutput<float>("scan_loop_state_out_0", {batch_size, 1}, loop_state_out_0);
test.AddOutput<float>("scan_loop_state_out_0", loop_state_shape, loop_state_out_0);

std::vector<int64_t> output_shape{batch_size, max_sequence_len, 1};
test.AddOutput<float>("scan_output_0", output_shape, output_0);
Expand Down Expand Up @@ -353,6 +367,16 @@ TEST(Scan, ShortSequenceOneInBatchOneLoopStateVar_NoShapeInMainGraph_NoTypeAndSh
ShortSequenceOneInBatchOneLoopStateVar(options);
}

TEST(Scan, OnnxScalarLoopState) {
RunOptions options{};
options.include_dim_values_in_main_graph = true;
options.include_types_in_subgraph = false;
options.include_dim_values_in_subgraph = false;
options.scalar_loop_state_value = true;

ShortSequenceOneInBatchOneLoopStateVar(options);
}

// test when there is an operator in the subgraph that uses a value coming from outer scope
TEST(Scan, OuterScopeAccess_NoShapeInMainGraph_TypeAndShapeInSubgraph) {
RunOptions options{};
Expand Down