Upgrade to Pro — share decks privately, control downloads, hide ads and more …

Long-tail learning via logit adjustment

Sho Yokoi
PRO
September 16, 2021

Long-tail learning via logit adjustment

2021-09-16, 第13回最先端NLP勉強会
https://sites.google.com/view/snlp-jp/home/2021

Menon et al., Long-tail learning via logit adjustment (ICLR 2021) の論文紹介です
https://openreview.net/forum?id=37nvvqkCo5

Sho Yokoi
PRO

September 16, 2021
Tweet

More Decks by Sho Yokoi

Other Decks in Research

Transcript

  1. Long-tail learning via logit adjustment
    Aditya Krishna Menon, Sadeep Jayasumana, Ankit Singh Rawat,
    Himanshu Jain, Andreas Veit, Sanjiv Kumar (Google Research)
    ICLR 2021
    https://openreview.net/forum?id=37nvvqkCo5
    読む⼈︓横井 祥 (東北⼤学/理研AIP)
    2021-09-16, 第13回最先端NLP勉強会
    (2021-09-30, 少し更新)

    View Slide

  2. “Long-tail learning via logit adjustment”
    ざっくりこういう話です
    2
    • 問題設定︓Softmax を使った分類
    − NLP でも頻出 e.g. ⾔語モデルの単語予測
    • 課題︓クラス分布が不均衡な場合,⾼頻度クラスばかり予
    測され低頻度クラスでは汎化されない
    − クラス不均衡は NLP でも頻出 e.g. 単語頻度分布は long tail
    • やったこと︓均衡誤差 (balanced error) の最⼩化という観
    点で,新しい訓練戦略・予測戦略を提案
    − Softmax への⼊⼒を少しいじるだけ
    • 結果︓「⾼頻度クラスばかり〜」問題が緩和
    ※ もう少し正確なまとめは最後のスライドに

    View Slide

  3. “Long-tail learning via logit adjustment”
    ざっくりこういう話です
    3
    • 問題設定︓Softmax を使った分類
    − NLP でも頻出 e.g. ⾔語モデルの単語予測
    • 課題︓クラス分布が不均衡な場合,⾼頻度クラスばかり予
    測され低頻度クラスでは汎化されない
    − クラス不均衡は NLP でも頻出 e.g. 単語頻度分布は long tail
    • やったこと︓均衡誤差 (balanced error) の最⼩化という観
    点で,新しい訓練戦略・予測戦略を提案
    − Softmax への⼊⼒を少しいじるだけ
    • 結果︓「⾼頻度クラスばかり〜」問題が緩和
    ※ もう少し正確なまとめは最後のスライドに
    long-tail learning
    via logit adjustment

    View Slide

  4. “Long-tail learning via logit adjustment”
    ざっくりこういう話です
    4
    • 問題設定︓Softmax を使った分類
    − NLP でも頻出 e.g. ⾔語モデルの単語予測
    • 課題︓クラス分布が不均衡な場合,⾼頻度クラスばかり予
    測され低頻度クラスでは汎化されない
    − クラス不均衡は NLP でも頻出 e.g. 単語頻度分布は long tail
    • やったこと︓均衡誤差 (balanced error) の最⼩化という
    観点で,新しい訓練戦略・予測戦略を提案
    − Softmax への⼊⼒を少しいじるだけ
    • 結果︓「⾼頻度クラスばかり〜」問題が緩和
    ※ もう少し正確なまとめは最後のスライドに
    long-tail learning
    via logit adjustment
    これがキーワード
    あとから説明します

    View Slide


  5. 5
    • NLP の話ではありません.が,NLP に効きそうな話です.
    − 機械学習・表現学習の会議 (ICLR) の論⽂で,実験でも画像データを
    ⽤いています.
    − ただ考え⽅は NLP の研究・開発に⽰唆を与えるものでした.
    • 論⽂の⼀部のみを紹介します.
    − 論⽂のコンテンツは盛りだくさん,かつ時間も少ないので.
    − 復習しやすいようスライドはできるだけ self-contained にしておき
    ます.⾶ばすコンテンツは をつけておきます.
    • 数式も基本的に全部⾶ばします.
    − ⼤事な式だけ,その読み⽅ (お気持ち) の説明をします.
    − 復習しやすいように式も貼っておきます.
    − 「softmax を使う分類器」だけ知っていれば⼗分わかる内容です.
    SKIP

    View Slide

  6. 背景・既存法
    6

    View Slide

  7. 問題設定︓
    親の顔より⾒た softmax → cross entropy
    7
    • タスク︓多クラス分類
    • モデル︓特徴抽出 → softmax
    − スコア関数
    − 予測分布︓
    • 学習︓cross entropy 損失
    • 予測︓スコアの argmax
    exp に⼊れる直前のアレです.
    ⼊⼒を NN (等) でベクトルにして (Φ(x))
    予測の⾏列の各列 (クラスベクトル w_y)
    と内積をとったもの.

    View Slide

  8. 課題︓Class imbalance
    8
    • 実⽤上多くの分類問題のクラス分布は不均衡
    − たとえば⾔語モデルの単語予測
    – 単語頻度分布 (クラス分布) は経験的にべき分布によく従うことが知られて
    いる (Zipf 則).
    − 注︓タイトルに “long-tail” とあるが tail の具体的な形状は扱わな
    い.べき分布云々という話ではない.先⾏研究と同様,単に class
    imbalance (non-uniform) の意で標語的に “long-tail” という語が⽤
    いられている.
    • → 問題︓⾼頻度クラスへの過剰適合
    − ⾼頻度クラスばかり選ばれるようになる
    − 低頻度クラスで汎化されない
    • これまでもたくさん掘られてきた.キーワード︓
    − class imbalance
    − cost-sensitive learning

    View Slide

  9. クラス不均衡問題への既存の対策
    9
    • いろいろある
    − 既存⽅針0︓訓練データの over-, under-sampling
    − 既存⽅針1︓予測時に⾼頻度クラスをディスカウント
    – 1-1︓スコアを ||w_y|| で割る
    – 1-2︓スコアを p(y) で割る
    − 既存⽅針2︓訓練時に低頻度クラスを重み付け
    – 2-1︓損失を p(y) で割る
    – 2-2︓rare positive と negative の margin を広げる
    – 2-3︓positive と rare negative の margin を広げない
    • このあと2つだけ紹介します.
    − 提案法によって既存法を “改善” するので,⽐較⽤に.
    − 「予測時にがんばる」「訓練時にがんばる」の⼤きく2流儀あること
    だけ抑えてもらえれば⼗分です.
    SKIP

    View Slide

  10. クラス不均衡への対策
    最近の流れ1︓予測時に⾼頻度クラスをディスカウント
    10
    • 例1-1︓ノルムを⽤いた weight normalization
    − 予測時にスコア関数を ||w_y|| で割る
    − お気持ち︓p(y) と ||w_y|| は単調 なので
    − 頻度に応じて予測をディスカウントする
    • 著者らの突っ込み
    − 頻度とノルムは実際は単調ではない
    − optimizer による
    ⾼頻度クラス
    Momentum は確かに
    p(y) と ||w_y|| が単調
    Adam だと全然ダメ
    ||w_y|| ⼤きい
    ||w_y||

    View Slide

  11. クラス不均衡への対策
    最近の流れ2︓訓練時に低頻度クラスに重み付け
    11
    • 例2-2︓低頻度クラスを持つインスタンスでの訓練時に
    他クラスとのマージンを強く広げる
    • 著者らの突っ込み
    − こうした損失は Balanced error (後述) と⾮整合的
    − “本来⽬指すべきゴール” とギャップがある
    低頻度クラスの場合に
    (p(y) が⼩さい場合に)
    損失を重く⾒積もる,たくさん更新する
    正例スコア f_y(x) と
    負例スコア f_yʼ(x) の差を
    強く広げさせる

    View Slide

  12. クラス不均衡問題への既存の対策 (再)
    12
    • 簡単サマリ
    − 既存⽅針0︓訓練データの over-, under-sampling
    – ➡ 既存⽅針1,2と組合せられる話なので今回は除外
    − 既存⽅針1︓予測時に⾼頻度クラスをディスカウント
    – 1-1︓スコアを ||w_y|| で割る ➡ ノルムと頻度は相関しない
    – 1-2︓スコアを p(y) で割る ➡スコアが負の場合に状況が悪化する
    − 既存⽅針2︓訓練時に低頻度クラスを重み付け
    – 2-1︓損失を p(y) で割る ➡ 割る前と割った後で最適解が動かない
    – 2-2︓rare positive と negative の margin を広げる ➡ balanced error
    と⼀貫的でない
    – 2-3︓positive と rare negative の margin を広げない ➡ 〃
    • 他の既存法についても個々に丁寧に反駁されています.
    SKIP

    View Slide

  13. 提案法
    13

    View Slide

  14. 誤分類率から均衡誤差 (balanced error) へ
    14
    • 復習︓我々が使っていた cross entropy 損失 の気持ち
    − 狙っていたゴール︓誤分類率 (misclassification error) の最⼩化
    – 0-1損失は微分できないので代理損失の cross entropy 損失を使う.
    − データ全体での平均的な誤分類率を最⼩化しようとしていた.
    – 訓練データ {(x,y)} ~ P_x,y に対して y を打率良く当てるよう訓練

    View Slide

  15. 誤分類率から均衡誤差 (balanced error) へ
    15
    • 復習︓我々が使っていた cross entropy 損失 の気持ち
    − 狙っていたゴール︓誤分類率 (misclassification error) の最⼩化
    – 0-1損失は微分できないので代理損失の cross entropy 損失を使う.
    − データ全体での平均的な誤分類率を最⼩化しようとしていた.
    – 訓練データ {(x,y)} ~ P_x,y に対して y を打率良く当てるよう訓練
    − Cross entropy で訓練されるモデルの気持ち o0( ⾼頻度 y を予測し
    まくれば正解率上がるじゃん… 勝ったか…?)
    – これこそがいま起きている問題

    View Slide

  16. 誤分類率から均衡誤差 (balanced error) へ
    16
    • 誤分類率 (misclassification error) の最⼩化
    − cross entropy 損失はこれ (0-1損失) の代替
    − データ全体での平均的な誤分類率を最⼩化しようとしていた
    • 均衡誤差 (balanced error) の最⼩化が狙うべきゴールでは︖
    − お気持ち︓各クラスでバランスよく正解するとえらい
    – ⾼頻度クラスばかりを「これでしょ…︖」と予測するモデルはダメ
    – 低頻度クラスも打率⾼く当てないとダメ
    の平均 を最⼩化
    各クラスでの不正解率
    “均衡誤差” は紹介者に
    よる適当な和訳です

    View Slide

  17. 誤分類率から均衡誤差 (balanced error) へ
    17
    誤差
    スコア関数 f: X→Y に満たしてほ
    しい性質.学習のゴール.
    損失
    訓練時にデータ (x,y) 毎に与える損失
    標準的
    な学習
    誤分類率 (misclassification
    error)↓
    Softmax cross-entropy loss
    提案
    均衡誤差 (balanced error)↓ ?
    全体での誤分類率を下げたい.
    正解率を上げたい.
    クラス毎の誤分類率の平均を下げた
    い.バランスよく正解してほしい.
    Q: 損失はどう⽤意すれば良い?
    Q: 先⾏研究のように
    予測時にだけ補正する⽅法は?

    View Slide

  18. 提案法の準備︓
    均衡誤差の最⼩化と誤分類率の最⼩化の関係
    18
    • 準備︓均衡誤差の最⼩化と誤分類率の最⼩化の関係
    − Balanced error を最⼩化するスコア関数 f* は
    argmax_y p(x|y) でクラスを予測する (Eq.7)
    − → p(y|x) を当てられるスコア関数 s* を以下のように補正すれば
    balanced loss に従って y を当てられる (Eq.8)
    • …は今⽇は⾶ばして,次ページから天下り式に結果だけ述べます
    SKIP
    cross entropy で最適化したスコア関数
    s* から log p(y) を引いて 予測すれば良い
    Balanced error に基づいて
    予測したければ
    logit adjustment

    View Slide

  19. 均衡誤差の最⼩化を実現するために
    提案法1 (予測時に補正)
    19
    • 提案法1︓Post-hoc logit adjustment
    − 訓練時は通常通り softmax / cross-entropy を利⽤
    − 予測時にスコア関数をクラス頻度で補正
    − Softmax cross-entropy からモデルを変更しなくて良い.
    − ※ 温度パラメータτについて︓予測器が p(y|x) にフィットしている
    かは分からないので (cf. calibration)
    logit adjustment
    ここだけ変更すれば良い

    View Slide

  20. 均衡誤差の最⼩化を実現するために
    提案法2 (訓練時に補正)
    20
    • 提案法2︓Logit adjusted loss
    − 訓練時にスコア関数をクラス頻度で補正
    − 予測時は学習されたスコア関数 f をそのまま利⽤
    − Softmax cross-entropy からモデルをほとんど変更しなくて良い.
    − ※ 学習の結果 balanced error を最⼩化する global optimum (の近
    く) までたどり着けるかはまた別の話.実験セクションで検証.
    logit adjustment
    ここだけ変更すれば良い

    View Slide

  21. 誤分類率から均衡誤差 (balanced error) へ
    (再)
    21
    誤差
    スコア関数 f: X→Y に満たしてほ
    しい性質.学習のゴール.
    損失
    訓練時にデータ (x,y) 毎に与える損失
    標準的
    な学習
    誤分類率 (misclassification
    error)↓
    Softmax cross-entropy loss↓
    提案
    均衡誤差 (balanced error)↓ Logit adjusted loss↓
    全体での誤分類率を下げたい.
    正解率を上げたい.
    クラス毎の誤分類率の平均を下げたい.
    バランスよく正解してほしい.
    正例と負例のスコアのマージン
    を広げたい
    低頻度な正例と⾼頻度な負例の
    マージンを広げる
    スコア関数を
    + log p(y) で補正
    ※ 予測時は略

    View Slide

  22. 注︓新規性︖
    22
    • 個々の観点に関しては既存研究がある
    − Balanced error (Chan & Stolfo, 1998; Brodersen et al., 2010;
    Menon et al., 2013)
    − Logit adjustment (Fawcett & Provost, 1996; Provost, 2000;
    Maloof, 2003; Zhou & Liu, 2006; Collell et al., 2016)
    • ⼿法群を均衡誤差の観点でまとめあげたのが偉い印象
    − 既存の {理論, ⼿法} 研究の limitation を個々に丁寧に反駁
    − 既存法・提案法で学習されるスコア関数 f が Balanced error を最適
    化する f* と Fisher ⼀致性を持つか (推定量として良い性質を持つか
    ) を横断的に評価 (後述)
    − 既存法・提案法が実際に Balanced error を下げることができるかど
    うかを経験的に評価

    View Slide

  23. 理論評価
    23

    View Slide

  24. 理論評価︓各⼿法は Balanced error のための
    損失として “良い” か
    24
    • 各⼿法を pairwise margin loss として⼀般化
    • 適当な δ: Y→R+ を⽤いて次の形でパラメータα, Δを表せ
    れば,その損失は均衡誤差と Fisher consistent (Thm. 1)
    • ➡ 提案法 (δy = p(y)) と古典的 loss modification (δy =
    1; Eq. 4) のみ Fisher consistent
    SKIP
    既存⼿法のほとんどは
    均衡誤差を最⼩化するような損失になっていない

    View Slide

  25. 実験評価
    25

    View Slide

  26. ⼈⼯データによる実験
    26
    • Q︓提案法を使うと均衡誤差は最⼩化されるのか?
    • 実験設定
    − タスク︓2値分類
    − データ︓2D Gaussian mixture からサンプリング
    – 真の分布をこちらが⽤意
    − ⼿法群︓線形分類器
    – 学習法 (損失) だけ既存法・提案法で変える
    − (特に) オラクルベースライン︓均衡誤差にベイズ最適な線形分類器
    – 真の分布のパラメータを⾒ながら解析的に作れる
    • Q (再)︓提案法はオラクルベースラインに迫れるのか?

    View Slide

  27. ⼈⼯データによる実験の結果
    提案法1︓予測時に補正
    提案法は温度パラメータ次第でオラクルベースラインに迫れる
    ベースラインはパラメータをどういじっても迫れない
    27
    better
    SKIP

    View Slide

  28. ⼈⼯データによる実験の結果
    提案法2︓訓練時の損失変更
    提案法はオラクルベースラインに迫る
    =提案した損失は (実際の最適化を伴っても) 均衡誤差をよく最⼩化してくれる28
    better
    提案法 オラクル
    既存法

    View Slide

  29. 実データによる実験
    29
    • Q︓実際のクラス不均衡なデータ (画像分類) で,
    提案法はどれくらいよく均衡誤差を最⼩化してくれる?
    • A︓提案法は既存法に⽐べてなかなか良さそう
    提案法
    既存法
    cross entropy Lower is better
    既存法と混ぜる.
    スコア上げの頑張りを感じる

    View Slide

  30. 実データによる実験
    30
    • Q︓0-1損失 (正解率) をある意味で捨てたわけだけど,正
    解率はどれくらい下がるの?
    • A︓提案法は⾼頻度クラスでやや悪化,低頻度クラスで好転
    − 期待通り
    − ※ overall accuracy も報告してほしかった
    better
    ⾼頻度クラス
    提案法

    View Slide

  31. まとめ
    31

    View Slide

  32. “Long-tail learning via logit adjustment”
    まとめ
    32
    • 問題︓softmax cross-entropy で分類すると,クラス分布が不均
    衡な場合に問題が起きる
    − ⾼頻度クラスばかり予測され,低頻度クラスでは汎化されない
    • 提案︓均衡誤差 (balanced error) の最⼩化という観点での新し
    い訓練戦略 or 予測戦略の提案
    − ゴールを変える
    – Before: 誤分類率 (正解率)
    – After: 均衡誤差 (クラス毎の正解率の平均)
    − Softmax への⼊⼒を少しいじるだけ
    – 予測時︓スコア関数から log p(y) を引く
    – 訓練時︓スコア関数に log p(y) を⾜す
    • 理論的・経験的に提案法の優位性を確認
    − 理論︓提案法だけ (※ほぼ) は均衡誤差と Fisher ⼀致性を持つ
    − 実験︓データを使った実験でも提案法は均衡誤差の意味でかなり良い
    long-tail learning
    via logit adjustment

    View Slide

  33. 感想・コメント
    33
    • よかった点
    − 既存法たちとの接続が丁寧で respectful.
    − とても読みやすい.
    • 気になる点
    − 均衡誤差が天から降ってくる.
    – 不均衡ラベルを持つタスク毎に求められるゴール・評価尺度を検討すべき?
    – 誤差 (評価尺度) ←→ 損失の話の⽂脈をどれくらい抑えているかもわからない.
    − Logit adjustment を⼊れてもモデルの表現⼒はほとんど変わらない?
    – softmax に bias 項を⼊れて良いなら bias 項が log p(y) を吸収できる.
    – 内積のままでも1次元分に log p(y) 分の情報を持たせることができる.
    – 損失に明⽰的に log p(y) を⼊れることがなぜ / どのように効いてくるのか?
    − 古典的な loss modification (損失を 1/p(y) 倍, Eq.4) との⽐較が⽢い.
    – 均衡誤差とFisher⼀致性を持つ,しかも実験での⽐較がない.結構強い可能性…?
    – ほとんど upsampling/downsampling?
    − Negative sampling との関係?
    – SGNS も p(y|x) ∝ p(y)exp() をモデルに訓練している.提案法と⼀貫
    的.

    View Slide