Skip to content

Commit

Permalink
Fix out of index errors encountered with sampling on out of index sam…
Browse files Browse the repository at this point in the history
…ples (#2825)

THIS PR does the following

- [x] Ensure we dont sample on out of range values
Issue: #2828
- [x] Add tests for the sampling error
- [x] Ensure all the DGL examples here pass
https://github.com/rapidsai/dgl/blob/6ece904c69687adcd35a5ea41d1f5ca4ea01c0e2/examples/cugraph-pytorch/cugraph-local/rgcn-hetero/README.MD
- [x] Reformat out the non class specific utilities in prepration for DGL graph service class

Authors:
  - Vibhu Jawa (https://github.com/VibhuJawa)

Approvers:
  - Rick Ratzel (https://github.com/rlratzel)
  - Joseph Nke (https://github.com/jnke2016)

URL: #2825
  • Loading branch information
VibhuJawa authored Oct 20, 2022
1 parent 74ead42 commit 50ba399
Show file tree
Hide file tree
Showing 9 changed files with 468 additions and 273 deletions.
2 changes: 1 addition & 1 deletion python/cugraph/cugraph/gnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# limitations under the License.

from .graph_store import CuGraphStore
from .graph_store import CuFeatureStorage
from .dgl_extensions.feature_storage import CuFeatureStorage
Empty file.
100 changes: 100 additions & 0 deletions python/cugraph/cugraph/gnn/dgl_extensions/feature_storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (c) 2022, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import cudf
import dask_cudf
import cupy as cp
from cugraph.experimental import MGPropertyGraph


class CuFeatureStorage:
"""
Storage for node/edge feature data.
"""

def __init__(
self, pg, columns, storage_type, backend_lib="torch", indices_offset=0
):
self.pg = pg
self.columns = columns
if backend_lib == "torch":
from torch.utils.dlpack import from_dlpack
elif backend_lib == "tf":
from tensorflow.experimental.dlpack import from_dlpack
elif backend_lib == "cupy":
from cupy import from_dlpack
else:
raise NotImplementedError(
f"Only PyTorch ('torch'), TensorFlow ('tf'), and CuPy ('cupy') "
f"backends are currently supported, got {backend_lib=}"
)
if storage_type not in ["edge", "node"]:
raise NotImplementedError("Only edge and node storage is supported")

self.storage_type = storage_type

self.from_dlpack = from_dlpack
self.indices_offset = indices_offset

def fetch(self, indices, device=None, pin_memory=False, **kwargs):
"""Fetch the features of the given node/edge IDs to the
given device.
Parameters
----------
indices : Tensor
Node or edge IDs.
device : Device
Device context.
pin_memory :
Returns
-------
Tensor
Feature data stored in PyTorch Tensor.
"""
# Default implementation uses synchronous fetch.

indices = cp.asarray(indices)
if isinstance(self.pg, MGPropertyGraph):
# dask_cudf loc breaks if we provide cudf series/cupy array
# https://github.com/rapidsai/cudf/issues/11877
indices = indices.get()
else:
indices = cudf.Series(indices)

indices = indices + self.indices_offset

if self.storage_type == "node":
subset_df = self.pg.get_vertex_data(
vertex_ids=indices, columns=self.columns
)
else:
subset_df = self.pg.get_edge_data(edge_ids=indices, columns=self.columns)

subset_df = subset_df[self.columns]

if isinstance(subset_df, dask_cudf.DataFrame):
subset_df = subset_df.compute()

if len(subset_df) == 0:
raise ValueError(f"indices = {indices} not found in FeatureStorage")
cap = subset_df.to_dlpack()
tensor = self.from_dlpack(cap)
del cap
if device:
if not isinstance(tensor, cp.ndarray):
# Cant transfer to different device for cupy
tensor = tensor.to(device)
return tensor
12 changes: 12 additions & 0 deletions python/cugraph/cugraph/gnn/dgl_extensions/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) 2022, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
62 changes: 62 additions & 0 deletions python/cugraph/cugraph/gnn/dgl_extensions/utils/add_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) 2022, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Utils for adding data to cugraph graphstore objects


def _update_feature_map(
pg_feature_map, feat_name_obj, contains_vector_features, columns
):
"""
Update the existing feature map `pg_feature_map` based on `feat_name_obj`
"""
if contains_vector_features:
if feat_name_obj is None:
raise ValueError(
"feature name must be provided when wrapping"
+ " multiple columns under a single feature name"
+ " or a feature map"
)

if isinstance(feat_name_obj, str):
pg_feature_map[feat_name_obj] = columns

elif isinstance(feat_name_obj, dict):
covered_columns = []
for col in feat_name_obj.keys():
current_cols = feat_name_obj[col]
# Handle strings too
if isinstance(current_cols, str):
current_cols = [current_cols]
covered_columns = covered_columns + current_cols

if set(covered_columns) != set(columns):
raise ValueError(
f"All the columns {columns} not covered in {covered_columns} "
f"Please check the feature_map {feat_name_obj} provided"
)

for key, cols in feat_name_obj.items():
if isinstance(cols, str):
cols = [cols]
pg_feature_map[key] = cols
else:
raise ValueError(f"{feat_name_obj} should be str or dict")
else:
if feat_name_obj:
raise ValueError(
f"feat_name {feat_name_obj} is only valid when "
"wrapping multiple columns under feature names"
)
for col in columns:
pg_feature_map[col] = [col]
181 changes: 181 additions & 0 deletions python/cugraph/cugraph/gnn/dgl_extensions/utils/sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
# Copyright (c) 2022, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


# Utils for sampling on graphstore like objects
import cugraph
import cudf
import cupy as cp
import dask_cudf
from cugraph.experimental import PropertyGraph

src_n = PropertyGraph.src_col_name
dst_n = PropertyGraph.dst_col_name
type_n = PropertyGraph.type_col_name
eid_n = PropertyGraph.edge_id_col_name
vid_n = PropertyGraph.vertex_col_name


def get_subgraph_and_src_range_from_edgelist(edge_list, is_mg, reverse_edges=False):
if reverse_edges:
edge_list = edge_list.rename(columns={src_n: dst_n, dst_n: src_n})

subgraph = cugraph.MultiGraph(directed=True)
if is_mg:
# FIXME: Can not switch to renumber = False
# For MNMG Algos
# Remove when https://github.com/rapidsai/cugraph/issues/2437
# lands
create_subgraph_f = subgraph.from_dask_cudf_edgelist
renumber = True
edge_list = edge_list.persist()
src_range = edge_list[src_n].min().compute(), edge_list[src_n].max().compute()

else:
# Note: We have to keep renumber = False
# to handle cases when the seed_nodes is not present in subgraph
create_subgraph_f = subgraph.from_cudf_edgelist
renumber = False
src_range = edge_list[src_n].min(), edge_list[src_n].max()

create_subgraph_f(
edge_list,
source=src_n,
destination=dst_n,
edge_attr=eid_n,
renumber=renumber,
# FIXME: renumber=False is not supported for MNMG algos
legacy_renum_only=True,
)

return subgraph, src_range


def sample_multiple_sgs(
sgs,
sample_f,
start_list_d,
start_list_dtype,
edge_dir,
fanout,
with_replacement,
):
start_list_types = list(start_list_d.keys())
output_dfs = []
for can_etype, (sg, start_list_range) in sgs.items():
can_etype = _convert_can_etype_s_to_tup(can_etype)
if _edge_types_contains_canonical_etype(can_etype, start_list_types, edge_dir):
if edge_dir == "in":
subset_type = can_etype[2]
else:
subset_type = can_etype[0]
output = sample_single_sg(
sg,
sample_f,
start_list_d[subset_type],
start_list_dtype,
start_list_range,
fanout,
with_replacement,
)
output_dfs.append(output)
if len(output_dfs) == 0:
empty_df = cudf.DataFrame({"sources": [], "destinations": [], "indices": []})
return empty_df.astype(cp.int32)

if isinstance(output_dfs[0], dask_cudf.DataFrame):
return dask_cudf.concat(output_dfs, ignore_index=True)
else:
return cudf.concat(output_dfs, ignore_index=True)


def sample_single_sg(
sg,
sample_f,
start_list,
start_list_dtype,
start_list_range,
fanout,
with_replacement,
):
if isinstance(start_list, dict):
start_list = cudf.concat(list(start_list.values()))

# Uniform sampling fails when the dtype
# of the seed dtype is not same as the node dtype
start_list = start_list.astype(start_list_dtype)

# Filter start list by ranges
# to enure the seed is with in index values
# see below:
# https://github.com/rapidsai/cugraph/blob/branch-22.12/cpp/src/prims/per_v_random_select_transform_outgoing_e.cuh
start_list = start_list[
(start_list >= start_list_range[0]) & (start_list <= start_list_range[1])
]
sampled_df = sample_f(
sg,
start_list=start_list,
fanout_vals=[fanout],
with_replacement=with_replacement,
)
return sampled_df


def _edge_types_contains_canonical_etype(can_etype, edge_types, edge_dir):
src_type, _, dst_type = can_etype
if edge_dir == "in":
return dst_type in edge_types
else:
return src_type in edge_types


def _convert_can_etype_s_to_tup(canonical_etype_s):
src_type, etype, dst_type = canonical_etype_s.split(",")
src_type = src_type[2:-1]
dst_type = dst_type[2:-2]
etype = etype[2:-1]
return (src_type, etype, dst_type)


def create_dlpack_d(d):
dlpack_d = {}
for k, df in d.items():
if len(df) == 0:
dlpack_d[k] = (None, None, None)
else:
dlpack_d[k] = (
df[src_n].to_dlpack(),
df[dst_n].to_dlpack(),
df[eid_n].to_dlpack(),
)

return dlpack_d


def get_underlying_dtype_from_sg(sg):
"""
Returns the underlying dtype of the subgraph
"""
# FIXME: Remove after we have consistent naming
# https://github.com/rapidsai/cugraph/issues/2618
sg_columns = sg.edgelist.edgelist_df.columns
if "src" in sg_columns:
# src for single node graph
sg_node_dtype = sg.edgelist.edgelist_df["src"].dtype
elif src_n in sg_columns:
# _SRC_ for multi-node graphs
sg_node_dtype = sg.edgelist.edgelist_df[src_n].dtype
else:
raise ValueError(f"Source column {src_n} not found in the subgraph")

return sg_node_dtype
Loading

0 comments on commit 50ba399

Please sign in to comment.