Slide 20
Slide 20 text
MNISTのコードをPyTorchと比較する
● モデルの定義(Burnの場合)
○ PyTorchの書き方にかなり近い
○ モデルとなる構造体にMuduleトレイトを継承させる
#[derive(Module, Debug)]
pub struct Model {
conv1: Param>,
conv2: Param>,
dropout1: Dropout,
dropout2: Dropout,
linear1: Param>,
linear2: Param>,
max_pool: MaxPool2d,
}
pub fn new() -> Self {
Self {
conv1: Param::new(Conv2d::new(
&
Conv2dConfig
::new([1, 32], [3, 3]),
)),
conv2: Param::new(Conv2d::new(
&
Conv2dConfig
::new([32, 64], [3, 3]),
)),
dropout1: Dropout::new(&DropoutConfig
::new(0.25)),
dropout2: Dropout::new(&DropoutConfig
::new(0.5)),
linear1: Param::new(Linear::new(&LinearConfig
::new(9216, 128))),
linear2: Param::new(Linear::new(&LinearConfig
::new(128, 10))),
max_pool: MaxPool2d::new(
&
MaxPool2dConfig
::new(64, [2, 2]).with_strides
([2, 2]
)),
}
}