Upgrade to Pro — share decks privately, control downloads, hide ads and more …

Wasserstein GANからSpectral Normalizationへ

koshian
September 13, 2019

Wasserstein GANからSpectral Normalizationへ

#7【画像処理 & 機械学習】論文LT会 @LPIXEL

koshian

September 13, 2019
Tweet

More Decks by koshian

Other Decks in Technology

Transcript

  1. 関連する論文  WGAN(無印、Weight Clipping) M. Arjovsky, S. Chintala, L. Bottou.

    Wasserstein GAN. 2017 https://arxiv.org/abs/1701.07875 →非常に理論的な内容で難しい。この実装の改良がWGAN-GP。理論部分に 価値がある  WGAN-GP(Gradient Penalty) I. Gulrajani, F. Ahmed, M. Arjovsky, V. Dumoulin, A. Courville. Improved Training of Wasserstein GANs. 2017 https://arxiv.org/abs/1704.00028 →WGANの改良。WGANよりは実践的な内容で読みやすい。WGANの実験部 分を補完する内容。  Spectral Normalization T. Miyato, T. Kataoka, M. Koyama, Y. Yoshida. Spectral Normalization for Generative Adversarial Networks. 2018. https://arxiv.org/abs/1802.05957 →おすすめ。WGANの前提であるDの損失関数のリプシッツ連続性を Batch Normzalitionの置き換えで実践。著者は日本人(うち2人がPFN)
  2. Wasserstein GAN(WGAN)  GANの訓練を成功させるためには、 → マッピングが滑らかになる必 要がある=連続性の概念が重要  交差エントロピーのようなKL/JSダイバージェンスは非連続な点がある。 これがよくない(弱い距離が必要)

     Dの損失にリプシッツ(Lipschitz)連続という制約を置く。この制約 (Lipschitz constraint)は、GANの安定化に対する重要なキーワード  WGANではDの距離にEM(Earth Mover)距離であるWasserstein距離を採用 JS距離 ←θ=0で非連続 EM距離 連続なのが 望ましい
  3. PyTorchで書くと  通常のGAN:Dの出力=本物/偽物の確率やロジット   WGAN:Dの出力=本物/偽物のWasserstein距離 import torch # 通常のGAN

    loss_func = torch.nn.BCEWithLogitsLoss() real_out = model_D(real_img) # 本物の画像のDの出力(ロジット) fake_out = model_D(fake_img) # 偽物の画像のDの出力(ロジット) ones = torch.ones(batch_size, 1) # すべて1の行列 zeros = torch.zeros(batch_size, 1) # すべて0の行列 loss_G = loss_func(fake_out, ones) # Gの損失 loss_D = loss_func(fake_out, zeros) + loss_func(real_out, ones) # Dの損失 ←数式で書くと… (コードではmaxとminで符号が入れ替わる) maxの下についてるのがリプシッツ連続の制約 # WGAN loss_D = torch.mean(fake_out) - torch.mean(real_out) # 偽物と本物の平均の差 loss_G = -torch.mean(fake_out) # Gの場合は、ただ偽物の出力を符号反転して損失にする
  4. WGANの問題点とWGAN-GP  WGANではリプシッツ連続を満たすために、Weight Clippingをしている →Dの係数を[-c, c](c=0.01のような定数)に頭打ち  実はこれはよくない。著者もひどい(terrible)方法と言っている。他にも学 習率の非常に低いRMSPropを使っていたり学習方法が極端。収束も遅い 

    これを改善したのがWGAN-GP(Gradient Penalty)  モデル構造へのロバスト性がかなり上がった(ResNetでも訓練できた) 一様乱数のεを用意し ↑ 偽物と本物を線形補間した画像を作る ↑ 線形補間した画像の周りのDの勾配 →この勾配ノルムが1になるような ペナルティー(正則化項をつける) WGANのDのことを 通常のGANと区別して 「critic」ともいう
  5. Spectral Normalization  Spectral Normalization:実はWGAN-GPもまだ改良できる →Gradient Penaltyと損失で計2回微分計算するから遅い  リプシッツ連続を満たすために、GPをつけるのではなく、 DのBatch

    Normalizationの置き換えをする。リプシッツ制約を満たすよう なNormalizationを定義する(リプシッツ定数のコントロール)  これにより、WGANの損失関数をやめても安定性は取れる →通常のGANの交差エントロピーやHinge lossで良い  Spectral Normのやっていること:係数の特異値分解(SVD)  ただし、愚直にSVDすると遅いので、Power Iterationというアルゴリズム で特異値の近似値を高速に計算
  6.  Spectral Normの実装。BigGANより抜粋 def l2normalize(v, eps=1e-4): return v / (v.norm()

    + eps) class SpectralNorm(nn.Module): # アバウトなイメージ(抜粋なので正常に動作するコードではない) def _update_u_v(self): height = w.data.shape[0] _w = w.view(height, -1) # 行列にする (height, -1) for _ in range(self.power_iterations): # power iteration v = l2normalize(torch.matmul(_w.t(), u)) # (-1, ) u = l2normalize(torch.matmul(_w, v)) # (height, ) sigma = u.dot((_w).mv(v)) # スカラー setattr(self.module, self.name, w / sigma.expand_as(w)) # normalize def _make_params(self): u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) v = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) u.data = l2normalize(u.data) v.data = l2normalize(v.data)
  7. まとめ  GANの安定性を考える上で、(特に)Dのリプシッツ連続が 重要  WGANやWGAN-GPはリプシッツ連続の制約をおいたGAN →リプシッツ連続が重要であって、Dの損失関数を Wasserstein距離にするのが唯一の手法ではない  DのBatch

    NormをSpectral Normに置き換えれば、リプシッツ 連続の制約はおける  結論:Spectral Normを使おう (torch.nn.utils.spectral_normやBigGANの実装にある)  ※宣伝:PyTorchで実装したよ https://github.com/koshian2/SNGAN 記事1(自分のブログ) 記事2(Qiita)