Skip to content

Commit

Permalink
Add tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jun 7, 2021
1 parent 48414aa commit eb1ee9a
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 3 deletions.
8 changes: 6 additions & 2 deletions R-package/R/xgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -410,10 +410,14 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
cnames <- if (!is.null(colnames(newdata))) c(colnames(newdata), "BIAS") else NULL
if (predcontrib) {
dimnames(arr) <- list(cnames, NULL, NULL)
arr <- aperm(a = arr, perm = c(2, 3, 1)) # [group, row, col]
if (!strict_shape) {
arr <- aperm(a = arr, perm = c(2, 3, 1)) # [group, row, col]
}
} else if (predinteraction) {
dimnames(arr) <- list(cnames, cnames, NULL, NULL)
arr <- aperm(a = arr, perm = c(3, 4, 1, 2)) # [group, row, col, col]
if (!strict_shape) {
arr <- aperm(a = arr, perm = c(3, 4, 1, 2)) # [group, row, col, col]
}
}

if (!strict_shape) {
Expand Down
55 changes: 55 additions & 0 deletions R-package/tests/testthat/test_basic.R
Original file line number Diff line number Diff line change
Expand Up @@ -390,3 +390,58 @@ test_that("Configuration works", {
reloaded_config <- xgb.config(bst)
expect_equal(config, reloaded_config);
})

test_that("strict_shape works", {
n_rounds = 2

test_strict_shape <- function(bst, X, n_groups) {
predt = predict(bst, X, strict_shape = TRUE)
margin = predict(bst, X, outputmargin = TRUE, strict_shape = TRUE)
contri = predict(bst, X, predcontrib = TRUE, strict_shape = TRUE)
interact = predict(bst, X, predinteraction = TRUE, strict_shape = TRUE)
leaf = predict(bst, X, predleaf = TRUE, strict_shape = TRUE)

n_rows <- nrow(X)
n_cols <- ncol(X)

expect_equal(dim(predt), c(n_groups, n_rows))
expect_equal(dim(margin), c(n_groups, n_rows))
expect_equal(dim(contri), c(n_cols + 1, n_groups, n_rows))
expect_equal(dim(interact), c(n_cols + 1, n_cols + 1, n_groups, n_rows))
expect_equal(dim(leaf), c(1, n_groups, n_rounds, n_rows))

if (n_groups != 1) {
print(seq_len(n_groups))
for (g in seq_len(n_groups)) {
expect_lt(max(abs(colSums(contri[, g, ]) - margin[g, ])), 1e-5)
}
}
}

test_iris <- function() {
y <- as.numeric(iris$Species) - 1
X <- as.matrix(iris[, -5])

bst <- xgboost(data = X, label = y,
max_depth = 2, nrounds = n_rounds,
objective = "multi:softprob", num_class = 3, eval_metric = "merror")

test_strict_shape(bst, X, 3)
}


test_agaricus <- function() {
data(agaricus.train, package = 'xgboost')
X <- agaricus.train$data
y <- agaricus.train$label

bst <- xgboost(data = X, label = y, max_depth = 2,
nrounds = n_rounds, objective = "binary:logistic",
eval_metric = 'error', eval_metric = 'auc', eval_metric = "logloss")

test_strict_shape(bst, X, 1)
}

test_iris()
test_agaricus()
})
9 changes: 8 additions & 1 deletion doc/prediction.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Prediction

There are a number of prediction functions in XGBoost with various parameters. This
document attempts to clarify some of confusions around prediction with a focus on the
Python binding.
Python binding, R package is similar when ``strict_shape`` is specified (see below).

******************
Prediction Options
Expand Down Expand Up @@ -58,6 +58,13 @@ After 1.4 release, we added a new parameter called ``strict_shape``, one can set
``apply`` method in scikit learn interface, this is set to False by default.


For R package, when ``strict_shape`` is specified, an ``array`` is returned, with the same
value as Python except R array is column-major while Python numpy array is row-major, so
all the dimensions are reversed. For example, for a Python ``predict_leaf`` output
obtained by having ``strict_shape=True`` has 4 dimensions: ``(n_samples, n_iterations,
n_classes, n_trees_in_forest)``, while R with ``strict_shape=TRUE`` outputs
``(n_trees_in_forest, n_classes, n_iterations, n_samples)``.

Other than these prediction types, there's also a parameter called ``iteration_range``,
which is similar to model slicing. But instead of actually splitting up the model into
multiple stacks, it simply returns the prediction formed by the trees within range.
Expand Down

0 comments on commit eb1ee9a

Please sign in to comment.