diff --git a/docs/src/api.md b/docs/src/api.md index 10a244e5..3de93747 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -172,6 +172,12 @@ mean_vector We provide standard mean functions like [`ZeroMean`](@ref) and [`ConstMean`](@ref) as well as [`CustomMean`](@ref) to simply wrap a function. +```@docs +ZeroMean +ConstMean +CustomMean +``` + ## Testing Utilities AbstractGPs provides several consistency tests in the `AbstractGPs.TestUtils` module. diff --git a/src/mean_function.jl b/src/mean_function.jl index efb3a954..56ca290d 100644 --- a/src/mean_function.jl +++ b/src/mean_function.jl @@ -1,4 +1,6 @@ """ + abstract type MeanFunction end + `MeanFunction` introduces an API for treating the prior mean function appropriately. On the abstract level, all `MeanFunction` are functions. However we generally want to evaluate them on a collection of inputs. @@ -42,6 +44,20 @@ mean_vector(m::ConstMean, x::AbstractVector) = Fill(m.c, length(x)) A wrapper around whatever unary function you fancy. Must be able to be mapped over an `AbstractVector` of inputs. + +# Warning +`CustomMean` is generally sufficient for testing purposes, but care should be taken if +attempting to differentiate through `mean_vector` with a `CustomMean` when using +`Zygote.jl`. In particular, `mean_vector(m::CustomMean, x)` is implemented as `map(m.f, x)`, +which when `x` is a `ColVecs` or `RowVecs` will not differentiate correctly. + +In such cases, you should implement `mean_vector` directly for your custom mean. +For example, if `f(x) = sum(x)`, you might implement `mean_vector` as +```julia +mean_vector(::CustomMean{typeof(f)}, x::ColVecs) = vec(sum(x.X; dims=1)) +mean_vector(::CustomMean{typeof(f)}, x::RowVecs) = vec(sum(x.X; dims=2)) +``` +which avoids ever applying `map` to a `ColVecs` or `RowVecs`. """ struct CustomMean{Tf} <: MeanFunction f::Tf