Upgrade to Pro — share decks privately, control downloads, hide ads and more …

新しいスケーリング則と学習理論

Taiji Suzuki
January 07, 2025

 新しいスケーリング則と学習理論

Taiji Suzuki

January 07, 2025
Tweet

More Decks by Taiji Suzuki

Other Decks in Technology

Transcript

  1. 鈴木大慈 2 所属 ➢ 東京大学大学院情報理工学系研究科数理情報学専攻・教授 ➢ 東大次世代知能科学研究センター研究部門研究者(研究知能部門) ➢ 理化学研究所 革新知能統合研究センター

    深層学習理論チーム チームリーダー 専門 ➢ 機械学習の理論:数理統計学,統計的学習理論,確率的最適化 解釈可能性: 説明可能性,データの可視化,メンテナ ンスの容易化 各種テクニックの解析: アーキテクチャの解析,損失関数の設計, 最適化技法の解析 深層学習の原理解明: 「表現理論」「汎化誤差理論」「最適化 理論」 学習の本質解明: “良い”学習手法の特徴付け,統一理論, 深層学習を優越する方法論の提唱 応用 基礎 鈴木大慈 情報理工学系研究科 確率論 幾何学 関数解析 最適化理論 数学 数理統計 スパース推定 関連する機械学習理論 特徴抽出 カーネル法 深層学習の理論 主な研究内容 ➢ 深層学習を含む様々な学習機構に関する理論 ➢ 学習理論を通じた各種学習手法の汎化解析や学習アルゴリズムの収 束理論 ➢ 確率的最適化による大規模複雑な機械学習問題の効率的解法 著書/授賞 ➢『確率的最適化(機械学習プロフェッショナルシリーズ)』講談社,2015年 8月8日. ➢金森敬文,鈴木大慈,竹内一郎,佐藤一誠:『機械学習のための連続最適化 (機械学習プロフェッショナルシリーズ)』講談社,2016年12月7日. ➢文部科学大臣表彰・若手科学者賞「深層学習の原理解明に向けた統計的学習 理論の研究」.文部科学省,2020年4月7日. ➢第11回日本統計学会研究業績賞 (2017年度).2017年9月5日. ➢Satoshi Hayakawa and Taiji Suzuki:日本神経回路学会論文賞.日本神経回路学会, 2021年9月23日. ➢日本応用数理学会,ベストオーサー賞(論文部門).2019年9月4日. 主な活動場所 • 国内:IBIS, 統計連合大会 • 国外:NeurIPS, ICML, ICLR, ACML, ... (ACML steering committee member)
  2. モデル訓練の計算量 5 [Sastry et al.: Computing Power and the Governance

    of Artificial Intelligence. arXiv:2402.08797] 訓練時計算量
  3. 6 Alex-net 2 × GTX 580 1.581 TFLOPS for FLOAT

    1.5GB memory xAI 200,000 × H100/H200 ・・・ 800 TFLOPS for FP16 80GB memory 2012 2024 [参考] 産総研 ABCI 3.0: 6,128 × H200 ・・・
  4. スケーリング則 7 Reducible loss [Kaplan et al.: Scaling Laws for

    Neural Language Models, 2020] [Henighan et al.: Scaling Laws for Autoregressive Generative Modeling, 2020] モデルサイズ固定 (基本的に訓練データサイズと思ってよい) [Brown et al.: Language Models are Few-Shot Learners, 2020] (GPT-3モデルの解析) log(予測精度)=−𝛼 log 𝑛 + log(𝐶)
  5. 基本的考え方 • スケーリング則は古典的な学習理論でも現れる. 8 真のモデル log 予測誤差 = − 𝑎

    1+𝑎 log 𝑛 + log(𝐶) バイアス バリアンス 予測誤差 観測データ: 正則化学習法 (カーネル法) 最適なモデルサイズ 学習モデル ただし を用いて 𝑀 𝑀−𝑎 (正規直交系 in L2) バリアンス=モデルの次元/n バイアス=切り捨てた係数の二乗和
  6. 9 log 予測誤差 = − 𝑎 1+𝑎 log 𝑛 +

    log(𝐶) log(𝒏) log(予測誤差) モデルサイズM固定の予測誤差 最適なモデルサイズの予測誤差
  7. 深層学習 vs 浅層学習 (異なるスケーリング則) 学習すべき真の関数の形状によっては深層が有利になる 10 深 層 浅 層

    縮小ランク回帰 特徴空間の次元 が低い状況は深 層学習が得意 区分滑らかな関数 不連続な関数の 推定は深層学習 が得意 Besov空間 滑らかさが非一 様な関数の推定 は深層学習が得 意 低次元データ データが低次元 部分空間上に分 布していたら深 層学習が有利 [Suzuki, 2019] [Schmidt-Hieber, 2019] [Nakada&Imaizumi, 2019][Chen et al., 2019][Suzuki&Nitanda, 2019] [Imaizumi&Fukumizu, 2019] 推 定 精 度
  8. カーネル法と深層学習の違い 11 推定誤差 データサイズ 少ないデータサイズ では浅い学習が良い. 大きなデータサイズ では深層学習が良い. 深層学習 浅い学習

    • スケーリング則自体は比較的古典的な理論からも導出できる. • しかし,これだけの「データ量」「モデルサイズ」「学習問題の複 雑さ」で実証されることはなかった.
  9. (𝑌𝑡 ∼ 𝑋 𝑇−𝑡 ) 拡散モデルの統計理論 13 Stable diffusion, 2022.

    Forward process Backward process どちらも(ほぼ)ミニマックス最適 [Yang & Barron, 1999; Niles-Weed & Berthet, 2022]. 経験スコアマッチング推定量: (for any 𝛿 > 0). 定理 Let ෠ 𝑌 be the r.v. generated by the backward process w.r.t. Ƹ 𝑠, then (Estimator for 𝑊1 distance requires some modification) (𝑠: 密度関数の滑らかさ) [Kazusato Oko, Shunta Akiyama, Taiji Suzuki: Diffusion Models are Minimax Optimal Distribution Estimators. ICML2023, oral] (2% of all submissions)
  10. Transformerの推定理論 14 定理 (推定誤差) ➢ 入力が無限次元でも多項式オーダーの収束レート. (ほぼミニマックス最適) ⋯ 𝑥−1 𝑥0

    𝑥1 𝑥2 ⋯ ⋯ 𝑌−1 𝑌0 𝑌1 𝑌2 ⋯ ⋮ ⋮ ⋮ ⋮ Self-attention FNN Transformerの性質 • かなり広いトークン幅から重要な トークンを選べる. → 次元の呪い? • 入力に依存して重要なトークンを 選択できる. → 次元の呪いを回避! [Shokichi Takakura, Taiji Suzuki: Approximation and Estimation Ability of Transformers for Sequence-to-Sequence Functions with Infinite Dimensional Input. ICML2023]
  11. Transformerの推定理論 15 定理 (推定誤差) ➢ 入力が無限次元でも多項式オーダーの収束レート. (ほぼミニマックス最適) ⋯ 𝑥−1 𝑥0

    𝑥1 𝑥2 ⋯ ⋯ 𝑌−1 𝑌0 𝑌1 𝑌2 ⋯ ⋮ ⋮ ⋮ ⋮ Self-attention FNN Transformerの性質 • かなり広いトークン幅から重要な トークンを選べる. → 次元の呪い? • 入力に依存して重要なトークンを 選択できる. → 次元の呪いを回避! [Shokichi Takakura, Taiji Suzuki: Approximation and Estimation Ability of Transformers for Sequence-to-Sequence Functions with Infinite Dimensional Input. ICML2023] State-Space-Modelといった新しいモデル が提案されているが,少なくともトークン を取捨選択できる性質は必要
  12. 学習レジームの多様化 19 事前学習 事前学習 事後学習 テスト時 推論 100% 45% 35%

    20% 事前学習データの質向上 Data augmentation アラインメント 教師有りファイン チューニング Preference optimization RLHF, RLAIF Monte-Carlo Search In-context learning (Few-shot prompting) Chain-of-thought ここの重要度が上がっている 含,データの自動生成 ➢o1, AlphaProof, AlphaGeometry
  13. 良質データの効果 • Pythonコード生成タスクで実験 • 教育的効果の高いデータをフィルタリング • GPT3.5による学習データ生成:diversityを担保した生成,演習問題も生成 (CodeExercises) ➢ StackOverflowのデータだけを用いるより,上記の方法で生成したCode

    Texbook並みの質のデータで学習すると精度がかなり上がる. (データの質は学習アルゴリズムの改善を簡単に上まわる) 21 [Gunasekar et al.: Textbooks Are All You Need. ICLR2024]
  14. 22 事前学習 事後学習 テスト時 推論 事前学習データの質向上 Data augmentation アラインメント 教師有りファイン

    チューニング Preference optimization RLHF, RLAIF Monte-Carlo Search In-context learning (Few-shot prompting) Chain-of-thought 事後学習によるアラインメント改善 ➢ DPO (Direct Preference Optimization) ➢ RLHF (Reinforcement Learning from Human Feedback) ➢ Supervised Fine Tuning: Instruction Tuning 自動データ生成による事後学習 ➢ 自動的なinstruction data生成 ➢ Self-Improvement (e.g., SPIN, Instruction Back-translation) ➢ RLAIF (AIフィードバックによる強化学習) ➢ Chain-of-thought generation
  15. 例: UltraFeedback 23 [Cui et al.: UltraFeedback: Boosting Language Models

    with Scaled AI Feedback. arXiv:2310.01377] • 質問に対して,複数のLLM (異なる種類,異なるモデルサイズ) か ら回答を生成. • GPT-4に回答の質をランキングさせる. • ランキングの結果は元モデルのfine-tuningに利用できる.
  16. 例: RLAIF 24 [Lee et al.: RLAIF vs. RLHF: Scaling

    Reinforcement Learning from Human Feedback with AI Feedback. ICML2024] • 別のLLMを用いてfeedback評価を生成 • 人間のfeedback (RLHF) より良好な性能 • 評価用LLMを自モデルと同じサイズにし ても性能の向上を確認 → AI自身の生成データによる自己改善の可 能性 評価の難しさ ≪ 生成の難しさ ※コーチは選手よりも優れたプレイヤーである必要はない.
  17. 25 事前学習 事後学習 テスト時 推論 事前学習データの質向上 Data augmentation アラインメント 教師有りファイン

    チューニング Preference optimization RLHF, RLAIF Monte-Carlo Search In-context learning (Few-shot prompting) Chain-of-thought
  18. (𝑌𝑡 ∼ 𝑋 𝑇−𝑡 ) 拡散モデル 26 逆過程 (target distribution)

    (標準正規分布) 順過程 (Wasserstein勾配流) KL-divergence from 𝜇∗ : 標準正規分布に収束 N(0,I) (Wasserstein GF) 順過程: 逆過程: : 順過程のたどった軌跡を逆にたどる
  19. Direct Distribution Optimization 27 • 𝜇ref : 事前学習された参照モデル 拡散モデルで与えられている. ➢

    複雑な分布も生成可能 ➢ 密度関数は評価できない ➢ サンプリングしかできない → どうやって最適化するか? 例: DPO, ベイズフィ ルタリング
  20. Direct Distribution Optimization 28 • 𝜇ref : 事前学習された参照モデル 拡散モデルで与えられている. ➢

    複雑な分布も生成可能 ➢ 密度関数は評価できない ➢ サンプリングしかできない → どうやって最適化するか? 例: DPO, ベイズフィ ルタリング • これまでの手法:上界の最小化 • 本研究:直接的に拡散モデルの最適化を実現
  21. Direct Preference Optimization • DPO: fine-tuning method for generative models

    such as LLMs. 29 Fine-tuning data: • For each prompt 𝑐 ∼ 𝑝(𝑐), generate 𝑦1 , 𝑦2 ∼ 𝑝SFT 𝑦 𝑐 (independently). • Get preference 𝑦𝑤 ≻ 𝑦𝑙 between 𝑦1 , 𝑦2 . (human feedback) 1. 2. (Bradley-Terry model) (computation of normalization constant is not required) [Rafailov et al. 2024]
  22. 双対平均加法 30 双対平均化法 (Dual averaging method) • 𝜇ref : 参照モデル

    (事前学習モデル) where For 𝑘 = 1, … , 𝐾 − 1: 𝑂(1/𝐾) convergence 最適な分布と参照分布の密度比が求まる [Kawata, Oko, Nitanda, Suzuki: Direct Distributional Optimization for Provable Alignment of Diffusion Models. 2024]
  23. Doob h-transform 31 Q: ෝ 𝝁 ∝ 𝐞𝐱𝐩(−ෝ 𝒈) 𝝁𝐫𝐞𝐟

    からどうサンプリングするか? Doob ℎ-Transform (Doob, 1957; Rogers & Williams, 2000): 修正項 reference (𝝁𝐫𝐞𝐟 ) Corrected process Ƹ 𝜇 𝑌0 𝑌ത 𝑇 参照モデルの逆過程 (参照モデル) (Gaussian distribution) 修正された拡散モデル (最適モデル) 修正 𝜇ref 確率論
  24. 33 事前学習 事後学習 テスト時 推論 事前学習データの質向上 Data augmentation アラインメント 教師有りファイン

    チューニング Preference optimization RLHF, RLAIF Monte-Carlo Search In-context learning (Few-shot prompting) Chain-of-thought
  25. 35 • 解候補をモンテカルロサンプリングして良いものだけをピックアップ • 解候補の「良さ」を測るProcess Reward Verifier (PRM) も学習 →

    枝刈り・推論を高速化 テストタイムの時間を多くとった方が性能向上 (たくさんサンプリングし た方が良い出力を見つけられる) → 見つかった良い結果を用いてモデルをfine-tuningすることも可能 → 候補の生成に費やした計算量も考慮すべき:新しいスケーリング則
  26. 36 事前学習 事後学習 テスト時 推論 事前学習データの質向上 Data augmentation アラインメント 教師有りファイン

    チューニング Preference optimization RLHF, RLAIF Monte-Carlo Search In-context learning (Few-shot prompting) Chain-of-thought
  27. 思考連鎖 (Chain-of-Thought) 37 • 思考の連鎖を訓練データに用いる • 思考の連鎖を例示してin-context learningさせる • 思考の連鎖を出力させる(e.g.,

    think step by step) →精度向上,解釈性向上 • 結果だけでなく思考過程も出力/入力 [Wei et al.: Chain-of-Thought Prompting Elicits Reasoning in Large Language Models. 2022]
  28. 数学への応用 39 • AlphaGeometry (DeepMind, 2023) • AlphaProof, AlphaGeometry2 (DeepMind,

    2024) • 定理証明系の言語を利用:Lean [de Moura et al., 2015], Coq [Barras et al., 1997], Isabelle [Nipkow et al., 2002] • 定理を「形式化」して,証明をプログラムとして書き下す. • 証明の真偽は自動的に判定可能(単発の回答はもちろん真偽判定可能) ➢ 思考連鎖の訓練データを収集して学習 ➢ 思考連鎖を自動生成して証明が通ったものを訓練データにして学習
  29. 思考連鎖の理論 40 [Kim&Suzuki: Transformers provably solve parity Efficiently with chain

    of thought. arXiv:2410.08633, 2024] 𝑘-パリティ問題 ➢ 𝑥 = (𝑥1 , … , 𝑥𝑑 ) ∼ Unif( −1,1 𝑑) ➢ 𝑦 = 𝑥𝑖1 𝑥𝑖2 … 𝑥𝑖𝑘 = ς𝑗∈𝑝 𝑥𝑗 𝒅次元入力のうち𝒌個のみが出力に関係ある. 𝒙𝒊, 𝒚𝒊 𝒊=𝟏 𝒏 𝑛個のデータから意味のある𝑘個の座標を特定したい. 普通に勾配法で解こうとすると 𝑛 = 𝑂(𝑑𝑘−1) くらいのデータ数が必要. Q: CoTで効率化できるか? NNによる学習法が多く研究されている: Abbe et al. (2023); Refinetti et al. (2021); Ben Arous et al. (2022); Damian et al. (2022); Suzuki, Wu, Oko, Nitanda (2023).
  30. 𝑘-パリティ問題の階層構造 41 : 中間結果を教師データとして与えて問題を分割 Transformerの入力の各トークンに • 入力の各要素 𝑥𝑖 (𝑖 =

    1, … , 𝑑) • 中間結果 𝑥𝑗 (𝑗 ≥ 𝑑 + 1) を入れていく. ➢ 𝑗 ≥ 𝑑 + 1に対してnext token prediction ➢ 最後のトークンが𝑦の予測値 各トークンの予測問題ごとに誤差逆伝 搬でネットワークを学習
  31. 結果 42 思考連鎖の途中過程を含めた学習データを用いてTransformer を学習する.すると,データ数 𝑛 = 𝑂(𝑑2+𝜖) で学習できる: ො 𝑦test

    − 𝑦test ∞ ≤ 𝑂(𝑑−𝜖/8). 定理 w/o CoT with CoT 必要な学習データ数 𝑂(𝑑𝑘−1) 𝑂(𝑑2+𝜖) データ数の比較
  32. 結果 43 思考連鎖の途中過程を含めた学習データを用いてTransformer を学習する.すると,データ数 𝑛 = 𝑂(𝑑2+𝜖) で学習できる: ො 𝑦test

    − 𝑦test ∞ ≤ 𝑂(𝑑−𝜖/8). 定理 w/o CoT with CoT 必要な学習データ数 𝑂(𝑑𝑘−1) 𝑂(𝑑2+𝜖) データ数の比較 連鎖思考の方法 (いくつかの異なる実装法)
  33. 44 事前学習 事後学習 テスト時 推論 事前学習データの質向上 Data augmentation アラインメント 教師有りファイン

    チューニング Preference optimization RLHF, RLAIF Monte-Carlo Search In-context learning (Few-shot prompting) Chain-of-thought
  34. In-context learning 45 In-Context Learning (ICL) [Brown et al., 2020].

    良く事前学習されたモデルはテスト時のin-context learning (文脈内学習) でも良い性能を示す. Question ChatGPT
  35. In-context learning 46 Question ChatGPT In-Context Learning (ICL) [Brown et

    al., 2020]. 良く事前学習されたモデルはテスト時のin-context learning (文脈内学習) でも良い性能を示す.
  36. Fine tuning method 47 通常のファインチューニングはパラメータを更新する. In-context learningでは更新しない. (e.g., RLHF) ※

    最近では,test time trainingと言って,LoRAを用いてin-context learning時に少 しファインチューニングする方法も提案されている [Akyurek et al.
  37. In-Context learning 48 Pretraining Test task Example Query Query Examples

    ICLはモデルパラメータを更新しない → メタ学習,学習の学習 事前学習時にたくさんのタスク を観測しておく. → タスク汎化
  38. In-context learningの数学的定式化 49 Pretraining (𝑻 tasks): ⋯ × 𝑇 ➢

    We observe pretraining task data 𝑇 times. ➢ Each task has 𝑛 data. Test task (In-context learning): ⋯ Predict • The true functions 𝐹𝑡 ∘ are different across different tasks. • 𝐹𝑡 ∘ is generated randomly for each task. Model: 𝑡 = 1, … , 𝑇: Task index
  39. Transformerの役割 • 事前学習 (pretraining): 特徴量 (表現) を学習 [𝑓∘] ➢Fourier, B-Spline

    ➢文脈 (𝑡) に非依存 ➢データを表現する「最も効率的」な基底を学習 → 中間層 • 文脈内学習 (in-context): 係数を学習 [𝛽𝑡 ] ➢文脈 (𝑡) に依存 ➢例示から現在の文脈𝛽𝑡 を推定 → 最終層のAttention 50 ✓ Guo et al. (2023), von Oswald et al. (2023) では,Transformerは浅い層で特徴 量を抽出して深い層で線形回帰 (or 勾配法) を行っていることを実験的に確認.
  40. 予測誤差の評価 51 Empirical risk minimizer: Thm. (ICL risk bound; Kim,

    Nakamaki, TS, NeurIPS2024) 1. 2. Feature approximation error Pretraining generalization to estimate basis functions 3. In-context generalization gap 4. 𝑓𝑗 ∘ 𝑗=1 ∞ are “near” orthonormal Assumption (informal) Covering number of DNN (関数空間の複雑さ) (基底の近似誤差) (基底は大きすぎない) (基底はほぼ正規直交)
  41. 予測誤差の収束レート 52 • 例 (B-spline基底; 𝑓𝑗 ∘がB-spline→Besov/Sobolev空間): 𝑻が小さい: 記憶中の状況 𝑻が大きい:

    記憶が完了し汎化できる状況 → テスト時推論のスケーリング則 「文脈 (𝛽𝑡 )」の推定誤差 ミニマックス最適 「表現 (𝑓∘)」の推定誤差
  42. タスク多様性と性能の関係 53 [Raventós, Paul, Chen, Ganguli: Pretraining task diversity and

    the emergence of non-Bayesian in-context learning for regression. 2023 ] If # of pretraining tasks is enough, ICL coincides with optimal ridge regression.
  43. まとめ 54 事前学習 事後学習 テスト時 推論 事前学習データの質向上 Data augmentation アラインメント

    教師有りファイン チューニング Preference optimization RLHF, RLAIF Monte-Carlo Search In-context learning (Few-shot prompting) Chain-of-thought スケーリング則の質的な変化 • 事後学習 • テスト時推論 • Fine-tuning用データの自動生成 これらの計算量も考慮したスケーリング則への移行 • 人間を超える知能を持ち始めた場合,いかにしてスケーリング則を 継続させるか? ➢ 評価方法,データの生成法 • 一般的な論理推論の獲得はスケーリング則だけで可能か?