(https://github.com/TKQXX/BVSA) を見る限りこのモデルを使用? - 入力信号を平滑化→線形層→残差接続→層正規化 22 図を作る時間がなかったです… class ResidualAdd(nn.Module): def __init__(self, f): super().__init__() self.f = f def forward(self, x): return x + self.f(x) def _build_proj_block(in_dim, out_dim, drop_rate): return nn.Sequential( nn.Linear(in_dim, out_dim), ResidualAdd(nn.Sequential( nn.GELU(), nn.Linear(out_dim, out_dim), nn.Dropout(drop_rate), )), nn.LayerNorm(out_dim) ) class EEGProject(nn.Module): def __init__(self, z_dim, c_num, timesteps, drop_proj=0.3): super().__init__() self.input_dim = c_num * (timesteps[1] - timesteps[0]) self.model_txt = _build_proj_block(self.input_dim, z_dim, drop_proj) self.model_img = _build_proj_block(self.input_dim, z_dim, drop_proj) self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) self.softplus = nn.Softplus() def forward(self, x, training): x = x.view(x.shape[0], -1) x_txt, x_img = self.model_txt(x), self.model_img(x) if training: return x_txt, x_img return x_txt.repeat(1, 3), x_img.repeat(1, 3)