Slide 1

Slide 1 text

Tensorflow/Pytorch モデル移植のススメ 2021-08-17 第3回分析コンペLT会 presented by @rishigami_

Slide 2

Slide 2 text

こんな悩みありませんか? \突然ですが/

Slide 3

Slide 3 text

実装がtf/torchしかない 参考になりそうな論文/notebookを見つけたが… 重みがtf/torchしかない 実装はできるものの… フレームワークにまつわる困りごと →これらの悩みを自分で解決できるように! github notebook 大規模事前学習モデルの隆盛

Slide 4

Slide 4 text

本日の内容 Tensorflow/Pytorchの モデル実装の基礎 Tensorflow/Pytorchの 比較 Tensorflow/Pytorchの 重み変換のテクニック 実践 Tensorflow/Pytorchの モデル移植

Slide 5

Slide 5 text

Tensorflowのモデル定義 Tensorflow (Sequential API) Tensorflow (Functional API) model = tf.keras.Sequential() model.add(layers.Flatten()) model.add(layers.Dense(512)) model.add(layers.ReLU()) model.add(layers.Dense(256)) model.add(layers.ReLU()) model.add(layers.Dense(10)) model.add(layers.Softmax()) inputs = Input(shape=(28,28)) x = layers.Flatten()(inputs) x = layers.Dense(512)(x) x = layers.ReLU()(x) x = layers.Dense(256)(x) x = layers.ReLU()(x) x = layers.Dense(10)(x) outputs = layers.Softmax()(x) model = Model(inputs=inputs, outputs=outputs) Sequentialのコンストラクタ利用 レイヤーを追加していく Modelのコンストラクタ利用 入力と出力を指定する Tensorflow/Pytorchの モデル実装の基礎

Slide 6

Slide 6 text

Tensorflow/Pytorchのモデル定義 Tensorflow (Subclassing API) Pytorch (Subclass) class TFCustomModel(tf.keras.Model): def __init__(self): super(TFCustomModel, self).__init__() self.flatten = layers.Flatten() self.dense1 = layers.Dense(512) self.dense2 = layers.Dense(256) self.dense3 = layers.Dense(10) def call(self, x): x = self.flatten(x) x = tf.nn.relu(self.dense1(x)) x = tf.nn.relu(self.dense2(x)) x = tf.nn.softmax(self.dense3(x)) return x tf_model = TFCustomModel() Modelのサブクラス化 Pytorchに近い書き方 class CustomModel(nn.Module): def __init__(self): super(CustomModel, self).__init__() self.dense1 = nn.Linear(784, 512) self.dense2 = nn.Linear(512, 256) self.dense3 = nn.Linear(256, 10) def forward(self, x): x = x.view(-1, 784) x = F.relu(self.dense1(x)) x = F.relu(self.dense2(x)) x = F.softmax(self.dense3(x)) return x torch_model = CustomModel() Moduleのサブクラス化 (Sequentialでの書き方もあり) Tensorflow/Pytorchの モデル実装の基礎

Slide 7

Slide 7 text

[B, H, W, C] / [B, C, H, W] Tensorflow/Pytorchの主な違い 1. channel last / channel first 2. in_featuresの有無 layers.Dense(10) / nn.Linear(256, 10) Tensorflow Pytorch Tensorflow Pytorch 3. hyperparameterの違い layers.BatchNormalization(momentum=0.99, epsilon=0.001) nn.BatchNorm2d(momentum=0.1, eps=1e-05) Tensorflow Pytorch →channelの位置が異なる →pytorchはin_featuresが必要 →レイヤーによってはhyperparameterが異なる Tensorflow/Pytorchの 比較

Slide 8

Slide 8 text

Tensorflow/Pytorchのレイヤー名の違い Pytorch nn.Linear nn.Conv1d nn.Conv2d nn.ConvTranspose2d nn.RNN nn.LSTM nn.BatchNorm1d/2d nn.LayerNorm Tensorflow/Pytorchの 比較 Tensorflow layers.Dense layers.Conv1D layers.Conv2D layers.Conv2DTranspose layers.SimpleRNN layers.LSTM layers.BatchNormalization layers.LayerNormalization

Slide 9

Slide 9 text

Weightへのアクセス Tensorflow Pytorch tf_model torch_model for layer in tf_model.layers: for var in (layer.weights): print(var.name, var.numpy().shape) for param in torch_model.named_parameters(): print(param[0], param[1].detach().numpy().shape) Tensorflow/Pytorchの 重み変換のテクニック weightへのアクセス weightへのアクセス dense/kernel:0 (784, 512) dense/bias:0 (512,) dense_1/kernel:0 (512, 256) dense_1/bias:0 (256,) dense_2/kernel:0 (256, 10) dense_2/bias:0 (10,) name, shapeの表示結果 dense1.weight (512, 784) dense1.bias (512,) dense2.weight (256, 512) dense2.bias (256,) dense3.weight (10, 256) dense3.bias (10,) name, shapeの表示結果

Slide 10

Slide 10 text

TensorflowからPytorchの例 Tensorflow Pytorch Tensorflow/Pytorchの 重み変換のテクニック dense/kernel:0 (784, 512) dense/bias:0 (512,) dense_1/kernel:0 (512, 256) dense_1/bias:0 (256,) dense_2/kernel:0 (256, 10) dense_2/bias:0 (10,) dense1.weight (512, 784) dense1.bias (512,) dense2.weight (256, 512) dense2.bias (256,) dense3.weight (10, 256) dense3.bias (10,) tf_names, tf_params = [], [] for layer in tf_model.layers: for var in (layer.weights): tf_names.append(var.name) tf_params.append(var.numpy()) torch_params = torch_model.state_dict() for idx, key in enumerate(torch_params): if "kernel" in tf_names[idx]: torch_params[key].data = torch.tensor(tf_params[idx].transpose(1,0)) else: torch_params[key].data = torch.tensor(tf_params[idx]) torch_model.load_state_dict(torch_params) tf_model torch_model

Slide 11

Slide 11 text

PytorchからTensorflowの例 Tensorflow Pytorch Tensorflow/Pytorchの 重み変換のテクニック dense/kernel:0 (784, 512) dense/bias:0 (512,) dense_1/kernel:0 (512, 256) dense_1/bias:0 (256,) dense_2/kernel:0 (256, 10) dense_2/bias:0 (10,) dense1.weight (512, 784) dense1.bias (512,) dense2.weight (256, 512) dense2.bias (256,) dense3.weight (10, 256) dense3.bias (10,) torch_params = torch_model.state_dict() torch_keys = list(torch_params.keys()) for layer in tf_model.layers: for var in (layer.weights): torch_key = torch_keys.pop(0) torch_param = torch_params[torch_key].numpy() if "kernel" in var.name: var.assign(torch_param.transpose(1,0)) else: var.assign(torch_param) tf_model torch_model

Slide 12

Slide 12 text

layers.Dense / nn.Linear Tensorflow Pytorch layers.Dense() nn.Linear() kernel bias weight bias transpose(1,0) transpose(1,0) Tensorflow/Pytorchの 重み変換のテクニック

Slide 13

Slide 13 text

layers.Conv1D / nn.Conv1d Tensorflow Pytorch layers.Conv1D() nn.Conv1d() kernel bias weight bias transpose(2,1,0) transpose(2,1,0) Tensorflow/Pytorchの 重み変換のテクニック

Slide 14

Slide 14 text

layers.Conv2D / nn.Conv2d Tensorflow Pytorch layers.Conv2D() nn.Conv2d() kernel bias weight bias transpose(3,2,0,1) transpose(2,3,1,0) Tensorflow/Pytorchの 重み変換のテクニック

Slide 15

Slide 15 text

layers.Conv2DTranspose / nn.ConvTranspose2d Tensorflow Pytorch layers.Conv2DTranspose() nn.ConvTranspose2d() kernel bias weight bias transpose(3,2,0,1) transpose(2,3,1,0) Tensorflow/Pytorchの 重み変換のテクニック

Slide 16

Slide 16 text

layers.SimpleRNN / nn.RNN Tensorflow Pytorch layers.SimpleRNN() nn.RNN() kernel recurrent_kernel weight_ih_l0 weight_hh_l0 transpose(1,0) transpose(1,0) bias bias_ih_l0 bias_hh_l0 transpose(1,0) transpose(1,0) + zeros_like Tensorflow/Pytorchの 重み変換のテクニック

Slide 17

Slide 17 text

layers.LSTM / nn.LSTM Tensorflow Pytorch layers.LSTM() nn.LSTM() kernel recurrent_kernel weight_ih_l0 weight_hh_l0 transpose(1,0) transpose(1,0) bias bias_ih_l0 bias_hh_l0 transpose(1,0) transpose(1,0) + zeros_like Tensorflow/Pytorchの 重み変換のテクニック

Slide 18

Slide 18 text

layers.BatchNormalization / nn.BatchNorm1d,2d Tensorflow Pytorch layers.BatchNormalization() nn.BatchNorm1d,2d() gamma beta weight bias moving_mean running_mean running_var moving_variance Tensorflow/Pytorchの 重み変換のテクニック

Slide 19

Slide 19 text

layers.LayerNormalization / nn.LayerNorm Tensorflow Pytorch layers.LayerNormalization() nn.LayerNorm() gamma beta weight bias Tensorflow/Pytorchの 重み変換のテクニック layers.BatchNormalization(momentum=0.99, epsilon=0.001) nn.BatchNorm2d(momentum=0.1, eps=1e-05) Tensorflow Pytorch layers.LayerNormalization(epsilon=0.001) nn.LayerNorm(eps=1e-05) Tensorflow Pytorch ハイパラ 補足

Slide 20

Slide 20 text

VGG16のモデル移植 Tensorflow Pytorch 自前実装+Weight移植 class VGG(nn.Module): def __init__(self, features, num_classes=1000): super(VGG, self).__init__() self.features = features self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) self.classifier = nn.Sequential( nn.Linear(512 * 7 * 7, 4096), nn.ReLU(True), nn.Dropout(), nn.Linear(4096, 4096), nn.ReLU(True), nn.Dropout(), nn.Linear(4096, num_classes)) def forward(self, x): x = self.features(x) x = self.avgpool(x) x = torch.flatten(x, 1) x = self.classifier(x) return x https://github.com/pytorch/vision/blob/m aster/torchvision/models/vgg.py 実践 Tensorflow/Pytorchの モデル移植 ?

Slide 21

Slide 21 text

VGG16について 実践 Tensorflow/Pytorchの モデル移植 VGG16のモデル構造 https://neurohive.io/en/popular-networks/vgg16/ https://kgptalkie.com/image-classification-using-pre-trained-vgg-16-model/

Slide 22

Slide 22 text

make_layers()の移植 Tensorflow Pytorch cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'] def make_layers(cfg): tflayers = [] for v in cfg: if v == 'M': tflayers += [layers.MaxPool2D(strides=2)] else: tflayers += [layers.Conv2D(v, kernel_size=3, padding='same', activation='relu')] return tf.keras.Sequential(tflayers) cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'] def make_layers(cfg): layers = [] in_channels = 3 for v in cfg: if v == 'M': layers += [nn.MaxPool2d(kernel_size=2, stride=2)] else: conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) layers += [conv2d, nn.ReLU(inplace=True)] in_channels = v return nn.Sequential(*layers) https://github.com/pytorch/vision/blob/m aster/torchvision/models/vgg.py#L69 実践 Tensorflow/Pytorchの モデル移植 自前実装

Slide 23

Slide 23 text

VGG16()の移植 Tensorflow Pytorch class TFVGG(tf.keras.Model): def __init__(self, features, num_classes=1000): super(TFVGG, self).__init__() self.features = features self.flatten = layers.Flatten() self.classifier = tf.keras.Sequential([ layers.Dense(4096, activation='relu'), layers.Dropout(0.5), layers.Dense(4096, activation='relu'), layers.Dropout(0.5), layers.Dense(num_classes)]) def call(self, x): x = self.features(x) # [B,H,W,C] x = tf.transpose(x, perm=[0,3,1,2]) # [B,C,H,W] x = self.flatten(x) x = self.classifier(x) return x tf_model = TFVGG(make_layers(cfg)) class VGG(nn.Module): def __init__(self, features, num_classes=1000): super(VGG, self).__init__() self.features = features self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) self.classifier = nn.Sequential( nn.Linear(512 * 7 * 7, 4096), nn.ReLU(True), nn.Dropout(), nn.Linear(4096, 4096), nn.ReLU(True), nn.Dropout(), nn.Linear(4096, num_classes)) def forward(self, x): x = self.features(x) # [B,C,H,W] x = self.avgpool(x) x = torch.flatten(x, 1) x = self.classifier(x) return x torch_model = torchvision.models.vgg16(pretrained=True) https://github.com/pytorch/vision/blob/m aster/torchvision/models/vgg.py#L25 実践 Tensorflow/Pytorchの モデル移植 自前実装

Slide 24

Slide 24 text

VGG16のWeightへのアクセス Tensorflow Pytorch TFVGG() torchvision.models.vgg16() conv2d/kernel:0 (3, 3, 3, 64) conv2d/bias:0 (64,) conv2d_1/kernel:0 (3, 3, 64, 64) conv2d_1/bias:0 (64,) conv2d_2/kernel:0 (3, 3, 64, 128) conv2d_2/bias:0 (128,) … dense/kernel:0 (25088, 4096) dense/bias:0 (4096,) dense_1/kernel:0 (4096, 4096) dense_1/bias:0 (4096,) dense_2/kernel:0 (4096, 1000) dense_2/bias:0 (1000,) name, shapeの表示結果 features.0.weight (64, 3, 3, 3) features.0.bias (64,) features.2.weight (64, 64, 3, 3) features.2.bias (64,) features.5.weight (128, 64, 3, 3) features.5.bias (128,) … classifier.0.weight (4096, 25088) classifier.0.bias (4096,) classifier.3.weight (4096, 4096) classifier.3.bias (4096,) classifier.6.weight (1000, 4096) classifier.6.bias (1000,) name, shapeの表示結果 実践 Tensorflow/Pytorchの モデル移植

Slide 25

Slide 25 text

VGG16のWeightの移植 Tensorflow Pytorch 実践 Tensorflow/Pytorchの モデル移植 conv2d/kernel:0 (3, 3, 3, 64) conv2d/bias:0 (64,) conv2d_1/kernel:0 (3, 3, 64, 64) conv2d_1/bias:0 (64,) conv2d_2/kernel:0 (3, 3, 64, 128) conv2d_2/bias:0 (128,) … dense/kernel:0 (25088, 4096) dense/bias:0 (4096,) dense_1/kernel:0 (4096, 4096) dense_1/bias:0 (4096,) dense_2/kernel:0 (4096, 1000) dense_2/bias:0 (1000,) torch_params = torch_model.state_dict() torch_keys = list(torch_params.keys()) for layer in tf_model.layers: for var in (layer.weights): torch_key = torch_keys.pop(0) torch_param = torch_params[torch_key].numpy() if len(torch_param.shape) == 4: # Conv2d.weight var.assign(torch_param.transpose(2,3,1,0)) elif len(torch_param.shape) == 2: # Linear.weight var.assign(torch_param.transpose(1,0)) else: var.assign(torch_param) features.0.weight (64, 3, 3, 3) features.0.bias (64,) features.2.weight (64, 64, 3, 3) features.2.bias (64,) features.5.weight (128, 64, 3, 3) features.5.bias (128,) … classifier.0.weight (4096, 25088) classifier.0.bias (4096,) classifier.3.weight (4096, 4096) classifier.3.bias (4096,) classifier.6.weight (1000, 4096) classifier.6.bias (1000,) TFVGG() torchvision.models.vgg16()

Slide 26

Slide 26 text

Sanity check Tensorflow Pytorch 実践 Tensorflow/Pytorchの モデル移植 TFVGG() torchvision.models.vgg16() set_seed() x = np.random.rand(8, 224, 224, 3) out_tf = tf_model(x).numpy() torch_model.eval() set_seed() x = np.random.rand(8, 224, 224, 3) x = torch.from_numpy(x.transpose(0,3,1,2)).float() out_torch = torch_model(x).detach().numpy() diff = np.mean(np.abs(out_tf - out_torch)) print(f'Mean of Abs Diff: {diff}') >> Mean of Abs Diff: 4.78e-07 🎉 dummyデータによる予測 dummyデータによる予測

Slide 27

Slide 27 text

まとめ Tensorflow/Pytorchのモデル実装の基礎 Tensorflow/Pytorchの比較 Tensorflow/Pytorchの重み変換のテクニック 実践Tensorflow/Pytorchのモデル移植