Skip to content

Commit

Permalink
fix zygote error
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Jan 13, 2025
1 parent 4d50023 commit d7fd26d
Show file tree
Hide file tree
Showing 8 changed files with 59 additions and 47 deletions.
2 changes: 1 addition & 1 deletion GNNGraphs/src/GNNGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import NearestNeighbors
import NNlib
import StatsBase
import KrylovKit
using ChainRulesCore
import ChainRulesCore as CRC
using LinearAlgebra, Random, Statistics
import MLUtils
using MLUtils: getobs, numobs, ones_like, zeros_like, chunk, batch, rand_like
Expand Down
11 changes: 6 additions & 5 deletions GNNGraphs/src/chainrules.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
# Taken from https://github.com/JuliaDiff/ChainRules.jl/pull/648
# Remove when merged

function ChainRulesCore.rrule(::Type{T}, ps::Pair...) where {T<:Dict}
function CRC.rrule(::Type{T}, ps::Pair...) where {T<:Dict}
ks = map(first, ps)
project_ks, project_vs = map(ProjectTo, ks), map(ProjectTolast, ps)
project_ks, project_vs = map(CRC.ProjectTo, ks), map(CRC.ProjectTo last, ps)
function Dict_pullback(ȳ)
dy = CRC.unthunk(ȳ)
dps = map(ks, project_ks, project_vs) do k, proj_k, proj_v
dk, dv = proj_k(getkey(, k, NoTangent())), proj_v(get(, k, NoTangent()))
Tangent{Pair{typeof(dk), typeof(dv)}}(first = dk, second = dv)
dk, dv = proj_k(getkey(dy, k, CRC.NoTangent())), proj_v(get(dy, k, CRC.NoTangent()))
CRC.Tangent{Pair{typeof(dk), typeof(dv)}}(first = dk, second = dv)
end
return (NoTangent(), dps...)
return (CRC.NoTangent(), dps...)
end
return T(ps...), Dict_pullback
end
2 changes: 1 addition & 1 deletion GNNGraphs/src/convert.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ function _findnz_idx(A)
return s, t, nz
end

@non_differentiable _findnz_idx(A)
CRC.@non_differentiable _findnz_idx(A)

function to_coo(A::ADJMAT_T; dir = :out, num_nodes = nothing, weighted = true)
s, t, nz = _findnz_idx(A)
Expand Down
4 changes: 2 additions & 2 deletions GNNGraphs/src/gnnheterograph/gnnheterograph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ end

# TODO this is not correct but Zygote cannot differentiate
# through dictionary generation
# @non_differentiable edge_type_subgraph(::Any...)
# CRC.@non_differentiable edge_type_subgraph(::Any...)

function _ntypes_from_edges(edge_ts::AbstractVector{<:EType})
ntypes = Symbol[]
Expand All @@ -285,7 +285,7 @@ function _ntypes_from_edges(edge_ts::AbstractVector{<:EType})
return ntypes
end

@non_differentiable _ntypes_from_edges(::Any...)
CRC.@non_differentiable _ntypes_from_edges(::Any...)

function Base.getindex(g::GNNHeteroGraph, node_t::NType)
return g.ndata[node_t]
Expand Down
62 changes: 36 additions & 26 deletions GNNGraphs/src/query.jl
Original file line number Diff line number Diff line change
Expand Up @@ -241,37 +241,46 @@ function Graphs.adjacency_matrix(g::GNNGraph{<:ADJMAT_T}, T::DataType = eltype(g
return dir == :out ? A : A'
end

function ChainRulesCore.rrule(::typeof(adjacency_matrix), g::G, T::DataType;
function CRC.rrule(::typeof(adjacency_matrix), g::G, T::DataType;
dir = :out, weighted = true) where {G <: GNNGraph{<:ADJMAT_T}}
A = adjacency_matrix(g, T; dir, weighted)
if !weighted
function adjacency_matrix_pullback_noweight(Δ)
return (NoTangent(), ZeroTangent(), NoTangent())
return (CRC.NoTangent(), CRC.ZeroTangent(), CRC.NoTangent())
end
return A, adjacency_matrix_pullback_noweight
else
function adjacency_matrix_pullback_weighted(Δ)
dg = Tangent{G}(; graph = Δ .* binarize(A))
return (NoTangent(), dg, NoTangent())
dy = CRC.unthunk(Δ)
dg = CRC.Tangent{G}(; graph = dy .* binarize(dy))
return (CRC.NoTangent(), dg, CRC.NoTangent())
end
return A, adjacency_matrix_pullback_weighted
end
end

function ChainRulesCore.rrule(::typeof(adjacency_matrix), g::G, T::DataType;
function CRC.rrule(::typeof(adjacency_matrix), g::G, T::DataType;
dir = :out, weighted = true) where {G <: GNNGraph{<:COO_T}}
A = adjacency_matrix(g, T; dir, weighted)
w = get_edge_weight(g)
if !weighted || w === nothing
function adjacency_matrix_pullback_noweight(Δ)
return (NoTangent(), ZeroTangent(), NoTangent())
return (CRC.NoTangent(), CRC.ZeroTangent(), CRC.NoTangent())
end
return A, adjacency_matrix_pullback_noweight
else
function adjacency_matrix_pullback_weighted(Δ)
dy = CRC.unthunk(Δ)
s, t = edge_index(g)
dg = Tangent{G}(; graph = (NoTangent(), NoTangent(), NNlib.gather(Δ, s, t)))
return (NoTangent(), dg, NoTangent())
@show dy s t
#TODO using CRC.@thunk gives an error
#TODO use gather when https://github.com/FluxML/NNlib.jl/issues/625 is fixed
dw = zeros_like(w)
idx = CartesianIndex.(s, t) #TODO remove when https://github.com/FluxML/NNlib.jl/issues/626 is fixed
NNlib.gather!(dw, dy, idx)
@show dw
dg = CRC.Tangent{G}(; graph = (CRC.NoTangent(), CRC.NoTangent(), dw))
return (CRC.NoTangent(), dg, CRC.NoTangent())
end
return A, adjacency_matrix_pullback_weighted
end
Expand Down Expand Up @@ -378,34 +387,35 @@ function _degree(A::AbstractMatrix, T::Type, dir::Symbol, edge_weight::Bool, num
vec(sum(A, dims = 1)) .+ vec(sum(A, dims = 2))
end

function ChainRulesCore.rrule(::typeof(_degree), graph, T, dir, edge_weight::Nothing, num_nodes)
function CRC.rrule(::typeof(_degree), graph, T, dir, edge_weight::Nothing, num_nodes)
degs = _degree(graph, T, dir, edge_weight, num_nodes)
function _degree_pullback(Δ)
return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent())
return ntuple(i -> (CRC.NoTangent(),), 6)
end
return degs, _degree_pullback
end

function ChainRulesCore.rrule(::typeof(_degree), A::ADJMAT_T, T, dir, edge_weight::Bool, num_nodes)
function CRC.rrule(::typeof(_degree), A::ADJMAT_T, T, dir, edge_weight::Bool, num_nodes)
degs = _degree(A, T, dir, edge_weight, num_nodes)
if edge_weight === false
function _degree_pullback_noweights(Δ)
return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent())
return ntuple(i -> (CRC.NoTangent(),), 6)
end
return degs, _degree_pullback_noweights
else
function _degree_pullback_weights(Δ)
dy = CRC.unthunk(Δ)
# We propagate the gradient only to the non-zero elements
# of the adjacency matrix.
bA = binarize(A)
if dir == :in
dA = bA .* Δ'
dA = bA .* dy'
elseif dir == :out
dA = Δ .* bA
dA = dy .* bA
else # dir == :both
dA = Δ .* bA + Δ' .* bA
dA = dy .* bA + dy' .* bA
end
return (NoTangent(), dA, NoTangent(), NoTangent(), NoTangent(), NoTangent())
return (CRC.NoTangent(), dA, CRC.NoTangent(), CRC.NoTangent(), CRC.NoTangent(), CRC.NoTangent())
end
return degs, _degree_pullback_weights
end
Expand Down Expand Up @@ -452,7 +462,7 @@ function normalized_adjacency(g::GNNGraph, T::DataType = Float32;
A = A + I
end
degs = vec(sum(A; dims = 2))
ChainRulesCore.ignore_derivatives() do
CRC.ignore_derivatives() do
@assert all(!iszero, degs) "Graph contains isolated nodes, cannot compute `normalized_adjacency`."
end
inv_sqrtD = Diagonal(inv.(sqrt.(degs)))
Expand Down Expand Up @@ -609,12 +619,12 @@ function laplacian_lambda_max(g::GNNGraph, T::DataType = Float32;
end
end

@non_differentiable edge_index(x...)
@non_differentiable adjacency_list(x...)
@non_differentiable graph_indicator(x...)
@non_differentiable has_multi_edges(x...)
@non_differentiable Graphs.has_self_loops(x...)
@non_differentiable is_bidirected(x...)
@non_differentiable normalized_adjacency(x...) # TODO remove this in the future
@non_differentiable normalized_laplacian(x...) # TODO remove this in the future
@non_differentiable scaled_laplacian(x...) # TODO remove this in the future
CRC.@non_differentiable edge_index(x...)
CRC.@non_differentiable adjacency_list(x...)
CRC.@non_differentiable graph_indicator(x...)
CRC.@non_differentiable has_multi_edges(x...)
CRC.@non_differentiable Graphs.has_self_loops(x...)
CRC.@non_differentiable is_bidirected(x...)
CRC.@non_differentiable normalized_adjacency(x...) # TODO remove this in the future
CRC.@non_differentiable normalized_laplacian(x...) # TODO remove this in the future
CRC.@non_differentiable scaled_laplacian(x...) # TODO remove this in the future
12 changes: 6 additions & 6 deletions GNNGraphs/src/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -808,8 +808,8 @@ function _unbatch_edgemasks(s, t, num_graphs, cumnum_nodes)
return edgemasks
end

@non_differentiable _unbatch_nodemasks(::Any...)
@non_differentiable _unbatch_edgemasks(::Any...)
CRC.@non_differentiable _unbatch_nodemasks(::Any...)
CRC.@non_differentiable _unbatch_edgemasks(::Any...)

"""
getgraph(g::GNNGraph, i; nmap=false)
Expand Down Expand Up @@ -998,10 +998,10 @@ dense_zeros_like(x, sz = size(x)) = dense_zeros_like(x, eltype(x), sz)
# """
ci2t(ci::AbstractVector{<:CartesianIndex}, dims) = ntuple(i -> map(x -> x[i], ci), dims)

@non_differentiable negative_sample(x...)
@non_differentiable add_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule
@non_differentiable remove_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule
@non_differentiable dense_zeros_like(x...)
CRC.@non_differentiable negative_sample(x...)
CRC.@non_differentiable add_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule
CRC.@non_differentiable remove_self_loops(x...) # TODO this is wrong, since g carries feature arrays, needs rrule
CRC.@non_differentiable dense_zeros_like(x...)

"""
ppr_diffusion(g::GNNGraph{<:COO_T}, alpha =0.85f0) -> GNNGraph
Expand Down
12 changes: 6 additions & 6 deletions GNNGraphs/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -292,9 +292,9 @@ end

binarize(x) = map(>(0), x)

@non_differentiable binarize(x...)
@non_differentiable edge_encoding(x...)
@non_differentiable edge_decoding(x...)
CRC.@non_differentiable binarize(x...)
CRC.@non_differentiable edge_encoding(x...)
CRC.@non_differentiable edge_decoding(x...)

### PRINTING #####

Expand Down Expand Up @@ -330,11 +330,11 @@ function dims2string(d)
join(map(string, d), '×')
end

@non_differentiable normalize_graphdata(::NamedTuple{(), Tuple{}})
@non_differentiable normalize_graphdata(::Nothing)
CRC.@non_differentiable normalize_graphdata(::NamedTuple{(), Tuple{}})
CRC.@non_differentiable normalize_graphdata(::Nothing)

iscuarray(x::AbstractArray) = false
@non_differentiable iscuarray(::Any)
CRC.@non_differentiable iscuarray(::Any)


@doc raw"""
Expand Down
1 change: 1 addition & 0 deletions GNNGraphs/test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Expand Down

0 comments on commit d7fd26d

Please sign in to comment.