Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

create GNNLux.jl package #460

Merged
merged 3 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions .github/workflows/test_GNNLux.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
name: GNNLux
on:
pull_request:
branches:
- master
push:
branches:
- master
jobs:
test:
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
version:
- '1.10' # Replace this with the minimum Julia version that your package supports.
# - '1' # '1' will automatically expand to the latest stable 1.x release of Julia.
# - 'pre'
os:
- ubuntu-latest
arch:
- x64

steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- uses: julia-actions/cache@v2
- uses: julia-actions/julia-buildpkg@v1
- name: Install Julia dependencies and run tests
shell: julia --project=monorepo {0}
run: |
using Pkg
# dev mono repo versions
pkg"registry up"
Pkg.update()
pkg"dev ./GNNGraphs ./GNNlib ./GNNLux"
Pkg.test("GNNLux"; coverage=true)
- uses: julia-actions/julia-processcoverage@v1
with:
# directories: ./GNNLux/src, ./GNNLux/ext
directories: ./GNNLux/src
- uses: codecov/codecov-action@v4
with:
files: lcov.info
21 changes: 21 additions & 0 deletions GNNLux/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2024 Carlo Lucibello <carlo.lucibello@gmail.com> and contributors

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
35 changes: 35 additions & 0 deletions GNNLux/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
name = "GNNLux"
uuid = "e8545f4d-a905-48ac-a8c4-ca114b98986d"
authors = ["Carlo Lucibello and contributors"]
version = "0.1.0"

[deps]
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c"
GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ConcreteStructs = "0.2.3"
Lux = "0.5.61"
LuxCore = "0.1.20"
NNlib = "0.9.21"
Reexport = "1.2"
julia = "1.10"

[extras]
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "ComponentArrays", "Functors", "LuxTestUtils", "ReTestItems", "StableRNGs", "Zygote"]
2 changes: 2 additions & 0 deletions GNNLux/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# GNNLux.jl

15 changes: 15 additions & 0 deletions GNNLux/src/GNNLux.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
module GNNLux
using ConcreteStructs: @concrete
using NNlib: NNlib
using LuxCore: LuxCore, AbstractExplicitLayer
using Lux: glorot_uniform, zeros32
using Reexport: @reexport
using Random: AbstractRNG
using GNNlib: GNNlib
@reexport using GNNGraphs

include("layers/conv.jl")
export GraphConv

end #module

93 changes: 93 additions & 0 deletions GNNLux/src/layers/conv.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@

@doc raw"""
GraphConv(in => out, σ=identity; aggr=+, bias=true, init=glorot_uniform)

Graph convolution layer from Reference: [Weisfeiler and Leman Go Neural: Higher-order Graph Neural Networks](https://arxiv.org/abs/1810.02244).

Performs:
```math
\mathbf{x}_i' = W_1 \mathbf{x}_i + \square_{j \in \mathcal{N}(i)} W_2 \mathbf{x}_j
```

where the aggregation type is selected by `aggr`.

# Arguments

- `in`: The dimension of input features.
- `out`: The dimension of output features.
- `σ`: Activation function.
- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`).
- `bias`: Add learnable bias.
- `init`: Weights' initializer.

# Examples

```julia
# create data
s = [1,1,2,3]
t = [2,3,1,1]
in_channel = 3
out_channel = 5
g = GNNGraph(s, t)
x = randn(Float32, 3, g.num_nodes)

# create layer
l = GraphConv(in_channel => out_channel, relu, bias = false, aggr = mean)

# forward pass
y = l(g, x)
```
"""
@concrete struct GraphConv <: AbstractExplicitLayer
in_dims::Int
out_dims::Int
use_bias::Bool
init_weight::Function
init_bias::Function
σ
aggr
end


function GraphConv(ch::Pair{Int, Int}, σ = identity;
aggr = +,
init_weight = glorot_uniform,
init_bias = zeros32,
use_bias::Bool = true,
allow_fast_activation::Bool = true)
in_dims, out_dims = ch
σ = allow_fast_activation ? NNlib.fast_act(σ) : σ
return GraphConv(in_dims, out_dims, use_bias, init_weight, init_bias, σ, aggr)
end

function LuxCore.initialparameters(rng::AbstractRNG, l::GraphConv)
weight1 = l.init_weight(rng, l.out_dims, l.in_dims)
weight2 = l.init_weight(rng, l.out_dims, l.in_dims)
if l.use_bias
bias = l.init_bias(rng, l.out_dims)
else
bias = false
end
return (; weight1, weight2, bias)
end

function LuxCore.parameterlength(l::GraphConv)
if l.use_bias
return 2 * l.in_dims * l.out_dims + l.out_dims
else
return 2 * l.in_dims * l.out_dims
end
end

LuxCore.statelength(d::GraphConv) = 0
LuxCore.outputsize(d::GraphConv) = (d.out_dims,)

function Base.show(io::IO, l::GraphConv)
print(io, "GraphConv(", l.in_dims, " => ", l.out_dims)
(l.σ == identity) || print(io, ", ", l.σ)
(l.aggr == +) || print(io, ", aggr=", l.aggr)
l.use_bias || print(io, ", use_bias=false")
print(io, ")")
end

(l::GraphConv)(g::GNNGraph, x, ps, st) = GNNlib.graph_conv(l, g, x, ps), st
19 changes: 19 additions & 0 deletions GNNLux/test/layers/conv_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
@testitem "layers/conv" setup=[SharedTestSetup] begin
rng = StableRNG(1234)
g = rand_graph(10, 30, seed=1234)
x = randn(rng, Float32, 3, 10)

@testset "GraphConv" begin
l = GraphConv(3 => 5, relu)
ps = Lux.initialparameters(rng, l)
st = Lux.initialstates(rng, l)
@test Lux.parameterlength(l) == Lux.parameterlength(ps)
@test Lux.statelength(l) == Lux.statelength(st)

y, _ = l(g, x, ps, st)
@test Lux.outputsize(l) == (5,)
@test size(y) == (5, 10)
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3
end
end
10 changes: 10 additions & 0 deletions GNNLux/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
using Test
using Lux
using GNNLux
using Random, Statistics

using ReTestItems
# using Pkg, Preferences, Test
# using InteractiveUtils, Hwloc

runtests(GNNLux)
23 changes: 23 additions & 0 deletions GNNLux/test/shared_testsetup.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
@testsetup module SharedTestSetup

import Reexport: @reexport

@reexport using Lux, Functors
@reexport using ComponentArrays, LuxCore, LuxTestUtils, Random, StableRNGs, Test,
Zygote, Statistics
@reexport using LuxTestUtils: @jet, @test_gradients, check_approx

# Some Helper Functions
function get_default_rng(mode::String)
dev = mode == "cpu" ? LuxCPUDevice() :
mode == "cuda" ? LuxCUDADevice() : mode == "amdgpu" ? LuxAMDGPUDevice() : nothing
rng = default_device_rng(dev)
return rng isa TaskLocalRNG ? copy(rng) : deepcopy(rng)
end

export get_default_rng

# export BACKEND_GROUP, MODES, cpu_testing, cuda_testing, amdgpu_testing, get_default_rng,
# StableRNG, maybe_rewrite_to_crosscor

end
9 changes: 6 additions & 3 deletions GNNlib/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,15 @@ function cheb_conv(c, g::GNNGraph, X::AbstractMatrix{T}) where {T}
return Y .+ c.bias
end

function graph_conv(l, g::AbstractGNNGraph, x)
function graph_conv(l, g::AbstractGNNGraph, x, ps)
check_num_nodes(g, x)
xj, xi = expand_srcdst(g, x)
m = propagate(copy_xj, g, l.aggr, xj = xj)
x = l.σ.(l.weight1 * xi .+ l.weight2 * m .+ l.bias)
return x
x = ps.weight1 * xi .+ ps.weight2 * m
if l.use_bias
x = x .+ ps.bias
end
return l.σ.(x)
end

function gat_conv(l, g::AbstractGNNGraph, x, e::Union{Nothing, AbstractMatrix} = nothing)
Expand Down
Loading