Slide 9
Slide 9 text
Spectral Normの実装。BigGANより抜粋
def l2normalize(v, eps=1e-4):
return v / (v.norm() + eps)
class SpectralNorm(nn.Module):
# アバウトなイメージ(抜粋なので正常に動作するコードではない)
def _update_u_v(self):
height = w.data.shape[0]
_w = w.view(height, -1) # 行列にする (height, -1)
for _ in range(self.power_iterations): # power iteration
v = l2normalize(torch.matmul(_w.t(), u)) # (-1, )
u = l2normalize(torch.matmul(_w, v)) # (height, )
sigma = u.dot((_w).mv(v)) # スカラー
setattr(self.module, self.name, w / sigma.expand_as(w)) # normalize
def _make_params(self):
u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
v = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
u.data = l2normalize(u.data)
v.data = l2normalize(v.data)