diff --git a/GNNlib/src/layers/conv.jl b/GNNlib/src/layers/conv.jl index 95e439c81..ca3c5cfe0 100644 --- a/GNNlib/src/layers/conv.jl +++ b/GNNlib/src/layers/conv.jl @@ -71,9 +71,10 @@ function gcn_conv(l, g::AbstractGNNGraph, x, edge_weight::EW, norm_fn::F, conv_w end # when we also have edge_weight we need to convert the graph to COO -function gcn_conv(l, g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix, edge_weight::AbstractVector; kws...) +function gcn_conv(l, g::GNNGraph{<:ADJMAT_T}, x, edge_weight::EW, norm_fn::F, conv_weight::CW) where + {EW <: Union{Nothing, AbstractVector}, CW<:Union{Nothing,AbstractMatrix}, F} g = GNNGraph(edge_index(g)...; g.num_nodes) # convert to COO - return gcn_conv(l, g, x, edge_weight; kws...) + return gcn_conv(l, g, x, edge_weight, norm_fn, conv_weight) end function cheb_conv(l, g::GNNGraph, X::AbstractMatrix{T}) where {T}