Skip to content

Commit 8f5515e

Browse files
authored
Merge pull request #156 from drieslab/dev
enh: allow explicit col_class setting with fread colmatch
2 parents a2b6a73 + e126dff commit 8f5515e

File tree

2 files changed

+53
-17
lines changed

2 files changed

+53
-17
lines changed

R/file_read.R

+46-17
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,11 @@ dir_manifest <- function(
8484
#' @param drop Vector of column names or numbers to drop, keep the rest.
8585
#' @param schema_detect_nrow numeric. how many rows to sample to guess the
8686
#' arrow schema to use.
87+
#' @param col_classes character vector (optional). R types each column is
88+
#' expected to be. These will be translated to arrow schema. Only necessary
89+
#' if the schema autodetection from `schema_detect_nrow` is insufficient.
90+
#' Select one or multiple of "integer", "double", "raw", "character",
91+
#' "logical", "Date", "POSIXct"
8792
#' @param verbose be verbose
8893
#' @param ... additional parameters to pass to [arrow::open_delim_dataset()]
8994
#' @keywords internal
@@ -106,11 +111,13 @@ read_colmatch <- function(file,
106111
values_to_match,
107112
drop = NULL,
108113
schema_detect_nrow = 1000,
114+
col_classes = NULL,
109115
verbose = FALSE,
110116
...) {
111117
# check dependencies
112118
package_check("dplyr")
113119
.arrow_codec_check(file)
120+
checkmate::assert_character(col)
114121

115122
file <- normalizePath(file)
116123
# get colnames
@@ -125,9 +132,15 @@ read_colmatch <- function(file,
125132
stop("read_colmatch: sep param cannot be guessed", call. = FALSE)
126133
}
127134
}
135+
136+
if (is.null(col_classes)) {
137+
s <- .arrow_infer_schema(file, n_rows = schema_detect_nrow)
138+
} else {
139+
s <- .arrow_infer_schema(file, col_classes = col_classes)
140+
}
128141

129142
a <- arrow::read_delim_arrow(file,
130-
schema = .arrow_infer_schema(file, n_rows = schema_detect_nrow),
143+
schema = s,
131144
skip = 1L,
132145
delim = sep,
133146
...
@@ -234,27 +247,32 @@ fread_colmatch <- function(...) {
234247
}
235248

236249
# Use data.table to get a sample and infer schema
237-
.arrow_infer_schema <- function(file, n_rows = 1000) {
250+
.arrow_infer_schema <- function(file, n_rows = 1000, col_classes = NULL) {
251+
# reduce if 'col_classes' given since we then only care about colnames
252+
if (!is.null(col_classes)) n_rows <- 10
238253
lines <- readLines(file, n = n_rows)
239254
# Parse with fread as string input
240255
sample_dt <- data.table::fread(paste(lines, collapse = "\n"))
241256

242-
# Map data.table/R types to Arrow types
243-
type_mapping <- list(
244-
"integer" = arrow::int32(),
245-
"double" = arrow::float64(),
246-
"raw" = arrow::binary(),
247-
"character" = arrow::string(),
248-
"logical" = arrow::boolean(),
249-
"Date" = arrow::date32(),
250-
"POSIXct" = arrow::timestamp("us")
251-
)
252-
253257
# Create schema
254-
fields <- lapply(sample_dt, function(col) {
255-
col_type <- type_mapping[[typeof(col)]] %null% arrow::string()
256-
col_type
257-
})
258+
if (!is.null(col_classes)) { # exact class given
259+
if (length(col_classes) != ncol(sample_dt)) {
260+
message("data preview:")
261+
print(sample_dt)
262+
stop(sprintf(
263+
".arrow_infer_schema: 'col_classes' incorrect:\n %s\n %s",
264+
paste(length(col_classes), "col_classes provided"),
265+
paste(ncol(sample_dt), "columns in data")
266+
), call. = FALSE)
267+
}
268+
fields <- lapply(col_classes, function(col_class) {
269+
.arrow_type_map()[[col_class]] %null% arrow::string()
270+
})
271+
} else { # discern class based on sampled dt
272+
fields <- lapply(sample_dt, function(col) {
273+
.arrow_type_map()[[typeof(col)]] %null% arrow::string()
274+
})
275+
}
258276

259277
# If there were no column names, generate them
260278
if (all(grepl("^V[0-9]+$", names(sample_dt)))) {
@@ -265,3 +283,14 @@ fread_colmatch <- function(...) {
265283

266284
arrow::schema(!!!fields)
267285
}
286+
287+
# Map R types to Arrow types
288+
.arrow_type_map <- function(timestamp_locale = "us") list(
289+
"integer" = arrow::int32(),
290+
"double" = arrow::float64(),
291+
"raw" = arrow::binary(),
292+
"character" = arrow::string(),
293+
"logical" = arrow::boolean(),
294+
"Date" = arrow::date32(),
295+
"POSIXct" = arrow::timestamp(timestamp_locale)
296+
)

man/read_colmatch.Rd

+7
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)