Skip to content

Commit

Permalink
...
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed Feb 13, 2025
1 parent ea5e360 commit 5a4aad5
Show file tree
Hide file tree
Showing 9 changed files with 231 additions and 39 deletions.
24 changes: 16 additions & 8 deletions R/Rush.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
#' The only requirement is that the worker can connect to the Redis server.
#' The script is created with the `$worker_script()` method.
#'
#'
#'
#' @template param_network_id
#' @template param_config
#' @template param_worker_loop
Expand All @@ -31,6 +29,7 @@
#' @template param_heartbeat_expire
#' @template param_seed
#' @template param_data_format
#' @template param_consistent
#'
#' @return Object of class [R6::R6Class] and `Rush` with controller methods.
#' @export
Expand Down Expand Up @@ -64,11 +63,16 @@ Rush = R6::R6Class("Rush",
#' List of mirai processes started with `$start_remote_workers()`.
processes_mirai = NULL,

#' @field consistent (`logical(1)`)\cr
#' Whether tasks are consistent.
consistent = NULL,

#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(network_id = NULL, config = NULL, seed = NULL) {
initialize = function(network_id = NULL, config = NULL, seed = NULL, consistent = FALSE) {
self$network_id = assert_string(network_id, null.ok = TRUE) %??% uuid::UUIDgenerate()
self$config = assert_class(config, "redis_config", null.ok = TRUE) %??% rush_env$config
self$consistent = assert_flag(consistent)
if (is.null(self$config)) self$config = redux::redis_config()
if (!redux::redis_available(self$config)) {
stop("Can't connect to Redis. Check the configuration.")
Expand Down Expand Up @@ -958,8 +962,13 @@ Rush = R6::R6Class("Rush",
private$.n_seen_results = private$.n_seen_results + n_new_results

# fetch finished tasks
data = self$fetch_finished_tasks(fields, data_format = data_format)
tail(data, n_new_results)
data = self$fetch_finished_tasks(fields, data_format = "list")
data = tail(data, n_new_results)
if (data_format == "list") return(data)
# it is much faster to only convert the new results to data.table instead of doing it in fetch_finished_tasks
tab = rbindlist(data, use.names = !self$consistent, fill = !self$consistent)
tab[, keys := names(data)]
tab[]
},

#' @description
Expand Down Expand Up @@ -1618,12 +1627,11 @@ Rush = R6::R6Class("Rush",
lg$debug("Fetching %i task(s)", length(data))

if (data_format == "list") return(set_names(data, keys))
tab = rbindlist(data, use.names = TRUE, fill = TRUE)
tab = rbindlist(data, use.names = !self$consistent, fill = !self$consistent)
tab[, keys := unlist(keys)]
tab[]
},


# fetch and cache tasks
.fetch_cached_tasks = function(new_keys, fields, reset_cache = FALSE, data_format = "data.table") {
r = self$connector
Expand All @@ -1646,7 +1654,7 @@ Rush = R6::R6Class("Rush",
lg$debug("Fetching %i task(s)", length(private$.cached_tasks))

if (data_format == "list") return(private$.cached_tasks)
tab = rbindlist(private$.cached_tasks, use.names = TRUE, fill = TRUE)
tab = rbindlist(private$.cached_tasks, use.names = !self$consistent, fill = !self$consistent)
if (nrow(tab)) tab[, keys := names(private$.cached_tasks)]
tab[]
}
Expand Down
6 changes: 4 additions & 2 deletions R/RushWorker.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#' @template param_seed
#' @template param_heartbeat_period
#' @template param_heartbeat_expire
#' @template param_consistent
#'
#' @return Object of class [R6::R6Class] and `RushWorker` with worker methods.
#' @export
Expand Down Expand Up @@ -42,9 +43,10 @@ RushWorker = R6::R6Class("RushWorker",
worker_id = NULL,
heartbeat_period = NULL,
heartbeat_expire = NULL,
seed = NULL
seed = NULL,
consistent = FALSE
) {
super$initialize(network_id = network_id, config = config, seed = seed)
super$initialize(network_id = network_id, config = config, seed = seed, consistent)

self$remote = assert_flag(remote)
self$worker_id = assert_string(worker_id %??% ids::adjective_animal(1))
Expand Down
3 changes: 3 additions & 0 deletions man-roxygen/param_consistent.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#' @param consistent (`logical(1)`)\cr
#' Whether the task and result structure are consistent.
#' When the tasks are consistent, fetching tasks is faster.
9 changes: 8 additions & 1 deletion man/Rush.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 6 additions & 1 deletion man/RushWorker.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions vignettes/articles/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
/.quarto/
114 changes: 114 additions & 0 deletions vignettes/articles/benchmark.qmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
---
title: "Rush Benchmarks"
vignette: >
%\VignetteIndexEntry{Rush Benchmarks}
%\VignetteEngine{quarto::html}
%\VignetteEncoding{UTF-8}
---


## Push Task

Pushing a running task to the database is a fast operation.

```{r}
data = rbindlist(list(
small = results[["push_running_tasks_small"]],
large = results[["push_running_tasks_large"]]
), idcol = "task_size")
ggplot(data, aes(x = size, y = median_runtime, color = task_size)) +
geom_point() +
geom_line() +
scale_x_log10() +
scale_y_log10() +
labs(x = "Number of cached tasks", y = "Median runtime (ms)") +
theme_minimal()
```


## Fetch Results

The fastest way to retrieve new results is to use the `$fetch_new_results(fields = "ys")` method with `data_format = "list"`.
If we need the results as a `data.table`, we can use the `data_format = "data.table"`.
The `$fetch_new_results()` add new results to the cache to minimize the runtime when fetching the results again.

We measure the runtime of fetching one new result depending on the cache size.
The runtime increase slightly with the number of cached tasks.
The runtime difference between a `list` and a `data.table` is negligible.

```{r}
data = rbindlist(list(
list = results[["fetch_new_results_cache_list"]],
data.table = results[["fetch_new_results_cache_data_table"]]),
idcol = "data_format")
ggplot(data, aes(x = size, y = median_runtime, color = data_format)) +
geom_point() +
geom_line() +
scale_x_log10() +
scale_y_log10() +
labs(x = "Number of cached tasks", y = "Median runtime (ms)") +
theme_minimal()
```

We can also retrieve new results and return them with the cached task with the `$fetch_finished_tasks(fields = "ys")` method.

We measure the time to retrieve one new result and the n cached tasks.
The `data_format = "data.table"` runs `rbindlist()` on the tasks.
This operation gets more expensive with the number of cached tasks.

```{r}
data = rbindlist(list(
list = results[["fetch_results_cache_list"]],
data.table = results[["fetch_results_cache_data_table"]]),
idcol = "data_format")
ggplot(data, aes(x = size, y = median_runtime, color = data_format)) +
geom_point() +
geom_line() +
scale_x_log10() +
scale_y_log10() +
labs(x = "Number of cached tasks", y = "Median runtime (ms)", title = "Fetch new results with cache and data.table") +
theme_minimal()
```

## Fetch Tasks And Results

When we need the results and the tasks we use the `$fetch_finished_tasks()` method without the `fields` argument.

We measure the time to retrieve one new task and the n cached tasks.
The `data_format = "data.table"` increases more because the `rbindlist()` operation gets more expensive with more

```{r}
data = rbindlist(list(
list = results[["fetch_tasks_cache_list"]],
data.table = results[["fetch_tasks_cache_data_table"]]),
idcol = "data_format")
ggplot(data, aes(x = size, y = median_runtime, color = data_format)) +
geom_point() +
geom_line() +
scale_x_log10() +
scale_y_log10() +
labs(x = "Number of cached tasks", y = "Median runtime (ms)") +
theme_minimal()
```


The overhead of the `rbindlist()` operation can be minimized by using `consistent = TRUE`.

```{r}
data = rbindlist(list(
"TRUE" = results[["fetch_tasks_cache_data_table_consistent"]],
"FALSE" = results[["fetch_tasks_cache_data_table"]]),
idcol = "consistent")
ggplot(data, aes(x = size, y = median_runtime, color = consistent)) +
geom_point() +
geom_line() +
scale_x_log10() +
scale_y_log10() +
labs(x = "Number of cached tasks", y = "Median runtime (ms)", title = "Fetch new results with cache and data.table consistent") +
theme_minimal()
```
40 changes: 39 additions & 1 deletion vignettes/articles/rush_advanced.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,42 @@ vignette: >
%\VignetteEncoding{UTF-8}
---

```{r}
library(rush)
wl_random_search = function(rush) {
repeat {
# draw new task
xs = list(x1 = runif(1, -5, 10), x2 = runif(1, 0, 15))
# mark task as running
key = rush$push_running_tasks(xss = list(xs))
# evaluate task
ys = list(y = branin(xs$x1, xs$x2))
# push result
rush$push_results(key, yss = list(ys))
# stop optimization after 100 tasks
if (rush$n_finished_tasks >= 100) break
}
}
# Connection to the Redis database
config = redux::redis_config()
# Initialize rush controller
rush = rsh(
network = "test-random-search",
config = config)
worker_ids = rush$start_local_workers(
worker_loop = wl_random_search,
n_workers = 1)
```

# Retrieve Results {#sec-retrieve-results}

The `$fetch_finished_tasks()` method retrieves the results of finished tasks.
Expand All @@ -26,7 +62,7 @@ The default of `$fetch_finished_tasks()` is `c("xs", "xs_extra", "worker_extra",
If we don't want all that extra information, we can just query `"xs"` and `"ys"`.

```{r}
rush$fetch_finished_tasks(fields = c("xs", "ys"))
rush$fetch_finished_tasks(fields = c("xs", "ys"), reset_cache = TRUE)
```

The option`data_format = "list"` returns a `list` instead of a `data.table`.
Expand Down Expand Up @@ -73,6 +109,7 @@ The occurring error leads to the task being labeled as `"failed"`, with the corr
This process is demonstrated in the following example.

```{r}
#| eval: false
rush = rsh(network_id = "simple_error")
fun = function(x) {
Expand All @@ -99,6 +136,7 @@ As an example, we define a function that simulates a segmentation fault by killi
The package includes the method `$detect_lost_workers()` designed to identify and manage such instances effectively.

```{r}
#| eval: false
rush = rsh(network_id = "segmenation_fault")
fun = function(x) {
Expand Down
Loading

0 comments on commit 5a4aad5

Please sign in to comment.