Sparkで生存時間解析をしたかった話

 Sparkで生存時間解析をしたかった話

第1回 Machine Learning 15minutes! ( http://machine-learning15minutes.connpass.com/event/32889/ ) で発表した時の資料です。

D2849a9c49a0d1ad3685086085014b5a?s=128

Atsushi Hayakawa

June 27, 2016
Tweet

Transcript

  1. 2.

    ࣗݾ঺հ • ૣ઒ ರ࢜ • ػցֶश΍σʔλ෼ੳΛͯ͠·͢ • גࣜձࣾϦΫϧʔτίϛϡχέʔγϣϯζʹ৽ଔೖࣾͯ͠ 2 ೥໨

    • ίϛϡχςΟ׆ಈ: Japan.R ओ࠵ • ࣥච׆ಈ • σʔλαΠΤϯςΟετཆ੒ಡຊ, ʮPython ʹΑΔػցֶशʯΛ୲౰ • झຯ • Ϛϥιϯ: ϑϧϚϥιϯͷλΠϜ͸ 5 ࣌ؒ 4 ෼Ͱͨ͠ orz • ࣗసं: ୆࿷Ұप, ౦ژ͔Βവؗ·Ͱ૸ഁ, ശࠜͷࢁӽ͑ • ϑοταϧ: ٳΈͷ೔ʹͨ·ʹ, ༠͍ͬͯͩ͘͞ʂ ʂ • ిࢠ޻࡞: όονॲཧ͕མͪΔͱճస౮͕ޫΔΨδΣ οτΛ࡞ͬͯΈͨ • ՖՐ: ֶੜ࣌୅ͷαʔΫϧͰ೥ 1 ճͷଧͪ༲͛ 2/ 18
  2. 3.

    ͢Δ࿩ɺ͠ͳ͍࿩ • ͢Δ࿩ • Spark Ͱੜଘ࣌ؒղੳΛࢼͯ͠஍ཕΛ౿Μͩ࿩ • Ϟσϧͷ؆୯ͳ঺հ • ͠ͳ͍࿩

    • Spark, Scala ͷৄࡉͳղઆ • ొ৔͢Δίʔυ΁ͷποίϛ͸׻ܴʂ • ஍ཕΛղܾ͢Δํ๏ • Ϟσϧͷৄࡉͳ࿩ 3/ 18
  3. 4.

    Sparkͱ͸ • େن໛σʔλ޲͚ͷ෼ࢄॲཧϑϨʔϜϫʔΫ • Java, Scala, Python, R ͔Βར༻Մೳ ਤ

    1: ׆ൃͳ։ൃ (https://en.wikipedia.org/wiki/Apache Spark ΑΓҾ༻) 4/ 18
  4. 7.

    SparkͰར༻Մೳͳੜଘ࣌ؒղੳͷϞσϧ • AFT(Accelerated Failure Time) Ϟσϧ • ଧͪ੾Γσʔλ༻ͷύϥϝτϦοΫͳϞσϧ • Spark

    Ͱ͸ϫΠϒϧ෼෍ΛԾఆͨ͠ϫΠϒϧճؼ͕࣮૷ࡁΈ • ܗঢ়ύϥϝʔλڞ௨ mɺई౓ύϥϝʔλ η(X) ͕ڞมྔͰมΘΔ • F(t|X) = 1 − exp ( − ( t η(X) )m ) • η(X) = exp(βX) 7/ 18
  5. 8.

    αϯϓϧσʔλͷ࡞੒ αϯϓϧσʔλͷ࡞੒ > set.seed(71) > n <- 1000 > m

    <- 2 > beta < -1 > beta0 <- 10 > x <- rnorm(n) > data.df <- data.frame( > tt = rweibull(n, shape = m, > scale=exp(beta * x + beta0)), > x = x, status=rep(1,n)) 8/ 18
  6. 9.

    RݴޠʹΑΔϫΠϒϧճؼ R ݴޠʹΑΔϫΠϒϧճؼ > library(survival) > survreg(Surv(tt, status) ~ x,

    data=data.df, dist="weibull") Call: survreg(formula = Surv(tt, status) ~ x, data = data.df, dist = " weibull") Coefficients: (Intercept) x 10.0162268 -0.9916226 Scale= 0.491728 Loglik(model)= -10618.8 Loglik(intercept only)= -11366.3 Chisq= 1494.89 on 1 degrees of freedom, p= 0 n= 1000 9/ 18
  7. 10.

    Sparkͷ࣮ߦ؀ڥ Docker ্Ͱͷ࣮ߦ sudo docker run --name spark-notebook -p 9000:9000

    andypetrella/ spark-notebook:0.6.3-scala-2.10.5-spark-1.6.1-hadoop-2.7.2- with-parquet Ͱ Spark Notebook Λىಈͯ͠࢖͍ͬͯ·͢ɻ 10/ 18
  8. 11.

    SparkͰͷσʔλಡΈࠐΈ1 Scala ͷίʔυ import org.apache.spark.sql.functions._ import org.apache.spark.mllib.linalg.{Vector, Vectors} val sqlContext

    = new org.apache.spark.sql.SQLContext(sc) import sqlContext.implicits._ val df = sqlContext.read.json("sampledata.json") 11/ 18
  9. 12.

    SparkͰͷσʔλಡΈࠐΈ2 Scala ͷίʔυ (100 ݅ͷσʔλͰࢼ͢) case class MyDF(label: Double, censor:

    Double, features: org.apache.spark.mllib.linalg.Vector) val training = df. withColumn("x", df("x").cast("Double")). withColumn("tt", df("tt").cast("Double")). withColumn("status", df("status").cast("Double")). select("tt", "status", "x"). map(row => MyDF(row.getDouble(0), row.getDouble(1), Vectors.dense(row.getDouble(2)))). toDF("label", "censor", "features"). limit(100) 12/ 18
  10. 13.

    SparkͰͷϞσϧߏங Scala ͷίʔυ import org.apache.spark.ml.regression.AFTSurvivalRegression import org.apache.spark.mllib.linalg.Vectors val quantileProbabilities =

    Array(0.3, 0.6) val aft = new AFTSurvivalRegression() .setQuantileProbabilities(quantileProbabilities) .setQuantilesCol("quantiles") val model = aft.fit(training) println(s"Coefficients: ${model.coefficients} Intercept: " + s"${model.intercept} Scale: ${model.scale}") model.transform(training).show(false) 13/ 18
  11. 14.

    ݁Ռ ύϥϝʔλਪఆΛͯ͠ɺΠϕϯτͷൃੜΛ༧ଌ͍ͯ͠Δ Coefficients: [-0.94] Intercept: 9.94 Scale: 0.49 +---------+------+--------+-----------+--------------------+ |label

    |censor|features|prediction |quantiles | +---------+------+--------+-----------+--------------------+ |42227.15 |1.0 |[-0.43] |31156.57 |[18784.58,29848.00] | |31827.69 |1.0 |[-0.45] |31608.28 |[19056.92,30280.74] | |18740.14 |1.0 |[-0.48] |32552.67 |[19626.31,31185.47] | |9580.26 |1.0 |[0.42] |14050.37 |[8471.10,13460.26] | |9225.93 |1.0 |[-0.42] |30751.77 |[18540.52,29460.20] | 14/ 18
  12. 15.

    But 1000 ݅ͷσʔλΛೖΕΔͱɾ ɾ ɾ org.apache.spark.SparkException: Job aborted due to

    stage failure: Task 0 in stage 235.0 failed 1 times, most recent failure: Lost task 0.0 in stage 235.0 (TID 127, localhost): java.lang.AssertionError: assertion failed: AFTAggregator loss sum is infinity. Error for unknown reason. ͱΤϥʔΛు͘ɻhttps://issues.apache.org/jira/browse/SPARK-13322 ͷνέο τͰ͸ɺڞมྔΛඪ४Խͨ͠ͳ͍৔߹ʹى͖Δͱ͋Δ͕ɺࠓճ͸ඪ४Խ͞Ε ͍ͯΔɻ 15/ 18
  13. 16.

    ·ͱΊ • ࣗલͰσʔλΛੜ੒ͯ͠ Spark Ͱੜଘ࣌ؒղੳΛͨ͠ • 100 ݅ͷΑ͏ʹখ͞ͳσʔλɾηοτͳΒ໰୊ͳ͍͕ɺ1000 ݅ʹͳΔͱ ΤϥʔΛు͘

    • Spark ͷϝʔϦϯάϦετͰ࣭໰͍ͯ͠Δ࠷தͰ͢ • ੜଘ࣌ؒղੳΛ͢ΔͳΒ R ݴޠ͕ॆ࣮͍ͯ͠·͢ 16/ 18
  14. 17.

    ࢀߟࢿྉ • https://en.wikipedia.org/wiki/Apache Spark • https://www1.doshisha.ac.jp/ mjin/R/36/36.html • https://spark.apache.org/docs/1.6.0/ml-classification-regression.html •

    http://spark.apache.org/docs/latest/sql-programming-guide.html • http://aaaaushisan.blogspot.jp/2011/12/survreg-weibreg.html • http://blog.gepuro.net/archives/102 17/ 18