Skip to content

Commit 8d068c9

Browse files
separate GNNGraphs from GNNlib
1 parent acf4b6a commit 8d068c9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+516
-113
lines changed

GNNGraphs/Project.toml

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
name = "GNNGraphs"
2+
uuid = "aed8fd31-079b-4b5a-b342-a13352159b8c"
3+
authors = ["Carlo Lucibello and contributors"]
4+
version = "0.1.0"
5+
6+
[deps]
7+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
8+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
9+
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
10+
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
11+
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
12+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
13+
LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553"
14+
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
15+
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
16+
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
17+
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
18+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
19+
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
20+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
21+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
22+
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
23+
24+
[weakdeps]
25+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
26+
SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622"
27+
28+
[extensions]
29+
GNNGraphsCUDAExt = "CUDA"
30+
GNNGraphsSimpleWeightedGraphsExt = "SimpleWeightedGraphs"
31+
32+
[compat]
33+
Adapt = "4"
34+
CUDA = "5"
35+
ChainRulesCore = "1"
36+
Functors = "0.4.1"
37+
Graphs = "1.4"
38+
KrylovKit = "0.8"
39+
LinearAlgebra = "1"
40+
LuxDeviceUtils = "0.1.24"
41+
MLDatasets = "0.7"
42+
MLUtils = "0.4"
43+
NNlib = "0.9"
44+
NearestNeighbors = "0.4"
45+
Random = "1"
46+
SimpleWeightedGraphs = "1.4.0"
47+
SparseArrays = "1"
48+
Statistics = "1"
49+
StatsBase = "0.34"
50+
cuDNN = "1"
51+
julia = "1.9"
52+
53+
[extras]
54+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
55+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
56+
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
57+
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
58+
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
59+
InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48"
60+
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
61+
SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622"
62+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
63+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
64+
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
65+
66+
[targets]
67+
test = ["Test", "Adapt", "DataFrames", "InlineStrings", "SimpleWeightedGraphs", "Zygote", "FiniteDifferences", "ChainRulesTestUtils", "MLDatasets", "CUDA", "cuDNN"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
module GNNGraphsCUDAExt
2+
3+
using CUDA
4+
using Random, Statistics, LinearAlgebra
5+
using GNNGraphs
6+
using GNNGraphs: COO_T, ADJMAT_T, SPARSE_T
7+
8+
const CUMAT_T = Union{CUDA.AnyCuMatrix, CUDA.CUSPARSE.CuSparseMatrix}
9+
10+
include("query.jl")
11+
include("transform.jl")
12+
include("utils.jl")
13+
14+
end #module

GNNlib/ext/GNNlibCUDAExt/GNNGraphs/utils.jl GNNGraphs/ext/GNNGraphsCUDAExt/utils.jl

+5-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ GNNGraphs.iscuarray(x::AnyCuArray) = true
33

44

55
function sort_edge_index(u::AnyCuArray, v::AnyCuArray)
6+
dev = get_device(u)
7+
cdev = cpu_device()
8+
u, v = u |> cdev, v |> cdev
69
#TODO proper cuda friendly implementation
7-
sort_edge_index(u |> Flux.cpu, v |> Flux.cpu) |> Flux.gpu
8-
end
10+
sort_edge_index(u, v) |> dev
11+
end

GNNlib/ext/GNNlibSimpleWeightedGraphsExt/GNNlibSimpleWeightedGraphsExt.jl GNNGraphs/ext/GNNGraphsSimpleWeightedGraphsExt/GNNGraphsSimpleWeightedGraphsExt.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
module GNNlibSimpleWeightedGraphsExt
1+
module GNNGraphsSimpleWeightedGraphsExt
22

3-
using GNNlib
43
using Graphs
4+
using GNNGraphs
55
using SimpleWeightedGraphs
66

7-
function GNNlib.GNNGraph(g::T; kws...) where
7+
function GNNGraphs.GNNGraph(g::T; kws...) where
88
{T <: Union{SimpleWeightedGraph, SimpleWeightedDiGraph}}
99
return GNNGraph(g.weights, kws...)
1010
end

GNNlib/src/GNNGraphs/GNNGraphs.jl GNNGraphs/src/GNNGraphs.jl

+6-5
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,16 @@ using Functors: @functor
55
import Graphs
66
using Graphs: AbstractGraph, outneighbors, inneighbors, adjacency_matrix, degree,
77
has_self_loops, is_directed
8-
import MLUtils
9-
using MLUtils: getobs, numobs, ones_like, zeros_like, batch
108
import NearestNeighbors
119
import NNlib
1210
import StatsBase
1311
import KrylovKit
1412
using ChainRulesCore
1513
using LinearAlgebra, Random, Statistics
1614
import MLUtils
15+
using MLUtils: getobs, numobs, ones_like, zeros_like, chunk, batch
1716
import Functors
17+
using LuxDeviceUtils: get_device, cpu_device, LuxCPUDevice
1818

1919
include("chainrules.jl") # hacks for differentiability
2020

@@ -78,7 +78,9 @@ export add_nodes,
7878
to_unidirected,
7979
random_walk_pe,
8080
remove_nodes,
81-
# from Flux
81+
ppr_diffusion,
82+
drop_nodes,
83+
# from MLUtils
8284
batch,
8385
unbatch,
8486
# from SparseArrays
@@ -101,8 +103,7 @@ include("operators.jl")
101103

102104
include("convert.jl")
103105
include("utils.jl")
104-
export sort_edge_index,
105-
color_refinement
106+
export sort_edge_index, color_refinement
106107

107108
include("gatherscatter.jl")
108109
# _gather, _scatter
File renamed without changes.
File renamed without changes.
File renamed without changes.

GNNlib/src/GNNGraphs/datastore.jl GNNGraphs/src/datastore.jl

+16-16
Original file line numberDiff line numberDiff line change
@@ -9,44 +9,44 @@ At construction time, the `data` can be provided as any iterables of pairs
99
of symbols and arrays or as keyword arguments:
1010
1111
```jldoctest
12-
julia> ds = DataStore(3, x = rand(Float32, 2, 3), y = rand(Float32, 3))
12+
julia> ds = DataStore(3, x = rand(2, 3), y = rand(3))
1313
DataStore(3) with 2 elements:
14-
y = 3-element Vector{Float32}
15-
x = 2×3 Matrix{Float32}
14+
y = 3-element Vector{Float64}
15+
x = 2×3 Matrix{Float64}
1616
17-
julia> ds = DataStore(3, Dict(:x => rand(Float32, 2, 3), :y => rand(Float32, 3))); # equivalent to above
17+
julia> ds = DataStore(3, Dict(:x => rand(2, 3), :y => rand(3))); # equivalent to above
1818
19-
julia> ds = DataStore(3, (x = rand(Float32, 2, 3), y = rand(Float32, 30)))
19+
julia> ds = DataStore(3, (x = rand(2, 3), y = rand(30)))
2020
ERROR: AssertionError: DataStore: data[y] has 30 observations, but n = 3
2121
Stacktrace:
2222
[1] DataStore(n::Int64, data::Dict{Symbol, Any})
23-
@ GNNlib.GNNGraphs ~/.julia/dev/GNNlib/src/GNNGraphs/datastore.jl:54
24-
[2] DataStore(n::Int64, data::NamedTuple{(:x, :y), Tuple{Matrix{Float32}, Vector{Float32}}})
25-
@ GNNlib.GNNGraphs ~/.julia/dev/GNNlib/src/GNNGraphs/datastore.jl:73
23+
@ GraphNeuralNetworks.GNNGraphs ~/.julia/dev/GraphNeuralNetworks/src/GNNGraphs/datastore.jl:54
24+
[2] DataStore(n::Int64, data::NamedTuple{(:x, :y), Tuple{Matrix{Float64}, Vector{Float64}}})
25+
@ GraphNeuralNetworks.GNNGraphs ~/.julia/dev/GraphNeuralNetworks/src/GNNGraphs/datastore.jl:73
2626
[3] top-level scope
2727
@ REPL[13]:1
2828
29-
julia> ds = DataStore(x = randFloat32, 2, 3), y = rand(Float32, 30)) # no checks
29+
julia> ds = DataStore(x = rand(2, 3), y = rand(30)) # no checks
3030
DataStore() with 2 elements:
31-
y = 30-element Vector{Float32}
32-
x = 2×3 Matrix{Float32}
31+
y = 30-element Vector{Float64}
32+
x = 2×3 Matrix{Float64}
3333
```
3434
3535
The `DataStore` has an interface similar to both dictionaries and named tuples.
3636
Arrays can be accessed and added using either the indexing or the property syntax:
3737
3838
```jldoctest
39-
julia> ds = DataStore(x = ones(Float32, 2, 3), y = zeros(Float32, 3))
39+
julia> ds = DataStore(x = ones(2, 3), y = zeros(3))
4040
DataStore() with 2 elements:
41-
y = 3-element Vector{Float32}
42-
x = 2×3 Matrix{Float32}
41+
y = 3-element Vector{Float64}
42+
x = 2×3 Matrix{Float64}
4343
4444
julia> ds.x # same as `ds[:x]`
45-
2×3 Matrix{Float32}:
45+
2×3 Matrix{Float64}:
4646
1.0 1.0 1.0
4747
1.0 1.0 1.0
4848
49-
julia> ds.z = zeros(Float32, 3) # Add new feature array `z`. Same as `ds[:z] = rand(Float32, 3)`
49+
julia> ds.z = zeros(3) # Add new feature array `z`. Same as `ds[:z] = rand(3)`
5050
3-element Vector{Float64}:
5151
0.0
5252
0.0
File renamed without changes.

GNNlib/src/GNNGraphs/generate.jl GNNGraphs/src/generate.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ julia> edge_index(g)
2626
([1, 3, 3, 4], [5, 4, 5, 2])
2727
2828
# In the bidirected case, edge data will be duplicated on the reverse edges if needed.
29-
julia> g = rand_graph(5, 4, edata=rand(Float32, 16, 2))
29+
julia> g = rand_graph(5, 4, edata=rand(16, 2))
3030
GNNGraph:
3131
num_nodes = 5
3232
num_edges = 4
@@ -173,7 +173,7 @@ to its `k` closest `points`.
173173
```jldoctest
174174
julia> n, k = 10, 3;
175175
176-
julia> x = rand(Float32, 3, n);
176+
julia> x = rand(3, n);
177177
178178
julia> g = knn_graph(x, k)
179179
GNNGraph:
@@ -254,7 +254,7 @@ to its neighbors within a given distance `r`.
254254
```jldoctest
255255
julia> n, r = 10, 0.75;
256256
257-
julia> x = Float32, 3, n);
257+
julia> x = rand(3, n);
258258
259259
julia> g = radius_graph(x, r)
260260
GNNGraph:

GNNlib/src/GNNGraphs/gnngraph.jl GNNGraphs/src/gnngraph.jl

+10-7
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ functionality from that library.
6666
# Examples
6767
6868
```julia
69-
using Flux, GraphNeuralNetworks, CUDA
69+
using GraphNeuralNetworks
7070
7171
# Construct from adjacency list representation
7272
data = [[2,3], [1,4,5], [1], [2,5], [2,4]]
@@ -86,24 +86,27 @@ g = GNNGraph(s, t)
8686
g = GNNGraph(erdos_renyi(100, 20))
8787
8888
# Add 2 node feature arrays at creation time
89-
g = GNNGraph(g, ndata = (x=rand(Float32,100,g.num_nodes), y=rand(Float32,g.num_nodes)))
89+
g = GNNGraph(g, ndata = (x=rand(100, g.num_nodes), y=rand(g.num_nodes)))
9090
9191
# Add 1 edge feature array, after the graph creation
92-
g.edata.z = rand(Float32,16,g.num_edges)
92+
g.edata.z = rand(16, g.num_edges)
9393
9494
# Add node features and edge features with default names `x` and `e`
95-
g = GNNGraph(g, ndata = rand(Float32,100,g.num_nodes), edata = rand(Float32,16,g.num_edges))
95+
g = GNNGraph(g, ndata = rand(100, g.num_nodes), edata = rand(16, g.num_edges))
9696
9797
g.ndata.x # or just g.x
9898
g.edata.e # or just g.e
9999
100-
# Send to gpu
101-
g = g |> gpu
102-
103100
# Collect edges' source and target nodes.
104101
# Both source and target are vectors of length num_edges
105102
source, target = edge_index(g)
106103
```
104+
A `GNNGraph` can be sent to the GPU using e.g. Flux's `gpu` function:
105+
```
106+
# Send to gpu
107+
using Flux, CUDA
108+
g = g |> Flux.gpu
109+
```
107110
"""
108111
struct GNNGraph{T <: Union{COO_T, ADJMAT_T}} <: AbstractGNNGraph{T}
109112
graph::T
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

GNNlib/src/GNNGraphs/transform.jl GNNGraphs/src/transform.jl

+82
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,38 @@ function remove_nodes(g::GNNGraph{<:COO_T}, nodes_to_remove::AbstractVector)
306306
ndata, edata, g.gdata)
307307
end
308308

309+
"""
310+
drop_nodes(g::GNNGraph{<:COO_T}, p)
311+
312+
Randomly drop nodes (and their associated edges) from a GNNGraph based on a given probability.
313+
Dropping nodes is a technique that can be used for graph data augmentation, refering paper [DropNode](https://arxiv.org/pdf/2008.12578.pdf).
314+
315+
# Arguments
316+
- `g`: The input graph from which nodes (and their associated edges) will be dropped.
317+
- `p`: The probability of dropping each node. Default value is `0.5`.
318+
319+
# Returns
320+
A modified GNNGraph with nodes (and their associated edges) dropped based on the given probability.
321+
322+
# Example
323+
```julia
324+
using GraphNeuralNetworks
325+
# Construct a GNNGraph
326+
g = GNNGraph([1, 1, 2, 2, 3], [2, 3, 1, 3, 1], num_nodes=3)
327+
# Drop nodes with a probability of 0.5
328+
g_new = drop_node(g, 0.5)
329+
println(g_new)
330+
```
331+
"""
332+
function drop_nodes(g::GNNGraph{<:COO_T}, p = 0.5)
333+
num_nodes = g.num_nodes
334+
nodes_to_remove = filter(_ -> rand() < p, 1:num_nodes)
335+
336+
new_g = remove_nodes(g, nodes_to_remove)
337+
338+
return new_g
339+
end
340+
309341
"""
310342
add_edges(g::GNNGraph, s::AbstractVector, t::AbstractVector; [edata])
311343
add_edges(g::GNNGraph, (s, t); [edata])
@@ -1028,6 +1060,9 @@ function negative_sample(g::GNNGraph;
10281060

10291061
s, t = edge_index(g)
10301062
n = g.num_nodes
1063+
dev = get_device(s)
1064+
cdev = cpu_device()
1065+
s, t = s |> cdev, t |> cdev
10311066
idx_pos, maxid = edge_encoding(s, t, n)
10321067
if bidirected
10331068
num_neg_edges = num_neg_edges ÷ 2
@@ -1051,6 +1086,7 @@ function negative_sample(g::GNNGraph;
10511086
if bidirected
10521087
s_neg, t_neg = [s_neg; t_neg], [t_neg; s_neg]
10531088
end
1089+
s_neg, t_neg = s_neg |> dev, t_neg |> dev
10541090
return GNNGraph(s_neg, t_neg, num_nodes = n)
10551091
end
10561092

@@ -1129,3 +1165,49 @@ ci2t(ci::AbstractVector{<:CartesianIndex}, dims) = ntuple(i -> map(x -> x[i], ci
11291165
@non_differentiable remove_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule
11301166
@non_differentiable dense_zeros_like(x...)
11311167

1168+
"""
1169+
ppr_diffusion(g::GNNGraph{<:COO_T}, alpha =0.85f0) -> GNNGraph
1170+
1171+
Calculates the Personalized PageRank (PPR) diffusion based on the edge weight matrix of a GNNGraph and updates the graph with new edge weights derived from the PPR matrix.
1172+
References paper: [The pagerank citation ranking: Bringing order to the web](http://ilpubs.stanford.edu:8090/422)
1173+
1174+
1175+
The function performs the following steps:
1176+
1. Constructs a modified adjacency matrix `A` using the graph's edge weights, where `A` is adjusted by `(α - 1) * A + I`, with `α` being the damping factor (`alpha_f32`) and `I` the identity matrix.
1177+
2. Normalizes `A` to ensure each column sums to 1, representing transition probabilities.
1178+
3. Applies the PPR formula `α * (I + (α - 1) * A)^-1` to compute the diffusion matrix.
1179+
4. Updates the original edge weights of the graph based on the PPR diffusion matrix, assigning new weights for each edge from the PPR matrix.
1180+
1181+
# Arguments
1182+
- `g::GNNGraph`: The input graph for which PPR diffusion is to be calculated. It should have edge weights available.
1183+
- `alpha_f32::Float32`: The damping factor used in PPR calculation, controlling the teleport probability in the random walk. Defaults to `0.85f0`.
1184+
1185+
# Returns
1186+
- A new `GNNGraph` instance with the same structure as `g` but with updated edge weights according to the PPR diffusion calculation.
1187+
"""
1188+
function ppr_diffusion(g::GNNGraph{<:COO_T}; alpha = 0.85f0)
1189+
s, t = edge_index(g)
1190+
w = get_edge_weight(g)
1191+
if isnothing(w)
1192+
w = ones(Float32, g.num_edges)
1193+
end
1194+
1195+
N = g.num_nodes
1196+
1197+
initial_A = sparse(t, s, w, N, N)
1198+
scaled_A = (Float32(alpha) - 1) * initial_A
1199+
1200+
I_sparse = sparse(Diagonal(ones(Float32, N)))
1201+
A_sparse = I_sparse + scaled_A
1202+
1203+
A_dense = Matrix(A_sparse)
1204+
1205+
PPR = alpha * inv(A_dense)
1206+
1207+
new_w = [PPR[dst, src] for (src, dst) in zip(s, t)]
1208+
1209+
return GNNGraph((s, t, new_w),
1210+
g.num_nodes, length(s), g.num_graphs,
1211+
g.graph_indicator,
1212+
g.ndata, g.edata, g.gdata)
1213+
end

GNNlib/src/GNNGraphs/utils.jl GNNGraphs/src/utils.jl

-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ function sort_edge_index(u, v)
6565
end
6666

6767

68-
6968
cat_features(x1::Nothing, x2::Nothing) = nothing
7069
cat_features(x1::AbstractArray, x2::AbstractArray) = cat(x1, x2, dims = ndims(x1))
7170
function cat_features(x1::Union{Number, AbstractVector}, x2::Union{Number, AbstractVector})
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)