Skip to content

Commit

Permalink
complete factorization
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Jul 20, 2024
1 parent 6f12763 commit f979b7d
Show file tree
Hide file tree
Showing 10 changed files with 51 additions and 53 deletions.
7 changes: 7 additions & 0 deletions GNNGraphs/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# GNNGraphs.jl

A package implementing graph types for graph deep learning.

This package is currently under development and may break frequentely.
It is not meant for final users but for GNN libraries developers.
Final user should use GraphNeuralNetworks.jl instead.
2 changes: 1 addition & 1 deletion GNNGraphs/src/GNNGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module GNNGraphs
using SparseArrays
using Functors: @functor
import Graphs
using Graphs: AbstractGraph, outneighbors, inneighbors, adjacency_matrix, degree,
using Graphs: AbstractGraph, outneighbors, inneighbors, adjacency_matrix, degree,
has_self_loops, is_directed
import NearestNeighbors
import NNlib
Expand Down
16 changes: 8 additions & 8 deletions GNNGraphs/test/gnngraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ end
# core functionality
g = GNNGraph(s, t; graph_type = GRAPH_T)
if TEST_GPU
g_gpu = g |> gpu
dev = LuxCUDADevice() #TODO replace with gpu_device()
g_gpu = g |> dev
end

@test g.num_edges == 8
Expand Down Expand Up @@ -99,12 +100,10 @@ end
end

@testset "scaled_laplacian" begin if TEST_GPU
@test_broken begin
mat = scaled_laplacian(g)
mat_gpu = scaled_laplacian(g_gpu)
@test mat_gpu isa ACUMatrix{Float32}
@test Array(mat_gpu) == mat
end
mat = scaled_laplacian(g)
mat_gpu = scaled_laplacian(g_gpu)
@test mat_gpu isa ACUMatrix{Float32}
@test Array(mat_gpu) mat
end end

@testset "constructors" begin
Expand Down Expand Up @@ -142,7 +141,8 @@ end
# core functionality
g = GNNGraph(s, t; graph_type = GRAPH_T)
if TEST_GPU
g_gpu = g |> gpu
dev = LuxCUDADevice() #TODO replace with `gpu_device()`
g_gpu = g |> dev
end

@test g.num_edges == 4
Expand Down
6 changes: 4 additions & 2 deletions GNNGraphs/test/query.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ end
@test eltype(degree(g, Float32)) == Float32

if TEST_GPU
g_gpu = g |> gpu
dev = LuxCUDADevice() #TODO replace with `gpu_device()`
g_gpu = g |> dev
d = degree(g)
d_gpu = degree(g_gpu)
@test d_gpu isa CuVector{Int}
Expand All @@ -86,7 +87,8 @@ end
@test degree(g, edge_weight = 2 * eweight) [4.4, 2.4, 2.0, 0.0] broken = (GRAPH_T != :coo)

if TEST_GPU
g_gpu = g |> gpu
dev = LuxCUDADevice() #TODO replace with `gpu_device()`
g_gpu = g |> dev
d = degree(g)
d_gpu = degree(g_gpu)
@test d_gpu isa CuVector{Float32}
Expand Down
4 changes: 3 additions & 1 deletion GNNGraphs/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using CUDA
using CUDA, cuDNN
using GNNGraphs
using GNNGraphs: getn, getdata
using Functors
Expand All @@ -13,6 +13,8 @@ using Test
using MLDatasets
using InlineStrings # not used but with the import we test #98 and #104
using SimpleWeightedGraphs
using LuxDeviceUtils: gpu_device, cpu_device, get_device
using LuxDeviceUtils: LuxCUDADevice # remove after https://github.com/LuxDL/LuxDeviceUtils.jl/pull/58

CUDA.allowscalar(false)

Expand Down
2 changes: 1 addition & 1 deletion GNNGraphs/test/temporalsnapshotsgnngraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ if TEST_GPU
snapshots = [rand_graph(10, 20; ndata = rand(5,10)) for i in 1:5]
tsg = TemporalSnapshotsGNNGraph(snapshots)
tsg.tgdata.x = rand(5)
dev = gpu_device()
dev = LuxCUDADevice() #TODO replace with `gpu_device()`
tsg = tsg |> dev
@test tsg.snapshots[1].ndata.x isa CuArray
@test tsg.snapshots[end].ndata.x isa CuArray
Expand Down
3 changes: 1 addition & 2 deletions GNNlib/README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# GNNlib.jl

This package contains a collection deep-learning framework agnostic
building blocks for graph neural networks such as graph convolutional layers and the implementation
of GraphGNN.
building blocks for graph neural networks such as message passing operators and implementations of graph convolutional layers.

In the future it will serve as the foundation of GraphNeuralNetworks.jl (based on Flux,jl).
GNNlib.jl will be to GraphNeuralNetworks.jl what NNlib.jl is to Flux.jl and Lux.jl.
Expand Down
19 changes: 6 additions & 13 deletions GNNlib/src/GNNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,21 @@ using ChainRulesCore
using SparseArrays, Graphs # not needed but if removed Documenter will complain
using DataStructures: nlargest
using Reexport: @reexport

include("GNNGraphs/GNNGraphs.jl")

@reexport using .GNNGraphs

using GNNGraphs
using .GNNGraphs: COO_T, ADJMAT_T, SPARSE_T,
check_num_nodes, check_num_edges,
EType, NType # for heteroconvs

export


# utils
# utils
reduce_nodes,
reduce_edges,
softmax_nodes,
softmax_edges,
broadcast_nodes,
broadcast_edges,
softmax_edge_neighbors,

# msgpass
# msgpass
apply_edges,
aggregate_neighbors,
propagate,
Expand All @@ -43,8 +36,7 @@ export
xj_sub_xi,
e_mul_xj,
w_mul_xj,

# mldatasets
# mldatasets
mldataset2gnngraph

## The following methods are defined but not exported
Expand Down Expand Up @@ -92,4 +84,5 @@ include("layers/pool.jl")
include("msgpass.jl")
include("mldatasets.jl")

end
end #module

6 changes: 2 additions & 4 deletions src/GNNGraphs/GNNGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,14 @@ using Functors: @functor
import Graphs
using Graphs: AbstractGraph, outneighbors, inneighbors, adjacency_matrix, degree,
has_self_loops, is_directed
import Flux
using Flux: batch
import NearestNeighbors
import NNlib
import StatsBase
import KrylovKit
using ChainRulesCore
using LinearAlgebra, Random, Statistics
import MLUtils
using MLUtils: getobs, numobs, ones_like, zeros_like, chunk
using MLUtils: getobs, numobs, ones_like, zeros_like, chunk, batch
import Functors

include("chainrules.jl") # hacks for differentiability
Expand Down Expand Up @@ -81,7 +79,7 @@ export add_nodes,
remove_nodes,
ppr_diffusion,
drop_nodes,
# from Flux
# from MLUtils.jl
batch,
unbatch,
# from SparseArrays
Expand Down
39 changes: 18 additions & 21 deletions src/GNNGraphs/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@ end
"""
blockdiag(xs::GNNGraph...)
Equivalent to [`Flux.batch`](@ref).
Equivalent to [`MLUtils.batch`](@ref).
"""
function SparseArrays.blockdiag(g1::GNNGraph, gothers::GNNGraph...)
g = g1
Expand All @@ -717,7 +717,7 @@ Batch together multiple `GNNGraph`s into a single one
containing the total number of original nodes and edges.
Equivalent to [`SparseArrays.blockdiag`](@ref).
See also [`Flux.unbatch`](@ref).
See also [`MLUtils.unbatch`](@ref).
# Examples
Expand All @@ -736,7 +736,7 @@ GNNGraph:
ndata:
x => (8, 7)
julia> g12 = Flux.batch([g1, g2])
julia> g12 = MLUtils.batch([g1, g2])
GNNGraph:
num_nodes = 11
num_edges = 10
Expand All @@ -756,18 +756,18 @@ julia> g12.ndata.x
1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
```
"""
function Flux.batch(gs::AbstractVector{<:GNNGraph})
function MLUtils.batch(gs::AbstractVector{<:GNNGraph})
Told = eltype(gs)
# try to restrict the eltype
gs = [g for g in gs]
if eltype(gs) != Told
return Flux.batch(gs)
return MLUtils.batch(gs)
else
return blockdiag(gs...)
end
end

function Flux.batch(gs::AbstractVector{<:GNNGraph{T}}) where {T <: COO_T}
function MLUtils.batch(gs::AbstractVector{<:GNNGraph{T}}) where {T <: COO_T}
v_num_nodes = [g.num_nodes for g in gs]
edge_indices = [edge_index(g) for g in gs]
nodesum = cumsum([0; v_num_nodes])[1:(end - 1)]
Expand Down Expand Up @@ -796,12 +796,12 @@ function Flux.batch(gs::AbstractVector{<:GNNGraph{T}}) where {T <: COO_T}
cat_features([g.gdata for g in gs]))
end

function Flux.batch(g::GNNGraph)
function MLUtils.batch(g::GNNGraph)
throw(ArgumentError("Cannot batch a `GNNGraph` (containing $(g.num_graphs) graphs). Pass a vector of `GNNGraph`s instead."))
end


function Flux.batch(gs::AbstractVector{<:GNNHeteroGraph})
function MLUtils.batch(gs::AbstractVector{<:GNNHeteroGraph})
function edge_index_nullable(g::GNNHeteroGraph{<:COO_T}, edge_t::EType)
if haskey(g.graph, edge_t)
g.graph[edge_t][1:2]
Expand Down Expand Up @@ -871,21 +871,21 @@ end
"""
unbatch(g::GNNGraph)
Opposite of the [`Flux.batch`](@ref) operation, returns
Opposite of the [`MLUtils.batch`](@ref) operation, returns
an array of the individual graphs batched together in `g`.
See also [`Flux.batch`](@ref) and [`getgraph`](@ref).
See also [`MLUtils.batch`](@ref) and [`getgraph`](@ref).
# Examples
```jldoctest
julia> gbatched = Flux.batch([rand_graph(5, 6), rand_graph(10, 8), rand_graph(4,2)])
julia> gbatched = MLUtils.batch([rand_graph(5, 6), rand_graph(10, 8), rand_graph(4,2)])
GNNGraph:
num_nodes = 19
num_edges = 16
num_graphs = 3
julia> Flux.unbatch(gbatched)
julia> MLUtils.unbatch(gbatched)
3-element Vector{GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}}:
GNNGraph:
num_nodes = 5
Expand All @@ -900,7 +900,7 @@ julia> Flux.unbatch(gbatched)
num_edges = 2
```
"""
function Flux.unbatch(g::GNNGraph{T}) where {T <: COO_T}
function MLUtils.unbatch(g::GNNGraph{T}) where {T <: COO_T}
g.num_graphs == 1 && return [g]

nodemasks = _unbatch_nodemasks(g.graph_indicator, g.num_graphs)
Expand Down Expand Up @@ -939,7 +939,7 @@ function Flux.unbatch(g::GNNGraph{T}) where {T <: COO_T}
return [build_graph(i) for i in 1:(g.num_graphs)]
end

function Flux.unbatch(g::GNNGraph)
function MLUtils.unbatch(g::GNNGraph)
return [getgraph(g, i) for i in 1:(g.num_graphs)]
end

Expand Down Expand Up @@ -1060,13 +1060,10 @@ function negative_sample(g::GNNGraph;

s, t = edge_index(g)
n = g.num_nodes
if iscuarray(s)
# Convert to gpu since set operations and sampling are not supported by CUDA.jl
device = Flux.gpu
s, t = Flux.cpu(s), Flux.cpu(t)
else
device = Flux.cpu
end
device = get_device(s)
cdevice = cpu_device()
# Convert to gpu since set operations and sampling are not supported by CUDA.jl
s, t = cdevice(s), cdevice(t)
idx_pos, maxid = edge_encoding(s, t, n)
if bidirected
num_neg_edges = num_neg_edges ÷ 2
Expand Down

0 comments on commit f979b7d

Please sign in to comment.