Skip to content

Commit

Permalink
[GNNLux] GCNConv, ChebConv, GNNChain (#462)
Browse files Browse the repository at this point in the history
* add gcconv and chebconv

* gnn chain
  • Loading branch information
CarloLucibello authored Jul 26, 2024
1 parent 79515e9 commit 3b42087
Show file tree
Hide file tree
Showing 8 changed files with 281 additions and 64 deletions.
13 changes: 10 additions & 3 deletions GNNLux/src/GNNLux.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
module GNNLux
using ConcreteStructs: @concrete
using NNlib: NNlib
using LuxCore: LuxCore, AbstractExplicitLayer
using Lux: glorot_uniform, zeros32
using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer
using Lux: Lux, glorot_uniform, zeros32
using Reexport: @reexport
using Random: AbstractRNG
using GNNlib: GNNlib
@reexport using GNNGraphs

include("layers/basic.jl")
export GNNLayer,
GNNContainerLayer,
GNNChain

include("layers/conv.jl")
export GraphConv
export GCNConv,
ChebConv,
GraphConv

end #module

61 changes: 61 additions & 0 deletions GNNLux/src/layers/basic.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""
abstract type GNNLayer <: AbstractExplicitLayer end
An abstract type from which graph neural network layers are derived.
It is Derived from Lux's `AbstractExplicitLayer` type.
See also [`GNNChain`](@ref GNNLux.GNNChain).
"""
abstract type GNNLayer <: AbstractExplicitLayer end

abstract type GNNContainerLayer{T} <: AbstractExplicitContainerLayer{T} end

@concrete struct GNNChain <: GNNContainerLayer{(:layers,)}
layers <: NamedTuple
end

GNNChain(xs...) = GNNChain(; (Symbol("layer_", i) => x for (i, x) in enumerate(xs))...)

function GNNChain(; kw...)
:layers in Base.keys(kw) &&
throw(ArgumentError("a GNNChain cannot have a named layer called `layers`"))
nt = NamedTuple{keys(kw)}(values(kw))
nt = map(_wrapforchain, nt)
return GNNChain(nt)
end

_wrapforchain(l::AbstractExplicitLayer) = l
_wrapforchain(l) = Lux.WrappedFunction(l)

Base.keys(c::GNNChain) = Base.keys(getfield(c, :layers))
Base.getindex(c::GNNChain, i::Int) = c.layers[i]
Base.getindex(c::GNNChain, i::AbstractVector) = GNNChain(NamedTuple{keys(c)[i]}(Tuple(c.layers)[i]))

function Base.getproperty(c::GNNChain, name::Symbol)
hasfield(typeof(c), name) && return getfield(c, name)
layers = getfield(c, :layers)
hasfield(typeof(layers), name) && return getfield(layers, name)
throw(ArgumentError("$(typeof(c)) has no field or layer $name"))
end

Base.length(c::GNNChain) = length(c.layers)
Base.lastindex(c::GNNChain) = lastindex(c.layers)
Base.firstindex(c::GNNChain) = firstindex(c.layers)

LuxCore.outputsize(c::GNNChain) = LuxCore.outputsize(c.layers[end])

(c::GNNChain)(g::GNNGraph, x, ps, st) = _applychain(c.layers, g, x, ps, st)

function _applychain(layers, g::GNNGraph, x, ps, st) # type-unstable path, helps compile times
newst = (;)
for (name, l) in pairs(layers)
x, s′ = _applylayer(l, g, x, getproperty(ps, name), getproperty(st, name))
newst = merge(newst, (; name => s′))
end
return x, newst
end

_applylayer(l, g::GNNGraph, x, ps, st) = l(x), (;)
_applylayer(l::AbstractExplicitLayer, g::GNNGraph, x, ps, st) = l(x, ps, st)
_applylayer(l::GNNLayer, g::GNNGraph, x, ps, st) = l(g, x, ps, st)
_applylayer(l::GNNContainerLayer, g::GNNGraph, x, ps, st) = l(g, x, ps, st)
152 changes: 115 additions & 37 deletions GNNLux/src/layers/conv.jl
Original file line number Diff line number Diff line change
@@ -1,54 +1,132 @@
# Missing Layers

# | Layer |Sparse Ops|Edge Weight|Edge Features| Heterograph | TemporalSnapshotsGNNGraphs |
# | :-------- | :---: |:---: |:---: | :---: | :---: |
# | [`AGNNConv`](@ref) | | | ✓ | | |
# | [`CGConv`](@ref) | | | ✓ | ✓ | ✓ |
# | [`EGNNConv`](@ref) | | | ✓ | | |
# | [`EdgeConv`](@ref) | | | | ✓ | |
# | [`GATConv`](@ref) | | | ✓ | ✓ | ✓ |
# | [`GATv2Conv`](@ref) | | | ✓ | ✓ | ✓ |
# | [`GatedGraphConv`](@ref) | ✓ | | | | ✓ |
# | [`GINConv`](@ref) | ✓ | | | ✓ | ✓ |
# | [`GMMConv`](@ref) | | | ✓ | | |
# | [`MEGNetConv`](@ref) | | | ✓ | | |
# | [`NNConv`](@ref) | | | ✓ | | |
# | [`ResGatedGraphConv`](@ref) | | | | ✓ | ✓ |
# | [`SAGEConv`](@ref) | ✓ | | | ✓ | ✓ |
# | [`SGConv`](@ref) | ✓ | | | | ✓ |
# | [`TransformerConv`](@ref) | | | ✓ | | |


@concrete struct GCNConv <: GNNLayer
in_dims::Int
out_dims::Int
use_bias::Bool
add_self_loops::Bool
use_edge_weight::Bool
init_weight
init_bias
σ
end

@doc raw"""
GraphConv(in => out, σ=identity; aggr=+, bias=true, init=glorot_uniform)
function GCNConv(ch::Pair{Int, Int}, σ = identity;
init_weight = glorot_uniform,
init_bias = zeros32,
use_bias::Bool = true,
add_self_loops::Bool = true,
use_edge_weight::Bool = false,
allow_fast_activation::Bool = true)
in_dims, out_dims = ch
σ = allow_fast_activation ? NNlib.fast_act(σ) : σ
return GCNConv(in_dims, out_dims, use_bias, add_self_loops, use_edge_weight, init_weight, init_bias, σ)
end

Graph convolution layer from Reference: [Weisfeiler and Leman Go Neural: Higher-order Graph Neural Networks](https://arxiv.org/abs/1810.02244).
function LuxCore.initialparameters(rng::AbstractRNG, l::GCNConv)
weight = l.init_weight(rng, l.out_dims, l.in_dims)
if l.use_bias
bias = l.init_bias(rng, l.out_dims)
return (; weight, bias)
else
return (; weight)
end
end

Performs:
```math
\mathbf{x}_i' = W_1 \mathbf{x}_i + \square_{j \in \mathcal{N}(i)} W_2 \mathbf{x}_j
```
LuxCore.parameterlength(l::GCNConv) = l.use_bias ? l.in_dims * l.out_dims + l.out_dims : l.in_dims * l.out_dims
LuxCore.statelength(d::GCNConv) = 0
LuxCore.outputsize(d::GCNConv) = (d.out_dims,)

where the aggregation type is selected by `aggr`.
function Base.show(io::IO, l::GCNConv)
print(io, "GCNConv(", l.in_dims, " => ", l.out_dims)
l.σ == identity || print(io, ", ", l.σ)
l.use_bias || print(io, ", use_bias=false")
l.add_self_loops || print(io, ", add_self_loops=false")
!l.use_edge_weight || print(io, ", use_edge_weight=true")
print(io, ")")
end

# Arguments
# TODO norm_fn should be keyword argument only
(l::GCNConv)(g, x, ps, st; conv_weight=nothing, edge_weight=nothing, norm_fn= d -> 1 ./ sqrt.(d)) =
l(g, x, edge_weight, norm_fn, ps, st; conv_weight)
(l::GCNConv)(g, x, edge_weight, ps, st; conv_weight=nothing, norm_fn = d -> 1 ./ sqrt.(d)) =
l(g, x, edge_weight, norm_fn, ps, st; conv_weight)
(l::GCNConv)(g, x, edge_weight, norm_fn, ps, st; conv_weight=nothing) =
GNNlib.gcn_conv(l, g, x, edge_weight, norm_fn, conv_weight, ps), st

- `in`: The dimension of input features.
- `out`: The dimension of output features.
- `σ`: Activation function.
- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`).
- `bias`: Add learnable bias.
- `init`: Weights' initializer.
@concrete struct ChebConv <: GNNLayer
in_dims::Int
out_dims::Int
use_bias::Bool
k::Int
init_weight
init_bias
σ
end

# Examples
function ChebConv(ch::Pair{Int, Int}, k::Int, σ = identity;
init_weight = glorot_uniform,
init_bias = zeros32,
use_bias::Bool = true,
allow_fast_activation::Bool = true)
in_dims, out_dims = ch
σ = allow_fast_activation ? NNlib.fast_act(σ) : σ
return ChebConv(in_dims, out_dims, use_bias, k, init_weight, init_bias, σ)
end

```julia
# create data
s = [1,1,2,3]
t = [2,3,1,1]
in_channel = 3
out_channel = 5
g = GNNGraph(s, t)
x = randn(Float32, 3, g.num_nodes)
function LuxCore.initialparameters(rng::AbstractRNG, l::ChebConv)
weight = l.init_weight(rng, l.out_dims, l.in_dims, l.k)
if l.use_bias
bias = l.init_bias(rng, l.out_dims)
return (; weight, bias)
else
return (; weight)
end
end

LuxCore.parameterlength(l::ChebConv) = l.use_bias ? l.in_dims * l.out_dims * l.k + l.out_dims :
l.in_dims * l.out_dims * l.k
LuxCore.statelength(d::ChebConv) = 0
LuxCore.outputsize(d::ChebConv) = (d.out_dims,)

function Base.show(io::IO, l::ChebConv)
print(io, "ChebConv(", l.in_dims, " => ", l.out_dims, ", K=", l.K)
l.σ == identity || print(io, ", ", l.σ)
l.use_bias || print(io, ", use_bias=false")
print(io, ")")
end

# create layer
l = GraphConv(in_channel => out_channel, relu, bias = false, aggr = mean)
(l::ChebConv)(g, x, ps, st) = GNNlib.cheb_conv(l, g, x, ps), st

# forward pass
y = l(g, x)
```
"""
@concrete struct GraphConv <: AbstractExplicitLayer
@concrete struct GraphConv <: GNNLayer
in_dims::Int
out_dims::Int
use_bias::Bool
init_weight::Function
init_bias::Function
init_weight
init_bias
σ
aggr
end


function GraphConv(ch::Pair{Int, Int}, σ = identity;
aggr = +,
init_weight = glorot_uniform,
Expand All @@ -65,10 +143,10 @@ function LuxCore.initialparameters(rng::AbstractRNG, l::GraphConv)
weight2 = l.init_weight(rng, l.out_dims, l.in_dims)
if l.use_bias
bias = l.init_bias(rng, l.out_dims)
return (; weight1, weight2, bias)
else
bias = false
return (; weight1, weight2)
end
return (; weight1, weight2, bias)
end

function LuxCore.parameterlength(l::GraphConv)
Expand All @@ -90,4 +168,4 @@ function Base.show(io::IO, l::GraphConv)
print(io, ")")
end

(l::GraphConv)(g::GNNGraph, x, ps, st) = GNNlib.graph_conv(l, g, x, ps), st
(l::GraphConv)(g, x, ps, st) = GNNlib.graph_conv(l, g, x, ps), st
24 changes: 24 additions & 0 deletions GNNLux/test/layers/basic_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
@testitem "layers/basic" setup=[SharedTestSetup] begin
rng = StableRNG(17)
g = rand_graph(10, 40, seed=17)
x = randn(rng, Float32, 3, 10)

@testset "GNNLayer" begin
@test GNNLayer <: LuxCore.AbstractExplicitLayer
end

@testset "GNNChain" begin
@test GNNChain <: LuxCore.AbstractExplicitContainerLayer{(:layers,)}
@test GNNChain <: GNNContainerLayer
c = GNNChain(GraphConv(3 => 5, relu), GCNConv(5 => 3))
ps = LuxCore.initialparameters(rng, c)
st = LuxCore.initialstates(rng, c)
@test LuxCore.parameterlength(c) == LuxCore.parameterlength(ps)
@test LuxCore.statelength(c) == LuxCore.statelength(st)
y, st′ = c(g, x, ps, st)
@test LuxCore.outputsize(c) == (3,)
@test size(y) == (3, 10)
loss = (x, ps) -> sum(first(c(g, x, ps, st)))
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true skip_reverse_diff=true
end
end
35 changes: 33 additions & 2 deletions GNNLux/test/layers/conv_tests.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,41 @@
@testitem "layers/conv" setup=[SharedTestSetup] begin
rng = StableRNG(1234)
g = rand_graph(10, 30, seed=1234)
g = rand_graph(10, 40, seed=1234)
x = randn(rng, Float32, 3, 10)

@testset "GCNConv" begin
l = GCNConv(3 => 5, relu)
@test l isa GNNLayer
ps = Lux.initialparameters(rng, l)
st = Lux.initialstates(rng, l)
@test Lux.parameterlength(l) == Lux.parameterlength(ps)
@test Lux.statelength(l) == Lux.statelength(st)

y, _ = l(g, x, ps, st)
@test Lux.outputsize(l) == (5,)
@test size(y) == (5, 10)
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true
end

@testset "ChebConv" begin
l = ChebConv(3 => 5, 2, relu)
@test l isa GNNLayer
ps = Lux.initialparameters(rng, l)
st = Lux.initialstates(rng, l)
@test Lux.parameterlength(l) == Lux.parameterlength(ps)
@test Lux.statelength(l) == Lux.statelength(st)

y, _ = l(g, x, ps, st)
@test Lux.outputsize(l) == (5,)
@test size(y) == (5, 10)
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true skip_reverse_diff=true
end

@testset "GraphConv" begin
l = GraphConv(3 => 5, relu)
@test l isa GNNLayer
ps = Lux.initialparameters(rng, l)
st = Lux.initialstates(rng, l)
@test Lux.parameterlength(l) == Lux.parameterlength(ps)
Expand All @@ -14,6 +45,6 @@
@test Lux.outputsize(l) == (5,)
@test size(y) == (5, 10)
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true
end
end
Loading

0 comments on commit 3b42087

Please sign in to comment.