diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index f699916f0c..68f851808c 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -2282,7 +2282,12 @@ def doc_string(self, value: str | None) -> None: def opset_imports(self) -> dict[str, int]: return self._opset_imports - def __getitem__(self, index: int) -> Node: + @typing.overload + def __getitem__(self, index: int) -> Node: ... + @typing.overload + def __getitem__(self, index: slice) -> Sequence[Node]: ... + + def __getitem__(self, index): return self._nodes[index] def __len__(self) -> int: @@ -2712,7 +2717,12 @@ def __init__( self._metadata_props: dict[str, str] | None = metadata_props self._nodes: tuple[Node, ...] = tuple(nodes) - def __getitem__(self, index: int) -> Node: + @typing.overload + def __getitem__(self, index: int) -> Node: ... + @typing.overload + def __getitem__(self, index: slice) -> Sequence[Node]: ... + + def __getitem__(self, index): return self._nodes[index] def __len__(self) -> int: @@ -2961,7 +2971,12 @@ def outputs(self) -> MutableSequence[Value]: def attributes(self) -> OrderedDict[str, Attr]: return self._attributes - def __getitem__(self, index: int) -> Node: + @typing.overload + def __getitem__(self, index: int) -> Node: ... + @typing.overload + def __getitem__(self, index: slice) -> Sequence[Node]: ... + + def __getitem__(self, index): return self._graph.__getitem__(index) def __len__(self) -> int: diff --git a/onnxscript/ir/_linked_list.py b/onnxscript/ir/_linked_list.py index 0db770e20e..fd425c505b 100644 --- a/onnxscript/ir/_linked_list.py +++ b/onnxscript/ir/_linked_list.py @@ -4,7 +4,7 @@ from __future__ import annotations -from typing import Generic, Iterable, Iterator, Sequence, TypeVar +from typing import Generic, Iterable, Iterator, Sequence, TypeVar, overload T = TypeVar("T") @@ -136,11 +136,18 @@ def __len__(self) -> int: ) return self._length - def __getitem__(self, index: int) -> T: + @overload + def __getitem__(self, index: int) -> T: ... + @overload + def __getitem__(self, index: slice) -> Sequence[T]: ... + + def __getitem__(self, index): """Get the node at the given index. Complexity is O(n). """ + if isinstance(index, slice): + return tuple(self)[index] if index >= self._length or index < -self._length: raise IndexError( f"Index out of range: {index} not in range [-{self._length}, {self._length})" diff --git a/onnxscript/ir/_linked_list_test.py b/onnxscript/ir/_linked_list_test.py index 00f03e71ea..ead022bf2e 100644 --- a/onnxscript/ir/_linked_list_test.py +++ b/onnxscript/ir/_linked_list_test.py @@ -373,6 +373,15 @@ def test_insert_after_supports_taking_elements_from_another_doubly_linked_list( self.assertEqual(len(other_linked_list), 1) self.assertEqual([elem.value for elem in other_linked_list], [42]) + @parameterized.parameterized.expand( + [(s, t, p) for s in [-2, 0, 2, 3] for t in [2, -1, -2] for p in [-3, -1, 1, 2]] + ) + def test_get_item_slice(self, start, stop, step): + elems = [_TestElement(i) for i in range(5)] + linked_list = _linked_list.DoublyLinkedSet(elems) + self.assertEqual(len(linked_list), 5) + self.assertEqual(list(linked_list[start:stop:step]), elems[start:stop:step]) + if __name__ == "__main__": unittest.main()