|
8 | 8 | from .num_nodes import maybe_num_nodes
|
9 | 9 |
|
10 | 10 |
|
| 11 | +@torch.jit._overload |
| 12 | +def is_undirected(edge_index, edge_attr=None, num_nodes=None): |
| 13 | + # type: (Tensor, Optional[Tensor], Optional[int]) -> bool # noqa |
| 14 | + pass |
| 15 | + |
| 16 | + |
| 17 | +@torch.jit._overload |
| 18 | +def is_undirected(edge_index, edge_attr=None, num_nodes=None): |
| 19 | + # type: (Tensor, List[Tensor], Optional[int]) -> bool # noqa |
| 20 | + pass |
| 21 | + |
| 22 | + |
11 | 23 | def is_undirected(
|
12 | 24 | edge_index: Tensor,
|
13 |
| - edge_attr: Optional[Union[Tensor, List[Tensor]]] = None, |
| 25 | + edge_attr: Union[Optional[Tensor], List[Tensor]] = None, |
14 | 26 | num_nodes: Optional[int] = None,
|
15 | 27 | ) -> bool:
|
16 | 28 | r"""Returns :obj:`True` if the graph given by :attr:`edge_index` is
|
@@ -42,31 +54,56 @@ def is_undirected(
|
42 | 54 | """
|
43 | 55 | num_nodes = maybe_num_nodes(edge_index, num_nodes)
|
44 | 56 |
|
45 |
| - edge_attr = [] if edge_attr is None else edge_attr |
46 |
| - edge_attr = [edge_attr] if isinstance(edge_attr, Tensor) else edge_attr |
| 57 | + edge_attrs: List[Tensor] = [] |
| 58 | + if isinstance(edge_attr, Tensor): |
| 59 | + edge_attrs.append(edge_attr) |
| 60 | + elif isinstance(edge_attr, (list, tuple)): |
| 61 | + edge_attrs = edge_attr |
47 | 62 |
|
48 |
| - edge_index1, edge_attr1 = sort_edge_index( |
| 63 | + edge_index1, edge_attrs1 = sort_edge_index( |
49 | 64 | edge_index,
|
50 |
| - edge_attr, |
| 65 | + edge_attrs, |
51 | 66 | num_nodes=num_nodes,
|
52 | 67 | sort_by_row=True,
|
53 | 68 | )
|
54 |
| - edge_index2, edge_attr2 = sort_edge_index( |
55 |
| - edge_index1, |
56 |
| - edge_attr1, |
| 69 | + edge_index2, edge_attrs2 = sort_edge_index( |
| 70 | + edge_index, |
| 71 | + edge_attrs, |
57 | 72 | num_nodes=num_nodes,
|
58 | 73 | sort_by_row=False,
|
59 | 74 | )
|
60 | 75 |
|
61 |
| - return (bool(torch.all(edge_index1[0] == edge_index2[1])) |
62 |
| - and bool(torch.all(edge_index1[1] == edge_index2[0])) and all([ |
63 |
| - torch.all(e == e_T) for e, e_T in zip(edge_attr1, edge_attr2) |
64 |
| - ])) |
| 76 | + if not torch.equal(edge_index1[0], edge_index2[1]): |
| 77 | + return False |
| 78 | + if not torch.equal(edge_index1[1], edge_index2[0]): |
| 79 | + return False |
| 80 | + for edge_attr1, edge_attr2 in zip(edge_attrs1, edge_attrs2): |
| 81 | + if not torch.equal(edge_attr1, edge_attr2): |
| 82 | + return False |
| 83 | + return True |
| 84 | + |
| 85 | + |
| 86 | +@torch.jit._overload |
| 87 | +def to_undirected(edge_index, edge_attr=None, num_nodes=None, reduce="add"): |
| 88 | + # type: (Tensor, Optional[bool], Optional[int], str) -> Tensor # noqa |
| 89 | + pass |
| 90 | + |
| 91 | + |
| 92 | +@torch.jit._overload |
| 93 | +def to_undirected(edge_index, edge_attr=None, num_nodes=None, reduce="add"): |
| 94 | + # type: (Tensor, Tensor, Optional[int], str) -> Tuple[Tensor, Tensor] # noqa |
| 95 | + pass |
| 96 | + |
| 97 | + |
| 98 | +@torch.jit._overload |
| 99 | +def to_undirected(edge_index, edge_attr=None, num_nodes=None, reduce="add"): |
| 100 | + # type: (Tensor, List[Tensor], Optional[int], str) -> Tuple[Tensor, List[Tensor]] # noqa |
| 101 | + pass |
65 | 102 |
|
66 | 103 |
|
67 | 104 | def to_undirected(
|
68 | 105 | edge_index: Tensor,
|
69 |
| - edge_attr: Optional[Union[Tensor, List[Tensor]]] = None, |
| 106 | + edge_attr: Union[Optional[Tensor], List[Tensor]] = None, |
70 | 107 | num_nodes: Optional[int] = None,
|
71 | 108 | reduce: str = "add",
|
72 | 109 | ) -> Union[Tensor, Tuple[Tensor, Tensor], Tuple[Tensor, List[Tensor]]]:
|
@@ -116,13 +153,13 @@ def to_undirected(
|
116 | 153 | edge_attr = None
|
117 | 154 | num_nodes = edge_attr
|
118 | 155 |
|
119 |
| - row, col = edge_index |
| 156 | + row, col = edge_index[0], edge_index[1] |
120 | 157 | row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0)
|
121 | 158 | edge_index = torch.stack([row, col], dim=0)
|
122 | 159 |
|
123 |
| - if edge_attr is not None and isinstance(edge_attr, Tensor): |
| 160 | + if isinstance(edge_attr, Tensor): |
124 | 161 | edge_attr = torch.cat([edge_attr, edge_attr], dim=0)
|
125 |
| - elif edge_attr is not None: |
| 162 | + elif isinstance(edge_attr, (list, tuple)): |
126 | 163 | edge_attr = [torch.cat([e, e], dim=0) for e in edge_attr]
|
127 | 164 |
|
128 | 165 | return coalesce(edge_index, edge_attr, num_nodes, reduce)
|
0 commit comments