【R】Tidymodelsで機械学習 モデル・アルゴリズム

1.はじめに

機械学習の分類でモデルやアルゴリズムでどれだけ結果が異なるかを見てみます。

2.データ

使用するデータは、kernlabパッケージに含まれるspamです。このデータは、Hewlett-Packard Labsで集められた4601の電子メールがスパムかそうでないかを記録したもので、57の変数とspam/nonspamのデータからなっています。変数には、電子メールの内容に数字や文字がどれだけあるかや、特定の文字がどの頻度であるかなどの情報があります。

次のように、準備します。

library(kernlab)
data(spam)

head(spam)

3.事前準備

機械学習を行う準備です。Introduction to Machine Learning with the Tidyverseを参考にしました。

library(tidyverse)
library(tidymodels)
library(tune)
library(parsnip)

fit_split <- function(formula, model, split, ...) {
  wf <- workflows::add_model(workflows::add_formula(workflows::workflow(), formula, blueprint = hardhat::default_formula_blueprint(indicators = FALSE, allow_novel_levels = TRUE)), model)
  tune::last_fit(wf, split, ...)
}

set.seed(100) 
spam_split  <- initial_split(spam)

4.やってみる

実際に、モデルやアルゴリズムを変更して、結果を見てみます。

4.1 決定木-rpart

モデルとして決定木を使い、アルゴリズムをrpartで機械学習を実行してみます。

spam_spec_rpart <- 
  decision_tree() %>%         
  set_engine("rpart") %>%      
  set_mode("classification") 

dt_fit_rpart <- fit_split(type ~ .,
                    model = spam_spec_rpart,
                    split = spam_split) 

dt_fit_rpart %>%   
  collect_predictions() %>% 
  conf_mat(truth = type, estimate = .pred_class)
          Truth
Prediction nonspam spam
   nonspam     646   80
   spam         47  377

正解率は89.0%で、誤認識率は11%でした。

4.2 決定木-C5.0

認識率が良いといわれるC5.0アルゴリズムを使ってみます。これは最も有名な実装の一つでコンピュータ科学者のJ. Ross Quinlanが自分が開発したC4.5を改良して作ったものらしいです。商用として販売していますが、アルゴリズムのソースコードを公開しているので様々なプログラムに移植されているようです。大規模でも小規模でもほとんどの問題で高い性能を発揮する汎用分類器らしいです。

spam_spec_C50 <- decision_tree() %>%
  set_engine("C5.0") %>%
  set_mode("classification")

dt_fit_C50 <- fit_split(type ~ .,
                    model = spam_spec_C50,
                    split = spam_split) 

dt_fit_C50 %>%   
  collect_predictions() %>% 
  conf_mat(truth = type, estimate = .pred_class)

          Truth
Prediction nonspam spam
   nonspam     646   42
   spam         47  415

正解率は92.3%、誤認識率は7.7%でした。rpartより良いですね。

4.3 ランダムフォレスト-ranger

ランダムフォレストを使ってみます。

rf_spec <- 
  rand_forest() %>%         
  set_engine("ranger") %>%      
  set_mode("classification") 

rf_fit <- fit_split(type ~ ., 
                    model = rf_spec, 
                    split = spam_split) 

rf_fit %>%   
  collect_predictions() %>% 
  conf_mat(truth = type, estimate = .pred_class)
          Truth
Prediction nonspam spam
   nonspam     670   37
   spam         23  420

正解率は94.8%、誤認識率は5.2%でした。決定木よりランダムフォレストの方が良い結果ですね。

5.おわりに

機械学習のモデルやアルゴリズムの違いで、どれくらい結果が異なるか比べてみました。本来であれば、学習・評価の後でリサンプリングやチューニング等をするのですが、今回の結果からもモデルやアルゴリズムの違いによる結果の差の指標を得られたと思います。

Add a Comment

メールアドレスが公開されることはありません。