Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for LSTM aggregator in SAGEConv #4379

Merged
merged 20 commits into from
Apr 8, 2022
Merged

Conversation

hunarbatra
Copy link
Contributor

@hunarbatra hunarbatra commented Mar 30, 2022

This pull request adds the implementation for LSTM, Max Pool and GCN aggregation for SAGEConv which were implemented in the original paper/code Inductive Representation Learning on Large Graphs and the absence of these aggregator functions implementation has been brought up in issue #1147. Addition of these aggregators to SAGEConv would help others looking to use these with PyTorch Geometric SAGEConv, since DGL supports them too!

Thank you to @rusty1s for guiding me with the LSTM implementation! :)

Notes

  • I have removed kwargs.setdefault('aggr', 'mean') because that was making SAGEConv inherit 'mean' aggr from MessagePassing.
  • I've added aggregator_type argument which is set to mean by default

Copy link
Member

@rusty1s rusty1s left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you :) Left some comments.

torch_geometric/nn/conv/sage_conv.py Outdated Show resolved Hide resolved
torch_geometric/nn/conv/sage_conv.py Outdated Show resolved Hide resolved
torch_geometric/nn/conv/sage_conv.py Show resolved Hide resolved
torch_geometric/nn/conv/sage_conv.py Outdated Show resolved Hide resolved
torch_geometric/nn/conv/sage_conv.py Outdated Show resolved Hide resolved
torch_geometric/nn/conv/sage_conv.py Outdated Show resolved Hide resolved
@codecov
Copy link

codecov bot commented Apr 1, 2022

Codecov Report

Merging #4379 (3b5aa4c) into master (8d8bbd8) will decrease coverage by 0.09%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##           master    #4379      +/-   ##
==========================================
- Coverage   82.68%   82.58%   -0.10%     
==========================================
  Files         312      312              
  Lines       16118    16135      +17     
==========================================
- Hits        13327    13325       -2     
- Misses       2791     2810      +19     
Impacted Files Coverage Δ
torch_geometric/nn/conv/sage_conv.py 100.00% <100.00%> (ø)
torch_geometric/nn/conv/utils/typing.py 81.25% <0.00%> (-17.50%) ⬇️
torch_geometric/io/tu.py 93.58% <0.00%> (-2.57%) ⬇️
torch_geometric/nn/models/mlp.py 98.41% <0.00%> (-1.59%) ⬇️
torch_geometric/transforms/gdc.py 78.17% <0.00%> (-1.02%) ⬇️
torch_geometric/data/dataset.py 96.80% <0.00%> (-0.80%) ⬇️
torch_geometric/nn/conv/rgat_conv.py 83.76% <0.00%> (-0.53%) ⬇️
torch_geometric/graphgym/utils/comp_budget.py 15.51% <0.00%> (+0.51%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 8d8bbd8...3b5aa4c. Read the comment docs.

@hunarbatra
Copy link
Contributor Author

I have converted LSTM's output from tuple to tensor. Kindly review it and let me know if it seems good to you and if something should be changed.

Also, setting self.aggr = 'lstm' throws an error since self.aggr initialises the parent constructor aggr value i.e in MessagePassing (where aggr is in AGGRS = {'add', 'sum', 'mean', 'min', 'max', 'mul'}).

So, currently with my changes LSTM is only available when the edge_index is a SparseTensor and when message_and_aggregate() is implemented.

Please let me know if you have any suggestions for making LSTM work even with Tensors

Copy link
Member

@rusty1s rusty1s left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates! One thing I think we still need to fix is that this currently only works in case the input is a SparseTensor. IMO, it would be great to have more general support for this directly inside aggregate rather than inside message_and_aggregate. For aggr='lstm', we would need to set self.fuse = False such that message_and_aggregate is not called. Furthermore, we would need to ensure that edge_index_i is sorted, and fail with a ValueError otherwise. We can easily check this via (edge_index_j[:-1] <= edge_index[1:]).all()). WDYT?

torch_geometric/nn/conv/sage_conv.py Outdated Show resolved Hide resolved
out = rst.squeeze(0)
return out

elif self.aggr == 'mean':
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
elif self.aggr == 'mean':
else:

should work here as well.

@wilcoln
Copy link
Contributor

wilcoln commented Apr 2, 2022

Thanks for the updates! One thing I think we still need to fix is that this currently only works in case the input is a SparseTensor. IMO, it would be great to have more general support for this directly inside aggregate rather than inside message_and_aggregate. For aggr='lstm', we would need to set self.fuse = False such that message_and_aggregate is not called. Furthermore, we would need to ensure that edge_index_i is sorted, and fail with a ValueError otherwise. We can easily check this via (edge_index_j[:-1] <= edge_index[1:]).all()). WDYT?

Hey @rusty1s , @hunarbatra How about converting edge_index into a SparseTensor before passing it to propagate as follows:

        # Convert edge_index to a sparse tensor,
        # this is required for propagate to call message_and_aggregate
        if isinstance(edge_index, Tensor):
            num_nodes = int(edge_index.max()) + 1
            edge_index = SparseTensor(row=edge_index[0], col=edge_index[1], sparse_sizes=(num_nodes, num_nodes))

        # propagate_type: (x: OptPairTensor)
        # propagate internally calls message_and_aggregate()
        out = self.propagate(edge_index, x=x, size=size)
        out = self.lin_l(out)

@rusty1s
Copy link
Member

rusty1s commented Apr 2, 2022

@wilcoln Yes, this works as well. The problem with this approach is that we would need to convert to SparseTensor in every GNN layer (which is an expensive op that ideally only needs to be done once).

@hunarbatra
Copy link
Contributor Author

hunarbatra commented Apr 6, 2022

@rusty1s Thank you! Yes, definitely having a more general support in aggregate() would be better and I have been working on it. I feel just checking if edge_index_i is sorted is not really working - because it still fails for some common datasets like Cora with dimension issues (when adding root embeddings in update out += self.lin_r(x_r)) even when edge_index_i is “sorted” - after debugging I noticed that the dimension generated with to_dense_batch() [batch_size part of out i.e first dim] is not correct and hence it’s leading to the dimension issues in further steps. Do you have any suggestion for this?

And thanks @wilcoln, that approach definitely works well - but I noticed that when I am converting it to SparseTensor and then applying LSTM - the F1 score is pretty low (for eg: just 14% for supervised Cora)

@rusty1s
Copy link
Member

rusty1s commented Apr 7, 2022

@hunarbatra Can you show me an example on how you tried to integrate it into message and aggregate?

@hunarbatra
Copy link
Contributor Author

hunarbatra commented Apr 7, 2022

Sure, @rusty1s. This is how I've implemented it in message() and aggregate()

def message(self, x_j: Tensor, x_i: Tensor) -> Tensor:
        return x_j

def aggregate(self, inputs: Tensor, edge_index_i, edge_index_j) -> Tensor:
        # LSTM
        if (edge_index_i[:-1] <= edge_index_i[1:]).all().tolist() is True:
            x_j = inputs[edge_index_j]
            x, mask = to_dense_batch(x_j, edge_index_i) # batch=edge_index_i should be "ordered"
            _, (hidden, _) = self.lstm(x)
            out = hidden.squeeze(0)
            return out
        else:
            raise ValueError("edge_index_i is not ordered")

@rusty1s rusty1s changed the title Added support for LSTM, Max Pool & GCN aggregators for SAGEConv Support for LSTM aggregator in SAGEConv Apr 8, 2022
Copy link
Member

@rusty1s rusty1s left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. I pushed the logic to aggregate and added a basic test case.

@germayneng
Copy link

germayneng commented Jun 22, 2022

Hi guys,

i am using the latest master branch and cannot seem to get lstm aggr to work. i am getting this error:

Based on the error, it seems like something to do with the dimension of the hidden layer if set via lstm?

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/tmp/ipykernel_207344/1053175680.py in <module>
----> 1 trainer.fit(model,datamodule=datamodule)

/opt/conda/envs/base/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    767         self.strategy.model = model
    768         self._call_and_handle_interrupt(
--> 769             self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    770         )
    771 

/opt/conda/envs/base/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _call_and_handle_interrupt(self, trainer_fn, *args, **kwargs)
    719                 return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs)
    720             else:
--> 721                 return trainer_fn(*args, **kwargs)
    722         # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7
    723         except KeyboardInterrupt as exception:

/opt/conda/envs/base/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    807             ckpt_path, model_provided=True, model_connected=self.lightning_module is not None
    808         )
--> 809         results = self._run(model, ckpt_path=self.ckpt_path)
    810 
    811         assert self.state.stopped

/opt/conda/envs/base/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _run(self, model, ckpt_path)
   1232         self._checkpoint_connector.resume_end()
   1233 
-> 1234         results = self._run_stage()
   1235 
   1236         log.detail(f"{self.__class__.__name__}: trainer tearing down")

/opt/conda/envs/base/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _run_stage(self)
   1319         if self.predicting:
   1320             return self._run_predict()
-> 1321         return self._run_train()
   1322 
   1323     def _pre_training_routine(self):

/opt/conda/envs/base/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _run_train(self)
   1349         self.fit_loop.trainer = self
   1350         with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> 1351             self.fit_loop.run()
   1352 
   1353     def _run_evaluate(self) -> _EVALUATE_OUTPUT:

/opt/conda/envs/base/lib/python3.7/site-packages/pytorch_lightning/loops/base.py in run(self, *args, **kwargs)
    202             try:
    203                 self.on_advance_start(*args, **kwargs)
--> 204                 self.advance(*args, **kwargs)
    205                 self.on_advance_end()
    206                 self._restarting = False

/opt/conda/envs/base/lib/python3.7/site-packages/pytorch_lightning/loops/fit_loop.py in advance(self)
    266         )
    267         with self.trainer.profiler.profile("run_training_epoch"):
--> 268             self._outputs = self.epoch_loop.run(self._data_fetcher)
    269 
    270     def on_advance_end(self) -> None:

/opt/conda/envs/base/lib/python3.7/site-packages/pytorch_lightning/loops/base.py in run(self, *args, **kwargs)
    202             try:
    203                 self.on_advance_start(*args, **kwargs)
--> 204                 self.advance(*args, **kwargs)
    205                 self.on_advance_end()
    206                 self._restarting = False

/opt/conda/envs/base/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py in advance(self, data_fetcher)
    206 
    207             with self.trainer.profiler.profile("run_training_batch"):
--> 208                 batch_output = self.batch_loop.run(batch, batch_idx)
    209 
    210         self.batch_progress.increment_processed()

/opt/conda/envs/base/lib/python3.7/site-packages/pytorch_lightning/loops/base.py in run(self, *args, **kwargs)
    202             try:
    203                 self.on_advance_start(*args, **kwargs)
--> 204                 self.advance(*args, **kwargs)
    205                 self.on_advance_end()
    206                 self._restarting = False

/opt/conda/envs/base/lib/python3.7/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py in advance(self, batch, batch_idx)
     86         if self.trainer.lightning_module.automatic_optimization:
     87             optimizers = _get_active_optimizers(self.trainer.optimizers, self.trainer.optimizer_frequencies, batch_idx)
---> 88             outputs = self.optimizer_loop.run(split_batch, optimizers, batch_idx)
     89         else:
     90             outputs = self.manual_loop.run(split_batch, batch_idx)

/opt/conda/envs/base/lib/python3.7/site-packages/pytorch_lightning/loops/base.py in run(self, *args, **kwargs)
    202             try:
    203                 self.on_advance_start(*args, **kwargs)
--> 204                 self.advance(*args, **kwargs)
    205                 self.on_advance_end()
    206                 self._restarting = False

/opt/conda/envs/base/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py in advance(self, batch, *args, **kwargs)
    205             self._batch_idx,
    206             self._optimizers[self.optim_progress.optimizer_position],
--> 207             self.optimizer_idx,
    208         )
    209         if result.loss is not None:

/opt/conda/envs/base/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py in _run_optimization(self, split_batch, batch_idx, optimizer, opt_idx)
    254         # gradient update with accumulated gradients
    255         else:
--> 256             self._optimizer_step(optimizer, opt_idx, batch_idx, closure)
    257 
    258         result = closure.consume_result()

/opt/conda/envs/base/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py in _optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
    376             on_tpu=isinstance(self.trainer.accelerator, TPUAccelerator),
    377             using_native_amp=(self.trainer.amp_backend == AMPType.NATIVE),
--> 378             using_lbfgs=is_lbfgs,
    379         )
    380 

/opt/conda/envs/base/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _call_lightning_module_hook(self, hook_name, pl_module, *args, **kwargs)
   1591 
   1592         with self.profiler.profile(f"[LightningModule]{pl_module.__class__.__name__}.{hook_name}"):
-> 1593             output = fn(*args, **kwargs)
   1594 
   1595         # restore current_fx when nested context

/opt/conda/envs/base/lib/python3.7/site-packages/pytorch_lightning/core/lightning.py in optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs)
   1642 
   1643         """
-> 1644         optimizer.step(closure=optimizer_closure)
   1645 
   1646     def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int):

/opt/conda/envs/base/lib/python3.7/site-packages/pytorch_lightning/core/optimizer.py in step(self, closure, **kwargs)
    166 
    167         assert self._strategy is not None
--> 168         step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
    169 
    170         self._on_after_step()

/opt/conda/envs/base/lib/python3.7/site-packages/pytorch_lightning/strategies/strategy.py in optimizer_step(self, optimizer, opt_idx, closure, model, **kwargs)
    191         """
    192         model = model or self.lightning_module
--> 193         return self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs)
    194 
    195     def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]:

/opt/conda/envs/base/lib/python3.7/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py in optimizer_step(self, model, optimizer, optimizer_idx, closure, **kwargs)
    153         if isinstance(model, pl.LightningModule):
    154             closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure)
--> 155         return optimizer.step(closure=closure, **kwargs)
    156 
    157     def _track_grad_norm(self, trainer: "pl.Trainer") -> None:

/opt/conda/envs/base/lib/python3.7/site-packages/torch/optim/optimizer.py in wrapper(*args, **kwargs)
     86                 profile_name = "Optimizer.step#{}.step".format(obj.__class__.__name__)
     87                 with torch.autograd.profiler.record_function(profile_name):
---> 88                     return func(*args, **kwargs)
     89             return wrapper
     90 

/opt/conda/envs/base/lib/python3.7/site-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
     25         def decorate_context(*args, **kwargs):
     26             with self.clone():
---> 27                 return func(*args, **kwargs)
     28         return cast(F, decorate_context)
     29 

/opt/conda/envs/base/lib/python3.7/site-packages/torch/optim/adamw.py in step(self, closure)
     98         if closure is not None:
     99             with torch.enable_grad():
--> 100                 loss = closure()
    101 
    102         for group in self.param_groups:

/opt/conda/envs/base/lib/python3.7/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py in _wrap_closure(self, model, optimizer, optimizer_idx, closure)
    138         consistent with the ``PrecisionPlugin`` subclasses that cannot pass ``optimizer.step(closure)`` directly.
    139         """
--> 140         closure_result = closure()
    141         self._after_closure(model, optimizer, optimizer_idx)
    142         return closure_result

/opt/conda/envs/base/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py in __call__(self, *args, **kwargs)
    146 
    147     def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:
--> 148         self._result = self.closure(*args, **kwargs)
    149         return self._result.loss
    150 

/opt/conda/envs/base/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py in closure(self, *args, **kwargs)
    132 
    133     def closure(self, *args: Any, **kwargs: Any) -> ClosureResult:
--> 134         step_output = self._step_fn()
    135 
    136         if step_output.closure_loss is None:

/opt/conda/envs/base/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py in _training_step(self, split_batch, batch_idx, opt_idx)
    425 
    426         # manually capture logged metrics
--> 427         training_step_output = self.trainer._call_strategy_hook("training_step", *step_kwargs.values())
    428         self.trainer.strategy.post_training_step()
    429 

/opt/conda/envs/base/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _call_strategy_hook(self, hook_name, *args, **kwargs)
   1761 
   1762         with self.profiler.profile(f"[Strategy]{self.strategy.__class__.__name__}.{hook_name}"):
-> 1763             output = fn(*args, **kwargs)
   1764 
   1765         # restore current_fx when nested context

/opt/conda/envs/base/lib/python3.7/site-packages/pytorch_lightning/strategies/strategy.py in training_step(self, *args, **kwargs)
    331         """
    332         with self.precision_plugin.train_step_context():
--> 333             return self.model.training_step(*args, **kwargs)
    334 
    335     def post_training_step(self):

~/base-training/base_training/models/models.py in training_step(self, batch, batch_idx)
    334 
    335     def training_step(self, batch, batch_idx):
--> 336         loss, pos_loss, neg_loss = self.__share_step(batch, "train")
    337 
    338         return {"loss": loss, "pos_loss": pos_loss, "neg_loss": neg_loss}

~/base-training/base_training/models/models.py in __share_step(self, batch, mode)
    350             # do forward
    351             with torch.cuda.amp.autocast(enabled=True):
--> 352                 logits = self.forward(x, adjs).squeeze(1)
    353                 loss, pos_loss, neg_loss = self._criterion(logits)
    354         else:

~/base-training/base_training/models/models.py in forward(self, x, adjs, embedding)
    269             x_target = x[:size]
    270             x = self.convs[i](
--> 271                 (x.to(self.device), x_target.to(self.device)), edge_index
    272             )  # x_j flow to x_i
    273             if i != self.k_hops - 1:

/opt/conda/envs/base/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input, **kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []

/opt/conda/envs/base/lib/python3.7/site-packages/torch_geometric/nn/conv/sage_conv.py in forward(self, x, edge_index, size)
    121 
    122         # propagate_type: (x: OptPairTensor)
--> 123         out = self.propagate(edge_index, x=x, size=size)
    124         out = self.lin_l(out)
    125 

/opt/conda/envs/base/lib/python3.7/site-packages/torch_geometric/nn/conv/message_passing.py in propagate(self, edge_index, size, **kwargs)
    350 
    351                 if len(self.aggrs) == 0:
--> 352                     out = self.aggregate(out, **aggr_kwargs)
    353                 else:
    354                     outs = []

/opt/conda/envs/base/lib/python3.7/site-packages/torch_geometric/nn/conv/sage_conv.py in aggregate(self, x, index, ptr, dim_size)
    145         if self.aggr is not None:
    146             return scatter(x, index, dim=self.node_dim, dim_size=dim_size,
--> 147                            reduce=self.aggr)
    148 
    149         # LSTM aggregation:

/opt/conda/envs/base/lib/python3.7/site-packages/torch_scatter/scatter.py in scatter(src, index, dim, out, dim_size, reduce)
    160         return scatter_max(src, index, dim, out, dim_size)[0]
    161     else:
--> 162         raise ValueError

ValueError: 

@rusty1s
Copy link
Member

rusty1s commented Jun 22, 2022

Do you have a minimal script to reproduce? Please also have a look at https://github.com/pyg-team/pytorch_geometric/blob/master/test/nn/conv/test_sage_conv.py#L59-L73.

@germayneng
Copy link

germayneng commented Jun 22, 2022

Thanks for the test reference. i found the issue:
I was setting aggr based on the latest release of 2.0.4 via:

self.conv[i].aggr = "lstm"

using the param aggr seems to fixed the issue

SageConv(aggr = "lstm")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants