Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MessagePassing] Investigate jittable conversion on instantiation of GNN layers #8745

Merged
merged 10 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- GNN layers are now jittable by default ([#8745](https://github.com/pyg-team/pytorch_geometric/pull/8745))
- Sparse node features in `NELL` and `AttributedGraphDataset` are now represented as `torch.sparse_csr_tensor` instead of `torch_sparse.SparseTensor` ([#8679](https://github.com/pyg-team/pytorch_geometric/pull/8679))
- Accelerated mini-batching of `torch.sparse` tensors ([#8670](https://github.com/pyg-team/pytorch_geometric/pull/8670))
- Fixed RPC timeout due to worker closing in `DistLoader` with `atexit` not executed correctly in `worker_init_fn` ([#8605](https://github.com/pyg-team/pytorch_geometric/pull/8605))
Expand Down
17 changes: 17 additions & 0 deletions test/nn/conv/test_propagate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import torch

from torch_geometric.nn import SAGEConv


def test_propagate():
SAGEConv.jit_on_init = True
conv = SAGEConv(10, 16)
SAGEConv.jit_on_init = False

old_propagate = conv.__class__.propagate
old_collect = conv.__class__._collect
conv.__class__.propagate = conv.propagate_jit
conv.__class__._collect = conv._collect_jit
torch.jit.script(conv)
conv.__class__.propagate = old_propagate
conv.__class__._collect = old_collect
4 changes: 1 addition & 3 deletions torch_geometric/home.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@ def get_home_dir() -> str:
if _home_dir is not None:
return _home_dir

home_dir = os.getenv(ENV_PYG_HOME, DEFAULT_CACHE_DIR)
home_dir = osp.expanduser(home_dir)
return home_dir
return osp.expanduser(os.getenv(ENV_PYG_HOME, DEFAULT_CACHE_DIR))


def set_home_dir(path: str) -> None:
Expand Down
3 changes: 2 additions & 1 deletion torch_geometric/nn/conv/fa_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,8 @@ def forward( # noqa: F811
alpha_l = self.att_l(x)
alpha_r = self.att_r(x)

# propagate_type: (x: Tensor, alpha: PairTensor, edge_weight: OptTensor) # noqa
# propagate_type: (x: Tensor, alpha: PairTensor,
# edge_weight: OptTensor)
out = self.propagate(edge_index, x=x, alpha=(alpha_l, alpha_r),
edge_weight=edge_weight)

Expand Down
3 changes: 2 additions & 1 deletion torch_geometric/nn/conv/heat_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ def forward(self, x: Tensor, edge_index: Adj, node_type: Tensor,
edge_type_emb = F.leaky_relu(self.edge_type_emb(edge_type),
self.negative_slope)

# propagate_type: (x: Tensor, edge_type_emb: Tensor, edge_attr: OptTensor) # noqa
# propagate_type: (x: Tensor, edge_type_emb: Tensor,
# edge_attr: OptTensor)
out = self.propagate(edge_index, x=x, edge_type_emb=edge_type_emb,
edge_attr=edge_attr)

Expand Down
93 changes: 63 additions & 30 deletions torch_geometric/nn/conv/message_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@
from torch.utils.hooks import RemovableHandle

from torch_geometric.nn.aggr import Aggregation
from torch_geometric.nn.conv.propagate import (
find_parenthesis_content,
module_from_template,
type_hint_to_str,
)
from torch_geometric.nn.conv.utils.inspector import Inspector
from torch_geometric.nn.conv.utils.jit import class_from_module_repr
from torch_geometric.nn.conv.utils.typing import sanitize, split_types_repr
Expand Down Expand Up @@ -173,6 +178,60 @@
self._edge_update_forward_pre_hooks: HookDict = OrderedDict()
self._edge_update_forward_hooks: HookDict = OrderedDict()

# Test code for performing on-the-fly TorchScript support:
if getattr(self, 'jit_on_init', False):
root_dir = osp.dirname(osp.realpath(__file__))
prop_types, prop_return_type = self._get_propagate_types()
module = module_from_template(
module_name=f'{self.__module__}_propagate',
module=self.__module__,
template_path=osp.join(root_dir, 'propagate.jinja'),
propagate_types=prop_types,
propagate_return_type=prop_return_type,
collect_types=self.inspector.types(
['message', 'aggregate', 'update']),
message_args=self.inspector.keys(['message']),
aggregate_args=self.inspector.keys(['aggregate']),
message_and_aggregate_args=self.inspector.keys(
['message_and_aggregate']),
update_args=self.inspector.keys(['update']),
)

self.propagate_jit = module.propagate
self._collect_jit = module._collect

def _get_propagate_types(self) -> Tuple[Dict[str, str], str]:
# Parse `propagate_types` and generate an efficient `propagate` method:
if hasattr(self, 'propagate_type'):
assert isinstance(self.propagate_type, dict)
propagate_types = {

Check warning on line 207 in torch_geometric/nn/conv/message_passing.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/conv/message_passing.py#L206-L207

Added lines #L206 - L207 were not covered by tests
name: type_hint_to_str(type_hint)
for name, type_hint in self.propagate_type.items()
}
else:
source = inspect.getsource(self.forward)
match = find_parenthesis_content(source, prefix='propagate_type:')
if match is not None:
propagate_types = dict(
[re.split(r'\s*:\s*', t) for t in split_types_repr(match)])
else:
match = find_parenthesis_content(source, 'self.propagate')
if match is None: # No `self.propagate` call:
return

Check warning on line 220 in torch_geometric/nn/conv/message_passing.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/conv/message_passing.py#L220

Added line #L220 was not covered by tests
args = [ # Find all keyword argument names:
arg[:arg.find('=')].strip() for arg in match.split(',')
if arg.find('=') >= 0
]
propagate_types = {
arg: 'Tensor'
for arg in args if arg not in {'edge_index', 'size'}
}

propagate_return_type = type_hint_to_str(
get_type_hints(self.update).get('return', Tensor))

return propagate_types, propagate_return_type

def reset_parameters(self) -> None:
r"""Resets all learnable parameters of the module."""
if self.aggr_module is not None:
Expand Down Expand Up @@ -845,26 +904,7 @@
source = inspect.getsource(self.__class__)

# Find and parse `propagate()` types to format `{arg1: type1, ...}`.
if hasattr(self, 'propagate_type'):
assert isinstance(self.propagate_type, dict)
prop_types = {
k: sanitize(str(v))
for k, v in self.propagate_type.items()
}
else:
match = re.search(r'#\s*propagate_type:\s*\((.*)\)', source)
if match is None:
raise TypeError(
'TorchScript support requires the definition of the types '
'passed to `propagate()`. Please specify them via\n\n'
'propagate_type = {"arg1": type1, "arg2": type2, ... }\n\n'
'or via\n\n'
'# propagate_type: (arg1: type1, arg2: type2, ...)\n\n'
'inside the `MessagePassing` module.')
prop_types = dict([
re.split(r'\s*:\s*', t)
for t in split_types_repr(match.group(1))
])
prop_types, prop_return_type = self._get_propagate_types()

# Find and parse `edge_updater` types to format `{arg1: type1, ...}`.
if 'edge_update' in self.__class__.__dict__.keys():
Expand All @@ -875,7 +915,7 @@
for k, v in self.edge_updater_type.items()
}
else:
match = re.search(r'#\s*edge_updater_type:\s*\((.*)\)', source)
match = find_parenthesis_content(source, 'edge_updater_type:')
if match is None:
raise TypeError(
'TorchScript support requires the definition of the '
Expand All @@ -884,18 +924,11 @@
'"arg2": type2, ... }\n\n or via\n\n'
'# edge_updater_type: (arg1: type1, arg2: type2, ...)'
'\n\ninside the `MessagePassing` module.')
edge_updater_types = dict([
re.split(r'\s*:\s*', t)
for t in split_types_repr(match.group(1))
])
edge_updater_types = dict(
[re.split(r'\s*:\s*', t) for t in split_types_repr(match)])
else:
edge_updater_types = {}

type_hints = get_type_hints(self.__class__.update)
prop_return_type = type_hints.get('return', 'Tensor')
if str(prop_return_type)[:6] == '<class':
prop_return_type = prop_return_type.__name__

type_hints = get_type_hints(self.__class__.edge_update)
edge_updater_return_type = type_hints.get('return', 'Tensor')
if str(edge_updater_return_type)[:6] == '<class':
Expand Down
3 changes: 2 additions & 1 deletion torch_geometric/nn/conv/ppf_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ def forward(
elif isinstance(edge_index, SparseTensor):
edge_index = torch_sparse.set_diag(edge_index)

# propagate_type: (x: PairOptTensor, pos: PairTensor, normal: PairTensor) # noqa
# propagate_type: (x: PairOptTensor, pos: PairTensor,
# normal: PairTensor)
out = self.propagate(edge_index, x=x, pos=pos, normal=normal)

if self.global_nn is not None:
Expand Down
105 changes: 105 additions & 0 deletions torch_geometric/nn/conv/propagate.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import typing
from typing import *

import torch
from torch import Tensor

from torch_geometric.typing import *
from torch_geometric import is_compiling
from torch_geometric.utils import is_sparse

from {{module}} import *


class CollectArgs(NamedTuple):
{%- for name, type_hint in collect_types.items() %}
{{name}}: {{type_hint}}
{%- endfor %}


def _collect(
self,
edge_index: Union[Tensor, SparseTensor],
{%- for name, type_hint in propagate_types.items() %}
{{name}}: {{type_hint}},
{%- endfor %}
size: List[Optional[int]],
) -> CollectArgs:

i, j = (1, 0) if self.flow == 'source_to_target' else (0, 1)
{% for name in collect_types %}
{%- if name.endswith('_i') or name.endswith('_j') %}
# Collect `{{name}}`:
if isinstance({{name[:-2]}}, (tuple, list)):
assert len({{name[:-2]}}) == 2
_{{name[:-2]}}_0, _{{name[:-2]}}_1 = {{name[:-2]}}[0], {{name[:-2]}}[1]
if isinstance(_{{name[:-2]}}_0, Tensor):
self._set_size(size, 0, _{{name[:-2]}}_0)
{%- if name.endswith('_j') %}
{{name}} = self._lift(_{{name[:-2]}}_0, edge_index, {{name[-1]}})
else:
{{name}} = None
{%- endif %}
if isinstance(_{{name[:-2]}}_1, Tensor):
self._set_size(size, 1, _{{name[:-2]}}_1)
{%- if name.endswith('_i') %}
{{name}} = self._lift(_{{name[:-2]}}_1, edge_index, {{name[-1]}})
else:
{{name}} = None
{%- endif %}
elif isinstance({{name[:-2]}}, Tensor):
self._set_size(size, 0, {{name[:-2]}})
self._set_size(size, 1, {{name[:-2]}})
{{name}} = self._lift({{name[:-2]}}, edge_index, {{name[-1]}})
else:
{{name}} = None
{%- endif %}
{%- endfor %}

assert isinstance(edge_index, Tensor)
adj_t = None
edge_index_i = edge_index[i]
edge_index_j = edge_index[j]
ptr = None
index = edge_index_i

size_i = size[i]
size_j = size[j]
dim_size = size_i

return CollectArgs({% for name in collect_types %}{{name}}{% if not loop.last %}, {% endif %}{% endfor %})


def propagate(
self,
edge_index: Union[Tensor, SparseTensor],
{%- for name, type_hint in propagate_types.items() %}
{{name}}: {{type_hint}},
{%- endfor %}
size: Size = None,
) -> {{propagate_return_type}}:

decomposed_layers = 1 if self.explain is True else self.decomposed_layers

if not torch.jit.is_scripting() and not is_compiling():
for hook in self._propagate_forward_pre_hooks.values():
kwargs = dict({% for name in propagate_types %}{{name}}={{name}}{% if not loop.last %}, {% endif %}{% endfor %})
res = hook(self, (edge_index, size, kwargs))
if res is not None:
edge_index, size, kwargs = res
{%- for name, type_hint in propagate_types.items() %}
{{name}} = kwargs['{{name}}']
{%- endfor %}

mutable_size = self._check_input(edge_index, size)
fuse = is_sparse(edge_index) and self.fuse and self.explain is not True

if fuse:
raise NotImplementedError

else:
kwargs = self._collect(edge_index, {% for name in propagate_types %}{{name}}, {% endfor %}mutable_size)
out = self.message({% for name in message_args %}{{name}}=kwargs.{{name}}{% if not loop.last %}, {% endif %}{% endfor %})
out = self.aggregate(out{% for name in aggregate_args %}, {{name}}=kwargs.{{name}}{% endfor %})
out = self.update(out{% for name in update_args %}, {{name}}=kwargs.{{name}}{% endfor %})
return out
65 changes: 65 additions & 0 deletions torch_geometric/nn/conv/propagate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import importlib
import os
import os.path as osp
import re
import sys
from typing import Any, Optional

from jinja2 import Template

from torch_geometric import get_home_dir


def module_from_template(
module_name: str,
template_path: str,
**kwargs: Any,
) -> Any:

if module_name in sys.modules: # If module is already loaded, return it:
return sys.modules[module_name]

Check warning on line 20 in torch_geometric/nn/conv/propagate.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/conv/propagate.py#L20

Added line #L20 was not covered by tests

with open(template_path, 'r') as f:
template = Template(f.read())
module_repr = template.render(**kwargs)

instance_dir = osp.join(get_home_dir(), 'propagate')
os.makedirs(instance_dir, exist_ok=True)
instance_path = osp.join(instance_dir, f'{module_name}.py')
with open(instance_path, 'w') as f:
f.write(module_repr)

spec = importlib.util.spec_from_file_location(module_name, instance_path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module


def type_hint_to_str(type_hint: Any) -> str:
type_repr = str(type_hint)
type_repr = re.sub(r'<class \'(.*)\'>', r'\1', type_repr)
return type_repr


def find_parenthesis_content(text: str, prefix: str) -> Optional[str]:
match = re.search(prefix, text)
if match is None:
return

offset = text[match.start():].find('(')
if offset < 0:
return

Check warning on line 52 in torch_geometric/nn/conv/propagate.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/conv/propagate.py#L52

Added line #L52 was not covered by tests

content = text[match.start() + offset:]

num_opened = num_closed = 0
for end, char in enumerate(content):
if char == '(':
num_opened += 1
if char == ')':
num_closed += 1
if num_opened > 0 and num_opened == num_closed:
content = content[1:end]
content = content.replace('\n', ' ').replace('#', ' ')
return re.sub(' +', ' ', content.replace('\n', ' ')).strip()
3 changes: 2 additions & 1 deletion torch_geometric/nn/conv/res_gated_graph_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ def forward(
else:
k, q, v = x[1], x[0], x[0]

# propagate_type: (k: Tensor, q: Tensor, v: Tensor, edge_attr: OptTensor) # noqa
# propagate_type: (k: Tensor, q: Tensor, v: Tensor,
# edge_attr: OptTensor)
out = self.propagate(edge_index, k=k, q=q, v=v, edge_attr=edge_attr)

if self.root_weight:
Expand Down
3 changes: 2 additions & 1 deletion torch_geometric/nn/conv/rgat_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,8 @@ def forward(
:obj:`(edge_index, attention_weights)`, holding the computed
attention weights for each edge. (default: :obj:`None`)
"""
# propagate_type: (x: Tensor, edge_type: OptTensor, edge_attr: OptTensor) # noqa
# propagate_type: (x: Tensor, edge_type: OptTensor,
# edge_attr: OptTensor)
out = self.propagate(edge_index=edge_index, edge_type=edge_type, x=x,
size=size, edge_attr=edge_attr)

Expand Down
3 changes: 2 additions & 1 deletion torch_geometric/nn/conv/transformer_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,8 @@ def forward( # noqa: F811
key = self.lin_key(x[0]).view(-1, H, C)
value = self.lin_value(x[0]).view(-1, H, C)

# propagate_type: (query: Tensor, key:Tensor, value: Tensor, edge_attr: OptTensor) # noqa
# propagate_type: (query: Tensor, key:Tensor, value: Tensor,
# edge_attr: OptTensor)
out = self.propagate(edge_index, query=query, key=key, value=value,
edge_attr=edge_attr)

Expand Down
Loading