Slide 1

Slide 1 text

中井 悦司 Google Cloud, Solutions Architect JAX / Flax 入門 ※ コミュニティイベント「 [第1回] ML女子部 JAX入門」で使用した資料です。

Slide 2

Slide 2 text

中井悦司 / Etsuji Nakai Solutions Architect, Google Cloud $ who am i 新発売!

Slide 3

Slide 3 text

Google の研究部門で活用が広がる JAX 3 https://cloud.google.com/blog/ja/topics/developers-practitioners/evojax-bringing-power-neuroevolution-solve-your-problems

Slide 4

Slide 4 text

JAX とは? 4 ● JAX 独自の機能の例 ○ JIT コンパイラ:Python で定義した関数を GPU / TPU での計算に最適化されたバイナリーに事前コ ンパイル ○ 自動微分:Python で定義した関数の微分(勾配ベクトル)を計算 ● Google の AI 研究チーム(Google Brain)が開発した数値 計算ライブラリー ● 機械学習のベースとなる計算処理を GPU / TPU で高速に 実行可能 ● NumPy とほぼ同じ関数(メソッド)を用意しており、 NumPy を知っていればすぐに使える

Slide 5

Slide 5 text

JAX は上位ライブラリーと組み合わせて使用 5 ● JAX はあくまでも数値計算ライブラリーなので、それだけでは 機械学習の処理(ニューラルネットワークの定義、勾配降下法 による学習処理など)はできない ● さまざまな上位ライブラリーと組み合わせて使用する ○ Flax:ニューラルネットワークの構築と学習プロセスの管 理機能を提供 ○ Optax:誤差関数や勾配降下法のアルゴリズムをモ ジュールとして提供 ○ などなど ● 今回は、JAX / Flax / Optax の組み合わせ例を紹介

Slide 6

Slide 6 text

6 GPU を用いた高速な数値計算処理 微分計算機能 ① モデルの定義 ② 誤差関数の定義 誤差関数の 勾配ベクトルを計算 パラメーターの値 ③ 学習アルゴリズム 勾配ベクトル 勾配ベクトルの値を用いて パラメーターを更新 Optax Flax JAX パラメーターの初期値を ランダムに設定 JAX / Flax / Optax の役割分担

Slide 7

Slide 7 text

JAX 入門

Slide 8

Slide 8 text

JAX で勾配ベクトルを計算! 8

Slide 9

Slide 9 text

普通の Python の関数を微分可能 9 import jax from jax import numpy as jnp @jax.jit def h(x1, x2): z = 2 * jnp.sin(x1) * jnp.sin(x2) return z nabla_h = jax.jit(jax.grad(h, (0, 1))) plot3d((-jnp.pi, jnp.pi), (-jnp.pi, jnp.pi), h, nabla_h) 2変数関数 を Python の関数として定義 勾配ベクトル       を計算

Slide 10

Slide 10 text

JIT コンパイラによる事前コンパイル 10 @jax.jit def my_function(x1, x2, ...): return z … 1010111001010… Python で定義した関数 入力データ (x1, x2, ...) 計算結果 z 高速実行可能な バイナリーコードに まとめて変換 1 つの処理ごとに対応する バイナリーコードを呼び出す インタープリターで実行する場合 事前コンパイル機能を用いる場合

Slide 11

Slide 11 text

11 import jax from jax import numpy as jnp @jax.jit def h(x1, x2): y = 2 * jnp.sin(x1) * jnp.sin(x2) return y nabla_h = jax.jit(jax.grad(h, (0, 1))) plot3d((-jnp.pi, jnp.pi), (-jnp.pi, jnp.pi), h, nabla_h) JIT コンパイラによる事前コンパイル 直後の関数に事前 コンパイルを適用 引数で指定した関数 事前コンパイルを適用

Slide 12

Slide 12 text

Flax / Optax 入門

Slide 13

Slide 13 text

Flax によるニューラルネットワークの定義 13 class SingleLayerCNN(nn.Module): @nn.compact def __call__(self, x, get_logits=False): x = x.reshape([-1, 28, 28, 1]) x = nn.Conv(features=16, kernel_size=(5, 5)(x) x = nn.relu(x) x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2)) x = x.reshape([x.shape[0], -1]) # Flatten x = nn.Dense(features=1024)(x) x = nn.relu(x) x = nn.Dense(features=10)(x) if get_logits: return x x = nn.softmax(x) return x 条件分岐を記述可能

Slide 14

Slide 14 text

誤差関数と学習ステップを個別に関数として実装 14 @jax.jit def loss_fn(params, state, inputs, labels): logits = state.apply_fn({'params': params}, inputs, get_logits=True) loss = optax.softmax_cross_entropy(logits, labels).mean() return loss @jax.jit def train_step(state, inputs, labels): loss, grads = jax.value_and_grad(loss_fn)( state.params, state, inputs, labels) new_state = state.apply_gradients(grads=grads) return new_state, loss モデルの予測結果(ロジットの値)を取得 誤差関数(クロスエントロピーを計算) 勾配降下法のアルゴリズムで パラメーターを(1回だけ)修正 勾配ベクトルを計算 誤差関数 1 回の学習ステップ

Slide 15

Slide 15 text

15 GPU を用いた高速な数値計算処理 微分計算機能 ① モデルの定義 ② 誤差関数の定義 誤差関数の 勾配ベクトルを計算 パラメーターの値 ③ 学習アルゴリズム 勾配ベクトル 勾配ベクトルの値を用いて パラメーターを更新 Optax Flax JAX パラメーターの初期値を ランダムに設定 JAX / Flax / Optax の役割分担

Slide 16

Slide 16 text

JAX / Flax のメリット(個人の感想) 16 ● 誤差関数や学習ステップを自分で関数として実装する必要がある ○ Keras のように「fit() メソッドで一発!」というわけではない ● 誤差関数や学習ステップの中身がブラックボックスにならないので、細かなチューニングや 「ちょっと凝った独自の実装」が手軽にできる ○ 定番のモデルを使った実務用途よりは、さまざまなチューニングや実装上の工夫を試す研 究・開発用途に適している ● 独自の fit() メソッドをモジュール化して使い回すことももちろん可能 ○ 学習状態を管理する TrainState オブジェクトなど、Flax が提供する機能を活用

Slide 17

Slide 17 text

JAX / Flax のその他の特徴 17 ● モデルのオブジェクトは学習中のパラメーター値を含まない ○ パラメーター値は、ディクショナリーで管理 ○ パラメーターの一部を書き換えたり、学習済みのパラメーター値を取り出して流用するなど が容易にできる @jax.jit def loss_fn(params, state, inputs, labels): logits = state.apply_fn({'params': params}, inputs, get_logits=True) loss = optax.softmax_cross_entropy(logits, labels).mean() return loss モデルを呼び出す際に パラメーター値を入力

Slide 18

Slide 18 text

JAX / Flax による転移学習の例 18

Slide 19

Slide 19 text

ここまでのまとめ(復習)

Slide 20

Slide 20 text

JAX とは? 20 ● JAX 独自の機能の例 ○ JIT コンパイラ:Python で定義した関数を GPU / TPU での計算に最適化されたバイナリーに事前コ ンパイル ○ 自動微分:Python で定義した関数の微分(勾配ベクトル)を計算 ● Google の AI 研究チーム(Google Brain)が開発した数値 計算ライブラリー ● 機械学習のベースとなる計算処理を GPU / TPU で高速に 実行可能 ● NumPy とほぼ同じ関数(メソッド)を用意しており、 NumPy を知っていればすぐに使える

Slide 21

Slide 21 text

JAX は上位ライブラリーと組み合わせて使用 21 ● JAX はあくまでも数値計算ライブラリーなので、それだけでは 機械学習の処理(ニューラルネットワークの定義、勾配降下法 による学習処理など)はできない ● さまざまな上位ライブラリーと組み合わせて使用する ○ Flax:ニューラルネットワークの構築と学習プロセスの管 理機能を提供 ○ Optax:誤差関数や勾配降下法のアルゴリズムをモ ジュールとして提供 ○ などなど ● 今回は、JAX / Flax / Optax の組み合わせ例を紹介

Slide 22

Slide 22 text

22 GPU を用いた高速な数値計算処理 微分計算機能 ① モデルの定義 ② 誤差関数の定義 誤差関数の 勾配ベクトルを計算 パラメーターの値 ③ 学習アルゴリズム 勾配ベクトル 勾配ベクトルの値を用いて パラメーターを更新 Optax Flax JAX パラメーターの初期値を ランダムに設定 JAX / Flax / Optax の役割分担

Slide 23

Slide 23 text

実際のコードを見てみましょう!

Slide 24

Slide 24 text

ノートブックの開き方 24 ● GitHub にアクセス ○ https://github.com/enakai00/JAX_workshop ● ノートブックのリンクをクリック ● [Open in Colab] のボタンをクリック

Slide 25

Slide 25 text

1_Introduction_to_JAX.ipynb:JAX の基本機能 25 ● 以下の内容を学びます ○ DeviceArray オブジェクトと事前コンパイル機能の使い方 ○ JAX における乱数の扱い方 ○ JAX の微分機能

Slide 26

Slide 26 text

2_MNIST_Softmax_Estimation.ipynb:  多項分類器によるMNISTデータセットの分類 26 ● 以下の手順を学びます ○ Flax を用いたモデルの定義方法 ○ パラメーターの初期値の生成と TrainState オブジェクトの作成 ○ 誤差関数の定義 ○ 勾配降下法によるパラメーターの修正を 1 回だけ行う関数を定義 ○ パラメーターの修正を 1 エポック分繰り返す関数を定義 ○ 指定回数のエポック分の学習を行う関数 fit() を定義 ○ 関数 fit() を実行して、学習を実施

Slide 27

Slide 27 text

ハンズオン! 27 ● 質問は、チャットに入れてください。 ● 時々、QA タイムを設けるので、その時にマイクをオンにして質問してもらっても大丈夫です。 ● 機械学習モデルや誤差関数の詳細は、今回は説明しません。まずは、 JAX / Flax を使ったコードの雰囲気 に慣れてください。

Slide 28

Slide 28 text

Thank you.