diff --git a/benchmarks/cugraph/standalone/bulk_sampling/.gitignore b/benchmarks/cugraph/standalone/bulk_sampling/.gitignore new file mode 100644 index 00000000000..19cbd00ebe0 --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/.gitignore @@ -0,0 +1 @@ +mg_utils/ diff --git a/benchmarks/cugraph/standalone/bulk_sampling/README.md b/benchmarks/cugraph/standalone/bulk_sampling/README.md index f48eea5c556..bb01133c52f 100644 --- a/benchmarks/cugraph/standalone/bulk_sampling/README.md +++ b/benchmarks/cugraph/standalone/bulk_sampling/README.md @@ -1,11 +1,13 @@ -# cuGraph Bulk Sampling +# cuGraph Sampling Benchmarks -## Overview +## cuGraph Bulk Sampling + +### Overview The `cugraph_bulk_sampling.py` script runs the bulk sampler for a variety of datasets, including both generated (rmat) datasets and disk (ogbn_papers100M, etc.) datasets. It can also load replicas of these datasets to create a larger benchmark (i.e. ogbn_papers100M x2). -## Arguments +### Arguments The script takes a variety of arguments to control sampling behavior. Required: --output_root @@ -51,14 +53,8 @@ Optional: Seed for random number generation. Defaults to '62' - --persist - Whether to aggressively use persist() in dask to make the ETL steps (NOT PART OF SAMPLING) faster. - Will probably make this script finish sooner at the expense of memory usage, but won't affect - sampling time. - Changing this is not recommended unless you know what you are doing. - Defaults to False. -## Input Format +### Input Format The script expects its input data in the following format: ``` @@ -103,7 +99,7 @@ the parquet files. It must have the following format: } ``` -## Output Meta +### Output Meta The script, in addition to the samples, will also output a file named `output_meta.json`. This file contains various statistics about the sampling run, including the runtime, as well as information about the dataset and system that the samples were produced from. @@ -111,6 +107,56 @@ as well as information about the dataset and system that the samples were produc This metadata file can be used to gather the results from the sampling and training stages together. -## Other Notes +### Other Notes For rmat datasets, you will need to generate your own bogus features in the training stage. Since that is trivial, that is not done in this sampling script. + +## cuGraph MNMG Training + +### Overview +The script `run_train_job.sh` runs with the `sbatch` command to launch a series of slurm jobs. +First, for a given number of epochs, the script will produce samples for a given graph. +Then, the training process starts where samples are loaded and training iterations are +processed. + +### Important Notes +Downloading the dataset files before running the slurm jobs is highly recommended. Even though +the script will attempt to download the files if they are not available, this can often +lead to a timeout which will crash the scripts. This applies regardless of whether you are training +with native PyG or cuGraph-PyG. You can download data as follows: + +``` +from ogb.nodeproppred import NodePropPredDataset +dataset = NodePropPredDataset('ogbn-papers100M', root='/home/username/datasets') +``` + +For datasets other than ogbn-papers100M, you follow the same process but only change the dataset name. +The dataset will be correctly preprocessed when you run training. In case you have a slow system, you +can also run preprocessing by running the training script on a single worker, which will avoid a timeout +which crashes the script. + +The multi-GPU utilities are in `mg_utils` in the top level of the cuGraph repository. You should either +copy them to this directory or symlink to them before running the scripts. + +### Arguments +You will need to modify the bash scripts to run appopriately for your environment and +desired training workflow. The standard sbatch arguments are at the top of the script, such as +job name, queue, etc. These will need to be modified for your SLURM cluster. + +Next are arguments for the container image (required), +and directories where the data and outputs are stored. The directories default to subdirectories +of the current working directory. But if there is a high-throughput storage system available, +using that storage for the samples and datasets is highly recommended. + +Next are standard GNN training arguments such as `FANOUT`, `BATCH_SIZE`, etc. You can also set +the number of training epochs here. These are followed by the `REPLICATION_FACTOR` argument, which +can be used to create replications of the dataset for scale testing purposes. + +The final two arguments are `FRAMEWORK` which can be either "cuGraphPyG" or "PyG", and `GPUS_PER_NODE` +which must be set to the correct value, even if this is provided by a SLURM argument. If `GPUS_PER_NODE` +is not set to the correct number of GPUs, the script will hang indefinitely until it times out. Mismatched +GPUs per node is currently unsupported by this script but should be possible in practice. + +### Output +The results of training will be outputted to the logs directory with an `output.txt` file for each worker. +These will be overwritten upon each run. Accuracy is only reported on rank 0. \ No newline at end of file diff --git a/benchmarks/cugraph/standalone/bulk_sampling/bench_cugraph_training.py b/benchmarks/cugraph/standalone/bulk_sampling/bench_cugraph_training.py new file mode 100644 index 00000000000..c9e347b261d --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/bench_cugraph_training.py @@ -0,0 +1,251 @@ +# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +os.environ["RAPIDS_NO_INITIALIZE"] = "1" +os.environ["CUDF_SPILL"] = "1" +os.environ["LIBCUDF_CUFILE_POLICY"] = "KVIKIO" +os.environ["KVIKIO_NTHREADS"] = "8" + +import argparse +import json +import warnings + +import torch +import numpy as np +import pandas + +import torch.distributed as dist + +from datasets import OGBNPapers100MDataset + +from cugraph.testing.mg_utils import enable_spilling + + +def init_pytorch_worker(rank: int, use_rmm_torch_allocator: bool = False) -> None: + import cupy + import rmm + from pynvml.smi import nvidia_smi + + smi = nvidia_smi.getInstance() + pool_size = 16e9 # FIXME calculate this + + rmm.reinitialize( + devices=[rank], + pool_allocator=True, + initial_pool_size=pool_size, + ) + + if use_rmm_torch_allocator: + warnings.warn( + "Using the rmm pytorch allocator is currently unsupported." + " The default allocator will be used instead." + ) + # FIXME somehow get the pytorch allocator to work + # from rmm.allocators.torch import rmm_torch_allocator + # torch.cuda.memory.change_current_allocator(rmm_torch_allocator) + + from rmm.allocators.cupy import rmm_cupy_allocator + + cupy.cuda.set_allocator(rmm_cupy_allocator) + + cupy.cuda.Device(rank).use() + torch.cuda.set_device(rank) + + # Pytorch training worker initialization + torch.distributed.init_process_group(backend="nccl") + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--gpus_per_node", + type=int, + default=8, + help="# GPUs per node", + required=False, + ) + + parser.add_argument( + "--num_epochs", + type=int, + default=1, + help="Number of training epochs", + required=False, + ) + + parser.add_argument( + "--batch_size", + type=int, + default=512, + help="Batch size", + required=False, + ) + + parser.add_argument( + "--fanout", + type=str, + default="10_10_10", + help="Fanout", + required=False, + ) + + parser.add_argument( + "--sample_dir", + type=str, + help="Directory with stored bulk samples (required for cuGraph run)", + required=False, + ) + + parser.add_argument( + "--output_file", + type=str, + help="File to store results", + required=True, + ) + + parser.add_argument( + "--framework", + type=str, + help="The framework to test (PyG, cuGraphPyG)", + required=True, + ) + + parser.add_argument( + "--model", + type=str, + default="GraphSAGE", + help="The model to use (currently only GraphSAGE supported)", + required=False, + ) + + parser.add_argument( + "--replication_factor", + type=int, + default=1, + help="The replication factor for the dataset", + required=False, + ) + + parser.add_argument( + "--dataset_dir", + type=str, + help="The directory where datasets are stored", + required=True, + ) + + parser.add_argument( + "--train_split", + type=float, + help="The percentage of the labeled data to use for training. The remainder is used for testing/validation.", + default=0.8, + required=False, + ) + + parser.add_argument( + "--val_split", + type=float, + help="The percentage of the testing/validation data to allocate for validation.", + default=0.5, + required=False, + ) + + return parser.parse_args() + + +def main(args): + import logging + + logging.basicConfig( + level=logging.INFO, + ) + logger = logging.getLogger("bench_cugraph_training") + logger.setLevel(logging.INFO) + + local_rank = int(os.environ["LOCAL_RANK"]) + global_rank = int(os.environ["RANK"]) + + init_pytorch_worker( + local_rank, use_rmm_torch_allocator=(args.framework == "cuGraph") + ) + enable_spilling() + print(f"worker initialized") + dist.barrier() + + world_size = int(os.environ["SLURM_JOB_NUM_NODES"]) * args.gpus_per_node + + dataset = OGBNPapers100MDataset( + replication_factor=args.replication_factor, + dataset_dir=args.dataset_dir, + train_split=args.train_split, + val_split=args.val_split, + load_edge_index=(args.framework == "PyG"), + ) + + if global_rank == 0: + dataset.download() + dist.barrier() + + fanout = [int(f) for f in args.fanout.split("_")] + + if args.framework == "PyG": + from trainers.pyg import PyGNativeTrainer + + trainer = PyGNativeTrainer( + model=args.model, + dataset=dataset, + device=local_rank, + rank=global_rank, + world_size=world_size, + num_epochs=args.num_epochs, + shuffle=True, + replace=False, + num_neighbors=fanout, + batch_size=args.batch_size, + ) + elif args.framework == "cuGraphPyG": + sample_dir = os.path.join( + args.sample_dir, + f"ogbn_papers100M[{args.replication_factor}]_b{args.batch_size}_f{fanout}", + ) + from trainers.pyg import PyGCuGraphTrainer + + trainer = PyGCuGraphTrainer( + model=args.model, + dataset=dataset, + sample_dir=sample_dir, + device=local_rank, + rank=global_rank, + world_size=world_size, + num_epochs=args.num_epochs, + shuffle=True, + replace=False, + num_neighbors=fanout, + batch_size=args.batch_size, + ) + else: + raise ValueError("unsupported framework") + + logger.info(f"Trainer ready on rank {global_rank}") + stats = trainer.train() + logger.info(stats) + + with open(f"{args.output_file}[{global_rank}]", "w") as f: + json.dump(stats, f) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/benchmarks/cugraph/standalone/bulk_sampling/bulk_sampling.sh b/benchmarks/cugraph/standalone/bulk_sampling/bulk_sampling.sh deleted file mode 100755 index e62cb3cda29..00000000000 --- a/benchmarks/cugraph/standalone/bulk_sampling/bulk_sampling.sh +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -export RAPIDS_NO_INITIALIZE="1" -export CUDF_SPILL="1" -export LIBCUDF_CUFILE_POLICY=OFF - - -dataset_name=$1 -dataset_root=$2 -output_root=$3 -batch_sizes=$4 -fanouts=$5 -reverse_edges=$6 - -rm -rf $output_root -mkdir -p $output_root - -# Change to 2 in Selene -gpu_per_replica=4 -#--add_edge_ids \ - -# Expand to 1, 4, 8 in Selene -for i in 1,2,3,4: -do - for replication in 2; - do - dataset_name_with_replication="${dataset_name}[${replication}]" - dask_worker_devices=$(seq -s, 0 $((gpu_per_replica*replication-1))) - echo "Sampling dataset = $dataset_name_with_replication on devices = $dask_worker_devices" - python3 cugraph_bulk_sampling.py --datasets $dataset_name_with_replication \ - --dataset_root $dataset_root \ - --batch_sizes $batch_sizes \ - --output_root $output_root \ - --dask_worker_devices $dask_worker_devices \ - --fanouts $fanouts \ - --batch_sizes $batch_sizes \ - --reverse_edges - done -done \ No newline at end of file diff --git a/benchmarks/cugraph/standalone/bulk_sampling/cugraph_bulk_sampling.py b/benchmarks/cugraph/standalone/bulk_sampling/cugraph_bulk_sampling.py index 9de6c3a2b01..e3a5bba3162 100644 --- a/benchmarks/cugraph/standalone/bulk_sampling/cugraph_bulk_sampling.py +++ b/benchmarks/cugraph/standalone/bulk_sampling/cugraph_bulk_sampling.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -97,19 +97,15 @@ def symmetrize_ddf(dask_dataframe): return new_ddf -def renumber_ddf(dask_df, persist=False): +def renumber_ddf(dask_df): vertices = ( dask_cudf.concat([dask_df["src"], dask_df["dst"]]) .unique() .reset_index(drop=True) ) - if persist: - vertices = vertices.persist() vertices.name = "v" vertices = vertices.reset_index().set_index("v").rename(columns={"index": "m"}) - if persist: - vertices = vertices.persist() src = dask_df.merge(vertices, left_on="src", right_on="v", how="left").m.rename( "src" @@ -170,7 +166,7 @@ def _replicate_df( if replication_factor > 1: for r in range(1, replication_factor): - df_replicated = original_df + df_replicated = original_df.copy() for col, offset in col_item_counts.items(): df_replicated[col] += offset * r @@ -189,46 +185,75 @@ def sample_graph( seeds_per_call=400000, batches_per_partition=100, fanout=[5, 5, 5], + num_epochs=1, + train_perc=0.8, + val_perc=0.5, sampling_kwargs={}, ): cupy.random.seed(seed) - - sampler = BulkSampler( - batch_size=batch_size, - output_path=output_path, - graph=G, - fanout_vals=fanout, - with_replacement=False, - random_state=seed, - seeds_per_call=seeds_per_call, - batches_per_partition=batches_per_partition, - log_level=logging.INFO, - **sampling_kwargs, + train_df, test_df = label_df.random_split( + [train_perc, 1 - train_perc], random_state=seed, shuffle=True + ) + val_df, test_df = label_df.random_split( + [val_perc, 1 - val_perc], random_state=seed, shuffle=True ) - n_workers = len(default_client().scheduler_info()["workers"]) + total_time = 0.0 + for epoch in range(num_epochs): + steps = [("train", train_df), ("test", test_df)] + if epoch == num_epochs - 1: + steps.append(("val", val_df)) - meta = cudf.DataFrame( - {"node": cudf.Series(dtype="int64"), "batch": cudf.Series(dtype="int32")} - ) + for step, batch_df in steps: + batch_df = batch_df.sample(frac=1.0, random_state=seed) - batch_df = label_df.map_partitions( - _make_batch_ids, batch_size, n_workers, meta=meta - ) - # batch_df = batch_df.sort_values(by='node') + if step == "val": + output_sample_path = os.path.join(output_path, "val", "samples") + else: + output_sample_path = os.path.join( + output_path, f"epoch={epoch}", f"{step}", "samples" + ) + os.makedirs(output_sample_path) + + sampler = BulkSampler( + batch_size=batch_size, + output_path=output_sample_path, + graph=G, + fanout_vals=fanout, + with_replacement=False, + random_state=seed, + seeds_per_call=seeds_per_call, + batches_per_partition=batches_per_partition, + log_level=logging.INFO, + **sampling_kwargs, + ) - # should always persist the batch dataframe or performance may be suboptimal - batch_df = batch_df.persist() + n_workers = len(default_client().scheduler_info()["workers"]) - del label_df - print("created batches") + meta = cudf.DataFrame( + { + "node": cudf.Series(dtype="int64"), + "batch": cudf.Series(dtype="int32"), + } + ) + + batch_df = batch_df.map_partitions( + _make_batch_ids, batch_size, n_workers, meta=meta + ) + + # should always persist the batch dataframe or performance may be suboptimal + batch_df = batch_df.persist() + + print("created batches") - start_time = perf_counter() - sampler.add_batches(batch_df, start_col_name="node", batch_col_name="batch") - sampler.flush() - end_time = perf_counter() - print("flushed all batches") - return end_time - start_time + start_time = perf_counter() + sampler.add_batches(batch_df, start_col_name="node", batch_col_name="batch") + sampler.flush() + end_time = perf_counter() + print("flushed all batches") + total_time += end_time - start_time + + return total_time def assign_offsets_pyg(node_counts: Dict[str, int], replication_factor: int = 1): @@ -253,7 +278,6 @@ def generate_rmat_dataset( labeled_percentage=0.01, num_labels=256, reverse_edges=False, - persist=False, add_edge_types=False, ): """ @@ -282,12 +306,8 @@ def generate_rmat_dataset( dask_edgelist_df = dask_edgelist_df.reset_index(drop=True) dask_edgelist_df = renumber_ddf(dask_edgelist_df).persist() - if persist: - dask_edgelist_df = dask_edgelist_df.persist() dask_edgelist_df = symmetrize_ddf(dask_edgelist_df).persist() - if persist: - dask_edgelist_df = dask_edgelist_df.persist() if add_edge_types: dask_edgelist_df["etp"] = cupy.int32( @@ -329,7 +349,6 @@ def load_disk_dataset( dataset_dir=".", reverse_edges=True, replication_factor=1, - persist=False, add_edge_types=False, ): from pathlib import Path @@ -363,8 +382,6 @@ def load_disk_dataset( ] edge_index_dict[can_edge_type] = edge_index_dict[can_edge_type] - if persist: - edge_index_dict = edge_index_dict.persist() if replication_factor > 1: edge_index_dict[can_edge_type] = edge_index_dict[ @@ -384,11 +401,6 @@ def load_disk_dataset( ), ) - if persist: - edge_index_dict[can_edge_type] = edge_index_dict[ - can_edge_type - ].persist() - gc.collect() if reverse_edges: @@ -396,9 +408,6 @@ def load_disk_dataset( columns={"src": "dst", "dst": "src"} ) - if persist: - edge_index_dict[can_edge_type] = edge_index_dict[can_edge_type].persist() - # Assign numeric edge type ids based on lexicographic order edge_offsets = {} edge_count = 0 @@ -410,9 +419,6 @@ def load_disk_dataset( all_edges_df = dask_cudf.concat(list(edge_index_dict.values())) - if persist: - all_edges_df = all_edges_df.persist() - del edge_index_dict gc.collect() @@ -440,15 +446,9 @@ def load_disk_dataset( meta=cudf.DataFrame({"node": cudf.Series(dtype="int64")}), ) - if persist: - node_labels[node_type] = node_labels[node_type].persist() - gc.collect() - node_labels_df = dask_cudf.concat(list(node_labels.values())) - - if persist: - node_labels_df = node_labels_df.persist() + node_labels_df = dask_cudf.concat(list(node_labels.values())).reset_index(drop=True) del node_labels gc.collect() @@ -475,8 +475,8 @@ def benchmark_cugraph_bulk_sampling( replication_factor=1, num_labels=256, labeled_percentage=0.001, - persist=False, add_edge_types=False, + num_epochs=1, ): """ Entry point for the benchmark. @@ -506,14 +506,17 @@ def benchmark_cugraph_bulk_sampling( labeled_percentage: float The percentage of the data that is labeled (only for rmat datasets) Defaults to 0.001 to match papers100M - persist: bool - Whether to aggressively persist data in dask in attempt to speed up ETL. - Defaults to False. add_edge_types: bool Whether to add edge types to the edgelist. Defaults to False. + sampling_target_framework: str + The framework to sample for. + num_epochs: int + The number of epochs to sample for. """ - print(dataset) + + logger = logging.getLogger("__main__") + logger.info(str(dataset)) if dataset[0:4] == "rmat": ( dask_edgelist_df, @@ -527,7 +530,6 @@ def benchmark_cugraph_bulk_sampling( seed=seed, labeled_percentage=labeled_percentage, num_labels=num_labels, - persist=persist, add_edge_types=add_edge_types, ) @@ -543,28 +545,25 @@ def benchmark_cugraph_bulk_sampling( dataset_dir=dataset_dir, reverse_edges=reverse_edges, replication_factor=replication_factor, - persist=persist, add_edge_types=add_edge_types, ) num_input_edges = len(dask_edgelist_df) - print(f"Number of input edges = {num_input_edges:,}") + logger.info(f"Number of input edges = {num_input_edges:,}") G = construct_graph(dask_edgelist_df) del dask_edgelist_df - print("constructed graph") + logger.info("constructed graph") input_memory = G.edgelist.edgelist_df.memory_usage().sum().compute() - print(f"input memory: {input_memory}") + logger.info(f"input memory: {input_memory}") output_subdir = os.path.join( - output_path, f"{dataset}[{replication_factor}]_b{batch_size}_f{fanout}" + output_path, + f"{dataset}[{replication_factor}]_b{batch_size}_f{fanout}", ) os.makedirs(output_subdir) - output_sample_path = os.path.join(output_subdir, "samples") - os.makedirs(output_sample_path) - if sampling_target_framework == "cugraph_dgl_csr": sampling_kwargs = { "deduplicate_sources": True, @@ -587,11 +586,12 @@ def benchmark_cugraph_bulk_sampling( "include_hop_column": True, } - batches_per_partition = 400_000 // batch_size + batches_per_partition = 600_000 // batch_size execution_time, allocation_counts = sample_graph( G=G, label_df=dask_label_df, - output_path=output_sample_path, + output_path=output_subdir, + num_epochs=num_epochs, seed=seed, batch_size=batch_size, seeds_per_call=seeds_per_call, @@ -620,8 +620,8 @@ def benchmark_cugraph_bulk_sampling( with open(os.path.join(output_subdir, "output_meta.json"), "w") as f: json.dump(output_meta, f, indent="\t") - print("allocation counts b:") - print(allocation_counts.values()) + logger.info("allocation counts b:") + logger.info(allocation_counts.values()) ( input_to_peak_ratio, @@ -631,8 +631,8 @@ def benchmark_cugraph_bulk_sampling( ) = get_memory_statistics( allocation_counts=allocation_counts, input_memory=input_memory ) - print(f"Number of edges in final graph = {G.number_of_edges():,}") - print("-" * 80) + logger.info(f"Number of edges in final graph = {G.number_of_edges():,}") + logger.info("-" * 80) return ( num_input_edges, input_to_peak_ratio, @@ -693,12 +693,20 @@ def get_args(): required=True, ) + parser.add_argument( + "--num_epochs", + type=int, + help="Number of epochs to run for", + required=False, + default=1, + ) + parser.add_argument( "--fanouts", type=str, - help="Comma separated list of fanouts (i.e. 10_25,5_5_5)", + help='Comma separated list of fanouts (i.e. "10_25,5_5_5")', required=False, - default="10_25", + default="10_10_10", ) parser.add_argument( @@ -743,28 +751,14 @@ def get_args(): "--random_seed", type=int, help="Random seed", required=False, default=62 ) - parser.add_argument( - "--persist", - action="store_true", - help="Will add additional persist() calls to speed up ETL. Does not affect sampling runtime.", - required=False, - default=False, - ) - - parser.add_argument( - "--add_edge_types", - action="store_true", - help="Adds edge types to the edgelist. Required for PyG if not providing edge ids.", - required=False, - default=False, - ) - return parser.parse_args() # call __main__ function if __name__ == "__main__": logging.basicConfig() + logger = logging.getLogger("__main__") + logger.setLevel(logging.INFO) args = get_args() if args.sampling_target_framework not in ["cugraph_dgl_csr", None]: @@ -781,29 +775,28 @@ def get_args(): seeds_per_call_opts = [int(s) for s in args.seeds_per_call_opts.split(",")] dask_worker_devices = [int(d) for d in args.dask_worker_devices.split(",")] - client, cluster = start_dask_client( - dask_worker_devices=dask_worker_devices, - jit_unspill=False, - rmm_pool_size=28e9, - rmm_async=True, - ) + logger.info("starting dask client") + client, cluster = start_dask_client() enable_spilling() stats_ls = [] client.run(enable_spilling) + logger.info("dask client started") for dataset in datasets: - if re.match(r"([A-z]|[0-9])+\[[0-9]+\]", dataset): - replication_factor = int(dataset[-2]) - dataset = dataset[:-3] + m = re.match(r"(\w+)\[([0-9]+)\]", dataset) + if m: + replication_factor = int(m.groups()[1]) + dataset = m.groups()[0] else: replication_factor = 1 for fanout in fanouts: for batch_size in batch_sizes: for seeds_per_call in seeds_per_call_opts: - print(f"dataset: {dataset}") - print(f"batch size: {batch_size}") - print(f"fanout: {fanout}") - print(f"seeds_per_call: {seeds_per_call}") + logger.info(f"dataset: {dataset}") + logger.info(f"batch size: {batch_size}") + logger.info(f"fanout: {fanout}") + logger.info(f"seeds_per_call: {seeds_per_call}") + logger.info(f"num epochs: {args.num_epochs}") try: stats_d = {} @@ -816,6 +809,7 @@ def get_args(): ) = benchmark_cugraph_bulk_sampling( dataset=dataset, output_path=args.output_root, + num_epochs=args.num_epochs, seed=args.random_seed, batch_size=batch_size, seeds_per_call=seeds_per_call, @@ -824,8 +818,6 @@ def get_args(): dataset_dir=args.dataset_root, reverse_edges=args.reverse_edges, replication_factor=replication_factor, - persist=args.persist, - add_edge_types=args.add_edge_types, ) stats_d["dataset"] = dataset stats_d["num_input_edges"] = num_input_edges diff --git a/benchmarks/cugraph/standalone/bulk_sampling/datasets/__init__.py b/benchmarks/cugraph/standalone/bulk_sampling/datasets/__init__.py new file mode 100644 index 00000000000..0f4b516cd80 --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/datasets/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .dataset import Dataset +from .ogbn_papers100M import OGBNPapers100MDataset diff --git a/benchmarks/cugraph/standalone/bulk_sampling/datasets/dataset.py b/benchmarks/cugraph/standalone/bulk_sampling/datasets/dataset.py new file mode 100644 index 00000000000..f914f69fa4e --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/datasets/dataset.py @@ -0,0 +1,55 @@ +# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from typing import Dict, Tuple + + +class Dataset: + @property + def edge_index_dict(self) -> Dict[Tuple[str, str, str], Dict[str, torch.Tensor]]: + raise NotImplementedError() + + @property + def x_dict(self) -> Dict[str, torch.Tensor]: + raise NotImplementedError() + + @property + def y_dict(self) -> Dict[str, torch.Tensor]: + raise NotImplementedError() + + @property + def train_dict(self) -> Dict[str, torch.Tensor]: + raise NotImplementedError() + + @property + def test_dict(self) -> Dict[str, torch.Tensor]: + raise NotImplementedError() + + @property + def val_dict(self) -> Dict[str, torch.Tensor]: + raise NotImplementedError() + + @property + def num_input_features(self) -> int: + raise NotImplementedError() + + @property + def num_labels(self) -> int: + raise NotImplementedError() + + def num_nodes(self, node_type: str) -> int: + raise NotImplementedError() + + def num_edges(self, edge_type: Tuple[str, str, str]) -> int: + raise NotImplementedError() diff --git a/benchmarks/cugraph/standalone/bulk_sampling/datasets/ogbn_papers100M.py b/benchmarks/cugraph/standalone/bulk_sampling/datasets/ogbn_papers100M.py new file mode 100644 index 00000000000..a50e40f6d55 --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/datasets/ogbn_papers100M.py @@ -0,0 +1,345 @@ +# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .dataset import Dataset +from typing import Dict, Tuple, Union + +import pandas +import torch +import numpy as np + +from sklearn.model_selection import train_test_split + +import gc +import os +import json + + +class OGBNPapers100MDataset(Dataset): + def __init__( + self, + *, + replication_factor=1, + dataset_dir=".", + train_split=0.8, + val_split=0.5, + load_edge_index=True, + ): + self.__replication_factor = replication_factor + self.__disk_x = None + self.__y = None + self.__edge_index = None + self.__dataset_dir = dataset_dir + self.__train_split = train_split + self.__val_split = val_split + self.__load_edge_index = load_edge_index + + def download(self): + import logging + + logger = logging.getLogger("OGBNPapers100MDataset") + logger.info("Processing dataset...") + + dataset_path = os.path.join(self.__dataset_dir, "ogbn_papers100M") + + meta_json_path = os.path.join(dataset_path, "meta.json") + if not os.path.exists(meta_json_path): + j = { + "num_nodes": {"paper": 111059956}, + "num_edges": {"paper__cites__paper": 1615685872}, + } + with open(meta_json_path, "w") as file: + json.dump(j, file) + + dataset = None + if not os.path.exists(dataset_path): + from ogb.nodeproppred import NodePropPredDataset + + dataset = NodePropPredDataset( + name="ogbn-papers100M", root=self.__dataset_dir + ) + + features_path = os.path.join(dataset_path, "npy", "paper") + os.makedirs(features_path, exist_ok=True) + + logger.info("Processing node features...") + if self.__replication_factor == 1: + replication_path = os.path.join(features_path, "node_feat.npy") + else: + replication_path = os.path.join( + features_path, f"node_feat_{self.__replication_factor}x.npy" + ) + if not os.path.exists(replication_path): + if dataset is None: + from ogb.nodeproppred import NodePropPredDataset + + dataset = NodePropPredDataset( + name="ogbn-papers100M", root=self.__dataset_dir + ) + + node_feat = dataset[0][0]["node_feat"] + if self.__replication_factor != 1: + node_feat_replicated = np.concat( + [node_feat] * self.__replication_factor + ) + node_feat = node_feat_replicated + np.save(replication_path, node_feat) + + logger.info("Processing edge index...") + edge_index_parquet_path = os.path.join( + dataset_path, "parquet", "paper__cites__paper" + ) + os.makedirs(edge_index_parquet_path, exist_ok=True) + + edge_index_parquet_file_path = os.path.join( + edge_index_parquet_path, "edge_index.parquet" + ) + if not os.path.exists(edge_index_parquet_file_path): + if dataset is None: + from ogb.nodeproppred import NodePropPredDataset + + dataset = NodePropPredDataset( + name="ogbn-papers100M", root=self.__dataset_dir + ) + + edge_index = dataset[0][0]["edge_index"] + eidf = pandas.DataFrame({"src": edge_index[0], "dst": edge_index[1]}) + eidf.to_parquet(edge_index_parquet_file_path) + + edge_index_npy_path = os.path.join(dataset_path, "npy", "paper__cites__paper") + os.makedirs(edge_index_npy_path, exist_ok=True) + + edge_index_npy_file_path = os.path.join(edge_index_npy_path, "edge_index.npy") + if not os.path.exists(edge_index_npy_file_path): + if dataset is None: + from ogb.nodeproppred import NodePropPredDataset + + dataset = NodePropPredDataset( + name="ogbn-papers100M", root=self.__dataset_dir + ) + + edge_index = dataset[0][0]["edge_index"] + np.save(edge_index_npy_file_path, edge_index) + + logger.info("Processing labels...") + node_label_path = os.path.join(dataset_path, "parquet", "paper") + os.makedirs(node_label_path, exist_ok=True) + + node_label_file_path = os.path.join(node_label_path, "node_label.parquet") + if not os.path.exists(node_label_file_path): + if dataset is None: + from ogb.nodeproppred import NodePropPredDataset + + dataset = NodePropPredDataset( + name="ogbn-papers100M", root=self.__dataset_dir + ) + + ldf = pandas.Series(dataset[0][1].T[0]) + ldf = ( + ldf[ldf >= 0] + .reset_index() + .rename(columns={"index": "node", 0: "label"}) + ) + ldf.to_parquet(node_label_file_path) + + @property + def edge_index_dict( + self, + ) -> Dict[Tuple[str, str, str], Union[Dict[str, torch.Tensor], int]]: + import logging + + logger = logging.getLogger("OGBNPapers100MDataset") + + if self.__edge_index is None: + if self.__load_edge_index: + npy_path = os.path.join( + self.__dataset_dir, + "ogbn_papers100M", + "npy", + "paper__cites__paper", + "edge_index.npy", + ) + + logger.info(f"loading edge index from {npy_path}") + ei = np.load(npy_path, mmap_mode="r") + ei = torch.as_tensor(ei) + ei = { + "src": ei[1], + "dst": ei[0], + } + + logger.info("sorting edge index...") + ei["dst"], ix = torch.sort(ei["dst"]) + ei["src"] = ei["src"][ix] + del ix + gc.collect() + + logger.info("processing replications...") + orig_num_nodes = self.num_nodes("paper") // self.__replication_factor + if self.__replication_factor > 1: + orig_src = ei["src"].clone().detach() + orig_dst = ei["dst"].clone().detach() + for r in range(1, self.__replication_factor): + ei["src"] = torch.concat( + [ + ei["src"], + orig_src + int(r * orig_num_nodes), + ] + ) + + ei["dst"] = torch.concat( + [ + ei["dst"], + orig_dst + int(r * orig_num_nodes), + ] + ) + + del orig_src + del orig_dst + + ei["src"] = ei["src"].contiguous() + ei["dst"] = ei["dst"].contiguous() + gc.collect() + + logger.info(f"# edges: {len(ei['src'])}") + self.__edge_index = {("paper", "cites", "paper"): ei} + else: + self.__edge_index = { + ("paper", "cites", "paper"): self.num_edges( + ("paper", "cites", "paper") + ) + } + + return self.__edge_index + + @property + def x_dict(self) -> Dict[str, torch.Tensor]: + node_type_path = os.path.join( + self.__dataset_dir, "ogbn_papers100M", "npy", "paper" + ) + + if self.__disk_x is None: + if self.__replication_factor == 1: + full_path = os.path.join(node_type_path, "node_feat.npy") + else: + full_path = os.path.join( + node_type_path, f"node_feat_{self.__replication_factor}x.npy" + ) + + self.__disk_x = {"paper": np.load(full_path, mmap_mode="r")} + + return self.__disk_x + + @property + def y_dict(self) -> Dict[str, torch.Tensor]: + if self.__y is None: + self.__get_labels() + + return self.__y + + @property + def train_dict(self) -> Dict[str, torch.Tensor]: + if self.__train is None: + self.__get_labels() + return self.__train + + @property + def test_dict(self) -> Dict[str, torch.Tensor]: + if self.__test is None: + self.__get_labels() + return self.__test + + @property + def val_dict(self) -> Dict[str, torch.Tensor]: + if self.__val is None: + self.__get_labels() + return self.__val + + @property + def num_input_features(self) -> int: + return int(self.x_dict["paper"].shape[1]) + + @property + def num_labels(self) -> int: + return int(self.y_dict["paper"].max()) + 1 + + def num_nodes(self, node_type: str) -> int: + if node_type != "paper": + raise ValueError(f"Invalid node type {node_type}") + + return 111_059_956 * self.__replication_factor + + def num_edges(self, edge_type: Tuple[str, str, str]) -> int: + if edge_type != ("paper", "cites", "paper"): + raise ValueError(f"Invalid edge type {edge_type}") + + return 1_615_685_872 * self.__replication_factor + + def __get_labels(self): + label_path = os.path.join( + self.__dataset_dir, + "ogbn_papers100M", + "parquet", + "paper", + "node_label.parquet", + ) + + node_label = pandas.read_parquet(label_path) + + if self.__replication_factor > 1: + orig_num_nodes = self.num_nodes("paper") // self.__replication_factor + dfr = pandas.DataFrame( + { + "node": pandas.concat( + [ + node_label.node + (r * orig_num_nodes) + for r in range(1, self.__replication_factor) + ] + ), + "label": pandas.concat( + [node_label.label for r in range(1, self.__replication_factor)] + ), + } + ) + node_label = pandas.concat([node_label, dfr]).reset_index(drop=True) + + num_nodes = self.num_nodes("paper") + node_label_tensor = torch.full( + (num_nodes,), -1, dtype=torch.float32, device="cpu" + ) + node_label_tensor[ + torch.as_tensor(node_label.node.values, device="cpu") + ] = torch.as_tensor(node_label.label.values, device="cpu") + + self.__y = {"paper": node_label_tensor.contiguous()} + + train_ix, test_val_ix = train_test_split( + torch.as_tensor(node_label.node.values), + train_size=self.__train_split, + random_state=num_nodes, + ) + test_ix, val_ix = train_test_split( + test_val_ix, test_size=self.__val_split, random_state=num_nodes + ) + + train_tensor = torch.full((num_nodes,), 0, dtype=torch.bool, device="cpu") + train_tensor[train_ix] = 1 + self.__train = {"paper": train_tensor} + + test_tensor = torch.full((num_nodes,), 0, dtype=torch.bool, device="cpu") + test_tensor[test_ix] = 1 + self.__test = {"paper": test_tensor} + + val_tensor = torch.full((num_nodes,), 0, dtype=torch.bool, device="cpu") + val_tensor[val_ix] = 1 + self.__val = {"paper": val_tensor} diff --git a/benchmarks/cugraph/standalone/bulk_sampling/models/__init__.py b/benchmarks/cugraph/standalone/bulk_sampling/models/__init__.py new file mode 100644 index 00000000000..c2002fd3fb9 --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/models/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/benchmarks/cugraph/standalone/bulk_sampling/models/pyg/__init__.py b/benchmarks/cugraph/standalone/bulk_sampling/models/pyg/__init__.py new file mode 100644 index 00000000000..337cb0fa243 --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/models/pyg/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .models_cugraph_pyg import CuGraphSAGE +from .models_pyg import GraphSAGE diff --git a/benchmarks/cugraph/standalone/bulk_sampling/models/pyg/models_cugraph_pyg.py b/benchmarks/cugraph/standalone/bulk_sampling/models/pyg/models_cugraph_pyg.py new file mode 100644 index 00000000000..1de791bf588 --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/models/pyg/models_cugraph_pyg.py @@ -0,0 +1,78 @@ +# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from cugraph_pyg.nn.conv import SAGEConv as CuGraphSAGEConv + +try: + from torch_geometric.utils.trim_to_layer import TrimToLayer +except ModuleNotFoundError: + from torch_geometric.utils._trim_to_layer import TrimToLayer + +import torch.nn as nn +import torch.nn.functional as F + + +def extend_tensor(t: torch.Tensor, l: int): + return torch.concat([t, torch.zeros(l - len(t), dtype=t.dtype, device=t.device)]) + + +class CuGraphSAGE(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, num_layers): + super().__init__() + + self.convs = torch.nn.ModuleList() + self.convs.append(CuGraphSAGEConv(in_channels, hidden_channels, aggr="mean")) + for _ in range(num_layers - 2): + conv = CuGraphSAGEConv(hidden_channels, hidden_channels, aggr="mean") + self.convs.append(conv) + + self.convs.append(CuGraphSAGEConv(hidden_channels, out_channels, aggr="mean")) + + self._trim = TrimToLayer() + + def forward(self, x, edge, num_sampled_nodes, num_sampled_edges): + if isinstance(edge, torch.Tensor): + edge = list( + CuGraphSAGEConv.to_csc( + edge.cuda(), (x.shape[0], num_sampled_nodes.sum()) + ) + ) + else: + edge = edge.csr() + edge = [edge[1], edge[0], x.shape[0]] + + x = x.cuda().to(torch.float32) + + for i, conv in enumerate(self.convs): + if i > 0: + new_num_edges = edge[1][-2] + edge[0] = edge[0].narrow( + dim=0, + start=0, + length=new_num_edges, + ) + edge[1] = edge[1].narrow( + dim=0, start=0, length=edge[1].size(0) - num_sampled_nodes[-i - 1] + ) + edge[2] = x.shape[0] + + x = conv(x, edge) + + x = F.relu(x) + x = F.dropout(x, p=0.5) + + x = x.narrow(dim=0, start=0, length=num_sampled_nodes[0]) + + return x diff --git a/benchmarks/cugraph/standalone/bulk_sampling/models/pyg/models_pyg.py b/benchmarks/cugraph/standalone/bulk_sampling/models/pyg/models_pyg.py new file mode 100644 index 00000000000..37f98d5362d --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/models/pyg/models_pyg.py @@ -0,0 +1,58 @@ +# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from torch_geometric.nn import SAGEConv + +try: + from torch_geometric.utils.trim_to_layer import TrimToLayer +except ModuleNotFoundError: + from torch_geometric.utils._trim_to_layer import TrimToLayer + +import torch.nn as nn +import torch.nn.functional as F + + +class GraphSAGE(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, num_layers): + super().__init__() + + self.convs = torch.nn.ModuleList() + self.convs.append(SAGEConv(in_channels, hidden_channels, aggr="mean")) + for _ in range(num_layers - 2): + conv = SAGEConv(hidden_channels, hidden_channels, aggr="mean") + self.convs.append(conv) + + self.convs.append(SAGEConv(hidden_channels, out_channels, aggr="mean")) + + self._trim = TrimToLayer() + + def forward(self, x, edge, num_sampled_nodes, num_sampled_edges): + edge = edge.cuda() + x = x.cuda().to(torch.float32) + + for i, conv in enumerate(self.convs): + x, edge, _ = self._trim( + i, num_sampled_nodes, num_sampled_edges, x, edge, None + ) + + s = x.shape[0] + x = conv(x, edge, size=(s, s)) + x = F.relu(x) + x = F.dropout(x, p=0.5) + + x = x.narrow(dim=0, start=0, length=x.shape[0] - num_sampled_nodes[1]) + + # assert x.shape[0] == num_sampled_nodes[0] + return x diff --git a/benchmarks/cugraph/standalone/bulk_sampling/run_sampling.sh b/benchmarks/cugraph/standalone/bulk_sampling/run_sampling.sh new file mode 100644 index 00000000000..41792c0b63a --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/run_sampling.sh @@ -0,0 +1,111 @@ +#!/bin/bash +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +conda init +source ~/.bashrc +conda activate rapids + +BATCH_SIZE=$1 +FANOUT=$2 +REPLICATION_FACTOR=$3 +SCRIPTS_DIR=$4 +NUM_EPOCHS=$5 + +SAMPLES_DIR=/samples +DATASET_DIR=/datasets +LOGS_DIR=/logs + +MG_UTILS_DIR=${SCRIPTS_DIR}/mg_utils +SCHEDULER_FILE=${MG_UTILS_DIR}/dask_scheduler.json + +export WORKER_RMM_POOL_SIZE=28G +export UCX_MAX_RNDV_RAILS=1 +export RAPIDS_NO_INITIALIZE=1 +export CUDF_SPILL=1 +export LIBCUDF_CUFILE_POLICY="OFF" +export GPUS_PER_NODE=8 + +export SCHEDULER_FILE=$SCHEDULER_FILE +export LOGS_DIR=$LOGS_DIR + +function handleTimeout { + seconds=$1 + eval "timeout --signal=2 --kill-after=60 $*" + LAST_EXITCODE=$? + if (( $LAST_EXITCODE == 124 )); then + logger "ERROR: command timed out after ${seconds} seconds" + elif (( $LAST_EXITCODE == 137 )); then + logger "ERROR: command timed out after ${seconds} seconds, and had to be killed with signal 9" + fi + ERRORCODE=$((ERRORCODE | ${LAST_EXITCODE})) +} + +DASK_STARTUP_ERRORCODE=0 +if [[ $SLURM_NODEID == 0 ]]; then + ${MG_UTILS_DIR}/run-dask-process.sh scheduler workers & +else + ${MG_UTILS_DIR}/run-dask-process.sh workers & +fi + +echo "properly waiting for workers to connect" +NUM_GPUS=$(python -c "import os; print(int(os.environ['SLURM_JOB_NUM_NODES'])*int(os.environ['GPUS_PER_NODE']))") +handleTimeout 120 python ${MG_UTILS_DIR}/wait_for_workers.py \ + --num-expected-workers ${NUM_GPUS} \ + --scheduler-file-path ${SCHEDULER_FILE} + + +DASK_STARTUP_ERRORCODE=$LAST_EXITCODE + +echo $SLURM_NODEID +if [[ $SLURM_NODEID == 0 ]]; then + echo "Launching Python Script" + python ${SCRIPTS_DIR}/cugraph_bulk_sampling.py \ + --output_root ${SAMPLES_DIR} \ + --dataset_root ${DATASET_DIR} \ + --datasets "ogbn_papers100M["$REPLICATION_FACTOR"]" \ + --fanouts $FANOUT \ + --batch_sizes $BATCH_SIZE \ + --seeds_per_call_opts "524288" \ + --num_epochs $NUM_EPOCHS \ + --random_seed 42 + + echo "DONE" > ${SAMPLES_DIR}/status.txt +fi + +while [ ! -f "${SAMPLES_DIR}"/status.txt ] +do + sleep 1 +done + +sleep 3 + +# At this stage there should be no running processes except /usr/lpp/mmfs/bin/mmsysmon.py +dask_processes=$(pgrep -la dask) +python_processes=$(pgrep -la python) +echo "$dask_processes" +echo "$python_processes" + +if [[ ${#python_processes[@]} -gt 1 || $dask_processes ]]; then + logger "The client was not shutdown properly, killing dask/python processes for Node $SLURM_NODEID" + # This can be caused by a job timeout + pkill python + pkill dask + pgrep -la python + pgrep -la dask +fi +sleep 2 + +if [[ $SLURM_NODEID == 0 ]]; then + rm ${SAMPLES_DIR}/status.txt +fi \ No newline at end of file diff --git a/benchmarks/cugraph/standalone/bulk_sampling/run_train_job.sh b/benchmarks/cugraph/standalone/bulk_sampling/run_train_job.sh new file mode 100755 index 00000000000..977745a9593 --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/run_train_job.sh @@ -0,0 +1,84 @@ +#!/bin/bash +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#SBATCH -A datascience_rapids_cugraphgnn +#SBATCH -p luna +#SBATCH -J datascience_rapids_cugraphgnn-papers:bulkSamplingPyG +#SBATCH -N 1 +#SBATCH -t 00:25:00 + +CONTAINER_IMAGE=${CONTAINER_IMAGE:="please_specify_container"} +SCRIPTS_DIR=$(pwd) +LOGS_DIR=${LOGS_DIR:=$(pwd)"/logs"} +SAMPLES_DIR=${SAMPLES_DIR:=$(pwd)/samples} +DATASETS_DIR=${DATASETS_DIR:=$(pwd)/datasets} + +mkdir -p $LOGS_DIR +mkdir -p $SAMPLES_DIR +mkdir -p $DATASETS_DIR + +BATCH_SIZE=512 +FANOUT="10_10_10" +NUM_EPOCHS=1 +REPLICATION_FACTOR=1 + +# options: PyG or cuGraphPyG +FRAMEWORK="cuGraphPyG" +GPUS_PER_NODE=8 + +nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) ) +nodes_array=($nodes) +head_node=${nodes_array[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) + +echo Node IP: $head_node_ip + +nnodes=$SLURM_JOB_NUM_NODES +echo Num Nodes: $nnodes + +gpus_per_node=$GPUS_PER_NODE +echo Num GPUs Per Node: $gpus_per_node + +set -e + +# First run without cuGraph to get data + +if [[ "$FRAMEWORK" == "cuGraphPyG" ]]; then + # Generate samples + srun \ + --container-image $CONTAINER_IMAGE \ + --container-mounts=${LOGS_DIR}":/logs",${SAMPLES_DIR}":/samples",${SCRIPTS_DIR}":/scripts",${DATASETS_DIR}":/datasets" \ + bash /scripts/run_sampling.sh $BATCH_SIZE $FANOUT $REPLICATION_FACTOR "/scripts" $NUM_EPOCHS +fi + +# Train +srun \ + --container-image $CONTAINER_IMAGE \ + --container-mounts=${LOGS_DIR}":/logs",${SAMPLES_DIR}":/samples",${SCRIPTS_DIR}":/scripts",${DATASETS_DIR}":/datasets" \ + torchrun \ + --nnodes $nnodes \ + --nproc-per-node $gpus_per_node \ + --rdzv-id $RANDOM \ + --rdzv-backend c10d \ + --rdzv-endpoint $head_node_ip:29500 \ + /scripts/bench_cugraph_training.py \ + --output_file "/logs/output.txt" \ + --framework $FRAMEWORK \ + --dataset_dir "/datasets" \ + --sample_dir "/samples" \ + --batch_size $BATCH_SIZE \ + --fanout $FANOUT \ + --replication_factor $REPLICATION_FACTOR \ + --num_epochs $NUM_EPOCHS + diff --git a/benchmarks/cugraph/standalone/bulk_sampling/trainers/__init__.py b/benchmarks/cugraph/standalone/bulk_sampling/trainers/__init__.py new file mode 100644 index 00000000000..5f8f4c2b868 --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/trainers/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .trainer import Trainer +from .trainer import extend_tensor diff --git a/benchmarks/cugraph/standalone/bulk_sampling/trainers/pyg/__init__.py b/benchmarks/cugraph/standalone/bulk_sampling/trainers/pyg/__init__.py new file mode 100644 index 00000000000..def6110b8e5 --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/trainers/pyg/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .trainers_cugraph_pyg import PyGCuGraphTrainer +from .trainers_pyg import PyGNativeTrainer diff --git a/benchmarks/cugraph/standalone/bulk_sampling/trainers/pyg/trainers_cugraph_pyg.py b/benchmarks/cugraph/standalone/bulk_sampling/trainers/pyg/trainers_cugraph_pyg.py new file mode 100644 index 00000000000..71151e9ba59 --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/trainers/pyg/trainers_cugraph_pyg.py @@ -0,0 +1,184 @@ +# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .trainers_pyg import PyGTrainer +from models.pyg import CuGraphSAGE + +import torch +import numpy as np + +from torch.nn.parallel import DistributedDataParallel as ddp + +from cugraph.gnn import FeatureStore +from cugraph_pyg.data import CuGraphStore +from cugraph_pyg.loader import BulkSampleLoader + +import os + + +class PyGCuGraphTrainer(PyGTrainer): + def __init__( + self, + dataset, + model="GraphSAGE", + device=0, + rank=0, + world_size=1, + num_epochs=1, + sample_dir=".", + **kwargs, + ): + self.__data = None + self.__device = device + self.__rank = rank + self.__world_size = world_size + self.__num_epochs = num_epochs + self.__dataset = dataset + self.__sample_dir = sample_dir + self.__loader_kwargs = kwargs + self.__model = self.get_model(model) + self.__optimizer = None + + @property + def rank(self): + return self.__rank + + @property + def model(self): + return self.__model + + @property + def dataset(self): + return self.__dataset + + @property + def optimizer(self): + if self.__optimizer is None: + self.__optimizer = torch.optim.Adam( + self.model.parameters(), lr=0.01, weight_decay=0.0005 + ) + return self.__optimizer + + @property + def num_epochs(self) -> int: + return self.__num_epochs + + def get_loader(self, epoch: int = 0, stage="train") -> int: + import logging + + logger = logging.getLogger("PyGCuGraphTrainer") + + logger.info(f"getting loader for epoch {epoch}, {stage} stage") + + # TODO support online sampling + if stage == "val": + path = os.path.join(self.__sample_dir, "val", "samples") + else: + path = os.path.join(self.__sample_dir, f"epoch={epoch}", stage, "samples") + + loader = BulkSampleLoader( + self.data, + self.data, + None, # FIXME get input nodes properly + directory=path, + input_files=self.get_input_files(path, epoch=epoch, stage=stage), + **self.__loader_kwargs, + ) + + logger.info(f"got loader successfully on rank {self.rank}") + return loader + + @property + def data(self): + import logging + + logger = logging.getLogger("PyGCuGraphTrainer") + logger.info("getting data") + + if self.__data is None: + # FIXME wholegraph + fs = FeatureStore(backend="torch") + num_nodes_dict = {} + + for node_type, x in self.__dataset.x_dict.items(): + logger.debug(f"getting x for {node_type}") + fs.add_data(x, node_type, "x") + num_nodes_dict[node_type] = self.__dataset.num_nodes(node_type) + + for node_type, y in self.__dataset.y_dict.items(): + logger.debug(f"getting y for {node_type}") + fs.add_data(y, node_type, "y") + + for node_type, train in self.__dataset.train_dict.items(): + logger.debug(f"getting train for {node_type}") + fs.add_data(train, node_type, "train") + + for node_type, test in self.__dataset.test_dict.items(): + logger.debug(f"getting test for {node_type}") + fs.add_data(test, node_type, "test") + + for node_type, val in self.__dataset.val_dict.items(): + logger.debug(f"getting val for {node_type}") + fs.add_data(val, node_type, "val") + + # TODO support online sampling if the edge index is provided + num_edges_dict = self.__dataset.edge_index_dict + if not isinstance(list(num_edges_dict.values())[0], int): + num_edges_dict = {k: len(v) for k, v in num_edges_dict} + + self.__data = CuGraphStore( + fs, + num_edges_dict, + num_nodes_dict, + ) + + logger.info(f"got data successfully on rank {self.rank}") + + return self.__data + + def get_model(self, name="GraphSAGE"): + if name != "GraphSAGE": + raise ValueError("only GraphSAGE is currently supported") + + num_input_features = self.__dataset.num_input_features + num_output_features = self.__dataset.num_labels + num_layers = len(self.__loader_kwargs["num_neighbors"]) + + with torch.cuda.device(self.__device): + model = ( + CuGraphSAGE( + in_channels=num_input_features, + hidden_channels=64, + out_channels=num_output_features, + num_layers=num_layers, + ) + .to(torch.float32) + .to(self.__device) + ) + + model = ddp(model, device_ids=[self.__device]) + print("done creating model") + + return model + + def get_input_files(self, path, epoch=0, stage="train"): + file_list = np.array(os.listdir(path)) + file_list.sort() + + if stage == "train": + splits = np.array_split(file_list, self.__world_size) + np.random.seed(epoch) + np.random.shuffle(splits) + return splits[self.rank] + else: + return file_list diff --git a/benchmarks/cugraph/standalone/bulk_sampling/trainers/pyg/trainers_pyg.py b/benchmarks/cugraph/standalone/bulk_sampling/trainers/pyg/trainers_pyg.py new file mode 100644 index 00000000000..bddd6ae2644 --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/trainers/pyg/trainers_pyg.py @@ -0,0 +1,430 @@ +# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from trainers import Trainer +from trainers import extend_tensor +from datasets import OGBNPapers100MDataset +from models.pyg import GraphSAGE + +import torch +import numpy as np + +import torch.distributed as td +from torch.nn.parallel import DistributedDataParallel as ddp +import torch.nn.functional as F + +from torch_geometric.utils.sparse import index2ptr +from torch_geometric.data import HeteroData +from torch_geometric.loader import NeighborLoader + +import gc +import os +import time + + +def pyg_num_workers(world_size): + num_workers = None + if hasattr(os, "sched_getaffinity"): + try: + num_workers = len(os.sched_getaffinity(0)) / (2 * world_size) + except Exception: + pass + if num_workers is None: + num_workers = os.cpu_count() / (2 * world_size) + return int(num_workers) + + +class PyGTrainer(Trainer): + def train(self): + import logging + + logger = logging.getLogger("PyGTrainer") + logger.info("Entered train loop") + + total_loss = 0.0 + num_batches = 0 + + time_forward = 0.0 + time_backward = 0.0 + time_loader = 0.0 + time_feature_transfer = 0.0 + start_time = time.perf_counter() + end_time_backward = start_time + + for epoch in range(self.num_epochs): + with td.algorithms.join.Join( + [self.model], divide_by_initial_world_size=False + ): + self.model.train() + for iter_i, data in enumerate( + self.get_loader(epoch=epoch, stage="train") + ): + loader_time_iter = time.perf_counter() - end_time_backward + time_loader += loader_time_iter + + time_feature_transfer_start = time.perf_counter() + + num_sampled_nodes = sum( + [ + torch.as_tensor(n) + for n in data.num_sampled_nodes_dict.values() + ] + ) + num_sampled_edges = sum( + [ + torch.as_tensor(e) + for e in data.num_sampled_edges_dict.values() + ] + ) + + # FIXME find a way to get around this and not have to call extend_tensor + num_layers = len(self.model.module.convs) + num_sampled_nodes = extend_tensor(num_sampled_nodes, num_layers + 1) + num_sampled_edges = extend_tensor(num_sampled_edges, num_layers) + + data = data.to_homogeneous().cuda() + time_feature_transfer_end = time.perf_counter() + time_feature_transfer += ( + time_feature_transfer_end - time_feature_transfer_start + ) + + num_batches += 1 + if iter_i % 20 == 1: + time_forward_iter = time_forward / num_batches + time_backward_iter = time_backward / num_batches + + total_time_iter = ( + time.perf_counter() - start_time + ) / num_batches + logger.info(f"epoch {epoch}, iteration {iter_i}") + logger.info(f"num sampled nodes: {num_sampled_nodes}") + logger.info(f"num sampled edges: {num_sampled_edges}") + logger.info(f"time forward: {time_forward_iter}") + logger.info(f"time backward: {time_backward_iter}") + logger.info(f"loader time: {loader_time_iter}") + logger.info( + f"feature transfer time: {time_feature_transfer / num_batches}" + ) + logger.info(f"total time: {total_time_iter}") + + y_true = data.y + x = data.x.to(torch.float32) + + start_time_forward = time.perf_counter() + edge_index = data.edge_index if "edge_index" in data else data.adj_t + + self.optimizer.zero_grad() + y_pred = self.model( + x, + edge_index, + num_sampled_nodes, + num_sampled_edges, + ) + + end_time_forward = time.perf_counter() + time_forward += end_time_forward - start_time_forward + + if y_pred.shape[0] > len(y_true): + raise ValueError( + f"illegal shape: {y_pred.shape}; {y_true.shape}" + ) + + y_true = y_true[: y_pred.shape[0]] + + y_true = F.one_hot( + y_true.to(torch.int64), num_classes=self.dataset.num_labels + ).to(torch.float32) + + if y_true.shape != y_pred.shape: + raise ValueError( + f"y_true shape was {y_true.shape} " + f"but y_pred shape was {y_pred.shape} " + f"in iteration {iter_i} " + f"on rank {y_pred.device.index}" + ) + + start_time_backward = time.perf_counter() + loss = F.cross_entropy(y_pred, y_true) + + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + total_loss += loss.item() + end_time_backward = time.perf_counter() + time_backward += end_time_backward - start_time_backward + + end_time = time.perf_counter() + + # test + from torchmetrics import Accuracy + + acc = Accuracy( + task="multiclass", num_classes=self.dataset.num_labels + ).cuda() + + with td.algorithms.join.Join( + [self.model], divide_by_initial_world_size=False + ): + self.model.eval() + if self.rank == 0: + acc_sum = 0.0 + with torch.no_grad(): + for i, batch in enumerate( + self.get_loader(epoch=epoch, stage="test") + ): + num_sampled_nodes = sum( + [ + torch.as_tensor(n) + for n in batch.num_sampled_nodes_dict.values() + ] + ) + num_sampled_edges = sum( + [ + torch.as_tensor(e) + for e in batch.num_sampled_edges_dict.values() + ] + ) + batch_size = num_sampled_nodes[0] + + batch = batch.to_homogeneous().cuda() + + batch.y = batch.y.to(torch.long) + out = self.model.module( + batch.x, + batch.edge_index, + num_sampled_nodes, + num_sampled_edges, + ) + acc_sum += acc( + out[:batch_size].softmax(dim=-1), batch.y[:batch_size] + ) + print( + f"Accuracy: {acc_sum/(i) * 100.0:.4f}%", + ) + + td.barrier() + + with td.algorithms.join.Join([self.model], divide_by_initial_world_size=False): + self.model.eval() + if self.rank == 0: + acc_sum = 0.0 + with torch.no_grad(): + for i, batch in enumerate( + self.get_loader(epoch=epoch, stage="val") + ): + num_sampled_nodes = sum( + [ + torch.as_tensor(n) + for n in batch.num_sampled_nodes_dict.values() + ] + ) + num_sampled_edges = sum( + [ + torch.as_tensor(e) + for e in batch.num_sampled_edges_dict.values() + ] + ) + batch_size = num_sampled_nodes[0] + + batch = batch.to_homogeneous().cuda() + + batch.y = batch.y.to(torch.long) + out = self.model.module( + batch.x, + batch.edge_index, + num_sampled_nodes, + num_sampled_edges, + ) + acc_sum += acc( + out[:batch_size].softmax(dim=-1), batch.y[:batch_size] + ) + print( + f"Validation Accuracy: {acc_sum/(i) * 100.0:.4f}%", + ) + + stats = { + "Accuracy": float(acc_sum / (i) * 100.0) if self.rank == 0 else 0.0, + "# Batches": num_batches, + "Loader Time": time_loader, + "Feature Transfer Time": time_feature_transfer, + "Forward Time": time_forward, + "Backward Time": time_backward, + } + return stats + + +class PyGNativeTrainer(PyGTrainer): + def __init__( + self, + dataset, + model="GraphSAGE", + device=0, + rank=0, + world_size=1, + num_epochs=1, + **kwargs, + ): + self.__dataset = dataset + self.__device = device + self.__data = None + self.__rank = rank + self.__num_epochs = num_epochs + self.__world_size = world_size + self.__loader_kwargs = kwargs + self.__model = self.get_model(model) + self.__optimizer = None + + @property + def rank(self): + return self.__rank + + @property + def model(self): + return self.__model + + @property + def dataset(self): + return self.__dataset + + @property + def data(self): + import logging + + logger = logging.getLogger("PyGNativeTrainer") + logger.info("getting data") + + if self.__data is None: + self.__data = HeteroData() + + for node_type, x in self.__dataset.x_dict.items(): + logger.debug(f"getting x for {node_type}") + self.__data[node_type].x = x + self.__data[node_type]["num_nodes"] = self.__dataset.num_nodes( + node_type + ) + + for node_type, y in self.__dataset.y_dict.items(): + logger.debug(f"getting y for {node_type}") + self.__data[node_type]["y"] = y + + for node_type, train in self.__dataset.train_dict.items(): + logger.debug(f"getting train for {node_type}") + self.__data[node_type]["train"] = train + + for node_type, test in self.__dataset.test_dict.items(): + logger.debug(f"getting test for {node_type}") + self.__data[node_type]["test"] = test + + for node_type, val in self.__dataset.val_dict.items(): + logger.debug(f"getting val for {node_type}") + self.__data[node_type]["val"] = val + + for can_edge_type, ei in self.__dataset.edge_index_dict.items(): + logger.info("converting to csc...") + ei["dst"] = index2ptr( + ei["dst"], self.__dataset.num_nodes(can_edge_type[2]) + ) + + logger.info("updating data structure...") + self.__data.put_edge_index( + layout="csc", + edge_index=list(ei.values()), + edge_type=can_edge_type, + size=( + self.__dataset.num_nodes(can_edge_type[0]), + self.__dataset.num_nodes(can_edge_type[2]), + ), + is_sorted=True, + ) + gc.collect() + + return self.__data + + @property + def optimizer(self): + if self.__optimizer is None: + self.__optimizer = torch.optim.Adam( + self.model.parameters(), lr=0.01, weight_decay=0.0005 + ) + return self.__optimizer + + @property + def num_epochs(self) -> int: + return self.__num_epochs + + def get_loader(self, epoch: int = 0, stage="train"): + import logging + + logger = logging.getLogger("PyGNativeTrainer") + logger.info(f"Getting loader for epoch {epoch}") + + if stage == "train": + mask_dict = self.__dataset.train_dict + elif stage == "test": + mask_dict = self.__dataset.test_dict + elif stage == "val": + mask_dict = self.__dataset.val_dict + else: + raise ValueError(f"Invalid stage {stage}") + + input_nodes_dict = { + node_type: np.array_split(np.arange(len(mask))[mask], self.__world_size)[ + self.__rank + ] + for node_type, mask in mask_dict.items() + } + + input_nodes = list(input_nodes_dict.items()) + if len(input_nodes) > 1: + raise ValueError("Multiple input node types currently unsupported") + else: + input_nodes = tuple(input_nodes[0]) + + # get loader + loader = NeighborLoader( + self.data, + input_nodes=input_nodes, + is_sorted=True, + disjoint=False, + num_workers=pyg_num_workers(self.__world_size), # FIXME change this + persistent_workers=True, + **self.__loader_kwargs, # batch size, num neighbors, replace, shuffle, etc. + ) + + logger.info("done creating loader") + return loader + + def get_model(self, name="GraphSAGE"): + if name != "GraphSAGE": + raise ValueError("only GraphSAGE is currently supported") + + num_input_features = self.__dataset.num_input_features + num_output_features = self.__dataset.num_labels + num_layers = len(self.__loader_kwargs["num_neighbors"]) + + with torch.cuda.device(self.__device): + model = ( + GraphSAGE( + in_channels=num_input_features, + hidden_channels=64, + out_channels=num_output_features, + num_layers=num_layers, + ) + .to(torch.float32) + .to(self.__device) + ) + model = ddp(model, device_ids=[self.__device]) + print("done creating model") + + return model diff --git a/benchmarks/cugraph/standalone/bulk_sampling/trainers/trainer.py b/benchmarks/cugraph/standalone/bulk_sampling/trainers/trainer.py new file mode 100644 index 00000000000..321edbea96e --- /dev/null +++ b/benchmarks/cugraph/standalone/bulk_sampling/trainers/trainer.py @@ -0,0 +1,54 @@ +# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from typing import Union, List + + +def extend_tensor(t: Union[List[int], torch.Tensor], l: int): + t = torch.as_tensor(t) + + return torch.concat([t, torch.zeros(l - len(t), dtype=t.dtype, device=t.device)]) + + +class Trainer: + @property + def rank(self): + raise NotImplementedError() + + @property + def model(self): + raise NotImplementedError() + + @property + def dataset(self): + raise NotImplementedError() + + @property + def data(self): + raise NotImplementedError() + + @property + def optimizer(self): + raise NotImplementedError() + + @property + def num_epochs(self) -> int: + raise NotImplementedError() + + def get_loader(self, epoch: int = 0, stage="train"): + raise NotImplementedError() + + def train(self): + raise NotImplementedError() diff --git a/cpp/src/community/flatten_dendrogram.hpp b/cpp/src/community/flatten_dendrogram.hpp index c0186983904..a4299f17d52 100644 --- a/cpp/src/community/flatten_dendrogram.hpp +++ b/cpp/src/community/flatten_dendrogram.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/mg_utils/wait_for_workers.py b/mg_utils/wait_for_workers.py new file mode 100644 index 00000000000..fa75c90d4ad --- /dev/null +++ b/mg_utils/wait_for_workers.py @@ -0,0 +1,124 @@ +# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import time +import yaml + +from dask.distributed import Client + + +def initialize_dask_cuda(communication_type): + communication_type = communication_type.lower() + if "ucx" in communication_type: + os.environ["UCX_MAX_RNDV_RAILS"] = "1" + + if communication_type == "ucx-ib": + os.environ["UCX_MEMTYPE_REG_WHOLE_ALLOC_TYPES"]="cuda" + os.environ["DASK_RMM__POOL_SIZE"]="0.5GB" + os.environ["DASK_DISTRIBUTED__COMM__UCX__CREATE_CUDA_CONTEXT"]="True" + + +def wait_for_workers( + num_expected_workers, scheduler_file_path, communication_type, timeout_after=0 +): + """ + Waits until num_expected_workers workers are available based on + the workers managed by scheduler_file_path, then returns 0. If + timeout_after is specified, will return 1 if num_expected_workers + workers are not available before the timeout. + """ + # FIXME: use scheduler file path from global environment if none + # supplied in configuration yaml + + print("wait_for_workers.py - initializing client...", end="") + sys.stdout.flush() + initialize_dask_cuda(communication_type) + print("done.") + sys.stdout.flush() + + ready = False + start_time = time.time() + while not ready: + if timeout_after and ((time.time() - start_time) >= timeout_after): + print( + f"wait_for_workers.py timed out after {timeout_after} seconds before finding {num_expected_workers} workers." + ) + sys.stdout.flush() + break + with Client(scheduler_file=scheduler_file_path) as client: + num_workers = len(client.scheduler_info()["workers"]) + if num_workers < num_expected_workers: + print( + f"wait_for_workers.py expected {num_expected_workers} but got {num_workers}, waiting..." + ) + sys.stdout.flush() + time.sleep(5) + else: + print(f"wait_for_workers.py got {num_workers} workers, done.") + sys.stdout.flush() + ready = True + + if ready is False: + return 1 + return 0 + + +if __name__ == "__main__": + import argparse + + ap = argparse.ArgumentParser() + ap.add_argument( + "--num-expected-workers", + type=int, + required=False, + help="Number of workers to wait for. If not specified, " + "uses the NUM_WORKERS env var if set, otherwise defaults " + "to 16.", + ) + ap.add_argument( + "--scheduler-file-path", + type=str, + required=True, + help="Path to shared scheduler file to read.", + ) + ap.add_argument( + "--communication-type", + type=str, + default="tcp", + required=False, + help="Initiliaze dask_cuda based on the cluster communication type." + "Supported values are tcp(default), ucx, ucxib, ucx-ib.", + ) + ap.add_argument( + "--timeout-after", + type=int, + default=0, + required=False, + help="Number of seconds to wait for workers. " + "Default is 0 which means wait forever.", + ) + args = ap.parse_args() + + if args.num_expected_workers is None: + args.num_expected_workers = os.environ.get("NUM_WORKERS", 16) + + exitcode = wait_for_workers( + num_expected_workers=args.num_expected_workers, + scheduler_file_path=args.scheduler_file_path, + communication_type=args.communication_type, + timeout_after=args.timeout_after, + ) + + sys.exit(exitcode) diff --git a/python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py b/python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py index 8a1db4edf29..bcfaf579820 100644 --- a/python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py +++ b/python/cugraph-pyg/cugraph_pyg/loader/cugraph_node_loader.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -151,9 +151,25 @@ def __init__( self.__input_files = iter(input_files) return - input_type, input_nodes = torch_geometric.loader.utils.get_input_nodes( - (feature_store, graph_store), input_nodes + # To accommodate DLFW/PyG 2.5 + get_input_nodes = torch_geometric.loader.utils.get_input_nodes + get_input_nodes_kwargs = {} + if "input_id" in get_input_nodes.__annotations__: + get_input_nodes_kwargs["input_id"] = None + input_node_info = get_input_nodes( + (feature_store, graph_store), input_nodes, **get_input_nodes_kwargs ) + + # PyG 2.4 + if len(input_node_info) == 2: + input_type, input_nodes = input_node_info + # PyG 2.5 + elif len(input_node_info) == 3: + input_type, input_nodes, input_id = input_node_info + # Invalid + else: + raise ValueError("Invalid output from get_input_nodes") + if input_type is not None: input_nodes = graph_store._get_sample_from_vertex_groups( {input_type: input_nodes} @@ -439,7 +455,12 @@ def __next__(self): start_time_feature = perf_counter() # Create a PyG HeteroData object, loading the required features if self.__coo: - out = torch_geometric.loader.utils.filter_custom_store( + pyg_filter_fn = ( + torch_geometric.loader.utils.filter_custom_hetero_store + if hasattr(torch_geometric.loader.utils, "filter_custom_hetero_store") + else torch_geometric.loader.utils.filter_custom_store + ) + out = pyg_filter_fn( self.__feature_store, self.__graph_store, sampler_output.node, diff --git a/python/cugraph-pyg/cugraph_pyg/sampler/cugraph_sampler.py b/python/cugraph-pyg/cugraph_pyg/sampler/cugraph_sampler.py index 300ca9beb5a..65cb63d25e0 100644 --- a/python/cugraph-pyg/cugraph_pyg/sampler/cugraph_sampler.py +++ b/python/cugraph-pyg/cugraph_pyg/sampler/cugraph_sampler.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -216,7 +216,6 @@ def _sampler_output_from_sampling_results_homogeneous_csr( if renumber_map is None: raise ValueError("Renumbered input is expected for homogeneous graphs") - node_type = graph_store.node_types[0] edge_type = graph_store.edge_types[0] diff --git a/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_store.py b/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_store.py index b39ebad8254..c99fd447aa0 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_store.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/test_cugraph_store.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -365,10 +365,20 @@ def test_get_input_nodes(karate_gnn): F, G, N = karate_gnn cugraph_store = CuGraphStore(F, G, N) - node_type, input_nodes = torch_geometric.loader.utils.get_input_nodes( + input_node_info = torch_geometric.loader.utils.get_input_nodes( (cugraph_store, cugraph_store), "type0" ) + # PyG 2.4 + if len(input_node_info) == 2: + node_type, input_nodes = input_node_info + # PyG 2.5 + elif len(input_node_info) == 3: + node_type, input_nodes, input_id = input_node_info + # Invalid + else: + raise ValueError("Invalid output from get_input_nodes") + assert node_type == "type0" assert input_nodes.tolist() == torch.arange(17, dtype=torch.int32).tolist() diff --git a/python/cugraph/cugraph/experimental/__init__.py b/python/cugraph/cugraph/experimental/__init__.py index d809e28c92e..7e8fd666972 100644 --- a/python/cugraph/cugraph/experimental/__init__.py +++ b/python/cugraph/cugraph/experimental/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at