-
Notifications
You must be signed in to change notification settings - Fork 48
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[GNNLux] GCNConv, ChebConv, GNNChain (#462)
* add gcconv and chebconv * gnn chain
- Loading branch information
1 parent
79515e9
commit 3b42087
Showing
8 changed files
with
281 additions
and
64 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,22 @@ | ||
module GNNLux | ||
using ConcreteStructs: @concrete | ||
using NNlib: NNlib | ||
using LuxCore: LuxCore, AbstractExplicitLayer | ||
using Lux: glorot_uniform, zeros32 | ||
using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer | ||
using Lux: Lux, glorot_uniform, zeros32 | ||
using Reexport: @reexport | ||
using Random: AbstractRNG | ||
using GNNlib: GNNlib | ||
@reexport using GNNGraphs | ||
|
||
include("layers/basic.jl") | ||
export GNNLayer, | ||
GNNContainerLayer, | ||
GNNChain | ||
|
||
include("layers/conv.jl") | ||
export GraphConv | ||
export GCNConv, | ||
ChebConv, | ||
GraphConv | ||
|
||
end #module | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
""" | ||
abstract type GNNLayer <: AbstractExplicitLayer end | ||
An abstract type from which graph neural network layers are derived. | ||
It is Derived from Lux's `AbstractExplicitLayer` type. | ||
See also [`GNNChain`](@ref GNNLux.GNNChain). | ||
""" | ||
abstract type GNNLayer <: AbstractExplicitLayer end | ||
|
||
abstract type GNNContainerLayer{T} <: AbstractExplicitContainerLayer{T} end | ||
|
||
@concrete struct GNNChain <: GNNContainerLayer{(:layers,)} | ||
layers <: NamedTuple | ||
end | ||
|
||
GNNChain(xs...) = GNNChain(; (Symbol("layer_", i) => x for (i, x) in enumerate(xs))...) | ||
|
||
function GNNChain(; kw...) | ||
:layers in Base.keys(kw) && | ||
throw(ArgumentError("a GNNChain cannot have a named layer called `layers`")) | ||
nt = NamedTuple{keys(kw)}(values(kw)) | ||
nt = map(_wrapforchain, nt) | ||
return GNNChain(nt) | ||
end | ||
|
||
_wrapforchain(l::AbstractExplicitLayer) = l | ||
_wrapforchain(l) = Lux.WrappedFunction(l) | ||
|
||
Base.keys(c::GNNChain) = Base.keys(getfield(c, :layers)) | ||
Base.getindex(c::GNNChain, i::Int) = c.layers[i] | ||
Base.getindex(c::GNNChain, i::AbstractVector) = GNNChain(NamedTuple{keys(c)[i]}(Tuple(c.layers)[i])) | ||
|
||
function Base.getproperty(c::GNNChain, name::Symbol) | ||
hasfield(typeof(c), name) && return getfield(c, name) | ||
layers = getfield(c, :layers) | ||
hasfield(typeof(layers), name) && return getfield(layers, name) | ||
throw(ArgumentError("$(typeof(c)) has no field or layer $name")) | ||
end | ||
|
||
Base.length(c::GNNChain) = length(c.layers) | ||
Base.lastindex(c::GNNChain) = lastindex(c.layers) | ||
Base.firstindex(c::GNNChain) = firstindex(c.layers) | ||
|
||
LuxCore.outputsize(c::GNNChain) = LuxCore.outputsize(c.layers[end]) | ||
|
||
(c::GNNChain)(g::GNNGraph, x, ps, st) = _applychain(c.layers, g, x, ps, st) | ||
|
||
function _applychain(layers, g::GNNGraph, x, ps, st) # type-unstable path, helps compile times | ||
newst = (;) | ||
for (name, l) in pairs(layers) | ||
x, s′ = _applylayer(l, g, x, getproperty(ps, name), getproperty(st, name)) | ||
newst = merge(newst, (; name => s′)) | ||
end | ||
return x, newst | ||
end | ||
|
||
_applylayer(l, g::GNNGraph, x, ps, st) = l(x), (;) | ||
_applylayer(l::AbstractExplicitLayer, g::GNNGraph, x, ps, st) = l(x, ps, st) | ||
_applylayer(l::GNNLayer, g::GNNGraph, x, ps, st) = l(g, x, ps, st) | ||
_applylayer(l::GNNContainerLayer, g::GNNGraph, x, ps, st) = l(g, x, ps, st) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
@testitem "layers/basic" setup=[SharedTestSetup] begin | ||
rng = StableRNG(17) | ||
g = rand_graph(10, 40, seed=17) | ||
x = randn(rng, Float32, 3, 10) | ||
|
||
@testset "GNNLayer" begin | ||
@test GNNLayer <: LuxCore.AbstractExplicitLayer | ||
end | ||
|
||
@testset "GNNChain" begin | ||
@test GNNChain <: LuxCore.AbstractExplicitContainerLayer{(:layers,)} | ||
@test GNNChain <: GNNContainerLayer | ||
c = GNNChain(GraphConv(3 => 5, relu), GCNConv(5 => 3)) | ||
ps = LuxCore.initialparameters(rng, c) | ||
st = LuxCore.initialstates(rng, c) | ||
@test LuxCore.parameterlength(c) == LuxCore.parameterlength(ps) | ||
@test LuxCore.statelength(c) == LuxCore.statelength(st) | ||
y, st′ = c(g, x, ps, st) | ||
@test LuxCore.outputsize(c) == (3,) | ||
@test size(y) == (3, 10) | ||
loss = (x, ps) -> sum(first(c(g, x, ps, st))) | ||
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true skip_reverse_diff=true | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.