diff --git a/NEWS.md b/NEWS.md index aa54a18e5..3e3978424 100644 --- a/NEWS.md +++ b/NEWS.md @@ -8,6 +8,8 @@ * feat: `benchmark_grid()` will now throw a warning if you mix different predict types in the design (#1273). * feat: Converting a `BenchmarkResult` to a `data.table` now includes the `task_id`, `learner_id`, and `resampling_id` columns (#1275). +* fix: Instantiating (repeated) CV on tasks with observations less than the +number of folds now fails. # mlr3 0.23.0 diff --git a/R/Resampling.R b/R/Resampling.R index 6cf54d97d..2986a7b58 100644 --- a/R/Resampling.R +++ b/R/Resampling.R @@ -170,6 +170,7 @@ Resampling = R6Class("Resampling", #' the object in its previous state. instantiate = function(task) { task = assert_task(as_task(task)) + private$.check(task) strata = task$strata groups = task$groups @@ -257,6 +258,9 @@ Resampling = R6Class("Resampling", .id = NULL, .hash = NULL, .groups = NULL, + .check = function(task) { + TRUE + }, .get_set = function(getter, i) { if (!self$is_instantiated) { diff --git a/R/ResamplingCV.R b/R/ResamplingCV.R index 34a0e60fc..9922b492d 100644 --- a/R/ResamplingCV.R +++ b/R/ResamplingCV.R @@ -61,12 +61,25 @@ ResamplingCV = R6Class("ResamplingCV", inherit = Resampling, private = list( .sample = function(ids, ...) { + pvs = self$param_set$get_values() data.table( row_id = ids, - fold = shuffle(seq_along0(ids) %% as.integer(self$param_set$values$folds) + 1L), + fold = shuffle(seq_along0(ids) %% as.integer(pvs$folds) + 1L), key = "fold" ) }, + .check = function(task) { + pvs = self$param_set$get_values() + if (!is.null(task$groups)) { + n_groups = length(unique(task$groups$group)) + if (n_groups < pvs$folds) { + stopf("Cannot instantiate ResamplingCV with %i folds on a grouped task with %i groups.", pvs$folds, n_groups) + } + } + if (task$nrow < pvs$folds) { + stopf("Cannot instantiate ResamplingCV with %i folds on a task with %i rows.", pvs$folds, task$nrow) + } + }, .get_train = function(i) { self$instance[!list(i), "row_id", on = "fold"][[1L]] diff --git a/R/ResamplingRepeatedCV.R b/R/ResamplingRepeatedCV.R index bbf3b492c..d27ee43e4 100644 --- a/R/ResamplingRepeatedCV.R +++ b/R/ResamplingRepeatedCV.R @@ -93,13 +93,25 @@ ResamplingRepeatedCV = R6Class("ResamplingRepeatedCV", inherit = Resampling, private = list( .sample = function(ids, ...) { - pv = self$param_set$values + pv = self$param_set$get_values() n = length(ids) folds = as.integer(pv$folds) map_dtr(seq_len(pv$repeats), function(i) { data.table(row_id = ids, rep = i, fold = shuffle(seq_len0(n) %% folds + 1L)) }) }, + .check = function(task) { + pvs = self$param_set$get_values() + if (!is.null(task$groups)) { + n_groups = length(unique(task$groups$group)) + if (n_groups < pvs$folds) { + stopf("Cannot instantiate ResamplingRepeatedCV with %i folds on a grouped task with %i groups.", pvs$folds, n_groups) + } + } + if (task$nrow < pvs$folds) { + stopf("Cannot instantiate ResamplingRepeatedCV with %i folds on a task with %i rows.", pvs$folds, task$nrow) + } + }, .get_train = function(i) { i = as.integer(i) - 1L diff --git a/tests/testthat/test_Resampling.R b/tests/testthat/test_Resampling.R index db7f67d7b..38076622f 100644 --- a/tests/testthat/test_Resampling.R +++ b/tests/testthat/test_Resampling.R @@ -158,3 +158,17 @@ test_that("task_row_hash in Resampling works correctly", { resampling$instantiate(task) expect_identical(resampling$task_row_hash, task$row_hash) }) + +test_that("folds must be <= task size", { + cv = rsmp("cv", folds = 151) + rep_cv = rsmp("repeated_cv", folds = 151) + task = tsk("iris") + expect_error(cv$instantiate(task), "Cannot instantiate ResamplingCV with 151 folds on a task with 150 rows") + expect_error(rep_cv$instantiate(task), "Cannot instantiate ResamplingRepeatedCV with 151 folds on a task with 150 rows") + + task$col_roles$group = "Species" + cv$param_set$set_values(folds = 4L) + rep_cv$param_set$set_values(folds = 4L) + expect_error(cv$instantiate(task), "on a grouped task with 3 groups") + expect_error(rep_cv$instantiate(task), "on a grouped task with 3 groups") +})