diff --git a/R-package/R/xgb.DMatrix.R b/R-package/R/xgb.DMatrix.R index d9335405c68a..3aade2396d0f 100644 --- a/R-package/R/xgb.DMatrix.R +++ b/R-package/R/xgb.DMatrix.R @@ -54,7 +54,10 @@ xgb.DMatrix <- function(data, info = list(), missing = NA, silent = FALSE, nthre stop("xgb.DMatrix does not support construction from ", typeof(data)) } dmat <- handle - attributes(dmat) <- list(.Dimnames = list(NULL, cnames), class = "xgb.DMatrix") + attributes(dmat) <- list(class = "xgb.DMatrix") + if (!is.null(cnames)) { + setinfo(dmat, "feature_name", cnames) + } info <- append(info, list(...)) for (i in seq_along(info)) { @@ -144,7 +147,9 @@ dim.xgb.DMatrix <- function(x) { #' @rdname dimnames.xgb.DMatrix #' @export dimnames.xgb.DMatrix <- function(x) { - attr(x, '.Dimnames') + fn <- getinfo(x, "feature_name") + ## row names is null. + list(NULL, fn) } #' @rdname dimnames.xgb.DMatrix @@ -155,13 +160,13 @@ dimnames.xgb.DMatrix <- function(x) { if (!is.null(value[[1L]])) stop("xgb.DMatrix does not have rownames") if (is.null(value[[2]])) { - attr(x, '.Dimnames') <- NULL + setinfo(x, "feature_name", NULL) return(x) } - if (ncol(x) != length(value[[2]])) - stop("can't assign ", length(value[[2]]), " colnames to a ", - ncol(x), " column xgb.DMatrix") - attr(x, '.Dimnames') <- value + if (ncol(x) != length(value[[2]])) { + stop("can't assign ", length(value[[2]]), " colnames to a ", ncol(x), " column xgb.DMatrix") + } + setinfo(x, "feature_name", value[[2]]) x } @@ -203,13 +208,17 @@ getinfo <- function(object, ...) UseMethod("getinfo") #' @export getinfo.xgb.DMatrix <- function(object, name, ...) { if (typeof(name) != "character" || - length(name) != 1 || - !name %in% c('label', 'weight', 'base_margin', 'nrow', - 'label_lower_bound', 'label_upper_bound')) { - stop("getinfo: name must be one of the following\n", - " 'label', 'weight', 'base_margin', 'nrow', 'label_lower_bound', 'label_upper_bound'") + length(name) != 1 || + !name %in% c('label', 'weight', 'base_margin', 'nrow', + 'label_lower_bound', 'label_upper_bound', "feature_type", "feature_name")) { + stop( + "getinfo: name must be one of the following\n", + " 'label', 'weight', 'base_margin', 'nrow', 'label_lower_bound', 'label_upper_bound', 'feature_type', 'feature_name'" + ) } - if (name != "nrow"){ + if (name == "feature_name" || name == "feature_type") { + ret <- .Call(XGDMatrixGetStrFeatureInfo_R, object, name) + } else if (name != "nrow"){ ret <- .Call(XGDMatrixGetInfo_R, object, name) } else { ret <- nrow(object) @@ -294,6 +303,30 @@ setinfo.xgb.DMatrix <- function(object, name, info, ...) { .Call(XGDMatrixSetInfo_R, object, name, as.numeric(info)) return(TRUE) } + + set_feat_info <- function(name) { + msg <- sprintf( + "The number of %s must equal to the number of columns in the input data. %s vs. %s", + name, + length(info), + ncol(object) + ) + if (!is.null(info)) { + info <- as.list(info) + if (length(info) != ncol(object)) { + stop(msg) + } + } + .Call(XGDMatrixSetStrFeatureInfo_R, object, name, info) + } + if (name == "feature_name") { + set_feat_info("feature_name") + return(TRUE) + } + if (name == "feature_type") { + set_feat_info("feature_type") + return(TRUE) + } stop("setinfo: unknown info name ", name) return(FALSE) } diff --git a/R-package/src/init.c b/R-package/src/init.c index 4e38f8220a86..13b21fd96c92 100644 --- a/R-package/src/init.c +++ b/R-package/src/init.c @@ -42,10 +42,12 @@ extern SEXP XGDMatrixCreateFromCSR_R(SEXP, SEXP, SEXP, SEXP, SEXP); extern SEXP XGDMatrixCreateFromFile_R(SEXP, SEXP); extern SEXP XGDMatrixCreateFromMat_R(SEXP, SEXP, SEXP); extern SEXP XGDMatrixGetInfo_R(SEXP, SEXP); +extern SEXP XGDMatrixGetStrFeatureInfo_R(SEXP, SEXP); extern SEXP XGDMatrixNumCol_R(SEXP); extern SEXP XGDMatrixNumRow_R(SEXP); extern SEXP XGDMatrixSaveBinary_R(SEXP, SEXP, SEXP); extern SEXP XGDMatrixSetInfo_R(SEXP, SEXP, SEXP); +extern SEXP XGDMatrixSetStrFeatureInfo_R(SEXP, SEXP, SEXP); extern SEXP XGDMatrixSliceDMatrix_R(SEXP, SEXP); extern SEXP XGBSetGlobalConfig_R(SEXP); extern SEXP XGBGetGlobalConfig_R(); @@ -78,10 +80,12 @@ static const R_CallMethodDef CallEntries[] = { {"XGDMatrixCreateFromFile_R", (DL_FUNC) &XGDMatrixCreateFromFile_R, 2}, {"XGDMatrixCreateFromMat_R", (DL_FUNC) &XGDMatrixCreateFromMat_R, 3}, {"XGDMatrixGetInfo_R", (DL_FUNC) &XGDMatrixGetInfo_R, 2}, + {"XGDMatrixGetStrFeatureInfo_R", (DL_FUNC) &XGDMatrixGetStrFeatureInfo_R, 2}, {"XGDMatrixNumCol_R", (DL_FUNC) &XGDMatrixNumCol_R, 1}, {"XGDMatrixNumRow_R", (DL_FUNC) &XGDMatrixNumRow_R, 1}, {"XGDMatrixSaveBinary_R", (DL_FUNC) &XGDMatrixSaveBinary_R, 3}, {"XGDMatrixSetInfo_R", (DL_FUNC) &XGDMatrixSetInfo_R, 3}, + {"XGDMatrixSetStrFeatureInfo_R", (DL_FUNC) &XGDMatrixSetStrFeatureInfo_R, 3}, {"XGDMatrixSliceDMatrix_R", (DL_FUNC) &XGDMatrixSliceDMatrix_R, 2}, {"XGBSetGlobalConfig_R", (DL_FUNC) &XGBSetGlobalConfig_R, 1}, {"XGBGetGlobalConfig_R", (DL_FUNC) &XGBGetGlobalConfig_R, 0}, diff --git a/R-package/src/xgboost_R.cc b/R-package/src/xgboost_R.cc index 2383eb9a6ec4..0fe56be1711d 100644 --- a/R-package/src/xgboost_R.cc +++ b/R-package/src/xgboost_R.cc @@ -249,15 +249,53 @@ XGB_DLL SEXP XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array) { return R_NilValue; } +XGB_DLL SEXP XGDMatrixSetStrFeatureInfo_R(SEXP handle, SEXP field, SEXP array) { + R_API_BEGIN(); + size_t len{0}; + if (!isNull(array)) { + len = length(array); + } + + const char *name = CHAR(asChar(field)); + std::vector str_info; + for (size_t i = 0; i < len; ++i) { + str_info.emplace_back(CHAR(asChar(VECTOR_ELT(array, i)))); + } + std::vector vec(len); + std::transform(str_info.cbegin(), str_info.cend(), vec.begin(), + [](auto const &str) { return str.c_str(); }); + CHECK_CALL(XGDMatrixSetStrFeatureInfo(R_ExternalPtrAddr(handle), name, vec.data(), len)); + R_API_END(); + return R_NilValue; +} + +XGB_DLL SEXP XGDMatrixGetStrFeatureInfo_R(SEXP handle, SEXP field) { + SEXP ret; + R_API_BEGIN(); + char const **out_features{nullptr}; + bst_ulong len{0}; + const char *name = CHAR(asChar(field)); + XGDMatrixGetStrFeatureInfo(R_ExternalPtrAddr(handle), name, &len, &out_features); + + if (len > 0) { + ret = PROTECT(allocVector(STRSXP, len)); + for (size_t i = 0; i < len; ++i) { + SET_STRING_ELT(ret, i, mkChar(out_features[i])); + } + } else { + ret = PROTECT(R_NilValue); + } + R_API_END(); + UNPROTECT(1); + return ret; +} + XGB_DLL SEXP XGDMatrixGetInfo_R(SEXP handle, SEXP field) { SEXP ret; R_API_BEGIN(); bst_ulong olen; const float *res; - CHECK_CALL(XGDMatrixGetFloatInfo(R_ExternalPtrAddr(handle), - CHAR(asChar(field)), - &olen, - &res)); + CHECK_CALL(XGDMatrixGetFloatInfo(R_ExternalPtrAddr(handle), CHAR(asChar(field)), &olen, &res)); ret = PROTECT(allocVector(REALSXP, olen)); for (size_t i = 0; i < olen; ++i) { REAL(ret)[i] = res[i]; diff --git a/R-package/tests/testthat/test_dmatrix.R b/R-package/tests/testthat/test_dmatrix.R index b4f2b6ff31a3..eb83544d8d16 100644 --- a/R-package/tests/testthat/test_dmatrix.R +++ b/R-package/tests/testthat/test_dmatrix.R @@ -42,6 +42,20 @@ test_that("xgb.DMatrix: saving, loading", { dtest4 <- xgb.DMatrix(tmp_file, silent = TRUE) expect_equal(dim(dtest4), c(3, 4)) expect_equal(getinfo(dtest4, 'label'), c(0, 1, 0)) + + # check that feature info is saved + data(agaricus.train, package = 'xgboost') + dtrain <- xgb.DMatrix(data = agaricus.train$data, label = agaricus.train$label) + cnames <- colnames(dtrain) + expect_equal(length(cnames), 126) + tmp_file <- tempfile('xgb.DMatrix_') + xgb.DMatrix.save(dtrain, tmp_file) + dtrain <- xgb.DMatrix(tmp_file) + expect_equal(colnames(dtrain), cnames) + + ft <- rep(c("c", "q"), each=length(cnames)/2) + setinfo(dtrain, "feature_type", ft) + expect_equal(ft, getinfo(dtrain, "feature_type")) }) test_that("xgb.DMatrix: getinfo & setinfo", {