Upgrade to Pro — share decks privately, control downloads, hide ads and more …

Scaling Rectified Flow Transformers for High-Resolution Image Synthesis / Stable Diffusion 3

Scaling Rectified Flow Transformers for High-Resolution Image Synthesis / Stable Diffusion 3

Stability.AI からこれまでの拡散モデルとは少し異なるパラダイムの新たな Text-to-Image モデル Stable Diffusion 3 (SD3) の提案について紹介します。

一部 GIF アニメーションを用いた図があるため、オリジナルの Google Slide を参照していただくのをおすすめします: https://bit.ly/stable-diffusion-3-explained

Shunsuke KITADA

March 27, 2024
Tweet

More Decks by Shunsuke KITADA

Other Decks in Research

Transcript

  1. © LY Corporation Scaling Rectified Flow Transformers for High-Resolution Image

    Synthesis Patrick Esser, Sumith Kulal, Andreas Blattmann, Rahim Entezari, Jonas Muller, Harry Saini, Yam Levi, Dominik Lorenz, Axel Sauer, Frederic Boesel, Dustin Podell, Tim Dockhorn, Zion English, Kyle Lacey, Alex Goodwin, Yannik Marek, Robin Rombach Stability AI, UK; Stable Diffusion v3 の元論文 Image and Video Dept. / Generation team Shunsuke Kitada, Ph.D. HP: shunk031.me / 𝕏: @shunk031 ※本発表で紹介する図や数式は 対象の論文およびブログ記事から 引用しております stability.ai/news/stable-diffusion-3
  2. © LY Corporation 自己紹介: 北田俊輔 Shunsuke KITADA 経歴 • ‘23/04

    LINE ➜ ‘23/10 LINEヤフー Research Scientist • ‘23/03 法政大学大学院 彌冨研 博士 (工学) / 学振 DC2 研究分野 • 自然言語処理 (NLP) / 画像処理 (CV) ◦ 摂動に頑健で解釈可能な深層学習 [Kitada+ IEEE Access’21, Appl. Intell.’22] • 計算機広告 (Multi-modal / Vision & Language) ◦ 効果の高いデジタル広告の作成支援 [Kitada+ KDD’19] ◦ 効果の低いデジタル広告の停止支援 [Kitada+ Appl. Sci.’22] • 画像生成・レイアウト生成 ◦ 画像生成AI入門:Pythonによる拡散モデルの理論と実践 @オンライン教育サービス Coloso. これに関連して本を書いています(応援よろしくおねがいします) 2
 🏠: shunk031.me / 𝕏: @shunk031 画像生成AIにおける拡散モデルの理論と実践 リサーチサイエンティ スト 北田俊輔 – www.youtube.com/watch?v=-IPEUOcPTas
  3. © LY Corporation Stability.AI からこれまでの拡散モデルとは 少し異なるパラダイムの新たな T2I モデル Stable Diffusion

    3 (SD3) の提案 • 拡散モデルではなくFlowモデルを採用 ◦ 拡散モデルはデータにノイズを徐々に追加する 順拡散過程を逆にたどる逆拡散過程を元に生成 ◦ Flowモデルは簡単な分布を徐々に複雑な 分布を学習できるような f と f-1 を元に生成 • アーキテクチャの改善 ◦ バックボーンを UNet から DiT へ変更 ◦ さまざまな text encoder を組み合わせて 文字の描画性を向上させている? 本論文の選定理由 3
 lilianweng.github.io/posts/2021-07-11-diffusion-models/ stability.ai/news/ stable-diffusion -3-research-paper lilianweng.github.io/posts/2018-10-13-flow-models Flow-based model Text-to-Image
  4. © LY Corporation • ’21/12/20: Latent Diffusion Model (LDM) ◦

    [Rombach+ CVPR’22] DDPM を潜在空間へ拡張 • ’22/08/10: Stable Diffusion v1 ◦ LDM を LAION-5B で学習して公開 • ’22/11/24: Stable Diffusion v2 ◦ OpenCLIP の使用、深度マップ対応 ◦ ’22/12/07: v2.1 • ‘23/04/13: Stable Diffusion XL [Podell+ CoRR’23] ◦ ’22/11/24: beta, ’23/06/22: 0.9 ◦ ’23/07/26: 1.0, ’23/11/28: SDXL Turbo • ’24/02/08: Stable Cascade [Pernias+ CoRR’23] ◦ 複数ステージを経て画像を洗練化 • ’24/02/22: Stable Diffusion 3 [Esser+ CoRR’24] ◦ Flowの採用、UNetからDiTへ Stable Diffusion の進化過程 4
 SD1 SD2 SD2 with depth map SD3 Stable Cascade
  5. © LY Corporation 本論文の貢献 5
 • 拡散モデル & Rectified Flow

    (RF) [Liu+ ICLR’23] に対する大規模実験の実施 ◦ Diffusion Process とは少し異なる Flow ベースのサンプリング ◦ RF に対する新たな noise sampler の導入・従来手法よりも性能向上 • スケーラブルな新たな Text-to-Image モデルの提案 ◦ 従来の U-Net [Ronneberger+ MICCAI’15] を Diffusion Transformer (DiT) [Peebles+ CVPR’23] へ ◦ SoTA である UViT [Hoogeboom+ ICML’23] 等や vanilla DiT と比較したときの利点提示 • 提案モデルにおける予測可能なスケーリング傾向の実証 ◦ 提案モデルにおいて validation loss が下がれば下がるほど T2I-CompBench [Huang+ NeurIPS’24] / GenEval [Ghosh+ NeurIPS’24] および 人間による評価等の指標と強い相関
  6. © LY Corporation データ分布からある ガウシアンへの輸送を 考えるとその軌道は ガウシアンノイズに よってブラウン運動の ような軌道を示す ノイズからデータを生成する拡散モデル

    [Ho+ NeurIPS’20, Song+ ICLR’21] 拡散モデル: データに徐々にノイズを追加する道筋 (forward diffusion process) を 逆にたどる (reverse diffusion process) 学習を経て新しいデータを生成 • 拡散モデルは高次元のデータを効率的にモデリング可能 [Ho+ NeurIPS’20] ◦ 特に画像生成の分野で活躍 [Saharia+ NeurIPS’22, Rombach+ CVPR’22, Esser+ ICCV’23 etc] データからノイズへの道筋をどのようにたどればよいか? • ノイズを付与したデータから適切にノイズを除去できない場合 訓練と推論の分布の不一致で適切な 画像を生成できない可能性 大 ➜ グレースケール画像が生成 されてしまったり [Lin+ WACV’24] 💡 適切な道筋をたどることが重要 6
 図は cvpr2022-tutorial-diffusion-models.github.io/ の公開資料から データからノイズ への道筋
  7. © LY Corporation Rectified Flow (RF) [Li+ ICLR’23, Albergo+ ICLR’23,

    Lipman+ ICLR’23] • データ⇔ノイズの道筋を直線で繋ぐような方法 ◦ 中小規模の実験において有効性が確認されている EDM [Karras+ NeurIPS’22] • DDPMスケジューラを一般化し生成性能向上 Cosine [Nichol+ ICML’21] (Improved DDPM) • DDPMで最終盤にノイズが乗りすぎるのを軽減 7
 図の引用元 [Biroli+ CoRR’24] 直線で繋ぐのが Rectified Flow 効果的なデータ⇔ノイズの “たどり方” の研究
  8. © LY Corporation Rectified Flow (RF) [Li+ ICLR’23, Albergo+ ICLR’23,

    Lipman+ ICLR’23] 1/2 ある分布を別の分布に写像させるための効率的な方法 • できるだけ2つの分布間を直線的につなげる輸送写像を見つける 常微分微分方程式 (ordinal differential equation; ODE) 手法 [Song+ ICLR’20, Ho+ NeurIPS’20] ◦ ニューラル ODE や確率微分方程式 (stochastic differential equatnin; SDE) と関連 • 従来の ODE/SDE と RF の違い ◦ 従来: 無限に考えうる2つの分布間の輸送方法を暗黙的に学習 ◦ RF: 解の経路が直線であるような ODE を明示的に学習 • RF の直感的利点 ◦ 最適輸送理論との相性の良さ・ODE/SDE の利点を継承・シンプルなフレームワーク ◦ 数値的に解いた際に誤差があまり出ないため、少ないステップ数で推論可能 8

  9. © LY Corporation 💡 元画像 x 0 からノイズ画像 ε への変換を学習する際にそれぞれを

    等速直線運動する軌跡を目標に回帰 • このとき移動速度は で な を近似できるような モデルをニューラルネットワークで学習 ◦ ナイーブな線形補間を学習すると x 0 と ε に対する 因果性 causality が保証できず複数の軌跡が発生してしまう • 時刻 t における点 z t を通過する直線方向の 平均を計算することで因果性を保証 ◦ 生成過程が直線的になるため離散化誤差が小さく 少ないステップで高品質な画像生成が実現 Rectified Flow (RF) [Li+ ICLR’23, Albergo+ ICLR’23, Lipman+ ICLR’23] 2/2 9
 ナイーブな線形補間 (non-causalized) Rectified Flow (causalized) 👎入力に対して 複数の軌跡が発生 因果性を保証して 軌跡を一意に 👍 www.cs.utexas.edu/~lqiang/rectflow/html/intro.html x 0 ε
  10. © LY Corporation のターゲットとなる は [0, 1] の中間地点で予測するのは難しい • ノイズ分布

    から密度分布 への変化を学習するのは重み付け 損失を学習するのと同等; 中間地点の t をより多くサンプリングする ことで、中間タイムステップをより重要視するように学習させたい • Logit-Normal Sampling [Aitchison+ Biometrika’80] ◦ logit-normal 分布に基づいたサンプリング 実際は正規分布からサンプリングしてロジスティック関数で変換 • Mode Sampling with Heavy Tails ◦ logit-normal 分布は 0 と 1 の端点で消失 ◦ f mode に従う π mode を定義してサンプリングに使用 • CosMap [Nichol+ ICML’21] ◦ log-SNR がコサインに従うようなサンプリング SD3 における RF に対する sampler の検討 10

  11. © LY Corporation Backbone を U-Net から Transformer (ViT) へ

    • 入力と中間表現の設計 ◦ Latent Diffusion Model (LDM) [Rombach+ CVPR’22] を踏襲 ◦ 画像をパッチに変換して DiT block へ入力し潜在データへ • DiT ブロックの設計 ◦ 条件付は Cross-Attention や In-Context Conditioning (シンプルに concat) よりも adaptive な LayerNorm (adaLN) が 実験から有効性が明らかに Diffusion Transformer (DiT) [Peebles+ CVPR’23] 11
 OpenAI が発表した次世代 text-to-videoモデル Sora でも DiT が採用されている? (Open じゃないから本当かは分からない)
  12. © LY Corporation • DiT をベースに潜在表現を学習する LDM を踏襲 ◦ Frozen

    text-encoder からテキスト表現を得る点も類似 ▪ CLIP-G/14, L/14 のほか、T5 XXL も同時に使用 学習時にランダムに各埋め込みをゼロ埋め込みに dropout • マルチモーダル化した Multi-modal DiT (MMDiT) ◦ DiT はクラス条件のみ; text-to-image へ拡張 ◦ LDM 等とは異なりテキストの埋め込み列をモデルに入力 ▪ 従来はテキストの埋め込み列を pooling してモデルへ入力 ▪ 言語理解能力が上がって画像中に適切に文字の描画が可能に ◦ 各モダリティに対する独立したパラメータ ▪ テキストや画像はもともと大きく異なる情報を表現 ▪ それぞれ独立したTransformerでモダリティ情報を変換しつつ 最後に attention を計算するように SD3 のアーキテクチャ 12
 SD3 のアーキテクチャ [Esser+ CoRR’24] smiling cartoon dog sits at a table, coffee mug on hand, as a room goes up in flames. “This is fine,” the dog assures himself. text-encoder てんこ盛りでデカい
  13. © LY Corporation • 評価データセット ◦ さまざまな学習設定下での実験 ▪ 学習: ImageNet

    [Russakovsky+ IJCV’14] , CC12M [Changpinyo+ CVPR’21] , 評価: COCO-2014 [Lin+ ECCV’14] ◦ スケーリング則に対する実験 ▪ GenEval [Ghosh+ NeurIPS’24] , T2I-CompBench [Huang+ NeurIPS’24] • 評価指標 ◦ CLIP score [Hessel+ EMNLP+21] , FID [Heusel+ NeurIPS’17] FID 計算時に CLIP feature を使用 [Sauer+ NeurIPS’21] • 比較手法 ◦ 2 つのデータセットに対して 61 の実験設定で評価 ▪ 従来 loss と RF loss で様々なハイパラ設定で詳細に実験 ▪ LDM, EDM, ADM 等のモデルの比較 実験設定と比較結果 1/4 13
 異なるサンプリング手法の比較結果 (T = 25) RF ベースのサンプリングが少ないステップ数で 高品質な画像を生成可能
  14. © LY Corporation • 評価データセット ◦ さまざまな学習設定下での実験 ▪ 学習: ImageNet

    [Russakovsky+ IJCV’14] , CC12M [Changpinyo+ CVPR’21] , 評価: COCO-2014 [Lin+ ECCV’14] ◦ スケーリング則に対する実験 ▪ GenEval [Ghosh+ NeurIPS’24] , T2I-CompBench [Huang+ NeurIPS’24] • 評価指標 ◦ CLIP score [Hessel+ EMNLP+21] , FID [Heusel+ NeurIPS’17] FID 計算時に CLIP feature を使用 [Sauer+ NeurIPS’21] • 比較手法 ◦ 2 つのデータセットに対して 61 の実験設定で評価 ▪ 従来 loss と RF loss で様々なハイパラ設定で詳細に実験 ▪ LDM, EDM, ADM 等のモデルの比較 実験設定と比較結果 1/4 14
 異なるサンプリング手法の比較結果 (T = 25) RF ベースのサンプリングが少ないステップ数で 高品質な画像を生成可能 RF + logit-normal sampling で CLIP score が高い (= テキストを よく表した画像を生成している) RF + mode sampling で FID が低い (= 品質の高い画像を生成している)
  15. © LY Corporation • 評価データセット ◦ さまざまな学習設定下での実験 ▪ 学習: ImageNet

    [Russakovsky+ IJCV’14] , CC12M [Changpinyo+ CVPR’21] , 評価: COCO-2014 [Lin+ ECCV’14] ◦ スケーリング則に対する実験 ▪ GenEval [Ghosh+ NeurIPS’24] , T2I-CompBench [Huang+ NeurIPS’24] • 評価指標 ◦ CLIP score [Hessel+ EMNLP+21] , FID [Heusel+ NeurIPS’17] FID 計算時に CLIP feature を使用 [Sauer+ NeurIPS’21] • 比較手法 ◦ 2 つのデータセットに対して 61 の実験設定で評価 ▪ 従来 loss と RF loss で様々なハイパラ設定で詳細に実験 ▪ LDM, EDM, ADM 等のモデルの比較 実験設定と比較結果 1/4 15
 異なるサンプリング手法の比較結果 (T = 25) RF ベースのサンプリングが少ないステップ数で 高品質な画像を生成可能 RF + logit-normal sampling が 他の手法よりも低い FID を達成
  16. © LY Corporation 合成キャプションによる生成性能 • 先行研究で有効性が確認されている [Betker+ DALL-E2] 合成キャプションを追加したときの性能比較 ◦

    CogVLM [Wang+ CoRR’23] を用いてキャプションを生成 ◦ オリジナルのキャプション + 合成キャプションを 50/50 で混合すると全体のスコアが向上 ➜ 以降、混合して学習したモデルで評価 Backbone の違いによる生成性能 • DiT: text を concat してまとめる • CrossDiT: text を cross attention • UViT: U-Net と transformer 合せ技 • MM-DiT: text をモダリティごとにあつかう ➜ DiT > UViT; CrossDiT > UViT; 初期は UViT 比較結果 2/4 16
 GenEval ベンチマークを用いた生成画像に対する評価 異なる Backbone における生成性能の比較 MMDiT が一貫して他の Backbone よりも良い性能を達成 CLIP (G/14, L/14) と T5 の 2 or 3 sets
  17. © LY Corporation 比較結果 3/4 17
 MM-DiT のスケーリング効果の比較 • Parti-prompt

    benchmark [Yu+ CoRR’22] を用いた 複数の指標から画像の品質を人手で評価 ◦ 比較手法に対して提案手法がより好まれる傾向 ◦ T5を取り除くと審美性には影響がほとんどなく 特に文字の描画性が大きく損なわれる傾向 MM-DiT のスケーリング効果の比較 • 2B から 8B までスケールさせたモデル ◦ モデルサイズと学習回数を増やせば増やすほど validation loss が下がっていく傾向 ➜ Validation loss が下がれば下がるほど 画像の品質を評価する GenEval, 人手指標が向上 各指標・各SoTAモデルにおける人手評価の結果 MM-DiT のスケーリング効果の比較 どの画像が一番審美的に品質が高いですか? どの画像が一番プロンプトに忠実ですか? どの画像が一番正しく 画像中に文字を描画 できていますか?
  18. © LY Corporation T5エンコーダによる文字描画性能の比較 • T5エンコーダを使用すると複雑なプロンプトに対する生成画像の忠実度が向上 ◦ T5 を使わず zero

    vector で埋めて推論しても、簡単なプロントで破綻なく描画可能 ◦ 学習時に3つのテキストエンコーダをランダムにゼロ埋めしているため、ある程度補完可能? 比較結果 4/4 18
 “A burger patty, with the bottom bun and lettuce and tomatoes. ”COFFEE” written on it in mustard” “A monkey holding a sign reading ”Scaling transformer models is awesome!” “A mischievous ferret with a playful grin squeezes itself into a large glass jar, surrounded by colorful candy. The jar sits on a wooden table in a cozy kitchen, and warm sunlight filters through a nearby window” All text-encoders without T5
  19. © LY Corporation 議論とまとめ • Text-to-Image に対する RF のスケーリング効果を詳細に検証 ◦

    RF に対する新たな noise sampling として考えられる方法を複数調査 ◦ 拡散過程を用いるよりも少ないステップ数で高い品質の画像を生成可能 • Text-to-Image に対して DiT を拡張した MM-DiT を提案 ◦ 8B までパラメータを増やしたときの スケーリングの効果の実証 ◦ Validation loss が下がれば下がるほど 生成画像の品質が向上する性質の発見 ◦ より忠実にテキストを生成画像に描画可能 • SD3 公開されるかな…? ◦ Waiting list に登録して全力待機しましょう 19
 stability.ai/news/ stable-diffusion -3-research-paper