Skip to content

Commit

Permalink
[GNNLux] fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Jul 29, 2024
1 parent fb394d1 commit f56bfa0
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 89 deletions.
15 changes: 5 additions & 10 deletions GNNLux/test/layers/basic_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,13 @@
@test GNNLayer <: LuxCore.AbstractExplicitLayer
end

@testset "GNNContainerLayer" begin
@test GNNContainerLayer <: LuxCore.AbstractExplicitContainerLayer
end

@testset "GNNChain" begin
@test GNNChain <: LuxCore.AbstractExplicitContainerLayer{(:layers,)}
@test GNNChain <: GNNContainerLayer
c = GNNChain(GraphConv(3 => 5, relu), GCNConv(5 => 3))
ps = LuxCore.initialparameters(rng, c)
st = LuxCore.initialstates(rng, c)
@test LuxCore.parameterlength(c) == LuxCore.parameterlength(ps)
@test LuxCore.statelength(c) == LuxCore.statelength(st)
y, st′ = c(g, x, ps, st)
@test LuxCore.outputsize(c) == (3,)
@test size(y) == (3, 10)
loss = (x, ps) -> sum(first(c(g, x, ps, st)))
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true skip_reverse_diff=true
test_lux_layer(rng, c, g, x, outputsize=(3,), container=true)
end
end
71 changes: 7 additions & 64 deletions GNNLux/test/layers/conv_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,89 +5,32 @@

@testset "GCNConv" begin
l = GCNConv(3 => 5, relu)
@test l isa GNNLayer
ps = Lux.initialparameters(rng, l)
st = Lux.initialstates(rng, l)
@test Lux.parameterlength(l) == Lux.parameterlength(ps)
@test Lux.statelength(l) == Lux.statelength(st)

y, _ = l(g, x, ps, st)
@test Lux.outputsize(l) == (5,)
@test size(y) == (5, 10)
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true
test_lux_layer(rng, l, g, x, outputsize=(5,))
end

@testset "ChebConv" begin
l = ChebConv(3 => 5, 2)
@test l isa GNNLayer
ps = Lux.initialparameters(rng, l)
st = Lux.initialstates(rng, l)
@test Lux.parameterlength(l) == Lux.parameterlength(ps)
@test Lux.statelength(l) == Lux.statelength(st)

y, _ = l(g, x, ps, st)
@test Lux.outputsize(l) == (5,)
@test size(y) == (5, 10)
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true skip_reverse_diff=true
test_lux_layer(rng, l, g, x, outputsize=(5,))
end

@testset "GraphConv" begin
l = GraphConv(3 => 5, relu)
@test l isa GNNLayer
ps = Lux.initialparameters(rng, l)
st = Lux.initialstates(rng, l)
@test Lux.parameterlength(l) == Lux.parameterlength(ps)
@test Lux.statelength(l) == Lux.statelength(st)

y, _ = l(g, x, ps, st)
@test Lux.outputsize(l) == (5,)
@test size(y) == (5, 10)
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true
test_lux_layer(rng, l, g, x, outputsize=(5,))
end

@testset "AGNNConv" begin
l = AGNNConv(init_beta=1.0f0)
@test l isa GNNLayer
ps = Lux.initialparameters(rng, l)
st = Lux.initialstates(rng, l)
@test Lux.parameterlength(ps) == 1
@test Lux.parameterlength(l) == Lux.parameterlength(ps)
@test Lux.statelength(l) == Lux.statelength(st)

y, _ = l(g, x, ps, st)
@test size(y) == size(x)
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true skip_reverse_diff=true
test_lux_layer(rng, l, g, x, sizey=(3,10))
end

@testset "EdgeConv" begin
nn = Chain(Dense(6 => 5, relu), Dense(5 => 5))
l = EdgeConv(nn, aggr = +)
@test l isa GNNContainerLayer
ps = Lux.initialparameters(rng, l)
st = Lux.initialstates(rng, l)
@test Lux.parameterlength(l) == Lux.parameterlength(ps)
@test Lux.statelength(l) == Lux.statelength(st)
y, st′ = l(g, x, ps, st)
@test size(y) == (5, 10)
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true skip_reverse_diff=true
test_lux_layer(rng, l, g, x, sizey=(5,10), container=true)
end

@testset "CGConv" begin
l = CGConv(3 => 5, residual = true)
@test l isa GNNContainerLayer
ps = Lux.initialparameters(rng, l)
st = Lux.initialstates(rng, l)
@test Lux.parameterlength(l) == Lux.parameterlength(ps)
@test Lux.statelength(l) == Lux.statelength(st)
y, st′ = l(g, x, ps, st)
@test size(y) == (5, 10)
@test Lux.outputsize(l) == (5,)
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true
l = CGConv(3 => 3, residual = true)
test_lux_layer(rng, l, g, x, outputsize=(3,), container=true)
end
end
50 changes: 35 additions & 15 deletions GNNLux/test/shared_testsetup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,43 @@

import Reexport: @reexport

@reexport using Test
@reexport using GNNLux
@reexport using Lux, Functors
@reexport using ComponentArrays, LuxCore, LuxTestUtils, Random, StableRNGs, Test,
Zygote, Statistics
@reexport using LuxTestUtils: @jet, @test_gradients, check_approx
using MLDataDevices

# Some Helper Functions
function get_default_rng(mode::String)
dev = mode == "cpu" ? CPUDevice() :
mode == "cuda" ? CUDADevice() : mode == "amdgpu" ? AMDGPUDevice() : nothing
rng = default_device_rng(dev)
return rng isa TaskLocalRNG ? copy(rng) : deepcopy(rng)
end
@reexport using Lux
@reexport using StableRNGs
@reexport using Random, Statistics

using LuxTestUtils: test_gradients, AutoReverseDiff, AutoTracker, AutoForwardDiff, AutoEnzyme

export test_lux_layer

export get_default_rng
function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x;
outputsize=nothing, sizey=nothing, container=false,
atol=1.0f-2, rtol=1.0f-2)

# export BACKEND_GROUP, MODES, cpu_testing, cuda_testing, amdgpu_testing
if container
@test l isa GNNContainerLayer
else
@test l isa GNNLayer
end

ps = LuxCore.initialparameters(rng, l)
st = LuxCore.initialstates(rng, l)
@test LuxCore.parameterlength(l) == LuxCore.parameterlength(ps)
@test LuxCore.statelength(l) == LuxCore.statelength(st)

y, st′ = l(g, x, ps, st)
if outputsize !== nothing
@test LuxCore.outputsize(l) == outputsize
end
if sizey !== nothing
@test size(y) == sizey
elseif outputsize !== nothing
@test size(y) == (outputsize..., g.num_nodes)
end

loss = (x, ps) -> sum(first(l(g, x, ps, st)))
test_gradients(loss, x, ps; atol, rtol, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
end

end

0 comments on commit f56bfa0

Please sign in to comment.