Upgrade to Pro — share decks privately, control downloads, hide ads and more …

JAXとFlaxを使って、ナウい機械学習をしたい

Moriyama Naoto
February 27, 2021

 JAXとFlaxを使って、ナウい機械学習をしたい

JAXとFlaxの基本と、深層学習フレームワークの流れなど

Moriyama Naoto

February 27, 2021
Tweet

Other Decks in Technology

Transcript

  1. JAXとFlaxを使って、

    ナウい機械学習をしたい


    View full-size slide

  2. 自己紹介 

    - 森山直人(Twitter: @vimmode)
    - みらい翻訳株式会社でリサーチエンジニア
    - 日本語<->中国語の言語間の機械翻訳がメイン
    - pythonと自然言語処理が好き
    - 深層学習はPyTorchを使うことが多いです
    - allennlp, flairあたりが好き
    - 発表内容は組織を代表するものではありません

    View full-size slide

  3. 今日のお話

    - 深層学習フレームワークの複雑化は止まらない
    - 最も知見があるGoogleが一から設計した物が出来たらしい?
    - これは試せずにいられない!
    - (試した結果、とても素晴らしいものでした👍)
    - 本題の前に少し温故知新な話をしたい
    - 時間の都合で、網羅性よりも自分の感想が中心です

    View full-size slide

  4. 深層学習フレームワークの機能 

    - 必要最低限の機能
    - 学習データからミニバッチを作成
    - ニューラルネットワークの定義
    - 予測値を計算し、誤差から自動微分でパラメータを更新
    - 学習済みのモデルのシリアライズ
    - GPU, TPUなどのハードウェアアクセラレータ対応
    - ニューラルネットワークの記述は主に2つのパラダイムがある
    - 以降のページで説明していきます

    View full-size slide

  5. Define and run

    - Caffe, TensorFlow1などが該当
    - 静的な計算グラフを作ってから、データを流し込む
    - 内部構造は直感的であり、理解しやすい
    - pythonで計算グラフを定義に使うが、実行時はpythonは必要ない
    - 定義されたネットワークは実行時に変わることはないので、
    デプロイと運用は安心・安全
    - モバイルやエッジコンピューティングにも強い!
    - コーディングは深層学習への深い理解がないと直感的には書けない

    View full-size slide

  6. Define by run

    - Chainer, PyTorchなどが該当
    - 変数に計算元の情報を保持させ、それを辿っていくとネットワークが出来る(計算グ
    ラフの概念を意識させない)
    - これにより、記述のしやすさが格段に向上
    - ネットワークは入力が来て初めて作られる、永続化はパラメータの辞書
    - (初期は)製品化でもpythonのruntimeが必要なため、言語由来の制約は多い
    - ネットワークが動的に変わることがあり、実運用で問題が発生し得る
    - 動作環境やメモリ、互換性など、デプロイはデータを流してみないとわからない

    View full-size slide

  7. Define by run VS Define and run

    - 研究ではDefine by runが支持され、製品運用ではDefine and runが支持
    される構図に
    - Define and runであるTensorFlowは研究者から避けられることが多いが、
    実務運用ではきわめて優秀
    - Define by runであるPyTorchは、モデルをdefine and run スタイルである
    Caffe2に変換する機能を早期に採用したことで製品運用の課題を一定カ
    バー
    - Chainerとの強い差別化
    - とはいえ、TensorFlowほど簡単ではない

    View full-size slide

  8. 異なるフレームワークの規格を統一したい

    - 記述が得意なフレームワークと、実装に優れたフレームワークの相互運用の
    ために、学習済みモデルの規格を統一させる => ONNX
    - PyTorch -> MXNetなど
    - 一方で、フレームワーク間で数値表現に違いが存在する場合があり、ONNX
    を交えた変換で計算結果が同じにならない事がある!
    - 平均や分散などの統計計算は注意が必要
    - ONNX専用のruntimeを利用する話もあるが、時間の都合でここでは割愛し
    ます

    View full-size slide

  9. 現在の二大勢力の課題(個人感)

    Tensorflow
    - TensorFlow2ではdefine by run形式でコーディングできるようになったものの、TensorFlow1の
    基本設計を考えると、かなり無理な拡張をしたと察する
    - kerasやeagerなど、抽象化機能が多くて書き方が多様すぎる
    PyTorch
    - 初期からCaffeに変換する設計だったこともあり、内部は複雑に
    - 細かいところはC++なので、内部実装把握はそこで力づきる
    - モデルとパラメータが密接に紐付いており、かつネットワークは計算時に確立されるため、量子化と
    いったパラメータ操作や、モデルの確実なシリアライズが複雑

    View full-size slide

  10. JAX

    Googleが開発した行列演算+自動微分+XLAのライブラリ
    (もともとはautogradというライブラリを拡張して設計されたもの)
    - 行列演算
    - NumPyのAPIと完全互換(ただし非同期処理)
    - 自動微分
    - 自動微分をサポートすることで、JAXだけで簡単なニューラルネットワークが書ける
    - XLA
    - pythonで記載された線形代数関連の命令郡をまとめてハードウェアアクセラレータ向け
    にJITコンパイルし、一度で実行できるようにする。

    View full-size slide

  11. JAXの好きなところ(個人感)

    - pure python!
    - デバックや内部実装の把握がしやすい
    - とにかく早い
    - ミニバッチ内の処理など関数をすべて
    JITコンパイルすることで、全体の処理が高速化
    - データのCPU -> GPU(TPU)間の移動がシームレスに出来る
    - 設計は関数型指向
    - 行列のデータは基本的に変更不可
    - インデックス/スライス経由の値変更やインプレース演算ができない設計
    - 乱数生成はグローバルの乱数状態を参考にするのではなく、都度状態を生成

    View full-size slide

  12. JAXサンプル- 行列の不変性

    - 行列は基本不変であり、変更するには update関数経由で新たに生成する必要がある
    - (深層学習で相当な量のバグを防げる)

    View full-size slide

  13. JAXサンプル- 乱数生成

    JAXでは(関数型の性格から)共通のグローバル空間の状態を参照するのではなく、
    都度乱数状態を作り出し、そこを参照するスタイル

    View full-size slide

  14. JAXサンプル- 関数ベクトル化

    - vmapを使うことで関数のベクトル化が簡単に行うことが出来る
    - 深層学習のミニバッチ構築で強い恩恵を受ける

    View full-size slide

  15. JAXサンプル- JIT

    - 関数をJITコンパイルすることで高速化
    - NumbaなどのJITコンパイルと違い、テンソル計算を主眼に設計されており、
    機械学習関連の用途では (JITのために)関数を書き換える必要はほぼない

    View full-size slide

  16. Flax

    GoogleによるJAXをベースに実装された深層学習フレームワーク
    - JAX開発者と近い距離で開発されており、一枚岩感がある
    - JAX以上に、強い関数型指向の性格を持つ 🌟
    - 各種深層学習フレームワークの負債を研究しており、設計思想がアツい
    設計思想(抜粋&意訳)
    - 悪い抽象化や関数のオプションを増やすよりも、コードの複製を
    - ドキュメンテーションやテストが難しい部分は、設計を見直そう
    - 関数型スタイルは一部のユーザーを混乱させるが、高い利益をもたらす
    - 役に立たないエラーメッセージはバグ同然

    View full-size slide

  17. Flaxの好きなところ(個人感)

    - 自動微分はJAXの機能を使うため、設計は大変見通しがよい
    - モデルとパラメータを明確に分離
    - PyTorchのような、モデルとパラメータが一体になる構造ではない
    - モデルは初期化時に内部構造が確定したら、その後なにも変更されない
    - パラメータの更新はoptimizerが管理し、シリアライズ時はoptimizerに対し
    て行う
    - モデルはpythonのdataclassと同じ構造
    - 余計なものがないのは心理的にとても楽

    View full-size slide

  18. サンプルコード-モデル定義

    - setup経由でモデルの構造を確定。
    - @nn.compactでcallをデコレートすればsetupは省略可能
    - __call__を定義しているものの、学習時は直接モデルに入力を渡すことはしない(後述)

    View full-size slide

  19. サンプルコード-初期化


    View full-size slide

  20. サンプルコード-学習

    - ネットワークで流れるデータは型はjax.DeviceArray
    - (JAX及びFlaxのコードは関数渡しやクロージャなどの関数型指向で書かれたコードが多
    く、それに慣れておくと良い)
    - optimizer.targetに更新パラメータが格納される

    View full-size slide

  21. JAXとFlaxの所感- 良い点

    - JAXとFlaxの役割がそれぞれ明確に分割されているおかげで、両方の設計と
    APIはスッキリしている
    - 関数型指向な設計を取り入れたことで、既存のフレームワークのような過度な
    抽象化はなく、透明性が高い
    - イミュータブルな設計も大変良い
    - Flaxではモデルとパラメータの分離の考え方は素晴らしく、実装側としては納
    得感が高い
    - なにより、書いていて楽しい!

    View full-size slide

  22. JAXとFlaxの所感- 悩むところ

    - 関数型指向な設計により、フレームワーク設計としての美しさは十分だが、入
    門者にとっての学習コストは高い
    - とはいえ、慣れれば可読性と生産性はかなり高い
    - 既存の資産は簡単には転用できない
    - PyTorchとTensorFlow2間はある程度簡単だが、Flaxは少し複雑
    - コミュニティがどれだけ大きくなるかは読めない
    - 世間一般ではPyTorchとTensorFlowはさほど強い不満は持たれていない
    - 実務観点で、既存のフレームワークからリプレイスするROIは難しいと思う

    View full-size slide

  23. まとめ

    - JAXとFlaxについて、基本的な操作と個人的な感想を紹介しました
    - まだ日が浅いプロダクトであるものの、明確な思想で面白い
    - ここで紹介しきれなかった優れた要素はまだまだあります
    - 非同期実行やGPU/TPU上での処理など
    - 気になる方はぜひ公式ドキュメントをご一読してください
    - 個人的にはもう少し使いこなせるようになりたいと思う

    View full-size slide