From b29edcce239670a591bfa590290e5ee689b14462 Mon Sep 17 00:00:00 2001 From: Tingyu Wang Date: Mon, 10 Apr 2023 22:21:17 -0400 Subject: [PATCH 01/17] set up nn conv --- python/cugraph-pyg/cugraph_pyg/nn/__init__.py | 14 ++ .../cugraph_pyg/nn/conv/__init__.py | 20 ++ .../cugraph-pyg/cugraph_pyg/nn/conv/base.py | 189 ++++++++++++++++++ .../cugraph_pyg/nn/conv/gat_conv.py | 106 ++++++++++ .../cugraph_pyg/nn/conv/transformer_conv.py | 37 ++++ .../cugraph_pyg/tests/nn/test_gat_conv.py | 77 +++++++ 6 files changed, 443 insertions(+) create mode 100644 python/cugraph-pyg/cugraph_pyg/nn/__init__.py create mode 100644 python/cugraph-pyg/cugraph_pyg/nn/conv/__init__.py create mode 100644 python/cugraph-pyg/cugraph_pyg/nn/conv/base.py create mode 100644 python/cugraph-pyg/cugraph_pyg/nn/conv/gat_conv.py create mode 100644 python/cugraph-pyg/cugraph_pyg/nn/conv/transformer_conv.py create mode 100644 python/cugraph-pyg/cugraph_pyg/tests/nn/test_gat_conv.py diff --git a/python/cugraph-pyg/cugraph_pyg/nn/__init__.py b/python/cugraph-pyg/cugraph_pyg/nn/__init__.py new file mode 100644 index 00000000000..331b49ebec0 --- /dev/null +++ b/python/cugraph-pyg/cugraph_pyg/nn/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .conv import * diff --git a/python/cugraph-pyg/cugraph_pyg/nn/conv/__init__.py b/python/cugraph-pyg/cugraph_pyg/nn/conv/__init__.py new file mode 100644 index 00000000000..d2317f7ae53 --- /dev/null +++ b/python/cugraph-pyg/cugraph_pyg/nn/conv/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .gat_conv import GATConv +from .transformer_conv import TransformerConv + +__all__ = [ + "GATConv", + "TransformerConv", +] diff --git a/python/cugraph-pyg/cugraph_pyg/nn/conv/base.py b/python/cugraph-pyg/cugraph_pyg/nn/conv/base.py new file mode 100644 index 00000000000..91f64200e47 --- /dev/null +++ b/python/cugraph-pyg/cugraph_pyg/nn/conv/base.py @@ -0,0 +1,189 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import Any, Optional, Tuple, Union + +from cugraph.utilities.utils import import_optional + +torch = import_optional("torch") +torch_geometric = import_optional("torch_geometric") + +try: # pragma: no cover + from pylibcugraphops.pytorch import ( + SampledCSC, + SampledHeteroCSC, + StaticCSC, + StaticHeteroCSC, + ) + + HAS_PYLIBCUGRAPHOPS = True +except ImportError: + HAS_PYLIBCUGRAPHOPS = False + + +class BaseConv(torch.nn.Module): # pragma: no cover + r"""An abstract base class for implementing cugraph message passing layers.""" + + def __init__(self): + super().__init__() + + if HAS_PYLIBCUGRAPHOPS is False: + raise ModuleNotFoundError( + f"'{self.__class__.__name__}' requires " f"'pylibcugraphops>=23.04'" + ) + + def reset_parameters(self): + r"""Resets all learnable parameters of the module.""" + pass + + @staticmethod + def to_csc( + edge_index: torch.Tensor, + size: Optional[Tuple[int, int]] = None, + edge_attr: Optional[torch.Tensor] = None, + ) -> Union[ + Tuple[torch.Tensor, torch.Tensor, int], + Tuple[Tuple[torch.Tensor, torch.Tensor, int], torch.Tensor], + ]: + r"""Returns a CSC representation of an :obj:`edge_index` tensor to be + used as input to a :class:`CuGraphModule`. + + Args: + edge_index (torch.Tensor): The edge indices. + size ((int, int), optional). The shape of :obj:`edge_index` in each + dimension. (default: :obj:`None`) + edge_attr (torch.Tensor, optional): The edge features. + (default: :obj:`None`) + """ + if size is None: + warnings.warn( + f"Inferring the graph size from 'edge_index' causes " + f"a decline in performance and does not work for " + f"bipartite graphs. To suppress this warning, pass " + f"the 'size' explicitly in '{__name__}.to_csc()'." + ) + num_src_nodes = num_dst_nodes = int(edge_index.max()) + 1 + else: + num_src_nodes, num_dst_nodes = size + + row, col = edge_index + col, perm = torch_geometric.utils.index_sort(col, max_value=num_dst_nodes) + row = row[perm] + + colptr = torch_geometric.utils.sparse.index2ptr(col, num_dst_nodes) + + if edge_attr is not None: + return (row, colptr, num_src_nodes), edge_attr[perm] + + return row, colptr, num_src_nodes + + def get_cugraph( + self, + csc: Tuple[torch.Tensor, torch.Tensor, int], + max_num_neighbors: Optional[int] = None, + ) -> Any: + r"""Constructs a :obj:`cugraph` graph object from CSC representation. + Supports both bipartite and non-bipartite graphs. + + Args: + csc ((torch.Tensor, torch.Tensor, int)): A tuple containing the CSC + representation of a graph, given as a tuple of + :obj:`(row, colptr, num_src_nodes)`. Use the + :meth:`CuGraphModule.to_csc` method to convert an + :obj:`edge_index` representation to the desired format. + max_num_neighbors (int, optional): The maximum number of neighbors + of a target node. It is only effective when operating in a + bipartite graph. When not given, will be computed on-the-fly, + leading to slightly worse performance. (default: :obj:`None`) + """ + row, colptr, num_src_nodes = csc + + if not row.is_cuda: + raise RuntimeError( + f"'{self.__class__.__name__}' requires GPU-" + f"based processing (got CPU tensor)" + ) + + if num_src_nodes != colptr.numel() - 1: # Bipartite graph: + if max_num_neighbors is None: + max_num_neighbors = int((colptr[1:] - colptr[:-1]).max()) + + return SampledCSC(colptr, row, max_num_neighbors, num_src_nodes) + + return StaticCSC(colptr, row) + + def get_typed_cugraph( + self, + csc: Tuple[torch.Tensor, torch.Tensor, int], + edge_type: torch.Tensor, + num_edge_types: Optional[int] = None, + max_num_neighbors: Optional[int] = None, + ) -> Any: + r"""Constructs a typed :obj:`cugraph` graph object from a CSC + representation where each edge corresponds to a given edge type. + Supports both bipartite and non-bipartite graphs. + + Args: + csc ((torch.Tensor, torch.Tensor, int)): A tuple containing the CSC + representation of a graph, given as a tuple of + :obj:`(row, colptr, num_src_nodes)`. Use the + :meth:`CuGraphModule.to_csc` method to convert an + :obj:`edge_index` representation to the desired format. + edge_type (torch.Tensor): The edge type. + num_edge_types (int, optional): The maximum number of edge types. + When not given, will be computed on-the-fly, leading to + slightly worse performance. (default: :obj:`None`) + max_num_neighbors (int, optional): The maximum number of neighbors + of a target node. It is only effective when operating in a + bipartite graph. When not given, will be computed on-the-fly, + leading to slightly worse performance. (default: :obj:`None`) + """ + if num_edge_types is None: + num_edge_types = int(edge_type.max()) + 1 + + row, colptr, num_src_nodes = csc + edge_type = edge_type.int() + + if num_src_nodes != colptr.numel() - 1: # Bipartite graph: + if max_num_neighbors is None: + max_num_neighbors = int((colptr[1:] - colptr[:-1]).max()) + + return SampledHeteroCSC( + colptr, row, edge_type, max_num_neighbors, num_src_nodes, num_edge_types + ) + + return StaticHeteroCSC(colptr, row, edge_type, num_edge_types) + + def forward( + self, + x: torch.Tensor, + csc: Tuple[torch.Tensor, torch.Tensor, int], + max_num_neighbors: Optional[int] = None, + ) -> torch.Tensor: + r"""Runs the forward pass of the module. + + Args: + x (torch.Tensor): The node features. + csc ((torch.Tensor, torch.Tensor, int)): A tuple containing the CSC + representation of a graph, given as a tuple of + :obj:`(row, colptr, num_src_nodes)`. Use the + :meth:`CuGraphModule.to_csc` method to convert an + :obj:`edge_index` representation to the desired format. + max_num_neighbors (int, optional): The maximum number of neighbors + of a target node. It is only effective when operating in a + bipartite graph. When not given, the value will be computed + on-the-fly, leading to slightly worse performance. + (default: :obj:`None`) + """ + raise NotImplementedError diff --git a/python/cugraph-pyg/cugraph_pyg/nn/conv/gat_conv.py b/python/cugraph-pyg/cugraph_pyg/nn/conv/gat_conv.py new file mode 100644 index 00000000000..9382898d32b --- /dev/null +++ b/python/cugraph-pyg/cugraph_pyg/nn/conv/gat_conv.py @@ -0,0 +1,106 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +from cugraph.utilities.utils import import_optional + +from .base import BaseConv + +torch = import_optional("torch") +nn = import_optional("torch.nn") +torch_geometric = import_optional("torch_geometric") + +try: + from pylibcugraphops.pytorch.operators import mha_gat_n2n as GATConvAgg +except ImportError: + pass + + +class GATConv(BaseConv): # pragma: no cover + r"""The graph attentional operator from the `"Graph Attention Networks" + `_ paper. + + :class:`GATConv` is an optimized version of + :class:`~torch_geometric.nn.conv.GATConv` based on the :obj:`cugraph-ops` + package that fuses message passing computation for accelerated execution + and lower memory footprint. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + heads: int = 1, + concat: bool = True, + negative_slope: float = 0.2, + bias: bool = True, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.heads = heads + self.concat = concat + self.negative_slope = negative_slope + + self.lin = nn.Linear(in_channels, heads * out_channels, bias=False) + self.att = nn.Parameter(torch.Tensor(2 * heads * out_channels)) + + if bias and concat: + self.bias = nn.Parameter(torch.Tensor(heads * out_channels)) + elif bias and not concat: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter("bias", None) + + self.reset_parameters() + + def reset_parameters(self): + self.lin.reset_parameters() + gain = torch.nn.init.calculate_gain("relu") + torch.nn.init.xavier_normal_( + self.att.view(2, self.heads, self.out_channels), gain=gain + ) + torch_geometric.nn.inits.zeros(self.bias) + + def forward( + self, + x: torch.Tensor, + csc: Tuple[torch.Tensor, torch.Tensor, int], + max_num_neighbors: Optional[int] = None, + ) -> torch.Tensor: + graph = self.get_cugraph(csc, max_num_neighbors) + + x = self.lin(x) + + out = GATConvAgg( + x, + self.att, + graph, + self.heads, + "LeakyReLU", + self.negative_slope, + self.concat, + ) + + if self.bias is not None: + out = out + self.bias + + return out + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}({self.in_channels}, " + f"{self.out_channels}, heads={self.heads})" + ) diff --git a/python/cugraph-pyg/cugraph_pyg/nn/conv/transformer_conv.py b/python/cugraph-pyg/cugraph_pyg/nn/conv/transformer_conv.py new file mode 100644 index 00000000000..d887d2df236 --- /dev/null +++ b/python/cugraph-pyg/cugraph_pyg/nn/conv/transformer_conv.py @@ -0,0 +1,37 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .base import BaseConv + + +class TransformerConv(BaseConv): + r"""The graph transformer operator from the `"Masked Label Prediction: + Unified Message Passing Model for Semi-Supervised Classification" + `_ paper + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + heads: int = 1, + concat: bool = True, + root_weight: bool = True, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.heads = heads + self.root_weight = root_weight + self.concat = concat diff --git a/python/cugraph-pyg/cugraph_pyg/tests/nn/test_gat_conv.py b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_gat_conv.py new file mode 100644 index 00000000000..6be5a6e48ef --- /dev/null +++ b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_gat_conv.py @@ -0,0 +1,77 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +try: + from torch_geometric.nn import GATConv +except ModuleNotFoundError: + pytest.skip("PyG not available", allow_module_level=True) + +from cugraph.utilities.utils import import_optional +from cugraph_pyg.nn import GATConv as CuGraphGATConv + +torch = import_optional("torch") + + +@pytest.mark.parametrize("bias", [True, False]) +@pytest.mark.parametrize("bipartite", [True, False]) +@pytest.mark.parametrize("concat", [True, False]) +@pytest.mark.parametrize("heads", [1, 2, 3]) +@pytest.mark.parametrize("max_num_neighbors", [8, None]) +def test_gat_conv_equality(bias, bipartite, concat, heads, max_num_neighbors): + in_channels, out_channels = (5, 2) + kwargs = dict(bias=bias, concat=concat) + + size = (10, 8) if bipartite else (10, 10) + x = torch.rand(size[0], in_channels, device="cuda") + edge_index = torch.tensor( + [ + [7, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 8, 9], + [0, 1, 2, 3, 4, 5, 6, 7, 7, 7, 7, 7, 7, 7, 7], + ], + device="cuda", + ) + + conv1 = GATConv( + in_channels, out_channels, heads, add_self_loops=False, **kwargs + ).cuda() + conv2 = CuGraphGATConv(in_channels, out_channels, heads, **kwargs).cuda() + + with torch.no_grad(): + conv2.lin.weight.data[:, :] = conv1.lin_src.weight.data + conv2.att.data[: heads * out_channels] = conv1.att_src.data.flatten() + conv2.att.data[heads * out_channels :] = conv1.att_dst.data.flatten() + + if bipartite: + out1 = conv1((x, x[: size[1]]), edge_index) + else: + out1 = conv1(x, edge_index) + + csc = CuGraphGATConv.to_csc(edge_index, size) + out2 = conv2(x, csc, max_num_neighbors=max_num_neighbors) + assert torch.allclose(out1, out2, atol=1e-3) + + grad_output = torch.rand_like(out1) + out1.backward(grad_output) + out2.backward(grad_output) + + assert torch.allclose(conv1.lin_src.weight.grad, conv2.lin.weight.grad, atol=1e-3) + assert torch.allclose( + conv1.att_src.grad.flatten(), conv2.att.grad[: heads * out_channels], atol=1e-3 + ) + assert torch.allclose( + conv1.att_dst.grad.flatten(), conv2.att.grad[heads * out_channels :], atol=1e-3 + ) + if bias: + assert torch.allclose(conv1.bias.grad, conv2.bias.grad, atol=1e-3) From f687207be339c8ce00d247b79e38d066f9e81cbc Mon Sep 17 00:00:00 2001 From: Tingyu Wang Date: Tue, 11 Apr 2023 21:48:47 -0400 Subject: [PATCH 02/17] test backward accuracy --- .../cugraph_pyg/nn/conv/gat_conv.py | 10 +-- .../cugraph_pyg/nn/conv/transformer_conv.py | 53 ++++++++++++++ .../tests/nn/test_transformer_conv.py | 73 +++++++++++++++++++ 3 files changed, 130 insertions(+), 6 deletions(-) create mode 100644 python/cugraph-pyg/cugraph_pyg/tests/nn/test_transformer_conv.py diff --git a/python/cugraph-pyg/cugraph_pyg/nn/conv/gat_conv.py b/python/cugraph-pyg/cugraph_pyg/nn/conv/gat_conv.py index 9382898d32b..53dabac6651 100644 --- a/python/cugraph-pyg/cugraph_pyg/nn/conv/gat_conv.py +++ b/python/cugraph-pyg/cugraph_pyg/nn/conv/gat_conv.py @@ -13,18 +13,15 @@ from typing import Optional, Tuple +from pylibcugraphops.pytorch.operators import mha_gat_n2n as GATConvAgg + from cugraph.utilities.utils import import_optional from .base import BaseConv torch = import_optional("torch") nn = import_optional("torch.nn") -torch_geometric = import_optional("torch_geometric") -try: - from pylibcugraphops.pytorch.operators import mha_gat_n2n as GATConvAgg -except ImportError: - pass class GATConv(BaseConv): # pragma: no cover @@ -72,7 +69,8 @@ def reset_parameters(self): torch.nn.init.xavier_normal_( self.att.view(2, self.heads, self.out_channels), gain=gain ) - torch_geometric.nn.inits.zeros(self.bias) + if self.bias is not None: + self.bias.data.fill_(0.) def forward( self, diff --git a/python/cugraph-pyg/cugraph_pyg/nn/conv/transformer_conv.py b/python/cugraph-pyg/cugraph_pyg/nn/conv/transformer_conv.py index d887d2df236..f37b1f93657 100644 --- a/python/cugraph-pyg/cugraph_pyg/nn/conv/transformer_conv.py +++ b/python/cugraph-pyg/cugraph_pyg/nn/conv/transformer_conv.py @@ -10,9 +10,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional, Tuple, Union + +from pylibcugraphops.pytorch.operators import mha_simple_n2n as TransformerConvAgg + +from cugraph.utilities.utils import import_optional from .base import BaseConv +torch = import_optional("torch") +nn = import_optional("torch.nn") + class TransformerConv(BaseConv): r"""The graph transformer operator from the `"Masked Label Prediction: @@ -26,6 +34,7 @@ def __init__( out_channels: int, heads: int = 1, concat: bool = True, + edge_dim: Optional[int] = None, root_weight: bool = True, ): super().__init__() @@ -35,3 +44,47 @@ def __init__( self.heads = heads self.root_weight = root_weight self.concat = concat + self.edge_dim = edge_dim + + if isinstance(in_channels, int): + in_channels = (in_channels, in_channels) + + self.lin_key = nn.Linear(in_channels[0], heads * out_channels) + self.lin_query = nn.Linear(in_channels[1], heads * out_channels) + self.lin_value = nn.Linear(in_channels[0], heads * out_channels) + + if edge_dim is not None: + self.lin_edge = nn.Linear(edge_dim, heads * out_channels, bias=False) + else: + self.lin_edge = self.register_parameter('lin_edge', None) + + self.reset_parameters() + + def reset_parameters(self): + self.lin_key.reset_parameters() + self.lin_query.reset_parameters() + self.lin_value.reset_parameters() + if self.lin_edge is not None: + self.lin_edge.reset_parameters() + + def forward(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + csc: Tuple[torch.Tensor, torch.Tensor, int], + edge_attr: Optional[torch.Tensor] = None, + max_num_neighbors: Optional[int] = None): + graph = self.get_cugraph(csc, max_num_neighbors) + + if isinstance(x, torch.Tensor): + x = (x, x) + + query = self.lin_query(x[1]) + key = self.lin_key(x[0]) + value = self.lin_value(x[0]) + + out = TransformerConvAgg(key, query, value, graph, self.heads, self.concat, + edge_emb=edge_attr, norm_by_dim=False, score_bias=None) + + return out + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}({self.in_channels}, ' + f'{self.out_channels}, heads={self.heads})') diff --git a/python/cugraph-pyg/cugraph_pyg/tests/nn/test_transformer_conv.py b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_transformer_conv.py new file mode 100644 index 00000000000..919ff6d6005 --- /dev/null +++ b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_transformer_conv.py @@ -0,0 +1,73 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +try: + from torch_geometric.nn import TransformerConv +except ModuleNotFoundError: + pytest.skip("PyG not available", allow_module_level=True) + +from cugraph.utilities.utils import import_optional +from cugraph_pyg.nn import TransformerConv as CuGraphTransformerConv + +torch = import_optional("torch") + +@pytest.mark.parametrize("bipartite", [False]) +@pytest.mark.parametrize("concat", [True]) +@pytest.mark.parametrize("heads", [1, 2, 3]) +@pytest.mark.parametrize("max_num_neighbors", [8, None]) +def test_transformer_conv_equality(bipartite, concat, heads, max_num_neighbors): + in_channels, out_channels = (5, 2) + kwargs = dict(concat=concat) + + size = (10, 8) if bipartite else (10, 10) + x = torch.rand(size[0], in_channels, device="cuda") + edge_index = torch.tensor( + [ + [7, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 8, 9], + [0, 1, 2, 3, 4, 5, 6, 7, 7, 7, 7, 7, 7, 7, 7], + ], + device="cuda", + ) + + conv1 = TransformerConv(in_channels, out_channels, heads, bias=False, root_weight=False, **kwargs).cuda() + conv2 = CuGraphTransformerConv(in_channels, out_channels, heads, **kwargs).cuda() + + with torch.no_grad(): + conv2.lin_query.weight.data[:, :] = conv1.lin_query.weight.data + conv2.lin_key.weight.data[:, :] = conv1.lin_key.weight.data + conv2.lin_value.weight.data[:, :] = conv1.lin_value.weight.data + conv2.lin_query.bias.data[:] = conv1.lin_query.bias.data + conv2.lin_key.bias.data[:] = conv1.lin_key.bias.data + conv2.lin_value.bias.data[:] = conv1.lin_value.bias.data + + if bipartite: + out1 = conv1((x, x[: size[1]]), edge_index) + else: + out1 = conv1(x, edge_index) + + csc = CuGraphTransformerConv.to_csc(edge_index, size) + out2 = conv2(x, csc, max_num_neighbors=max_num_neighbors) + assert torch.allclose(out1, out2, atol=1e-2) + + grad_output = torch.rand_like(out1) + out1.backward(grad_output) + out2.backward(grad_output) + + assert torch.allclose(conv1.lin_query.weight.grad, conv2.lin_query.weight.grad, atol=1e-3) + assert torch.allclose(conv1.lin_key.weight.grad, conv2.lin_key.weight.grad, atol=1e-3) + assert torch.allclose(conv1.lin_value.weight.grad, conv2.lin_value.weight.grad, atol=1e-3) + assert torch.allclose(conv1.lin_query.bias.grad, conv2.lin_query.bias.grad, atol=1e-3) + assert torch.allclose(conv1.lin_key.bias.grad, conv2.lin_key.bias.grad, atol=1e-3) + assert torch.allclose(conv1.lin_value.bias.grad, conv2.lin_value.bias.grad, atol=1e-3) From 30883e12145d39dac12ecad0ecd31ff5b019049b Mon Sep 17 00:00:00 2001 From: Tingyu Wang Date: Wed, 12 Apr 2023 22:18:56 -0400 Subject: [PATCH 03/17] test bipartite variant --- .../cugraph-pyg/cugraph_pyg/nn/conv/base.py | 9 +++ .../cugraph_pyg/nn/conv/gat_conv.py | 3 +- .../cugraph_pyg/nn/conv/transformer_conv.py | 39 ++++++++---- .../tests/nn/test_transformer_conv.py | 60 ++++++++++++------- 4 files changed, 77 insertions(+), 34 deletions(-) diff --git a/python/cugraph-pyg/cugraph_pyg/nn/conv/base.py b/python/cugraph-pyg/cugraph_pyg/nn/conv/base.py index 91f64200e47..df6cc0f65c5 100644 --- a/python/cugraph-pyg/cugraph_pyg/nn/conv/base.py +++ b/python/cugraph-pyg/cugraph_pyg/nn/conv/base.py @@ -21,6 +21,7 @@ try: # pragma: no cover from pylibcugraphops.pytorch import ( + BipartiteCSC, SampledCSC, SampledHeteroCSC, StaticCSC, @@ -92,6 +93,7 @@ def get_cugraph( self, csc: Tuple[torch.Tensor, torch.Tensor, int], max_num_neighbors: Optional[int] = None, + bipartite: Optional[bool] = False, ) -> Any: r"""Constructs a :obj:`cugraph` graph object from CSC representation. Supports both bipartite and non-bipartite graphs. @@ -115,6 +117,9 @@ def get_cugraph( f"based processing (got CPU tensor)" ) + if bipartite: + return BipartiteCSC(colptr, row, num_src_nodes) + if num_src_nodes != colptr.numel() - 1: # Bipartite graph: if max_num_neighbors is None: max_num_neighbors = int((colptr[1:] - colptr[:-1]).max()) @@ -129,6 +134,7 @@ def get_typed_cugraph( edge_type: torch.Tensor, num_edge_types: Optional[int] = None, max_num_neighbors: Optional[int] = None, + bipartite: Optional[bool] = False, ) -> Any: r"""Constructs a typed :obj:`cugraph` graph object from a CSC representation where each edge corresponds to a given edge type. @@ -155,6 +161,9 @@ def get_typed_cugraph( row, colptr, num_src_nodes = csc edge_type = edge_type.int() + if bipartite: + raise NotImplementedError + if num_src_nodes != colptr.numel() - 1: # Bipartite graph: if max_num_neighbors is None: max_num_neighbors = int((colptr[1:] - colptr[:-1]).max()) diff --git a/python/cugraph-pyg/cugraph_pyg/nn/conv/gat_conv.py b/python/cugraph-pyg/cugraph_pyg/nn/conv/gat_conv.py index 53dabac6651..586a7599567 100644 --- a/python/cugraph-pyg/cugraph_pyg/nn/conv/gat_conv.py +++ b/python/cugraph-pyg/cugraph_pyg/nn/conv/gat_conv.py @@ -23,7 +23,6 @@ nn = import_optional("torch.nn") - class GATConv(BaseConv): # pragma: no cover r"""The graph attentional operator from the `"Graph Attention Networks" `_ paper. @@ -70,7 +69,7 @@ def reset_parameters(self): self.att.view(2, self.heads, self.out_channels), gain=gain ) if self.bias is not None: - self.bias.data.fill_(0.) + self.bias.data.fill_(0.0) def forward( self, diff --git a/python/cugraph-pyg/cugraph_pyg/nn/conv/transformer_conv.py b/python/cugraph-pyg/cugraph_pyg/nn/conv/transformer_conv.py index f37b1f93657..5b267f8459a 100644 --- a/python/cugraph-pyg/cugraph_pyg/nn/conv/transformer_conv.py +++ b/python/cugraph-pyg/cugraph_pyg/nn/conv/transformer_conv.py @@ -30,7 +30,7 @@ class TransformerConv(BaseConv): def __init__( self, - in_channels: int, + in_channels: Union[int, Tuple[int, int]], out_channels: int, heads: int = 1, concat: bool = True, @@ -56,7 +56,7 @@ def __init__( if edge_dim is not None: self.lin_edge = nn.Linear(edge_dim, heads * out_channels, bias=False) else: - self.lin_edge = self.register_parameter('lin_edge', None) + self.lin_edge = self.register_parameter("lin_edge", None) self.reset_parameters() @@ -67,24 +67,39 @@ def reset_parameters(self): if self.lin_edge is not None: self.lin_edge.reset_parameters() - def forward(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], - csc: Tuple[torch.Tensor, torch.Tensor, int], - edge_attr: Optional[torch.Tensor] = None, - max_num_neighbors: Optional[int] = None): - graph = self.get_cugraph(csc, max_num_neighbors) + def forward( + self, + x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + csc: Tuple[torch.Tensor, torch.Tensor, int], + edge_attr: Optional[torch.Tensor] = None, + max_num_neighbors: Optional[int] = None, + ): + bipartite = not isinstance(x, torch.Tensor) + graph = self.get_cugraph(csc, max_num_neighbors, bipartite=bipartite) - if isinstance(x, torch.Tensor): + if not bipartite: x = (x, x) query = self.lin_query(x[1]) key = self.lin_key(x[0]) value = self.lin_value(x[0]) - out = TransformerConvAgg(key, query, value, graph, self.heads, self.concat, - edge_emb=edge_attr, norm_by_dim=False, score_bias=None) + out = TransformerConvAgg( + key, + query, + value, + graph, + self.heads, + self.concat, + edge_emb=edge_attr, + norm_by_dim=False, + score_bias=None, + ) return out def __repr__(self) -> str: - return (f'{self.__class__.__name__}({self.in_channels}, ' - f'{self.out_channels}, heads={self.heads})') + return ( + f"{self.__class__.__name__}({self.in_channels}, " + f"{self.out_channels}, heads={self.heads})" + ) diff --git a/python/cugraph-pyg/cugraph_pyg/tests/nn/test_transformer_conv.py b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_transformer_conv.py index 919ff6d6005..949d7b8fc1c 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/nn/test_transformer_conv.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_transformer_conv.py @@ -23,16 +23,25 @@ torch = import_optional("torch") -@pytest.mark.parametrize("bipartite", [False]) + +@pytest.mark.parametrize("bipartite", [True, False]) @pytest.mark.parametrize("concat", [True]) @pytest.mark.parametrize("heads", [1, 2, 3]) -@pytest.mark.parametrize("max_num_neighbors", [8, None]) -def test_transformer_conv_equality(bipartite, concat, heads, max_num_neighbors): - in_channels, out_channels = (5, 2) +def test_transformer_conv_equality(bipartite, concat, heads): + out_channels = 2 + size = (10, 10) kwargs = dict(concat=concat) - size = (10, 8) if bipartite else (10, 10) - x = torch.rand(size[0], in_channels, device="cuda") + if bipartite: + in_channels = (5, 3) + x = ( + torch.rand(size[0], in_channels[0], device="cuda"), + torch.rand(size[1], in_channels[1], device="cuda"), + ) + else: + in_channels = 5 + x = torch.rand(size[0], in_channels, device="cuda") + edge_index = torch.tensor( [ [7, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 8, 9], @@ -41,7 +50,9 @@ def test_transformer_conv_equality(bipartite, concat, heads, max_num_neighbors): device="cuda", ) - conv1 = TransformerConv(in_channels, out_channels, heads, bias=False, root_weight=False, **kwargs).cuda() + conv1 = TransformerConv( + in_channels, out_channels, heads, bias=False, root_weight=False, **kwargs + ).cuda() conv2 = CuGraphTransformerConv(in_channels, out_channels, heads, **kwargs).cuda() with torch.no_grad(): @@ -52,22 +63,31 @@ def test_transformer_conv_equality(bipartite, concat, heads, max_num_neighbors): conv2.lin_key.bias.data[:] = conv1.lin_key.bias.data conv2.lin_value.bias.data[:] = conv1.lin_value.bias.data - if bipartite: - out1 = conv1((x, x[: size[1]]), edge_index) - else: - out1 = conv1(x, edge_index) - + out1 = conv1(x, edge_index) csc = CuGraphTransformerConv.to_csc(edge_index, size) - out2 = conv2(x, csc, max_num_neighbors=max_num_neighbors) - assert torch.allclose(out1, out2, atol=1e-2) + out2 = conv2(x, csc) + + atol = 1e-2 + + assert torch.allclose(out1, out2, atol=atol) grad_output = torch.rand_like(out1) out1.backward(grad_output) out2.backward(grad_output) - assert torch.allclose(conv1.lin_query.weight.grad, conv2.lin_query.weight.grad, atol=1e-3) - assert torch.allclose(conv1.lin_key.weight.grad, conv2.lin_key.weight.grad, atol=1e-3) - assert torch.allclose(conv1.lin_value.weight.grad, conv2.lin_value.weight.grad, atol=1e-3) - assert torch.allclose(conv1.lin_query.bias.grad, conv2.lin_query.bias.grad, atol=1e-3) - assert torch.allclose(conv1.lin_key.bias.grad, conv2.lin_key.bias.grad, atol=1e-3) - assert torch.allclose(conv1.lin_value.bias.grad, conv2.lin_value.bias.grad, atol=1e-3) + assert torch.allclose( + conv1.lin_query.weight.grad, conv2.lin_query.weight.grad, atol=atol + ) + assert torch.allclose( + conv1.lin_key.weight.grad, conv2.lin_key.weight.grad, atol=atol + ) + assert torch.allclose( + conv1.lin_value.weight.grad, conv2.lin_value.weight.grad, atol=atol + ) + assert torch.allclose( + conv1.lin_query.bias.grad, conv2.lin_query.bias.grad, atol=atol + ) + assert torch.allclose(conv1.lin_key.bias.grad, conv2.lin_key.bias.grad, atol=atol) + assert torch.allclose( + conv1.lin_value.bias.grad, conv2.lin_value.bias.grad, atol=atol + ) From 082d69f6b4a9d7d16df762c464dccd65bd378b7c Mon Sep 17 00:00:00 2001 From: Tingyu Wang Date: Thu, 13 Apr 2023 20:51:34 -0400 Subject: [PATCH 04/17] add self connection --- .../cugraph_pyg/nn/conv/transformer_conv.py | 36 ++++++++++++++++--- .../tests/nn/test_transformer_conv.py | 26 +++++++------- 2 files changed, 44 insertions(+), 18 deletions(-) diff --git a/python/cugraph-pyg/cugraph_pyg/nn/conv/transformer_conv.py b/python/cugraph-pyg/cugraph_pyg/nn/conv/transformer_conv.py index 5b267f8459a..f8b45674552 100644 --- a/python/cugraph-pyg/cugraph_pyg/nn/conv/transformer_conv.py +++ b/python/cugraph-pyg/cugraph_pyg/nn/conv/transformer_conv.py @@ -34,7 +34,9 @@ def __init__( out_channels: int, heads: int = 1, concat: bool = True, + beta: bool = False, edge_dim: Optional[int] = None, + bias: bool = True, root_weight: bool = True, ): super().__init__() @@ -42,6 +44,7 @@ def __init__( self.in_channels = in_channels self.out_channels = out_channels self.heads = heads + self.beta = beta and root_weight self.root_weight = root_weight self.concat = concat self.edge_dim = edge_dim @@ -58,6 +61,19 @@ def __init__( else: self.lin_edge = self.register_parameter("lin_edge", None) + if concat: + self.lin_skip = nn.Linear(in_channels[1], heads * out_channels, bias=bias) + if self.beta: + self.lin_beta = nn.Linear(3 * heads * out_channels, 1, bias=bias) + else: + self.lin_beta = self.register_parameter("lin_beta", None) + else: + self.lin_skip = nn.Linear(in_channels[1], out_channels, bias=bias) + if self.beta: + self.lin_beta = nn.Linear(3 * out_channels, 1, bias=False) + else: + self.lin_beta = self.register_parameter("lin_beta", None) + self.reset_parameters() def reset_parameters(self): @@ -66,16 +82,19 @@ def reset_parameters(self): self.lin_value.reset_parameters() if self.lin_edge is not None: self.lin_edge.reset_parameters() + if self.lin_skip is not None: + self.lin_skip.reset_parameters() + if self.lin_beta is not None: + self.lin_beta.reset_parameters() def forward( self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], csc: Tuple[torch.Tensor, torch.Tensor, int], edge_attr: Optional[torch.Tensor] = None, - max_num_neighbors: Optional[int] = None, - ): + ) -> torch.Tensor: bipartite = not isinstance(x, torch.Tensor) - graph = self.get_cugraph(csc, max_num_neighbors, bipartite=bipartite) + graph = self.get_cugraph(csc, bipartite=bipartite) if not bipartite: x = (x, x) @@ -92,10 +111,19 @@ def forward( self.heads, self.concat, edge_emb=edge_attr, - norm_by_dim=False, + norm_by_dim=True, score_bias=None, ) + if self.root_weight: + x_r = self.lin_skip(x[1]) + if self.lin_beta is not None: + beta = self.lin_beta(torch.cat([out, x_r, out - x_r], dim=-1)) + beta = beta.sigmoid() + out = beta * x_r + (1 - beta) * out + else: + out = out + x_r + return out def __repr__(self) -> str: diff --git a/python/cugraph-pyg/cugraph_pyg/tests/nn/test_transformer_conv.py b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_transformer_conv.py index 949d7b8fc1c..c36df1ea410 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/nn/test_transformer_conv.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_transformer_conv.py @@ -25,12 +25,12 @@ @pytest.mark.parametrize("bipartite", [True, False]) -@pytest.mark.parametrize("concat", [True]) +@pytest.mark.parametrize("concat", [True, False]) @pytest.mark.parametrize("heads", [1, 2, 3]) def test_transformer_conv_equality(bipartite, concat, heads): out_channels = 2 size = (10, 10) - kwargs = dict(concat=concat) + kwargs = dict(concat=concat, bias=False, root_weight=False) if bipartite: in_channels = (5, 3) @@ -44,30 +44,28 @@ def test_transformer_conv_equality(bipartite, concat, heads): edge_index = torch.tensor( [ - [7, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 8, 9], - [0, 1, 2, 3, 4, 5, 6, 7, 7, 7, 7, 7, 7, 7, 7], + [7, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 8, 9, 3, 4, 5], + [0, 1, 2, 3, 4, 5, 6, 7, 7, 7, 7, 7, 7, 7, 7, 5, 5, 6], ], device="cuda", ) - conv1 = TransformerConv( - in_channels, out_channels, heads, bias=False, root_weight=False, **kwargs - ).cuda() + conv1 = TransformerConv(in_channels, out_channels, heads, **kwargs).cuda() conv2 = CuGraphTransformerConv(in_channels, out_channels, heads, **kwargs).cuda() with torch.no_grad(): - conv2.lin_query.weight.data[:, :] = conv1.lin_query.weight.data - conv2.lin_key.weight.data[:, :] = conv1.lin_key.weight.data - conv2.lin_value.weight.data[:, :] = conv1.lin_value.weight.data - conv2.lin_query.bias.data[:] = conv1.lin_query.bias.data - conv2.lin_key.bias.data[:] = conv1.lin_key.bias.data - conv2.lin_value.bias.data[:] = conv1.lin_value.bias.data + conv2.lin_query.weight.data = conv1.lin_query.weight.data.detach().clone() + conv2.lin_key.weight.data = conv1.lin_key.weight.data.detach().clone() + conv2.lin_value.weight.data = conv1.lin_value.weight.data.detach().clone() + conv2.lin_query.bias.data = conv1.lin_query.bias.data.detach().clone() + conv2.lin_key.bias.data = conv1.lin_key.bias.data.detach().clone() + conv2.lin_value.bias.data = conv1.lin_value.bias.data.detach().clone() out1 = conv1(x, edge_index) csc = CuGraphTransformerConv.to_csc(edge_index, size) out2 = conv2(x, csc) - atol = 1e-2 + atol = 1e-5 assert torch.allclose(out1, out2, atol=atol) From d20e7137f330b43a1567e07d235e64f96c6766a2 Mon Sep 17 00:00:00 2001 From: Tingyu Wang Date: Fri, 14 Apr 2023 14:51:41 -0400 Subject: [PATCH 05/17] docstring --- .../cugraph-pyg/cugraph_pyg/nn/conv/base.py | 38 +++++----- .../cugraph_pyg/nn/conv/transformer_conv.py | 74 +++++++++++++++++++ 2 files changed, 92 insertions(+), 20 deletions(-) diff --git a/python/cugraph-pyg/cugraph_pyg/nn/conv/base.py b/python/cugraph-pyg/cugraph_pyg/nn/conv/base.py index df6cc0f65c5..2be6851460a 100644 --- a/python/cugraph-pyg/cugraph_pyg/nn/conv/base.py +++ b/python/cugraph-pyg/cugraph_pyg/nn/conv/base.py @@ -34,7 +34,7 @@ class BaseConv(torch.nn.Module): # pragma: no cover - r"""An abstract base class for implementing cugraph message passing layers.""" + r"""An abstract base class for implementing cugraph-ops message passing layers.""" def __init__(self): super().__init__() @@ -58,7 +58,7 @@ def to_csc( Tuple[Tuple[torch.Tensor, torch.Tensor, int], torch.Tensor], ]: r"""Returns a CSC representation of an :obj:`edge_index` tensor to be - used as input to a :class:`CuGraphModule`. + used as input to cugraph-ops conv layers. Args: edge_index (torch.Tensor): The edge indices. @@ -93,21 +93,23 @@ def get_cugraph( self, csc: Tuple[torch.Tensor, torch.Tensor, int], max_num_neighbors: Optional[int] = None, - bipartite: Optional[bool] = False, + bipartite: bool = False, ) -> Any: - r"""Constructs a :obj:`cugraph` graph object from CSC representation. + r"""Constructs a :obj:`cugraph-ops` graph object from CSC representation. Supports both bipartite and non-bipartite graphs. Args: csc ((torch.Tensor, torch.Tensor, int)): A tuple containing the CSC representation of a graph, given as a tuple of :obj:`(row, colptr, num_src_nodes)`. Use the - :meth:`CuGraphModule.to_csc` method to convert an - :obj:`edge_index` representation to the desired format. + :meth:`to_csc` method to convert an :obj:`edge_index` + representation to the desired format. max_num_neighbors (int, optional): The maximum number of neighbors of a target node. It is only effective when operating in a bipartite graph. When not given, will be computed on-the-fly, leading to slightly worse performance. (default: :obj:`None`) + bipartite (bool): If set to :obj:`True`, will create the bipartite + structure in cugraph-ops. (default: :obj:`False`) """ row, colptr, num_src_nodes = csc @@ -120,7 +122,7 @@ def get_cugraph( if bipartite: return BipartiteCSC(colptr, row, num_src_nodes) - if num_src_nodes != colptr.numel() - 1: # Bipartite graph: + if num_src_nodes != colptr.numel() - 1: if max_num_neighbors is None: max_num_neighbors = int((colptr[1:] - colptr[:-1]).max()) @@ -134,7 +136,7 @@ def get_typed_cugraph( edge_type: torch.Tensor, num_edge_types: Optional[int] = None, max_num_neighbors: Optional[int] = None, - bipartite: Optional[bool] = False, + bipartite: bool = False, ) -> Any: r"""Constructs a typed :obj:`cugraph` graph object from a CSC representation where each edge corresponds to a given edge type. @@ -144,8 +146,8 @@ def get_typed_cugraph( csc ((torch.Tensor, torch.Tensor, int)): A tuple containing the CSC representation of a graph, given as a tuple of :obj:`(row, colptr, num_src_nodes)`. Use the - :meth:`CuGraphModule.to_csc` method to convert an - :obj:`edge_index` representation to the desired format. + :meth:`to_csc` method to convert an :obj:`edge_index` + representation to the desired format. edge_type (torch.Tensor): The edge type. num_edge_types (int, optional): The maximum number of edge types. When not given, will be computed on-the-fly, leading to @@ -154,6 +156,8 @@ def get_typed_cugraph( of a target node. It is only effective when operating in a bipartite graph. When not given, will be computed on-the-fly, leading to slightly worse performance. (default: :obj:`None`) + bipartite (bool): If set to :obj:`True`, will create the bipartite + structure in cugraph-ops. (default: :obj:`False`) """ if num_edge_types is None: num_edge_types = int(edge_type.max()) + 1 @@ -164,7 +168,7 @@ def get_typed_cugraph( if bipartite: raise NotImplementedError - if num_src_nodes != colptr.numel() - 1: # Bipartite graph: + if num_src_nodes != colptr.numel() - 1: if max_num_neighbors is None: max_num_neighbors = int((colptr[1:] - colptr[:-1]).max()) @@ -176,9 +180,8 @@ def get_typed_cugraph( def forward( self, - x: torch.Tensor, + x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], csc: Tuple[torch.Tensor, torch.Tensor, int], - max_num_neighbors: Optional[int] = None, ) -> torch.Tensor: r"""Runs the forward pass of the module. @@ -187,12 +190,7 @@ def forward( csc ((torch.Tensor, torch.Tensor, int)): A tuple containing the CSC representation of a graph, given as a tuple of :obj:`(row, colptr, num_src_nodes)`. Use the - :meth:`CuGraphModule.to_csc` method to convert an - :obj:`edge_index` representation to the desired format. - max_num_neighbors (int, optional): The maximum number of neighbors - of a target node. It is only effective when operating in a - bipartite graph. When not given, the value will be computed - on-the-fly, leading to slightly worse performance. - (default: :obj:`None`) + :meth:`to_csc` method to convert an :obj:`edge_index` + representation to the desired format. """ raise NotImplementedError diff --git a/python/cugraph-pyg/cugraph_pyg/nn/conv/transformer_conv.py b/python/cugraph-pyg/cugraph_pyg/nn/conv/transformer_conv.py index f8b45674552..273e4993053 100644 --- a/python/cugraph-pyg/cugraph_pyg/nn/conv/transformer_conv.py +++ b/python/cugraph-pyg/cugraph_pyg/nn/conv/transformer_conv.py @@ -26,6 +26,68 @@ class TransformerConv(BaseConv): r"""The graph transformer operator from the `"Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification" `_ paper + + .. math:: + \mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + + \sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \mathbf{W}_2 \mathbf{x}_{j}, + + where the attention coefficients :math:`\alpha_{i,j}` are computed via + multi-head dot product attention: + + .. math:: + \alpha_{i,j} = \textrm{softmax} \left( + \frac{(\mathbf{W}_3\mathbf{x}_i)^{\top} (\mathbf{W}_4\mathbf{x}_j)} + {\sqrt{d}} \right) + + Args: + in_channels (int or tuple): Size of each input sample, or :obj:`-1` to + derive the size from the first input(s) to the forward method. + A tuple corresponds to the sizes of source and target + dimensionalities. + out_channels (int): Size of each output sample. + heads (int, optional): Number of multi-head-attentions. + (default: :obj:`1`) + concat (bool, optional): If set to :obj:`False`, the multi-head + attentions are averaged instead of concatenated. + (default: :obj:`True`) + beta (bool, optional): If set, will combine aggregation and + skip information via + + .. math:: + \mathbf{x}^{\prime}_i = \beta_i \mathbf{W}_1 \mathbf{x}_i + + (1 - \beta_i) \underbrace{\left(\sum_{j \in \mathcal{N}(i)} + \alpha_{i,j} \mathbf{W}_2 \vec{x}_j \right)}_{=\mathbf{m}_i} + + with :math:`\beta_i = \textrm{sigmoid}(\mathbf{w}_5^{\top} + [ \mathbf{W}_1 \mathbf{x}_i, \mathbf{m}_i, \mathbf{W}_1 + \mathbf{x}_i - \mathbf{m}_i ])` (default: :obj:`False`) + edge_dim (int, optional): Edge feature dimensionality (in case + there are any). Edge features are added to the keys after + linear transformation, that is, prior to computing the + attention dot product. They are also added to final values + after the same linear transformation. The model is: + + .. math:: + \mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + + \sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \left( + \mathbf{W}_2 \mathbf{x}_{j} + \mathbf{W}_6 \mathbf{e}_{ij} + \right), + + where the attention coefficients :math:`\alpha_{i,j}` are now + computed via: + + .. math:: + \alpha_{i,j} = \textrm{softmax} \left( + \frac{(\mathbf{W}_3\mathbf{x}_i)^{\top} + (\mathbf{W}_4\mathbf{x}_j + \mathbf{W}_6 \mathbf{e}_{ij})} + {\sqrt{d}} \right) + + (default :obj:`None`) + bias (bool, optional): If set to :obj:`False`, the layer will not learn + an additive bias. (default: :obj:`True`) + root_weight (bool, optional): If set to :obj:`False`, the layer will + not add the transformed root node features to the output and the + option :attr:`beta` is set to :obj:`False`. (default: :obj:`True`) """ def __init__( @@ -93,6 +155,18 @@ def forward( csc: Tuple[torch.Tensor, torch.Tensor, int], edge_attr: Optional[torch.Tensor] = None, ) -> torch.Tensor: + r"""Runs the forward pass of the module. + + Args: + x (torch.Tensor or tuple): The node features. Can be a tuple of + tensors denoting source and destination node features. + csc ((torch.Tensor, torch.Tensor, int)): A tuple containing the CSC + representation of a graph, given as a tuple of + :obj:`(row, colptr, num_src_nodes)`. Use the + :meth:`to_csc` method to convert an :obj:`edge_index` + representation to the desired format. + edge_attr: (torch.Tensor, optional) The edge features. + """ bipartite = not isinstance(x, torch.Tensor) graph = self.get_cugraph(csc, bipartite=bipartite) From 8f46dfc1a07591301242db2ae32b2fc5271c2d13 Mon Sep 17 00:00:00 2001 From: Tingyu Wang Date: Fri, 14 Apr 2023 17:54:47 -0400 Subject: [PATCH 06/17] support bipartite input features in GATConv --- .../cugraph-pyg/cugraph_pyg/nn/conv/base.py | 12 +- .../cugraph_pyg/nn/conv/gat_conv.py | 158 +++++++++++++++--- .../cugraph_pyg/nn/conv/transformer_conv.py | 2 +- 3 files changed, 139 insertions(+), 33 deletions(-) diff --git a/python/cugraph-pyg/cugraph_pyg/nn/conv/base.py b/python/cugraph-pyg/cugraph_pyg/nn/conv/base.py index 2be6851460a..bec50792131 100644 --- a/python/cugraph-pyg/cugraph_pyg/nn/conv/base.py +++ b/python/cugraph-pyg/cugraph_pyg/nn/conv/base.py @@ -92,8 +92,8 @@ def to_csc( def get_cugraph( self, csc: Tuple[torch.Tensor, torch.Tensor, int], - max_num_neighbors: Optional[int] = None, bipartite: bool = False, + max_num_neighbors: Optional[int] = None, ) -> Any: r"""Constructs a :obj:`cugraph-ops` graph object from CSC representation. Supports both bipartite and non-bipartite graphs. @@ -104,12 +104,12 @@ def get_cugraph( :obj:`(row, colptr, num_src_nodes)`. Use the :meth:`to_csc` method to convert an :obj:`edge_index` representation to the desired format. + bipartite (bool): If set to :obj:`True`, will create the bipartite + structure in cugraph-ops. (default: :obj:`False`) max_num_neighbors (int, optional): The maximum number of neighbors of a target node. It is only effective when operating in a bipartite graph. When not given, will be computed on-the-fly, leading to slightly worse performance. (default: :obj:`None`) - bipartite (bool): If set to :obj:`True`, will create the bipartite - structure in cugraph-ops. (default: :obj:`False`) """ row, colptr, num_src_nodes = csc @@ -135,8 +135,8 @@ def get_typed_cugraph( csc: Tuple[torch.Tensor, torch.Tensor, int], edge_type: torch.Tensor, num_edge_types: Optional[int] = None, - max_num_neighbors: Optional[int] = None, bipartite: bool = False, + max_num_neighbors: Optional[int] = None, ) -> Any: r"""Constructs a typed :obj:`cugraph` graph object from a CSC representation where each edge corresponds to a given edge type. @@ -152,12 +152,12 @@ def get_typed_cugraph( num_edge_types (int, optional): The maximum number of edge types. When not given, will be computed on-the-fly, leading to slightly worse performance. (default: :obj:`None`) + bipartite (bool): If set to :obj:`True`, will create the bipartite + structure in cugraph-ops. (default: :obj:`False`) max_num_neighbors (int, optional): The maximum number of neighbors of a target node. It is only effective when operating in a bipartite graph. When not given, will be computed on-the-fly, leading to slightly worse performance. (default: :obj:`None`) - bipartite (bool): If set to :obj:`True`, will create the bipartite - structure in cugraph-ops. (default: :obj:`False`) """ if num_edge_types is None: num_edge_types = int(edge_type.max()) + 1 diff --git a/python/cugraph-pyg/cugraph_pyg/nn/conv/gat_conv.py b/python/cugraph-pyg/cugraph_pyg/nn/conv/gat_conv.py index 586a7599567..a0807b81183 100644 --- a/python/cugraph-pyg/cugraph_pyg/nn/conv/gat_conv.py +++ b/python/cugraph-pyg/cugraph_pyg/nn/conv/gat_conv.py @@ -10,10 +10,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional, Tuple, Union -from typing import Optional, Tuple - -from pylibcugraphops.pytorch.operators import mha_gat_n2n as GATConvAgg +from pylibcugraphops.pytorch.operators import mha_gat_n2n, mha_gat_n2n_bipartite from cugraph.utilities.utils import import_optional @@ -23,23 +22,68 @@ nn = import_optional("torch.nn") -class GATConv(BaseConv): # pragma: no cover +class GATConv(BaseConv): r"""The graph attentional operator from the `"Graph Attention Networks" `_ paper. - :class:`GATConv` is an optimized version of - :class:`~torch_geometric.nn.conv.GATConv` based on the :obj:`cugraph-ops` - package that fuses message passing computation for accelerated execution - and lower memory footprint. + .. math:: + \mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}\mathbf{x}_{i} + + \sum_{j \in \mathcal{N}(i)} \alpha_{i,j}\mathbf{\Theta}\mathbf{x}_{j}, + + where the attention coefficients :math:`\alpha_{i,j}` are computed as + + .. math:: + \alpha_{i,j} = + \frac{ + \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} + [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_j] + \right)\right)} + {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} + \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} + [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_k] + \right)\right)}. + + If the graph has multi-dimensional edge features :math:`\mathbf{e}_{i,j}`, + the attention coefficients :math:`\alpha_{i,j}` are computed as + + .. math:: + \alpha_{i,j} = + \frac{ + \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} + [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_j + \, \Vert \, \mathbf{\Theta}_{e} \mathbf{e}_{i,j}]\right)\right)} + {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} + \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} + [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_k + \, \Vert \, \mathbf{\Theta}_{e} \mathbf{e}_{i,k}]\right)\right)}. + + Args: + in_channels (int or tuple): Size of each input sample, or :obj:`-1` to + derive the size from the first input(s) to the forward method. + A tuple corresponds to the sizes of source and target + dimensionalities. + out_channels (int): Size of each output sample. + heads (int, optional): Number of multi-head-attentions. + (default: :obj:`1`) + concat (bool, optional): If set to :obj:`False`, the multi-head + attentions are averaged instead of concatenated. + (default: :obj:`True`) + negative_slope (float, optional): LeakyReLU angle of the negative + slope. (default: :obj:`0.2`) + edge_dim (int, optional): Edge feature dimensionality (in case + there are any). (default: :obj:`None`) + bias (bool, optional): If set to :obj:`False`, the layer will not learn + an additive bias. (default: :obj:`True`) """ def __init__( self, - in_channels: int, + in_channels: Union[int, Tuple[int, int]], out_channels: int, heads: int = 1, concat: bool = True, negative_slope: float = 0.2, + edge_dim: Optional[int] = None, bias: bool = True, ): super().__init__() @@ -49,9 +93,20 @@ def __init__( self.heads = heads self.concat = concat self.negative_slope = negative_slope + self.edge_dim = edge_dim - self.lin = nn.Linear(in_channels, heads * out_channels, bias=False) - self.att = nn.Parameter(torch.Tensor(2 * heads * out_channels)) + if isinstance(in_channels, int): + self.lin = nn.Linear(in_channels, heads * out_channels, bias=False) + else: + self.lin_src = nn.Linear(in_channels[0], heads * out_channels, bias=False) + self.lin_dst = nn.Linear(in_channels[1], heads * out_channels, bias=False) + + if edge_dim is not None: + self.lin_edge = nn.Linear(edge_dim, heads * out_channels, bias=False) + self.att = nn.Parameter(torch.Tensor(3 * heads * out_channels)) + else: + self.register_parameter("lin_edge", None) + self.att = nn.Parameter(torch.Tensor(2 * heads * out_channels)) if bias and concat: self.bias = nn.Parameter(torch.Tensor(heads * out_channels)) @@ -63,34 +118,85 @@ def __init__( self.reset_parameters() def reset_parameters(self): - self.lin.reset_parameters() + if isinstance(self.in_channels, int): + self.lin.reset_parameters() + else: + self.lin_src.reset_parameters() + self.lin_dst.reset_parameters() + gain = torch.nn.init.calculate_gain("relu") torch.nn.init.xavier_normal_( - self.att.view(2, self.heads, self.out_channels), gain=gain + self.att.view(-1, self.heads, self.out_channels), gain=gain ) + + if self.lin_edge is not None: + self.lin_edge.reset_parameters() if self.bias is not None: self.bias.data.fill_(0.0) def forward( self, - x: torch.Tensor, + x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], csc: Tuple[torch.Tensor, torch.Tensor, int], + edge_attr: Optional[torch.Tensor] = None, max_num_neighbors: Optional[int] = None, ) -> torch.Tensor: - graph = self.get_cugraph(csc, max_num_neighbors) - - x = self.lin(x) - - out = GATConvAgg( - x, - self.att, - graph, - self.heads, - "LeakyReLU", - self.negative_slope, - self.concat, + r"""Runs the forward pass of the module. + + Args: + x (torch.Tensor or tuple): The node features. Can be a tuple of + tensors denoting source and destination node features. + csc ((torch.Tensor, torch.Tensor, int)): A tuple containing the CSC + representation of a graph, given as a tuple of + :obj:`(row, colptr, num_src_nodes)`. Use the + :meth:`to_csc` method to convert an :obj:`edge_index` + representation to the desired format. + edge_attr: (torch.Tensor, optional) The edge features. + max_num_neighbors (int, optional): The maximum number of neighbors + of a target node. It is only effective when operating in a + bipartite graph. When not given, will be computed on-the-fly, + leading to slightly worse performance. (default: :obj:`None`) + """ + bipartite = not isinstance(x, torch.Tensor) + graph = self.get_cugraph( + csc, bipartite=bipartite, max_num_neighbors=max_num_neighbors ) + if edge_attr is not None and self.lin_edge is not None: + if edge_attr.dim() == 1: + edge_attr = edge_attr.view(-1, 1) + edge_attr = self.lin_edge(edge_attr) + + if bipartite: + x_src = self.lin_src(x[0]) + x_dst = self.lin_dst(x[1]) + + out = mha_gat_n2n_bipartite( + x_src, + x_dst, + self.att, + graph, + num_heads=self.heads, + activation="LeakyReLU", + negative_slope=self.negative_slope, + concat_heads=self.concat, + edge_feat=edge_attr, + ) + + else: + x = self.lin(x) + + out = mha_gat_n2n( + x, + self.att, + graph, + num_heads=self.heads, + activation="LeakyReLU", + negative_slope=self.negative_slope, + concat_heads=self.concat, + edge_feat=edge_attr, + ) + if self.bias is not None: out = out + self.bias diff --git a/python/cugraph-pyg/cugraph_pyg/nn/conv/transformer_conv.py b/python/cugraph-pyg/cugraph_pyg/nn/conv/transformer_conv.py index 273e4993053..23e4cefd7da 100644 --- a/python/cugraph-pyg/cugraph_pyg/nn/conv/transformer_conv.py +++ b/python/cugraph-pyg/cugraph_pyg/nn/conv/transformer_conv.py @@ -25,7 +25,7 @@ class TransformerConv(BaseConv): r"""The graph transformer operator from the `"Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification" - `_ paper + `_ paper. .. math:: \mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + From 389ef2d4ce1f37cd82bf9bb9e476b3262364ae20 Mon Sep 17 00:00:00 2001 From: Tingyu Wang Date: Thu, 20 Apr 2023 21:54:40 -0400 Subject: [PATCH 07/17] fix edge feat computation --- python/cugraph-pyg/cugraph_pyg/nn/conv/transformer_conv.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/cugraph-pyg/cugraph_pyg/nn/conv/transformer_conv.py b/python/cugraph-pyg/cugraph_pyg/nn/conv/transformer_conv.py index 23e4cefd7da..90552c5b1f7 100644 --- a/python/cugraph-pyg/cugraph_pyg/nn/conv/transformer_conv.py +++ b/python/cugraph-pyg/cugraph_pyg/nn/conv/transformer_conv.py @@ -177,6 +177,9 @@ def forward( key = self.lin_key(x[0]) value = self.lin_value(x[0]) + if self.lin_edge is not None and edge_attr is not None: + edge_attr = self.lin_edge(edge_attr) + out = TransformerConvAgg( key, query, From 183416acd29b723ee79eb96679b8a2b335121ec7 Mon Sep 17 00:00:00 2001 From: Tingyu Wang Date: Thu, 20 Apr 2023 22:02:21 -0400 Subject: [PATCH 08/17] update test script --- ci/test_python.sh | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ci/test_python.sh b/ci/test_python.sh index 2a6be338819..bc69c9364d2 100755 --- a/ci/test_python.sh +++ b/ci/test_python.sh @@ -179,6 +179,7 @@ if [[ "${RAPIDS_CUDA_VERSION}" == "11.8.0" ]]; then --channel "${PYTHON_CHANNEL}" \ libcugraph \ pylibcugraph \ + pylibcugraphops \ cugraph \ cugraph-pyg @@ -198,13 +199,13 @@ if [[ "${RAPIDS_CUDA_VERSION}" == "11.8.0" ]]; then --cov-report=term \ . popd - + # Reactivate the test environment back set +u conda deactivate conda activate test set -u - + else rapids-logger "skipping cugraph_pyg pytest on ARM64" fi From dee2fe5ffe699da6f33710a8910f126e5798a172 Mon Sep 17 00:00:00 2001 From: Tingyu Wang Date: Fri, 21 Apr 2023 17:24:48 -0400 Subject: [PATCH 09/17] add option to use edge feat in GATConv test --- .../cugraph_pyg/tests/nn/test_gat_conv.py | 86 ++++++++++++++----- 1 file changed, 65 insertions(+), 21 deletions(-) diff --git a/python/cugraph-pyg/cugraph_pyg/tests/nn/test_gat_conv.py b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_gat_conv.py index 6be5a6e48ef..3b74cb155f3 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/nn/test_gat_conv.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_gat_conv.py @@ -27,51 +27,95 @@ @pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize("bipartite", [True, False]) @pytest.mark.parametrize("concat", [True, False]) -@pytest.mark.parametrize("heads", [1, 2, 3]) +@pytest.mark.parametrize("heads", [1, 2, 3, 5, 10, 16]) @pytest.mark.parametrize("max_num_neighbors", [8, None]) -def test_gat_conv_equality(bias, bipartite, concat, heads, max_num_neighbors): - in_channels, out_channels = (5, 2) - kwargs = dict(bias=bias, concat=concat) - - size = (10, 8) if bipartite else (10, 10) - x = torch.rand(size[0], in_channels, device="cuda") +@pytest.mark.parametrize("use_edge_attr", [True, False]) +def test_gat_conv_equality( + bias, bipartite, concat, heads, max_num_neighbors, use_edge_attr +): edge_index = torch.tensor( [ [7, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 8, 9], [0, 1, 2, 3, 4, 5, 6, 7, 7, 7, 7, 7, 7, 7, 7], ], - device="cuda", - ) + ).cuda() + size = (10, 10) + + if bipartite: + in_channels = (5, 3) + x = ( + torch.rand(size[0], in_channels[0]).cuda(), + torch.rand(size[1], in_channels[1]).cuda(), + ) + else: + in_channels = 5 + x = torch.rand(size[0], in_channels).cuda() + out_channels = 2 + + if use_edge_attr: + edge_dim = 3 + edge_attr = torch.rand(edge_index.size(1), edge_dim).cuda() + csc, edge_attr = CuGraphGATConv.to_csc(edge_index, size, edge_attr=edge_attr) + else: + edge_dim = None + edge_attr = None + csc = CuGraphGATConv.to_csc(edge_index, size) + + kwargs = dict(bias=bias, concat=concat, edge_dim=edge_dim) conv1 = GATConv( in_channels, out_channels, heads, add_self_loops=False, **kwargs ).cuda() conv2 = CuGraphGATConv(in_channels, out_channels, heads, **kwargs).cuda() + out_dim = heads * out_channels with torch.no_grad(): - conv2.lin.weight.data[:, :] = conv1.lin_src.weight.data - conv2.att.data[: heads * out_channels] = conv1.att_src.data.flatten() - conv2.att.data[heads * out_channels :] = conv1.att_dst.data.flatten() + if bipartite: + conv2.lin_src.weight.data = conv1.lin_src.weight.data.detach().clone() + conv2.lin_dst.weight.data = conv1.lin_dst.weight.data.detach().clone() + else: + conv2.lin.weight.data = conv1.lin_src.weight.data.detach().clone() - if bipartite: - out1 = conv1((x, x[: size[1]]), edge_index) - else: - out1 = conv1(x, edge_index) + conv2.att.data[:out_dim] = conv1.att_src.data.flatten() + conv2.att.data[out_dim : 2 * out_dim] = conv1.att_dst.data.flatten() + if use_edge_attr: + conv2.att.data[2 * out_dim :] = conv1.att_edge.data.flatten() + conv2.lin_edge.weight.data = conv1.lin_edge.weight.data.detach().clone() - csc = CuGraphGATConv.to_csc(edge_index, size) - out2 = conv2(x, csc, max_num_neighbors=max_num_neighbors) + out1 = conv1(x, edge_index, edge_attr=edge_attr) + out2 = conv2(x, csc, edge_attr=edge_attr, max_num_neighbors=max_num_neighbors) assert torch.allclose(out1, out2, atol=1e-3) grad_output = torch.rand_like(out1) out1.backward(grad_output) out2.backward(grad_output) - assert torch.allclose(conv1.lin_src.weight.grad, conv2.lin.weight.grad, atol=1e-3) + if bipartite: + assert torch.allclose( + conv1.lin_src.weight.grad, conv2.lin_src.weight.grad, atol=1e-3 + ) + assert torch.allclose( + conv1.lin_dst.weight.grad, conv2.lin_dst.weight.grad, atol=1e-3 + ) + else: + assert torch.allclose( + conv1.lin_src.weight.grad, conv2.lin.weight.grad, atol=1e-3 + ) + assert torch.allclose( - conv1.att_src.grad.flatten(), conv2.att.grad[: heads * out_channels], atol=1e-3 + conv1.att_src.grad.flatten(), conv2.att.grad[:out_dim], atol=1e-3 ) assert torch.allclose( - conv1.att_dst.grad.flatten(), conv2.att.grad[heads * out_channels :], atol=1e-3 + conv1.att_dst.grad.flatten(), conv2.att.grad[out_dim : 2 * out_dim], atol=1e-3 ) + + if use_edge_attr: + assert torch.allclose( + conv1.att_edge.grad.flatten(), conv2.att.grad[2 * out_dim :], atol=1e-3 + ) + assert torch.allclose( + conv1.lin_edge.weight.grad, conv2.lin_edge.weight.grad, atol=1e-3 + ) + if bias: assert torch.allclose(conv1.bias.grad, conv2.bias.grad, atol=1e-3) From 72d145bc544709879bcb841c4c34f27fd0e27e8f Mon Sep 17 00:00:00 2001 From: Tingyu Wang Date: Wed, 26 Apr 2023 16:23:20 -0400 Subject: [PATCH 10/17] fix gatconv test --- .../cugraph_pyg/tests/nn/test_gat_conv.py | 27 ++++++++++--------- .../tests/nn/test_transformer_conv.py | 4 +-- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/python/cugraph-pyg/cugraph_pyg/tests/nn/test_gat_conv.py b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_gat_conv.py index 3b74cb155f3..ae5fd73c438 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/nn/test_gat_conv.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_gat_conv.py @@ -33,6 +33,7 @@ def test_gat_conv_equality( bias, bipartite, concat, heads, max_num_neighbors, use_edge_attr ): + atol = 1e-6 edge_index = torch.tensor( [ [7, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 8, 9], @@ -55,10 +56,12 @@ def test_gat_conv_equality( if use_edge_attr: edge_dim = 3 edge_attr = torch.rand(edge_index.size(1), edge_dim).cuda() - csc, edge_attr = CuGraphGATConv.to_csc(edge_index, size, edge_attr=edge_attr) + csc, edge_attr_perm = CuGraphGATConv.to_csc( + edge_index, size, edge_attr=edge_attr + ) else: edge_dim = None - edge_attr = None + edge_attr = edge_attr_perm = None csc = CuGraphGATConv.to_csc(edge_index, size) kwargs = dict(bias=bias, concat=concat, edge_dim=edge_dim) @@ -83,8 +86,8 @@ def test_gat_conv_equality( conv2.lin_edge.weight.data = conv1.lin_edge.weight.data.detach().clone() out1 = conv1(x, edge_index, edge_attr=edge_attr) - out2 = conv2(x, csc, edge_attr=edge_attr, max_num_neighbors=max_num_neighbors) - assert torch.allclose(out1, out2, atol=1e-3) + out2 = conv2(x, csc, edge_attr=edge_attr_perm, max_num_neighbors=max_num_neighbors) + assert torch.allclose(out1, out2, atol=atol) grad_output = torch.rand_like(out1) out1.backward(grad_output) @@ -92,30 +95,30 @@ def test_gat_conv_equality( if bipartite: assert torch.allclose( - conv1.lin_src.weight.grad, conv2.lin_src.weight.grad, atol=1e-3 + conv1.lin_src.weight.grad, conv2.lin_src.weight.grad, atol=atol ) assert torch.allclose( - conv1.lin_dst.weight.grad, conv2.lin_dst.weight.grad, atol=1e-3 + conv1.lin_dst.weight.grad, conv2.lin_dst.weight.grad, atol=atol ) else: assert torch.allclose( - conv1.lin_src.weight.grad, conv2.lin.weight.grad, atol=1e-3 + conv1.lin_src.weight.grad, conv2.lin.weight.grad, atol=atol ) assert torch.allclose( - conv1.att_src.grad.flatten(), conv2.att.grad[:out_dim], atol=1e-3 + conv1.att_src.grad.flatten(), conv2.att.grad[:out_dim], atol=atol ) assert torch.allclose( - conv1.att_dst.grad.flatten(), conv2.att.grad[out_dim : 2 * out_dim], atol=1e-3 + conv1.att_dst.grad.flatten(), conv2.att.grad[out_dim : 2 * out_dim], atol=atol ) if use_edge_attr: assert torch.allclose( - conv1.att_edge.grad.flatten(), conv2.att.grad[2 * out_dim :], atol=1e-3 + conv1.att_edge.grad.flatten(), conv2.att.grad[2 * out_dim :], atol=atol ) assert torch.allclose( - conv1.lin_edge.weight.grad, conv2.lin_edge.weight.grad, atol=1e-3 + conv1.lin_edge.weight.grad, conv2.lin_edge.weight.grad, atol=atol ) if bias: - assert torch.allclose(conv1.bias.grad, conv2.bias.grad, atol=1e-3) + assert torch.allclose(conv1.bias.grad, conv2.bias.grad, atol=atol) diff --git a/python/cugraph-pyg/cugraph_pyg/tests/nn/test_transformer_conv.py b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_transformer_conv.py index c36df1ea410..a2153ee7891 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/nn/test_transformer_conv.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_transformer_conv.py @@ -26,7 +26,7 @@ @pytest.mark.parametrize("bipartite", [True, False]) @pytest.mark.parametrize("concat", [True, False]) -@pytest.mark.parametrize("heads", [1, 2, 3]) +@pytest.mark.parametrize("heads", [1, 2, 3, 5, 10, 16]) def test_transformer_conv_equality(bipartite, concat, heads): out_channels = 2 size = (10, 10) @@ -65,7 +65,7 @@ def test_transformer_conv_equality(bipartite, concat, heads): csc = CuGraphTransformerConv.to_csc(edge_index, size) out2 = conv2(x, csc) - atol = 1e-5 + atol = 1e-6 assert torch.allclose(out1, out2, atol=atol) From 6e8e92b12f6d0ec27b89b7df300e0b5dba7030fb Mon Sep 17 00:00:00 2001 From: Tingyu Wang Date: Wed, 26 Apr 2023 18:28:02 -0400 Subject: [PATCH 11/17] add GATv2Conv --- .../cugraph_pyg/nn/conv/__init__.py | 2 + .../cugraph_pyg/nn/conv/gat_conv.py | 5 +- .../cugraph_pyg/nn/conv/gatv2_conv.py | 218 ++++++++++++++++++ .../cugraph_pyg/tests/nn/test_gatv2_conv.py | 95 ++++++++ 4 files changed, 318 insertions(+), 2 deletions(-) create mode 100644 python/cugraph-pyg/cugraph_pyg/nn/conv/gatv2_conv.py create mode 100644 python/cugraph-pyg/cugraph_pyg/tests/nn/test_gatv2_conv.py diff --git a/python/cugraph-pyg/cugraph_pyg/nn/conv/__init__.py b/python/cugraph-pyg/cugraph_pyg/nn/conv/__init__.py index d2317f7ae53..0c94be5e12b 100644 --- a/python/cugraph-pyg/cugraph_pyg/nn/conv/__init__.py +++ b/python/cugraph-pyg/cugraph_pyg/nn/conv/__init__.py @@ -12,9 +12,11 @@ # limitations under the License. from .gat_conv import GATConv +from .gatv2_conv import GATv2Conv from .transformer_conv import TransformerConv __all__ = [ "GATConv", + "GATv2Conv", "TransformerConv", ] diff --git a/python/cugraph-pyg/cugraph_pyg/nn/conv/gat_conv.py b/python/cugraph-pyg/cugraph_pyg/nn/conv/gat_conv.py index a0807b81183..03136445181 100644 --- a/python/cugraph-pyg/cugraph_pyg/nn/conv/gat_conv.py +++ b/python/cugraph-pyg/cugraph_pyg/nn/conv/gat_conv.py @@ -132,7 +132,7 @@ def reset_parameters(self): if self.lin_edge is not None: self.lin_edge.reset_parameters() if self.bias is not None: - self.bias.data.fill_(0.0) + nn.init.zeros_(self.bias.data) def forward( self, @@ -162,7 +162,8 @@ def forward( csc, bipartite=bipartite, max_num_neighbors=max_num_neighbors ) - if edge_attr is not None and self.lin_edge is not None: + if edge_attr is not None: + assert self.lin_edge is not None if edge_attr.dim() == 1: edge_attr = edge_attr.view(-1, 1) edge_attr = self.lin_edge(edge_attr) diff --git a/python/cugraph-pyg/cugraph_pyg/nn/conv/gatv2_conv.py b/python/cugraph-pyg/cugraph_pyg/nn/conv/gatv2_conv.py new file mode 100644 index 00000000000..471c927c7e1 --- /dev/null +++ b/python/cugraph-pyg/cugraph_pyg/nn/conv/gatv2_conv.py @@ -0,0 +1,218 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Tuple, Union + +from pylibcugraphops.pytorch.operators import mha_gat_v2_n2n, mha_gat_v2_n2n_bipartite + +from cugraph.utilities.utils import import_optional + +from .base import BaseConv + +torch = import_optional("torch") +nn = import_optional("torch.nn") + + +class GATv2Conv(BaseConv): + r"""The GATv2 operator from the `"How Attentive are Graph Attention + Networks?" `_ paper, which fixes the + static attention problem of the standard + :class:`~torch_geometric.conv.GATConv` layer. + Since the linear layers in the standard GAT are applied right after each + other, the ranking of attended nodes is unconditioned on the query node. + In contrast, in :class:`GATv2`, every node can attend to any other node. + + .. math:: + \mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}\mathbf{x}_{i} + + \sum_{j \in \mathcal{N}(i)} \alpha_{i,j}\mathbf{\Theta}\mathbf{x}_{j}, + + where the attention coefficients :math:`\alpha_{i,j}` are computed as + + .. math:: + \alpha_{i,j} = + \frac{ + \exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left(\mathbf{\Theta} + [\mathbf{x}_i \, \Vert \, \mathbf{x}_j] + \right)\right)} + {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} + \exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left(\mathbf{\Theta} + [\mathbf{x}_i \, \Vert \, \mathbf{x}_k] + \right)\right)}. + + If the graph has multi-dimensional edge features :math:`\mathbf{e}_{i,j}`, + the attention coefficients :math:`\alpha_{i,j}` are computed as + + .. math:: + \alpha_{i,j} = + \frac{ + \exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left(\mathbf{\Theta} + [\mathbf{x}_i \, \Vert \, \mathbf{x}_j \, \Vert \, \mathbf{e}_{i,j}] + \right)\right)} + {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} + \exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left(\mathbf{\Theta} + [\mathbf{x}_i \, \Vert \, \mathbf{x}_k \, \Vert \, \mathbf{e}_{i,k}] + \right)\right)}. + + Args: + in_channels (int or tuple): Size of each input sample, or :obj:`-1` to + derive the size from the first input(s) to the forward method. + A tuple corresponds to the sizes of source and target + dimensionalities. + out_channels (int): Size of each output sample. + heads (int, optional): Number of multi-head-attentions. + (default: :obj:`1`) + concat (bool, optional): If set to :obj:`False`, the multi-head + attentions are averaged instead of concatenated. + (default: :obj:`True`) + negative_slope (float, optional): LeakyReLU angle of the negative + slope. (default: :obj:`0.2`) + edge_dim (int, optional): Edge feature dimensionality (in case + there are any). (default: :obj:`None`) + bias (bool, optional): If set to :obj:`False`, the layer will not learn + an additive bias. (default: :obj:`True`) + share_weights (bool, optional): If set to :obj:`True`, the same matrix + will be applied to the source and the target node of every edge. + (default: :obj:`False`) + """ + + def __init__( + self, + in_channels: Union[int, Tuple[int, int]], + out_channels: int, + heads: int = 1, + concat: bool = True, + negative_slope: float = 0.2, + edge_dim: Optional[int] = None, + bias: bool = True, + share_weights: bool = False, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.heads = heads + self.concat = concat + self.negative_slope = negative_slope + self.edge_dim = edge_dim + self.share_weights = share_weights + + if isinstance(in_channels, int): + self.lin_src = nn.Linear(in_channels, heads * out_channels, bias=False) + + if share_weights: + self.lin_dst = self.lin_src + else: + self.lin_dst = nn.Linear(in_channels, heads * out_channels, bias=False) + else: + self.lin_src = nn.Linear(in_channels[0], heads * out_channels, bias=False) + self.lin_dst = nn.Linear(in_channels[1], heads * out_channels, bias=False) + + self.att = nn.Parameter(torch.Tensor(heads * out_channels)) + + if edge_dim is not None: + self.lin_edge = nn.Linear(edge_dim, heads * out_channels, bias=False) + else: + self.register_parameter("lin_edge", None) + + if bias and concat: + self.bias = nn.Parameter(torch.Tensor(heads * out_channels)) + elif bias and not concat: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter("bias", None) + + self.reset_parameters() + + def reset_parameters(self): + self.lin_src.reset_parameters() + self.lin_dst.reset_parameters() + + gain = torch.nn.init.calculate_gain("relu") + torch.nn.init.xavier_normal_( + self.att.view(-1, self.heads, self.out_channels), gain=gain + ) + + if self.lin_edge is not None: + self.lin_edge.reset_parameters() + if self.bias is not None: + nn.init.zeros_(self.bias.data) + + def forward( + self, + x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + csc: Tuple[torch.Tensor, torch.Tensor, int], + edge_attr: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + r"""Runs the forward pass of the module. + + Args: + x (torch.Tensor or tuple): The node features. Can be a tuple of + tensors denoting source and destination node features. + csc ((torch.Tensor, torch.Tensor, int)): A tuple containing the CSC + representation of a graph, given as a tuple of + :obj:`(row, colptr, num_src_nodes)`. Use the + :meth:`to_csc` method to convert an :obj:`edge_index` + representation to the desired format. + edge_attr: (torch.Tensor, optional) The edge features. + """ + bipartite = not isinstance(x, torch.Tensor) + graph = self.get_cugraph(csc, bipartite=bipartite or not self.share_weights) + + if edge_attr is not None: + assert self.lin_edge is not None + if edge_attr.dim() == 1: + edge_attr = edge_attr.view(-1, 1) + edge_attr = self.lin_edge(edge_attr) + + if not bipartite and self.share_weights: + x = self.lin_src(x) + + out = mha_gat_v2_n2n( + x, + self.att, + graph, + num_heads=self.heads, + activation="LeakyReLU", + negative_slope=self.negative_slope, + concat_heads=self.concat, + edge_feat=edge_attr, + ) + else: + if bipartite: + x_src = self.lin_src(x[0]) + x_dst = self.lin_dst(x[1]) + else: + x_src = self.lin_src(x) + x_dst = self.lin_dst(x) + + out = mha_gat_v2_n2n_bipartite( + x_src, + x_dst, + self.att, + graph, + num_heads=self.heads, + activation="LeakyReLU", + negative_slope=self.negative_slope, + concat_heads=self.concat, + edge_feat=edge_attr, + ) + + if self.bias is not None: + out = out + self.bias + + return out + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}({self.in_channels}, " + f"{self.out_channels}, heads={self.heads})" + ) diff --git a/python/cugraph-pyg/cugraph_pyg/tests/nn/test_gatv2_conv.py b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_gatv2_conv.py new file mode 100644 index 00000000000..1c4f241304e --- /dev/null +++ b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_gatv2_conv.py @@ -0,0 +1,95 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +try: + from torch_geometric.nn import GATv2Conv +except ModuleNotFoundError: + pytest.skip("PyG not available", allow_module_level=True) + +from cugraph.utilities.utils import import_optional +from cugraph_pyg.nn import GATv2Conv as CuGraphGATv2Conv + +torch = import_optional("torch") + + +@pytest.mark.parametrize("bipartite", [True, False]) +@pytest.mark.parametrize("concat", [True, False]) +@pytest.mark.parametrize("heads", [1, 2, 3, 5, 10, 16]) +@pytest.mark.parametrize("use_edge_attr", [True, False]) +def test_gatv2_conv_equality(bipartite, concat, heads, use_edge_attr): + atol = 1e-6 + edge_index = torch.tensor( + [ + [7, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 8, 9], + [0, 1, 2, 3, 4, 5, 6, 7, 7, 7, 7, 7, 7, 7, 7], + ], + ).cuda() + size = (10, 10) + + if bipartite: + in_channels = (5, 3) + x = ( + torch.rand(size[0], in_channels[0]).cuda(), + torch.rand(size[1], in_channels[1]).cuda(), + ) + else: + in_channels = 5 + x = torch.rand(size[0], in_channels).cuda() + out_channels = 2 + + if use_edge_attr: + edge_dim = 3 + edge_attr = torch.rand(edge_index.size(1), edge_dim).cuda() + csc, edge_attr_perm = CuGraphGATv2Conv.to_csc( + edge_index, size, edge_attr=edge_attr + ) + else: + edge_dim = None + edge_attr = edge_attr_perm = None + csc = CuGraphGATv2Conv.to_csc(edge_index, size) + + kwargs = dict(bias=False, concat=concat, edge_dim=edge_dim) + + conv1 = GATv2Conv( + in_channels, out_channels, heads, add_self_loops=False, **kwargs + ).cuda() + conv2 = CuGraphGATv2Conv(in_channels, out_channels, heads, **kwargs).cuda() + + with torch.no_grad(): + conv2.lin_src.weight.data = conv1.lin_l.weight.data.detach().clone() + conv2.lin_dst.weight.data = conv1.lin_r.weight.data.detach().clone() + + conv2.att.data = conv1.att.data.flatten().detach().clone() + + if use_edge_attr: + conv2.lin_edge.weight.data = conv1.lin_edge.weight.data.detach().clone() + + out1 = conv1(x, edge_index, edge_attr=edge_attr) + out2 = conv2(x, csc, edge_attr=edge_attr_perm) + assert torch.allclose(out1, out2, atol=atol) + + grad_output = torch.rand_like(out1) + out1.backward(grad_output) + out2.backward(grad_output) + + assert torch.allclose(conv1.lin_l.weight.grad, conv2.lin_src.weight.grad, atol=atol) + assert torch.allclose(conv1.lin_r.weight.grad, conv2.lin_dst.weight.grad, atol=atol) + + assert torch.allclose(conv1.att.grad.flatten(), conv2.att.grad, atol=atol) + + if use_edge_attr: + assert torch.allclose( + conv1.lin_edge.weight.grad, conv2.lin_edge.weight.grad, atol=atol + ) From aba9237ba80d67c692914ad1e6e7826e389946bc Mon Sep 17 00:00:00 2001 From: Tingyu Wang Date: Wed, 26 Apr 2023 18:28:47 -0400 Subject: [PATCH 12/17] Update ci/test_python.sh Co-authored-by: AJ Schmidt --- ci/test_python.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/ci/test_python.sh b/ci/test_python.sh index 569edb4ab21..f1ddf08df97 100755 --- a/ci/test_python.sh +++ b/ci/test_python.sh @@ -179,7 +179,6 @@ if [[ "${RAPIDS_CUDA_VERSION}" == "11.8.0" ]]; then --channel "${PYTHON_CHANNEL}" \ libcugraph \ pylibcugraph \ - pylibcugraphops \ cugraph \ cugraph-pyg From 00fc751b5afa84aa47de6c68a026cec16371e6a6 Mon Sep 17 00:00:00 2001 From: Tingyu Wang Date: Fri, 28 Apr 2023 14:15:06 -0400 Subject: [PATCH 13/17] use pyg linear operator and init functions --- .../cugraph_pyg/nn/conv/gat_conv.py | 40 ++++++++++++---- .../cugraph_pyg/nn/conv/gatv2_conv.py | 47 ++++++++++++++----- .../cugraph_pyg/nn/conv/transformer_conv.py | 24 +++++----- 3 files changed, 79 insertions(+), 32 deletions(-) diff --git a/python/cugraph-pyg/cugraph_pyg/nn/conv/gat_conv.py b/python/cugraph-pyg/cugraph_pyg/nn/conv/gat_conv.py index 03136445181..a0ec2909071 100644 --- a/python/cugraph-pyg/cugraph_pyg/nn/conv/gat_conv.py +++ b/python/cugraph-pyg/cugraph_pyg/nn/conv/gat_conv.py @@ -20,6 +20,7 @@ torch = import_optional("torch") nn = import_optional("torch.nn") +torch_geometric = import_optional("torch_geometric") class GATConv(BaseConv): @@ -95,14 +96,36 @@ def __init__( self.negative_slope = negative_slope self.edge_dim = edge_dim + Linear = torch_geometric.nn.Linear + if isinstance(in_channels, int): - self.lin = nn.Linear(in_channels, heads * out_channels, bias=False) + self.lin = Linear( + in_channels, + heads * out_channels, + bias=False, + weight_initializer="glorot", + ) else: - self.lin_src = nn.Linear(in_channels[0], heads * out_channels, bias=False) - self.lin_dst = nn.Linear(in_channels[1], heads * out_channels, bias=False) + self.lin_src = Linear( + in_channels[0], + heads * out_channels, + bias=False, + weight_initializer="glorot", + ) + self.lin_dst = Linear( + in_channels[1], + heads * out_channels, + bias=False, + weight_initializer="glorot", + ) if edge_dim is not None: - self.lin_edge = nn.Linear(edge_dim, heads * out_channels, bias=False) + self.lin_edge = Linear( + edge_dim, + heads * out_channels, + bias=False, + weight_initializer="glorot", + ) self.att = nn.Parameter(torch.Tensor(3 * heads * out_channels)) else: self.register_parameter("lin_edge", None) @@ -124,15 +147,14 @@ def reset_parameters(self): self.lin_src.reset_parameters() self.lin_dst.reset_parameters() - gain = torch.nn.init.calculate_gain("relu") - torch.nn.init.xavier_normal_( - self.att.view(-1, self.heads, self.out_channels), gain=gain + torch_geometric.nn.inits.glorot( + self.att.view(-1, self.heads, self.out_channels) ) if self.lin_edge is not None: self.lin_edge.reset_parameters() - if self.bias is not None: - nn.init.zeros_(self.bias.data) + + torch_geometric.nn.inits.zeros(self.bias) def forward( self, diff --git a/python/cugraph-pyg/cugraph_pyg/nn/conv/gatv2_conv.py b/python/cugraph-pyg/cugraph_pyg/nn/conv/gatv2_conv.py index 471c927c7e1..abca817fdae 100644 --- a/python/cugraph-pyg/cugraph_pyg/nn/conv/gatv2_conv.py +++ b/python/cugraph-pyg/cugraph_pyg/nn/conv/gatv2_conv.py @@ -20,6 +20,7 @@ torch = import_optional("torch") nn = import_optional("torch.nn") +torch_geometric = import_optional("torch_geometric") class GATv2Conv(BaseConv): @@ -105,21 +106,45 @@ def __init__( self.edge_dim = edge_dim self.share_weights = share_weights + Linear = torch_geometric.nn.Linear + if isinstance(in_channels, int): - self.lin_src = nn.Linear(in_channels, heads * out_channels, bias=False) + self.lin_src = Linear( + in_channels, + heads * out_channels, + bias=bias, + weight_initializer="glorot", + ) if share_weights: self.lin_dst = self.lin_src else: - self.lin_dst = nn.Linear(in_channels, heads * out_channels, bias=False) + self.lin_dst = Linear( + in_channels, + heads * out_channels, + bias=bias, + weight_initializer="glorot", + ) else: - self.lin_src = nn.Linear(in_channels[0], heads * out_channels, bias=False) - self.lin_dst = nn.Linear(in_channels[1], heads * out_channels, bias=False) + self.lin_src = Linear( + in_channels[0], + heads * out_channels, + bias=bias, + weight_initializer="glorot", + ) + self.lin_dst = Linear( + in_channels[1], + heads * out_channels, + bias=bias, + weight_initializer="glorot", + ) self.att = nn.Parameter(torch.Tensor(heads * out_channels)) if edge_dim is not None: - self.lin_edge = nn.Linear(edge_dim, heads * out_channels, bias=False) + self.lin_edge = Linear( + edge_dim, heads * out_channels, bias=False, weight_initializer="glorot" + ) else: self.register_parameter("lin_edge", None) @@ -135,16 +160,14 @@ def __init__( def reset_parameters(self): self.lin_src.reset_parameters() self.lin_dst.reset_parameters() + if self.lin_edge is not None: + self.lin_edge.reset_parameters() - gain = torch.nn.init.calculate_gain("relu") - torch.nn.init.xavier_normal_( - self.att.view(-1, self.heads, self.out_channels), gain=gain + torch_geometric.nn.inits.glorot( + self.att.view(-1, self.heads, self.out_channels) ) - if self.lin_edge is not None: - self.lin_edge.reset_parameters() - if self.bias is not None: - nn.init.zeros_(self.bias.data) + torch_geometric.nn.inits.zeros(self.bias) def forward( self, diff --git a/python/cugraph-pyg/cugraph_pyg/nn/conv/transformer_conv.py b/python/cugraph-pyg/cugraph_pyg/nn/conv/transformer_conv.py index 90552c5b1f7..d4d984467ae 100644 --- a/python/cugraph-pyg/cugraph_pyg/nn/conv/transformer_conv.py +++ b/python/cugraph-pyg/cugraph_pyg/nn/conv/transformer_conv.py @@ -20,6 +20,7 @@ torch = import_optional("torch") nn = import_optional("torch.nn") +torch_geometric = import_optional("torch_geometric") class TransformerConv(BaseConv): @@ -114,25 +115,26 @@ def __init__( if isinstance(in_channels, int): in_channels = (in_channels, in_channels) - self.lin_key = nn.Linear(in_channels[0], heads * out_channels) - self.lin_query = nn.Linear(in_channels[1], heads * out_channels) - self.lin_value = nn.Linear(in_channels[0], heads * out_channels) + Linear = torch_geometric.nn.Linear + self.lin_key = Linear(in_channels[0], heads * out_channels) + self.lin_query = Linear(in_channels[1], heads * out_channels) + self.lin_value = Linear(in_channels[0], heads * out_channels) if edge_dim is not None: - self.lin_edge = nn.Linear(edge_dim, heads * out_channels, bias=False) + self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False) else: self.lin_edge = self.register_parameter("lin_edge", None) if concat: - self.lin_skip = nn.Linear(in_channels[1], heads * out_channels, bias=bias) + self.lin_skip = Linear(in_channels[1], heads * out_channels, bias=bias) if self.beta: - self.lin_beta = nn.Linear(3 * heads * out_channels, 1, bias=bias) + self.lin_beta = Linear(3 * heads * out_channels, 1, bias=False) else: self.lin_beta = self.register_parameter("lin_beta", None) else: - self.lin_skip = nn.Linear(in_channels[1], out_channels, bias=bias) + self.lin_skip = Linear(in_channels[1], out_channels, bias=bias) if self.beta: - self.lin_beta = nn.Linear(3 * out_channels, 1, bias=False) + self.lin_beta = Linear(3 * out_channels, 1, bias=False) else: self.lin_beta = self.register_parameter("lin_beta", None) @@ -144,8 +146,7 @@ def reset_parameters(self): self.lin_value.reset_parameters() if self.lin_edge is not None: self.lin_edge.reset_parameters() - if self.lin_skip is not None: - self.lin_skip.reset_parameters() + self.lin_skip.reset_parameters() if self.lin_beta is not None: self.lin_beta.reset_parameters() @@ -177,7 +178,8 @@ def forward( key = self.lin_key(x[0]) value = self.lin_value(x[0]) - if self.lin_edge is not None and edge_attr is not None: + if self.lin_edge is not None: + assert edge_attr is not None edge_attr = self.lin_edge(edge_attr) out = TransformerConvAgg( From 6a79b175560365b248de127aff2de1485d87cbb2 Mon Sep 17 00:00:00 2001 From: Tingyu Wang Date: Fri, 28 Apr 2023 16:31:35 -0400 Subject: [PATCH 14/17] assert linear operators in gatconv --- python/cugraph-pyg/cugraph_pyg/nn/conv/gat_conv.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/cugraph-pyg/cugraph_pyg/nn/conv/gat_conv.py b/python/cugraph-pyg/cugraph_pyg/nn/conv/gat_conv.py index a0ec2909071..322e1d00635 100644 --- a/python/cugraph-pyg/cugraph_pyg/nn/conv/gat_conv.py +++ b/python/cugraph-pyg/cugraph_pyg/nn/conv/gat_conv.py @@ -191,6 +191,7 @@ def forward( edge_attr = self.lin_edge(edge_attr) if bipartite: + assert hasattr(self, "lin_src") x_src = self.lin_src(x[0]) x_dst = self.lin_dst(x[1]) @@ -207,6 +208,7 @@ def forward( ) else: + assert hasattr(self, "lin") x = self.lin(x) out = mha_gat_n2n( From 6044257246def57f9cac6b34504bb5e62e63cc6c Mon Sep 17 00:00:00 2001 From: Tingyu Wang Date: Tue, 9 May 2023 14:08:45 -0400 Subject: [PATCH 15/17] add back pylibcugraphops for cugraph_pyg testing --- ci/test_python.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/ci/test_python.sh b/ci/test_python.sh index ca216d662e9..2ee340f6f5c 100755 --- a/ci/test_python.sh +++ b/ci/test_python.sh @@ -179,6 +179,7 @@ if [[ "${RAPIDS_CUDA_VERSION}" == "11.8.0" ]]; then --channel "${PYTHON_CHANNEL}" \ libcugraph \ pylibcugraph \ + pylibcugraphops \ cugraph \ cugraph-pyg From f5a712ef61a10b0ba2b0e5ce28a31408f78121a0 Mon Sep 17 00:00:00 2001 From: Tingyu Wang Date: Wed, 10 May 2023 17:33:32 -0400 Subject: [PATCH 16/17] remove assertions, raise RuntimeError instead --- .../cugraph_pyg/nn/conv/gat_conv.py | 19 ++++++++++++++++--- .../cugraph_pyg/nn/conv/gatv2_conv.py | 6 +++++- .../cugraph_pyg/nn/conv/transformer_conv.py | 8 ++++++-- 3 files changed, 27 insertions(+), 6 deletions(-) diff --git a/python/cugraph-pyg/cugraph_pyg/nn/conv/gat_conv.py b/python/cugraph-pyg/cugraph_pyg/nn/conv/gat_conv.py index 322e1d00635..4bf37cf3e72 100644 --- a/python/cugraph-pyg/cugraph_pyg/nn/conv/gat_conv.py +++ b/python/cugraph-pyg/cugraph_pyg/nn/conv/gat_conv.py @@ -185,13 +185,22 @@ def forward( ) if edge_attr is not None: - assert self.lin_edge is not None + if self.lin_edge is None: + raise RuntimeError( + f"{self.__class__.__name__}.edge_dim must be set to accept " + f"edge features." + ) if edge_attr.dim() == 1: edge_attr = edge_attr.view(-1, 1) edge_attr = self.lin_edge(edge_attr) if bipartite: - assert hasattr(self, "lin_src") + if not hasattr(self, "lin_src"): + raise RuntimeError( + f"{self.__class__.__name__}.in_channels must be a pair of " + f"integers to allow bipartite node features, but got " + f"{self.in_channels}." + ) x_src = self.lin_src(x[0]) x_dst = self.lin_dst(x[1]) @@ -208,7 +217,11 @@ def forward( ) else: - assert hasattr(self, "lin") + if not hasattr(self, "lin"): + raise RuntimeError( + f"{self.__class__.__name__}.in_channels is expected to be an " + f"integer, but got {self.in_channels}." + ) x = self.lin(x) out = mha_gat_n2n( diff --git a/python/cugraph-pyg/cugraph_pyg/nn/conv/gatv2_conv.py b/python/cugraph-pyg/cugraph_pyg/nn/conv/gatv2_conv.py index abca817fdae..66d962b3f86 100644 --- a/python/cugraph-pyg/cugraph_pyg/nn/conv/gatv2_conv.py +++ b/python/cugraph-pyg/cugraph_pyg/nn/conv/gatv2_conv.py @@ -191,7 +191,11 @@ def forward( graph = self.get_cugraph(csc, bipartite=bipartite or not self.share_weights) if edge_attr is not None: - assert self.lin_edge is not None + if self.lin_edge is None: + raise RuntimeError( + f"{self.__class__.__name__}.edge_dim must be set to accept " + f"edge features." + ) if edge_attr.dim() == 1: edge_attr = edge_attr.view(-1, 1) edge_attr = self.lin_edge(edge_attr) diff --git a/python/cugraph-pyg/cugraph_pyg/nn/conv/transformer_conv.py b/python/cugraph-pyg/cugraph_pyg/nn/conv/transformer_conv.py index d4d984467ae..aeb51c028ae 100644 --- a/python/cugraph-pyg/cugraph_pyg/nn/conv/transformer_conv.py +++ b/python/cugraph-pyg/cugraph_pyg/nn/conv/transformer_conv.py @@ -178,8 +178,12 @@ def forward( key = self.lin_key(x[0]) value = self.lin_value(x[0]) - if self.lin_edge is not None: - assert edge_attr is not None + if edge_attr is not None: + if self.lin_edge is None: + raise RuntimeError( + f"{self.__class__.__name__}.edge_dim must be set to accept " + f"edge features." + ) edge_attr = self.lin_edge(edge_attr) out = TransformerConvAgg( From 94579a884e02e1ff110a37e65fa2bb1c3c91526f Mon Sep 17 00:00:00 2001 From: Tingyu Wang Date: Thu, 11 May 2023 14:10:40 -0400 Subject: [PATCH 17/17] empty commit