Slide 1

Slide 1 text

新しいスケーリング則と 学習理論 鈴木大慈 東京大学 大学院情報理工学系研究科 数理情報学専攻 理研AIP 2024年12月 1

Slide 2

Slide 2 text

鈴木大慈 2 所属 ➢ 東京大学大学院情報理工学系研究科数理情報学専攻・教授 ➢ 東大次世代知能科学研究センター研究部門研究者(研究知能部門) ➢ 理化学研究所 革新知能統合研究センター 深層学習理論チーム チームリーダー 専門 ➢ 機械学習の理論:数理統計学,統計的学習理論,確率的最適化 解釈可能性: 説明可能性,データの可視化,メンテナ ンスの容易化 各種テクニックの解析: アーキテクチャの解析,損失関数の設計, 最適化技法の解析 深層学習の原理解明: 「表現理論」「汎化誤差理論」「最適化 理論」 学習の本質解明: “良い”学習手法の特徴付け,統一理論, 深層学習を優越する方法論の提唱 応用 基礎 鈴木大慈 情報理工学系研究科 確率論 幾何学 関数解析 最適化理論 数学 数理統計 スパース推定 関連する機械学習理論 特徴抽出 カーネル法 深層学習の理論 主な研究内容 ➢ 深層学習を含む様々な学習機構に関する理論 ➢ 学習理論を通じた各種学習手法の汎化解析や学習アルゴリズムの収 束理論 ➢ 確率的最適化による大規模複雑な機械学習問題の効率的解法 著書/授賞 ➢『確率的最適化(機械学習プロフェッショナルシリーズ)』講談社,2015年 8月8日. ➢金森敬文,鈴木大慈,竹内一郎,佐藤一誠:『機械学習のための連続最適化 (機械学習プロフェッショナルシリーズ)』講談社,2016年12月7日. ➢文部科学大臣表彰・若手科学者賞「深層学習の原理解明に向けた統計的学習 理論の研究」.文部科学省,2020年4月7日. ➢第11回日本統計学会研究業績賞 (2017年度).2017年9月5日. ➢Satoshi Hayakawa and Taiji Suzuki:日本神経回路学会論文賞.日本神経回路学会, 2021年9月23日. ➢日本応用数理学会,ベストオーサー賞(論文部門).2019年9月4日. 主な活動場所 • 国内:IBIS, 統計連合大会 • 国外:NeurIPS, ICML, ICLR, ACML, ... (ACML steering committee member)

Slide 3

Slide 3 text

3 Transformerの理論 拡散モデルの理論

Slide 4

Slide 4 text

基盤モデルのスケーリング則とは? 新しい学習理論 今後の方向性 4

Slide 5

Slide 5 text

モデル訓練の計算量 5 [Sastry et al.: Computing Power and the Governance of Artificial Intelligence. arXiv:2402.08797] 訓練時計算量

Slide 6

Slide 6 text

6 Alex-net 2 × GTX 580 1.581 TFLOPS for FLOAT 1.5GB memory xAI 200,000 × H100/H200 ・・・ 800 TFLOPS for FP16 80GB memory 2012 2024 [参考] 産総研 ABCI 3.0: 6,128 × H200 ・・・

Slide 7

Slide 7 text

スケーリング則 7 Reducible loss [Kaplan et al.: Scaling Laws for Neural Language Models, 2020] [Henighan et al.: Scaling Laws for Autoregressive Generative Modeling, 2020] モデルサイズ固定 (基本的に訓練データサイズと思ってよい) [Brown et al.: Language Models are Few-Shot Learners, 2020] (GPT-3モデルの解析) log(予測精度)=−𝛼 log 𝑛 + log(𝐶)

Slide 8

Slide 8 text

基本的考え方 • スケーリング則は古典的な学習理論でも現れる. 8 真のモデル log 予測誤差 = − 𝑎 1+𝑎 log 𝑛 + log(𝐶) バイアス バリアンス 予測誤差 観測データ: 正則化学習法 (カーネル法) 最適なモデルサイズ 学習モデル ただし を用いて 𝑀 𝑀−𝑎 (正規直交系 in L2) バリアンス=モデルの次元/n バイアス=切り捨てた係数の二乗和

Slide 9

Slide 9 text

9 log 予測誤差 = − 𝑎 1+𝑎 log 𝑛 + log(𝐶) log(𝒏) log(予測誤差) モデルサイズM固定の予測誤差 最適なモデルサイズの予測誤差

Slide 10

Slide 10 text

深層学習 vs 浅層学習 (異なるスケーリング則) 学習すべき真の関数の形状によっては深層が有利になる 10 深 層 浅 層 縮小ランク回帰 特徴空間の次元 が低い状況は深 層学習が得意 区分滑らかな関数 不連続な関数の 推定は深層学習 が得意 Besov空間 滑らかさが非一 様な関数の推定 は深層学習が得 意 低次元データ データが低次元 部分空間上に分 布していたら深 層学習が有利 [Suzuki, 2019] [Schmidt-Hieber, 2019] [Nakada&Imaizumi, 2019][Chen et al., 2019][Suzuki&Nitanda, 2019] [Imaizumi&Fukumizu, 2019] 推 定 精 度

Slide 11

Slide 11 text

カーネル法と深層学習の違い 11 推定誤差 データサイズ 少ないデータサイズ では浅い学習が良い. 大きなデータサイズ では深層学習が良い. 深層学習 浅い学習 • スケーリング則自体は比較的古典的な理論からも導出できる. • しかし,これだけの「データ量」「モデルサイズ」「学習問題の複 雑さ」で実証されることはなかった.

Slide 12

Slide 12 text

基盤モデルのスケーリング則 12

Slide 13

Slide 13 text

(𝑌𝑡 ∼ 𝑋 𝑇−𝑡 ) 拡散モデルの統計理論 13 Stable diffusion, 2022. Forward process Backward process どちらも(ほぼ)ミニマックス最適 [Yang & Barron, 1999; Niles-Weed & Berthet, 2022]. 経験スコアマッチング推定量: (for any 𝛿 > 0). 定理 Let ෠ 𝑌 be the r.v. generated by the backward process w.r.t. Ƹ 𝑠, then (Estimator for 𝑊1 distance requires some modification) (𝑠: 密度関数の滑らかさ) [Kazusato Oko, Shunta Akiyama, Taiji Suzuki: Diffusion Models are Minimax Optimal Distribution Estimators. ICML2023, oral] (2% of all submissions)

Slide 14

Slide 14 text

Transformerの推定理論 14 定理 (推定誤差) ➢ 入力が無限次元でも多項式オーダーの収束レート. (ほぼミニマックス最適) ⋯ 𝑥−1 𝑥0 𝑥1 𝑥2 ⋯ ⋯ 𝑌−1 𝑌0 𝑌1 𝑌2 ⋯ ⋮ ⋮ ⋮ ⋮ Self-attention FNN Transformerの性質 • かなり広いトークン幅から重要な トークンを選べる. → 次元の呪い? • 入力に依存して重要なトークンを 選択できる. → 次元の呪いを回避! [Shokichi Takakura, Taiji Suzuki: Approximation and Estimation Ability of Transformers for Sequence-to-Sequence Functions with Infinite Dimensional Input. ICML2023]

Slide 15

Slide 15 text

Transformerの推定理論 15 定理 (推定誤差) ➢ 入力が無限次元でも多項式オーダーの収束レート. (ほぼミニマックス最適) ⋯ 𝑥−1 𝑥0 𝑥1 𝑥2 ⋯ ⋯ 𝑌−1 𝑌0 𝑌1 𝑌2 ⋯ ⋮ ⋮ ⋮ ⋮ Self-attention FNN Transformerの性質 • かなり広いトークン幅から重要な トークンを選べる. → 次元の呪い? • 入力に依存して重要なトークンを 選択できる. → 次元の呪いを回避! [Shokichi Takakura, Taiji Suzuki: Approximation and Estimation Ability of Transformers for Sequence-to-Sequence Functions with Infinite Dimensional Input. ICML2023] State-Space-Modelといった新しいモデル が提案されているが,少なくともトークン を取捨選択できる性質は必要

Slide 16

Slide 16 text

これからのスケーリング則 16

Slide 17

Slide 17 text

スケーリング則に乗っていれば良いのか? そう思っている人は多い (OpenAI, xAI). しかし,訓練データはインターネット上の ほぼ全てのデータを使い切っており これ以上スケールしない → スケーリング則の多様性も考慮すべき 17

Slide 18

Slide 18 text

学習レジームの多様化 18 事前学習 100%

Slide 19

Slide 19 text

学習レジームの多様化 19 事前学習 事前学習 事後学習 テスト時 推論 100% 45% 35% 20% 事前学習データの質向上 Data augmentation アラインメント 教師有りファイン チューニング Preference optimization RLHF, RLAIF Monte-Carlo Search In-context learning (Few-shot prompting) Chain-of-thought ここの重要度が上がっている 含,データの自動生成 ➢o1, AlphaProof, AlphaGeometry

Slide 20

Slide 20 text

データの重要性 • 最近のLLM開発は本質的な手法の改善は興味の対象 外 • いかに「質の高い」データを多く作れるかが性能の鍵 ➢Attentionの代替機構の研究よりもデータの質と量 → 原子力発電所の自社運用で電力を確保 20 1. 質の高いデータがあればそれを利用 (text book) 2. データを加工したりaugmentしたりして水増し 3. 強化学習で自動生成 生データ 生成データ

Slide 21

Slide 21 text

良質データの効果 • Pythonコード生成タスクで実験 • 教育的効果の高いデータをフィルタリング • GPT3.5による学習データ生成:diversityを担保した生成,演習問題も生成 (CodeExercises) ➢ StackOverflowのデータだけを用いるより,上記の方法で生成したCode Texbook並みの質のデータで学習すると精度がかなり上がる. (データの質は学習アルゴリズムの改善を簡単に上まわる) 21 [Gunasekar et al.: Textbooks Are All You Need. ICLR2024]

Slide 22

Slide 22 text

22 事前学習 事後学習 テスト時 推論 事前学習データの質向上 Data augmentation アラインメント 教師有りファイン チューニング Preference optimization RLHF, RLAIF Monte-Carlo Search In-context learning (Few-shot prompting) Chain-of-thought 事後学習によるアラインメント改善 ➢ DPO (Direct Preference Optimization) ➢ RLHF (Reinforcement Learning from Human Feedback) ➢ Supervised Fine Tuning: Instruction Tuning 自動データ生成による事後学習 ➢ 自動的なinstruction data生成 ➢ Self-Improvement (e.g., SPIN, Instruction Back-translation) ➢ RLAIF (AIフィードバックによる強化学習) ➢ Chain-of-thought generation

Slide 23

Slide 23 text

例: UltraFeedback 23 [Cui et al.: UltraFeedback: Boosting Language Models with Scaled AI Feedback. arXiv:2310.01377] • 質問に対して,複数のLLM (異なる種類,異なるモデルサイズ) か ら回答を生成. • GPT-4に回答の質をランキングさせる. • ランキングの結果は元モデルのfine-tuningに利用できる.

Slide 24

Slide 24 text

例: RLAIF 24 [Lee et al.: RLAIF vs. RLHF: Scaling Reinforcement Learning from Human Feedback with AI Feedback. ICML2024] • 別のLLMを用いてfeedback評価を生成 • 人間のfeedback (RLHF) より良好な性能 • 評価用LLMを自モデルと同じサイズにし ても性能の向上を確認 → AI自身の生成データによる自己改善の可 能性 評価の難しさ ≪ 生成の難しさ ※コーチは選手よりも優れたプレイヤーである必要はない.

Slide 25

Slide 25 text

25 事前学習 事後学習 テスト時 推論 事前学習データの質向上 Data augmentation アラインメント 教師有りファイン チューニング Preference optimization RLHF, RLAIF Monte-Carlo Search In-context learning (Few-shot prompting) Chain-of-thought

Slide 26

Slide 26 text

(𝑌𝑡 ∼ 𝑋 𝑇−𝑡 ) 拡散モデル 26 逆過程 (target distribution) (標準正規分布) 順過程 (Wasserstein勾配流) KL-divergence from 𝜇∗ : 標準正規分布に収束 N(0,I) (Wasserstein GF) 順過程: 逆過程: : 順過程のたどった軌跡を逆にたどる

Slide 27

Slide 27 text

Direct Distribution Optimization 27 • 𝜇ref : 事前学習された参照モデル 拡散モデルで与えられている. ➢ 複雑な分布も生成可能 ➢ 密度関数は評価できない ➢ サンプリングしかできない → どうやって最適化するか? 例: DPO, ベイズフィ ルタリング

Slide 28

Slide 28 text

Direct Distribution Optimization 28 • 𝜇ref : 事前学習された参照モデル 拡散モデルで与えられている. ➢ 複雑な分布も生成可能 ➢ 密度関数は評価できない ➢ サンプリングしかできない → どうやって最適化するか? 例: DPO, ベイズフィ ルタリング • これまでの手法:上界の最小化 • 本研究:直接的に拡散モデルの最適化を実現

Slide 29

Slide 29 text

Direct Preference Optimization • DPO: fine-tuning method for generative models such as LLMs. 29 Fine-tuning data: • For each prompt 𝑐 ∼ 𝑝(𝑐), generate 𝑦1 , 𝑦2 ∼ 𝑝SFT 𝑦 𝑐 (independently). • Get preference 𝑦𝑤 ≻ 𝑦𝑙 between 𝑦1 , 𝑦2 . (human feedback) 1. 2. (Bradley-Terry model) (computation of normalization constant is not required) [Rafailov et al. 2024]

Slide 30

Slide 30 text

双対平均加法 30 双対平均化法 (Dual averaging method) • 𝜇ref : 参照モデル (事前学習モデル) where For 𝑘 = 1, … , 𝐾 − 1: 𝑂(1/𝐾) convergence 最適な分布と参照分布の密度比が求まる [Kawata, Oko, Nitanda, Suzuki: Direct Distributional Optimization for Provable Alignment of Diffusion Models. 2024]

Slide 31

Slide 31 text

Doob h-transform 31 Q: ෝ 𝝁 ∝ 𝐞𝐱𝐩(−ෝ 𝒈) 𝝁𝐫𝐞𝐟 からどうサンプリングするか? Doob ℎ-Transform (Doob, 1957; Rogers & Williams, 2000): 修正項 reference (𝝁𝐫𝐞𝐟 ) Corrected process Ƹ 𝜇 𝑌0 𝑌ത 𝑇 参照モデルの逆過程 (参照モデル) (Gaussian distribution) 修正された拡散モデル (最適モデル) 修正 𝜇ref 確率論

Slide 32

Slide 32 text

Numerical comparison 32 (既存手法との比較) 理論保証あり

Slide 33

Slide 33 text

33 事前学習 事後学習 テスト時 推論 事前学習データの質向上 Data augmentation アラインメント 教師有りファイン チューニング Preference optimization RLHF, RLAIF Monte-Carlo Search In-context learning (Few-shot prompting) Chain-of-thought

Slide 34

Slide 34 text

Test-time inference 34

Slide 35

Slide 35 text

35 • 解候補をモンテカルロサンプリングして良いものだけをピックアップ • 解候補の「良さ」を測るProcess Reward Verifier (PRM) も学習 → 枝刈り・推論を高速化 テストタイムの時間を多くとった方が性能向上 (たくさんサンプリングし た方が良い出力を見つけられる) → 見つかった良い結果を用いてモデルをfine-tuningすることも可能 → 候補の生成に費やした計算量も考慮すべき:新しいスケーリング則

Slide 36

Slide 36 text

36 事前学習 事後学習 テスト時 推論 事前学習データの質向上 Data augmentation アラインメント 教師有りファイン チューニング Preference optimization RLHF, RLAIF Monte-Carlo Search In-context learning (Few-shot prompting) Chain-of-thought

Slide 37

Slide 37 text

思考連鎖 (Chain-of-Thought) 37 • 思考の連鎖を訓練データに用いる • 思考の連鎖を例示してin-context learningさせる • 思考の連鎖を出力させる(e.g., think step by step) →精度向上,解釈性向上 • 結果だけでなく思考過程も出力/入力 [Wei et al.: Chain-of-Thought Prompting Elicits Reasoning in Large Language Models. 2022]

Slide 38

Slide 38 text

OpenAI O1モデル 38

Slide 39

Slide 39 text

数学への応用 39 • AlphaGeometry (DeepMind, 2023) • AlphaProof, AlphaGeometry2 (DeepMind, 2024) • 定理証明系の言語を利用:Lean [de Moura et al., 2015], Coq [Barras et al., 1997], Isabelle [Nipkow et al., 2002] • 定理を「形式化」して,証明をプログラムとして書き下す. • 証明の真偽は自動的に判定可能(単発の回答はもちろん真偽判定可能) ➢ 思考連鎖の訓練データを収集して学習 ➢ 思考連鎖を自動生成して証明が通ったものを訓練データにして学習

Slide 40

Slide 40 text

思考連鎖の理論 40 [Kim&Suzuki: Transformers provably solve parity Efficiently with chain of thought. arXiv:2410.08633, 2024] 𝑘-パリティ問題 ➢ 𝑥 = (𝑥1 , … , 𝑥𝑑 ) ∼ Unif( −1,1 𝑑) ➢ 𝑦 = 𝑥𝑖1 𝑥𝑖2 … 𝑥𝑖𝑘 = ς𝑗∈𝑝 𝑥𝑗 𝒅次元入力のうち𝒌個のみが出力に関係ある. 𝒙𝒊, 𝒚𝒊 𝒊=𝟏 𝒏 𝑛個のデータから意味のある𝑘個の座標を特定したい. 普通に勾配法で解こうとすると 𝑛 = 𝑂(𝑑𝑘−1) くらいのデータ数が必要. Q: CoTで効率化できるか? NNによる学習法が多く研究されている: Abbe et al. (2023); Refinetti et al. (2021); Ben Arous et al. (2022); Damian et al. (2022); Suzuki, Wu, Oko, Nitanda (2023).

Slide 41

Slide 41 text

𝑘-パリティ問題の階層構造 41 : 中間結果を教師データとして与えて問題を分割 Transformerの入力の各トークンに • 入力の各要素 𝑥𝑖 (𝑖 = 1, … , 𝑑) • 中間結果 𝑥𝑗 (𝑗 ≥ 𝑑 + 1) を入れていく. ➢ 𝑗 ≥ 𝑑 + 1に対してnext token prediction ➢ 最後のトークンが𝑦の予測値 各トークンの予測問題ごとに誤差逆伝 搬でネットワークを学習

Slide 42

Slide 42 text

結果 42 思考連鎖の途中過程を含めた学習データを用いてTransformer を学習する.すると,データ数 𝑛 = 𝑂(𝑑2+𝜖) で学習できる: ො 𝑦test − 𝑦test ∞ ≤ 𝑂(𝑑−𝜖/8). 定理 w/o CoT with CoT 必要な学習データ数 𝑂(𝑑𝑘−1) 𝑂(𝑑2+𝜖) データ数の比較

Slide 43

Slide 43 text

結果 43 思考連鎖の途中過程を含めた学習データを用いてTransformer を学習する.すると,データ数 𝑛 = 𝑂(𝑑2+𝜖) で学習できる: ො 𝑦test − 𝑦test ∞ ≤ 𝑂(𝑑−𝜖/8). 定理 w/o CoT with CoT 必要な学習データ数 𝑂(𝑑𝑘−1) 𝑂(𝑑2+𝜖) データ数の比較 連鎖思考の方法 (いくつかの異なる実装法)

Slide 44

Slide 44 text

44 事前学習 事後学習 テスト時 推論 事前学習データの質向上 Data augmentation アラインメント 教師有りファイン チューニング Preference optimization RLHF, RLAIF Monte-Carlo Search In-context learning (Few-shot prompting) Chain-of-thought

Slide 45

Slide 45 text

In-context learning 45 In-Context Learning (ICL) [Brown et al., 2020]. 良く事前学習されたモデルはテスト時のin-context learning (文脈内学習) でも良い性能を示す. Question ChatGPT

Slide 46

Slide 46 text

In-context learning 46 Question ChatGPT In-Context Learning (ICL) [Brown et al., 2020]. 良く事前学習されたモデルはテスト時のin-context learning (文脈内学習) でも良い性能を示す.

Slide 47

Slide 47 text

Fine tuning method 47 通常のファインチューニングはパラメータを更新する. In-context learningでは更新しない. (e.g., RLHF) ※ 最近では,test time trainingと言って,LoRAを用いてin-context learning時に少 しファインチューニングする方法も提案されている [Akyurek et al.

Slide 48

Slide 48 text

In-Context learning 48 Pretraining Test task Example Query Query Examples ICLはモデルパラメータを更新しない → メタ学習,学習の学習 事前学習時にたくさんのタスク を観測しておく. → タスク汎化

Slide 49

Slide 49 text

In-context learningの数学的定式化 49 Pretraining (𝑻 tasks): ⋯ × 𝑇 ➢ We observe pretraining task data 𝑇 times. ➢ Each task has 𝑛 data. Test task (In-context learning): ⋯ Predict • The true functions 𝐹𝑡 ∘ are different across different tasks. • 𝐹𝑡 ∘ is generated randomly for each task. Model: 𝑡 = 1, … , 𝑇: Task index

Slide 50

Slide 50 text

Transformerの役割 • 事前学習 (pretraining): 特徴量 (表現) を学習 [𝑓∘] ➢Fourier, B-Spline ➢文脈 (𝑡) に非依存 ➢データを表現する「最も効率的」な基底を学習 → 中間層 • 文脈内学習 (in-context): 係数を学習 [𝛽𝑡 ] ➢文脈 (𝑡) に依存 ➢例示から現在の文脈𝛽𝑡 を推定 → 最終層のAttention 50 ✓ Guo et al. (2023), von Oswald et al. (2023) では,Transformerは浅い層で特徴 量を抽出して深い層で線形回帰 (or 勾配法) を行っていることを実験的に確認.

Slide 51

Slide 51 text

予測誤差の評価 51 Empirical risk minimizer: Thm. (ICL risk bound; Kim, Nakamaki, TS, NeurIPS2024) 1. 2. Feature approximation error Pretraining generalization to estimate basis functions 3. In-context generalization gap 4. 𝑓𝑗 ∘ 𝑗=1 ∞ are “near” orthonormal Assumption (informal) Covering number of DNN (関数空間の複雑さ) (基底の近似誤差) (基底は大きすぎない) (基底はほぼ正規直交)

Slide 52

Slide 52 text

予測誤差の収束レート 52 • 例 (B-spline基底; 𝑓𝑗 ∘がB-spline→Besov/Sobolev空間): 𝑻が小さい: 記憶中の状況 𝑻が大きい: 記憶が完了し汎化できる状況 → テスト時推論のスケーリング則 「文脈 (𝛽𝑡 )」の推定誤差 ミニマックス最適 「表現 (𝑓∘)」の推定誤差

Slide 53

Slide 53 text

タスク多様性と性能の関係 53 [Raventós, Paul, Chen, Ganguli: Pretraining task diversity and the emergence of non-Bayesian in-context learning for regression. 2023 ] If # of pretraining tasks is enough, ICL coincides with optimal ridge regression.

Slide 54

Slide 54 text

まとめ 54 事前学習 事後学習 テスト時 推論 事前学習データの質向上 Data augmentation アラインメント 教師有りファイン チューニング Preference optimization RLHF, RLAIF Monte-Carlo Search In-context learning (Few-shot prompting) Chain-of-thought スケーリング則の質的な変化 • 事後学習 • テスト時推論 • Fine-tuning用データの自動生成 これらの計算量も考慮したスケーリング則への移行 • 人間を超える知能を持ち始めた場合,いかにしてスケーリング則を 継続させるか? ➢ 評価方法,データの生成法 • 一般的な論理推論の獲得はスケーリング則だけで可能か?