Slide 1

Slide 1 text

AI 2024/07/26 内田 祐介 GO株式会社 Post-hoc EMA EMAの減衰パラメータの事後最適化

Slide 2

Slide 2 text

AI 2 通常のEMAの指数減衰ではランダムな初期パラメータの影響が大きす ぎるため、べき関数を用いた重み付けを定式化 複数の減衰パラメータのcheckpointを一定間隔で保存しておくこと で、減衰パラメータを事後的に最適化することを提案 Kaggleとかでも使えそう! Post-hoc EMA

Slide 3

Slide 3 text

AI 3 同一モデルのweightを平均することで汎化性能を向上 1. モデルを一度学習 2. 上記のモデルを学習率をcyclicに変化させながらfinetune、 複数weightを取得 3. 上記のweight全てを平均し 新たなモデルのweightとする 4. 学習データでforwardして BNのパラメータをアップデート Stochastic Weight Averaging (SWA) P. Izmailov, et al., "Averaging Weights Leads to Wider Optima and Better Generalization," in Proc. of UAI'18.

Slide 4

Slide 4 text

AI 4 PyTorch実装はあるが1回の学習でpretrain + annealingするので ややこしい。pytorch lightningのwrapperもあるがexperimental https://pytorch.org/docs/stable/optim.html#weight-averaging-swa-and-ema Stochastic Weight Averaging (SWA)

Slide 5

Slide 5 text

AI 5 1回の学習で複数のweightを取得しアンサンブル 等間隔にM個取得しておいて最後のm≦M個を利用 SWAと異なりweightを平均して利用するわけではない Learning rateをcyclicに高くして 異なる局所解を得ることが目的 cosine LRスケジューリング Snapshot Ensembles G. Huang, et al., "Snapshot Ensembles: Train 1, Get M for Free," in Proc. of ICLR'17.

Slide 6

Slide 6 text

AI 6 (余談) みんな大好きcosine LRスケジューリングの始祖はこちら もはや誰もrestartはしていない気がするけど SGDR: Stochastic Gradient Descent with Warm Restarts I. Loshchilov and F. Hutter, "SGDR: Stochasitc Gradient Descent with Warm Restarts," in Proc. of ICLR'17.

Slide 7

Slide 7 text

AI 7 同一モデルのweightを平均することで汎化性能を向上 1. 一定間隔 (step) 毎にEMAモデルをアップデートするだけ! EMAも torch.optim.swa_utils に実装がある timm.utils.ModelEmaV2 を使っていたが ModelEmaV3 がある… (もちろん ModelEmaもある…) V3は減衰率のwarmupができる模様 後述のEMAの問題(初期weightの影響)が軽減されそう Exponential Moving Average (EMA)

Slide 8

Slide 8 text

AI 8 Diffusionモデルに関するCVPR’24の論文の提案手法の一部 画像生成ではEMAの利用とそのパラメータ調整が重要 複数のEMA checkpointを保存しておくことで事後的に EMAの減衰パラメータを最適化することを提案 arXivの論文のほうがappendixが充実してて良い Post-hoc EMA T. Karras, et al., "Analyzing and Improving the Training Dynamics of Diffusion Models," in Proc. of CVPR'24.

Slide 9

Slide 9 text

AI 9 下記の2点の理由から、EMAを指数減衰からべき関数に基づいた平均 を行うように変更する 長時間の平均を利用したいが初期値付近の重みは0にしたい 訓練時間に対して自動的に減衰パラメータをスケーリングしたい 準備 指数減衰では初期パ ラメータのweightが 0にならない

Slide 10

Slide 10 text

AI 10 通常のEMAの更新式 べき関数に基づいたパラメータ平均の定義 べき関数に基づいたEMA更新式 べき関数に基づいたパラメータ平均 Weightが t 依存に 正規化係数 τ時のweigtht τ時のパラメータ

Slide 11

Slide 11 text

AI 11 パラメータ設定時には、γと互換性のある relative standard deviation σrel を利用 べき関数に基づいたパラメータ平均

Slide 12

Slide 12 text

AI 12 複数の γ (σrel ) で、複数タイミングでweightを保存 学習後にこれらのweightから所望の γ (σrel ) で学習した際のweightを 最小二乗法で事後的に算出 再構築アルゴリズム

Slide 13

Slide 13 text

AI 13 γ (σrel ) は2パラメータ スナップショット数は学習時間に応じて。画像生成タスクでなければ そこまで大量じゃなくても良いかも 再構築アルゴリズム

Slide 14

Slide 14 text

AI 14 Diffusionモデルの結果ではあるが モデルによって最適なパラメータがかなり違う =調整の意義あり (EMA前提のパラメータ設定だと 思われるが)EMAなし(左側)の 性能が低い 再構築による最適化結果

Slide 15

Slide 15 text

AI 15 https://github.com/NVlabs/edm2 https://github.com/mmathew23/improved_edm 実装

Slide 16

Slide 16 text

AI 16 いっぱい保存すれば良いじゃない説も D. Morales-Brotons., et al., "Exponential Moving Average of Weights in Deep Learning: Dynamics and Benefits," in TMLR'24.