Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

set_num_threads on Linux does not seem to work #1178

Open
MrDomani opened this issue Jun 29, 2024 · 0 comments
Open

set_num_threads on Linux does not seem to work #1178

MrDomani opened this issue Jun 29, 2024 · 0 comments

Comments

@MrDomani
Copy link

Hi,

I'm using a fairly up-to-date Manjaro Linux. I've noticed that R's torch does not seem utilize my CPU (AMD Ryzen 7, series 5000) to its full extent. Further, using torch_set_num_threads does not seem to take any effect, as the code takes roughly the same amount of time. An equivalent Python code does not have these issues.

I'm attaching a reproducible example, modelled after one available at documentation of Python torch library. Let me know whether the setup is correct and whether You observe a similar effect on Your side.

size <- 1024
set.seed(2024)
X <- matrix(runif(size^2), size, size)
Y <- matrix(runif(size^2), size, size)

for(n_threads in c(1,2,3,4,5,6,7,8)){
  torch::torch_set_num_threads(n_threads)
  library(torch)
  message <- paste0("Number of threads: ", torch_get_num_threads(), "\n")
  cat(message)
  t1 <- torch_tensor(X)
  t2 <- torch_tensor(Y)
  time_start <- Sys.time()
  out <- microbenchmark::microbenchmark(torch_matmul(t1,t2), times = 7)
  time_stop <- Sys.time()
  print((time_stop - time_start) / 7)
  print(out)
  detach(package:torch)
}

print(sessionInfo())

Output :

Number of threads: 1
Time difference of 0.4125318 secs
Unit: milliseconds
                 expr      min       lq     mean   median       uq      max neval
 torch_matmul(t1, t2) 385.0982 386.6031 392.8908 387.9618 390.5209 422.9279     7
Number of threads: 2
Time difference of 0.4029093 secs
Unit: milliseconds
                 expr      min       lq     mean   median       uq      max neval
 torch_matmul(t1, t2) 383.9286 385.3504 387.0242 387.8626 388.6715 389.3341     7
Number of threads: 3
Time difference of 0.412262 secs
Unit: milliseconds
                 expr      min       lq     mean   median       uq      max neval
 torch_matmul(t1, t2) 390.5602 393.3822 395.6617 394.6437 397.4655 402.7326     7
Number of threads: 4
Time difference of 0.4133101 secs
Unit: milliseconds
                 expr      min       lq     mean   median       uq      max neval
 torch_matmul(t1, t2) 389.1392 389.7427 398.4571 390.4203 390.6904 448.7739     7
Number of threads: 5
Time difference of 0.4015262 secs
Unit: milliseconds
                 expr      min       lq     mean   median       uq      max neval
 torch_matmul(t1, t2) 383.1957 385.1875 386.9126 387.8565 388.3859 390.1893     7
Number of threads: 6
Time difference of 0.4085842 secs
Unit: milliseconds
                 expr      min       lq     mean   median       uq      max neval
 torch_matmul(t1, t2) 387.7209 390.1149 391.9693 391.7547 393.8607 396.3586     7
Number of threads: 7
Time difference of 0.4061831 secs
Unit: milliseconds
                 expr      min       lq     mean   median       uq      max neval
 torch_matmul(t1, t2) 389.0678 389.5376 391.5841 390.2344 392.4632 397.7849     7
Number of threads: 8
Time difference of 0.4225501 secs
Unit: milliseconds
                 expr      min    lq    mean   median       uq      max neval
 torch_matmul(t1, t2) 396.3549 398.5 406.618 400.3949 404.1815 444.2131     7
R version 4.4.0 (2024-04-24)
Platform: x86_64-pc-linux-gnu
Running under: Manjaro Linux

Matrix products: default
BLAS:   /usr/lib/libblas.so.3.12.0 
LAPACK: /usr/lib/liblapack.so.3.12.0

locale:
 [1] LC_CTYPE=en_GB.UTF-8       LC_NUMERIC=C               LC_TIME=pl_PL.UTF-8       
 [4] LC_COLLATE=en_GB.UTF-8     LC_MONETARY=pl_PL.UTF-8    LC_MESSAGES=en_GB.UTF-8   
 [7] LC_PAPER=pl_PL.UTF-8       LC_NAME=C                  LC_ADDRESS=C              
[10] LC_TELEPHONE=C             LC_MEASUREMENT=pl_PL.UTF-8 LC_IDENTIFICATION=C       

time zone: Europe/Warsaw
tzcode source: system (glibc)

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

loaded via a namespace (and not attached):
 [1] microbenchmark_1.4.10 processx_3.8.4        bit_4.0.5             compiler_4.4.0       
 [5] magrittr_2.0.3        cli_3.6.2             tools_4.4.0           rstudioapi_0.16.0    
 [9] torch_0.13.0          Rcpp_1.0.12           bit64_4.0.5           coro_1.0.4           
[13] callr_3.7.6           ps_1.7.6              rlang_1.1.3  
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant