Skip to content

Commit

Permalink
fix: task's cbind works with non-standard pk (#1079)
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer authored Aug 17, 2024
1 parent b165d9e commit dab0b33
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 3 deletions.
3 changes: 2 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
* refactor: Optimize runtime of setting row roles.
* refactor: Optimize runtime of marshalling.
* refactor: Optimize runtime of `Task$col_info`
* fix: `Task$cbind()` now works with non-standard primary keys
for `data.frames` (#961).
* fix: Triggering of fallback learner now has log-level "info"
instead of "debug" (#972)


# mlr3 0.20.2

* refactor: Move RhpcBLASctl to suggest.
Expand Down
7 changes: 5 additions & 2 deletions R/Task.R
Original file line number Diff line number Diff line change
Expand Up @@ -562,8 +562,11 @@ Task = R6Class("Task",
return(invisible(self))
}

row_ids = if (pk %in% names(data)) pk else self$row_ids
data = as_data_backend(data, primary_key = row_ids)
row_ids = if (pk %nin% names(data)) {
data[[pk]] = self$row_ids
}

data = as_data_backend(data, primary_key = pk)
} else {
assert_backend(data)
if (data$ncol <= 1L) {
Expand Down
8 changes: 8 additions & 0 deletions tests/testthat/test_Task.R
Original file line number Diff line number Diff line change
Expand Up @@ -634,3 +634,11 @@ test_that("divide requires ratio in (0, 1)", {
test_that("divide requires ids to be row_ids", {
expect_error(tsk("iris")$divide(ids = 0.5))
})

test_that("cbind supports non-standard primary key (#961)", {
tbl = data.table(x = runif(10), y = runif(10), myid = 1:10)
b = as_data_backend(tbl, primary_key = "myid")
task = as_task_regr(b, target = "y")
task$cbind(data.table(x1 = 10:1))
expect_true("x1" %in% task$feature_names)
})

0 comments on commit dab0b33

Please sign in to comment.