tidymodelsによるモデル構築と運用 / tidymodels

tidymodelsによるモデル構築と運用 / tidymodels

Fukuoka.R#15 https://fukuoka-r.connpass.com/event/139211/ での発表資料です
リポジトリ http://github.com/uribo/190831_fukuokar15

D12a80cab206033a820ccff8319f957b?s=128

Uryu Shinya

August 31, 2019
Tweet

Transcript

  1. Uryu Shinya @u_ribo tidymodelsʹΑΔ Ϟσϧߏஙͱӡ༻ 190831 Fukuoka.R@LINE෱Ԭ

  2. ໨࣍ UJEZNPEFMTʹΑΔϞσϧߏங Ϟσϧͷվળͱӡ༻ σʔλϞσϦϯάͷϫʔΫϑϩʔ 1 2 3

  3. σʔλϞσϦϯάͷ ϫʔΫϑϩʔ tidymodelsʹΑΔϞσϧߏஙͱӡ༻ 190831 Fukuoka.R@LINE෱Ԭ

  4. Garrett and Hadley (2016)ΑΓ࡞੒ tidymodels͸͜͜ σʔλ෼ੳͷϫʔΫϑϩʔ

  5. ൓෮తͳ࡞ۀͰϞσϧΛຏ্͖͍͛ͯ͘ Ϟσϧߏஙͷయܕతͳ࿮૊Έ Max and Kjell (2019)ΑΓ࡞੒

  6. B σʔλͷಛ௃ɺσʔλؒͷؔ܎Λ஌ΓɺॳظϞσϧʹར༻͢ Δʮग़ൃ఺ʯΛݟ͚ͭΔͨΊͷࢹ֮Խɻ C ౷ܭྔͷूܭ΍໨తม਺ͱڧ͍૬ؔͷ͋Δม਺ΛಛఆɺϞσ ϧʹର͢ΔԾઆΛཱͯΔɻσʔλΛे෼ʹཧղͰ͖ͨͱݴ͑ Δ·Ͱɺؔ܎ΛՄࢹԽ͠ɺ͞ΒͳΔఆྔ෼ੳΛ܁Γฦ͢ɻ D σʔλΛॳظϞσϧʹద༻͢ΔͨΊͷ४උɻ ͍͟ɺϞσϦϯάͷ࣮ߦʂͱ͸͍͔ͳ͍

  7. E ॳظϞσϧͷ࣮ߦɻॳظϞσϧʹར༻ͨ͠σʔλͰɺɹɹɹ ͍͔ͭ͘ͷϞσϧ΋ద༻ɺൺֱɻɹɹɹɹɹɹɹɹɹɹɹɹ ϋΠύʔύϥϝʔλͷ୳ࡧ΋͜͜ͰߦΘΕΔɻ F ෳ਺ճߦΘΕͨύϥϝʔλௐ੔ͷ݁ՌΛ෼ੳ G Ϟσϧͷ݁ՌΛՄࢹԽ ෳ਺ͷϞσϧͰͷੑೳΛൺֱ͢Δ

  8. H ॳظϞσϧΛվྑ͢Δಛ௃ྔΤϯδχΞϦϯά I ࠷ऴతͳީิϞσϧʹର͢Δௐ੔ J ධՁηοτΛར༻ͨ͠൚ԽੑೳͷධՁ K ӡ༻ ϞσϧΛվળ͢Δಛ௃ྔΛ୳͢

  9. ϞσϧͷੑೳධՁ λεΫʹԠͨ͡ධՁࢦඪΛར༻͢Δ ܾఆ܎਺(R2, RSQ: coefficient of determination) ೋ৐ฏۉฏํࠜޡࠩ (RMSE: Root

    Mean Square Error) ฏۉઈରޡࠩ(MAE: Mean absolute error) ࠞಉߦྻ ਖ਼ղ཰ ద߹཰ͱ࠶ݱ཰ ROCۂઢͱAUC ճؼ໰୊ ෼ྨ໰୊
  10. ৄ͘͠͸ͪ͜Β http://bit.ly/slide-fe-recipes \Tweet΍Star͕ྭΈʹͳΓ·͢ʂ/ http://bit.ly/practical-ds

  11. ͰσʔλϞσϦϯάΛߦ͏ࡍͷ՝୊ ଟ͘ͷύοέʔδ͕։ൃ͞Ε͍ͯΔ͕ ΠϯλʔϑΣΠεʹ౷Ұੑ͕ͳ͍ ϞσϧΦϒδΣΫτΛѻ͏ͨΊͷformulaͷܽ఺ 5IF3'PSNVMB.FUIPE5IF(PPE1BSUTu37JFXT IUUQTSWJFXTSTUVEJPDPNUIFSGPSNVMBNFUIPEUIFHPPEQBSUT 5IF3'PSNVMB.FUIPE5IF#BE1BSUTu37JFXT IUUQTSWJFXTSTUVEJPDPNUIFSGPSNVMBNFUIPEUIFCBEQBSUT

  12. tidymodelsͷ֓ཁ tidymodelsʹΑΔϞσϧߏஙͱӡ༻ 190831 Fukuoka.R@LINE෱Ԭ

  13. UJEZNPEFMT .BY,VIO 34UVEJP\DBSFU^ͷ։ൃऀ Β͕ϝϯόʔ UJEZͳ౷ܭϞσϦϯάػցֶशͷॲཧΛߦ͏ ύοέʔδ܈Λఏڙ͢Δ ύΠϓϑϨϯυϦʔ ਓؒʹ΋ϓϩάϥϜʹ΋༏͍͠ઃܭ

  14. {parsnip} {recipes} {rsample} {yardstick} {tidymodels}ʹؚ·ΕΔύοέʔδ library(tidymodels) #> Registered S3 method

    overwritten by 'xts': #> method from #> as.zoo.xts zoo #> ─ Attaching packages ───── tidymodels 0.0.2 ─ #> ✔ broom 0.5.2 ✔ purrr 0.3.2 #> ✔ dials 0.0.2 ✔ recipes 0.1.6 #> ✔ dplyr 0.8.3 ✔ rsample 0.0.5 #> ✔ ggplot2 3.2.1 ✔ tibble 2.1.3 #> ✔ infer 0.4.0.1 ✔ yardstick 0.0.3 #> ✔ parsnip 0.0.3.1 #> ─ Conflicts ──────── tidymodels_conflicts() ─ #> ✖ purrr ::discard() masks scales ::discard() #> ✖ dplyr ::filter() masks stats ::filter() #> ✖ dplyr ::lag() masks stats ::lag() #> ✖ recipes ::step() masks stats ::step() Ϟσϧߏஙɾద༻ ϞσϧͷੑೳධՁ σʔλલॲཧɺ ಛ௃ྔੜ੒ ϦαϯϓϦϯά ಡΈࠐ·ΕΔύοέʔδɺόʔδϣϯ ͓Αͼؔ਺໊ͷিಥʹ͍ͭͯग़ྗ
  15. ΞϠϝଐछͷ෼ྨ໰୊ छ໊(Species)Λϥϕϧͱ͠ɺ෼ྨ໰୊Λద༻͢Δ iris <- as_tibble(iris) glimpse(iris) #> Observations: 150 #>

    Variables: 5 #> $ Sepal.Length <dbl> 5.1, 4.9, 4.7, 4.6, 5.0, 5.4, 4.6, … #> $ Sepal.Width <dbl> 3.5, 3.0, 3.2, 3.1, 3.6, 3.9, 3.4, … #> $ Petal.Length <dbl> 1.4, 1.4, 1.3, 1.5, 1.4, 1.7, 1.4, … #> $ Petal.Width <dbl> 0.2, 0.2, 0.2, 0.2, 0.2, 0.4, 0.3, … #> $ Species <fct> setosa, setosa, setosa, setosa, se… Ֆͷܗଶʹؔ͢Δͭͷಛ௃ྔ
  16. σʔληοτΛ ෼ੳηοτ USBJO ɺධՁηοτ UFTU ʹ෼͚Δ ෼ੳηοτ ධՁηοτ σʔλ෼ׂ 3Ͱͷφ΢ͳσʔλ෼ׂͷ΍ΓํSTBNQMFύοέʔδʹΑΔަࠩݕূגࣜձࣾϗΫιΤϜͷϒϩά

    IUUQTCMPHIPYPNDPNFOUSZ {rsample} σʔληοτ Ϟσϧͷֶशʹ༻͍Δ ϞσϧͷੑೳධՁΛଌఆ͢ΔͨΊɺ ະ஌ͷ৘ใͱͯ͠ར༻͞ΕΔ
  17. ෼ׂ͸ϥϯμϜ σʔλ෼ׂ {rsample} σʔληοτͷׂΛ෼ੳηοτͱ͢Δ iris_split <- initial_split(iris, prop = 0.6)

    iris_split #> <90/60/150> iris_train <- training(iris_split) iris_test <- testing(iris_split)
  18. લॲཧɾಛ௃ྔΤϯδχΞϦϯά{recipes} Ϟσϧʹ༻͍ΔσʔλՃ޻ͷखଓ͖ΛʮϨγϐʯԽ ϞσϧͰѻ͏σʔλͷલॲཧΛSFDJQFTͰߦ͏גࣜձࣾϗΫιΤϜͷϒϩά IUUQTCMPHIPYPNDPNFOUSZ 1 2 3 recipe() step_*() prep()

    bake()/juice() 4 ར༻͢Δม਺ͷؔ܎Λఆٛ ˠࡐྉΛࢦఆ͢Δ σʔλՃ޻ͷखଓ͖Λࢦఆ ˠௐཧ๏Λهड़͢Δ σʔληοτʹద༻ ˠௐཧΛߦ͏ TUFQ@ ͷॲཧΛ౷߹ ˠϨγϐΛ֬ೝ͢Δ
  19. એ఻ʢٳܜʣλΠϜ https://uribo.github.io/dpp-cookbook/

  20. iris_recipe <- iris_train %>% #> #4QFDJFTΛ໨తม਺ɺଞͷม਺Λઆ໌ม਺ʹͨ͠Ϟσϧ recipe(formula = Species ~

    .) %>% #> # εςοϓ1: ͢΂ͯͷઆ໌ม਺Λର৅ʹ #> # ૬ؔ܎਺ͷߴ͍(0.9Ҏ্)ม਺ͷ͍ͣΕ͔Λআ֎ step_corr(all_predictors(), threshold = 0.9) %>% #> # εςοϓ2: ͢΂ͯͷઆ໌ม਺Λର৅ʹ਺஋ม਺ͷඪ४Խ step_normalize(all_predictors(), -all_outcomes()) Ϟσϧ΁ͷॲཧΛύΠϓԋࢉࢠͰ௥Ճ outcome ~ predictors role
  21. step_*()Ͱͷม਺ͷࢦఆ จࣈྻͰͷࢦఆ tidyselectͷؔ਺ Ϟσϧ಺Ͱͷrole 1 2 3 ม਺ͷσʔλܕ 4 all_predictors()

    all_outcomes() starts_with() contains()ͳͲ all_nominal() all_numeric() "Species" "Sepal.Length"
  22. recipeΦϒδΣΫτ iris_recipe #> Data Recipe #> #> Inputs: #> #>

    role #variables #> outcome 1 #> predictor 4 #> #> Operations: #> #> Correlation filter on all_predictors #> Centering and scaling for all_predictors, -, all_outcomes() recipe()͓Αͼstep_*()Ͱఆٛͨ͠खॱΛࣔ͢
  23. iris_recipe_comp <- iris_recipe %>% prep() #> … தུ #> Training

    data contained 90 data points and no missing data. #> #> Operations: #> #> Correlation filter removed Petal.Length [trained] #> Centering and scaling for Sepal.Length, Sepal.Width, Petal.Width [trained] ෼ੳηοτʹద༻ͨ͠ΒͲ͏ͳΔ͔ ͜ͷঢ়ଶͰ͸·ͩσʔλ͕ੜ੒͞Ε͍ͯͳ͍ recipeΦϒδΣΫτ
  24. σʔληοτʹϨγϐΛద༻ iris_training <- juice(iris_recipe_comp) glimpse(iris_training) #> Observations: 90 #> Variables:

    4 #> $ Sepal.Length <dbl> -1.35774587, -1.71228299… #> $ Sepal.Width <dbl> 0.230356130, -0.438419732… #> $ Petal.Width <dbl> -1.2122997, -1.2122997… #> $ Species <fct> setosa, se… iris_testing <- iris_recipe_comp %>% bake(new_data = iris_test) ෼ੳηοτ ධՁηοτ Petal.Length͕ফ͍͑ͯΔ iris_testing Ͱ΋ಉ༷
  25. all.equal( mean(scale(iris_train$Sepal.Length)), mean(iris_training$Sepal.Length)) #> [1] TRUE {recipes}ͷॲཧ಺༰Λ͔֬ΊΔ ฏۉ ෼ࢄͷඪ४Խ iris_train

    %>% recipe(formula = Species ~ .) %>% step_corr(all_predictors(), threshold = 0.9) %>% step_normalize(all_predictors(), -all_outcomes())
  26. {recipes}ͷॲཧ಺༰Λ͔֬ΊΔ corrr ::correlate(iris_test %>% select_if(is.numeric), method = "pearson", use =

    "pairwise.complete.obs", quiet = TRUE) #> # A tibble: 4 x 5 #> rowname Sepal.Length Sepal.Width Petal.Length Petal.Width #> <chr> <dbl> <dbl> <dbl> <dbl> #> 1 Sepal.Length NA -0.0685 0.859 0.798 #> 2 Sepal.Width -0.0685 NA -0.395 -0.333 #> 3 Petal.Length 0.859 -0.395 NA 0.958 #> 4 Petal.Width 0.798 -0.333 0.958 NA ૬ؔ܎਺ͷߴ͍ม਺ͷยํΛ σʔληοτ͔Βআ֎͢Δ
  27. recipes͕ఏڙ͢Δstep ਺஋ม׵ ΤϯίʔσΟϯά ೔෇ɾ࣌ؒ ϑΟϧλॲཧ step_*() ؔ਺͸ ͱͯ͠ఏڙ͞ΕΔ ܽଛ஋΁ͷ୅ೖ ඪ४Խ

    ࣍ݩ࡟ݮ ͳͲ ls("package:recipes", pattern = “^step_") #> # 69ݸͷstep_*ؔ਺ (version 0.1.6)
  28. step_*()ͷ֦ுͱ࣮૷ σʔλܕʹԠͨ͡ {textrecipes} {embed} 5FOTPS'MPXϞσϧͰར༻ {tfdatasets} step_*() Λఏڙ͢Δύοέʔδ จࣈྻ ΧςΰϦ

    ೔ຊޠͷલॲཧ {washoku} ਖ਼نԽɺ൒֯શ֯ॲཧɺॅॴ෼ׂͳͲ ߏ૝த remotes ::install_github("uribo/washoku")
  29. Ϟσϧߏங {parsnip} ଟ༷ͳϞσϦϯάύοέʔδΛϥοϓɻ౷Ұతʹૢ࡞Մೳ ࢓༷Λఆٛ ΤϯδϯʢύοέʔδʣΛࢦఆ Ϟσϧͷ౰ͯ͸Ί ճؼʁ෼ྨʁ 1 set_engine() ՝୊ʹదͨ͠ϞσϧΛબͿ

    2 3 fit() rand_forest() linear_reg() logistic_reg() ෼ੳηοτͷֶश
  30. ύοέʔδ ܾఆ໦ʹ༩͑Δ ಛ௃ྔ਺ ࡞੒͢Δ ܾఆ໦ͷ਺ ϊʔυʹؚ·ΕΔ ࠷খͷαϯϓϧ਺ ranger mtry num.trees

    min.node.size randomForest mtry ntree nodesize sparklyr mtry num.trees min_instances_per_node ؔ਺ͷॻ͖ํʢҾ਺ͷࢦఆํ๏ʣ͕ ύοέʔδʹΑͬͯҟͳΔ {parsnip}Λ࢖Θͳ͍৔߹ rand_forest(mode = "classification", mtry, trees, min_n)
  31. rf_spec <- rand_forest(mode = "classification", trees = 100) #> #ranger

    ::ranger iris_ranger <- rf_spec %>% set_engine("ranger", seed = 123) %>% fit(Species ~ ., data = iris_training) #> #randomForest ::randomForest iris_rf <- rf_spec %>% set_engine("randomForest") %>% fit(Species ~ ., data = iris_training) ϥϯμϜϑΥϨετΛར༻ͨ͠Ϟσϧ Τϯδϯ ύοέʔδ ʹݻ༗ͷ ΦϓγϣϯΛࢦఆՄೳ
  32. Ϟσϧͷద༻݁Ռ iris_rf #> Call: #> randomForest(x = as.data.frame(x), y =

    y, ntree = ~100) #> Type of random forest: classification #> Number of trees: 100 #> No. of variables tried at each split: 1 #> #> OOB estimate of error rate: 2.22% #> Confusion matrix: #> setosa versicolor virginica class.error #> setosa 36 0 0 0.00000000 #> versicolor 0 26 1 0.03703704 #> virginica 0 1 26 0.03703704 rand_forest()ɺset_engine()Ͱࢦఆͨ͠ॲཧ͕࣮ߦ
  33. Ϟσϧͷվળͱӡ༻ tidymodelsʹΑΔϞσϧߏஙͱӡ༻ 190831 Fukuoka.R@LINE෱Ԭ

  34. {yardstick} ϞσϧੑೳධՁ ςετσʔλʢ৽͍͠σʔλʣΛର৅ʹϥϕϧͷਪఆ iris_rf_pred <- predict(iris_rf, iris_testing) %>% bind_cols(iris_testing) rf_metrics

    <- iris_rf_pred %>% metrics(truth = Species, estimate = .pred_class) rf_metrics #> # A tibble: 2 x 3 #> .metric .estimator .estimate #> <chr> <chr> <dbl> #> 1 accuracy multiclass 0.95 #> 2 kap multiclass 0.925 ෼ྨ໰୊ͷλεΫͰ͸ ਫ਼౓ BDDVSBDZ ͱ,BQQB܎਺ ͕ੑೳࢦඪͱͯ͠༻ҙ͞ΕΔ
  35. {yardstick} ੑೳࢦඪͷ૊Έ߹ΘͤͨΓಛఆͷੑೳࢦඪΛબ΂Δ ϞσϧੑೳධՁ ෼ྨ໰୊ͷλεΫͰͷ ੑೳධՁؔ਺ʹ͸ USVUI FTUJNBUFΛࢦఆ͢Δ iris_rf_pred %>% accuracy(Species,

    .pred_class) perf_metrics <- metric_set(accuracy, recall, precision) iris_rf_pred %>% perf_metrics(truth = Species, estimate = .pred_class) #> # A tibble: 3 x 3 #> .metric .estimator .estimate #> <chr> <chr> <dbl> #> 1 accuracy multiclass 0.983 #> 2 recall macro 0.981 #> 3 precision macro 0.986 iris_rf_pred %>% kap(Species, .pred_class)
  36. ggplotϕʔεͷՄࢹԽ autoplot(iris_rf, type = "heatmap") ࠞಉߦྻͷώʔτϚοϓ

  37. ggplotϕʔεͷՄࢹԽ iris_ranger %>% predict(iris_testing, type = "prob") %>% bind_cols(iris_testing) %>%

    roc_curve(Species, rstarts_with(".pred")) %>% autoplot() ROCۂઢ
  38. σʔλͷ෼ׂ͸Ұ౓͖Γ ަࠩݕূͷͨΊͷϦαϯϓϦϯά {rsample} ෼ੳηοτ ධՁηοτ iris_split <- initial_split(iris, prop =

    0.6) ෳ਺ͷϦαϯϓϦϯάσʔλ͔Β Ϟσϧͷ൚ԽੑೳΛධՁ͢Δ͜ͱ͕ඞཁˠަࠩݕূ
  39. σʔλɺ໨తʹԠͨ͡෼ׂํ๏ΛબͿ ࣌ܥྻɺۭؒσʔλͳͲ஫ҙ

  40. iris_cv <- vfold_cv(iris, v = 4) iris_cv #> # A

    tibble: 4 x 2 #> splits id #> <named list> <chr> #> 1 <split [112/38]> Fold1 #> 2 <split [112/38]> Fold2 #> 3 <split [113/37]> Fold3 #> 4 <split [113/37]> Fold4 σʔλɺ໨తʹԠͨ͡෼ׂํ๏ΛબͿ assessment(iris_cv$splits[[1]]) ධՁηοτ analysis(iris_cv$splits[[1]]) ෼ੳηοτ FoldʹϦαϯϓϦϯά͞Εͨ ෼ੳηοτɺධՁηοτΛࢀর σʔλϑϨʔϜͷߦ͝ͱʹ ෼ׂ͞Εͨσʔληοτ͕ ؚ·ΕΔ
  41. Fold͝ͱʹֶशɺੑೳࢦඪΛࢉग़ #> # 4-fold cross-validation #> # A tibble: 4

    x 4 #> splits id recipes accuracy #> * <named list> <chr> <named list> <dbl> #> 1 <split [112/38]> Fold1 <tibble [38 × 5]> 0.974 #> 2 <split [112/38]> Fold2 <tibble [38 × 5]> 0.895 #> 3 <split [113/37]> Fold3 <tibble [37 × 5]> 0.919 #> 4 <split [113/37]> Fold4 <tibble [37 × 5]> 0.973 .Last.value %>% pull(accuracy) %>% mean() #> [1] 0.9400782
  42. ϥϯμϜϑΥϨετϞσϧͰ͸ ͭͷϋΠύʔύϥϝʔλͷࢦఆ͕Մೳ ϋΠύʔύϥϝʔλͷ୳ࡧ {dials} ϋΠύʔύϥϝʔλͷ஋͕Ϟσϧͷਫ਼౓ʹӨڹ͢Δ rand_forest(mode = "classification", mtry, trees,

    min_n) ύϥϝʔλͷ஋ΛมԽͤͨ͞ঢ়ଶͰͷੑೳධՁ͕ඞཁ
  43. set.seed(1234) bst_grid <- grid_random( range_set(trees, c(1, 50)), range_set(min_n, c(2, 30)),

    size = 3) #> # A tibble: 3 x 2 #> trees min_n #> <int> <int> #> 1 28 6 #> 2 16 13 #> 3 22 16 άϦουαʔν mod_obj <- rand_forest( ɹɹmode = "classification", ɹɹtrees = varying(), ɹɹmin_n = varying(), ɹɹmtry = 3) ύϥϝʔλ໊ͱ ϥϯμϜʹׂΓৼΔ஋ͷ ൣғΛࢦఆ TJ[F݅਺ͷϥϯμϜσʔλ͕ੜ੒ {parsnip} ͰͷϞσϧ࡞੒ ஈ֊Ͱݻఆ͠ͳ͍ ύϥϝʔλΛࢦఆ
  44. lapply(), for()Ͱ΋OKͰ͢ purrrͱ૊Έ߹Θͤͯ seq.int(nrow(bst_grid)) %>% purrr ::map( ~ merge(mod_obj, bst_grid[.x,

    ])) %>% purrr ::flatten() #> #> Random Forest Model Specification (classification) #> #> Main Arguments: #> mtry = 3 #> trees = 28 #> min_n = 6 #> Main Arguments: #> mtry = 3 #> trees = 16 #> min_n = 13 #> Main Arguments: #> mtry = 3 #> trees = 22 #> min_n = 16 ͦΖͦΖखΛग़͢QVSSSOFLPTLZ4QFBLFS%FDL IUUQTTQFBLFSEFDLDPNT@VSZVOFLPTLZ
  45. ·ͱΊ {parsnip} {recipes} {rsample} {yardstick} Ϟσϧͷֶश ੑೳධՁ σʔλՃ޻ ϦαϯϓϦϯά {tidymodels}

    ͸ύΠϓϑϨϯυϦʔɺ ౷ܭϞσϧɾػցֶशʹඞཁͳॲཧΛ؆ུԽͰ͖Δ ౷ҰతͳΠϯλʔϑΣʔεΛఏڙ͢Δɻ initial_split() vfold_cv() step_*() bake()/juice() set_engine() metrics() 1 2 3 4
  46. &OKPZ ָ͠ΜͰ͘ΕΜ