Upgrade to Pro — share decks privately, control downloads, hide ads and more …

Long-tail learning via logit adjustment

Sho Yokoi
September 16, 2021

Long-tail learning via logit adjustment

2021-09-16, 第13回最先端NLP勉強会
https://sites.google.com/view/snlp-jp/home/2021

Menon et al., Long-tail learning via logit adjustment (ICLR 2021) の論文紹介です
https://openreview.net/forum?id=37nvvqkCo5

Sho Yokoi

September 16, 2021
Tweet

More Decks by Sho Yokoi

Other Decks in Research

Transcript

  1. Long-tail learning via logit adjustment Aditya Krishna Menon, Sadeep Jayasumana,

    Ankit Singh Rawat, Himanshu Jain, Andreas Veit, Sanjiv Kumar (Google Research) ICLR 2021 https://openreview.net/forum?id=37nvvqkCo5 読む⼈︓横井 祥 (東北⼤学/理研AIP) 2021-09-16, 第13回最先端NLP勉強会 (2021-09-30, 少し更新)
  2. “Long-tail learning via logit adjustment” ざっくりこういう話です 2 • 問題設定︓Softmax を使った分類

    − NLP でも頻出 e.g. ⾔語モデルの単語予測 • 課題︓クラス分布が不均衡な場合,⾼頻度クラスばかり予 測され低頻度クラスでは汎化されない − クラス不均衡は NLP でも頻出 e.g. 単語頻度分布は long tail • やったこと︓均衡誤差 (balanced error) の最⼩化という観 点で,新しい訓練戦略・予測戦略を提案 − Softmax への⼊⼒を少しいじるだけ • 結果︓「⾼頻度クラスばかり〜」問題が緩和 ※ もう少し正確なまとめは最後のスライドに
  3. “Long-tail learning via logit adjustment” ざっくりこういう話です 3 • 問題設定︓Softmax を使った分類

    − NLP でも頻出 e.g. ⾔語モデルの単語予測 • 課題︓クラス分布が不均衡な場合,⾼頻度クラスばかり予 測され低頻度クラスでは汎化されない − クラス不均衡は NLP でも頻出 e.g. 単語頻度分布は long tail • やったこと︓均衡誤差 (balanced error) の最⼩化という観 点で,新しい訓練戦略・予測戦略を提案 − Softmax への⼊⼒を少しいじるだけ • 結果︓「⾼頻度クラスばかり〜」問題が緩和 ※ もう少し正確なまとめは最後のスライドに long-tail learning via logit adjustment
  4. “Long-tail learning via logit adjustment” ざっくりこういう話です 4 • 問題設定︓Softmax を使った分類

    − NLP でも頻出 e.g. ⾔語モデルの単語予測 • 課題︓クラス分布が不均衡な場合,⾼頻度クラスばかり予 測され低頻度クラスでは汎化されない − クラス不均衡は NLP でも頻出 e.g. 単語頻度分布は long tail • やったこと︓均衡誤差 (balanced error) の最⼩化という 観点で,新しい訓練戦略・予測戦略を提案 − Softmax への⼊⼒を少しいじるだけ • 結果︓「⾼頻度クラスばかり〜」問題が緩和 ※ もう少し正確なまとめは最後のスライドに long-tail learning via logit adjustment これがキーワード あとから説明します
  5. 注 5 • NLP の話ではありません.が,NLP に効きそうな話です. − 機械学習・表現学習の会議 (ICLR) の論⽂で,実験でも画像データを

    ⽤いています. − ただ考え⽅は NLP の研究・開発に⽰唆を与えるものでした. • 論⽂の⼀部のみを紹介します. − 論⽂のコンテンツは盛りだくさん,かつ時間も少ないので. − 復習しやすいようスライドはできるだけ self-contained にしておき ます.⾶ばすコンテンツは をつけておきます. • 数式も基本的に全部⾶ばします. − ⼤事な式だけ,その読み⽅ (お気持ち) の説明をします. − 復習しやすいように式も貼っておきます. − 「softmax を使う分類器」だけ知っていれば⼗分わかる内容です. SKIP
  6. 問題設定︓ 親の顔より⾒た softmax → cross entropy 7 • タスク︓多クラス分類 •

    モデル︓特徴抽出 → softmax − スコア関数 − 予測分布︓ • 学習︓cross entropy 損失 • 予測︓スコアの argmax exp に⼊れる直前のアレです. ⼊⼒を NN (等) でベクトルにして (Φ(x)) 予測の⾏列の各列 (クラスベクトル w_y) と内積をとったもの.
  7. 課題︓Class imbalance 8 • 実⽤上多くの分類問題のクラス分布は不均衡 − たとえば⾔語モデルの単語予測 – 単語頻度分布 (クラス分布)

    は経験的にべき分布によく従うことが知られて いる (Zipf 則). − 注︓タイトルに “long-tail” とあるが tail の具体的な形状は扱わな い.べき分布云々という話ではない.先⾏研究と同様,単に class imbalance (non-uniform) の意で標語的に “long-tail” という語が⽤ いられている. • → 問題︓⾼頻度クラスへの過剰適合 − ⾼頻度クラスばかり選ばれるようになる − 低頻度クラスで汎化されない • これまでもたくさん掘られてきた.キーワード︓ − class imbalance − cost-sensitive learning
  8. クラス不均衡問題への既存の対策 9 • いろいろある − 既存⽅針0︓訓練データの over-, under-sampling − 既存⽅針1︓予測時に⾼頻度クラスをディスカウント

    – 1-1︓スコアを ||w_y|| で割る – 1-2︓スコアを p(y) で割る − 既存⽅針2︓訓練時に低頻度クラスを重み付け – 2-1︓損失を p(y) で割る – 2-2︓rare positive と negative の margin を広げる – 2-3︓positive と rare negative の margin を広げない • このあと2つだけ紹介します. − 提案法によって既存法を “改善” するので,⽐較⽤に. − 「予測時にがんばる」「訓練時にがんばる」の⼤きく2流儀あること だけ抑えてもらえれば⼗分です. SKIP
  9. クラス不均衡への対策 最近の流れ1︓予測時に⾼頻度クラスをディスカウント 10 • 例1-1︓ノルムを⽤いた weight normalization − 予測時にスコア関数を ||w_y||

    で割る − お気持ち︓p(y) と ||w_y|| は単調 なので − 頻度に応じて予測をディスカウントする • 著者らの突っ込み − 頻度とノルムは実際は単調ではない − optimizer による ⾼頻度クラス Momentum は確かに p(y) と ||w_y|| が単調 Adam だと全然ダメ ||w_y|| ⼤きい ||w_y||
  10. クラス不均衡への対策 最近の流れ2︓訓練時に低頻度クラスに重み付け 11 • 例2-2︓低頻度クラスを持つインスタンスでの訓練時に 他クラスとのマージンを強く広げる • 著者らの突っ込み − こうした損失は

    Balanced error (後述) と⾮整合的 − “本来⽬指すべきゴール” とギャップがある 低頻度クラスの場合に (p(y) が⼩さい場合に) 損失を重く⾒積もる,たくさん更新する 正例スコア f_y(x) と 負例スコア f_yʼ(x) の差を 強く広げさせる
  11. クラス不均衡問題への既存の対策 (再) 12 • 簡単サマリ − 既存⽅針0︓訓練データの over-, under-sampling –

    ➡ 既存⽅針1,2と組合せられる話なので今回は除外 − 既存⽅針1︓予測時に⾼頻度クラスをディスカウント – 1-1︓スコアを ||w_y|| で割る ➡ ノルムと頻度は相関しない – 1-2︓スコアを p(y) で割る ➡スコアが負の場合に状況が悪化する − 既存⽅針2︓訓練時に低頻度クラスを重み付け – 2-1︓損失を p(y) で割る ➡ 割る前と割った後で最適解が動かない – 2-2︓rare positive と negative の margin を広げる ➡ balanced error と⼀貫的でない – 2-3︓positive と rare negative の margin を広げない ➡ 〃 • 他の既存法についても個々に丁寧に反駁されています. SKIP
  12. 誤分類率から均衡誤差 (balanced error) へ 14 • 復習︓我々が使っていた cross entropy 損失

    の気持ち − 狙っていたゴール︓誤分類率 (misclassification error) の最⼩化 – 0-1損失は微分できないので代理損失の cross entropy 損失を使う. − データ全体での平均的な誤分類率を最⼩化しようとしていた. – 訓練データ {(x,y)} ~ P_x,y に対して y を打率良く当てるよう訓練
  13. 誤分類率から均衡誤差 (balanced error) へ 15 • 復習︓我々が使っていた cross entropy 損失

    の気持ち − 狙っていたゴール︓誤分類率 (misclassification error) の最⼩化 – 0-1損失は微分できないので代理損失の cross entropy 損失を使う. − データ全体での平均的な誤分類率を最⼩化しようとしていた. – 訓練データ {(x,y)} ~ P_x,y に対して y を打率良く当てるよう訓練 − Cross entropy で訓練されるモデルの気持ち o0( ⾼頻度 y を予測し まくれば正解率上がるじゃん… 勝ったか…?) – これこそがいま起きている問題
  14. 誤分類率から均衡誤差 (balanced error) へ 16 • 誤分類率 (misclassification error) の最⼩化

    − cross entropy 損失はこれ (0-1損失) の代替 − データ全体での平均的な誤分類率を最⼩化しようとしていた • 均衡誤差 (balanced error) の最⼩化が狙うべきゴールでは︖ − お気持ち︓各クラスでバランスよく正解するとえらい – ⾼頻度クラスばかりを「これでしょ…︖」と予測するモデルはダメ – 低頻度クラスも打率⾼く当てないとダメ の平均 を最⼩化 各クラスでの不正解率 “均衡誤差” は紹介者に よる適当な和訳です
  15. 誤分類率から均衡誤差 (balanced error) へ 17 誤差 スコア関数 f: X→Y に満たしてほ

    しい性質.学習のゴール. 損失 訓練時にデータ (x,y) 毎に与える損失 標準的 な学習 誤分類率 (misclassification error)↓ Softmax cross-entropy loss 提案 均衡誤差 (balanced error)↓ ? 全体での誤分類率を下げたい. 正解率を上げたい. クラス毎の誤分類率の平均を下げた い.バランスよく正解してほしい. Q: 損失はどう⽤意すれば良い? Q: 先⾏研究のように 予測時にだけ補正する⽅法は?
  16. 提案法の準備︓ 均衡誤差の最⼩化と誤分類率の最⼩化の関係 18 • 準備︓均衡誤差の最⼩化と誤分類率の最⼩化の関係 − Balanced error を最⼩化するスコア関数 f*

    は argmax_y p(x|y) でクラスを予測する (Eq.7) − → p(y|x) を当てられるスコア関数 s* を以下のように補正すれば balanced loss に従って y を当てられる (Eq.8) • …は今⽇は⾶ばして,次ページから天下り式に結果だけ述べます SKIP cross entropy で最適化したスコア関数 s* から log p(y) を引いて 予測すれば良い Balanced error に基づいて 予測したければ logit adjustment
  17. 均衡誤差の最⼩化を実現するために 提案法1 (予測時に補正) 19 • 提案法1︓Post-hoc logit adjustment − 訓練時は通常通り

    softmax / cross-entropy を利⽤ − 予測時にスコア関数をクラス頻度で補正 − Softmax cross-entropy からモデルを変更しなくて良い. − ※ 温度パラメータτについて︓予測器が p(y|x) にフィットしている かは分からないので (cf. calibration) logit adjustment ここだけ変更すれば良い
  18. 均衡誤差の最⼩化を実現するために 提案法2 (訓練時に補正) 20 • 提案法2︓Logit adjusted loss − 訓練時にスコア関数をクラス頻度で補正

    − 予測時は学習されたスコア関数 f をそのまま利⽤ − Softmax cross-entropy からモデルをほとんど変更しなくて良い. − ※ 学習の結果 balanced error を最⼩化する global optimum (の近 く) までたどり着けるかはまた別の話.実験セクションで検証. logit adjustment ここだけ変更すれば良い
  19. 誤分類率から均衡誤差 (balanced error) へ (再) 21 誤差 スコア関数 f: X→Y

    に満たしてほ しい性質.学習のゴール. 損失 訓練時にデータ (x,y) 毎に与える損失 標準的 な学習 誤分類率 (misclassification error)↓ Softmax cross-entropy loss↓ 提案 均衡誤差 (balanced error)↓ Logit adjusted loss↓ 全体での誤分類率を下げたい. 正解率を上げたい. クラス毎の誤分類率の平均を下げたい. バランスよく正解してほしい. 正例と負例のスコアのマージン を広げたい 低頻度な正例と⾼頻度な負例の マージンを広げる スコア関数を + log p(y) で補正 ※ 予測時は略
  20. 注︓新規性︖ 22 • 個々の観点に関しては既存研究がある − Balanced error (Chan & Stolfo,

    1998; Brodersen et al., 2010; Menon et al., 2013) − Logit adjustment (Fawcett & Provost, 1996; Provost, 2000; Maloof, 2003; Zhou & Liu, 2006; Collell et al., 2016) • ⼿法群を均衡誤差の観点でまとめあげたのが偉い印象 − 既存の {理論, ⼿法} 研究の limitation を個々に丁寧に反駁 − 既存法・提案法で学習されるスコア関数 f が Balanced error を最適 化する f* と Fisher ⼀致性を持つか (推定量として良い性質を持つか ) を横断的に評価 (後述) − 既存法・提案法が実際に Balanced error を下げることができるかど うかを経験的に評価
  21. 理論評価︓各⼿法は Balanced error のための 損失として “良い” か 24 • 各⼿法を

    pairwise margin loss として⼀般化 • 適当な δ: Y→R+ を⽤いて次の形でパラメータα, Δを表せ れば,その損失は均衡誤差と Fisher consistent (Thm. 1) • ➡ 提案法 (δy = p(y)) と古典的 loss modification (δy = 1; Eq. 4) のみ Fisher consistent SKIP 既存⼿法のほとんどは 均衡誤差を最⼩化するような損失になっていない
  22. ⼈⼯データによる実験 26 • Q︓提案法を使うと均衡誤差は最⼩化されるのか? • 実験設定 − タスク︓2値分類 − データ︓2D

    Gaussian mixture からサンプリング – 真の分布をこちらが⽤意 − ⼿法群︓線形分類器 – 学習法 (損失) だけ既存法・提案法で変える − (特に) オラクルベースライン︓均衡誤差にベイズ最適な線形分類器 – 真の分布のパラメータを⾒ながら解析的に作れる • Q (再)︓提案法はオラクルベースラインに迫れるのか?
  23. “Long-tail learning via logit adjustment” まとめ 32 • 問題︓softmax cross-entropy

    で分類すると,クラス分布が不均 衡な場合に問題が起きる − ⾼頻度クラスばかり予測され,低頻度クラスでは汎化されない • 提案︓均衡誤差 (balanced error) の最⼩化という観点での新し い訓練戦略 or 予測戦略の提案 − ゴールを変える – Before: 誤分類率 (正解率) – After: 均衡誤差 (クラス毎の正解率の平均) − Softmax への⼊⼒を少しいじるだけ – 予測時︓スコア関数から log p(y) を引く – 訓練時︓スコア関数に log p(y) を⾜す • 理論的・経験的に提案法の優位性を確認 − 理論︓提案法だけ (※ほぼ) は均衡誤差と Fisher ⼀致性を持つ − 実験︓データを使った実験でも提案法は均衡誤差の意味でかなり良い long-tail learning via logit adjustment
  24. 感想・コメント 33 • よかった点 − 既存法たちとの接続が丁寧で respectful. − とても読みやすい. •

    気になる点 − 均衡誤差が天から降ってくる. – 不均衡ラベルを持つタスク毎に求められるゴール・評価尺度を検討すべき? – 誤差 (評価尺度) ←→ 損失の話の⽂脈をどれくらい抑えているかもわからない. − Logit adjustment を⼊れてもモデルの表現⼒はほとんど変わらない? – softmax に bias 項を⼊れて良いなら bias 項が log p(y) を吸収できる. – 内積のままでも1次元分に log p(y) 分の情報を持たせることができる. – 損失に明⽰的に log p(y) を⼊れることがなぜ / どのように効いてくるのか? − 古典的な loss modification (損失を 1/p(y) 倍, Eq.4) との⽐較が⽢い. – 均衡誤差とFisher⼀致性を持つ,しかも実験での⽐較がない.結構強い可能性…? – ほとんど upsampling/downsampling? − Negative sampling との関係? – SGNS も p(y|x) ∝ p(y)exp(<x,y>) をモデルに訓練している.提案法と⼀貫 的.