Skip to content

Commit

Permalink
Now it is possible to specify a different base density/probability ma…
Browse files Browse the repository at this point in the history
…ss function than the uniform one. If none is specified, the uniform density (either discrete or continuous) is assumed for the case of discrete or continuous random variables, respectively.
  • Loading branch information
prdm0 committed Apr 14, 2024
1 parent 93fa5c6 commit 436b193
Show file tree
Hide file tree
Showing 33 changed files with 400 additions and 204 deletions.
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ BugReports: https://github.com/prdm0/AcceptReject/issues/
RoxygenNote: 7.2.3
VignetteBuilder: knitr
Imports:
assertthat,
cli,
ggplot2,
glue,
Expand All @@ -28,4 +29,4 @@ Imports:
Suggests:
knitr,
rmarkdown,
patchwork
cowplot
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@ S3method(plot,accept_reject)
S3method(print,accept_reject)
export(accept_reject)
import(rlang)
importFrom(assertthat,assert_that)
importFrom(cli,cli_alert_danger)
importFrom(cli,cli_alert_info)
importFrom(cli,cli_alert_success)
importFrom(cli,cli_alert_warning)
importFrom(cli,cli_h1)
importFrom(ggplot2,aes)
importFrom(ggplot2,after_stat)
Expand Down
14 changes: 13 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,16 @@

# AcceptReject 0.1.1

* In generating observations of continuous random variables, using histogram with the same breaks as the R graphics `hist()` function, in the histogram created by **ggplot2**.
* Now it is possible to specify a different base density/probability mass function than the uniform one. If none is specified, the uniform density (either discrete or continuous) is assumed for the case of discrete or continuous random variables, respectively;

* In generating observations of continuous random variables, using histogram with the same breaks as the R graphics `hist()` function, in the histogram created by **ggplot2**;

* Providing alerts regarding the limits passed to the `xlim` argument of the `accept_reject()` function. If a significant density/probability mass is present, a warning will be issued. The alert can be omitted by setting `warning = FALSE`;

* Improved performance;

* In the `plot.accept_reject()` function, there's an additional argument `hist = TRUE` (default). If `hist = TRUE`, a histogram is plotted along with the base density, in the case of generating pseudo-random observations of a continuous random variable. If `hist = FALSE`, the theoretical density is plotted alongside the observed density;

* The `print.accept_reject()` function now informs whether the case is discrete or continuous and the `xlim`;

* Bug fix.
105 changes: 85 additions & 20 deletions R/accept_reject.r
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,28 @@
#' @param n The number of random numbers to generate.
#' @param continuous A logical value indicating whether the pdf is continuous or discrete. Default is \code{TRUE}.
#' @param f The probability density function (`continuous = TRUE`), in the continuous case or the probability mass function, in the discrete case (`continuous = FALSE`).
#' @param args_f A list of arguments to be passed to the pdf function.
#' @param args_f A list of arguments to be passed to the `f` function. It refers to the list of arguments of the target distribution.
#' @param f_base Base probability density function (for continuous case).If `f_base = NULL`,
#' a uniform distribution will be used. In the discrete case, this argument is ignored,
#' and a uniform probability mass function will be used as the base.
#' @param random_base Random number generation function for the base distribution passed as an argument to `f_base`.
#' If `random_base = NULL` (default), the uniform generator will be used. In the discrete case, this argument is
#' disregarded, and the uniform random number generator function will be used.
#' @param args_f_base A list of arguments for the base distribution. This refers to the list of arguments that will be passed to the function `f_base`.
#' It will be disregarded in the discrete case.
#' @param xlim A vector specifying the range of values for the random numbers in the form `c(min, max)`. Default is \code{c(0, 100)}.
#' @param c A constant value used in the acceptance-rejection method. If \code{NULL}, it will be estimated using the [lbfgs::lbfgs()] optimization algorithm. Default is \code{NULL}.
#' @param linesearch_algorithm The linesearch algorithm to be used in the [lbfgs::lbfgs()] optimization. Default is \code{"LBFGS_LINESEARCH_BACKTRACKING_ARMIJO"}.
#' @param max_iterations The maximum number of iterations for the [lbfgs::lbfgs()] optimization. Default is \code{1000}.
#' @param epsilon The convergence criterion for the [lbfgs::lbfgs()] optimization. Default is \code{1e-6}.
#' @param start_c The initial value for the constant \code{c} in the [lbfgs::lbfgs()] optimization. Default is \code{25}.
#' @param parallel A logical value indicating whether to use parallel processing for generating random numbers. Default is \code{FALSE}.
#' @param warning A logical value indicating whether to show warnings. Default is \code{TRUE}.
#' @param ... Additional arguments to be passed to the [lbfgs::lbfgs()] optimization algorithm. For details, see [lbfgs::lbfgs()].
#'
#' @return A vector of random numbers generated using the acceptance-rejection method.
#' The return is an object of `class accept_reject`, but it can be treated as an atomic vector.
#'
#' @details
#' In situations where we cannot use the inversion method (situations where it is not possible to obtain the quantile function) and we do not know a transformation that involves a random variable from which we can generate observations, we can use the acceptance and rejection method.
#' Suppose that \eqn{X} and \eqn{Y} are random variables with probability density function (pdf) or probability function (pf) \eqn{f} and \eqn{g}, respectively. In addition, suppose that there is a constant \eqn{c} such that
Expand Down Expand Up @@ -51,7 +62,18 @@
#'
#' In Unix-based operating systems, the function [accept_reject()] can be executed in parallel. To do this, simply set the argument `parallel = TRUE`.
#' The function [accept_reject()] utilizes the [parallel::mclapply()] function to execute the acceptance and rejection method in parallel.
#' On Windows operating systems, the code will be seral even if `parallel = TRUE` is set.
#' On Windows operating systems, the code will not be parallelized even if `parallel = TRUE` is set.
#'
#' For the continuous case, a base density function can be used, where the arguments
#' `f_base`, `random_base` and `args_f_base` need to be passed. If at least one of
#' them is `NULL`, the function will assume a uniform density function over the
#' interval `xlim`.
#'
#' For the discrete case, the arguments `f_base`, `random_base` and `args_f_base`
#' should be `NULL`, and if they are passed, they will be disregarded, as for
#' the discrete case, the discrete uniform distribution will always be
#' considered as the base. Sampling from the discrete uniform distribution
#' has shown good performance for the discrete case.
#'
#' @seealso [parallel::mclapply()] and [lbfgs::lbfgs()].
#'
Expand Down Expand Up @@ -86,46 +108,89 @@
#' @importFrom parallel detectCores
#' @importFrom stats dunif runif dweibull
#' @importFrom utils capture.output
#'
#' @importFrom assertthat assert_that
#' @importFrom cli cli_alert_danger cli_alert_warning
#' @importFrom glue glue
#' @export
accept_reject <-
function(
n = 1L,
continuous = TRUE,
f = dweibull,
args_f = list(shape = 1, scale = 1),
xlim = c(0, 100),
f = NULL,
args_f = NULL,
f_base = NULL,
random_base = NULL,
args_f_base = NULL,
xlim = NULL,
c = NULL,
linesearch_algorithm = "LBFGS_LINESEARCH_BACKTRACKING_ARMIJO",
max_iterations = 1000L,
epsilon = 1e-6,
start_c = 25,
parallel = FALSE,
warning = TRUE,
...) {

pdf <- purrr::partial(.f = f, !!!args_f)
if(xlim[1L] == 0 && continuous) xlim[1L] <- .Machine$double.xmin
assertthat::assert_that(
!is.null(f),
msg = cli::cli_alert_danger("You need to pass the argument f referring to the probability density or mass function that you want to generate observations.")
)

if (continuous) {
step <- 1e-5
pdf_base <- purrr::partial(.f = dunif, min = xlim[1L], max = xlim[2L])
base_generator <- purrr::partial(.f = runif, min = xlim[1L], max = xlim[2L])
} else {
assertthat::assert_that(
!is.null(args_f),
msg = cli::cli_alert_danger("You need to pass args_f with the parameters that index f.")
)

assertthat::assert_that(
!is.null(xlim),
msg = cli::cli_alert_danger("You must provide the vector xlim argument (generation support).")
)

f <- purrr::partial(.f = f, !!!args_f)

if(warning && f(xlim[2L]) > 0.01) {
cli::cli_alert_warning(
glue::glue("Warning: xlim[2L] is {f(xlim[2L])}. Trying a better upper limit might be better.")
)
}

if(warning && xlim[1L] < 0 && f(xlim[1L]) > 0.01) {
cli::cli_alert_warning(
glue::glue("Warning: xlim[1L] is {f(xlim[1L])}. Trying a lower limit might be better.")
)
}

# Uniform distribution will be used if not all information from the base
# distribution is provided.
any_null <- any(is.null(c(f_base, random_base, args_f_base)))
if(continuous && any_null) {
step <- 1e-3
f_base <- purrr::partial(.f = dunif, min = xlim[1L], max = xlim[2L])
random_base <- purrr::partial(.f = runif, min = xlim[1L], max = xlim[2L])
}

# Is it a discrete random variable?
if(!continuous){
step <- 1L
pdf_base <- \(x) 1/ (xlim[2L] - xlim[1L] + 1)
base_generator <- \(n) sample(x = xlim[1L]:xlim[2L], size = n, replace = TRUE)
f_base <- function(x) 1/ (xlim[2L] - xlim[1L] + 1)
random_base <- function(n) sample(x = xlim[1L]:xlim[2L], size = n, replace = TRUE)
} else if(continuous && !any_null) {
step <- 1e-3
if(xlim[1L] == 0) xlim[1L] <- .Machine$double.xmin
f_base <- purrr::partial(.f = f_base, !!!args_f_base)
random_base <- purrr::partial(.f = random_base, !!!args_f_base)
}

x <- seq(from = xlim[1L], to = xlim[2L], by = step)

a <- purrr::map_dbl(.x = x, .f = pdf)
b <- purrr::map_dbl(.x = x, .f = pdf_base)
a <- purrr::map_dbl(.x = x, .f = f)
b <- purrr::map_dbl(.x = x, .f = f_base)

x_max <- x[which.max((a/b)[!is.infinite(a/b)])]

objective_c <- function(c) {
differences <-
(pdf(x_max) - c * pdf_base(x_max))^2
(f(x_max) - c * f_base(x_max))^2

if(is.infinite(differences)) return(.Machine$double.xmax)
else return(differences)
Expand Down Expand Up @@ -160,9 +225,9 @@ accept_reject <-
}
one_step <- function(i) {
repeat{
x <- base_generator(n = 1L)
x <- random_base(n = 1L)
u <- runif(n = 1L)
if (u <= pdf(x = x) / (c * pdf_base(x = x))) {
if (u <= f(x = x) / (c * f_base(x = x))) {
return(x)
}
}
Expand Down
83 changes: 41 additions & 42 deletions R/plot.r
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
#' @param color_real_point Color of real probability points (discrete case)
#' @param alpha Bar chart transparency (discrete case) and observed density
#' (continuous case)
#' @param hist If TRUE, a histogram will be plotted in the continuous case,
#' comparing the theoretical density with the observed one. If FALSE,
#' [ggplot2::geom_density()] will be used instead of the histogram.
#'
#' @param ... Additional arguments.
#'
#' @details
Expand Down Expand Up @@ -66,61 +70,56 @@ plot.accept_reject <-
function(
x,
color_observed_density = "#BB9FC9", #"#E65A65", # "#FBBA78",
color_true_density = "#E9796D",
color_true_density = "#FE4F0E",
color_bar = "#BB9FC9", #"#E65A65", #"#FCEFC3",
color_observable_point = "#7BBDB3",
color_real_point = "#FE4F0E",
alpha = .3,
hist = TRUE,
...
){

y <-
do.call(
attr(x, "f"),
rlang::list2(
as.vector(x),
!!!attr(x, "args_f")
)
)

y <- do.call(attr(x, "f"), list(as.vector(x)))
data <- data.frame(x = as.vector(x), y = y)

graphic <- function(x){
p <- ggplot2::ggplot(data, ggplot2::aes(x = x))
if(attr(x, "continuous")){
capture.output(
if(hist){
p <-
p +
ggplot2::geom_histogram(aes(y = after_stat(density), color = "Observed density"), fill = color_observed_density, alpha = alpha, breaks = hist(data$x, plot = FALSE)$breaks) +
#ggplot2::geom_density(aes(y = after_stat(density), color = "Observed density"), position = "stack", fill = color_observed_density, alpha = alpha) +
ggplot2::geom_line(aes(y = y, color = "True density")) +
ggplot2::scale_color_manual(values = c("True density" = color_true_density, "Observed density" = color_observed_density)) +
ggplot2::labs(
x = "x",
y = "f(x)",
title = "Probability density function",
subtitle = "True x Observed",
color = "Legend"
)
)
ggplot2::ggplot(data, ggplot2::aes(x = x)) +
ggplot2::geom_histogram(ggplot2::aes(y = after_stat(density), color = "Observed density"), fill = color_observed_density, alpha = alpha, breaks = hist(data$x, plot = FALSE)$breaks)
} else {
p <- ggplot2::ggplot(data, ggplot2::aes(x = x)) +
ggplot2::geom_density(ggplot2::aes(y = after_stat(density), color = "Observed density"), position = "stack", fill = color_observed_density, alpha = alpha)
}
p <-
p +
# ggplot2::ggplot(data, ggplot2::aes(x = x)) +
ggplot2::geom_line(aes(y = y, color = "True density")) +
ggplot2::scale_color_manual(values = c("True density" = color_true_density, "Observed density" = color_observed_density)) +
ggplot2::labs(
x = "x",
y = "f(x)",
title = "Probability density function",
subtitle = "True x Observed",
color = "Legend"
)
} else {
capture.output(
p <-
p +
ggplot2::geom_bar(aes(y = after_stat(prop), group = 1L), fill = color_bar, alpha = alpha) +
ggplot2::geom_line(aes(y = y), linetype = "dotted") +
ggplot2::geom_point(aes(y = y, color = "Observable Probability")) +
ggplot2::geom_point(aes(y = after_stat(prop), group = 1L, color = "Real Probability"), stat = "count") +
ggplot2::scale_color_manual(values = c("Observable Probability" = color_observable_point, "Real Probability" = color_real_point)) +
ggplot2::scale_y_continuous(labels = scales::percent) +
ggplot2::labs(
x = "x",
y = "P(X = x)",
title = "Probability Function",
subtitle = "True x Observed",
color = "Legend"
)
)
p <-
ggplot2::ggplot(data, ggplot2::aes(x = x)) +
ggplot2::geom_bar(aes(y = after_stat(prop), group = 1L), fill = color_bar, alpha = alpha) +
ggplot2::geom_line(aes(y = y), linetype = "dotted") +
ggplot2::geom_point(aes(y = y, color = "Observable Probability")) +
ggplot2::geom_point(aes(y = after_stat(prop), group = 1L, color = "True Probability"), stat = "count") +
ggplot2::scale_color_manual(values = c("Observable Probability" = color_observable_point, "True Probability" = color_real_point)) +
ggplot2::scale_y_continuous(labels = scales::percent) +
ggplot2::labs(
x = "x",
y = "P(X = x)",
title = "Probability Function",
subtitle = "True x Observed",
color = "Legend"
)
}
p <- p +
ggplot2::theme(
Expand Down
18 changes: 14 additions & 4 deletions R/print.r
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
#' or when executing the function [print()] on an object of class
#' `accept_reject`, returned by the function [accept_reject()].
#'
#' @seealso [accept_reject()] and [plot.accept_reject().
#' @seealso [accept_reject()] and [plot.accept_reject()].
#'
#' @importFrom cli cli_h1 cli_alert_success
#' @importFrom glue glue
Expand All @@ -45,13 +45,23 @@ print.accept_reject <- function(x, n_min = 10L, ...) {
cli::cli_alert_info("It's not necessary, but if you want to extract the observations, use as.vector().")
cat('\n')
n <- length(x)

case <- if (attr(x, "continuous")) "continuous" else "discrete"

cli_alert_success(glue("Case: {case}"))
cli_alert_success(glue("Number of observations: {n}"))
cli_alert_success(glue("c: {attr(x, 'c')}"))
cli_alert_success(glue("Probability of acceptance (1/c): {1/attr(x, 'c')}"))
cli_alert_success(glue("c: {round(attr(x, 'c'), digits = 4L)}"))
cli_alert_success(glue("Probability of acceptance (1/c): {round(1/attr(x, 'c'), digits = 4L)}"))
if (n <= n_min) {
cli_alert_success(glue("Observations: {paste(round(x[1L:n], 4L), collapse = ' '))}"))
cli_alert_success(glue("Observations: {paste(round(x[1L:n], 4L), collapse = ' ')}"))
} else {
cli_alert_success(glue("Observations: {paste(round(x[1L:n_min], 4L), collapse = ' ')}..."))
}

limits <- attr(x, "xlim")

if(limits[1L] >= 0 && limits[1L] < 1e-10) limits[1L] <- 0

cli_alert_success(glue("xlim = {paste(limits, collapse = ' ')}"))
cli_h1("")
}
Loading

0 comments on commit 436b193

Please sign in to comment.