2021-09-16, 第13回最先端NLP勉強会 https://sites.google.com/view/snlp-jp/home/2021
Menon et al., Long-tail learning via logit adjustment (ICLR 2021) の論文紹介です https://openreview.net/forum?id=37nvvqkCo5
Long-tail learning via logit adjustmentAditya Krishna Menon, Sadeep Jayasumana, Ankit Singh Rawat,Himanshu Jain, Andreas Veit, Sanjiv Kumar (Google Research)ICLR 2021https://openreview.net/forum?id=37nvvqkCo5読む⼈︓横井 祥 (東北⼤学/理研AIP)2021-09-16, 第13回最先端NLP勉強会(2021-09-30, 少し更新)
View Slide
“Long-tail learning via logit adjustment”ざっくりこういう話です2• 問題設定︓Softmax を使った分類− NLP でも頻出 e.g. ⾔語モデルの単語予測• 課題︓クラス分布が不均衡な場合,⾼頻度クラスばかり予測され低頻度クラスでは汎化されない− クラス不均衡は NLP でも頻出 e.g. 単語頻度分布は long tail• やったこと︓均衡誤差 (balanced error) の最⼩化という観点で,新しい訓練戦略・予測戦略を提案− Softmax への⼊⼒を少しいじるだけ• 結果︓「⾼頻度クラスばかり〜」問題が緩和※ もう少し正確なまとめは最後のスライドに
“Long-tail learning via logit adjustment”ざっくりこういう話です3• 問題設定︓Softmax を使った分類− NLP でも頻出 e.g. ⾔語モデルの単語予測• 課題︓クラス分布が不均衡な場合,⾼頻度クラスばかり予測され低頻度クラスでは汎化されない− クラス不均衡は NLP でも頻出 e.g. 単語頻度分布は long tail• やったこと︓均衡誤差 (balanced error) の最⼩化という観点で,新しい訓練戦略・予測戦略を提案− Softmax への⼊⼒を少しいじるだけ• 結果︓「⾼頻度クラスばかり〜」問題が緩和※ もう少し正確なまとめは最後のスライドにlong-tail learningvia logit adjustment
“Long-tail learning via logit adjustment”ざっくりこういう話です4• 問題設定︓Softmax を使った分類− NLP でも頻出 e.g. ⾔語モデルの単語予測• 課題︓クラス分布が不均衡な場合,⾼頻度クラスばかり予測され低頻度クラスでは汎化されない− クラス不均衡は NLP でも頻出 e.g. 単語頻度分布は long tail• やったこと︓均衡誤差 (balanced error) の最⼩化という観点で,新しい訓練戦略・予測戦略を提案− Softmax への⼊⼒を少しいじるだけ• 結果︓「⾼頻度クラスばかり〜」問題が緩和※ もう少し正確なまとめは最後のスライドにlong-tail learningvia logit adjustmentこれがキーワードあとから説明します
注5• NLP の話ではありません.が,NLP に効きそうな話です.− 機械学習・表現学習の会議 (ICLR) の論⽂で,実験でも画像データを⽤いています.− ただ考え⽅は NLP の研究・開発に⽰唆を与えるものでした.• 論⽂の⼀部のみを紹介します.− 論⽂のコンテンツは盛りだくさん,かつ時間も少ないので.− 復習しやすいようスライドはできるだけ self-contained にしておきます.⾶ばすコンテンツは をつけておきます.• 数式も基本的に全部⾶ばします.− ⼤事な式だけ,その読み⽅ (お気持ち) の説明をします.− 復習しやすいように式も貼っておきます.− 「softmax を使う分類器」だけ知っていれば⼗分わかる内容です.SKIP
背景・既存法6
問題設定︓親の顔より⾒た softmax → cross entropy7• タスク︓多クラス分類• モデル︓特徴抽出 → softmax− スコア関数− 予測分布︓• 学習︓cross entropy 損失• 予測︓スコアの argmaxexp に⼊れる直前のアレです.⼊⼒を NN (等) でベクトルにして (Φ(x))予測の⾏列の各列 (クラスベクトル w_y)と内積をとったもの.
課題︓Class imbalance8• 実⽤上多くの分類問題のクラス分布は不均衡− たとえば⾔語モデルの単語予測– 単語頻度分布 (クラス分布) は経験的にべき分布によく従うことが知られている (Zipf 則).− 注︓タイトルに “long-tail” とあるが tail の具体的な形状は扱わない.べき分布云々という話ではない.先⾏研究と同様,単に classimbalance (non-uniform) の意で標語的に “long-tail” という語が⽤いられている.• → 問題︓⾼頻度クラスへの過剰適合− ⾼頻度クラスばかり選ばれるようになる− 低頻度クラスで汎化されない• これまでもたくさん掘られてきた.キーワード︓− class imbalance− cost-sensitive learning
クラス不均衡問題への既存の対策9• いろいろある− 既存⽅針0︓訓練データの over-, under-sampling− 既存⽅針1︓予測時に⾼頻度クラスをディスカウント– 1-1︓スコアを ||w_y|| で割る– 1-2︓スコアを p(y) で割る− 既存⽅針2︓訓練時に低頻度クラスを重み付け– 2-1︓損失を p(y) で割る– 2-2︓rare positive と negative の margin を広げる– 2-3︓positive と rare negative の margin を広げない• このあと2つだけ紹介します.− 提案法によって既存法を “改善” するので,⽐較⽤に.− 「予測時にがんばる」「訓練時にがんばる」の⼤きく2流儀あることだけ抑えてもらえれば⼗分です.SKIP
クラス不均衡への対策最近の流れ1︓予測時に⾼頻度クラスをディスカウント10• 例1-1︓ノルムを⽤いた weight normalization− 予測時にスコア関数を ||w_y|| で割る− お気持ち︓p(y) と ||w_y|| は単調 なので− 頻度に応じて予測をディスカウントする• 著者らの突っ込み− 頻度とノルムは実際は単調ではない− optimizer による⾼頻度クラスMomentum は確かにp(y) と ||w_y|| が単調Adam だと全然ダメ||w_y|| ⼤きい||w_y||
クラス不均衡への対策最近の流れ2︓訓練時に低頻度クラスに重み付け11• 例2-2︓低頻度クラスを持つインスタンスでの訓練時に他クラスとのマージンを強く広げる• 著者らの突っ込み− こうした損失は Balanced error (後述) と⾮整合的− “本来⽬指すべきゴール” とギャップがある低頻度クラスの場合に(p(y) が⼩さい場合に)損失を重く⾒積もる,たくさん更新する正例スコア f_y(x) と負例スコア f_yʼ(x) の差を強く広げさせる
クラス不均衡問題への既存の対策 (再)12• 簡単サマリ− 既存⽅針0︓訓練データの over-, under-sampling– ➡ 既存⽅針1,2と組合せられる話なので今回は除外− 既存⽅針1︓予測時に⾼頻度クラスをディスカウント– 1-1︓スコアを ||w_y|| で割る ➡ ノルムと頻度は相関しない– 1-2︓スコアを p(y) で割る ➡スコアが負の場合に状況が悪化する− 既存⽅針2︓訓練時に低頻度クラスを重み付け– 2-1︓損失を p(y) で割る ➡ 割る前と割った後で最適解が動かない– 2-2︓rare positive と negative の margin を広げる ➡ balanced errorと⼀貫的でない– 2-3︓positive と rare negative の margin を広げない ➡ 〃• 他の既存法についても個々に丁寧に反駁されています.SKIP
提案法13
誤分類率から均衡誤差 (balanced error) へ14• 復習︓我々が使っていた cross entropy 損失 の気持ち− 狙っていたゴール︓誤分類率 (misclassification error) の最⼩化– 0-1損失は微分できないので代理損失の cross entropy 損失を使う.− データ全体での平均的な誤分類率を最⼩化しようとしていた.– 訓練データ {(x,y)} ~ P_x,y に対して y を打率良く当てるよう訓練
誤分類率から均衡誤差 (balanced error) へ15• 復習︓我々が使っていた cross entropy 損失 の気持ち− 狙っていたゴール︓誤分類率 (misclassification error) の最⼩化– 0-1損失は微分できないので代理損失の cross entropy 損失を使う.− データ全体での平均的な誤分類率を最⼩化しようとしていた.– 訓練データ {(x,y)} ~ P_x,y に対して y を打率良く当てるよう訓練− Cross entropy で訓練されるモデルの気持ち o0( ⾼頻度 y を予測しまくれば正解率上がるじゃん… 勝ったか…?)– これこそがいま起きている問題
誤分類率から均衡誤差 (balanced error) へ16• 誤分類率 (misclassification error) の最⼩化− cross entropy 損失はこれ (0-1損失) の代替− データ全体での平均的な誤分類率を最⼩化しようとしていた• 均衡誤差 (balanced error) の最⼩化が狙うべきゴールでは︖− お気持ち︓各クラスでバランスよく正解するとえらい– ⾼頻度クラスばかりを「これでしょ…︖」と予測するモデルはダメ– 低頻度クラスも打率⾼く当てないとダメの平均 を最⼩化各クラスでの不正解率“均衡誤差” は紹介者による適当な和訳です
誤分類率から均衡誤差 (balanced error) へ17誤差スコア関数 f: X→Y に満たしてほしい性質.学習のゴール.損失訓練時にデータ (x,y) 毎に与える損失標準的な学習誤分類率 (misclassificationerror)↓Softmax cross-entropy loss提案均衡誤差 (balanced error)↓ ?全体での誤分類率を下げたい.正解率を上げたい.クラス毎の誤分類率の平均を下げたい.バランスよく正解してほしい.Q: 損失はどう⽤意すれば良い?Q: 先⾏研究のように予測時にだけ補正する⽅法は?
提案法の準備︓均衡誤差の最⼩化と誤分類率の最⼩化の関係18• 準備︓均衡誤差の最⼩化と誤分類率の最⼩化の関係− Balanced error を最⼩化するスコア関数 f* はargmax_y p(x|y) でクラスを予測する (Eq.7)− → p(y|x) を当てられるスコア関数 s* を以下のように補正すればbalanced loss に従って y を当てられる (Eq.8)• …は今⽇は⾶ばして,次ページから天下り式に結果だけ述べますSKIPcross entropy で最適化したスコア関数s* から log p(y) を引いて 予測すれば良いBalanced error に基づいて予測したければlogit adjustment
均衡誤差の最⼩化を実現するために提案法1 (予測時に補正)19• 提案法1︓Post-hoc logit adjustment− 訓練時は通常通り softmax / cross-entropy を利⽤− 予測時にスコア関数をクラス頻度で補正− Softmax cross-entropy からモデルを変更しなくて良い.− ※ 温度パラメータτについて︓予測器が p(y|x) にフィットしているかは分からないので (cf. calibration)logit adjustmentここだけ変更すれば良い
均衡誤差の最⼩化を実現するために提案法2 (訓練時に補正)20• 提案法2︓Logit adjusted loss− 訓練時にスコア関数をクラス頻度で補正− 予測時は学習されたスコア関数 f をそのまま利⽤− Softmax cross-entropy からモデルをほとんど変更しなくて良い.− ※ 学習の結果 balanced error を最⼩化する global optimum (の近く) までたどり着けるかはまた別の話.実験セクションで検証.logit adjustmentここだけ変更すれば良い
誤分類率から均衡誤差 (balanced error) へ(再)21誤差スコア関数 f: X→Y に満たしてほしい性質.学習のゴール.損失訓練時にデータ (x,y) 毎に与える損失標準的な学習誤分類率 (misclassificationerror)↓Softmax cross-entropy loss↓提案均衡誤差 (balanced error)↓ Logit adjusted loss↓全体での誤分類率を下げたい.正解率を上げたい.クラス毎の誤分類率の平均を下げたい.バランスよく正解してほしい.正例と負例のスコアのマージンを広げたい低頻度な正例と⾼頻度な負例のマージンを広げるスコア関数を+ log p(y) で補正※ 予測時は略
注︓新規性︖22• 個々の観点に関しては既存研究がある− Balanced error (Chan & Stolfo, 1998; Brodersen et al., 2010;Menon et al., 2013)− Logit adjustment (Fawcett & Provost, 1996; Provost, 2000;Maloof, 2003; Zhou & Liu, 2006; Collell et al., 2016)• ⼿法群を均衡誤差の観点でまとめあげたのが偉い印象− 既存の {理論, ⼿法} 研究の limitation を個々に丁寧に反駁− 既存法・提案法で学習されるスコア関数 f が Balanced error を最適化する f* と Fisher ⼀致性を持つか (推定量として良い性質を持つか) を横断的に評価 (後述)− 既存法・提案法が実際に Balanced error を下げることができるかどうかを経験的に評価
理論評価23
理論評価︓各⼿法は Balanced error のための損失として “良い” か24• 各⼿法を pairwise margin loss として⼀般化• 適当な δ: Y→R+ を⽤いて次の形でパラメータα, Δを表せれば,その損失は均衡誤差と Fisher consistent (Thm. 1)• ➡ 提案法 (δy = p(y)) と古典的 loss modification (δy =1; Eq. 4) のみ Fisher consistentSKIP既存⼿法のほとんどは均衡誤差を最⼩化するような損失になっていない
実験評価25
⼈⼯データによる実験26• Q︓提案法を使うと均衡誤差は最⼩化されるのか?• 実験設定− タスク︓2値分類− データ︓2D Gaussian mixture からサンプリング– 真の分布をこちらが⽤意− ⼿法群︓線形分類器– 学習法 (損失) だけ既存法・提案法で変える− (特に) オラクルベースライン︓均衡誤差にベイズ最適な線形分類器– 真の分布のパラメータを⾒ながら解析的に作れる• Q (再)︓提案法はオラクルベースラインに迫れるのか?
⼈⼯データによる実験の結果提案法1︓予測時に補正提案法は温度パラメータ次第でオラクルベースラインに迫れるベースラインはパラメータをどういじっても迫れない27betterSKIP
⼈⼯データによる実験の結果提案法2︓訓練時の損失変更提案法はオラクルベースラインに迫る=提案した損失は (実際の最適化を伴っても) 均衡誤差をよく最⼩化してくれる28better提案法 オラクル既存法
実データによる実験29• Q︓実際のクラス不均衡なデータ (画像分類) で,提案法はどれくらいよく均衡誤差を最⼩化してくれる?• A︓提案法は既存法に⽐べてなかなか良さそう提案法既存法cross entropy Lower is better既存法と混ぜる.スコア上げの頑張りを感じる
実データによる実験30• Q︓0-1損失 (正解率) をある意味で捨てたわけだけど,正解率はどれくらい下がるの?• A︓提案法は⾼頻度クラスでやや悪化,低頻度クラスで好転− 期待通り− ※ overall accuracy も報告してほしかったbetter⾼頻度クラス提案法
まとめ31
“Long-tail learning via logit adjustment”まとめ32• 問題︓softmax cross-entropy で分類すると,クラス分布が不均衡な場合に問題が起きる− ⾼頻度クラスばかり予測され,低頻度クラスでは汎化されない• 提案︓均衡誤差 (balanced error) の最⼩化という観点での新しい訓練戦略 or 予測戦略の提案− ゴールを変える– Before: 誤分類率 (正解率)– After: 均衡誤差 (クラス毎の正解率の平均)− Softmax への⼊⼒を少しいじるだけ– 予測時︓スコア関数から log p(y) を引く– 訓練時︓スコア関数に log p(y) を⾜す• 理論的・経験的に提案法の優位性を確認− 理論︓提案法だけ (※ほぼ) は均衡誤差と Fisher ⼀致性を持つ− 実験︓データを使った実験でも提案法は均衡誤差の意味でかなり良いlong-tail learningvia logit adjustment
感想・コメント33• よかった点− 既存法たちとの接続が丁寧で respectful.− とても読みやすい.• 気になる点− 均衡誤差が天から降ってくる.– 不均衡ラベルを持つタスク毎に求められるゴール・評価尺度を検討すべき?– 誤差 (評価尺度) ←→ 損失の話の⽂脈をどれくらい抑えているかもわからない.− Logit adjustment を⼊れてもモデルの表現⼒はほとんど変わらない?– softmax に bias 項を⼊れて良いなら bias 項が log p(y) を吸収できる.– 内積のままでも1次元分に log p(y) 分の情報を持たせることができる.– 損失に明⽰的に log p(y) を⼊れることがなぜ / どのように効いてくるのか?− 古典的な loss modification (損失を 1/p(y) 倍, Eq.4) との⽐較が⽢い.– 均衡誤差とFisher⼀致性を持つ,しかも実験での⽐較がない.結構強い可能性…?– ほとんど upsampling/downsampling?− Negative sampling との関係?– SGNS も p(y|x) ∝ p(y)exp() をモデルに訓練している.提案法と⼀貫的.