-
-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy paththreading-repro.R
77 lines (63 loc) · 2.38 KB
/
threading-repro.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# Test num_threads
library(mlr3)
library(mlr3torch)
library(mlr3benchmark)
lrn <- lrn("classif.tabnet", epochs = 5L, num_threads = 1L)
# Without parallelization -------------------------------------------------
cli::cli_h1("CV without parallelisation")
tictoc::tic()
set.seed(247)
# torch::torch_manual_seed(247)
rr_single <- resample(
task = tsk("spam"),
learner = lrn,
resampling = rsmp("cv", folds = 5)
)
tictoc::toc()
saveRDS(rr_single, here::here("attic", glue::glue("rr_single_{format(Sys.time(), format = '%Y%m%d%H%M%S')}.rds")))
# With multisession -------------------------------------------------------
cli::cli_h1("CV with plan('multisession')")
tictoc::tic()
future::plan("multisession")
set.seed(247)
# torch::torch_manual_seed(247)
rr_multisession <- resample(
task = tsk("spam"),
learner = lrn,
resampling = rsmp("cv", folds = 5)
)
tictoc::toc()
saveRDS(rr_single, here::here("attic", glue::glue("rr_multisession_{format(Sys.time(), format = '%Y%m%d%H%M%S')}.rds")))
# With multicore ----------------------------------------------------------
# Error in (function (self, inputs, gradient, retain_graph, create_graph) :
# Unable to handle autograd's threading in combination with fork-based multiprocessing. See https://github.com/pytorch/pytorch/wiki/Autograd-and-Fork
# future::plan("multicore")
# set.seed(247)
# rr_multicore <- resample(
# task = tsk("spam"),
# learner = lrn,
# resampling = rsmp("cv", folds = 5)
# )
if (interactive()) {
# TODO: Figure out how to Reduce() an all.equal over sub-elements or something
cli::cli_h1("Checking sequential resampling")
rr_singles <- purrr::map(fs::dir_ls(here::here("attic"), glob = "*rr_single*rds"), readRDS)
cli::cli_text(
"Checking if predictions are the same: ",
purrr::reduce(purrr::map(rr_singles, ~.x$predictions()), all.equal)
)
cli::cli_text(
"Checking if scored CE is the same: ",
purrr::reduce(purrr::map(rr_singles, ~.x$score()[["classif.ce"]]), all.equal)
)
cli::cli_h1("Checking multisession resampling")
rr_multisess <- purrr::map(fs::dir_ls(here::here("attic"), glob = "*rr_multisession*rds"), readRDS)
cli::cli_text(
"Checking if predictions are the same: ",
purrr::reduce(purrr::map(rr_multisess, ~.x$predictions()), all.equal)
)
cli::cli_text(
"Checking if scored CE is the same: ",
purrr::reduce(purrr::map(rr_multisess, ~.x$score()[["classif.ce"]]), all.equal)
)
}