diff --git a/src/core/shape_inference/include/range_shape_inference.hpp b/src/core/shape_inference/include/range_shape_inference.hpp index 3be56a4543a7dd..5d754810d9b80a 100644 --- a/src/core/shape_inference/include/range_shape_inference.hpp +++ b/src/core/shape_inference/include/range_shape_inference.hpp @@ -11,6 +11,31 @@ namespace op { namespace ShapeInferRange { +template ::value>::type* = nullptr> +void symbol_propagation(const Node* op, + std::vector& output_shapes, + const double& start, + const double& step, + bool start_val, + bool step_val) { + output_shapes[0] = ov::PartialShape::dynamic(1); + if (op->get_input_size() == 3 && step_val && step == 1) { + auto start_symbol = op->input_value(0).get_tensor().get_value_symbol(); + auto stop_symbol = op->input_value(1).get_tensor().get_value_symbol(); + if (start_val && start == 0 && !stop_symbol.empty()) { + output_shapes[0][0].set_symbol(stop_symbol[0]); + } + } +} + +template ::value>::type* = nullptr> +void symbol_propagation(const Node* op, + std::vector& output_shapes, + const double& start, + const double& step, + bool start_val, + bool step_val) {} + template > std::vector range_shape_infer(const Node* op, const std::vector& input_shapes, @@ -35,12 +60,18 @@ std::vector range_shape_infer(const Node* op, NODE_VALIDATION_CHECK(op, start_val->size() == 1); start = (*start_val)[0]; NODE_VALIDATION_CHECK(op, std::isfinite(start) && !std::isnan(start), "'start' cannot be nan or infinite."); + if (output_is_integral) + // all inputs must be casted to output_type before the rounding for casting values are done towards zero + start = std::trunc(start); } if (stop_val) { NODE_VALIDATION_CHECK(op, stop_val->size() == 1); stop = (*stop_val)[0]; NODE_VALIDATION_CHECK(op, std::isfinite(stop) && !std::isnan(stop), "'stop' cannot be nan or infinite."); + if (output_is_integral) + // all inputs must be casted to output_type before the rounding for casting values are done towards zero + stop = std::trunc(stop); } if (step_val) { @@ -52,18 +83,13 @@ std::vector range_shape_infer(const Node* op, NODE_VALIDATION_CHECK(op, std::isfinite(step) && !std::isnan(step) && step != 0, "'step' cannot be zero, nan, or infinite."); + if (output_is_integral) + // all inputs must be casted to output_type before the rounding for casting values are done towards zero + step = std::trunc(step); } auto output_shapes = std::vector(1); if (start_val && stop_val && step_val) { - // all inputs must be casted to output_type before - // the rounding for casting values are done towards zero - if (output_is_integral) { - start = std::trunc(start); - stop = std::trunc(stop); - step = std::trunc(step); - } - // the number of elements is: max(ceil((stop − start) / step), 0) double span; if ((step > 0 && start >= stop) || (step < 0 && start <= stop)) { @@ -76,7 +102,7 @@ std::vector range_shape_infer(const Node* op, output_shapes[0] = TRShape{static_cast(strided)}; } else { - output_shapes[0] = ov::PartialShape::dynamic(1); + symbol_propagation(op, output_shapes, start, step, start_val, step_val); } return output_shapes; } diff --git a/src/core/tests/type_prop/range.cpp b/src/core/tests/type_prop/range.cpp index c37f5987047a53..d44fe3c3bcc6b8 100644 --- a/src/core/tests/type_prop/range.cpp +++ b/src/core/tests/type_prop/range.cpp @@ -895,3 +895,21 @@ INSTANTIATE_TEST_SUITE_P(type_prop, RangeParams{-1, 1, 0.25, PartialShape{8}}, RangeParams{-1, 0.875, 0.25, PartialShape{8}}), PrintToDummyParamName()); + +TEST(type_prop, range_symbol_start_0_stop_A_step_1) { + auto stop_symbol = std::make_shared(); + auto source_shape = PartialShape::dynamic(1); + source_shape[0].set_symbol(stop_symbol); + auto symbol_source = + make_shared(make_shared(element::i64, source_shape)); + + auto start = make_shared(element::i64, Shape{}, 0); + auto stop = make_shared(symbol_source, + make_shared(element::i64, Shape{}, 0), + make_shared(element::i64, Shape{}, 0)); + auto step = make_shared(element::i64, Shape{}, 1); + + auto range = make_shared(start, stop, step); + + ASSERT_TRUE(ov::symbol::are_equal(range->get_output_partial_shape(0)[0].get_symbol(), stop_symbol)); +}