@@ -109,6 +109,36 @@ class OpIndexTensorOutTest : public OperatorTest {
109109
110110 ET_FORALL_REALHBF16_TYPES (TEST_ENTRY);
111111
112+ #undef TEST_ENTRY
113+ }
114+
115+ template <executorch::aten::ScalarType INPUT_DTYPE>
116+ void test_indices_with_only_null_tensors_supported () {
117+ TensorFactory<INPUT_DTYPE> tf;
118+
119+ Tensor x = tf.make ({2 , 3 }, {1 , 2 , 3 , 4 , 5 , 6 });
120+ Tensor out = tf.zeros ({2 , 3 });
121+
122+ std::array<optional<Tensor>, 1 > indices1 = {optional<Tensor>()};
123+ op_index_tensor_out (x, indices1, out);
124+ EXPECT_TENSOR_EQ (out, x);
125+
126+ out = tf.zeros ({2 , 3 });
127+ std::array<optional<Tensor>, 2 > indices2 = {
128+ optional<Tensor>(), std::optional<Tensor>()};
129+ op_index_tensor_out (x, indices2, out);
130+ EXPECT_TENSOR_EQ (out, x);
131+ }
132+
133+ /* *
134+ * Test indices with only null tensors for all input data types
135+ */
136+ void test_indices_with_only_null_tensors_enumerate_in_types () {
137+ #define TEST_ENTRY (ctype, dtype ) \
138+ test_indices_with_only_null_tensors_supported<ScalarType::dtype>();
139+
140+ ET_FORALL_REALHBF16_TYPES (TEST_ENTRY);
141+
112142#undef TEST_ENTRY
113143 }
114144
@@ -405,21 +435,19 @@ TEST_F(OpIndexTensorOutTest, IndicesWithOnlyNullTensorsSupported) {
405435 if (torch::executor::testing::SupportedFeatures::get ()->is_aten ) {
406436 GTEST_SKIP () << " ATen kernel test fails" ;
407437 }
408- TensorFactory<ScalarType::Double> tf;
438+ test_indices_with_only_null_tensors_enumerate_in_types ();
439+ }
409440
441+ TEST_F (OpIndexTensorOutTest, TooManyNullIndices) {
442+ TensorFactory<ScalarType::Double> tf;
410443 Tensor x = tf.make ({2 , 3 }, {1 ., 2 ., 3 ., 4 ., 5 ., 6 .});
411- std::array<optional<Tensor>, 1 > indices0 = {optional<Tensor>()};
412- run_test_cases (x, indices0, x);
413-
414- std::array<optional<Tensor>, 2 > indices1 = {
415- optional<Tensor>(), std::optional<Tensor>()};
416- run_test_cases (x, indices1, x);
417-
418- std::array<optional<Tensor>, 3 > indices2 = {
444+ std::array<optional<Tensor>, 3 > indices = {
419445 optional<Tensor>(), std::optional<Tensor>(), std::optional<Tensor>()};
420446 Tensor out = tf.ones ({2 , 3 });
421447 ET_EXPECT_KERNEL_FAILURE_WITH_MSG (
422- context_, op_index_tensor_out (x, indices2, out), " " );
448+ context_,
449+ op_index_tensor_out (x, indices, out),
450+ " Indexing too many dimensions" );
423451}
424452
425453TEST_F (OpIndexTensorOutTest, EmptyIndicesSupported) {
0 commit comments