Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GNNLux] Adding NNConv Layer #478

Closed
wants to merge 47 commits into from
Closed
Show file tree
Hide file tree
Changes from 40 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
786b200
WIP
rbSparky Aug 2, 2024
01767ee
WIP
rbSparky Aug 2, 2024
b8c4db6
Update conv.jl
rbSparky Aug 3, 2024
b6c1a27
fix
rbSparky Aug 4, 2024
fb0bb1d
Update conv.jl
rbSparky Aug 4, 2024
32ee61d
fix
rbSparky Aug 4, 2024
2585cb6
Merge branch 'nn-lux' of https://github.com/rbSparky/GraphNeuralNetwo…
rbSparky Aug 4, 2024
cd28e97
Update conv.jl
rbSparky Aug 4, 2024
7f1a07a
Update conv.jl
rbSparky Aug 4, 2024
70674a2
added tests
rbSparky Aug 19, 2024
17627f1
Merge branch 'nn-lux' of https://github.com/rbSparky/GraphNeuralNetwo…
rbSparky Aug 19, 2024
8f081cd
Delete GNNLux/test/layers/temp.jl
rbSparky Aug 19, 2024
af0b78b
Merge branch 'CarloLucibello:master' into nn-lux
rbSparky Aug 19, 2024
890fcda
add to lux
rbSparky Aug 19, 2024
fc2db99
fix test
rbSparky Aug 19, 2024
90fc120
fixing
rbSparky Aug 19, 2024
0dae0bc
Delete data.txt
rbSparky Aug 19, 2024
01ec78b
Delete redundant file
rbSparky Aug 19, 2024
ff012bb
trying test fix
rbSparky Aug 19, 2024
cf7d30a
trying test fix
rbSparky Aug 19, 2024
1c60d1c
Update conv_tests.jl
rbSparky Aug 19, 2024
39b9c74
Update basic_tests.jl
rbSparky Aug 19, 2024
894bdb3
Update conv_tests.jl
rbSparky Aug 19, 2024
6e610c1
Update conv_tests.jl
rbSparky Aug 19, 2024
caf355c
Update conv_tests.jl: edata issues
rbSparky Aug 19, 2024
3e32261
Update conv_tests.jl
rbSparky Aug 19, 2024
24da4c4
Update conv_tests.jl: edata
rbSparky Aug 19, 2024
f2ff073
Update conv_tests.jl
rbSparky Aug 19, 2024
93affd2
change lux testing
rbSparky Aug 19, 2024
3547d9f
Merge branch 'nn-lux' of https://github.com/rbSparky/GraphNeuralNetwo…
rbSparky Aug 19, 2024
f0481b4
Update conv_tests.jl: Trying to fix tests
rbSparky Aug 22, 2024
b3e2649
Update conv.jl: trying to fix
rbSparky Aug 23, 2024
23b89c2
Update conv.jl: reverted
rbSparky Aug 23, 2024
d136de0
Merge branch 'master' into nn-lux
rbSparky Aug 25, 2024
4b32e2f
fixing
rbSparky Aug 25, 2024
6227cd3
Update shared_testsetup.jl: dont make other tests fail
rbSparky Aug 25, 2024
b1d185f
Update shared_testsetup.jl: fixing other tests
rbSparky Aug 25, 2024
ef68f79
gitignore
rbSparky Aug 25, 2024
4f0d60f
ignore
rbSparky Aug 25, 2024
e7661f2
remove useless params
rbSparky Aug 25, 2024
67bc8fd
Update GNNLux.jl: ordering
rbSparky Aug 30, 2024
b94b1f6
Update Project.toml: fixed
rbSparky Sep 3, 2024
91fed90
Update conv_tests.jl: checking test
rbSparky Sep 4, 2024
a587553
Update conv_tests.jl
rbSparky Sep 4, 2024
232a1b4
Update conv_tests.jl
rbSparky Sep 4, 2024
e2de74c
checking tests
rbSparky Sep 7, 2024
faa4df3
Update conv_tests.jl: typo in test
rbSparky Sep 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions GNNLux/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,29 @@ authors = ["Carlo Lucibello and contributors"]
version = "0.1.0"

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c"
GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48"
GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
rbSparky marked this conversation as resolved.
Show resolved Hide resolved

[compat]
ConcreteStructs = "0.2.3"
Expand Down
2 changes: 1 addition & 1 deletion GNNLux/src/GNNLux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ export AGNNConv,
GINConv,
# GMMConv,
GraphConv,
NNConv,
MEGNetConv,
rbSparky marked this conversation as resolved.
Show resolved Hide resolved
# NNConv,
# ResGatedGraphConv,
# SAGEConv,
SGConv
Expand Down
40 changes: 39 additions & 1 deletion GNNLux/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,44 @@ function Base.show(io::IO, l::GINConv)
print(io, ")")
end

@concrete struct NNConv <: GNNContainerLayer{(:nn,)}
nn <: AbstractExplicitLayer
aggr
in_dims::Int
out_dims::Int
use_bias::Bool
init_weight
init_bias
σ
end

function NNConv(ch::Pair{Int, Int}, nn, σ = identity;
aggr = +,
init_bias = zeros32,
use_bias::Bool = true,
init_weight = glorot_uniform,
allow_fast_activation::Bool = true)
in_dims, out_dims = ch
σ = allow_fast_activation ? NNlib.fast_act(σ) : σ
return NNConv(nn, aggr, in_dims, out_dims, use_bias, init_weight, init_bias, σ)
end

function (l::NNConv)(g, x, edge_weight, ps, st)
nn = StatefulLuxLayer{true}(l.nn, ps, st)

m = (; nn, l.aggr, ps.weight, bias = _getbias(ps), l.σ)
y = GNNlib.nn_conv(m, g, x, edge_weight)
stnew = _getstate(nn)
return y, stnew
end

function Base.show(io::IO, l::NNConv)
print(io, "NNConv($(l.nn)")
l.σ == identity || print(io, ", ", l.σ)
l.use_bias || print(io, ", use_bias=false")
print(io, ")")
end

@concrete struct MEGNetConv{TE, TV, A} <: GNNContainerLayer{(:ϕe, :ϕv)}
in_dims::Int
out_dims::Int
Expand Down Expand Up @@ -669,4 +707,4 @@ function Base.show(io::IO, l::MEGNetConv)
nout = l.out_dims
print(io, "MEGNetConv(", nin, " => ", nout)
print(io, ")")
end
end
10 changes: 10 additions & 0 deletions GNNLux/test/layers/conv_tests.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
@testitem "layers/conv" setup=[SharedTestSetup] begin
rng = StableRNG(1234)
edim = 10
g = rand_graph(rng, 10, 40)
in_dims = 3
out_dims = 5
x = randn(rng, Float32, in_dims, 10)

g2 = GNNGraph(g, edata = rand(Float32, edim, g.num_edges))

@testset "GCNConv" begin
l = GCNConv(in_dims => out_dims, tanh)
test_lux_layer(rng, l, g, x, outputsize=(out_dims,))
Expand Down Expand Up @@ -94,6 +97,13 @@
test_lux_layer(rng, l, g, x, sizey=(out_dims,g.num_nodes), container=true)
end

@testset "NNConv" begin
edim = 10
nn = Dense(edim, in_dims * out_dims)
l = NNConv(in_dims => out_dims, nn, tanh, aggr = +)
test_lux_layer(rng, l, g2, x, sizey=(out_dims, g2.num_nodes), container=true, edge_weight=g2.edata.e)
end

@testset "MEGNetConv" begin
l = MEGNetConv(in_dims => out_dims)

Expand Down
13 changes: 9 additions & 4 deletions GNNLux/test/shared_testsetup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ export test_lux_layer

function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x;
outputsize=nothing, sizey=nothing, container=false,
atol=1.0f-2, rtol=1.0f-2)
atol=1.0f-2, rtol=1.0f-2, edge_weight=nothing)

if container
@test l isa GNNContainerLayer
Expand All @@ -26,8 +26,13 @@ function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x;
st = LuxCore.initialstates(rng, l)
@test LuxCore.parameterlength(l) == LuxCore.parameterlength(ps)
@test LuxCore.statelength(l) == LuxCore.statelength(st)

y, st′ = l(g, x, ps, st)

if edge_weight !== nothing
y, st′ = l(g, x, edge_weight, ps, st)
else
y, st′ = l(g, x, ps, st)
end

@test eltype(y) == eltype(x)
if outputsize !== nothing
@test LuxCore.outputsize(l) == outputsize
Expand All @@ -42,4 +47,4 @@ function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x;
test_gradients(loss, x, ps; atol, rtol, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
end

end
end
Loading