【R】tidymodels の Tuning
1. はじめに
機械学習でどのようにパラメータのチューニングをすればよいかのお勉強。TidymodelsのページとJulia Silge氏のYouTube「Tuning random forest hyperparameters with tidymodels」と「Hyperparameter tuning using tidymodels」を参考にしました。
Swiss Bank Note のパラメータ=6か所の長さの測定結果から、ニセ札かどうかを分類する機械学習を行います。
2. データ
使用するデータは、Wolfram Data RepositoryからSwiss Bank Noteというデータを用います。これは、1988年に、B. Flury and H. Riedwylが発表したデータで、100のニセ札と100の本物のお札について、お札の6か所の長さを測ったものです。
上記のWolfram Data RepositoryのデータをCSV形式でダウンロードしましたが、そのままでは使いにくいので、少し修正しました。
# libraries --------------------------------------------------------------- library(tidyverse) library(tidymodels) # data -------------------------------------------------------------------- banknote_df <- read.csv("http://www.dinov.tokyo/Data/JP_Pref/Sample-Data-Swiss-Bank-Notes.csv") head(banknote_df)
Length Left Right Bottom Top Diagonal Genuine.Counterfeit
1 214.8 131.0 131.1 9.0 9.7 141.0 counterfeit
2 214.6 129.7 129.7 8.1 9.5 141.7 counterfeit
3 214.8 129.7 129.7 8.7 9.6 142.2 counterfeit
4 214.8 129.7 129.6 7.5 10.4 142.0 counterfeit
5 215.0 129.6 129.7 10.4 7.7 141.8 counterfeit
6 215.7 130.8 130.5 9.0 10.1 141.4 counterfeit
3. 機械学習
3.1 データ準備
トレーニングデータとテストデータを準備します。元のデータを3:1に分割します。
# 1: Initial Split -------------------------------------------------------- set.seed(1234) banknote_split <- initial_split(banknote_df, strada = Genuine.Counterfeit) banknote_train <- training(banknote_split) head(banknote_train) banknote_test <- testing(banknote_split) head(banknote_test)
> banknote_train <- training(banknote_split)
> head(banknote_train)
Length Left Right Bottom Top Diagonal Genuine.Counterfeit
1 214.8 131.0 131.1 9.0 9.7 141.0 counterfeit
2 214.6 129.7 129.7 8.1 9.5 141.7 counterfeit
3 214.8 129.7 129.7 8.7 9.6 142.2 counterfeit
4 214.8 129.7 129.6 7.5 10.4 142.0 counterfeit
5 215.0 129.6 129.7 10.4 7.7 141.8 counterfeit
6 215.7 130.8 130.5 9.0 10.1 141.4 counterfeit
> banknote_test <- testing(banknote_split)
> head(banknote_test)
Length Left Right Bottom Top Diagonal Genuine.Counterfeit
8 214.5 129.6 129.2 7.2 10.7 141.7 counterfeit
10 215.2 130.4 130.3 9.2 10.0 140.7 counterfeit
17 214.6 129.9 130.1 8.2 9.8 141.7 counterfeit
19 215.2 129.6 129.6 7.4 11.5 141.5 counterfeit
20 214.7 130.2 129.9 8.6 10.0 141.9 counterfeit
22 215.6 130.5 130.0 8.1 10.3 141.6 counterfeit
3.2 レシピ
目的変数を”Genuine.Counterfeit”として、レシピを準備します。
# 2: Preprocessing -------------------------------------------------------- banknote_rec <- recipe(Genuine.Counterfeit ~ ., data = banknote_train) banknote_prep <- prep(banknote_rec) juiced <- juice(banknote_prep)
> juiced
# A tibble: 150 x 7
Length Left Right Bottom Top Diagonal Genuine.Counterfeit
<dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <fct>
1 215. 131 131. 9 9.7 141 counterfeit
2 215. 130. 130. 8.1 9.5 142. counterfeit
3 215. 130. 130. 8.7 9.6 142. counterfeit
4 215 130. 130. 10.4 7.7 142. counterfeit
5 216. 131. 130. 9 10.1 141. counterfeit
6 216. 130. 130. 7.9 9.6 142. counterfeit
7 214. 130. 129. 7.2 10.7 142. counterfeit
8 215. 129. 130. 8.2 11 142. counterfeit
9 215. 130. 130. 9.2 10 141. counterfeit
10 215. 130. 130. 7.9 11.7 142. counterfeit
# ... with 140 more rows
3.3 モデルとチューニングの仕様
モデルとチューニングの仕様を決めます。ランダムフォレストを使用し、modelは分類、エンジンはrangerとします。mtryとmin_nをチューニングします。
# 3, 4: Model Specification & Hyperparameter Tuning Specification -------------------------------------------------- banknote_spec <- rand_forest( mtry = tune(), trees = 1000, min_n = tune() ) %>% set_mode("classification") %>% set_engine("ranger")
> banknote_spec
Random Forest Model Specification (classification)
Main Arguments:
mtry = tune()
trees = 1000
min_n = tune()
Computational engine: ranger
3.4 ワークフロー
これらをまとめて、ワークフローに定義します。
# 5: Bundle into Workflow ------------------------------------------------- tune_wf <- workflow() %>% add_recipe(banknote_rec) %>% add_model(banknote_spec)
> tune_wf
== Workflow ====================================================================
Preprocessor: Recipe
Model: rand_forest()
-- Preprocessor ----------------------------------------------------------------
0 Recipe Steps
-- Model -----------------------------------------------------------------------
Random Forest Model Specification (classification)
Main Arguments:
mtry = tune()
trees = 1000
min_n = tune()
Computational engine: ranger
3.5 クロスバリデーション準備
# 6: Cross Validation ----------------------------------------------------- set.seed(234) banknote_fold <- vfold_cv(banknote_train)
> banknote_fold
# 10-fold cross-validation
# A tibble: 10 x 2
splits id
<named list> <chr>
1 <split [135/15]> Fold01
2 <split [135/15]> Fold02
3 <split [135/15]> Fold03
4 <split [135/15]> Fold04
5 <split [135/15]> Fold05
6 <split [135/15]> Fold06
7 <split [135/15]> Fold07
8 <split [135/15]> Fold08
9 <split [135/15]> Fold09
10 <split [135/15]> Fold10
3.6 チューニング
20個のグリッドに分け、チューニングしてみます。
# 7: Tune ----------------------------------------------------------------- doParallel::registerDoParallel() set.seed(345) tune_res <- tune_grid( tune_wf, resamples = banknote_fold, grid = 20 )
結果を見ますと。
> tune_res %>%
+ collect_metrics()
# A tibble: 40 x 7
mtry min_n .metric .estimator mean n std_err
<int> <int> <chr> <chr> <dbl> <int> <dbl>
1 1 6 accuracy binary 0.993 10 0.00667
2 1 6 roc_auc binary 1 10 0
3 1 39 accuracy binary 0.993 10 0.00667
4 1 39 roc_auc binary 0.998 10 0.002
5 2 15 accuracy binary 0.993 10 0.00667
6 2 15 roc_auc binary 1 10 0
7 2 17 accuracy binary 0.993 10 0.00667
8 2 17 roc_auc binary 1 10 0
9 2 30 accuracy binary 0.993 10 0.00667
10 2 30 roc_auc binary 1 10 0
# ... with 30 more rows
この時点で、これ以上のチューニングは不要では。。。?
tune_res %>% collect_metrics() %>% filter(.metric == "roc_auc") %>% select(mean, min_n, mtry) %>% pivot_longer(min_n:mtry, values_to = "value", names_to = "parameter") %>% ggplot(aes(value, mean, color=parameter)) + geom_point(show.legend = FALSE) + facet_wrap(~ parameter, scales = "free_x")

練習のため、少しパラメータをいじって、再度実施してみます。
rf_grid <- grid_regular( mtry(range = c(2, 5)), min_n( range = c(10, 20)), levels = 5 ) set.seed(456) regular_res <- tune_grid( tune_wf, resamples = banknote_fold, grid = rf_grid )
> regular_res %>%
+ collect_metrics()
# A tibble: 40 x 7
mtry min_n .metric .estimator mean n std_err
<int> <int> <chr> <chr> <dbl> <int> <dbl>
1 2 10 accuracy binary 0.987 10 0.00889
2 2 10 roc_auc binary 1 10 0
3 2 12 accuracy binary 0.993 10 0.00667
4 2 12 roc_auc binary 1 10 0
5 2 15 accuracy binary 0.993 10 0.00667
6 2 15 roc_auc binary 1 10 0
7 2 17 accuracy binary 0.993 10 0.00667
8 2 17 roc_auc binary 1 10 0
9 2 20 accuracy binary 0.993 10 0.00667
10 2 20 roc_auc binary 1 10 0
# ... with 30 more rows
結果は、良くはなりましたが。。。
3.7 ファイナライズ
このモデルで良さそうなので、モデルのファイナライズをします。
# 9: Finalize Workflow ---------------------------------------------------- best_auc <- select_best(regular_res, "roc_auc") final_rf <- finalize_model( banknote_spec, best_auc ) # 10: Final Fit ----------------------------------------------------------- library(vip) final_rf %>% set_engine("ranger", importance = "permutation") %>% fit(Genuine.Counterfeit ~ ., data = juice(banknote_prep)) %>% vip(geom = "point") final_wf <- workflow() %>% add_recipe(banknote_rec) %>% add_model(final_rf) final_res <- final_wf %>% last_fit(banknote_split)
3.8 評価
最終的な評価をします。
# 11: Evaluate ------------------------------------------------------------ final_res %>% collect_metrics() final_res %>% collect_predictions() %>% conf_mat(truth = Genuine.Counterfeit, estimate = .pred_class)
> final_res %>%
+ collect_metrics()
# A tibble: 2 x 3
.metric .estimator .estimate
<chr> <chr> <dbl>
1 accuracy binary 1
2 roc_auc binary 1
> final_res %>%
+ collect_predictions() %>%
+ conf_mat(truth = Genuine.Counterfeit, estimate = .pred_class)
Truth
Prediction counterfeit genuine
counterfeit 25 0
genuine 0 25
すごいですね。認識率100%!
4.さいごに
今回は、あまりチューニングの意味がなかったとおもいますが、勉強にはなりました。