Upgrade to Pro — share decks privately, control downloads, hide ads and more …

This Looks Like That: Deep Learning for Interpretable Image Recognitionをまとめてみた

This Looks Like That: Deep Learning for Interpretable Image Recognitionをまとめてみた

This Looks Like That: Deep Learning for Interpretable Image Recognition
Chaofan, C. Oscar, L. Daniel, T. Alina, B. Cynthia, R. Jonathan, K. In Advances in Neural Information Processing Systems 32 (NeurIPS 2019)

Tomohiro Munemasa

November 30, 2019
Tweet

Other Decks in Research

Transcript

  1. This Looks Like That: Deep Learning for Interpretable Image Recognition

    Chaofan, C. Oscar, L. Daniel, T. Alina, B. Cynthia, R. Jonathan, K. In Advances in Neural Information Processing Systems 32 (NeurIPS 2019) 筑波⼤学院 宗政 友洋 作成者 :
  2. Ø ⼈間が識別するプロセスと近い説明を与えるネットワー クPrototypical Part Network (ProtoPNet) を提案 Ø ProtoPNetは明確な推論プロセスを持つ点で解釈可能で ある

    (interpretable) Ø 既存のDeep leaning model (non-interpretable)と同等な 学習精度を達成していることを実験的に確認 Ø 他のinterpretableなモデルには無い解釈性を提供 尻尾が⼤きい どんぐり ほっぺがふっくら Introduction finding prototypical parts combines evidence from the prototypes make a classification 3
  3. Introduction What is prototype ? ⼀般的な意味として, 各prototypeは訓練データの⼀部を代表する抽象的な表現 つまり, 全てのprototypeは全訓練データを表現する抽象的な表現となっている 予め各クラスに対して複数のprototypeが割り当てられており,

    クラス ∈ {1, … , }に割り当てられる各prototypeは, クラスであると表現される 局所的な表現となっている 分類時には, 学習されたprototypeと⼊⼒画像の類似度を⽐較して推論を⾏う 本論⽂では input four prototypes of sparrow class 類似度を計算 分類 4
  4. Introduction ProtoPNetʼ Classfication Prosess prototypes test image finding prototypical parts

    combines evidence from the prototypes make a classification 5
  5. Introduction Related work post-hoc interpretability analysis (ex. GradCAM, SmoothGrad, ...

    ) ü 訓練されたモデルに対して, 説明性を与える⼿法 ü ネットワークが実際にどのように決定をするかの推論プロセスを説明するものではない attention-based interpretability (Ex. Part R-CNN, PS-CNN, 2-level attn, Neural const, RA-CNN, ...) ü 決定するときに注⽬している箇所を強調する⼿法 ü ⼊⼒のどの部分を注⽬しているのかを⽰すのみ other prototype classification techniques prototype (≒訓練データの⼀部を代表する表現)を⽤いた学習 ü 本論⽂は[24]と最も近い ü prototypeを可視化するためにdecoderを必要 ü ⾃然画像の場合, 可視化の際に現実的なprototypeの作成に失敗 ü decoder不要 ü 全てのprototypeはある訓練データの潜在表現のため, 忠実に可視化可能 本論⽂では, 推論プロセスを含めた説明性を与える 注⽬しているのかを⽰すことに加えて, それらの部分に類似した典型的なケースを⽰している f x [24] O. Li, H. Liu, C. Chen, and C. Rudin. Deep Learning for Case-Based Reasoning through Prototypes: A Neural Network that Explains Its Predictions. In Proceedings of the Thirty-Second AAAI Conference on Artificial Intelligence (AAAI), 2018. 6
  6. Case study 1: bird species identification Convolution layers : parameters

    : ,-./ activation function (excent last layer) : ReLU activation function (last layer) : sigmoid Prototype layers : 1 Fully connected layers : ℎ parameters : 3 no bias 7
  7. Case study 1: bird species identification input : 224×224×3 ∶

    = f(x) 7×7×D ※ D is chosen from three possible values: 128, 256, 512, using cross validation) Convolution layers : parameters : ,-./ activation function (excent last layer) : ReLU activation function (last layer) : sigmoid Prototype layers : 1 Fully connected layers : ℎ parameters : 3 no bias ⼊⼒画像を潜在空間に射影 8
  8. Case study 1: bird species identification input : = f(x)

    H×W×D (= 7×7×D) ü 個prototype : = Q QRS T ,S×S×D (= 1×1×D) ü S ≤ , S ≤ → それぞれのprototype Q は,元の空間で, 画像のあるpatch (≒⼀部分)を表す ü とj番⽬のprototype Q のZ 距離に基づくスコア1[ を計算 ( = f(x)とする) 1\ = max ̃ `∈1abc3de ` log ̃ − Q Z Z + 1 ̃ − Q Z + 1[ が⼤きいほど, 番⽬のprototypeと似ている (近い)patchが⼊⼒画像内にあるということ 1[ に対して, max poolingを⾏うことで類似度 (Similarity score)を算出できる ü それぞれのクラス ∈ {1, … , }に対してm 個のprototypeが割り当てられている Convolution layers : parameters : ,-./ activation function (excent last layer) : ReLU activation function (last layer) : sigmoid Prototype layers : 1 Fully connected layers : ℎ parameters : 3 no bias ∶ 1 1[ QRS T ∈ 1 () 潜在空間内の⼊⼒画像とPrototypeの距離を計算 9
  9. Case study 1: bird species identification Convolution layers : parameters

    : ,-./ activation function (excent last layer) : ReLU activation function (last layer) : sigmoid Prototype layers : 1 Fully connected layers : ℎ parameters : 3 no bias input : 1 7×7×D ∶ ℎ ü ℎ = 3 1 を求める ü あるクラスのprototypeとの類似度に関する重みが学習される 分類におけるそれぞれのprototypeの重要度を求めることができる ※ ProtoPNetの推論メカニズムはいくつかの合理的な仮定のもとで, より⼀般的な確率的推論 (probabilistic inference)とみなすことができる. これについては補⾜資料のS2を参照 潜在空間内の⼊⼒画像とPrototypeの距離に重みを掛けて分類する softmax classification 10
  10. Case study 1: bird species identification Training algorithm (1) stochastic

    gradient descent (SGD) of layers before the last layer (2) projection of prototypes (3) convex optimization of last layer 11
  11. Case study 1: bird species identification Training algorithm (1) stochastic

    gradient descent (SGD) of layers before the last layer Ø ある訓練データに対して, 最も重要なpatchと⾃⾝のクラスのprototypeが潜在空間内の距離が近くなり 他のクラスのprototypeが⼗分に分離されるようなを学習する notation description m ⊆ クラスkのprototype集合 = , = r , r rRS t 訓練データ cross entropy loss(分類誤差)を ⼩さくすることを要請 各訓練データのpatchは他のクラス のprototypeから離れることを要請 各訓練データのpatchは少なく とも1つの⾃⾝のクラスの prototypeに近いことを要請 object function 12
  12. Case study 1: bird species identification Training algorithm (1) stochastic

    gradient descent (SGD) of layers before the last layer Ø 全結合層の重みを以下のようにして, 学習する 3 m,Q = 1 ℎ Q ∈ m , m,Q 3 = −0.5 ℎ Q ∉ m ≠ であるに対して, ()とQ ∈ m の距離が近い場合に分類精度が悪化 互いに異なるクラスのprototype同⼠が離れるように学習される notation description (, ) 訓練データ, 正解データ m ⊆ クラスkのprototype集合 3 m,Q 番⽬のprototype unitとクラスkのロジットの重み 13
  13. Case study 1: bird species identification Training algorithm Ø prototype

    Q を可視化するために Q と同クラスで潜在空間内で最も近い訓練データのpatchに置き換える (2) projection of prototypes prototypeの射影がprototypeをあまり動かさないならば, 予測結果は変化しない (Theorem 2.1.) ü prototypeの射影が分類精度にどのような影響を及ぼすか理論的に述べる 14
  14. Case study 1: bird species identification Theorem 2.1. Let Suppose

    射影後, 正しいクラスのロジットは最⼤でΔ€•‚ = ƒ log((1 + )(2 − ))減少し, 全ての異なるクラス ≠ のロジットは最⼤でΔ€•‚ 増加する 上位2クラスのロジット が少なくとも2Δ€•‚ 離れているならば, prototypeの射影はの予測を変更しない notation description † m 射影前のクラスの番⽬のprototype † m 射影後のクラスの番⽬のprototype ⼊⼒画像 の正解ラベル † m † mに最も近い()のpatch (argmin ̃ `∈ˆ•‰,Š‹Œ(•(Ž)) ̃ − † m Z ) i. 以下を満たす0 < < 1が存在する a. 全ての異なるクラス ≠ のprototypeにおいて, † m − † m Z ≤ † m − † m Z − , ℎ = min 1 + − 1,1 − S Z’“ b. 同⼀クラスのprototypeにおいて, † c − † c Z ≤ 1 + − 1 † c − † c Z † c − † c Z ≤ 1 − ii. それぞれのクラスのprototypeの数は等しい (ƒ) iii. 最終層の重みは以下となる 3 (m,Q) = 1 ℎ ℎ Q ∈ m, 3 (m,Q) = 0 ℎ ℎ Q ∉ m Then 15
  15. Case study 1: bird species identification Training algorithm (3) convex

    optimization of last layer Ø Q ∉ m でのとの接続 3 (m,Q)を 3 (m,Q) ≈ 0にしたい 否定的な推論 (ex. この⿃はクラスのprototypeに近いpatchを持っていないからクラスである) を⾏いたくない object function 16
  16. Case study 1: bird species identification Prototype visualization Ø デコーダーを⽤いずに,

    潜在空間上のprototype Q を元の画像空間に戻して可視化する ü prototype Q は()に射影されているため, Q によるのactivation mapを利⽤してQ を可視化する (1) activatiom mapは7×7なので, Q によるのactivation mapを元の次元224×224にupsamplingを⾏う (2) activation mapの合計ピクセル値に対して95%以上の割合を囲める最⼩な四⾓形を求める (3) 得られた四⾓形を切り取る input 1\ (()) upsampling and pile up ※ 論⽂中の別々の画像を使⽤しているため, 画像の対応が取れていない 17
  17. Case study 1: bird species identification Reasoning process of our

    network Ø テストデータに対するシマセゲラ (red-belied woodpecker)の推論プロセス ü prototype r と()のpatchを⽐較し, similarity scoreを求め, シマセゲラに対応する最終層の重みをかけることでprototype r を求めることができる (points contributed) points contributedの和を取ることで, テストデータがシマセゲラに属するスコアを求めることができる 18
  18. Case study 1: bird species identification Comparison with baseline models

    and attention-based interpretable deep models 19
  19. Case study 1: bird species identification Comparison with baseline models

    and attention-based interpretable deep models Ø ProtoPNetとBaseline (without the prototype layer)との精度⽐較 Ø ProtoPNetと既存⼿法との精度⽐較 full : model was trained on full images bb : model was trained on images cropped usinng bounding boxes (or the model used bounding boxes in other ways) anno : model was trained with keypoint annotations of bird parts 20
  20. Case study 1: bird species identification Analysis of latent space

    and prototype pruning test image three nearest prototype Ø ある検証データのpatchに対して近傍にあるprototypeの上位3件 Ø あるprototypeに対して近傍にあるpatchの上位3件 ü 検証データの近傍にあるprototypeは同⼀のクラスからなる ü あるprototypeのpatchは同質の意味概念を持つ prototype three nearest training patches three nearest test patches 21