10000 fix(cv): number of rows must be greater than folds by sebffischer · Pull Request #1294 · mlr-org/mlr3 · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

fix(cv): number of rows must be greater than folds #1294

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
* feat: `benchmark_grid()` will now throw a warning if you mix different predict types in the
design (#1273).
* feat: Converting a `BenchmarkResult` to a `data.table` now includes the `task_id`, `learner_id`, and `resampling_id` columns (#1275).
* fix: Instantiating (repeated) CV on tasks with observations less than the
number of folds now fails.

# mlr3 0.23.0

Expand Down
4 changes: 4 additions & 0 deletions R/Resampling.R
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ Resampling = R6Class("Resampling",
#' the object in its previous state.
instantiate = function(task) {
task = assert_task(as_task(task))
private$.check(task)
strata = task$strata
groups = task$groups

Expand Down Expand Up @@ -257,6 +258,9 @@ Resampling = R6Class("Resampling",
.id = NULL,
.hash = NULL,
.groups = NULL,
.check = function(task) {
TRUE
},

.get_set = function(getter, i) {
if (!self$is_instantiated) {
Expand Down
15 changes: 14 additions & 1 deletion R/ResamplingCV.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,25 @@ ResamplingCV = R6Class("ResamplingCV", inherit = Resampling,

private = list(
.sample = function(ids, ...) {
pvs = self$param_set$get_values()
data.table(
row_id = ids,
fold = shuffle(seq_along0(ids) %% as.integer(self$param_set$values$folds) + 1L),
fold = shuffle(seq_along0(ids) %% as.integer(pvs$folds) + 1L),
key = "fold"
)
},
.check = function(task) {
pvs = self$param_set$get_values()
if (!is.null(task$groups)) {
n_groups = length(unique(task$groups$group))
if (n_groups < pvs$folds) {
stopf("Cannot instantiate ResamplingCV with %i folds on a grouped task with %i groups.", pvs$folds, n_groups)
}
}
if (task$nrow < pvs$folds) {
stopf("Cannot instantiate ResamplingCV with %i folds on a task with %i rows.", pvs$folds, task$nrow)
}
},

.get_train = function(i) {
self$instance[!list(i), "row_id", on = "fold"][[1L]]
Expand Down
14 changes: 13 additions & 1 deletion R/ResamplingRepeatedCV.R
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,25 @@ ResamplingRepeatedCV = R6Class("ResamplingRepeatedCV", inherit = Resampling,

private = list(
.sample = function(ids, ...) {
pv = self$param_set$values
pv = self$param_set$get_values()
n = length(ids)
folds = as.integer(pv$folds)
map_dtr(seq_len(pv$repeats), function(i) {
data.table(row_id = ids, rep = i, fold = shuffle(seq_len0(n) %% folds + 1L))
})
},
.check = function(task) {
pvs = self$param_set$get_values()
if (!is.null(task$groups)) {
n_groups = length(unique(task$groups$group))
if (n_groups < pvs$folds) {
stopf("Cannot instantiate ResamplingRepeatedCV with %i folds on a grouped task with %i groups.", pvs$folds, n_groups)
}
}
if (task$nrow < pvs$folds) {
stopf("Cannot instantiate ResamplingRepeatedCV with %i folds on a task with %i rows.", pvs$folds, task$nrow)
}
},

.get_train = function(i) {
i = as.integer(i) - 1L
Expand Down
14 changes: 14 additions & 0 deletions tests/testthat/test_Resampling.R
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,17 @@ test_that("task_row_hash in Resampling works correctly", {
resampling$instantiate(task)
expect_identical(resampling$task_row_hash, task$row_hash)
})

test_that("folds must be <= task size", {
cv = rsmp("cv", folds = 151)
rep_cv = rsmp("repeated_cv", folds = 151)
task = tsk("iris")
expect_error(cv$instantiate(task), "Cannot instantiate ResamplingCV with 151 folds on a task with 150 rows")
expect_error(rep_cv$instantiate(task), "Cannot instantiate ResamplingRepeatedCV with 151 folds on a task with 150 rows")

task$col_roles$group = "Species"
cv$param_set$set_values(folds = 4L)
rep_cv$param_set$set_values(folds = 4L)
expect_error(cv$instantiate(task), "on a grouped task with 3 groups")
expect_error(rep_cv$instantiate(task), "on a grouped task with 3 groups")
})
0