Skip to content

Commit

Permalink
Merge branch 'master' into sparse_tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Mar 15, 2023
2 parents a386ea1 + 324a4f3 commit 27828f8
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 53 deletions.
6 changes: 6 additions & 0 deletions benchmark/inference/inference_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def run(args: argparse.ArgumentParser):
num_neighbors=[-1], # layer-wise inference
input_nodes=mask,
sampler=sampler,
filter_per_worker=args.filter_per_worker,
**kwargs,
) if with_loader else None
if args.evaluate and not args.full_batch:
Expand All @@ -110,6 +111,7 @@ def run(args: argparse.ArgumentParser):
num_neighbors=[-1], # layer-wise inference
input_nodes=test_mask,
sampler=None,
filter_per_worker=args.filter_per_worker,
**kwargs,
)

Expand All @@ -122,6 +124,7 @@ def run(args: argparse.ArgumentParser):
num_neighbors=num_neighbors,
input_nodes=mask,
sampler=sampler,
filter_per_worker=args.filter_per_worker,
**kwargs,
) if with_loader else None
if args.evaluate and not args.full_batch:
Expand All @@ -130,6 +133,7 @@ def run(args: argparse.ArgumentParser):
num_neighbors=num_neighbors,
input_nodes=test_mask,
sampler=None,
filter_per_worker=args.filter_per_worker,
**kwargs,
)

Expand Down Expand Up @@ -269,6 +273,8 @@ def run(args: argparse.ArgumentParser):
help='Use DataLoader affinitzation.')
add('--loader-cores', nargs='+', default=[], type=int,
help="List of CPU core IDs to use for DataLoader workers")
add('--filter-per-worker', action='store_true',
help='Enable filter-per-worker feature of the dataloader.')
add('--measure-load-time', action='store_true')
add('--full-batch', action='store_true', help='Use full batch mode')
add('--evaluate', action='store_true')
Expand Down
19 changes: 11 additions & 8 deletions benchmark/loader/neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def run(args: argparse.ArgumentParser):
data, num_neighbors=num_neighbors,
input_nodes=train_idx, batch_size=batch_size,
shuffle=True, num_workers=args.num_workers,
filter_per_worker=args.filter)
filter_per_worker=args.filter_per_worker)
cpu_affinity = train_loader.enable_cpu_affinity(
args.loader_cores
) if args.cpu_affinity else nullcontext()
Expand All @@ -71,12 +71,15 @@ def run(args: argparse.ArgumentParser):
if eval_batch_sizes is not None:
print('Evaluation sampling with all neighbors')
for batch_size in eval_batch_sizes:
subgraph_loader = NeighborLoader(data, num_neighbors=[-1],
input_nodes=eval_idx,
batch_size=batch_size,
shuffle=False,
num_workers=args.num_workers,
filter_per_worker=args.filter)
subgraph_loader = NeighborLoader(
data,
num_neighbors=[-1],
input_nodes=eval_idx,
batch_size=batch_size,
shuffle=False,
num_workers=args.num_workers,
filter_per_worker=args.filter_per_worker,
)
cpu_affinity = subgraph_loader.enable_cpu_affinity(
args.loader_cores) if args.cpu_affinity else nullcontext()
runtimes = []
Expand Down Expand Up @@ -120,7 +123,7 @@ def run(args: argparse.ArgumentParser):
help="Number of iterations for each test setting.")
add('--profile', default=False, action='store_true',
help="Run torch.profiler.")
add('--filter', default=False, action='store_true',
add('--filter-per-worker', default=False, action='store_true',
help="Use filter per worker.")
add('--cpu-affinity', default=False, action='store_true',
help="Use DataLoader affinitzation.")
Expand Down
30 changes: 23 additions & 7 deletions benchmark/training/training_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,28 @@ def run(args: argparse.ArgumentParser):
'shuffle': shuffle,
'num_workers': args.num_workers,
}
subgraph_loader = NeighborLoader(data, input_nodes=mask,
sampler=sampler, **kwargs)
subgraph_loader = NeighborLoader(
data,
input_nodes=mask,
sampler=sampler,
filter_per_worker=args.filter_per_worker,
**kwargs,
)
if args.evaluate:
val_loader = NeighborLoader(data, input_nodes=val_mask,
sampler=None, **kwargs)
test_loader = NeighborLoader(data,
input_nodes=test_mask,
sampler=None, **kwargs)
val_loader = NeighborLoader(
data,
input_nodes=val_mask,
sampler=None,
filter_per_worker=args.filter_per_worker,
**kwargs,
)
test_loader = NeighborLoader(
data,
input_nodes=test_mask,
sampler=None,
filter_per_worker=args.filter_per_worker,
**kwargs,
)
for hidden_channels in args.num_hidden_channels:
print('----------------------------------------------')
print(f'Batch size={batch_size}, '
Expand Down Expand Up @@ -301,6 +315,8 @@ def run(args: argparse.ArgumentParser):
help="Use DataLoader affinitzation.")
add('--loader-cores', nargs='+', default=[], type=int,
help="List of CPU core IDs to use for DataLoader workers.")
add('--filter-per-worker', action='store_true',
help='Enable filter-per-worker feature of the dataloader.')
add('--measure-load-time', action='store_true')
add('--evaluate', action='store_true')
add('--write-csv', action='store_true', help='Write benchmark data to csv')
Expand Down
9 changes: 5 additions & 4 deletions examples/multi_gpu/distributed_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,18 +84,19 @@ def run(rank, world_size: int, dataset_name: str, root: str):
for epoch in range(1, 51):
model.train()

total_loss = 0
total_loss = torch.zeros(2).to(rank)
for data in train_loader:
data = data.to(rank)
optimizer.zero_grad()
logits = model(data.x, data.adj_t, data.batch)
loss = criterion(logits, data.y.to(torch.float))
loss.backward()
optimizer.step()
total_loss += float(loss) * logits.size(0)
loss = total_loss / len(train_loader.dataset)
total_loss[0] += float(loss) * logits.size(0)
total_loss[1] += data.num_graphs

dist.barrier()
dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
loss = float(total_loss[0] / total_loss[1])

if rank == 0: # We evaluate on a single GPU for now.
model.eval()
Expand Down
67 changes: 50 additions & 17 deletions test/explain/algorithm/test_captum_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,25 @@
from torch_geometric.nn import GCNConv, global_add_pool
from torch_geometric.testing import withPackage

methods = [
'Saliency',
'InputXGradient',
'Deconvolution',
'ShapleyValueSampling',
'IntegratedGradients',
'GuidedBackprop',
]

unsupported_methods = [
'FeatureAblation',
'Occlusion',
'DeepLift',
'DeepLiftShap',
'GradientShap',
'KernelShap',
'Lime',
]


class GCN(torch.nn.Module):
def __init__(self, model_config: ModelConfig):
Expand Down Expand Up @@ -78,19 +97,35 @@ def check_explanation(


@withPackage('captum')
@pytest.mark.parametrize('method', unsupported_methods)
def test_unsupported_methods(method):
model_config = ModelConfig(mode='regression', task_level='node')

with pytest.raises(ValueError, match="does not support attribution"):
Explainer(
GCN(model_config),
algorithm=CaptumExplainer(method),
explanation_type='model',
edge_mask_type='object',
node_mask_type='attributes',
model_config=model_config,
)


@withPackage('captum')
@pytest.mark.parametrize('method', methods)
@pytest.mark.parametrize('node_mask_type', node_mask_types)
@pytest.mark.parametrize('edge_mask_type', edge_mask_types)
@pytest.mark.parametrize('task_level', ['node', 'edge', 'graph'])
@pytest.mark.parametrize('index', [1, torch.arange(2)])
def test_captum_explainer_multiclass_classification(
method,
data,
node_mask_type,
edge_mask_type,
task_level,
index,
):
import captum

if node_mask_type is None and edge_mask_type is None:
return

Expand All @@ -103,11 +138,9 @@ def test_captum_explainer_multiclass_classification(
return_type='probs',
)

model = GCN(model_config)

explainer = Explainer(
model,
algorithm=CaptumExplainer(captum.attr.IntegratedGradients),
GCN(model_config),
algorithm=CaptumExplainer(method),
explanation_type='model',
edge_mask_type=edge_mask_type,
node_mask_type=node_mask_type,
Expand All @@ -125,26 +158,26 @@ def test_captum_explainer_multiclass_classification(


@withPackage('captum')
@pytest.mark.parametrize('method', methods)
@pytest.mark.parametrize('node_mask_type', node_mask_types)
@pytest.mark.parametrize('edge_mask_type', edge_mask_types)
@pytest.mark.parametrize('index', [1, torch.arange(2)])
def test_captum_hetero_data(node_mask_type, edge_mask_type, index, hetero_data,
hetero_model):
def test_captum_hetero_data(method, node_mask_type, edge_mask_type, index,
hetero_data, hetero_model):

if node_mask_type is None or edge_mask_type is None:
if method == 'ShapleyValueSampling':
# This currently takes too long to test and is already covered by
# by the homogeneous graph test case.
return

model_config = ModelConfig(
mode='regression',
task_level='node',
return_type='raw',
)
if node_mask_type is None or edge_mask_type is None:
return

model = hetero_model(hetero_data.metadata())
model_config = ModelConfig(mode='regression', task_level='node')

explainer = Explainer(
model,
algorithm=CaptumExplainer('IntegratedGradients'),
hetero_model(hetero_data.metadata()),
algorithm=CaptumExplainer(method),
edge_mask_type=edge_mask_type,
node_mask_type=node_mask_type,
model_config=model_config,
Expand Down
71 changes: 54 additions & 17 deletions torch_geometric/explain/algorithm/captum_explainer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import logging
import warnings
from typing import Any, Dict, Optional, Union
Expand Down Expand Up @@ -26,11 +27,29 @@ class CaptumExplainer(ExplainerAlgorithm):
This explainer algorithm uses :captum:`null` `Captum <https://captum.ai/>`_
to compute attributions.
Currently, the following attribution methods are supported:
* :class:`captum.attr.IntegratedGradients`
* :class:`captum.attr.Saliency`
* :class:`captum.attr.InputXGradient`
* :class:`captum.attr.Deconvolution`
* :class:`captum.attr.ShapleyValueSampling`
* :class:`captum.attr.GuidedBackprop`
Args:
attribution_method (Attribution or str): The Captum attribution method
to use. Can be a string or a :class:`captum.attr` method.
**kwargs: Additional arguments for the Captum attribution method.
"""
SUPPORTED_METHODS = [ # TODO: Add support for more methods.
'IntegratedGradients',
'Saliency',
'InputXGradient',
'Deconvolution',
'ShapleyValueSampling',
'GuidedBackprop',
]

def __init__(
self,
attribution_method: Union[str, Any],
Expand All @@ -40,8 +59,6 @@ def __init__(

import captum.attr # noqa

self.kwargs = kwargs

if isinstance(attribution_method, str):
self.attribution_method = getattr(
captum.attr,
Expand All @@ -50,6 +67,19 @@ def __init__(
else:
self.attribution_method = attribution_method

if not self._is_supported_attribution_method():
raise ValueError(f"{self.__class__.__name__} does not support "
f"attribution method "
f"{self.attribution_method.__name__}")

if kwargs.get('internal_batch_size', 1) != 1:
warnings.warn("Overriding 'internal_batch_size' to 1")

if 'internal_batch_size' in self._get_attribute_parameters():
kwargs['internal_batch_size'] = 1

self.kwargs = kwargs

def _get_mask_type(self) -> MaskLevelType:
r"""Based on the explainer config, return the mask type."""
node_mask_type = self.explainer_config.node_mask_type
Expand All @@ -65,9 +95,28 @@ def _get_mask_type(self) -> MaskLevelType:
"edge mask type is specified.")
return mask_type

def _support_multiple_indices(self) -> bool:
r"""Checks if the method supports multiple indices."""
return self.attribution_method.__name__ == 'IntegratedGradients'
def _get_attribute_parameters(self) -> Dict[str, Any]:
r"""Returns the attribute arguments."""
signature = inspect.signature(self.attribution_method.attribute)
return signature.parameters

def _needs_baseline(self) -> bool:
r"""Checks if the method needs a baseline."""
parameters = self._get_attribute_parameters()
if 'baselines' in parameters:
param = parameters['baselines']
if param.default is inspect.Parameter.empty:
return True
return False

def _is_supported_attribution_method(self) -> bool:
r"""Returns :obj:`True` if `self.attribution_method` is supported."""
# This is redundant for now since all supported methods need a baseline
if self._needs_baseline():
return False
elif self.attribution_method.__name__ in self.SUPPORTED_METHODS:
return True
return False

def forward(
self,
Expand All @@ -80,18 +129,6 @@ def forward(
**kwargs,
) -> Union[Explanation, HeteroExplanation]:

if isinstance(index, Tensor) and index.numel() > 1:
if not self._support_multiple_indices():
raise ValueError(
f"{self.attribution_method.__name__} does not support "
"multiple indices. Please use a single index or a "
"different attribution method.")

# TODO (matthias) Check if `internal_batch_size` can be passed.
if self.kwargs.get('internal_batch_size', 1) != 1:
warnings.warn("Overriding 'internal_batch_size' to 1")
self.kwargs['internal_batch_size'] = 1

mask_type = self._get_mask_type()

inputs, add_forward_args = to_captum_input(
Expand Down

0 comments on commit 27828f8

Please sign in to comment.