Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dbWriteTable method for the CSV #254

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ connection_copy_data <- function(con, sql, df) {
invisible(.Call(`_RPostgres_connection_copy_data`, con, sql, df))
}

connection_copy_file <- function(con, sql, file) {
invisible(.Call(`_RPostgres_connection_copy_file`, con, sql, file))
}

connection_wait_for_notify <- function(con, timeout_secs) {
.Call(`_RPostgres_connection_wait_for_notify`, con, timeout_secs)
}
Expand Down
79 changes: 79 additions & 0 deletions R/tables.R
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,85 @@ setMethod("dbWriteTable", c("PqConnection", "character", "data.frame"),
)


#' @param header is a logical indicating whether the first data line (but see
#' `skip`) has a header or not. If missing, it value is determined
#' following [read.table()] convention, namely, it is set to TRUE if
#' and only if the first row has one fewer field that the number of columns.
#' @param sep The field separator, defaults to `','`.
#' @param eol The end-of-line delimiter, defaults to `'\n'`.
#' @param skip number of lines to skip before reading the data. Defaults to 0.
#' @param nrows Number of rows to read to determine types.
#' @param colClasses Character vector of R type names, used to override
#' defaults when imputing classes from on-disk file.
#' @param na.strings a character vector of strings which are to be interpreted as NA values.
#' @export
#' @rdname postgres-tables
setMethod("dbWriteTable", c("PqConnection", "character", "character"),
function(conn, name, value, ..., field.types = NULL, overwrite = FALSE,
append = FALSE, header = TRUE, colClasses = NA, row.names = FALSE,
nrows = 50, sep = ",", na.strings = "NA", eol = "\n", skip = 0, temporary = FALSE) {

if (!is.logical(overwrite) || length(overwrite) != 1L || is.na(overwrite)) {
stopc("`overwrite` must be a logical scalar")
}
if (!is.logical(append) || length(append) != 1L || is.na(append)) {
stopc("`append` must be a logical scalar")
}
if (!is.logical(temporary) || length(temporary) != 1L) {
stopc("`temporary` must be a logical scalar")
}
if (overwrite && append) {
stopc("overwrite and append cannot both be TRUE")
}
if (!is.null(field.types) && !(is.character(field.types) && !is.null(names(field.types)) && !anyDuplicated(names(field.types)))) {
stopc("`field.types` must be a named character vector with unique names, or NULL")
}
if (append && !is.null(field.types)) {
stopc("Cannot specify `field.types` with `append = TRUE`")
}

found <- dbExistsTable(conn, name)
if (found && !overwrite && !append) {
stop("Table ", name, " exists in database, and both overwrite and",
" append are FALSE", call. = FALSE)
}
if (found && overwrite) {
dbRemoveTable(conn, name)
}

if (!found || overwrite) {
if (is.null(field.types)) {
tmp_value <- utils::read.table(
value, sep = sep, header = header, skip = skip, nrows = nrows,
na.strings = na.strings, comment.char = "", colClasses = colClasses,
stringsAsFactors = FALSE)
field.types <- lapply(tmp_value, dbDataType, dbObj = conn)
}

dbCreateTable(
conn = conn,
name = name,
fields = field.types,
temporary = temporary
)
}

value <- path.expand(value)
fields <- dbQuoteIdentifier(conn, names(field.types))

skip <- skip + as.integer(header)
sql <- paste0(
"COPY ", dbQuoteIdentifier(conn, name),
" (", paste(fields, collapse = ","), ") ",
"FROM STDIN ", "(FORMAT CSV, DELIMITER '", sep, "', HEADER '", header, "')"
)

connection_copy_file(conn@ptr, sql, value)

invisible(TRUE)
}
)

#' @export
#' @inheritParams DBI::sqlRownamesToColumn
#' @param ... Ignored.
Expand Down
38 changes: 38 additions & 0 deletions man/postgres-tables.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

44 changes: 44 additions & 0 deletions src/DbConnection.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "pch.h"
#include <fstream>
#include "DbConnection.h"
#include "encode.h"
#include "DbResult.h"
Expand Down Expand Up @@ -164,6 +165,49 @@ void DbConnection::copy_data(std::string sql, List df) {
PQclear(pComplete);
}

void DbConnection::copy_csv(std::string sql, std::string file) {
LOG_DEBUG << sql;

if (file.size() == 0)
return;

PGresult* pInit = PQexec(pConn_, sql.c_str());
if (PQresultStatus(pInit) != PGRES_COPY_IN) {
PQclear(pInit);
conn_stop("Failed to initialise COPY");
}
PQclear(pInit);


const size_t buffer_size = 1024 * 64;
std::string buffer;
buffer.reserve(buffer_size);

std::ifstream fs(file.c_str(), std::ios::in);
if (!fs.is_open()) {
stop("Can not open file '%s'.", file);
}

while (!fs.eof()) {
buffer.clear();
fs.read(&buffer[0], buffer_size);
if (PQputCopyData(pConn_, buffer.data(), static_cast<int>(fs.gcount())) != 1) {
conn_stop("Failed to put data");
}
}

if (PQputCopyEnd(pConn_, NULL) != 1) {
conn_stop("Failed to finish COPY");
}

PGresult* pComplete = PQgetResult(pConn_);
if (PQresultStatus(pComplete) != PGRES_COMMAND_OK) {
PQclear(pComplete);
conn_stop("COPY returned error");
}
PQclear(pComplete);
}

void DbConnection::check_connection() {
if (!pConn_) {
stop("Disconnected");
Expand Down
2 changes: 2 additions & 0 deletions src/DbConnection.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class DbConnection : boost::noncopyable {

void copy_data(std::string sql, List df);

void copy_csv(std::string sql, std::string file);

void check_connection();
List info();

Expand Down
12 changes: 12 additions & 0 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,17 @@ BEGIN_RCPP
return R_NilValue;
END_RCPP
}
// connection_copy_file
void connection_copy_file(DbConnection* con, std::string sql, std::string file);
RcppExport SEXP _RPostgres_connection_copy_file(SEXP conSEXP, SEXP sqlSEXP, SEXP fileSEXP) {
BEGIN_RCPP
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< DbConnection* >::type con(conSEXP);
Rcpp::traits::input_parameter< std::string >::type sql(sqlSEXP);
Rcpp::traits::input_parameter< std::string >::type file(fileSEXP);
connection_copy_file(con, sql, file);
return R_NilValue;
}
// connection_wait_for_notify
List connection_wait_for_notify(DbConnection* con, int timeout_secs);
RcppExport SEXP _RPostgres_connection_wait_for_notify(SEXP conSEXP, SEXP timeout_secsSEXP) {
Expand Down Expand Up @@ -293,6 +304,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_RPostgres_connection_is_transacting", (DL_FUNC) &_RPostgres_connection_is_transacting, 1},
{"_RPostgres_connection_set_transacting", (DL_FUNC) &_RPostgres_connection_set_transacting, 2},
{"_RPostgres_connection_copy_data", (DL_FUNC) &_RPostgres_connection_copy_data, 3},
{"_RPostgres_connection_copy_file", (DL_FUNC) &_RPostgres_connection_copy_file, 3},
{"_RPostgres_connection_wait_for_notify", (DL_FUNC) &_RPostgres_connection_wait_for_notify, 2},
{"_RPostgres_encode_vector", (DL_FUNC) &_RPostgres_encode_vector, 1},
{"_RPostgres_encode_data_frame", (DL_FUNC) &_RPostgres_encode_data_frame, 1},
Expand Down
4 changes: 4 additions & 0 deletions src/connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ void connection_copy_data(DbConnection* con, std::string sql, List df) {
}

// [[Rcpp::export]]
void connection_copy_file(DbConnection* con, std::string sql, std::string file) {
return con->copy_csv(sql, file);
}

List connection_wait_for_notify(DbConnection* con, int timeout_secs) {
return con->wait_for_notify(timeout_secs);
}
Expand Down
10 changes: 10 additions & 0 deletions tests/testthat/test-dbWriteTable.R
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,16 @@ with_database_connection({
expect_equal(dbGetQuery(con, "SELECT * FROM xy"), value)
})
})

test_that("Writing CSV to the database", {
with_table(con, "iris", {
tmp <- tempfile()
iris2 <- transform(iris, Species = as.character(Species))
write.csv(iris2, tmp, row.names = FALSE)
dbWriteTable(con, "iris", tmp, temporary = TRUE)
expect_equal(dbReadTable(con, "iris"), iris2)
})
})
})

describe("Inf values", {
Expand Down