Slide 1

Slide 1 text

BurnでDeep Learning やってみる 2023/2/24 Shirokuma @Rust LT ハイブリッド #1

Slide 2

Slide 2 text

自己紹介 独立系ロボットエンジニア Shirokuma@neka-nat https://twitter.com/neka_nat ● フリーでロボティクス・画像処理関連のソフトウェアの お仕事やってます!

Slide 3

Slide 3 text

自己紹介 独立系ロボットエンジニア Shirokuma@neka-nat https://twitter.com/neka_nat ● フリーでロボティクス・画像処理関連のソフトウェアの お仕事やってます! ● OSS開発・ブログとか ○ myCobotをRustで動かす ○ Rustで書いたCUDAカーネルで画像処理してみる ○ Rustで点群処理

Slide 4

Slide 4 text

Burnとは? ● Rustで書かれた深層学習フレームワーク https://github.com/burn-rs/burn

Slide 5

Slide 5 text

Burnとは? ● Rustで書かれた深層学習フレームワーク ● Pythonで言うところのPyTorch/Tensorflow的なもの https://github.com/burn-rs/burn

Slide 6

Slide 6 text

Burnとは? ● Rustで書かれた深層学習フレームワーク ● Pythonで言うところのPyTorch/Tensorflow的なもの ● テンソル計算のバックエンドをいくつか選べる ○ Tch - LibTorch(C++)のRustラッパー(CPU/GPU) ○ NdArray+AutoDiff(CPUのみ) https://github.com/burn-rs/burn

Slide 7

Slide 7 text

なぜRustなのか?深層学習といえばPythonなのでは? ● Burnのブログで熱弁 ○ https://burn-rs.github.io/blog/a-case-for-ru st-in-deep-learning

Slide 8

Slide 8 text

なぜRustなのか?深層学習といえばPythonなのでは? ● Burnのブログで熱弁 ○ https://burn-rs.github.io/blog/a-case-for-ru st-in-deep-learning ● Pythonが使われてきた理由 ○ シンプルで学びやすい ○ 研究のサイクルを回しやすい

Slide 9

Slide 9 text

なぜRustなのか?深層学習といえばPythonなのでは? ● Burnのブログで熱弁 ○ https://burn-rs.github.io/blog/a-case-for-ru st-in-deep-learning ● Pythonが使われてきた理由 ○ シンプルで学びやすい ○ 研究のサイクルを回しやすい ● Pythonが使われていることの課題 ○ フレームワーク(PyTorchなど)内部はC++ ○ フレームワーク開発者よりのエンジニアと研究 者の間の技術の隔たり

Slide 10

Slide 10 text

なぜRustなのか?深層学習といえばPythonなのでは? ● Burnのブログで熱弁 ○ https://burn-rs.github.io/blog/a-case-for-ru st-in-deep-learning ● Pythonが使われてきた理由 ○ シンプルで学びやすい ○ 研究のサイクルを回しやすい ● Pythonが使われていることの課題 ○ フレームワーク(PyTorchなど)内部はC++ ○ フレームワーク開発者よりのエンジニアと研究 者の間の技術の隔たり ● Rustで解決できること ○ 1つの言語で低レベルから抽象レイヤまで扱 える ○ エンジニアと研究者の隔たりを無くす

Slide 11

Slide 11 text

一般的な深層学習フレームワークに含まれる機能

Slide 12

Slide 12 text

一般的な深層学習フレームワークに含まれる機能 ● テンソル計算 ○ 一般的なテンソルを用いた計算 ○ 自動微分 ○ CPU・GPUなどの使用するリソースの切り替え

Slide 13

Slide 13 text

一般的な深層学習フレームワークに含まれる機能 ● テンソル計算 ○ 一般的なテンソルを用いた計算 ○ 自動微分 ○ CPU・GPUなどの使用するリソースの切り替え ● データ準備 ○ 保存されている画像データなどを読み取ってテンソルに変換 ○ 一般的なデータセットはダウンロードで取ってくる

Slide 14

Slide 14 text

一般的な深層学習フレームワークに含まれる機能 ● テンソル計算 ○ 一般的なテンソルを用いた計算 ○ 自動微分 ○ CPU・GPUなどの使用するリソースの切り替え ● データ準備 ○ 保存されている画像データなどを読み取ってテンソルに変換 ○ 一般的なデータセットはダウンロードで取ってくる ● モデル作成 ○ 深層学習で使用するモデルのネットワーク構造を作る ○ Lossの設定

Slide 15

Slide 15 text

一般的な深層学習フレームワークに含まれる機能 ● テンソル計算 ○ 一般的なテンソルを用いた計算 ○ 自動微分 ○ CPU・GPUなどの使用するリソースの切り替え ● データ準備 ○ 保存されている画像データなどを読み取ってテンソルに変換 ○ 一般的なデータセットはダウンロードで取ってくる ● モデル作成 ○ 深層学習で使用するモデルのネットワーク構造を作る ○ Lossの設定 ● 学習 ○ パラメタ設定 ○ どのような手法で最適化するか選択 ○ イテレーション回数やモデルの保存方法などを決める

Slide 16

Slide 16 text

Burnのモジュール一覧 ● テンソル計算(Backend) ○ テンソルのバックエンドトレイト(Tch, NdArray) ● データ準備(Dataset) ○ PyTorchのDataLoaderに近い ○ mnistやhuggingface hubなどのデータセットを使用できる ● モデル作成(Module) ○ NNのレイヤー(Linear, Convolution, Pooling, Activation, …) ○ Loss関数(CrossEntropyLossのみ) ● 学習(Config) ○ デフォルト値設定やJSONシリアライズできる機械学習用パラメタ ● 学習(Learner) ○ 最適化ソルバの選択(SGD, Adam) ○ 学習時のメトリクスやプロット、モデル保存の設定

Slide 17

Slide 17 text

MNISTのコードをPyTorchと比較する ● MNISTとは? ○ 機械学習のサンプルでよく使用される分類のための機械学習データ ○ 手書きのグレースケール画像に対して0~9のどの文字かを判定する ○ データ数60000 https://upload.wikimedia.org/wikipedia/commons/thumb/2/27/MnistExamples.png/3 20px-MnistExamples.png

Slide 18

Slide 18 text

MNISTのコードをPyTorchと比較する ● モデルの定義(PyTorchの場合) ○ モデルとなる構造体を定義 ○ コンストラクタでモデルに含まれる各層の初期化 ■ 主にモデルパラメタを含むものが初期化される class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv2 = nn.Conv2d(32, 64, 3, 1) self.dropout1 = nn.Dropout(0.25) self.dropout2 = nn.Dropout(0.5) self.fc1 = nn.Linear(9216, 128) self.fc2 = nn.Linear(128, 10) Conv2d Conv2d MaxPool2d ReLU ReLU Dropout Flatten Linear Linear ReLU Dropout Log Softmax

Slide 19

Slide 19 text

MNISTのコードをPyTorchと比較する ● モデルの定義(PyTorchの場合) ○ forward関数に実際の計算を記述 ○ backward(微分計算)は計算グラフを用いて自動微分よってに行われる def forward(self, x): x = self.conv1(x) x = F.relu(x) x = self.conv2(x) x = F.relu(x) x = F.max_pool2d(x, 2) x = self.dropout1(x) x = torch.flatten(x, 1) x = self.fc1(x) x = F.relu(x) x = self.dropout2(x) x = self.fc2(x) output = F.log_softmax(x, dim=1) return output Conv2d Conv2d MaxPool2d ReLU ReLU Dropout Flatten Linear Linear ReLU Dropout Log Softmax

Slide 20

Slide 20 text

MNISTのコードをPyTorchと比較する ● モデルの定義(Burnの場合) ○ PyTorchの書き方にかなり近い ○ モデルとなる構造体にMuduleトレイトを継承させる #[derive(Module, Debug)] pub struct Model { conv1: Param>, conv2: Param>, dropout1: Dropout, dropout2: Dropout, linear1: Param>, linear2: Param>, max_pool: MaxPool2d, } pub fn new() -> Self { Self { conv1: Param::new(Conv2d::new( & Conv2dConfig ::new([1, 32], [3, 3]), )), conv2: Param::new(Conv2d::new( & Conv2dConfig ::new([32, 64], [3, 3]), )), dropout1: Dropout::new(&DropoutConfig ::new(0.25)), dropout2: Dropout::new(&DropoutConfig ::new(0.5)), linear1: Param::new(Linear::new(&LinearConfig ::new(9216, 128))), linear2: Param::new(Linear::new(&LinearConfig ::new(128, 10))), max_pool: MaxPool2d::new( & MaxPool2dConfig ::new(64, [2, 2]).with_strides ([2, 2] )), } }

Slide 21

Slide 21 text

MNISTのコードをPyTorchと比較する ● モデルの定義(Burnの場合) ○ PyTorchの書き方にかなり近い ○ In/Outのテンソルの次元を定義しながら書けるのが嬉しい ■ rust-analyzerで途中の出力の次元も分かる pub fn forward(&self, input: Tensor) -> Tensor { let [batch_size, heigth, width] = input.dims(); let x = input.reshape([batch_size, 1, heigth, width]).detach(); let x = self.conv1.forward(x); let x = relu(&x); let x = self.conv2.forward(x); let x = relu(&x); let x = self.max_pool.forward(x); let x = self.dropout1.forward(x); let x = x.reshape([batch_size, 9216]); let x = self.linear1.forward(x); let x = relu(&x); let x = self.dropout2.forward(x); let out = self.linear2.forward(x); out }

Slide 22

Slide 22 text

MNISTのコードをPyTorchと比較する ● パラメタ設定 ○ Configトレイトを使うことで、デフォルト値の設定が可能 ● 学習 ○ 最適化はAdamを使用 ○ プロットやモデルのチェックポイント保存設定を行ってfitを実行 #[derive(Config)] pub struct MnistConfig { #[config(default = 2)] pub num_epochs: usize, #[config(default = 64)] pub batch_size: usize, #[config(default = 8)] pub num_workers: usize, #[config(default = 42)] pub seed: u64, pub optimizer: AdamConfig, } let learner = LearnerBuilder::new(ARTIFACT_DIR) .metric_train_plot(AccuracyMetric::new()) .metric_valid_plot(AccuracyMetric::new()) .metric_train_plot(LossMetric::new()) .metric_valid_plot(LossMetric::new()) .with_file_checkpointer::(2) .devices(vec![device]) .num_epochs(config.num_epochs) .build(model, optim); let _model_trained = learner.fit(dataloader_train, dataloader_test);

Slide 23

Slide 23 text

実行してみる ● 最初にMNISTデータ取得が走る ○ 内部でpythonを使ってデータダウンロードを行っているため、初回実行時にpipと かが実行される ● 進捗とLos/Accuracyのグラフがターミナル上に表示される ● 今回はバックエンドをNdArrayにした ● CPUだと結構時間かかる(1エポック数日)

Slide 24

Slide 24 text

その他の機能や対応状況 ● 自然言語処理のブレイクスルーとなったTransformerが既に実装されている ○ text-classificationのexample ○ https://github.com/burn-rs/burn/tree/main/examples/text-classification

Slide 25

Slide 25 text

その他の機能や対応状況 ● 自然言語処理のブレイクスルーとなったTransformerが既に実装されている ○ text-classificationのexample ○ https://github.com/burn-rs/burn/tree/main/examples/text-classification ● 今後追加予定のまだ無い機能もいろいろ(2023/2/24時点) ○ Sigmoid関数 ○ Conv2dにおけるStride(現状1固定)、Dilationなど ○ Learning rate scheduler ○ …

Slide 26

Slide 26 text

まとめ ● Rustの深層学習ライブラリBurnの紹介 ● 書き方がPyTorchに似ていて、PyTorch使ってた人は使いやすそう ● 今回やらなかったけどTch使ってGPUの学習もしてみたい ● 今回使用したコード ○ https://github.com/neka-nat/burn-tutorial