Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introducing data.table dependency to significantly improve read_stan_csv read times #1018

Open
SimonCMills opened this issue Jun 20, 2022 · 6 comments

Comments

@SimonCMills
Copy link

SimonCMills commented Jun 20, 2022

read_stan_csv becomes very slow once models have thousands of parameters, with the bottleneck occurring at the read stage (see e.g. paul-buerkner/brms#1331). A very simple solution would be to alter the code block running lines 143:161 in stan_csv.R section with code that instead reads in the csv via data.table::fread(). I've worked on this a bit with @jsocolar and found that the readLines() approach that rstan currently uses is faster up to around ~1000 parameters, at which point it becomes increasingly slow relative to fread(). Across the range of csv sizes for which the readLines() approach is faster, fread() times are however also fast (<1-2 seconds for a single model), which I think is probably a trivial slowdown for most purposes? Conversely, by the time you are reading a csv with around 5000 parameters you are saving >10 seconds by using fread (~30% saving relative to readLines()) with the proportional and absolute savings continuing to widen with increasing number of parameters.

Would you be willing to consider introducing a data.table dependency in order to achieve speedups?

Code checking timings and equivalence of two methods below (comparing just the initial code block which is then used downstream in the rest of the function). Checking equivalence of two methods are complicated by occasional minor floating point differences between data.table and base R (discussed e.g. here). New code owes heavily to the cmdstanr implementation of this function.

# comparing header section of rstan::read_stan_csv (in read_initial), and 
# updated version that makes use of data.table::fread (in read_initial2). 
# Check for equivalence and compare timings.

library(cmdstanr)

model_code <- "
parameters {
  vector[10] x;
}
model {
  x ~ std_normal();
}
"
stan_file <- cmdstanr::write_stan_file(code = model_code)
mod <- cmdstan_model(stan_file)
fit <- mod$sample(parallel_chains = 4)
fit_warmup <- mod$sample(parallel_chains = 4, save_warmup = T)
# note: sampling with dense metric becomes slow when number of parameters become 
# large (e.g. around 500 is slow on my machine). 
fit_dense_warmup <- mod$sample(parallel_chains = 4, metric = "dense_e", save_warmup = T)
fit_optimize <- mod$optimize()
fit_variational <- mod$variational()

test_suite <- list(
  fit$output_files()[[1]], 
  fit$output_files(), 
  fit_warmup$output_files(),
  fit_dense_warmup$output_files(),
  fit_optimize$output_files(),
  fit_variational$output_files()
)

# code from read_stan_csv, taken verbatim from lines 127:161 of stan_csv.R 
read_initial <- function(csvfiles) {
  # Read the csv files saved from Stan (or RStan) to a stanfit object
  # Args:
  #   csvfiles: csv files fitted for the same model; each file contains 
  #     the sample of one chain 
  if (length(csvfiles) < 1) 
    stop("csvfiles does not contain any CSV file name")
  
  g_skip <- 10 # g_skip is never used anywhere.
  
  ss_lst <- vector("list", length(csvfiles))
  cs_lst2 <- vector("list", length(csvfiles))
  
  for (i in seq_along(csvfiles)) {
    f = csvfiles[i]    
    header <- rstan:::read_csv_header(f)
    lineno <- attr(header, 'lineno') 
    vnames <- strsplit(header, ",")[[1]] 
    iter.count <- attr(header,"iter.count") 
    variable.count <- length(vnames) 
    
    lines = readLines(f)
    comment_lines = grep("^#", lines)
    comments = lines[comment_lines]
    con = textConnection(lines[-comment_lines])
    on.exit(close(con))
    df = read.csv(con, colClasses = "numeric")
    cs_lst2[[i]] <- rstan:::parse_stancsv_comments(comments)
    if("output_samples" %in% names(cs_lst2[[i]])) 
      df <- df[-1,] # remove the means 
    ss_lst[[i]] <- df
  } 
  list(csvfiles = csvfiles, ss_lst = ss_lst, cs_lst2 = cs_lst2,
       f = f) 
}

# updated version to use fread, code inherited from cmdstanr

# repair path helper function 
# verbatim from cmdstanr:::repair_path
repair_path <- function(path) {
  if (!length(path) || !is.character(path)) {
    return(path)
  }
  path <- path.expand(path)
  path <- gsub("\\\\", "/", path)
  path <- gsub("//", "/", path)
  if (endsWith(path, "/")) {
    path <- substr(path, 1, nchar(path) - 1)
  }
  path
}

read_initial2 <- function(csvfiles) {
  if (length(csvfiles) < 1) 
    stop("csvfiles does not contain any CSV file name")
  
  ss_lst <- vector("list", length(csvfiles))
  cs_lst2 <- vector("list", length(csvfiles))
  
  if (length(csvfiles) < 1) 
    stop("csvfiles does not contain any CSV file name")
  
  ss_lst <- vector("list", length(csvfiles))
  cs_lst2 <- vector("list", length(csvfiles))
  
  for (i in seq_along(csvfiles)) {
    f = csvfiles[i]    
    
    # get non-comment component, "df" (code from cmdstanr:::read_cmdstan_csv)
    if (isTRUE(.Platform$OS.type == "windows")) {
      grep_path <- repair_path(Sys.which("grep.exe"))
      fread_cmd <- paste0(grep_path, " -v '^#' --color=never '", 
                          f, "'")
    } else {
      fread_cmd <- paste0("grep -v '^#' --color=never '", 
                          f, "'")
    }
    
    df <- data.table::fread(cmd = fread_cmd, data.table = FALSE, 
                            colClasses = "numeric")
    
    # get comments (code from cmdstanr:::read_csv_metadata)
    if (isTRUE(.Platform$OS.type == "windows")) {
      grep_path <- repair_path(Sys.which("grep.exe"))
      fread_cmd <- paste0(grep_path, " '^[#a-zA-Z]' --color=never '", 
                          f, "'")
    } else {
      fread_cmd <- paste0("grep '^[#a-zA-Z]' --color=never '", 
                          f, "'")
    }
    
    suppressWarnings(metadata <- data.table::fread(cmd = fread_cmd, 
                                                   colClasses = "character", 
                                                   stringsAsFactors = FALSE, 
                                                   fill = TRUE, sep = "", 
                                                   header = FALSE, 
                                                   data.table=FALSE))
    
    # minor reformatting
    metadata2 <- metadata[,1][grepl("#", metadata[,1])]
    metadata3 <- sub("#$", "# ", metadata2)
    
    cs_lst2[[i]] <- rstan:::parse_stancsv_comments(metadata3)
    if("output_samples" %in% names(cs_lst2[[i]])) df <- df[-1,] # remove the means 
    ss_lst[[i]] <- df
  } 
  
  list(csvfiles = csvfiles, ss_lst = ss_lst, cs_lst2 = cs_lst2,
       f = f)
}

time1 <- system.time(ts1 <- lapply(test_suite, read_initial))
time2 <- system.time(ts2 <- lapply(test_suite, read_initial2))

# often identical, but not always, due to floating point differences in non-comment
# component
identical(ts1, ts2)
all.equal(ts1, ts2, tolerance = 1e-15)

# check non-comment component
for(i in 1:length(ts1)) {
  print(identical(ts1[[i]][-2], ts2[[i]][-2]))
}

for(i in 1:length(ts1)) {
  print(paste0("exactly 0 (", i, "): ",  all(ts1[[i]][2]$ss_lst[[1]] - 
                                        ts2[[i]][2]$ss_lst[[1]] == 0)))
}

for(i in 1:length(ts1)) {
print(paste0("almost 0 (", i, "): ",  all(ts1[[i]][2]$ss_lst[[1]] - 
                                            ts2[[i]][2]$ss_lst[[1]] < 10e-15)))
}
@jsocolar
Copy link

Just to chime in that by the time you get up to 1e+5 parameters, the readLines approach takes hours at least, while the fread approach remains fast. Tagging @bgoodri and @hsbadr: if the data.table dependency is a problem in rstan, then we will re-implement the fast version of read_stan_csv in brms, where it gets used to generate brmsfit objects when using the cmdstanr backend.

@helske
Copy link

helske commented Apr 4, 2024

Is this still under development or has there been other solutions somewhere which I have missed? This is quite a big issue with large models, where sometimes the actual sampling is faster than returning the object to R.

@jsocolar
Copy link

jsocolar commented Apr 4, 2024

No action on the rstan side, but a function that achieves exactly this has been implemented in brms here paul-buerkner/brms#1400

@helske
Copy link

helske commented Apr 22, 2024

Unfortunately, I do not have access to cmdstanr on our server, so solutions based on that do not help. Probably need to write own function which uses freadr then.

@SimonCMills
Copy link
Author

If you can't make use of cmdstanr, then it's easy enough to construct your own version that is also fast (so long as you can make use of data.table's fread()). Just a case of parsing the header comments and repackaging into whatever format you like. If you want it to repackage it into a stanfit object then you might find this useful: paul-buerkner/brms@ad9b9bd

@helske
Copy link

helske commented Apr 23, 2024

Yeah, I wrote a function for reading the CSVs with fread for this purpose, fortunately I don't need to recreate the stanfit object here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants