diff --git a/CHANGELOG.md b/CHANGELOG.md index 924167abb840..6d138eba1fcf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -45,6 +45,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for dropping nodes in `utils.to_dense_batch` in case `max_num_nodes` is smaller than the number of nodes ([#6124](https://github.com/pyg-team/pytorch_geometric/pull/6124)) - Added the RandLA-Net architecture as an example ([#5117](https://github.com/pyg-team/pytorch_geometric/pull/5117)) ### Changed +- Fix the default arguments of `DataParallel` class ([#6376](https://github.com/pyg-team/pytorch_geometric/pull/6376)) - Fix `ImbalancedSampler` on sliced `InMemoryDataset` ([#6374](https://github.com/pyg-team/pytorch_geometric/pull/6374)) - Breaking Change: Changed the interface and implementation of `GraphMultisetTransformer` ([#6343](https://github.com/pyg-team/pytorch_geometric/pull/6343)) - Fixed the approximate PPR variant in `transforms.GDC` to not crash on graphs with isolated nodes ([#6242](https://github.com/pyg-team/pytorch_geometric/pull/6242)) diff --git a/torch_geometric/nn/data_parallel.py b/torch_geometric/nn/data_parallel.py index 63e1289ca38b..4ff960e0a196 100644 --- a/torch_geometric/nn/data_parallel.py +++ b/torch_geometric/nn/data_parallel.py @@ -34,16 +34,16 @@ class DataParallel(torch.nn.DataParallel): output_device (int or torch.device): Device location of output. (default: :obj:`device_ids[0]`) follow_batch (list or tuple, optional): Creates assignment batch - vectors for each key in the list. (default: :obj:`[]`) + vectors for each key in the list. (default: :obj:`None`) exclude_keys (list or tuple, optional): Will exclude each key in the - list. (default: :obj:`[]`) + list. (default: :obj:`None`) """ def __init__(self, module, device_ids=None, output_device=None, - follow_batch=[], exclude_keys=[]): + follow_batch=None, exclude_keys=None): super().__init__(module, device_ids, output_device) self.src_device = torch.device(f'cuda:{self.device_ids[0]}') - self.follow_batch = follow_batch - self.exclude_keys = exclude_keys + self.follow_batch = follow_batch or [] + self.exclude_keys = exclude_keys or [] def forward(self, data_list): """"""