Upgrade to Pro — share decks privately, control downloads, hide ads and more …

Kaggle State Farm Distracted Driver Detection

Takuya Akiba
February 04, 2017

Kaggle State Farm Distracted Driver Detection

Kaggle の画像分類コンテスト State Farm Distracted Driver Detection にて 1440 チーム中 9 位を取った際のアプローチについて。

Takuya Akiba

February 04, 2017
Tweet

Other Decks in Programming

Transcript

  1. 2017/02/04 @ Kaggle Tokyo Meetup #2
    Kaggle State Farm
    Distracted Driver Detection
    by iwiwi

    View Slide

  2. \ ヽ | / /
    \ ヽ | / /
    \ /
    _ わ た し で す _
    / ̄\
    ― |^o^|
    ― \_/  ̄

    / \
    / / | ヽ \
    / / | ヽ \

    View Slide

  3. はじめに

    View Slide

  4. このコンペは
    僕にとって

    View Slide

  5. 初めての
    Kaggle

    View Slide

  6. どころか

    View Slide

  7. 初めての
    機械学習

    View Slide

  8. 本発表の注意!
    l 当時は右も左もわからない状態
    l 今でも皆さんと違って機械学習歴半年ちょい
    l くれぐれも真に受けないでください!
    (変なところ突っ込んでもらえるとうれしいです!)

    View Slide

  9. ⽬次
    1. 問題とデータ
    2. 僕のアプローチ

    View Slide

  10. 1
    問題とデータ

    View Slide

  11. 問題
    l ⼊⼒:ドライバーの画像
    l 出⼒:運転態度を 10 クラスに分類
    l 評価:Multi-class log loss (cross entropy)

    View Slide

  12. クラス 0:安全運転

    View Slide

  13. クラス 1:右⼿でスマホ操作

    View Slide

  14. クラス 2:右⼿で通話

    View Slide

  15. クラス
    c0: normal driving
    c1: texting - right
    c2: talking on the phone - right
    c3: texting - left
    c4: talking on the phone - left
    c5: operating the radio
    c6: drinking
    c7: reaching behind
    c8: hair and makeup
    c9: talking to passenger

    View Slide

  16. データ
    l imgs.zip
    n 訓練画像:22,424
    n テスト画像:79,726
    l driver_imgs_list.csv
    n 各訓練画像のドライバーが誰か

    View Slide

  17. 余談:LB での accuracy
    l 普通に提出するとテストデータの public LB 部分
    での logloss が⼿に⼊る(当たり前)
    l 実はテストデータの public LB 部分での
    accuracy も⼿に⼊れられる
    n 予想確率を one-hot っぽくして提出
    (例:予想クラスを 0.91、それ以外を 0.01)
    n 正解時のスコアと不正解時のスコアが⼀定
    n 正解の割合が計算できる

    View Slide

  18. 2
    僕のアプローチ

    View Slide

  19. 第⼀話:leak の洗礼

    View Slide

  20. Kaggle 初挑戦!
    l Chainer 付属の AlexNet
    l とりあえず放り込んでみる
    l validation loss 0.117
    l 1 位より遥か上!?やったか!?

    View Slide

  21. 初提出スコア
    2.106
    あれ、0.117 は!?\(^o^)/
    ほぼサンプルのスコアじゃん!?

    View Slide

  22. Validation はちゃんとやる
    l Train / validation を完全ランダムに選択し
    ていた
    l 同じドライバーの画像は極めて似てる
    l テストデータと訓練データは違うドライバー
    l Train / validation もドライバーで分割しな
    いとだめ(そのための driver_imgs_list.csv!)
    l ド定番素⼈ミス

    View Slide

  23. 第⼆話:Pretrained VGG

    View Slide

  24. Pre-trained VGG16
    l このコンテストは external data 使⽤可
    l Kaggle では初? Forum では議論が⽩熱
    l Pre-trained model は絶対使う⽅が良い
    (と思った)
    l Pre-trained model を使うため、
    CNN のモデルは既存のものから選ぶだけ(楽になった?)

    View Slide

  25. Pre-trained VGG16
    [https://blog.heuritech.com/2016/02/29/a-brief-report-of-the-heuritech-deep-learning-meetup-5/]

    View Slide

  26. Pre-trained VGG16
    l 当時定番っぽかった VGG16 を選択
    l Fine tuning は⼩さめの学習率
    n 0.001 とかからスタート
    n (参考:ImageNet は普通 0.1 スタート)
    l 2,3 エポックで validation loss はサチってた
    (短い気がするが後述のデータの性質が関係?)

    View Slide

  27. 第三話:Model Averaging

    View Slide

  28. Model Averaging
    l 5 個独⽴に学習して予測結果を平均
    l 0.31661 → 0.22374
    l こっそり出てたけどこの辺でバレる

    View Slide

  29. 第四話:Data Augmentation

    View Slide

  30. Data Augmentation とは
    l 訓練時に画像を変形させたりする
    l 擬似的に訓練データを増やす
    [http://ultraist.hatenablog.com/entry/2015/03/20/121031]

    View Slide

  31. Data Augmentation
    l NSDB プランクトン分類コンペ優勝チーム
    のコードから拝借
    (MIT ライセンス https://github.com/benanne/kaggle-ndsb)
    l しかし、スコア変わらず (´ε`;)ウーン…

    View Slide

  32. 仮説1
    l Pre-trained model を使っているから?
    l ⼊⼒に近い⽅の層で、これらの操作で⽣
    まれる差は既に吸収されているのかも?

    View Slide

  33. 仮説2
    l データが既に “Augment” されている?
    l 動画切り出しでデータが作られており、
    同じドライバーの同じクラスの画像は酷似
    【極端な解釈】
    l ドライバー 26 ⼈ × 10 クラスで、
    訓練データは本質的には 260 枚の「画像」
    l 各「画像」は 100 枚弱に “Augment” されていて、
    ⼊⼒データは 22 万枚である(ように錯覚している!?)

    View Slide

  34. 仮説3
    l 当時の検証が適当だっただけ説
    l 後半では密かに効果を発揮していた説
    l ……

    View Slide

  35. 第四話:Pseudo Labeling

    View Slide

  36. 学習データが少ない!
    l 先程の「極端な解釈」によると本質的に訓練画像はたっ
    た 260 パターン
    l すぐモデルが overfit してしまうのは単純に学習データ
    が少なすぎるせいでは?
    l テストデータは 3 倍ぐらいある、利⽤できないか?
    l NDSB プランクトン分類コンペ優勝チームのブログによ
    ると Pseudo Labeling を使っている
    http://benanne.github.io/2015/03/17/plankton.html

    View Slide

  37. Pseudo Labeling とは
    l ⼀種の半教師有り学習のフレームワーク
    l 訓練データに加え、テストデータを⽤いて訓練
    l テストデータのラベルは以前のモデルによる予測
    l 訓練データに対しテストデータが相対的に少なくな
    るように⼊れるほうが良いっぽい
    (NDSB 優勝チームは 2:1 になるようにしていたので僕もそれに従った)

    View Slide

  38. Pseudo Labeling の効果
    l Pseudo Label にはアンサンブルで得たシングルモ
    デルより精度の⾼いラベルを使える
    l データが多くなり、より⼤きく汎化性能の⾼いモデ
    ルを安定して学習させられる(という説)
    Distillation に似ているが、Distillation は⼩さいモデルを作りたくて、Pseudo
    Labeling ではより⼤きいモデルを作りたい
    l 0.22374 → 0.21442

    View Slide

  39. 第五話:-NN(闇)

    View Slide

  40. データセットの性質
    https://www.kaggle.com/titericz/state-farm-distracted-driver-detection/just-relax-and-watch-some-cool-movies/code
    このカーネルが出る前から気づいていた⼈は多かったのではないかと予想してます。
    僕もこれが出るより前からこの性質を使ってました。

    View Slide

  41. データセットの性質
    l 動画切り出し
    l 時系列で前後の写真は同じクラス
    l 時系列で前後の写真は画像として酷似
    l 前後の写真を探して、結果を統合しよう

    View Slide

  42. 画像のまま k-NN
    l 計算が重いので画像は縮⼩して計算
    l 距離は L2 より L1 のほうがやや精度が良い
    0.999
    0.996 0.995
    0.992 0.991
    0.989
    0.986
    0.983
    0.979
    0.976
    0.999 0.997 0.996 0.995
    0.993
    0.991
    0.989
    0.987
    0.986
    0.984
    0.96
    0.97
    0.98
    0.99
    1
    1 2 3 4 5 6 7 8 9 10
    同じクラスの割合
    k
    L2 L1

    View Slide

  43. k-NN 結果の使い⽅
    l 予測時に k-NN への予測結果を平均する
    n k=10
    n i 番⽬の近傍は重み 0.9%
    l 0.21442 → 0.19356

    View Slide

  44. Pseudo Labeling との相乗効果
    l k-NN で予測結果を混ぜわせる
    l それを Pseudo Label にしてまた学習
    l その予測結果を k-NN で混ぜ合わせて……
    k-NN の関係をグラフだと思うと、
    ちょっとグラフベース半教師有り学習っぽい

    View Slide

  45. ちなみに
    l k-NN なんてそんな荒削りなことするの、
    ド素⼈な僕だけだよな・・・?と思ってた
    l が、フタを開けると、1 位を含むかなり多く
    の上位者が使っていた・・・!
    l あと、画像間の距離でのクラスタリングで⼈
    間がすっぱり別れないかとちょっと試したけ
    ど、そのまま使えそうな結果にはならず

    View Slide

  46. 第六話:Cross PL(仮)

    View Slide

  47. きっかけ
    l モデルをついに ResNet にした
    l これはスコア爆上げ間違い無し!( ・´ー・`)どや
    ・・・
    l 0.18072 が最⾼、0.17725 (VGG) を抜けない
    l ⼤きく悪くならないが、良くもならない

    View Slide

  48. ResNet
    [Figure 3, He+, ʼ15]

    View Slide

  49. 考えたこと
    l Pseudo Labeling が効きすぎているのでは?
    l Pseudo Label を暗記されてそのまま出⼒され
    ては困る
    l そこそこ離れた画像への Pseudo Label だけ
    が学習に影響してほしい
    l あと、完全なクラスタリングはどうしてもうまくいかないので、
    クラスタリングがそこそこうまくいくっぽいことを利⽤したい

    View Slide

  50. そこでやってみたこと
    Cross Pseudo Labeling(仮)
    1. テストデータを 個のグループに分割
    (今回はここで画像に対するクラスタリングを使う)
    2. グループ を予測するモデルは、以下で学習
    n 訓練データ
    n グループ 以外のテストデータ
    (Pseudo Label)

    View Slide

  51. Cross Pseudo Labeling 結果
    l 0.17725 → 0.14573
    l ⼤幅改善に成功
    (ただし ResNet への変更の影響も含むはず)

    View Slide

  52. 第七話:ResDrop

    View Slide

  53. ResDrop [Huang+ʼ16] とは
    l ResNet の層(ResBlock)を確率的に Drop

    View Slide

  54. ResDrop
    l 最後の晩の悪あがき
    l ResNet のネットワーク定義を公開してた⼯藤く
    んのツイッターで存在は知っていた
    l ResDrop のネットワーク定義にすり替えただけ
    n ResNet の pretrain model をそのまま読む
    n 他のパラメータ調整せず(時間なかった)

    View Slide

  55. ResDrop
    l Public LB:0.14126 → 0.1378
    l まあちょっと上がったかな?
    l ・・・
    と思っていたが!

    View Slide

  56. Private LB
    0.16484

    0.14866

    View Slide

  57. 最終順位
    19 位

    9 位

    View Slide

  58. ありがとう
    ResDrop!
    (&ありがとう⼯藤くん)

    View Slide

  59. まとめ:僕のアプローチ
    l ResDrop-{101,152}
    l Cross Pseudo Labeling
    l k-NN
    l 運が良かった気がします

    View Slide

  60. おまけ:toshi_k さんの解法
    闇っぽいテクが全く無いのに同等の精度……!
    [https://www.kaggle.com/c/state-farm-distracted-driver-detection/discussion/22666]

    View Slide

  61. おまけ2:toshi_k さんのと混ぜた
    解答ファイルを頂いて単純に平均
    スコアは 0.02 も向上!

    View Slide