Skip to content

Commit

Permalink
update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
avibryant committed Mar 23, 2020
1 parent f67872a commit 6e45737
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 12 deletions.
6 changes: 6 additions & 0 deletions docs/jupyter.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ import com.stripe.rainier.core._
import com.stripe.rainier.notebook._
```

Finally, you can register a custom pretty-printer for better notebook output of Rainier objects:

```scala
PrettyPrint.register(repl)
```

## Using a standard Almond kernel

If you are not using the custom kernel installer, make sure to use the `Scala 2.12` kernel, and add the following to the top of your notebook in its own cell:
Expand Down
4 changes: 2 additions & 2 deletions docs/model.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ Strips away the observations, for ease of checking prior predictions.

Combines two models.

* `sample(sampler: Sampler, nChains: Int = 4): Trace`
* `sample(config: SamplerConfig, nChains: Int = 4): Trace`

Run inference using the provided sampler.
Run inference using the provided sampler configuration.

* `optimize[T](value: Generator[T]): T`

Expand Down
39 changes: 37 additions & 2 deletions docs/samplers.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,44 @@ id: samplers
title: Samplers
---

These are found in `com.stripe.rainier.sampler`, and extend `Sampler`.
Samplers and related classes are found in `com.stripe.rainier.sampler`.

## HMC
## SamplerConfig

Calls to `sample()` can optionally provide a custom configuration object that implements the following trait:

```scala
trait SamplerConfig {
def iterations: Int
def warmupIterations: Int
def statsWindow: Int

def stepSizeTuner(): StepSizeTuner
def massMatrixTuner(): MassMatrixTuner
def sampler(): Sampler
}
```

Your configuration can directly implement every method of this trait, or extend `DefaultConfiguration` and selectively override.

* `iterations` is the number of samples per chain to keep, after the warmup period (default: 1000)
* `warmupIterations` is the number of samples per chain to use for warmup (default: 1000)
* `statsWindow` is the number of training samples to keep for diagnostics like average acceptance rate (default: 100)
* `stepSizeTuner()` should create a new `StepSizeTuner` object. This can be:
* `new DualAvgTuner(targetAcceptRate)` (the default is this with 0.8)
* `StaticStepSize(stepSize)`
* `massMatrixTuner()` should create a new `MassMatrixTuner` object. This can be:
* `new IdentityMassMatrix`
* `new DiagonalMassMatrixTuner(initialWindowSize, windowExpansion, skipFirst, skipLast)` (the default is `new DiagonalMassMatrixTuner(50, 1.5, 50, 50)`)
* `new DenseMassMatrixTuner(initialWindowSize, windowExpansion, skipFirst, skipLast)`
* `sampler()` should create a new `Sampler` object. This can be:
* `new HMCSampler(nSteps)`
* `new EHMCSampler(maxSteps, minSteps)` (the default is `new EHMCSampler(1024, 1)`)


For backwards compatibility with `0.3.0`, you can also build configs using the old `HMC` and `EHMC` constructors below.

## HMC
Hamiltonian Monte Carlo with dual averaging as per [Gelman & Hoffman](http://www.stat.columbia.edu/~gelman/research/published/nuts.pdf)

`HMC(warmupIterations: Int, iterations: Int, nSteps: Int)`
Expand All @@ -25,3 +59,4 @@ Empirical HMC as per [Wu et al](https://arxiv.org/pdf/1810.04449.pdf)
* `iterations` produce usable samples
* `k` is the number of iterations used to build an empirical distribution of steps to U-turn
* `l0` is the number of leap-frog steps used during this empirical phase

Original file line number Diff line number Diff line change
Expand Up @@ -164,18 +164,18 @@ trait WindowedMassMatrixTuner extends MassMatrixTuner {
}
}

case class DiagonalMassMatrixTuner(initialWindowSize: Int,
windowExpansion: Double,
skipFirst: Int,
skipLast: Int)
class DiagonalMassMatrixTuner(val initialWindowSize: Int,
val windowExpansion: Double,
val skipFirst: Int,
val skipLast: Int)
extends WindowedMassMatrixTuner {
def initializeEstimator(size: Int) = new VarianceEstimator(size)
}

case class DenseMassMatrixTuner(initialWindowSize: Int,
windowExpansion: Double,
skipFirst: Int,
skipLast: Int)
class DenseMassMatrixTuner(val initialWindowSize: Int,
val windowExpansion: Double,
val skipFirst: Int,
val skipLast: Int)
extends WindowedMassMatrixTuner {
def initializeEstimator(size: Int) = new CovarianceEstimator(size)
}

0 comments on commit 6e45737

Please sign in to comment.