Random Forest
Mathias Riechel, Michael Rustler
Source:vignettes/prediction_random-forest.Rmd
prediction_random-forest.Rmd
Install R Package
# Enable KWB-R universe
options(repos = c(
kwbr = 'https://kwb-r.r-universe.dev',
CRAN = 'https://cloud.r-project.org'))
# Install R package "dwc.wells"
install.packages('dwc.wells', dependencies = TRUE)
Input Dataset
library(dwc.wells)
df <- dwc.wells::model_data_reduced
str(df)
#> 'data.frame': 6308 obs. of 27 variables:
#> $ Qs_rel : num 100 97.8 75.1 78.5 52.2 ...
#> $ well_id : int 872166 872166 872166 872166 872166 872166 872166 872166 872166 872166 ...
#> $ well_age_years : num 0 23.2 29.3 29.5 35 ...
#> $ construction_year : num 1970 1970 1970 1970 1970 1970 1970 1970 1970 1970 ...
#> $ screen_material : Factor w/ 6 levels "3fe9575b","5bad81ff",..: 2 2 2 2 2 2 2 2 2 2 ...
#> $ diameter : num 300 300 300 300 300 300 300 300 300 300 ...
#> $ drilling_method : Factor w/ 6 levels "418861c3","484a778f",..: 1 1 1 1 1 1 1 1 1 1 ...
#> $ admissible_discharge : num 176 176 176 176 176 176 176 176 176 176 ...
#> $ operational_start.Qs : num 39.1 39.1 39.1 39.1 39.1 ...
#> $ aquifer_coverage : Factor w/ 5 levels "confined","edges confined",..: 1 1 1 1 1 1 1 1 1 1 ...
#> $ W_static.sd : num 0.766 0.766 0.766 0.766 0.766 ...
#> $ surface_water.distance: Factor w/ 8 levels "0-25","25-50",..: 1 1 1 1 1 1 1 1 1 1 ...
#> $ n_rehab : int 0 1 1 2 2 3 3 4 4 4 ...
#> $ time_since_rehab_years: num 0 0.0411 6.193 0.0794 5.5524 ...
#> $ volume_m3_d.mean : num 1203 1203 1203 1203 1203 ...
#> $ volume_m3_d.cv : num 1.29 1.29 1.29 1.29 1.29 ...
#> $ quality.EC : num 824 824 824 824 824 ...
#> $ quality.DO : num 0.235 0.235 0.235 0.235 0.235 0.235 0.235 0.235 0.235 0.235 ...
#> $ quality.Temp : num 11.7 11.7 11.7 11.7 11.7 11.7 11.7 11.7 11.7 11.7 ...
#> $ quality.pH : num 7.4 7.4 7.4 7.4 7.4 7.4 7.4 7.4 7.4 7.4 ...
#> $ quality.Redox : num 86 86 86 86 86 86 86 86 86 86 ...
#> $ quality.Fe_tot : num 1.9 1.9 1.9 1.9 1.9 1.9 1.9 1.9 1.9 1.9 ...
#> $ quality.Mn : num 0.24 0.24 0.24 0.24 0.24 0.24 0.24 0.24 0.24 0.24 ...
#> $ quality.NO3 : num 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 ...
#> $ quality.PO4 : num 0.966 0.966 0.966 0.966 0.966 ...
#> $ quality.SO4 : num 94 94 94 94 94 94 94 94 94 94 ...
#> $ quality.TSS : num 4.1 4.1 4.1 4.1 4.1 4.1 4.1 4.1 4.1 4.1 ...
Resampling
resampling <- "random"
#resampling <- "by_well"
set.seed(1)
if (resampling == "random") {
# for regression
data_split <- rsample::initial_split(df %>% dplyr::select(-well_id),
prop = 0.8,
strata = Qs_rel)
df_training <- data_split %>% rsample::training()
df_test <- data_split %>% rsample::testing()
}
# version 2: splitting per well ids
if (resampling == "by_well") {
well_ids <- unique(df$well_id)
train_ids <- sample(well_ids, 0.8 * length(well_ids))
test_ids <- setdiff(well_ids, train_ids)
df_training <- df %>%
dplyr::filter(well_id %in% train_ids) %>%
dplyr::select(-well_id)
df_test <- df %>%
dplyr::filter(well_id %in% test_ids) %>%
dplyr::select(-well_id)
}
tibble::as_tibble(df_training)
#> # A tibble: 5,044 × 26
#> Qs_rel well_age_years const…¹ scree…² diame…³ drill…⁴ admis…⁵ opera…⁶ aquif…⁷
#> <dbl> <dbl> <dbl> <fct> <dbl> <fct> <dbl> <dbl> <fct>
#> 1 31.5 42.3 1970 5bad81… 300 418861… 176 39.1 confin…
#> 2 20.7 44.2 1970 5bad81… 300 418861… 176 39.1 confin…
#> 3 16.3 47.9 1970 5bad81… 300 418861… 176 39.1 confin…
#> 4 22.4 47.9 1970 5bad81… 300 418861… 176 39.1 confin…
#> 5 25.6 25.8 1993 93242c… 393 418861… 93.5 44.7 unconf…
#> 6 19.0 16.2 1993 93242c… 393 418861… 42 24.8 confin…
#> 7 28.9 16.3 1993 93242c… 393 418861… 42 24.8 confin…
#> 8 37.5 15.9 1993 93242c… 393 418861… 40 21.3 confin…
#> 9 35.1 24.4 1993 93242c… 393 418861… 40 21.3 confin…
#> 10 30.4 20.7 1993 93242c… 393 418861… 119 28.5 confin…
#> # … with 5,034 more rows, 17 more variables: W_static.sd <dbl>,
#> # surface_water.distance <fct>, n_rehab <int>, time_since_rehab_years <dbl>,
#> # volume_m3_d.mean <dbl>, volume_m3_d.cv <dbl>, quality.EC <dbl>,
#> # quality.DO <dbl>, quality.Temp <dbl>, quality.pH <dbl>,
#> # quality.Redox <dbl>, quality.Fe_tot <dbl>, quality.Mn <dbl>,
#> # quality.NO3 <dbl>, quality.PO4 <dbl>, quality.SO4 <dbl>, quality.TSS <dbl>,
#> # and abbreviated variable names ¹construction_year, ²screen_material, …
tibble::as_tibble(df_test)
#> # A tibble: 1,264 × 26
#> Qs_rel well_age_years const…¹ scree…² diame…³ drill…⁴ admis…⁵ opera…⁶ aquif…⁷
#> <dbl> <dbl> <dbl> <fct> <dbl> <fct> <dbl> <dbl> <fct>
#> 1 97.8 23.2 1970 5bad81… 300 418861… 176 39.1 confin…
#> 2 78.5 29.5 1970 5bad81… 300 418861… 176 39.1 confin…
#> 3 52.2 35.0 1970 5bad81… 300 418861… 176 39.1 confin…
#> 4 27.2 42.9 1970 5bad81… 300 418861… 176 39.1 confin…
#> 5 35.1 20.9 1993 93242c… 393 418861… 93.5 44.7 unconf…
#> 6 82.7 6.07 1993 93242c… 393 418861… 42 24.8 confin…
#> 7 26.8 20.9 1993 93242c… 393 418861… 40 21.3 confin…
#> 8 16.9 24.4 1993 93242c… 393 418861… 40 21.3 confin…
#> 9 56.1 16.5 1993 93242c… 393 418861… 119 28.5 confin…
#> 10 68.5 6.11 1993 93242c… 393 418861… 75 33.5 edges …
#> # … with 1,254 more rows, 17 more variables: W_static.sd <dbl>,
#> # surface_water.distance <fct>, n_rehab <int>, time_since_rehab_years <dbl>,
#> # volume_m3_d.mean <dbl>, volume_m3_d.cv <dbl>, quality.EC <dbl>,
#> # quality.DO <dbl>, quality.Temp <dbl>, quality.pH <dbl>,
#> # quality.Redox <dbl>, quality.Fe_tot <dbl>, quality.Mn <dbl>,
#> # quality.NO3 <dbl>, quality.PO4 <dbl>, quality.SO4 <dbl>, quality.TSS <dbl>,
#> # and abbreviated variable names ¹construction_year, ²screen_material, …
Regression
Hyperparameter Tuning
# Specify model ----------------------------------------------------------------
# specify model
rf_reg_tune_model <- parsnip::rand_forest(trees = 500,
mtry = tune::tune(),
min_n = tune::tune()) %>%
# Specify the engine
parsnip::set_engine('randomForest') %>%
# Specify the mode
parsnip::set_mode('regression')
# specify recipe
rec <- recipes::recipe(Qs_rel ~ ., data = df_training)
# setup workflow
rf_reg_tune_wflow <- workflows::workflow() %>%
workflows::add_recipe(rec) %>%
workflows::add_model(rf_reg_tune_model)
# define cross validation procedure
cv_folds <- rsample::vfold_cv(df_training, v = 5)
# define hyperparameter grid
#rf_reg_grid <- grid_random(parameters(rf_reg_tune_model), size = 100)
#rf_reg_grid <- grid_regular(parameters(rf_reg_tune_model), c(5,5))
rf_reg_grid <- dials::grid_regular(dials::mtry(range = c(3, 15)),
dials::min_n(range = c(5, 15)),
levels = 13)
# parallelisation and tuning
doParallel::registerDoParallel()
set.seed(345)
rf_reg_tuning <- tune::tune_grid(
rf_reg_tune_wflow,
resamples = cv_folds,
grid = rf_reg_grid
)
# visualise results
metrics <- rf_reg_tuning %>% tune::collect_metrics()
dwc.wells::save_data(metrics,
path = getwd(),
filename = sprintf("rf-regression_tuning-grid_resampling-%s_metric",
resampling))
# visualise results
metrics %>%
dplyr::filter(.metric == "rmse") %>%
dplyr::select(mean, min_n, mtry) %>%
tidyr::pivot_longer(min_n:mtry,
values_to = "value",
names_to = "parameter") %>%
ggplot2::ggplot(ggplot2::aes(value, mean, color = parameter)) +
ggplot2::geom_point(show.legend = FALSE) +
ggplot2::scale_x_continuous(breaks = seq.int(1, 15, 2)) +
ggplot2::facet_wrap(~parameter, scales = "free_x") +
ggplot2::labs(x = NULL, y = "RMSE [%]") +
sema.berlin.utils::my_theme()
ggplot2::ggsave("rf_reg_regression_hyperparameter_tuning_plot_regular_random_resampling.png",
width = 6,
height = 3,
dpi = 600)
# raster heatmap plot
metrics %>%
dplyr::filter(.metric == "rmse") %>%
dplyr::select(mean, min_n, mtry) %>%
ggplot2::ggplot(aes(x = min_n, y = mtry, fill = mean)) +
ggplot2::geom_raster() +
ggplot2::scale_x_continuous(breaks = seq.int(1, 15, 2)) +
ggplot2::scale_y_continuous(breaks = seq.int(1, 15, 2)) +
ggplot2::labs(fill = "RMSE [%]") +
sema.berlin.utils::my_theme()
ggplot2::ggsave("rf_reg_regression_hyperparameter_tuning_plot_regular_random_resampling_heatmap.png",
width = 5,
height = 3,
dpi = 600)
# determine best model
best_rmse <- tune::select_best(rf_reg_tuning, "rmse")
dwc.wells::save_data(best_rmse,
path = getwd(),
filename = "rf_reg_regression_best_model_regular",
"RData")
final_rf_reg <- tune::finalize_model(rf_reg_tune_model, best_rmse)
# update workflow
rf_reg_final_wflow <- workflows::workflow() %>%
workflows::add_recipe(rec) %>%
workflows::add_model(final_rf_reg)
# train and test model / workflow
rf_reg_final_fit <- rf_reg_final_wflow %>% tune::last_fit(data_split)
#get metrics
rf_reg_final_fit %>% tune::collect_metrics()
# get predictions
df_pred <- rf_reg_final_fit %>% tune::collect_predictions()
# Evaluate model performance ---
ggplot2::scatterplot(df_pred)
ggplot2::ggsave("random_forest_regression_tuned_regular.png", width = 3.5, height = 3)
Best-Fit Model
rf_reg_model <- parsnip::rand_forest(trees = 500,
mtry = 6,
min_n = 10) %>%
# Specify the engine
parsnip::set_engine('randomForest') %>%
# Specify the mode
parsnip::set_mode('regression')
# Model training and assessment (regression) -----------------------------------
# Train model
set.seed(26)
rf_reg_fit <- rf_reg_model %>% parsnip::fit(Qs_rel ~ ., data = df_training)
# Make predictions
predictions <- predict(rf_reg_fit, df_test)
# Evaluate model performance
df_pred <- df_test %>%
dplyr::select(Qs_rel) %>%
dplyr::bind_cols(predictions)
yardstick::rmse(df_pred, truth = Qs_rel, estimate = .pred)
#> # A tibble: 1 × 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 rmse standard 16.3
yardstick::rsq(df_pred, truth = Qs_rel, estimate = .pred)
#> # A tibble: 1 × 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 rsq standard 0.747
# scatter plot
dwc.wells::scatterplot(df_pred,
lines_80perc = FALSE,
alpha = 1,
pointsize = 0.9)
#> Warning: Removed 142 rows containing missing values (geom_point).
ggplot2::ggsave("scatterplot_rf-regression_numeric.png",
dpi = 600,
width = 3.5,
height = 3)
#> Warning: Removed 142 rows containing missing values (geom_point).
Classification
# classification performance ---------------------------------------------------
# classify Qs data
df_pred <- df_pred %>%
dplyr::mutate(Qs_rel_class = dwc.wells::classify_Qs(Qs_rel),
.pred_class = dwc.wells::classify_Qs(.pred))
# confusion matrix
matrix <- yardstick::conf_mat(df_pred,
truth = Qs_rel_class,
estimate = .pred_class)
matrix
#> Truth
#> Prediction low high
#> low 743 139
#> high 27 355
# performance metrics
metrics <- summary(matrix)
metrics
#> # A tibble: 13 × 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 accuracy binary 0.869
#> 2 kap binary 0.713
#> 3 sens binary 0.965
#> 4 spec binary 0.719
#> 5 ppv binary 0.842
#> 6 npv binary 0.929
#> 7 mcc binary 0.726
#> 8 j_index binary 0.684
#> 9 bal_accuracy binary 0.842
#> 10 detection_prevalence binary 0.698
#> 11 precision binary 0.842
#> 12 recall binary 0.965
#> 13 f_meas binary 0.900
dwc.wells::save_data(matrix,
path = getwd(),
filename = "rf_numeric_to_class_matrix_split80",
formats = "RData")