-
Notifications
You must be signed in to change notification settings - Fork 1
/
README.Rmd
213 lines (166 loc) · 5.98 KB
/
README.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
---
output: github_document
---
<!-- README.md is generated from README.Rmd. Please edit that file -->
```{r, include = FALSE}
knitr::opts_chunk$set(
collapse = TRUE,
comment = "#>",
fig.path = "man/figures/README-",
out.width = "100%"
)
```
# StanEstimators
<!-- badges: start -->
[![CRAN status](https://www.r-pkg.org/badges/version/StanEstimators)](https://CRAN.R-project.org/package=StanEstimators)
[![R-CMD-check](https://github.com/andrjohns/StanEstimators/actions/workflows/R-CMD-check.yaml/badge.svg)](https://github.com/andrjohns/StanEstimators/actions/workflows/R-CMD-check.yaml)
[![StanEstimators status badge](https://andrjohns.r-universe.dev/badges/StanEstimators)](https://andrjohns.r-universe.dev/StanEstimators)
<!-- badges: end -->
The `StanEstimators` package provides an estimation back-end for R functions,
similar to those provided by the `optim` package, using the algorithms provided
by the Stan probabilistic programming language.
As Stan's algorithms are gradient-based, function gradients can be automatically
calculated using finite-differencing or the user can provide a function for
analytical calculation.
## Installation
You can install pre-built binaries using:
```{r, eval=FALSE}
# we recommend running this is a fresh R session or restarting your current session
install.packages('StanEstimators', repos = c('https://andrjohns.r-universe.dev', 'https://cloud.r-project.org'))
```
Or you can build from source using:
```{r, eval=FALSE}
# install.packages("remotes")
remotes::install_github("andrjohns/StanEstimators")
```
## Usage
Consider the goal of estimating the mean and standard deviation of a normal
distribution, with uniform uninformative priors on both parameters:
$$
y \sim \textbf{N}(\mu, \sigma)
$$
$$
\mu \sim \textbf{U}[-\infty, \infty]
$$
$$
\sigma \sim \textbf{U}[0, \infty]
$$
With known true values for verification:
```{r}
y <- rnorm(500, 10, 2)
```
As with other estimation routines provided in R, we need to specify this as a
function which takes a vector of parameters as its first argument and returns a
single scalar value (the log-likelihood), as well as initial values for the
parameters:
```{r}
loglik_fun <- function(v, x) {
sum(dnorm(x, v[1], v[2], log = TRUE))
}
inits <- c(0, 5)
```
Estimation time can also be significantly reduced by providing a gradient function,
rather than relying on finite-differencing:
```{r}
grad <- function(v, x) {
inv_sigma <- 1 / v[2]
y_scaled = (x - v[1]) * inv_sigma
scaled_diff = inv_sigma * y_scaled
c(sum(scaled_diff),
sum(inv_sigma * (y_scaled*y_scaled) - inv_sigma)
)
}
```
### MCMC Estimation
Full MCMC estimation is provided by the `stan_sample()` function, which uses
Stan's default No U-Turn Sampler (NUTS) unless otherwise specified:
```{r, results=FALSE, message=FALSE, warning=FALSE}
library(StanEstimators)
fit <- stan_sample(loglik_fun, inits, additional_args = list(y),
lower = c(-Inf, 0), # Enforce a positivity constraint for SD
num_chains = 1, seed = 1234)
```
We can see that the parameters were recovered accurately and that the estimation
was relatively fast: ~1 sec for 1000 warmup and 1000 iterations
```{r}
unlist(fit@timing)
summary(fit)
```
Estimation time can be improved further by providing a gradient function:
```{r, results=FALSE, message=FALSE, warning=FALSE}
fit_grad <- stan_sample(loglik_fun, inits, additional_args = list(y),
grad_fun = grad,
lower = c(-Inf, 0),
num_chains = 1,
seed = 1234)
```
Which shows that the estimation time was dramatically improved,
now ~0.15 seconds for 1000 warmup and 1000 iterations.
```{r}
unlist(fit_grad@timing)
summary(fit_grad)
```
### Optimization
```{r, results=FALSE, message=FALSE, warning=FALSE}
opt_fd <- stan_optimize(loglik_fun, inits, additional_args = list(y),
lower = c(-Inf, 0),
seed = 1234)
opt_grad <- stan_optimize(loglik_fun, inits, additional_args = list(y),
grad_fun = grad,
lower = c(-Inf, 0),
seed = 1234)
```
```{r}
summary(opt_fd)
summary(opt_grad)
```
### Laplace Approximation
```{r, results=FALSE, message=FALSE, warning=FALSE}
# Can provide the mode as a numeric vector:
lapl_num <- stan_laplace(loglik_fun, inits, additional_args = list(y),
mode = c(10, 2),
lower = c(-Inf, 0),
seed = 1234)
# Can provide the mode as a StanOptimize object:
lapl_opt <- stan_laplace(loglik_fun, inits, additional_args = list(y),
mode = opt_fd,
lower = c(-Inf, 0),
seed = 1234)
# Can estimate the mode before sampling:
lapl_est <- stan_laplace(loglik_fun, inits, additional_args = list(y),
lower = c(-Inf, 0),
seed = 1234)
```
```{r}
summary(lapl_num)
summary(lapl_opt)
summary(lapl_est)
```
### Variational Inference
```{r, results=FALSE, message=FALSE, warning=FALSE}
var_fd <- stan_variational(loglik_fun, inits, additional_args = list(y),
lower = c(-Inf, 0),
seed = 1234)
var_grad <- stan_variational(loglik_fun, inits, additional_args = list(y),
grad_fun = grad,
lower = c(-Inf, 0),
seed = 1234)
```
```{r}
summary(var_fd)
summary(var_grad)
```
### Pathfinder
```{r, results=FALSE, message=FALSE, warning=FALSE}
path_fd <- stan_pathfinder(loglik_fun, inits, additional_args = list(y),
lower = c(-Inf, 0),
seed = 1234)
path_grad <- stan_pathfinder(loglik_fun, inits, additional_args = list(y),
grad_fun = grad,
lower = c(-Inf, 0),
seed = 1234)
```
```{r}
summary(path_fd)
summary(path_grad)
```