Slide 14
Slide 14 text
© LY Corporation
実験条件を fixture に明示
● seed, dataset, model ...
pytest により以下が容易に
● 実験環境準備
○ fix_seed によるシード値固定
● 依存注入
○ 実験したいパラメータの比較
● 実験再実行・並列実行
○ test_ 関数を起点とした実行
○ parametrize + pytest-xdist で
テストの並列実行が可能
14
テストを書くことが
研究を再現可能にする
@pytest.fixture
def seed() -> int:
return 19950815 # シード値は小倉唯さんの誕生日
@pytest.fixture(autouse=True)
def fix_seed(seed: int): # シード値を受け取って固定
random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
@pytest.fixture
def dataset(transform: transforms.Compose) -> Dataset:
return datasets.MNIST(train=True, transform=transform)
@pytest.fixture
def model(device: torch.device) -> nn.Module:
return nn.Sequential(nn.Flatten(), nn.Linear(28*28, 10))
@pytest.mark.parametrize("lr", [1e-2, 1e-3, 1e-4])
@pytest.mark.parametrize("batch_size", [32, 64])
def test_train_smoke(dataset, model, lr, bs, device): # 適宜型付けして
ね
data_loader = DataLoader(dataset, batch_size=bs, shuffle=True)
optim = torch.optim.Adam(model.parameters(), lr=lr)
for x, y in loader:
x, y = x.to(device), y.to(device)
opt.zero_grad() … # 以降 train loop