Upgrade to Pro — share decks privately, control downloads, hide ads and more …

spark.ml の API で XGBoost を扱いたい!#shokaispark

spark.ml の API で XGBoost を扱いたい!#shokaispark

『詳解 Apache Spark』出版記念イベントでの発表資料です。

http://connpass.com/event/30375/

E77287648aff5484ac7659748e45c936?s=128

KOMIYA Atsushi

May 11, 2016
Tweet

Transcript

  1. spark.ml ͷ API Ͱ XGBoost Λѻ͍͍ͨʂ 2016-05-11 ʰৄղ Apache Sparkʱग़൛ه೦Πϕϯτ

    KOMIYA Atsushi (@komiya_atsushi)
  2. ͓·͑ͩΕΑ

  3. KOMIYA Atsushi @komiya_atsushi

  4. Today’s topic

  5. on

  6. XGBoost • ޯ഑ϒʔεςΟϯάͷ࣮૷ͷҰͭ • ܾఆ໦ʹର͢Δޯ഑ϒʔεςΟϯά͸ɺ MLlib Ͱ΋ GBTClassifier / GBTRegressor

    ͱ ࣮ͯ͠૷͞Ε͍ͯΔ • ༧ଌਫ਼౓ͷߴ͞ͳͲ͔ΒɺKaggler ͳํʑΛ த৺ʹਓؾ͕͋ΔʢͬΆ͍ʣ
  7. spark.ml ͷ API Ͱɺ XGBoost Λ Spark ্Ͱ ѻ͍͍ͨʂ

  8. spark.ml ͷ API Ͱѻ͑Δͱ… • spark.ml ͕ఏڙ͢Δ֤छػೳΛ༗ޮ׆༻Ͱ͖Δ • ಛ௃நग़ɾม׵ɾબ୒ •

    ύϥϝʔλͷάϦουαʔν • ύΠϓϥΠϯ • ަࠩݕূ… ͳͲ
  9. ͜ͷൃදͰ͓࿩͢Δ͜ͱ • XGBoost on Spark ͷݱঢ় • spark.ml ͷ API

    ͰػցֶशΞϧΰϦζϜΛ
 ࣮૷͢ΔࡍͷϙΠϯτ • ಛʹΠϯλϑΣʔε෦෼ʹண໨͢Δ
  10. XGBoost & Spark

  11. XGBoost on Spark • Spark ্Ͱ XGBoost Λ࢖͓͏ͱ͢Δͱɺ
 ݱঢ়Ͱ͸બ୒ࢶ͸ 2

    ͭ • SparkXGBoost • xgboost4j-spark
  12. SparkXGBoost • https://github.com/rotationsymmetry/sparkxgboost • XGBoost ͱಉ͡ޯ഑ϒʔεςΟϯάπϦʔΛɺSpark ޲͚ ʹ pure Scala

    Ͱ࣮૷͍ͯ͠Δ • Spark packages ʹొ࿥͞Ε͍ͯΔ • ΦϦδφϧͷ XGBoost ʹͲ͜·Ͱ஧࣮ͳ࣮૷ͳͷ͔ෆ໌ • ver 0.6 ·ͰͷϩʔυϚοϓ͕͋Δ͕ɺ։ൃ͕׆ൃͰ͸ͳ͍ • ࠷ޙͷίϛοτ͸ࡢ೥ 11 ݄ɺver 0.2
  13. xgboost4j-spark • DMLC ͕ఏڙ͢Δެࣜͷ Spark integration • ͨͩ͠ɺDataFrame ʹ͸ରԠ͍ͯ͠ͳ͍ •

    XGBoost ຊମͷ git ϦϙδτϦ্Ͱϝϯς͞Ε͍ͯΔ • ֶश͓Αͼ༧ଌͷ۩ମతͳॲཧ͸ɺJNI ܦ༝Ͱ C++ ࣮૷ʹ͓೚ͤ • ֶश࣌ͷϫʔΧʔؒͷ௨৴ʹ͸ Rabit Λར༻͍ͯ͠Δ • Maven central ʹ͸ొ࿥͞Ε͍ͯͳ͍ • ར༻͢Δʹ͸໺ྑϏϧυඞਢ
  14. ࠓճ͸… • SparkXGBoost ͷΑ͏ʹɺXGBoost Λֶशث ؚΊͯ pure Scala Ͱ࠶࣮૷͢Δͷ͸ϋʔυϧ ͕ߴ͍

    • xgboost4j-spark ͕ࢀর͢Δ xgboost4j Λ
 ϕʔεʹɺspark.ml ͷ API Ͱϥοϓͯ͠ΈΔ
  15. spark.ml internals (ΏΔ;Θ)

  16. spark.ml ͷ࣮૷ΛಡΉ • spark.ml ʹ͓͚ΔػցֶशΞϧΰϦζϜͷ
 ࣮૷͓࡞๏Λ஌Δʹ͸Ͳ͏ͨ͠ΒΑ͍͔ʁ • MLlib ͕ఏڙ͢Δ֤छΞϧΰϦζϜͷ࣮૷Λ ಡΉͷ͕Ұ൪ͷۙಓ

  17. spark.ml ͷ࣮૷ΛಡΉ • ࣮૷ΛಡΉͷʹ͓͢͢ΊͳػցֶशΞϧΰϦζϜ • ϩδεςΟοΫճؼ • LogisticRegression / LogisticRegressionModel

    • ܾఆ໦ (෼ྨ) • DecisionTreeClassifier / DecisionTreeClassificationModel • ܾఆ໦ (ճؼ) • DecisionTreeRegressor / DecisionTreeRegressionModel
  18. spark.ml ʹ͓͚Δػցֶशͷ࣮૷ • ػցֶशΞϧΰϦζϜͷֶशث͸ɺ਌ΛḷΔͱ Estimator Ϋϥεʹߦ͖ண͘ • ֶशثʹΑͬͯಘΒΕΔ༧ଌϞσϧ͸ɺ਌ΛḷΔͱ Transformer Ϋϥεʹߦ͖ண͘

    • ຊॻͷ pp.217-218 Λࢀর • ͨͩ͠ͲͪΒ΋ Estimator ΍ Transformer Λ௚઀ extends ͍ͯ͠Δͱ͸ݶΒͳ͍
  19. ֶशثͷΫϥε֊૚ &TUJNBUPS 1SFEJDUPS $MBTTJpFS 1SPCBCJMJTUJD$MBTTJpFS ճؼΞϧΰϦζϜͷଟ͘͸ 1SFEJDUPSΛFYUFOET͍ͯ͠Δ ෼ྨΞϧΰϦζϜͷଟ͘͸ 1SPCBCJMJTUJD$MBTTJpFSΛFYUFOET͍ͯ͠Δ

  20. ༧ଌϞσϧͷΫϥε֊૚ 5SBOTGPSNFS 1SFEJDUJPO.PEFM $MBTTJpDBUJPO.PEFM 1SPCBCJMJTUJD$MBTTJpDBUJPO.PEFM 1SFEJDUPSʹରԠ͢Δ ༧ଌϞσϧͷ਌ΫϥεͱͳΔ 1SPCBCJMJTUJD$MBTTJpFSʹରԠ͢Δ ༧ଌϞσϧͷ਌ΫϥεͱͳΔ

  21. ֶशثͱ༧ଌϞσϧͷ࣮૷

  22. Predictor Ϋϥε • ΧϥϜ • label: ਖ਼ղϥϕϧΛ࣋ͭΧϥϜ • features: ಛ௃ϕΫτϧΛ࣋ͭΧϥϜ

    • prediction: ༧ଌ͞Εͨϥϕϧ͕ઃఆ͞ΕΔΧϥϜ • ϝιου • train (ந৅ϝιου): ֶशॲཧΛ࣮૷͢Δ • extractLabeledPoints: DataFrame ͔Β RDD[LabeledPoint] Λੜ੒ͯ͘͠ΕΔϝιου
  23. Classifier Ϋϥε • ΧϥϜ • rawPrediction: ༧ଌϞσϧ͕ੜ੒ͨ͠ੜͷ஋ ͕ઃఆ͞ΕΔΧϥϜ • ༧ଌϥϕϧ͸ɺ͜ͷ஋ΛجʹٻΊΒΕΔ

  24. ProbabilisticClassifier Ϋϥε • ΧϥϜ • probability: (ೋ஋෼ྨͰ͋Ε͹) ਖ਼ղϥϕϧ͕ 1 Ͱ͋Δͱ༧ଌ͞ΕΔ֬཰͕ઃఆ͞ΕΔΧϥϜ

    • ύϥϝʔλ • threshold: ༧ଌ֬཰ (probability ΧϥϜ) ʹج͍ͮ ͯ 0/1 ʹৼΓ෼͚Δࡍͷ͖͍͠஋
  25. PredictionModel Ϋϥε • ϝιου • transform: transformImpl ϝιουΛݺͼग़͚ͩ͢ • transformImpl:

    ༩͑ΒΕͨ DataFrame ͷͦΕͧΕ ͷߦ͝ͱʹ predict ϝιουΛݺͼग़͢ • predict (ந৅ϝιου): ༩͑ΒΕͨಛ௃ϕΫτϧ͔ Β༧ଌ݁ՌΛੜ੒͢ΔॲཧΛ࣮૷͢Δ
  26. ClassificationModel Ϋϥε • ϝιου • transform: predict ϝιου΍ predictRaw &

    raw2Prediction ϝιουΛݺͼग़ͯ͠༧ଌ݁ՌΛٻΊΔ • predict: predictRaw ϝιουͷ݁ՌΛ raw2Prediction ʹ౉͠ ͯ༧ଌϥϕϧΛฦ͢ • predictRaw (ந৅ϝιου): ༧ଌϞσϧΛ༻͍ͯੜͷ༧ଌ஋Λ ฦ͢ॲཧΛ࣮૷͢Δ • raw2Prediction (ந৅ϝιου): ༧ଌϞσϧ͕ੜ੒ͨ͠ੜͷ༧ ଌ஋͔ΒϥϕϧΛ༧ଌॲཧΛ࣮૷͢Δ
  27. ProbabilisticClassificationModel Ϋϥε • ϝιου • predictRaw (ந৅ϝιου): ClassificationModel ʹಉ͡ •

    raw2ProbabilityInPlace (ந৅ϝιου): ੜͷ༧ଌ஋͔Β༧ଌ ֬཰ʹม׵͢ΔॲཧΛ࣮૷͢Δ • predictProbability: predictRaw ϝιουͷ݁ՌΛ raw2ProbabilityInPlace ϝιουʹ౉ͯ͠༧ଌ֬཰ʹม׵͢Δ • probability2Prediction: ༧ଌ֬཰͔Β༧ଌϥϕϧΛฦ͢ • raw2Prediction: ੜͷ༧ଌ஋͔Β༧ଌϥϕϧΛฦ͢
  28. ֶशثɾ༧ଌϞσϧͷ࣮૷ͷϙΠϯτ (1) • ෼ྨΞϧΰϦζϜͱճؼΞϧΰϦζϜͰ࣮૷ΫϥεΛ ෼͚Α͏ • MLlib Ͱ͸ɺϥϯμϜϑΥϨετ΍ޯ഑ϒʔεςΟ ϯάπϦʔͷΑ͏ʹɺ෼ྨʹ΋ճؼʹ΋࢖͑ΔΞϧ ΰϦζϜ͸ͦΕͧΕͷ࣮૷Ϋϥε͕༻ҙ͞Ε͍ͯΔ

    • e.g. GBTClassifier and GBTRegressor
  29. ֶशثɾ༧ଌϞσϧͷ࣮૷ͷϙΠϯτ (2) • ෼ྨΞϧΰϦζϜͷ࣮૷ • ֶशثͷ࣮૷Ϋϥε͸ ProbabilisticClassifier Λ extends ͠Α͏

    • ༧ଌϞσϧͷ࣮૷Ϋϥε͸ ProbabilisticClassificationModel Λ extends ͠Α͏ • (ςϯϓϨతͳϝιουͷ࣮૷Λআ͚͹) predictRaw, raw2probabilityInPlace ϝιουΛ࣮૷͢Δ͚ͩͰࡁΉ
  30. ֶशثɾ༧ଌϞσϧͷ࣮૷ͷϙΠϯτ (3) • ճؼΞϧΰϦζϜͷ࣮૷ • ֶशثͷ࣮૷Ϋϥε͸ Predictor Λextends ͠Α͏ •

    ༧ଌϞσϧͷ࣮૷Ϋϥε͸ PredictionModel Λ extends ͠Α͏ • predict ϝιουΛ࣮૷͢Δ͚ͩͰࡁΉ
  31. ύϥϝʔλ

  32. spark.ml ʹ͓͚Δύϥϝʔλ • ػցֶशʹ͸ϋΠύʔύϥϝʔλͷνϡʔχϯά͕ ͖ͭ΋ͷ • spark.ml Ͱ͸άϦουαʔνͷػೳΛఏڙ͍ͯ͠Δ • spark.ml

    ͰػցֶशΞϧΰϦζϜΛ࣮૷͢Δࡍ͸ɺ
 ύϥϝʔλνϡʔχϯάͰ͖ΔΑ͏ߟྀ͕ඞཁ
  33. ύϥϝʔλͷ࣮૷ྫ trait XGBoostGeneralParams extends Params {
 final val booster: Param[String]

    = new Param(this, "booster", // ύϥϝʔλ໊
 "which booster to use, can be gbtree or gblinear.", // આ໌ // ύϥϝʔλʹର͢ΔόϦσʔγϣϯϧʔϧ
 ParamValidators.inArray(Array("gbtree", "gblinear")))
 // setter, getter Λ༻ҙ͢Δ
 def setBooster(value: String): this.type = set(booster, value)
 def getBooster: String = $(booster)
 // σϑΥϧτ஋Λઃఆ͢Δ setDefault(booster, "gbtree")
 }
  34. ύϥϝʔλͷ࣮૷ϙΠϯτ (1) • ύϥϝʔλΛఆٛ͠Α͏ • ܕ • Param, DoubleParam, IntParam,

    FloatParam, LongParam… • ύϥϝʔλ໊ • આ໌ • όϦσʔγϣϯ • ParamValidators ͕ఏڙ͢ΔϑΝΫτϦϝιουΛར༻͢Δ
  35. ύϥϝʔλͷ࣮૷ϙΠϯτ (2) • getter / setter Λ༻ҙ͠Α͏ • σϑΥϧτ஋Λઃఆ͠Α͏ •

    ͜ͷ͋ͨΓ͸ςϯϓϨతͳ࣮૷ʹͳΔ
  36. spark.ml-friendly XGBoost

  37. xgboost-dataframe-prototype • https://github.com/komiya-atsushi/xgboost- dataframe-prototype • repo ໊ʹ͋Δͱ͓ΓɺϓϩτλΠϓͰ͢ • ͝ར༻͍ͨͩ͘ࡍ͸͝஫ҙΛ •

    ֶश࣌ͷ෼ࢄॲཧ͸͍ͯ͠·ͤΜ • Rabit ͷ API Λ೺Ѳ͢Δඞཁ͕͋ΔͷͰ…
  38. ·ͱΊ

  39. ·ͱΊ • XGBoost Λ୊ࡐʹɺspark.ml ͷ API Ͱػցֶश ΞϧΰϦζϜΛ࣮૷͢ΔϙΠϯτΛ͓࿩͠·ͨ͠ • ֶशثɾ༧ଌϞσϧͷ਌Ϋϥε

    • ύϥϝʔλ • Έͳ͞·ͷ Spark ্Ͱͷػցֶशͷ࣮૷ͷࢀߟ ʹͳΕ͹޾͍Ͱ͢
  40. Thank you!