Skip to content

Commit 0f46635

Browse files
committed
allow saving of weighted particles and their weights in 'pfilter' computation
1 parent 030e7e9 commit 0f46635

9 files changed

+128
-31
lines changed

DESCRIPTION

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
Package: pomp
22
Type: Package
33
Title: Statistical Inference for Partially Observed Markov Processes
4-
Version: 4.4.0.1
5-
Date: 2022-11-29
4+
Version: 4.4.1.0
5+
Date: 2022-11-30
66
Authors@R: c(person(given=c("Aaron","A."),family="King",
77
role=c("aut","cre"),email="kingaa@umich.edu"),
88
person(given=c("Edward","L."),family="Ionides",role=c("aut")),

R/pfilter.R

+34-13
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,12 @@
3232
##' @param filter.traj logical; if \code{TRUE}, a filtered trajectory is returned for the state variables and parameters.
3333
##' See \code{\link{filter.traj}} for more information.
3434
##'
35-
##' @param save.states logical.
36-
##' If \code{save.states=TRUE}, the state-vector for each particle at each time is saved.
35+
##' @param save.states character;
36+
##' If \code{save.states="unweighted"}, the state-vector for each unweighted particle at each time is saved.
37+
##' If \code{save.states="weighted"}, the state-vector for each weighted particle at each time is saved, along with the corresponding weight.
38+
##' If \code{save.states="no"}, information on the latent states is not saved.
39+
##' \code{"FALSE"} is a synonym for \code{"no"} and \code{"TRUE"} is a synonym for \code{"unweighted"}.
40+
##' To retrieve the saved states, applying \code{\link{saved.states}} to the result of the \code{pfilter} computation.
3741
##'
3842
##' @return
3943
##' An object of class \sQuote{pfilterd_pomp}, which extends class \sQuote{pomp}.
@@ -51,8 +55,8 @@
5155
##' retrieve one particle trajectory.
5256
##' Useful for building up the smoothing distribution.
5357
##' }
54-
##' \item{\code{\link{saved.states}}}{retrieve list of saved states.}
55-
##' \item{\code{\link{as.data.frame}}}{ coerce to a data frame }
58+
##' \item{\code{\link{saved.states}}}{retrieve saved states}
59+
##' \item{\code{\link{as.data.frame}}}{coerce to a data frame}
5660
##' \item{\code{\link{plot}}}{diagnostic plots}
5761
##' }
5862
##'
@@ -132,7 +136,7 @@ setMethod(
132136
pred.var = FALSE,
133137
filter.mean = FALSE,
134138
filter.traj = FALSE,
135-
save.states = FALSE,
139+
save.states = c("no", "weighted", "unweighted", "FALSE", "TRUE"),
136140
...,
137141
verbose = getOption("verbose", FALSE)) {
138142

@@ -170,7 +174,7 @@ setMethod(
170174
pred.var = FALSE,
171175
filter.mean = FALSE,
172176
filter.traj = FALSE,
173-
save.states = FALSE,
177+
save.states = c("no", "weighted", "unweighted", "FALSE", "TRUE"),
174178
...,
175179
verbose = getOption("verbose", FALSE)) {
176180

@@ -209,8 +213,9 @@ setMethod(
209213

210214
pfilter.internal <- function (object, Np,
211215
pred.mean = FALSE, pred.var = FALSE, filter.mean = FALSE,
212-
filter.traj = FALSE, cooling, cooling.m, save.states = FALSE, ...,
213-
.gnsi = TRUE, verbose = FALSE) {
216+
filter.traj = FALSE, cooling, cooling.m,
217+
save.states = c("no", "weighted", "unweighted", "FALSE", "TRUE"),
218+
..., .gnsi = TRUE, verbose = FALSE) {
214219

215220
verbose <- as.logical(verbose)
216221

@@ -224,7 +229,8 @@ pfilter.internal <- function (object, Np,
224229
pred.var <- as.logical(pred.var)
225230
filter.mean <- as.logical(filter.mean)
226231
filter.traj <- as.logical(filter.traj)
227-
save.states <- as.logical(save.states)
232+
save.states <- as.character(save.states)
233+
save.states <- match.arg(save.states)
228234

229235
params <- coef(object)
230236
times <- time(object,t0=TRUE)
@@ -241,8 +247,11 @@ pfilter.internal <- function (object, Np,
241247
x <- init.x
242248

243249
## set up storage for saving samples from filtering distributions
244-
if (save.states || filter.traj) {
250+
stsav <- save.states %in% c("unweighted","TRUE")
251+
wtsav <- save.states == "weighted"
252+
if (stsav || wtsav || filter.traj) {
245253
xparticles <- setNames(vector(mode="list",length=ntimes),time(object))
254+
if (wtsav) xweights <- xparticles
246255
}
247256
if (filter.traj) {
248257
pedigree <- vector(mode="list",length=ntimes+1)
@@ -298,6 +307,12 @@ pfilter.internal <- function (object, Np,
298307
times=times[nt+1],params=params,log=TRUE,.gnsi=gnsi)
299308
gnsi <- FALSE
300309

310+
## store unweighted particles and their weights
311+
if (wtsav) {
312+
xparticles[[nt]] <- x
313+
xweights[[nt]] <- weights
314+
}
315+
301316
## compute prediction mean, prediction variance, filtering mean,
302317
## effective sample size, log-likelihood.
303318
## also do resampling.
@@ -328,7 +343,7 @@ pfilter.internal <- function (object, Np,
328343
if (filter.mean) filt.m[,nt] <- xx$fm
329344
if (filter.traj) pedigree[[nt]] <- xx$ancestry
330345

331-
if (save.states || filter.traj) {
346+
if (stsav || filter.traj) {
332347
xparticles[[nt]] <- x
333348
dimnames(xparticles[[nt]]) <- setNames(dimnames(xparticles[[nt]]),
334349
c("variable",".id"))
@@ -354,7 +369,13 @@ pfilter.internal <- function (object, Np,
354369
}
355370
}
356371

357-
if (!save.states) xparticles <- list()
372+
if (stsav) {
373+
stsav <- xparticles
374+
} else if (wtsav) {
375+
stsav <- list(states=xparticles,weights=xweights)
376+
} else {
377+
stsav <- list()
378+
}
358379

359380
new(
360381
"pfilterd_pomp",
@@ -366,7 +387,7 @@ pfilter.internal <- function (object, Np,
366387
paramMatrix=array(data=numeric(0),dim=c(0,0)),
367388
eff.sample.size=eff.sample.size,
368389
cond.logLik=loglik,
369-
saved.states=xparticles,
390+
saved.states=stsav,
370391
Np=as.integer(Np),
371392
loglik=sum(loglik)
372393
)

R/saved_states.R

+37-4
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,20 @@
1313
##' @family particle filter methods
1414
##' @family extraction methods
1515
##' @inheritParams filter.mean
16+
##' @param format character;
17+
##' format of the returned object (see below).
1618
##'
17-
##' @return The saved states are returned in the form of a list, with one element per time-point.
18-
##' Each element consists of a matrix, with one row for each state variable and one column for each particle.
19+
##' @return According to the \code{format} argument, the saved states are returned either as a list or a data frame.
1920
##'
21+
##' If \code{format="data.frame"}, then the returned data frame holds the state variables and (optionally) the log weight of each particle at each observation time.
22+
##' The \code{.id} variable distinguishes particles.
23+
##'
24+
##' If \code{format="list"} and \code{\link{pfilter}} was called with \code{save.states="unweighted"} or \code{save.states="TRUE"}, the returned list contains one element per observation time.
25+
##' Each element consists of a matrix, with one row for each state variable and one column for each particle.
26+
##' If \code{\link{pfilter}} was called with \code{save.states="weighted"}, the list itself contains two lists:
27+
##' the first holds the particles as above, the second holds the corresponding log weights.
28+
##' In particular, it has one element per observation time; each element is the vector of per-particle log weights.
29+
##'
2030
NULL
2131

2232
setGeneric(
@@ -45,8 +55,31 @@ setMethod(
4555
setMethod(
4656
"saved.states",
4757
signature=signature(object="pfilterd_pomp"),
48-
definition=function (object, ...) {
49-
object@saved.states
58+
definition=function (object, ...,
59+
format = c("list","data.frame")) {
60+
format <- match.arg(format)
61+
if (format=="list") {
62+
object@saved.states
63+
} else if (length(object@saved.states)==2L) {
64+
s <- melt(object@saved.states$states)
65+
w <- melt(object@saved.states$weights)
66+
s$time <- time(object)[as.integer(s$L1)]
67+
w$time <- time(object)[as.integer(w$L1)]
68+
w$variable <- "weight"
69+
x <- rbind(
70+
s[,c("time",".id","variable","value")],
71+
w[,c("time",".id","variable","value")]
72+
)
73+
x <- x[order(x$time,x$.id),]
74+
rownames(x) <- NULL
75+
x
76+
} else {
77+
s <- melt(object@saved.states)
78+
s$time <- time(object)[as.integer(s$L1)]
79+
s <- s[,c("time",".id","variable","value")]
80+
rownames(s) <- NULL
81+
s
82+
}
5083
}
5184
)
5285

inst/NEWS

+9
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
_N_e_w_s _f_o_r _p_a_c_k_a_g_e '_p_o_m_p'
22

3+
_C_h_a_n_g_e_s _i_n '_p_o_m_p' _v_e_r_s_i_o_n _4._4._1:
4+
5+
• It is now possible to retrieve the weighted particles
6+
computed in the course of a ‘pfilter’ computation. To
7+
accomplish this, set ‘save.states="weighted"’ in the call to
8+
‘pfilter’ and retrieve the particles and their weights using
9+
‘saved.states’. Previously, one could obtain only the
10+
unweighted particles.
11+
312
_C_h_a_n_g_e_s _i_n '_p_o_m_p' _v_e_r_s_i_o_n _4._4._0:
413

514
• Some documentation improvements.

inst/NEWS.Rd

+7
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
\name{NEWS}
22
\title{News for package `pomp'}
3+
\section{Changes in \pkg{pomp} version 4.4.1}{
4+
\itemize{
5+
\item It is now possible to retrieve the weighted particles computed in the course of a \code{pfilter} computation.
6+
To accomplish this, set \code{save.states="weighted"} in the call to \code{pfilter} and retrieve the particles and their weights using \code{saved.states}.
7+
Previously, one could obtain only the unweighted particles.
8+
}
9+
}
310
\section{Changes in \pkg{pomp} version 4.4.0}{
411
\itemize{
512
\item Some documentation improvements.

man/pfilter.Rd

+10-6
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/saved_states.Rd

+13-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/pfilter.R

+6-2
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,15 @@ pfilter(pf,dmeasure=function(log,...) -Inf)
6767
pfilter(pf,dmeasure=function(log,...) -Inf,filter.mean=TRUE)
6868

6969
pf1 <- pfilter(pf,save.states=TRUE,filter.traj=TRUE)
70-
pf2 <- pfilter(pf,pred.mean=TRUE,pred.var=TRUE,filter.mean=TRUE,save.states=TRUE)
70+
pf2 <- pfilter(pf,pred.mean=TRUE,pred.var=TRUE,filter.mean=TRUE,save.states="unweighted")
7171
pf3 <- pfilter(pf,t0=1,filter.traj=TRUE)
7272
pf4 <- pfilter(pf,dmeasure=Csnippet("lik = (give_log) ? R_NegInf : 0;"),
7373
filter.traj=TRUE)
74+
pf5 <- pfilter(pf,save.states="weighted")
7475
pf1 %>% saved.states() %>% melt() %>% names()
7576
pf1 %>% saved.states() %>% melt() %>% dim()
77+
pf1 %>% saved.states(format="data") %>% names()
78+
pf1 %>% saved.states(format="data") %>% dim()
7679
c(A=pf1,B=pf2) %>% saved.states() %>% melt() %>% names()
7780
c(A=pf1,B=pf2) %>% saved.states() %>% melt() %>% sapply(class)
7881
c(A=pf1,B=pf2) %>% as.data.frame() %>% sapply(class)
@@ -87,7 +90,8 @@ names(dimnames(filter.traj(c(pf1,pf4))))
8790
names(melt(as(c(pf1,pf4),"data.frame")))
8891
pf2 %>% melt() %>% names()
8992
pf2 %>% melt(id="time") %>% names()
90-
93+
pf5 %>% saved.states(format="d") %>% names()
94+
pf5 %>% saved.states(format="d") %>% dim()
9195
try(saved.states())
9296
try(saved.states(NULL))
9397
try(saved.states("bob"))

tests/pfilter.Rout.save

+10-2
Original file line numberDiff line numberDiff line change
@@ -148,14 +148,19 @@ Error : in 'pfilter': in 'dmeasure': ouch!
148148
<object of class 'pfilterd_pomp'>
149149
>
150150
> pf1 <- pfilter(pf,save.states=TRUE,filter.traj=TRUE)
151-
> pf2 <- pfilter(pf,pred.mean=TRUE,pred.var=TRUE,filter.mean=TRUE,save.states=TRUE)
151+
> pf2 <- pfilter(pf,pred.mean=TRUE,pred.var=TRUE,filter.mean=TRUE,save.states="unweighted")
152152
> pf3 <- pfilter(pf,t0=1,filter.traj=TRUE)
153153
> pf4 <- pfilter(pf,dmeasure=Csnippet("lik = (give_log) ? R_NegInf : 0;"),
154154
+ filter.traj=TRUE)
155+
> pf5 <- pfilter(pf,save.states="weighted")
155156
> pf1 %>% saved.states() %>% melt() %>% names()
156157
[1] "variable" ".id" "value" "L1"
157158
> pf1 %>% saved.states() %>% melt() %>% dim()
158159
[1] 20000 4
160+
> pf1 %>% saved.states(format="data") %>% names()
161+
[1] "time" ".id" "variable" "value"
162+
> pf1 %>% saved.states(format="data") %>% dim()
163+
[1] 20000 4
159164
> c(A=pf1,B=pf2) %>% saved.states() %>% melt() %>% names()
160165
[1] "variable" ".id" "value" "L2" "L1"
161166
> c(A=pf1,B=pf2) %>% saved.states() %>% melt() %>% sapply(class)
@@ -211,7 +216,10 @@ No id variables; using all as measure variables
211216
[1] "variable" "value"
212217
> pf2 %>% melt(id="time") %>% names()
213218
[1] "time" "variable" "value"
214-
>
219+
> pf5 %>% saved.states(format="d") %>% names()
220+
[1] "time" ".id" "variable" "value"
221+
> pf5 %>% saved.states(format="d") %>% dim()
222+
[1] 30000 4
215223
> try(saved.states())
216224
Error : in 'saved.states': 'object' is a required argument.
217225
> try(saved.states(NULL))

0 commit comments

Comments
 (0)