Skip to content

Commit

Permalink
Docs improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt committed Sep 14, 2023
1 parent 1b7c135 commit 463e49f
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
6 changes: 6 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,12 @@ mean_vector
We provide standard mean functions like [`ZeroMean`](@ref) and [`ConstMean`](@ref)
as well as [`CustomMean`](@ref) to simply wrap a function.

```@docs
ZeroMean
ConstMean
CustomMean
```

## Testing Utilities

AbstractGPs provides several consistency tests in the `AbstractGPs.TestUtils` module.
Expand Down
16 changes: 16 additions & 0 deletions src/mean_function.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""
abstract type MeanFunction end
`MeanFunction` introduces an API for treating the prior mean function appropriately.
On the abstract level, all `MeanFunction` are functions.
However we generally want to evaluate them on a collection of inputs.
Expand Down Expand Up @@ -42,6 +44,20 @@ mean_vector(m::ConstMean, x::AbstractVector) = Fill(m.c, length(x))
A wrapper around whatever unary function you fancy. Must be able to be mapped over an
`AbstractVector` of inputs.
# Warning
`CustomMean` is generally sufficient for testing purposes, but care should be taken if
attempting to differentiate through `mean_vector` with a `CustomMean` when using
`Zygote.jl`. In particular, `mean_vector(m::CustomMean, x)` is implemented as `map(m.f, x)`,
which when `x` is a `ColVecs` or `RowVecs` will not differentiate correctly.
In such cases, you should implement `mean_vector` directly for your custom mean.
For example, if `f(x) = sum(x)`, you might implement `mean_vector` as
```julia
mean_vector(::CustomMean{typeof(f)}, x::ColVecs) = vec(sum(x.X; dims=1))
mean_vector(::CustomMean{typeof(f)}, x::RowVecs) = vec(sum(x.X; dims=2))
```
which avoids ever applying `map` to a `ColVecs` or `RowVecs`.
"""
struct CustomMean{Tf} <: MeanFunction
f::Tf
Expand Down

0 comments on commit 463e49f

Please sign in to comment.