Skip to content

Commit

Permalink
Add lazy kronecker product for matrix kernels, if Kronecker.jl is loa…
Browse files Browse the repository at this point in the history
…ded (#364)

* Restore additions

* Improvements for lazy kron

* Remove unneeded lines

* Small experiment with overwriting

* Reorder and overwrite

* Format and kernelmatrix!

* Reinstate separate method

* Adding tests

* Duplicate code for readability

* Format

* Remove comment and patch bump

* Change kernelmatrix!

* Change to output covariance type

* Change to output covariance type - revert

* Revert "Change to output covariance type - revert"

This reverts commit 09cd20e.

* Revert "Change kernelmatrix!"

This reverts commit f46bd61.

* Add kernelmatrix! changes again

* Change input types for pairwise pullback

* Missing changes to Any
  • Loading branch information
Crown421 authored Sep 23, 2021
1 parent 671f960 commit c76b27d
Show file tree
Hide file tree
Showing 8 changed files with 141 additions and 40 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "KernelFunctions"
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
version = "0.10.17"
version = "0.10.18"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
2 changes: 1 addition & 1 deletion src/KernelFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ include("kernels/neuralkernelnetwork.jl")
include("approximations/nystrom.jl")
include("generic.jl")

include("mokernels/mokernel.jl")
include("mokernels/moinput.jl")
include("mokernels/mokernel.jl")
include("mokernels/independent.jl")
include("mokernels/slfm.jl")
include("mokernels/intrinsiccoregion.jl")
Expand Down
12 changes: 6 additions & 6 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ function ChainRulesCore.rrule(
::typeof(Distances.pairwise), d::Delta, X::AbstractMatrix, Y::AbstractMatrix; dims=2
)
P = Distances.pairwise(d, X, Y; dims=dims)
function pairwise_pullback(::AbstractMatrix)
function pairwise_pullback(::Any)
return NoTangent(), NoTangent(), ZeroTangent(), ZeroTangent()
end
return P, pairwise_pullback
Expand All @@ -36,7 +36,7 @@ function ChainRulesCore.rrule(
::typeof(Distances.pairwise), d::Delta, X::AbstractMatrix; dims=2
)
P = Distances.pairwise(d, X; dims=dims)
function pairwise_pullback(::AbstractMatrix)
function pairwise_pullback(::Any)
return NoTangent(), NoTangent(), ZeroTangent()
end
return P, pairwise_pullback
Expand All @@ -46,7 +46,7 @@ function ChainRulesCore.rrule(
::typeof(Distances.colwise), d::Delta, X::AbstractMatrix, Y::AbstractMatrix
)
C = Distances.colwise(d, X, Y)
function colwise_pullback(::AbstractVector)
function colwise_pullback(::Any)
return NoTangent(), NoTangent(), ZeroTangent(), ZeroTangent()
end
return C, colwise_pullback
Expand All @@ -70,7 +70,7 @@ function ChainRulesCore.rrule(
dims=2,
)
P = Distances.pairwise(d, X, Y; dims=dims)
function pairwise_pullback_cols::AbstractMatrix)
function pairwise_pullback_cols::Any)
if dims == 1
return NoTangent(), NoTangent(), Δ * Y, Δ' * X
else
Expand All @@ -84,7 +84,7 @@ function ChainRulesCore.rrule(
::typeof(Distances.pairwise), d::DotProduct, X::AbstractMatrix; dims=2
)
P = Distances.pairwise(d, X; dims=dims)
function pairwise_pullback_cols::AbstractMatrix)
function pairwise_pullback_cols::Any)
if dims == 1
return NoTangent(), NoTangent(), 2 * Δ * X
else
Expand All @@ -98,7 +98,7 @@ function ChainRulesCore.rrule(
::typeof(Distances.colwise), d::DotProduct, X::AbstractMatrix, Y::AbstractMatrix
)
C = Distances.colwise(d, X, Y)
function colwise_pullback::AbstractVector)
function colwise_pullback::Any)
return NoTangent(), NoTangent(), Δ' .* Y, Δ' .* X
end
return C, colwise_pullback
Expand Down
40 changes: 40 additions & 0 deletions src/matrix/kernelkroneckermat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using .Kronecker: Kronecker

export kernelkronmat
export kronecker_kernelmatrix

function kernelkronmat::Kernel, X::AbstractVector, dims::Int)
@assert iskroncompatible(κ) "The chosen kernel is not compatible for kroenecker matrices (see [`iskroncompatible`](@ref))"
Expand All @@ -25,3 +26,42 @@ end
k(x,x') = ∏ᵢᴰ k(xᵢ,x'ᵢ)
"""
@inline iskroncompatible::Kernel) = false # Default return for kernels

function _kernelmatrix_kroneckerjl_helper(::MOInputIsotopicByFeatures, Kfeatures, Koutputs)
return Kronecker.kronecker(Kfeatures, Koutputs)
end

function _kernelmatrix_kroneckerjl_helper(::MOInputIsotopicByOutputs, Kfeatures, Koutputs)
return Kronecker.kronecker(Koutputs, Kfeatures)
end

function kronecker_kernelmatrix(
k::Union{IndependentMOKernel,IntrinsicCoregionMOKernel},
x::IsotopicMOInputsUnion,
y::IsotopicMOInputsUnion,
)
@assert x.out_dim == y.out_dim
Kfeatures = kernelmatrix(k.kernel, x.x, y.x)
Koutputs = _mo_output_covariance(k, x.out_dim)
return _kernelmatrix_kroneckerjl_helper(x, Kfeatures, Koutputs)
end

function kronecker_kernelmatrix(
k::Union{IndependentMOKernel,IntrinsicCoregionMOKernel}, x::IsotopicMOInputsUnion
)
Kfeatures = kernelmatrix(k.kernel, x.x)
Koutputs = _mo_output_covariance(k, x.out_dim)
return _kernelmatrix_kroneckerjl_helper(x, Kfeatures, Koutputs)
end

function kronecker_kernelmatrix(
k::MOKernel, x::IsotopicMOInputsUnion, y::IsotopicMOInputsUnion
)
return throw(
ArgumentError("This kernel does not support a lazy kronecker kernelmatrix.")
)
end

function kronecker_kernelmatrix(k::MOKernel, x::IsotopicMOInputsUnion)
return kronecker_kernelmatrix(k, x, x)
end
39 changes: 13 additions & 26 deletions src/mokernels/independent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,41 +27,28 @@ function (κ::IndependentMOKernel)((x, px)::Tuple{Any,Int}, (y, py)::Tuple{Any,I
return κ.kernel(x, y) * (px == py)
end

function _kernelmatrix_kron_helper(::MOInputIsotopicByFeatures, Kfeatures, B)
return kron(Kfeatures, B)
end

function _kernelmatrix_kron_helper(::MOInputIsotopicByOutputs, Kfeatures, B)
return kron(B, Kfeatures)
end
_mo_output_covariance(k::IndependentMOKernel, out_dim) = Eye{Bool}(out_dim)

function kernelmatrix(
k::IndependentMOKernel, x::MOI, y::MOI
) where {MOI<:IsotopicMOInputsUnion}
k::IndependentMOKernel, x::IsotopicMOInputsUnion, y::IsotopicMOInputsUnion
)
@assert x.out_dim == y.out_dim
Kfeatures = kernelmatrix(k.kernel, x.x, y.x)
mtype = eltype(Kfeatures)
return _kernelmatrix_kron_helper(x, Kfeatures, Eye{mtype}(x.out_dim))
Koutputs = _mo_output_covariance(k, x.out_dim)
return _kernelmatrix_kron_helper(x, Kfeatures, Koutputs)
end

if VERSION >= v"1.6"
function _kernelmatrix_kron_helper!(K, ::MOInputIsotopicByFeatures, Kfeatures, B)
return kron!(K, Kfeatures, B)
end

function _kernelmatrix_kron_helper!(K, ::MOInputIsotopicByOutputs, Kfeatures, B)
return kron!(K, B, Kfeatures)
end

function kernelmatrix!(
K::AbstractMatrix, k::IndependentMOKernel, x::MOI, y::MOI
) where {MOI<:IsotopicMOInputsUnion}
K::AbstractMatrix,
k::IndependentMOKernel,
x::IsotopicMOInputsUnion,
y::IsotopicMOInputsUnion,
)
@assert x.out_dim == y.out_dim
Ktmp = kernelmatrix(k.kernel, x.x, y.x)
mtype = eltype(Ktmp)
return _kernelmatrix_kron_helper!(
K, x, Ktmp, Matrix{mtype}(I, x.out_dim, x.out_dim)
)
Kfeatures = kernelmatrix(k.kernel, x.x, y.x)
Koutputs = _mo_output_covariance(k, x.out_dim)
return _kernelmatrix_kron_helper!(K, x, Kfeatures, Koutputs)
end
end

Expand Down
22 changes: 16 additions & 6 deletions src/mokernels/intrinsiccoregion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,31 @@ function (k::IntrinsicCoregionMOKernel)((x, px)::Tuple{Any,Int}, (y, py)::Tuple{
return k.B[px, py] * k.kernel(x, y)
end

function _mo_output_covariance(k::IntrinsicCoregionMOKernel, out_dim)
@assert size(k.B) == (out_dim, out_dim)
return k.B
end

function kernelmatrix(
k::IntrinsicCoregionMOKernel, x::MOI, y::MOI
) where {MOI<:IsotopicMOInputsUnion}
k::IntrinsicCoregionMOKernel, x::IsotopicMOInputsUnion, y::IsotopicMOInputsUnion
)
@assert x.out_dim == y.out_dim
Kfeatures = kernelmatrix(k.kernel, x.x, y.x)
return _kernelmatrix_kron_helper(x, Kfeatures, k.B)
Koutputs = _mo_output_covariance(k, x.out_dim)
return _kernelmatrix_kron_helper(x, Kfeatures, Koutputs)
end

if VERSION >= v"1.6"
function kernelmatrix!(
K::AbstractMatrix, k::IntrinsicCoregionMOKernel, x::MOI, y::MOI
) where {MOI<:IsotopicMOInputsUnion}
K::AbstractMatrix,
k::IntrinsicCoregionMOKernel,
x::IsotopicMOInputsUnion,
y::IsotopicMOInputsUnion,
)
@assert x.out_dim == y.out_dim
Kfeatures = kernelmatrix(k.kernel, x.x, y.x)
return _kernelmatrix_kron_helper!(K, x, Kfeatures, k.B)
Koutputs = _mo_output_covariance(k, x.out_dim)
return _kernelmatrix_kron_helper!(K, x, Kfeatures, Koutputs)
end
end

Expand Down
18 changes: 18 additions & 0 deletions src/mokernels/mokernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,21 @@
Abstract type for kernels with multiple outpus.
"""
abstract type MOKernel <: Kernel end

function _kernelmatrix_kron_helper(::MOInputIsotopicByFeatures, Kfeatures, Koutputs)
return kron(Kfeatures, Koutputs)
end

function _kernelmatrix_kron_helper(::MOInputIsotopicByOutputs, Kfeatures, Koutputs)
return kron(Koutputs, Kfeatures)
end

if VERSION >= v"1.6"
function _kernelmatrix_kron_helper!(K, ::MOInputIsotopicByFeatures, Kfeatures, Koutputs)
return kron!(K, Kfeatures, Koutputs)
end

function _kernelmatrix_kron_helper!(K, ::MOInputIsotopicByOutputs, Kfeatures, Koutputs)
return kron!(K, Koutputs, Kfeatures)
end
end
46 changes: 46 additions & 0 deletions test/matrix/kernelkroneckermat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,50 @@
@test all(collect(kernelkronmat(k, collect(x), 2)) .≈ kernelmatrix(k, X; obsdim=1))
@test all(collect(kernelkronmat(k, [x, x])) .≈ kernelmatrix(k, X; obsdim=1))
@test_throws AssertionError kernelkronmat(LinearKernel(), collect(x), 2)

@testset "lazy kernelmatrix" begin
rng = MersenneTwister(123)

dims = (in=3, out=2, obs=3)
r = 1

A = randn(dims.out, r)
B = A * transpose(A) + Diagonal(rand(dims.out))

# XIF = [(rand(dims.in), rand(1:(dims.out))) for i in 1:(dims.obs)]
x = [rand(dims.in) for _ in 1:2]
XIF = KernelFunctions.MOInputIsotopicByFeatures(x, dims.out)
XIO = KernelFunctions.MOInputIsotopicByOutputs(x, dims.out)
y = [rand(dims.in) for _ in 1:2]
YIF = KernelFunctions.MOInputIsotopicByFeatures(y, dims.out)
YIO = KernelFunctions.MOInputIsotopicByOutputs(y, dims.out)

skernel = GaussianKernel()
kIndMO = IndependentMOKernel(skernel)

A = randn(dims.out, r)
B = A * transpose(A) + Diagonal(rand(dims.out))
icoregionkernel = IntrinsicCoregionMOKernel(skernel, B)

function test_kronecker_kernelmatrix(k, x)
res = kronecker_kernelmatrix(k, x)
@test typeof(res) <: Kronecker.KroneckerProduct
@test res == kernelmatrix(k, x)
end
function test_kronecker_kernelmatrix(k, x, y)
res = kronecker_kernelmatrix(k, x, y)
@test typeof(res) <: Kronecker.KroneckerProduct
@test res == kernelmatrix(k, x, y)
end

for k in [kIndMO, icoregionkernel], x in [XIF, XIO]
test_kronecker_kernelmatrix(k, x)
end
for k in [kIndMO, icoregionkernel], (x, y) in ([XIF, YIF], [XIO, YIO])
test_kronecker_kernelmatrix(k, x, y)
end

struct TestMOKernel <: MOKernel end
@test_throws ArgumentError kronecker_kernelmatrix(TestMOKernel(), XIF)
end
end

2 comments on commit c76b27d

@willtebbutt
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/45409

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.10.18 -m "<description of version>" c76b27d178d767c304fbf3bcb105252d130ef49f
git push origin v0.10.18

Please sign in to comment.