Upgrade to Pro — share decks privately, control downloads, hide ads and more …

Tensorflow/Pytorch モデル移植のススメ

rishigami
August 17, 2021

Tensorflow/Pytorch モデル移植のススメ

1. Tensorflow/Pytorchのモデル実装の基礎
2. Tensorflow/Pytorchの比較
3. Tensorflow/Pytorchの重み変換のテクニック
4. 実践Tensorflow/Pytorchのモデル移植
について紹介します

第3回分析コンペLT会
https://kaggle-friends.connpass.com/event/220927/

rishigami

August 17, 2021
Tweet

Other Decks in Programming

Transcript

  1. Tensorflow/Pytorch
    モデル移植のススメ
    2021-08-17
    第3回分析コンペLT会
    presented by @rishigami_

    View Slide

  2. こんな悩みありませんか?
    \突然ですが/

    View Slide

  3. 実装がtf/torchしかない
    参考になりそうな論文/notebookを見つけたが…
    重みがtf/torchしかない
    実装はできるものの…
    フレームワークにまつわる困りごと
    →これらの悩みを自分で解決できるように!
    github notebook 大規模事前学習モデルの隆盛

    View Slide

  4. 本日の内容
    Tensorflow/Pytorchの
    モデル実装の基礎
    Tensorflow/Pytorchの
    比較
    Tensorflow/Pytorchの
    重み変換のテクニック
    実践
    Tensorflow/Pytorchの
    モデル移植

    View Slide

  5. Tensorflowのモデル定義
    Tensorflow (Sequential API) Tensorflow (Functional API)
    model = tf.keras.Sequential()
    model.add(layers.Flatten())
    model.add(layers.Dense(512))
    model.add(layers.ReLU())
    model.add(layers.Dense(256))
    model.add(layers.ReLU())
    model.add(layers.Dense(10))
    model.add(layers.Softmax())
    inputs = Input(shape=(28,28))
    x = layers.Flatten()(inputs)
    x = layers.Dense(512)(x)
    x = layers.ReLU()(x)
    x = layers.Dense(256)(x)
    x = layers.ReLU()(x)
    x = layers.Dense(10)(x)
    outputs = layers.Softmax()(x)
    model = Model(inputs=inputs,
    outputs=outputs)
    Sequentialのコンストラクタ利用
    レイヤーを追加していく
    Modelのコンストラクタ利用
    入力と出力を指定する
    Tensorflow/Pytorchの
    モデル実装の基礎

    View Slide

  6. Tensorflow/Pytorchのモデル定義
    Tensorflow (Subclassing API) Pytorch (Subclass)
    class TFCustomModel(tf.keras.Model):
    def __init__(self):
    super(TFCustomModel, self).__init__()
    self.flatten = layers.Flatten()
    self.dense1 = layers.Dense(512)
    self.dense2 = layers.Dense(256)
    self.dense3 = layers.Dense(10)
    def call(self, x):
    x = self.flatten(x)
    x = tf.nn.relu(self.dense1(x))
    x = tf.nn.relu(self.dense2(x))
    x = tf.nn.softmax(self.dense3(x))
    return x
    tf_model = TFCustomModel()
    Modelのサブクラス化
    Pytorchに近い書き方
    class CustomModel(nn.Module):
    def __init__(self):
    super(CustomModel, self).__init__()
    self.dense1 = nn.Linear(784, 512)
    self.dense2 = nn.Linear(512, 256)
    self.dense3 = nn.Linear(256, 10)
    def forward(self, x):
    x = x.view(-1, 784)
    x = F.relu(self.dense1(x))
    x = F.relu(self.dense2(x))
    x = F.softmax(self.dense3(x))
    return x
    torch_model = CustomModel()
    Moduleのサブクラス化
    (Sequentialでの書き方もあり)
    Tensorflow/Pytorchの
    モデル実装の基礎

    View Slide

  7. [B, H, W, C] / [B, C, H, W]
    Tensorflow/Pytorchの主な違い
    1. channel last / channel first
    2. in_featuresの有無
    layers.Dense(10) / nn.Linear(256, 10)
    Tensorflow Pytorch
    Tensorflow Pytorch
    3. hyperparameterの違い
    layers.BatchNormalization(momentum=0.99, epsilon=0.001)
    nn.BatchNorm2d(momentum=0.1, eps=1e-05)
    Tensorflow
    Pytorch
    →channelの位置が異なる
    →pytorchはin_featuresが必要
    →レイヤーによってはhyperparameterが異なる
    Tensorflow/Pytorchの
    比較

    View Slide

  8. Tensorflow/Pytorchのレイヤー名の違い
    Pytorch
    nn.Linear
    nn.Conv1d
    nn.Conv2d
    nn.ConvTranspose2d
    nn.RNN
    nn.LSTM
    nn.BatchNorm1d/2d
    nn.LayerNorm
    Tensorflow/Pytorchの
    比較
    Tensorflow
    layers.Dense
    layers.Conv1D
    layers.Conv2D
    layers.Conv2DTranspose
    layers.SimpleRNN
    layers.LSTM
    layers.BatchNormalization
    layers.LayerNormalization

    View Slide

  9. Weightへのアクセス
    Tensorflow Pytorch
    tf_model torch_model
    for layer in tf_model.layers:
    for var in (layer.weights):
    print(var.name, var.numpy().shape)
    for param in torch_model.named_parameters():
    print(param[0], param[1].detach().numpy().shape)
    Tensorflow/Pytorchの
    重み変換のテクニック
    weightへのアクセス weightへのアクセス
    dense/kernel:0 (784, 512)
    dense/bias:0 (512,)
    dense_1/kernel:0 (512, 256)
    dense_1/bias:0 (256,)
    dense_2/kernel:0 (256, 10)
    dense_2/bias:0 (10,)
    name, shapeの表示結果
    dense1.weight (512, 784)
    dense1.bias (512,)
    dense2.weight (256, 512)
    dense2.bias (256,)
    dense3.weight (10, 256)
    dense3.bias (10,)
    name, shapeの表示結果

    View Slide

  10. TensorflowからPytorchの例
    Tensorflow Pytorch
    Tensorflow/Pytorchの
    重み変換のテクニック
    dense/kernel:0 (784, 512)
    dense/bias:0 (512,)
    dense_1/kernel:0 (512, 256)
    dense_1/bias:0 (256,)
    dense_2/kernel:0 (256, 10)
    dense_2/bias:0 (10,)
    dense1.weight (512, 784)
    dense1.bias (512,)
    dense2.weight (256, 512)
    dense2.bias (256,)
    dense3.weight (10, 256)
    dense3.bias (10,)
    tf_names, tf_params = [], []
    for layer in tf_model.layers:
    for var in (layer.weights):
    tf_names.append(var.name)
    tf_params.append(var.numpy())
    torch_params = torch_model.state_dict()
    for idx, key in enumerate(torch_params):
    if "kernel" in tf_names[idx]:
    torch_params[key].data = torch.tensor(tf_params[idx].transpose(1,0))
    else:
    torch_params[key].data = torch.tensor(tf_params[idx])
    torch_model.load_state_dict(torch_params)
    tf_model torch_model

    View Slide

  11. PytorchからTensorflowの例
    Tensorflow Pytorch
    Tensorflow/Pytorchの
    重み変換のテクニック
    dense/kernel:0 (784, 512)
    dense/bias:0 (512,)
    dense_1/kernel:0 (512, 256)
    dense_1/bias:0 (256,)
    dense_2/kernel:0 (256, 10)
    dense_2/bias:0 (10,)
    dense1.weight (512, 784)
    dense1.bias (512,)
    dense2.weight (256, 512)
    dense2.bias (256,)
    dense3.weight (10, 256)
    dense3.bias (10,)
    torch_params = torch_model.state_dict()
    torch_keys = list(torch_params.keys())
    for layer in tf_model.layers:
    for var in (layer.weights):
    torch_key = torch_keys.pop(0)
    torch_param = torch_params[torch_key].numpy()
    if "kernel" in var.name:
    var.assign(torch_param.transpose(1,0))
    else:
    var.assign(torch_param)
    tf_model torch_model

    View Slide

  12. layers.Dense / nn.Linear
    Tensorflow Pytorch
    layers.Dense() nn.Linear()
    kernel
    bias
    weight
    bias
    transpose(1,0)
    transpose(1,0)
    Tensorflow/Pytorchの
    重み変換のテクニック

    View Slide

  13. layers.Conv1D / nn.Conv1d
    Tensorflow Pytorch
    layers.Conv1D() nn.Conv1d()
    kernel
    bias
    weight
    bias
    transpose(2,1,0)
    transpose(2,1,0)
    Tensorflow/Pytorchの
    重み変換のテクニック

    View Slide

  14. layers.Conv2D / nn.Conv2d
    Tensorflow Pytorch
    layers.Conv2D() nn.Conv2d()
    kernel
    bias
    weight
    bias
    transpose(3,2,0,1)
    transpose(2,3,1,0)
    Tensorflow/Pytorchの
    重み変換のテクニック

    View Slide

  15. layers.Conv2DTranspose / nn.ConvTranspose2d
    Tensorflow Pytorch
    layers.Conv2DTranspose() nn.ConvTranspose2d()
    kernel
    bias
    weight
    bias
    transpose(3,2,0,1)
    transpose(2,3,1,0)
    Tensorflow/Pytorchの
    重み変換のテクニック

    View Slide

  16. layers.SimpleRNN / nn.RNN
    Tensorflow Pytorch
    layers.SimpleRNN() nn.RNN()
    kernel
    recurrent_kernel
    weight_ih_l0
    weight_hh_l0
    transpose(1,0)
    transpose(1,0)
    bias bias_ih_l0
    bias_hh_l0
    transpose(1,0)
    transpose(1,0)
    +
    zeros_like
    Tensorflow/Pytorchの
    重み変換のテクニック

    View Slide

  17. layers.LSTM / nn.LSTM
    Tensorflow Pytorch
    layers.LSTM() nn.LSTM()
    kernel
    recurrent_kernel
    weight_ih_l0
    weight_hh_l0
    transpose(1,0)
    transpose(1,0)
    bias bias_ih_l0
    bias_hh_l0
    transpose(1,0)
    transpose(1,0)
    +
    zeros_like
    Tensorflow/Pytorchの
    重み変換のテクニック

    View Slide

  18. layers.BatchNormalization / nn.BatchNorm1d,2d
    Tensorflow Pytorch
    layers.BatchNormalization() nn.BatchNorm1d,2d()
    gamma
    beta
    weight
    bias
    moving_mean running_mean
    running_var
    moving_variance
    Tensorflow/Pytorchの
    重み変換のテクニック

    View Slide

  19. layers.LayerNormalization / nn.LayerNorm
    Tensorflow Pytorch
    layers.LayerNormalization() nn.LayerNorm()
    gamma
    beta
    weight
    bias
    Tensorflow/Pytorchの
    重み変換のテクニック
    layers.BatchNormalization(momentum=0.99, epsilon=0.001)
    nn.BatchNorm2d(momentum=0.1, eps=1e-05)
    Tensorflow
    Pytorch
    layers.LayerNormalization(epsilon=0.001)
    nn.LayerNorm(eps=1e-05)
    Tensorflow
    Pytorch
    ハイパラ
    補足

    View Slide

  20. VGG16のモデル移植
    Tensorflow Pytorch
    自前実装+Weight移植
    class VGG(nn.Module):
    def __init__(self, features, num_classes=1000):
    super(VGG, self).__init__()
    self.features = features
    self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
    self.classifier = nn.Sequential(
    nn.Linear(512 * 7 * 7, 4096),
    nn.ReLU(True),
    nn.Dropout(),
    nn.Linear(4096, 4096),
    nn.ReLU(True),
    nn.Dropout(),
    nn.Linear(4096, num_classes))
    def forward(self, x):
    x = self.features(x)
    x = self.avgpool(x)
    x = torch.flatten(x, 1)
    x = self.classifier(x)
    return x
    https://github.com/pytorch/vision/blob/m
    aster/torchvision/models/vgg.py
    実践
    Tensorflow/Pytorchの
    モデル移植
    ?

    View Slide

  21. VGG16について 実践
    Tensorflow/Pytorchの
    モデル移植
    VGG16のモデル構造
    https://neurohive.io/en/popular-networks/vgg16/
    https://kgptalkie.com/image-classification-using-pre-trained-vgg-16-model/

    View Slide

  22. make_layers()の移植
    Tensorflow Pytorch
    cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M',
    512, 512, 512, 'M', 512, 512, 512, 'M']
    def make_layers(cfg):
    tflayers = []
    for v in cfg:
    if v == 'M':
    tflayers += [layers.MaxPool2D(strides=2)]
    else:
    tflayers += [layers.Conv2D(v,
    kernel_size=3,
    padding='same',
    activation='relu')]
    return tf.keras.Sequential(tflayers)
    cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512,
    512, 512, 'M', 512, 512, 512, 'M']
    def make_layers(cfg):
    layers = []
    in_channels = 3
    for v in cfg:
    if v == 'M':
    layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
    else:
    conv2d = nn.Conv2d(in_channels,
    v,
    kernel_size=3,
    padding=1)
    layers += [conv2d, nn.ReLU(inplace=True)]
    in_channels = v
    return nn.Sequential(*layers)
    https://github.com/pytorch/vision/blob/m
    aster/torchvision/models/vgg.py#L69
    実践
    Tensorflow/Pytorchの
    モデル移植
    自前実装

    View Slide

  23. VGG16()の移植
    Tensorflow Pytorch
    class TFVGG(tf.keras.Model):
    def __init__(self, features, num_classes=1000):
    super(TFVGG, self).__init__()
    self.features = features
    self.flatten = layers.Flatten()
    self.classifier = tf.keras.Sequential([
    layers.Dense(4096, activation='relu'),
    layers.Dropout(0.5),
    layers.Dense(4096, activation='relu'),
    layers.Dropout(0.5),
    layers.Dense(num_classes)])
    def call(self, x):
    x = self.features(x) # [B,H,W,C]
    x = tf.transpose(x, perm=[0,3,1,2]) # [B,C,H,W]
    x = self.flatten(x)
    x = self.classifier(x)
    return x
    tf_model = TFVGG(make_layers(cfg))
    class VGG(nn.Module):
    def __init__(self, features, num_classes=1000):
    super(VGG, self).__init__()
    self.features = features
    self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
    self.classifier = nn.Sequential(
    nn.Linear(512 * 7 * 7, 4096),
    nn.ReLU(True),
    nn.Dropout(),
    nn.Linear(4096, 4096),
    nn.ReLU(True),
    nn.Dropout(),
    nn.Linear(4096, num_classes))
    def forward(self, x):
    x = self.features(x) # [B,C,H,W]
    x = self.avgpool(x)
    x = torch.flatten(x, 1)
    x = self.classifier(x)
    return x
    torch_model = torchvision.models.vgg16(pretrained=True)
    https://github.com/pytorch/vision/blob/m
    aster/torchvision/models/vgg.py#L25
    実践
    Tensorflow/Pytorchの
    モデル移植
    自前実装

    View Slide

  24. VGG16のWeightへのアクセス
    Tensorflow Pytorch
    TFVGG() torchvision.models.vgg16()
    conv2d/kernel:0 (3, 3, 3, 64)
    conv2d/bias:0 (64,)
    conv2d_1/kernel:0 (3, 3, 64, 64)
    conv2d_1/bias:0 (64,)
    conv2d_2/kernel:0 (3, 3, 64, 128)
    conv2d_2/bias:0 (128,)

    dense/kernel:0 (25088, 4096)
    dense/bias:0 (4096,)
    dense_1/kernel:0 (4096, 4096)
    dense_1/bias:0 (4096,)
    dense_2/kernel:0 (4096, 1000)
    dense_2/bias:0 (1000,)
    name, shapeの表示結果
    features.0.weight (64, 3, 3, 3)
    features.0.bias (64,)
    features.2.weight (64, 64, 3, 3)
    features.2.bias (64,)
    features.5.weight (128, 64, 3, 3)
    features.5.bias (128,)

    classifier.0.weight (4096, 25088)
    classifier.0.bias (4096,)
    classifier.3.weight (4096, 4096)
    classifier.3.bias (4096,)
    classifier.6.weight (1000, 4096)
    classifier.6.bias (1000,)
    name, shapeの表示結果
    実践
    Tensorflow/Pytorchの
    モデル移植

    View Slide

  25. VGG16のWeightの移植
    Tensorflow Pytorch
    実践
    Tensorflow/Pytorchの
    モデル移植
    conv2d/kernel:0 (3, 3, 3, 64)
    conv2d/bias:0 (64,)
    conv2d_1/kernel:0 (3, 3, 64, 64)
    conv2d_1/bias:0 (64,)
    conv2d_2/kernel:0 (3, 3, 64, 128)
    conv2d_2/bias:0 (128,)

    dense/kernel:0 (25088, 4096)
    dense/bias:0 (4096,)
    dense_1/kernel:0 (4096, 4096)
    dense_1/bias:0 (4096,)
    dense_2/kernel:0 (4096, 1000)
    dense_2/bias:0 (1000,)
    torch_params = torch_model.state_dict()
    torch_keys = list(torch_params.keys())
    for layer in tf_model.layers:
    for var in (layer.weights):
    torch_key = torch_keys.pop(0)
    torch_param = torch_params[torch_key].numpy()
    if len(torch_param.shape) == 4: # Conv2d.weight
    var.assign(torch_param.transpose(2,3,1,0))
    elif len(torch_param.shape) == 2: # Linear.weight
    var.assign(torch_param.transpose(1,0))
    else:
    var.assign(torch_param)
    features.0.weight (64, 3, 3, 3)
    features.0.bias (64,)
    features.2.weight (64, 64, 3, 3)
    features.2.bias (64,)
    features.5.weight (128, 64, 3, 3)
    features.5.bias (128,)

    classifier.0.weight (4096, 25088)
    classifier.0.bias (4096,)
    classifier.3.weight (4096, 4096)
    classifier.3.bias (4096,)
    classifier.6.weight (1000, 4096)
    classifier.6.bias (1000,)
    TFVGG() torchvision.models.vgg16()

    View Slide

  26. Sanity check
    Tensorflow Pytorch
    実践
    Tensorflow/Pytorchの
    モデル移植
    TFVGG() torchvision.models.vgg16()
    set_seed()
    x = np.random.rand(8, 224, 224, 3)
    out_tf = tf_model(x).numpy()
    torch_model.eval()
    set_seed()
    x = np.random.rand(8, 224, 224, 3)
    x = torch.from_numpy(x.transpose(0,3,1,2)).float()
    out_torch = torch_model(x).detach().numpy()
    diff = np.mean(np.abs(out_tf - out_torch))
    print(f'Mean of Abs Diff: {diff}')
    >> Mean of Abs Diff: 4.78e-07 🎉
    dummyデータによる予測 dummyデータによる予測

    View Slide

  27. まとめ
    Tensorflow/Pytorchのモデル実装の基礎 Tensorflow/Pytorchの比較
    Tensorflow/Pytorchの重み変換のテクニック 実践Tensorflow/Pytorchのモデル移植

    View Slide