Upgrade to Pro
— share decks privately, control downloads, hide ads and more …
Speaker Deck
Features
Speaker Deck
PRO
Sign in
Sign up for free
Search
Search
JAXとFlaxを使って、ナウい機械学習をしたい
Search
Moriyama Naoto
February 27, 2021
Technology
8
9.7k
JAXとFlaxを使って、ナウい機械学習をしたい
JAXとFlaxの基本と、深層学習フレームワークの流れなど
Moriyama Naoto
February 27, 2021
Tweet
Share
Other Decks in Technology
See All in Technology
大幅アップデートされたRagas v0.2をキャッチアップ
os1ma
2
530
あの日俺達が夢見たサーバレスアーキテクチャ/the-serverless-architecture-we-dreamed-of
tomoki10
0
450
サービスでLLMを採用したばっかりに振り回され続けたこの一年のあれやこれや
segavvy
2
410
2024年にチャレンジしたことを振り返るぞ
mitchan
0
140
Fanstaの1年を大解剖! 一人SREはどこまでできるのか!?
syossan27
2
170
10個のフィルタをAXI4-Streamでつなげてみた
marsee101
0
170
Amazon VPC Lattice 最新アップデート紹介 - PrivateLink も似たようなアップデートあったけど違いとは
bigmuramura
0
190
なぜCodeceptJSを選んだか
goataka
0
160
How to be an AWS Community Builder | 君もAWS Community Builderになろう!〜2024 冬 CB募集直前対策編?!〜
coosuke
PRO
2
2.8k
PHPからGoへのマイグレーション for DMMアフィリエイト
yabakokobayashi
1
170
5分でわかるDuckDB
chanyou0311
10
3.2k
生成AIをより賢く エンジニアのための RAG入門 - Oracle AI Jam Session #20
kutsushitaneko
4
220
Featured
See All Featured
RailsConf & Balkan Ruby 2019: The Past, Present, and Future of Rails at GitHub
eileencodes
132
33k
Exploring the Power of Turbo Streams & Action Cable | RailsConf2023
kevinliebholz
28
4.4k
Java REST API Framework Comparison - PWX 2021
mraible
PRO
28
8.3k
Practical Tips for Bootstrapping Information Extraction Pipelines
honnibal
PRO
10
810
Reflections from 52 weeks, 52 projects
jeffersonlam
347
20k
Dealing with People You Can't Stand - Big Design 2015
cassininazir
365
25k
Code Reviewing Like a Champion
maltzj
520
39k
Testing 201, or: Great Expectations
jmmastey
40
7.1k
"I'm Feeling Lucky" - Building Great Search Experiences for Today's Users (#IAC19)
danielanewman
226
22k
ピンチをチャンスに:未来をつくるプロダクトロードマップ #pmconf2020
aki_iinuma
111
49k
個人開発の失敗を避けるイケてる考え方 / tips for indie hackers
panda_program
95
17k
Agile that works and the tools we love
rasmusluckow
328
21k
Transcript
JAXとFlaxを使って、 ナウい機械学習をしたい
自己紹介 - 森山直人(Twitter: @vimmode) - みらい翻訳株式会社でリサーチエンジニア - 日本語<->中国語の言語間の機械翻訳がメイン -
pythonと自然言語処理が好き - 深層学習はPyTorchを使うことが多いです - allennlp, flairあたりが好き - 発表内容は組織を代表するものではありません
今日のお話 - 深層学習フレームワークの複雑化は止まらない - 最も知見があるGoogleが一から設計した物が出来たらしい? - これは試せずにいられない! - (試した結果、とても素晴らしいものでした👍) -
本題の前に少し温故知新な話をしたい - 時間の都合で、網羅性よりも自分の感想が中心です
深層学習フレームワークの機能 - 必要最低限の機能 - 学習データからミニバッチを作成 - ニューラルネットワークの定義 - 予測値を計算し、誤差から自動微分でパラメータを更新
- 学習済みのモデルのシリアライズ - GPU, TPUなどのハードウェアアクセラレータ対応 - ニューラルネットワークの記述は主に2つのパラダイムがある - 以降のページで説明していきます
Define and run - Caffe, TensorFlow1などが該当 - 静的な計算グラフを作ってから、データを流し込む - 内部構造は直感的であり、理解しやすい
- pythonで計算グラフを定義に使うが、実行時はpythonは必要ない - 定義されたネットワークは実行時に変わることはないので、 デプロイと運用は安心・安全 - モバイルやエッジコンピューティングにも強い! - コーディングは深層学習への深い理解がないと直感的には書けない
Define by run - Chainer, PyTorchなどが該当 - 変数に計算元の情報を保持させ、それを辿っていくとネットワークが出来る(計算グ ラフの概念を意識させない) -
これにより、記述のしやすさが格段に向上 - ネットワークは入力が来て初めて作られる、永続化はパラメータの辞書 - (初期は)製品化でもpythonのruntimeが必要なため、言語由来の制約は多い - ネットワークが動的に変わることがあり、実運用で問題が発生し得る - 動作環境やメモリ、互換性など、デプロイはデータを流してみないとわからない
Define by run VS Define and run - 研究ではDefine by
runが支持され、製品運用ではDefine and runが支持 される構図に - Define and runであるTensorFlowは研究者から避けられることが多いが、 実務運用ではきわめて優秀 - Define by runであるPyTorchは、モデルをdefine and run スタイルである Caffe2に変換する機能を早期に採用したことで製品運用の課題を一定カ バー - Chainerとの強い差別化 - とはいえ、TensorFlowほど簡単ではない
異なるフレームワークの規格を統一したい - 記述が得意なフレームワークと、実装に優れたフレームワークの相互運用の ために、学習済みモデルの規格を統一させる => ONNX - PyTorch -> MXNetなど
- 一方で、フレームワーク間で数値表現に違いが存在する場合があり、ONNX を交えた変換で計算結果が同じにならない事がある! - 平均や分散などの統計計算は注意が必要 - ONNX専用のruntimeを利用する話もあるが、時間の都合でここでは割愛し ます
現在の二大勢力の課題(個人感) Tensorflow - TensorFlow2ではdefine by run形式でコーディングできるようになったものの、TensorFlow1の 基本設計を考えると、かなり無理な拡張をしたと察する - kerasやeagerなど、抽象化機能が多くて書き方が多様すぎる PyTorch
- 初期からCaffeに変換する設計だったこともあり、内部は複雑に - 細かいところはC++なので、内部実装把握はそこで力づきる - モデルとパラメータが密接に紐付いており、かつネットワークは計算時に確立されるため、量子化と いったパラメータ操作や、モデルの確実なシリアライズが複雑
JAX Googleが開発した行列演算+自動微分+XLAのライブラリ (もともとはautogradというライブラリを拡張して設計されたもの) - 行列演算 - NumPyのAPIと完全互換(ただし非同期処理) - 自動微分 -
自動微分をサポートすることで、JAXだけで簡単なニューラルネットワークが書ける - XLA - pythonで記載された線形代数関連の命令郡をまとめてハードウェアアクセラレータ向け にJITコンパイルし、一度で実行できるようにする。
JAXの好きなところ(個人感) - pure python! - デバックや内部実装の把握がしやすい - とにかく早い - ミニバッチ内の処理など関数をすべて
JITコンパイルすることで、全体の処理が高速化 - データのCPU -> GPU(TPU)間の移動がシームレスに出来る - 設計は関数型指向 - 行列のデータは基本的に変更不可 - インデックス/スライス経由の値変更やインプレース演算ができない設計 - 乱数生成はグローバルの乱数状態を参考にするのではなく、都度状態を生成
JAXサンプル- 行列の不変性 - 行列は基本不変であり、変更するには update関数経由で新たに生成する必要がある - (深層学習で相当な量のバグを防げる)
JAXサンプル- 乱数生成 JAXでは(関数型の性格から)共通のグローバル空間の状態を参照するのではなく、 都度乱数状態を作り出し、そこを参照するスタイル
JAXサンプル- 関数ベクトル化 - vmapを使うことで関数のベクトル化が簡単に行うことが出来る - 深層学習のミニバッチ構築で強い恩恵を受ける
JAXサンプル- JIT - 関数をJITコンパイルすることで高速化 - NumbaなどのJITコンパイルと違い、テンソル計算を主眼に設計されており、 機械学習関連の用途では (JITのために)関数を書き換える必要はほぼない
Flax GoogleによるJAXをベースに実装された深層学習フレームワーク - JAX開発者と近い距離で開発されており、一枚岩感がある - JAX以上に、強い関数型指向の性格を持つ 🌟 - 各種深層学習フレームワークの負債を研究しており、設計思想がアツい 設計思想(抜粋&意訳)
- 悪い抽象化や関数のオプションを増やすよりも、コードの複製を - ドキュメンテーションやテストが難しい部分は、設計を見直そう - 関数型スタイルは一部のユーザーを混乱させるが、高い利益をもたらす - 役に立たないエラーメッセージはバグ同然
Flaxの好きなところ(個人感) - 自動微分はJAXの機能を使うため、設計は大変見通しがよい - モデルとパラメータを明確に分離 - PyTorchのような、モデルとパラメータが一体になる構造ではない - モデルは初期化時に内部構造が確定したら、その後なにも変更されない -
パラメータの更新はoptimizerが管理し、シリアライズ時はoptimizerに対し て行う - モデルはpythonのdataclassと同じ構造 - 余計なものがないのは心理的にとても楽
サンプルコード-モデル定義 - setup経由でモデルの構造を確定。 - @nn.compactでcallをデコレートすればsetupは省略可能 - __call__を定義しているものの、学習時は直接モデルに入力を渡すことはしない(後述)
サンプルコード-初期化
サンプルコード-学習 - ネットワークで流れるデータは型はjax.DeviceArray - (JAX及びFlaxのコードは関数渡しやクロージャなどの関数型指向で書かれたコードが多 く、それに慣れておくと良い) - optimizer.targetに更新パラメータが格納される
JAXとFlaxの所感- 良い点 - JAXとFlaxの役割がそれぞれ明確に分割されているおかげで、両方の設計と APIはスッキリしている - 関数型指向な設計を取り入れたことで、既存のフレームワークのような過度な 抽象化はなく、透明性が高い - イミュータブルな設計も大変良い
- Flaxではモデルとパラメータの分離の考え方は素晴らしく、実装側としては納 得感が高い - なにより、書いていて楽しい!
JAXとFlaxの所感- 悩むところ - 関数型指向な設計により、フレームワーク設計としての美しさは十分だが、入 門者にとっての学習コストは高い - とはいえ、慣れれば可読性と生産性はかなり高い - 既存の資産は簡単には転用できない -
PyTorchとTensorFlow2間はある程度簡単だが、Flaxは少し複雑 - コミュニティがどれだけ大きくなるかは読めない - 世間一般ではPyTorchとTensorFlowはさほど強い不満は持たれていない - 実務観点で、既存のフレームワークからリプレイスするROIは難しいと思う
まとめ - JAXとFlaxについて、基本的な操作と個人的な感想を紹介しました - まだ日が浅いプロダクトであるものの、明確な思想で面白い - ここで紹介しきれなかった優れた要素はまだまだあります - 非同期実行やGPU/TPU上での処理など -
気になる方はぜひ公式ドキュメントをご一読してください - 個人的にはもう少し使いこなせるようになりたいと思う