Upgrade to Pro
— share decks privately, control downloads, hide ads and more …
Speaker Deck
Features
Speaker Deck
PRO
Sign in
Sign up for free
Search
Search
JAXとFlaxを使って、ナウい機械学習をしたい
Search
Moriyama Naoto
February 27, 2021
Technology
8
10k
JAXとFlaxを使って、ナウい機械学習をしたい
JAXとFlaxの基本と、深層学習フレームワークの流れなど
Moriyama Naoto
February 27, 2021
Tweet
Share
Other Decks in Technology
See All in Technology
Data Hubグループ 紹介資料
sansan33
PRO
0
2.7k
コスト削減から「セキュリティと利便性」を担うプラットフォームへ
sansantech
PRO
3
1.2k
AI時代、1年目エンジニアの悩み
jin4
1
160
【インシデント入門】サイバー攻撃を受けた現場って何してるの?
shumei_ito
0
1.5k
ZOZOにおけるAI活用の現在 ~開発組織全体での取り組みと試行錯誤~
zozotech
PRO
4
4.8k
システムのアラート調査をサポートするAI Agentの紹介/Introduction to an AI Agent for System Alert Investigation
taddy_919
2
1.7k
なぜ今、コスト最適化(倹約)が必要なのか? ~AWSでのコスト最適化の進め方「目的編」~
htan
1
110
茨城の思い出を振り返る ~CDKのセキュリティを添えて~ / 20260201 Mitsutoshi Matsuo
shift_evolve
PRO
1
170
Mosaic AI Gatewayでコーディングエージェントを配るための運用Tips / JEDAI 2026 新春 Meetup! AIコーディング特集
genda
0
150
Claude_CodeでSEOを最適化する_AI_Ops_Community_Vol.2__マーケティングx_AIはここまで進化した.pdf
riku_423
1
410
Agile Leadership Summit Keynote 2026
m_seki
1
290
Amazon S3 Vectorsを使って資格勉強用AIエージェントを構築してみた
usanchuu
3
430
Featured
See All Featured
From Legacy to Launchpad: Building Startup-Ready Communities
dugsong
0
140
HDC tutorial
michielstock
1
350
Facilitating Awesome Meetings
lara
57
6.7k
Un-Boring Meetings
codingconduct
0
200
The SEO identity crisis: Don't let AI make you average
varn
0
64
The Mindset for Success: Future Career Progression
greggifford
PRO
0
230
Bioeconomy Workshop: Dr. Julius Ecuru, Opportunities for a Bioeconomy in West Africa
akademiya2063
PRO
1
54
Fireside Chat
paigeccino
41
3.8k
Breaking role norms: Why Content Design is so much more than writing copy - Taylor Woolridge
uxyall
0
160
エンジニアに許された特別な時間の終わり
watany
106
230k
No one is an island. Learnings from fostering a developers community.
thoeni
21
3.6k
HU Berlin: Industrial-Strength Natural Language Processing with spaCy and Prodigy
inesmontani
PRO
0
200
Transcript
JAXとFlaxを使って、 ナウい機械学習をしたい
自己紹介 - 森山直人(Twitter: @vimmode) - みらい翻訳株式会社でリサーチエンジニア - 日本語<->中国語の言語間の機械翻訳がメイン -
pythonと自然言語処理が好き - 深層学習はPyTorchを使うことが多いです - allennlp, flairあたりが好き - 発表内容は組織を代表するものではありません
今日のお話 - 深層学習フレームワークの複雑化は止まらない - 最も知見があるGoogleが一から設計した物が出来たらしい? - これは試せずにいられない! - (試した結果、とても素晴らしいものでした👍) -
本題の前に少し温故知新な話をしたい - 時間の都合で、網羅性よりも自分の感想が中心です
深層学習フレームワークの機能 - 必要最低限の機能 - 学習データからミニバッチを作成 - ニューラルネットワークの定義 - 予測値を計算し、誤差から自動微分でパラメータを更新
- 学習済みのモデルのシリアライズ - GPU, TPUなどのハードウェアアクセラレータ対応 - ニューラルネットワークの記述は主に2つのパラダイムがある - 以降のページで説明していきます
Define and run - Caffe, TensorFlow1などが該当 - 静的な計算グラフを作ってから、データを流し込む - 内部構造は直感的であり、理解しやすい
- pythonで計算グラフを定義に使うが、実行時はpythonは必要ない - 定義されたネットワークは実行時に変わることはないので、 デプロイと運用は安心・安全 - モバイルやエッジコンピューティングにも強い! - コーディングは深層学習への深い理解がないと直感的には書けない
Define by run - Chainer, PyTorchなどが該当 - 変数に計算元の情報を保持させ、それを辿っていくとネットワークが出来る(計算グ ラフの概念を意識させない) -
これにより、記述のしやすさが格段に向上 - ネットワークは入力が来て初めて作られる、永続化はパラメータの辞書 - (初期は)製品化でもpythonのruntimeが必要なため、言語由来の制約は多い - ネットワークが動的に変わることがあり、実運用で問題が発生し得る - 動作環境やメモリ、互換性など、デプロイはデータを流してみないとわからない
Define by run VS Define and run - 研究ではDefine by
runが支持され、製品運用ではDefine and runが支持 される構図に - Define and runであるTensorFlowは研究者から避けられることが多いが、 実務運用ではきわめて優秀 - Define by runであるPyTorchは、モデルをdefine and run スタイルである Caffe2に変換する機能を早期に採用したことで製品運用の課題を一定カ バー - Chainerとの強い差別化 - とはいえ、TensorFlowほど簡単ではない
異なるフレームワークの規格を統一したい - 記述が得意なフレームワークと、実装に優れたフレームワークの相互運用の ために、学習済みモデルの規格を統一させる => ONNX - PyTorch -> MXNetなど
- 一方で、フレームワーク間で数値表現に違いが存在する場合があり、ONNX を交えた変換で計算結果が同じにならない事がある! - 平均や分散などの統計計算は注意が必要 - ONNX専用のruntimeを利用する話もあるが、時間の都合でここでは割愛し ます
現在の二大勢力の課題(個人感) Tensorflow - TensorFlow2ではdefine by run形式でコーディングできるようになったものの、TensorFlow1の 基本設計を考えると、かなり無理な拡張をしたと察する - kerasやeagerなど、抽象化機能が多くて書き方が多様すぎる PyTorch
- 初期からCaffeに変換する設計だったこともあり、内部は複雑に - 細かいところはC++なので、内部実装把握はそこで力づきる - モデルとパラメータが密接に紐付いており、かつネットワークは計算時に確立されるため、量子化と いったパラメータ操作や、モデルの確実なシリアライズが複雑
JAX Googleが開発した行列演算+自動微分+XLAのライブラリ (もともとはautogradというライブラリを拡張して設計されたもの) - 行列演算 - NumPyのAPIと完全互換(ただし非同期処理) - 自動微分 -
自動微分をサポートすることで、JAXだけで簡単なニューラルネットワークが書ける - XLA - pythonで記載された線形代数関連の命令郡をまとめてハードウェアアクセラレータ向け にJITコンパイルし、一度で実行できるようにする。
JAXの好きなところ(個人感) - pure python! - デバックや内部実装の把握がしやすい - とにかく早い - ミニバッチ内の処理など関数をすべて
JITコンパイルすることで、全体の処理が高速化 - データのCPU -> GPU(TPU)間の移動がシームレスに出来る - 設計は関数型指向 - 行列のデータは基本的に変更不可 - インデックス/スライス経由の値変更やインプレース演算ができない設計 - 乱数生成はグローバルの乱数状態を参考にするのではなく、都度状態を生成
JAXサンプル- 行列の不変性 - 行列は基本不変であり、変更するには update関数経由で新たに生成する必要がある - (深層学習で相当な量のバグを防げる)
JAXサンプル- 乱数生成 JAXでは(関数型の性格から)共通のグローバル空間の状態を参照するのではなく、 都度乱数状態を作り出し、そこを参照するスタイル
JAXサンプル- 関数ベクトル化 - vmapを使うことで関数のベクトル化が簡単に行うことが出来る - 深層学習のミニバッチ構築で強い恩恵を受ける
JAXサンプル- JIT - 関数をJITコンパイルすることで高速化 - NumbaなどのJITコンパイルと違い、テンソル計算を主眼に設計されており、 機械学習関連の用途では (JITのために)関数を書き換える必要はほぼない
Flax GoogleによるJAXをベースに実装された深層学習フレームワーク - JAX開発者と近い距離で開発されており、一枚岩感がある - JAX以上に、強い関数型指向の性格を持つ 🌟 - 各種深層学習フレームワークの負債を研究しており、設計思想がアツい 設計思想(抜粋&意訳)
- 悪い抽象化や関数のオプションを増やすよりも、コードの複製を - ドキュメンテーションやテストが難しい部分は、設計を見直そう - 関数型スタイルは一部のユーザーを混乱させるが、高い利益をもたらす - 役に立たないエラーメッセージはバグ同然
Flaxの好きなところ(個人感) - 自動微分はJAXの機能を使うため、設計は大変見通しがよい - モデルとパラメータを明確に分離 - PyTorchのような、モデルとパラメータが一体になる構造ではない - モデルは初期化時に内部構造が確定したら、その後なにも変更されない -
パラメータの更新はoptimizerが管理し、シリアライズ時はoptimizerに対し て行う - モデルはpythonのdataclassと同じ構造 - 余計なものがないのは心理的にとても楽
サンプルコード-モデル定義 - setup経由でモデルの構造を確定。 - @nn.compactでcallをデコレートすればsetupは省略可能 - __call__を定義しているものの、学習時は直接モデルに入力を渡すことはしない(後述)
サンプルコード-初期化
サンプルコード-学習 - ネットワークで流れるデータは型はjax.DeviceArray - (JAX及びFlaxのコードは関数渡しやクロージャなどの関数型指向で書かれたコードが多 く、それに慣れておくと良い) - optimizer.targetに更新パラメータが格納される
JAXとFlaxの所感- 良い点 - JAXとFlaxの役割がそれぞれ明確に分割されているおかげで、両方の設計と APIはスッキリしている - 関数型指向な設計を取り入れたことで、既存のフレームワークのような過度な 抽象化はなく、透明性が高い - イミュータブルな設計も大変良い
- Flaxではモデルとパラメータの分離の考え方は素晴らしく、実装側としては納 得感が高い - なにより、書いていて楽しい!
JAXとFlaxの所感- 悩むところ - 関数型指向な設計により、フレームワーク設計としての美しさは十分だが、入 門者にとっての学習コストは高い - とはいえ、慣れれば可読性と生産性はかなり高い - 既存の資産は簡単には転用できない -
PyTorchとTensorFlow2間はある程度簡単だが、Flaxは少し複雑 - コミュニティがどれだけ大きくなるかは読めない - 世間一般ではPyTorchとTensorFlowはさほど強い不満は持たれていない - 実務観点で、既存のフレームワークからリプレイスするROIは難しいと思う
まとめ - JAXとFlaxについて、基本的な操作と個人的な感想を紹介しました - まだ日が浅いプロダクトであるものの、明確な思想で面白い - ここで紹介しきれなかった優れた要素はまだまだあります - 非同期実行やGPU/TPU上での処理など -
気になる方はぜひ公式ドキュメントをご一読してください - 個人的にはもう少し使いこなせるようになりたいと思う