Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions docs/README_mosaic.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Subgraph Learning with SALIENT++ via MOSAIC


This README describes how to apply the MOSAIC transformation to enable subgraph classification workloads to run within the SALIENT++ system. It also shows how to compose subgraph-aware labeling techniques like [GLASS](https://openreview.net/forum?id=XLxhEjKNbXj) with modular nodewise architectures.

![](./figs/mosaic.png)


## Preprocessing Steps

### 0. Prepare the datasets

The script supports the following datasets:

Synthetic
- coreness
- cut_ratio
- density
- component

Real-world
- em_user
- hpo_metab
- ppi_bp
- hpo_neuro
- elliptic2

See the [GLASS](https://github.com/Xi-yuanWang/GLASS/tree/main) repo for instructions on accessing these.

For the Elliptic2 dataset, see [Elliptic2](https://github.com/MITIBMxGraph/Elliptic2/tree/main?tab=readme-ov-file) for access. It provides instructions on producing the `edge_list.txt` and `subgraphs.pth` files. This script expects the dataset to be in these file formats.

### 1. Preprocess the dataset

Run the following command to generate a directory containing the `FastDataset` torch files. You should see a directory structure like `dataset/ppi_bp/ppi_bp`, containing `rowptr.pt`, `x.pt`, etc.

```
python -m scripts.preprocess_SALIENT --dataset_dir DATASET_NAME
```

`DATASET_NAME` refers to directories relative to `SALIENT_plusplus/dataset`, e.g. `ppi_bp`.

Now, you are ready to run the training pipeline!

## Running within SALIENT

Run the experiment driver as instructed in the installation [instructions](./INSTALL.md)

For example, to run on the `ppi_bp` dataset with the same hyperparameters as the GLASS paper (see their [configs](https://github.com/Xi-yuanWang/GLASS/blob/main/config/ppi_bp.yml)), run the following command.

`python -m utils.exp_driver --num_machines 1 --num_gpus_per_machine 1 --gpu_percent 0.999 --replication_factor 15 --run_local --train_fanouts -1 -1 -1 --test_fanouts -1 -1 -1 --num_hidden 64 --train_batch_size 80 --learning_rate 0.0005 --dataset_name ppi_bp --dataset_dir ./dataset/ppi_bp --job_name test-job --model_name sageresinception --num_epochs 300 --use_subgraph_label`

There are two key modifications to be aware of:

### 1. Fanout modification
The MOSAIC transformation introduces an additional message-passing layer for subgraph representative nodes. Therefore, specify an extra fanout of -1 at the start of the fanout lists (train and test).

In this case, the command represents an architecture with 2 convolution layers, but has 3 layers specified to account for the subgraph representatives.

### 2. Subgraph labeling
We now have the optional specifier `--use_subgraph_label`, which modifies the batch preparation process to add a 1 flag to all subgraph representatives within the batch. Specifically this occurs within the `DevicePrefetcher` class within `fast_trainer/transferers.py`.

**Note** - the preprocessing script automatically appends an 0-valued feature to all node feature vectors to accommodate subgraph labeling, avoiding the need for multiple versions of the dataset.

## Future work

- Adding support for the distributed execution across multiple machines
- Hyperparameter tuning
- Incorporating gcn aggregation to better emulate GLASS's configuration process
- Evaluation on additional datasets or model backbones

Binary file added docs/figs/mosaic.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion driver/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def from_path_if_exists(cls, path, name, skip_features=False):
path = Path(path).joinpath(name)
assert path.exists() and path.is_dir()
data = {
field: torch.load(path.joinpath(field + '.pt'))
field: torch.load(path.joinpath(field + '.pt'), map_location=torch.device('cpu'))
for field in cls._fields if not skip_features or (field != 'y' and field != 'x')
}
if not skip_features:
Expand Down
130 changes: 89 additions & 41 deletions driver/drivers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,15 @@ class BaseDriver:
lr: float
train_loader: ABCNeighborSampler
train_transferer: Type[DeviceIterator]
test_loader: ABCNeighborSampler
test_transferer: Type[DeviceIterator]
train_impl: TrainImpl
train_max_num_batches: int
model: torch.nn.Module
make_subgraph_loader: Callable[[torch.Tensor], Iterable[PreparedBatch]]
#evaluator: Evaluator
log_file: Path
binary_class: bool


def get_num_trainers(self):
Expand All @@ -65,6 +67,7 @@ def __init__(self, args, devices: List[torch.device],
self.logs = []
self.firstRun = True
self.TRIAL_NUM = 0
self.binary_class = self.dataset.y.unique().size(0) == 2
assert len(self.devices) > 0

if args.train_type == 'serial' and len(self.devices) > 1:
Expand Down Expand Up @@ -102,6 +105,7 @@ def __init__(self, args, devices: List[torch.device],
count_remote_frequency=True,
# Cannot use cache if not yet created.
use_cache=False,
use_subgraph_label=args.use_subgraph_label
)

if self.args.distribute_data:
Expand Down Expand Up @@ -143,11 +147,12 @@ def __init__(self, args, devices: List[torch.device],
count_remote_frequency=False,
# If a cache is created then use it.
use_cache=self.create_cache,
use_subgraph_label=args.use_subgraph_label
)
elif self.compute and self.args.distribute_data:
# TODO: Add 1D version of serial_idx kernel
train_cfg = FastSamplerConfig(
x_cpu=self.x_cpu, x_gpu=self.x_gpu, y=self.dataset.y.unsqueeze(-1),
x_cpu=self.x_cpu, x_gpu=self.x_gpu, y=self.dataset.y,
rowptr=self.dataset.rowptr, col=self.dataset.col,
# After the initial creation, this config object with a shuffled idx, can use placeholder data.
idx=torch.zeros(1, dtype=torch.int64),
Expand All @@ -165,11 +170,13 @@ def __init__(self, args, devices: List[torch.device],
count_remote_frequency=False,
# If a cache is created then use it.
use_cache=self.create_cache,
use_subgraph_label=args.use_subgraph_label
)
elif self.compute:
# TODO: Add 1D version of serial_idx kernel
train_cfg = FastSamplerConfig(
x_cpu=self.dataset.x, x_gpu=self.x_gpu, y=self.dataset.y.unsqueeze(-1),
# If multilabel, it already has 2nd dimension. Otherwise we add a [1] dimension on the end
x_cpu=self.dataset.x, x_gpu=self.x_gpu, y=self.dataset.y,
rowptr=self.dataset.rowptr, col=self.dataset.col,
# After the initial creation, this config object with a shuffled idx, can use placeholder data.
idx=torch.empty(self.dataset.split_idx['train'].numel(), dtype=torch.int64),
Expand All @@ -187,6 +194,7 @@ def __init__(self, args, devices: List[torch.device],
count_remote_frequency=False,
# If a cache is created then use it.
use_cache=self.create_cache,
use_subgraph_label=args.use_subgraph_label
)
else:
raise ValueError(f'Did not create a valid FastSamplerConfig for {args.execution_mode=}')
Expand Down Expand Up @@ -224,15 +232,28 @@ def make_loader(sampler, cfg: FastSamplerConfig):
else:
self.train_transferer = DeviceDistributedPrefetcher
self.test_transferer = DeviceDistributedPrefetcher
# print("Train loader:", self.train_loader)
# print("Train transferer:", self.train_transferer)
print("Train type ", args.train_type)
self.train_impl = {'dp': train.data_parallel_train,
'serial': train.serial_train}[args.train_type]

# From GLASS
if self.dataset.y.unique().shape[0] == 2:
if self.dataset.y.ndim > 1:
output_channels = self.dataset.y.shape[1]
else:
output_channels = 1
else:
output_channels = self.dataset.y.unique().shape[0]

self.model = self.model_type(
self.dataset.num_features, args.hidden_features,
self.dataset.num_classes,
output_channels,
num_layers=args.num_layers).to(self.main_device)
self.model_noddp = self.model_type(
self.dataset.num_features, args.hidden_features,
self.dataset.num_classes,
output_channels,
num_layers=args.num_layers).to(self.main_device)
self.idx_arange = torch.arange(self.dataset.y.numel())
#self.evaluator = Evaluator(name=args.dataset_name)
Expand Down Expand Up @@ -279,9 +300,13 @@ def main_device(self) -> torch.device:
def get_idx_test(self) -> None:
return self.dataset.split_idx['test']

# Added flag for subgraph label here
def make_train_devit(self) -> DeviceIterator:
return self.train_transferer(self.devices, iter(self.train_loader), pipeline_on = not self.args.pipeline_disabled)
return self.train_transferer(self.devices, iter(self.train_loader), pipeline_on = not self.args.pipeline_disabled, use_subgraph_label=self.args.use_subgraph_label)

def make_test_devit(self) -> DeviceIterator:
return self.test_transferer([self.main_device], iter(self.test_loader), pipeline_on = not self.args.pipeline_disabled, use_subgraph_label=self.args.use_subgraph_label)

def log(self, t) -> None:
self.logs.append(t)
if self.is_main_proc and self.args.verbose:
Expand Down Expand Up @@ -341,58 +366,69 @@ def record_sampler_init_time(x):
append_runtime_stats("Sampler init", devit.it.get_stats(
).total_blocked_dur.total_seconds() * 1000)
if self.is_main_proc:
if self.args.train_sampler == 'NeighborSampler':
pbar = tqdm(total=self.train_loader.node_idx.numel())
else:
pbar = tqdm(total=self.train_loader.idx.numel())
pbar.set_description(f'Epoch {epoch}')
pass # TODO remove
# if self.args.train_sampler == 'NeighborSampler':
# pbar = tqdm(total=self.train_loader.node_idx.numel())
# else:
# pbar = tqdm(total=self.train_loader.idx.numel())
# pbar.set_description(f'Epoch {epoch}')

def cb(inputs, results):
if self.is_main_proc:
pbar.update(sum(batch.batch_size for batch in inputs))
return # TODO taking out update
# if self.is_main_proc:
# pbar.update(sum(batch.batch_size for batch in inputs))

def cb_NS(inputs, results):
if self.is_main_proc:
pbar.update(sum(bs[0] for bs in inputs))
return # TODO taking out update
# if self.is_main_proc:
# pbar.update(sum(bs[0] for bs in inputs))

def log_total_compute_time(x):
append_runtime_stats("total", x.nanos/1000000)
self.log(x)

# print("Entering train call")
with Timer((epoch, 'Compute'), log_total_compute_time) as timer:
if self.args.train_sampler == 'NeighborSampler':
self.train_impl(self.model, train.barebones_train_core,
devit, self.optimizer, lr_scheduler,
self.binary_class,
cb_NS, dataset=self.dataset,
devices=self.devices)
else:
self.train_impl(self.model, train.barebones_train_core,
devit, self.optimizer, lr_scheduler,
self.binary_class,
cb, dataset=None, devices=None)
# Barrier is not needed for correctness. I'm also not sure it is needed for accurate
# timing either because of synchronization in DDP model. In any case, including it
# here to make sure there is a synchronization point inside the compute region.
if dist.is_initialized():
dist.barrier()
timer.stop()
timer.stop()

max_gpu_memory = torch.cuda.max_memory_allocated() / (1024 ** 3) # GB
print(f"Max GPU memory allocated: {max_gpu_memory:.2f} GB")

runtime_stats_cuda.report_stats({'total': 'Total', 'data_transfer': 'Data Transfer', 'sampling': 'Sampling + Slicing', 'train': 'Train', 'sampling2': 'Sampling Blocking'})
# TODO - TOOK OUT THE TIMING
# runtime_stats_cuda.report_stats({'total': 'Total', 'data_transfer': 'Data Transfer', 'sampling': 'Sampling + Slicing', 'train': 'Train', 'sampling2': 'Sampling Blocking'})

# Log amount of communication during training.
if self.args.distribute_data:
# NOTE: These values are off by a factor of 4. Only useful for relative comparisons.
self.log(f"NUM_SENT_BYTES(name={epoch}, bytes={devit.NUMBER_OF_SENT_BYTES})")
#print(f"NUM_SENT_BYTES(name={epoch}, bytes={devit.NUMBER_OF_SENT_BYTES})", flush=True)
if self.is_main_proc:
if self.args.train_sampler != 'NeighborSampler' and \
isinstance(devit.it, FastSamplerIter):

self.log((epoch, devit.it.get_stats()))
# append runtime stats. Convert units to milliseconds
append_runtime_stats("Sampling block time", devit.it.get_stats(
).total_blocked_dur.total_seconds()*1000)
pbar.close()
del pbar
pass # TODO remove pbars
# if self.args.train_sampler != 'NeighborSampler' and \
# isinstance(devit.it, FastSamplerIter):

# self.log((epoch, devit.it.get_stats()))
# # append runtime stats. Convert units to milliseconds
# append_runtime_stats("Sampling block time", devit.it.get_stats(
# ).total_blocked_dur.total_seconds()*1000)
# pbar.close()
# del pbar
if self.args.train_sampler == 'FastSampler':

"""
Expand Down Expand Up @@ -464,7 +500,7 @@ def batchwise_test(self, sets=None) -> Mapping[str, float]:


cfg = FastSamplerConfig(
x_cpu=self.x_cpu, x_gpu=self.x_gpu, y=self.dataset.y.unsqueeze(-1),
x_cpu=self.x_cpu, x_gpu=self.x_gpu, y=self.dataset.y,
rowptr=self.dataset.rowptr, col=self.dataset.col,
idx=self.get_idx_test(name),
batch_size=local_batchsize,
Expand All @@ -482,23 +518,27 @@ def batchwise_test(self, sets=None) -> Mapping[str, float]:
count_remote_frequency=False,
# Cache for inference not yet supported.
use_cache=False,
use_subgraph_label=self.args.use_subgraph_label
)

loader = FastSampler(self.args.num_workers,
self.test_loader = FastSampler(self.args.num_workers,
self.args.test_max_num_batches, cfg)
devit = self.test_transferer([self.main_device], iter(loader), pipeline_on = not self.args.pipeline_disabled)
devit = self.make_test_devit()

if self.is_main_proc:
pbar = tqdm(total=cfg.idx.numel())
if not dist.is_initialized():
pbar.set_description(f'Batchwise eval (one proc)')
else:
pbar.set_description(
'Batchwise eval (multi proc, showing main proc progress)')
pass # TODO taking out update
# pbar = tqdm(total=cfg.idx.numel())
# if not dist.is_initialized():
# pbar.set_description(f'Batchwise eval (one proc)')
# else:
# pbar.set_description(
# 'Batchwise eval (multi proc, showing main proc progress)')

def cb(batch):
if self.is_main_proc:
pbar.update(batch.batch_size)
pass
# if self.is_main_proc:
# TODO TAKE OUT pbar updates
# pbar.update(batch.batch_size)

with Timer((name, 'Compute'), self.log) as timer:
if hasattr(self.model, 'module'):
Expand All @@ -507,19 +547,27 @@ def cb(batch):
else:
self.model_noddp.load_state_dict(self.model.state_dict())
result = test.batchwise_test(
self.model_noddp, len(loader), devit, cb)

self.model_noddp, len(self.test_loader), devit, self.binary_class, cb)
results["microf1"] = result[2]
results["binaryf1"] = result[3]
results["auroc"] = result[4]
results["pr_roc"] = result[5]
# print("In base.py, got result", result)

timer.stop()
if self.is_main_proc:
pbar.close()
del pbar
pass # SkIp the below
# pbar.close()
# del pbar

if dist.is_initialized():
output_0 = torch.tensor([result[0]]).to(self.main_device)
output_1 = torch.tensor([result[1]]).to(self.main_device)
_ = dist.all_reduce(output_0)
_ = dist.all_reduce(output_1)
result = (output_0.item(), output_1.item())
results[name] = result[0] / result[1]

results[name] = result[0] / result[1]
# print(self.dataset.split_idx)
return results
Loading