8000 New deforest() function for removing trees from a fitted random forest by bgreenwell · Pull Request #571 · imbs-hl/ranger · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

New deforest() function for removing trees from a fitted random forest #571

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
# Generated by roxygen2: do not edit by hand

S3method(deforest,ranger)
S3method(importance,ranger)
S3method(predict,ranger)
S3method(predict,ranger.forest)
S3method(predictions,ranger)
S3method(predictions,ranger.prediction)
S3method(print,deforest.ranger)
S3method(print,ranger)
S3method(print,ranger.forest)
S3method(print,ranger.prediction)
S3method(timepoints,ranger)
S3method(timepoints,ranger.prediction)
export(csrf)
export(deforest)
export(getTerminalNodeIDs)
export(holdoutRF)
export(importance)
Expand Down
175 changes: 175 additions & 0 deletions R/deforest.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
#' Deforesting a random forest
#'
#' The main purpose of this function is to allow for post-processing of
#' ensembles via L2 regularized regression (i.e., the LASSO), as described in
#' Friedman and Popescu (2003). The basic idea is to use the LASSO to
#' post-process the predictions from the individual base learners in an ensemble
#' (i.e., decision trees) in the hopes of producing a much smaller model without
#' sacrificing much in the way of accuracy, and in some cases, improving it.
#' Friedman and Popescu (2003) describe conditions under which tree-based
#' ensembles, like random forest, can potentially benefit from such
#' post-processing (e.g., using shallower trees trained on much smaller samples
#' of the training data without replacement). However, the computational
#' benefits of such post-processing can only be realized if the base learners
#' "zeroed out" by the LASSO can actually be removed from the original ensemble,
#' hence the purpose of this function. A complete example using
#' \code{\link{ranger}} can be found at
#' \url{https://github.com/imbs-hl/ranger/issues/568}.
#'
#' @param object A fitted random forest (e.g., a \code{\link{ranger}}
#' object).
#'
#' @param which.trees Vector giving the indices of the trees to remove.
#'
#' @param warn Logical indicating whether or not to warn users that some of the
#' standard output of a typical \code{\link{ranger}} object or no longer
#' available after deforestation. Default is \code{TRUE}.
#'
#' @param ... Additional (optional) arguments. (Currently ignored.)
#'
#' @return An object of class \code{"deforest.ranger"}; essentially, a
#' \code{\link{ranger}} object with certain components replaced with
#' \code{NA}s (e.g., out-of-bag (OOB) predictions, variable importance scores
#' (if requested), and OOB-based error metrics).
#'
#' @note This function is a generic and can be extended by other packages.
#'
#' @references
#' Friedman, J. and Popescu, B. (2003). Importance sampled learning ensembles,
#' Technical report, Stanford University, Department of Statistics.
#' \url{https://statweb.stanford.edu/~jhf/ftp/isle.pdf}.
#'
#' @rdname deforest
#'
#' @export
#'
#' @author Brandon M. Greenwell
#'
#' @examples
#' ## Example of deforesting a random forest
#' rfo <- ranger(Species ~ ., data = iris, probability = TRUE, num.trees = 100)
#' dfo <- deforest(rfo, which.trees = c(1, 3, 5))
#' dfo # same as `rfo` but with trees 1, 3, and 5 removed
#'
#' ## Sanity check
#' preds.rfo <- predict(rfo, data = iris, predict.all = TRUE)$predictions
#' preds.dfo <- predict(dfo, data = iris, predict.all = TRUE)$predictions
#' identical(preds.rfo[, , -c(1, 3, 5)], y = preds.dfo)
deforest <- function(object, which.trees = NULL, ...) {
UseMethod("deforest")
}


#' @rdname deforest
#'
#' @export
deforest.ranger <- function(object, which.trees = NULL, warn = TRUE, ...) {

# Warn users about `predictions` and `prediction.error` components
if (isTRUE(warn)) {
warning("Many of the components of a typical \"ranger\" object are ",
"not available after deforestation and are instead replaced with ",
"`NA` (e.g., out-of-bag (OOB) predictions, variable importance ",
"scores (if requested), and OOB-based error metrics).",
call. = FALSE)
}

# "Remove trees" by removing necessary components from `forest` object
object$forest$child.nodeIDs[which.trees] <- NULL
object$forest$split.values[which.trees] <- NULL
object$forest$split.varIDs[which.trees] <- NULL
object$forest$terminal.class.counts[which.trees] <- NULL # for prob forests
object$forest$chf[which.trees] <- NULL # for survival forests

# Update `num.trees` components so `predict.ranger()` works
object$forest$num.trees <- object$num.trees <-
length(object$forest$child.nodeIDs)

# Coerce other components to `NA` as needed
if (!is.null(object$prediction.error)) {
object$prediction.error <- NA
}
if (!is.null(object$predictions)) { # classification and regression
object$predictions[] <- NA
}
if (!is.null(object$r.squared)) { # regression
object$r.squared <- NA
}
if (!is.null(object$chf)) { # survival forests
object$chf[] <- NA
}
if (!is.null(object$survival)) { # survival forests
object$survival[] <- NA
}
if (object$importance.mode != "none") { # variable importance
object$importance.mode <- NA
object$variable.importance[] <- NA
}

# Return "deforested" forest
class(object) <- c("deforest.ranger", class(object))
object

}


#' Print deforested ranger summary
#'
#' Print basic information about a deforested \code{\link{ranger}} object.
#'
#' @param x A \code{\link{deforest}} object (i.e., an object that inherits from
#' class \code{"deforest.ranger"}).
#'
#' @param ... Further arguments passed to or from other methods.
#'
#' @note Many of the components of a typical \code{\link{ranger}} object are not
#' available after deforestation and are instead replaced with \code{NA} (e.g.,
#' out-of-bag (OOB) predictions, variable importance scores (if requested), and
#' OOB-based error metrics).
#'
#' @seealso \code{\link{deforest}}.
#'
#' @author Brandon M. Greenwell
#'
#' @export
print.deforest.ranger <- function (x, ...) {
cat("Ranger (deforested) result\n\n")
cat("Note that many of the components of a typical \"ranger\" object are",
"not available after deforestation and are instead replaced with `NA`",
"(e.g., out-of-bag (OOB) predictions, variable importance scores (if",
"requested), and OOB-based error metrics)",
"\n\n")
cat("Type: ", x$treetype, "\n")
cat("Number of trees: ", x$num.trees, "\n")
cat("Sample size: ", x$num.samples, "\n")
cat("Number of independent variables: ", x$num.independent.variables, "\n")
cat("Mtry: ", x$mtry, "\n")
cat("Target node size: ", x$min.node.size, "\n")
cat("Variable importance mode: ", x$importance.mode, "\n")
cat("Splitrule: ", x$splitrule, "\n")
if (x$treetype == "Survival") {
cat("Number of unique death times: ", length(x$unique.death.times), "\n")
}
if (!is.null(x$splitrule) && x$splitrule == "extratrees" &&
!is.null(x$num.random.splits)) {
cat("Number of random splits: ", x$num.random.splits, "\n")
}
if (x$treetype == "Classification") {
cat("OOB prediction error: ", x$prediction.error, "\n")
}
else if (x$treetype == "Regression") {
cat("OOB prediction error (MSE): ", x$prediction.error, "\n")
}
else if (x$treetype == "Survival") {
cat("OOB prediction error (1-C): ", x$prediction.error, "\n")
}
else if (x$treetype == "Probability estimation") {
cat("OOB prediction error (Brier s.): ", x$prediction.error, "\n")
}
else {
cat("OOB prediction error: ", x$prediction.error, "\n")
}
if (x$treetype == "Regression") {
cat("R squared (OOB): ", x$r.squared, "\n")
}
}
68 changes: 68 additions & 0 deletions man/deforest.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

29 changes: 29 additions & 0 deletions man/print.deforest.ranger.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

43 changes: 43 additions & 0 deletions tests/testthat/test_deforest.R
50C7
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
library(ranger)
library(survival)
context("ranger_deforest")


test_that("deforest works as expected for probability estimation", {
rfo <- ranger(Species ~ ., data = iris, num.trees = 10, probability = TRUE)
dfo <- deforest(rfo, which.trees = c(1, 3, 5), warn = FALSE)
pred.rfo <- predict(rfo, data = iris, predict.all = TRUE)$predictions
pred.dfo <- predict(dfo, data = iris, predict.all = TRUE)$predictions
expect_identical(pred.rfo[, , -c(1, 3, 5)], pred.dfo)
})

test_that("deforest works as expected for classification", {
rfo <- ranger(Species ~ ., data = iris, num.trees = 10)
dfo <- deforest(rfo, which.trees = c(1, 3, 5), warn = FALSE)
pred.rfo <- predict(rfo, data = iris, predict.all = TRUE)$predictions
pred.dfo <- predict(dfo, data = iris, predict.all = TRUE)$predictions
expect_identical(pred.rfo[, -c(1, 3, 5)], pred.dfo)
})

test_that("deforest works as expected for regression", {
n <- 50
x <- runif(n, min = 0, max = 2*pi)
dat <- data.frame(x = x, y = sin(x) + rnorm(n, sd = 0.1))
rfo <- ranger(y ~ ., data = dat, num.trees = 10)
dfo <- deforest(rfo, which.trees = c(1, 3, 5), warn = FALSE)
pred.rfo <- predict(rfo, data = dat, predict.all = TRUE)$predictions
pred.dfo <- predict(dfo, data = dat, predict.all = TRUE)$predictions
expect_identical(pred.rfo[, -c(1, 3, 5)], pred.dfo)
})

test_that("deforest works as expected for censored outcomes", {
dat <- data.frame(time = runif(100, 1, 10), status = rbinom(100, 1, .5),
x = rbinom(100, 1, .5))
rfo <- ranger(Surv(time, status) ~ x, data = dat, num.trees = 10,
splitrule = "logrank")
dfo <- deforest(rfo, which.trees = c(1, 3, 5), warn = FALSE)
pred.rfo <- predict(rfo, data = dat, predict.all = TRUE)
pred.dfo <- predict(dfo, data = dat, predict.all = TRUE)
expect_identical(pred.rfo$chf[, , -c(1, 3, 5)], pred.dfo$chf)
expect_identical(pred.rfo$survival[, , -c(1, 3, 5)], pred.dfo$survival)
})
0