Skip to content

Commit

Permalink
separate GNNGraphs from GNNlib (#446)
Browse files Browse the repository at this point in the history
* separate GNNGraphs from GNNlib

* complete factorization

* rebase
  • Loading branch information
CarloLucibello authored Jul 21, 2024
1 parent e2623eb commit db923c0
Show file tree
Hide file tree
Showing 50 changed files with 550 additions and 147 deletions.
67 changes: 67 additions & 0 deletions GNNGraphs/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
name = "GNNGraphs"
uuid = "aed8fd31-079b-4b5a-b342-a13352159b8c"
authors = ["Carlo Lucibello and contributors"]
version = "0.1.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622"

[extensions]
GNNGraphsCUDAExt = "CUDA"
GNNGraphsSimpleWeightedGraphsExt = "SimpleWeightedGraphs"

[compat]
Adapt = "4"
CUDA = "5"
ChainRulesCore = "1"
Functors = "0.4.1"
Graphs = "1.4"
KrylovKit = "0.8"
LinearAlgebra = "1"
LuxDeviceUtils = "0.1.24"
MLDatasets = "0.7"
MLUtils = "0.4"
NNlib = "0.9"
NearestNeighbors = "0.4"
Random = "1"
SimpleWeightedGraphs = "1.4.0"
SparseArrays = "1"
Statistics = "1"
StatsBase = "0.34"
cuDNN = "1"
julia = "1.9"

[extras]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[targets]
test = ["Test", "Adapt", "DataFrames", "InlineStrings", "SimpleWeightedGraphs", "Zygote", "FiniteDifferences", "ChainRulesTestUtils", "MLDatasets", "CUDA", "cuDNN"]
7 changes: 7 additions & 0 deletions GNNGraphs/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# GNNGraphs.jl

A package implementing graph types for graph deep learning.

This package is currently under development and may break frequentely.
It is not meant for final users but for GNN libraries developers.
Final user should use GraphNeuralNetworks.jl instead.
14 changes: 14 additions & 0 deletions GNNGraphs/ext/GNNGraphsCUDAExt/GNNGraphsCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
module GNNGraphsCUDAExt

using CUDA
using Random, Statistics, LinearAlgebra
using GNNGraphs
using GNNGraphs: COO_T, ADJMAT_T, SPARSE_T

const CUMAT_T = Union{CUDA.AnyCuMatrix, CUDA.CUSPARSE.CuSparseMatrix}

include("query.jl")
include("transform.jl")
include("utils.jl")

end #module
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ GNNGraphs.iscuarray(x::AnyCuArray) = true


function sort_edge_index(u::AnyCuArray, v::AnyCuArray)
dev = get_device(u)
cdev = cpu_device()
u, v = u |> cdev, v |> cdev
#TODO proper cuda friendly implementation
sort_edge_index(u |> Flux.cpu, v |> Flux.cpu) |> Flux.gpu
end
sort_edge_index(u, v) |> dev
end
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
module GNNlibSimpleWeightedGraphsExt
module GNNGraphsSimpleWeightedGraphsExt

using GNNlib
using Graphs
using GNNGraphs
using SimpleWeightedGraphs

function GNNlib.GNNGraph(g::T; kws...) where
function GNNGraphs.GNNGraph(g::T; kws...) where
{T <: Union{SimpleWeightedGraph, SimpleWeightedDiGraph}}
return GNNGraph(g.weights, kws...)
end
Expand Down
13 changes: 7 additions & 6 deletions GNNlib/src/GNNGraphs/GNNGraphs.jl → GNNGraphs/src/GNNGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@ module GNNGraphs
using SparseArrays
using Functors: @functor
import Graphs
using Graphs: AbstractGraph, outneighbors, inneighbors, adjacency_matrix, degree,
using Graphs: AbstractGraph, outneighbors, inneighbors, adjacency_matrix, degree,
has_self_loops, is_directed
import MLUtils
using MLUtils: getobs, numobs, ones_like, zeros_like, batch
import NearestNeighbors
import NNlib
import StatsBase
import KrylovKit
using ChainRulesCore
using LinearAlgebra, Random, Statistics
import MLUtils
using MLUtils: getobs, numobs, ones_like, zeros_like, chunk, batch
import Functors
using LuxDeviceUtils: get_device, cpu_device, LuxCPUDevice

include("chainrules.jl") # hacks for differentiability

Expand Down Expand Up @@ -78,7 +78,9 @@ export add_nodes,
to_unidirected,
random_walk_pe,
remove_nodes,
# from Flux
ppr_diffusion,
drop_nodes,
# from MLUtils
batch,
unbatch,
# from SparseArrays
Expand All @@ -101,8 +103,7 @@ include("operators.jl")

include("convert.jl")
include("utils.jl")
export sort_edge_index,
color_refinement
export sort_edge_index, color_refinement

include("gatherscatter.jl")
# _gather, _scatter
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,18 @@ julia> ds = DataStore(3, (x = rand(Float32, 2, 3), y = rand(Float32, 30)))
ERROR: AssertionError: DataStore: data[y] has 30 observations, but n = 3
Stacktrace:
[1] DataStore(n::Int64, data::Dict{Symbol, Any})
@ GNNlib.GNNGraphs ~/.julia/dev/GNNlib/src/GNNGraphs/datastore.jl:54
@ GNNGraphs ~/.julia/dev/GNNGraphs/datastore.jl:54
[2] DataStore(n::Int64, data::NamedTuple{(:x, :y), Tuple{Matrix{Float32}, Vector{Float32}}})
@ GNNlib.GNNGraphs ~/.julia/dev/GNNlib/src/GNNGraphs/datastore.jl:73
@ GNNGraphs ~/.julia/dev/GNNGraphs/datastore.jl:73
[3] top-level scope
@ REPL[13]:1
julia> ds = DataStore(x = randFloat32, 2, 3), y = rand(Float32, 30)) # no checks
DataStore() with 2 elements:
y = 30-element Vector{Float32}
x = 2×3 Matrix{Float32}
y = 30-element Vector{Float64}
x = 2×3 Matrix{Float64}
```
The `DataStore` has an interface similar to both dictionaries and named tuples.
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ to its neighbors within a given distance `r`.
```jldoctest
julia> n, r = 10, 0.75;
julia> x = Float32, 3, n);
julia> x = rand(Float32, 3, n);
julia> g = radius_graph(x, r)
GNNGraph:
Expand Down
17 changes: 10 additions & 7 deletions GNNlib/src/GNNGraphs/gnngraph.jl → GNNGraphs/src/gnngraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ functionality from that library.
# Examples
```julia
using Flux, GraphNeuralNetworks, CUDA
using GraphNeuralNetworks
# Construct from adjacency list representation
data = [[2,3], [1,4,5], [1], [2,5], [2,4]]
Expand All @@ -86,24 +86,27 @@ g = GNNGraph(s, t)
g = GNNGraph(erdos_renyi(100, 20))
# Add 2 node feature arrays at creation time
g = GNNGraph(g, ndata = (x=rand(Float32,100,g.num_nodes), y=rand(Float32,g.num_nodes)))
g = GNNGraph(g, ndata = (x=rand(100, g.num_nodes), y=rand(g.num_nodes)))
# Add 1 edge feature array, after the graph creation
g.edata.z = rand(Float32,16,g.num_edges)
g.edata.z = rand(16, g.num_edges)
# Add node features and edge features with default names `x` and `e`
g = GNNGraph(g, ndata = rand(Float32,100,g.num_nodes), edata = rand(Float32,16,g.num_edges))
g = GNNGraph(g, ndata = rand(100, g.num_nodes), edata = rand(16, g.num_edges))
g.ndata.x # or just g.x
g.edata.e # or just g.e
# Send to gpu
g = g |> gpu
# Collect edges' source and target nodes.
# Both source and target are vectors of length num_edges
source, target = edge_index(g)
```
A `GNNGraph` can be sent to the GPU using e.g. Flux's `gpu` function:
```
# Send to gpu
using Flux, CUDA
g = g |> Flux.gpu
```
"""
struct GNNGraph{T <: Union{COO_T, ADJMAT_T}} <: AbstractGNNGraph{T}
graph::T
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
82 changes: 82 additions & 0 deletions GNNlib/src/GNNGraphs/transform.jl → GNNGraphs/src/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,38 @@ function remove_nodes(g::GNNGraph{<:COO_T}, nodes_to_remove::AbstractVector)
ndata, edata, g.gdata)
end

"""
drop_nodes(g::GNNGraph{<:COO_T}, p)
Randomly drop nodes (and their associated edges) from a GNNGraph based on a given probability.
Dropping nodes is a technique that can be used for graph data augmentation, refering paper [DropNode](https://arxiv.org/pdf/2008.12578.pdf).
# Arguments
- `g`: The input graph from which nodes (and their associated edges) will be dropped.
- `p`: The probability of dropping each node. Default value is `0.5`.
# Returns
A modified GNNGraph with nodes (and their associated edges) dropped based on the given probability.
# Example
```julia
using GraphNeuralNetworks
# Construct a GNNGraph
g = GNNGraph([1, 1, 2, 2, 3], [2, 3, 1, 3, 1], num_nodes=3)
# Drop nodes with a probability of 0.5
g_new = drop_node(g, 0.5)
println(g_new)
```
"""
function drop_nodes(g::GNNGraph{<:COO_T}, p = 0.5)
num_nodes = g.num_nodes
nodes_to_remove = filter(_ -> rand() < p, 1:num_nodes)

new_g = remove_nodes(g, nodes_to_remove)

return new_g
end

"""
add_edges(g::GNNGraph, s::AbstractVector, t::AbstractVector; [edata])
add_edges(g::GNNGraph, (s, t); [edata])
Expand Down Expand Up @@ -1028,6 +1060,9 @@ function negative_sample(g::GNNGraph;

s, t = edge_index(g)
n = g.num_nodes
dev = get_device(s)
cdev = cpu_device()
s, t = s |> cdev, t |> cdev
idx_pos, maxid = edge_encoding(s, t, n)
if bidirected
num_neg_edges = num_neg_edges ÷ 2
Expand All @@ -1051,6 +1086,7 @@ function negative_sample(g::GNNGraph;
if bidirected
s_neg, t_neg = [s_neg; t_neg], [t_neg; s_neg]
end
s_neg, t_neg = s_neg |> dev, t_neg |> dev
return GNNGraph(s_neg, t_neg, num_nodes = n)
end

Expand Down Expand Up @@ -1129,3 +1165,49 @@ ci2t(ci::AbstractVector{<:CartesianIndex}, dims) = ntuple(i -> map(x -> x[i], ci
@non_differentiable remove_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule
@non_differentiable dense_zeros_like(x...)

"""
ppr_diffusion(g::GNNGraph{<:COO_T}, alpha =0.85f0) -> GNNGraph
Calculates the Personalized PageRank (PPR) diffusion based on the edge weight matrix of a GNNGraph and updates the graph with new edge weights derived from the PPR matrix.
References paper: [The pagerank citation ranking: Bringing order to the web](http://ilpubs.stanford.edu:8090/422)
The function performs the following steps:
1. Constructs a modified adjacency matrix `A` using the graph's edge weights, where `A` is adjusted by `(α - 1) * A + I`, with `α` being the damping factor (`alpha_f32`) and `I` the identity matrix.
2. Normalizes `A` to ensure each column sums to 1, representing transition probabilities.
3. Applies the PPR formula `α * (I + (α - 1) * A)^-1` to compute the diffusion matrix.
4. Updates the original edge weights of the graph based on the PPR diffusion matrix, assigning new weights for each edge from the PPR matrix.
# Arguments
- `g::GNNGraph`: The input graph for which PPR diffusion is to be calculated. It should have edge weights available.
- `alpha_f32::Float32`: The damping factor used in PPR calculation, controlling the teleport probability in the random walk. Defaults to `0.85f0`.
# Returns
- A new `GNNGraph` instance with the same structure as `g` but with updated edge weights according to the PPR diffusion calculation.
"""
function ppr_diffusion(g::GNNGraph{<:COO_T}; alpha = 0.85f0)
s, t = edge_index(g)
w = get_edge_weight(g)
if isnothing(w)
w = ones(Float32, g.num_edges)
end

N = g.num_nodes

initial_A = sparse(t, s, w, N, N)
scaled_A = (Float32(alpha) - 1) * initial_A

I_sparse = sparse(Diagonal(ones(Float32, N)))
A_sparse = I_sparse + scaled_A

A_dense = Matrix(A_sparse)

PPR = alpha * inv(A_dense)

new_w = [PPR[dst, src] for (src, dst) in zip(s, t)]

return GNNGraph((s, t, new_w),
g.num_nodes, length(s), g.num_graphs,
g.graph_indicator,
g.ndata, g.edata, g.gdata)
end
1 change: 0 additions & 1 deletion GNNlib/src/GNNGraphs/utils.jl → GNNGraphs/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ function sort_edge_index(u, v)
end



cat_features(x1::Nothing, x2::Nothing) = nothing
cat_features(x1::AbstractArray, x2::AbstractArray) = cat(x1, x2, dims = ndims(x1))
function cat_features(x1::Union{Number, AbstractVector}, x2::Union{Number, AbstractVector})
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
4 changes: 2 additions & 2 deletions test/GNNGraphs/generate.jl → GNNGraphs/test/generate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ end
end

@testset "rand_temporal_hyperbolic_graph" begin
@test GraphNeuralNetworks.GNNGraphs._hyperbolic_distance([1.0,1.0],[1.0,1.0];ζ=1)==0
@test GraphNeuralNetworks.GNNGraphs._hyperbolic_distance([0.23,0.11],[0.98,0.55];ζ=1)==GraphNeuralNetworks.GNNGraphs._hyperbolic_distance([0.98,0.55],[0.23,0.11];ζ=1)
@test GNNGraphs._hyperbolic_distance([1.0,1.0],[1.0,1.0];ζ=1)==0
@test GNNGraphs._hyperbolic_distance([0.23,0.11],[0.98,0.55];ζ=1) == GNNGraphs._hyperbolic_distance([0.98,0.55],[0.23,0.11];ζ=1)
number_nodes = 30
number_snapshots = 5
α, R, speed, ζ = 1, 1, 0.1, 1
Expand Down
Loading

0 comments on commit db923c0

Please sign in to comment.