Upgrade to Pro — share decks privately, control downloads, hide ads and more …

CZII - CryoET Object Identification 参加振り返り・解法共有

tattaka
February 14, 2025

CZII - CryoET Object Identification 参加振り返り・解法共有

第3回関東Kagger会で発表したmiddle LT資料です(後で資料内リンク追加します)

tattaka

February 14, 2025
Tweet

More Decks by tattaka

Other Decks in Technology

Transcript

  1. 2025/02/15 第3回 関東Kaggler会 3 Agenda 1. CZII - CryoET Object

    Identification 概要 2. 公開notebookでのアプローチ 3. 自チーム (4th) 解法紹介 4. 他チーム解法紹介
  2. 2025/02/15 第3回 関東Kaggler会 4 CZII - CryoET Object Identification 概要

    CryoET(クライオ電子トモグラフィー)技術を用いて取得された 3D細胞画像内の5種類のタンパク質複合体の中心座標を検出するコンペ CryoET:細胞内のタンパク質を自然な環境下で近原子レベルの詳細で 3D画像化する技術 https://www.biorxiv.org/content/10.1101/2024.11.04.621686v1 予測モデル
  3. 2025/02/15 第3回 関東Kaggler会 5 検出対象となるタンパク質複合体 • apo-ferritin • beta-amylase (検出対象外)

    • beta-galactosidase • ribosome • thyroglobulin • virus-like-particle https://www.rcsb.org/
  4. 2025/02/15 第3回 関東Kaggler会 7 ラベルのための座標変換 1pix = 約10.01244Å (angstrom) (正確にはz:

    10.012444196428572Å, x: 10.012444196428572Å, y: 10.012444537618887Å)なので、 ラベルをピクセルグリッドに変換するためには i = int(x/10.012444 + 1) のように変換する必要がある
  5. 2025/02/15 第3回 関東Kaggler会 8 評価指標 F4 score: PrecisionとRecallのバランスをとるが、 よりRecallが強調される評価指標 →Recallが強調される=見逃しに厳しい

    推定した座標値が各class粒子の半径の1/2内にあればTP • apo-ferritin: 60Å, weight=1 • beta-galactosidase: 90Å, weight=2 • ribosome: 150Å, weight=1 • thyroglobulin: 130Å, weight=2 • virus-like-particle: 135Å, weight=1 TP FP
  6. 2025/02/15 第3回 関東Kaggler会 9 公開notebookでのアプローチ (1/2) • 2D-3DUNet ◦ 座標値をヒートマップ

    or セグメンテーションマスクに変換して学習 ◦ backboneであるresnet内のstage毎にreshape・3D方向のavg poolingを繰り返 し、3DCNNで構成されたUNet headに供給する ◦ 一度に1サンプル全体(184x630x630)を処理することはできないため、 決められたwindowをスライドさせて推論させ出力を統合 ◦ 後処理はconnected-components-3dを用いて 連結成分を特定、重心を推論座標値とする 3dcnn + upsample 3dcnn + upsample 3dcnn + upsample 3dcnn + upsample 3dcnn + upsample reshape & depth pooling 32x64x64 input: (64x128x128) reshape & depth pooling 16x32x32 reshape & depth pooling 8x16x16 reshape & depth pooling 4x8x8 reshape & depth pooling 2x4x4 https://www.kaggle.com/competitions/czii-cryo-et-object-identification/di scussion/545221 cc3d
  7. 2025/02/15 第3回 関東Kaggler会 10 公開notebookでのアプローチ (2/2) • YOLOを用いた物体検出アプローチ ◦ 物体検出の枠組みを使って直接座標値を予測する

    ◦ XY方向のスライス毎に推論し、推論した座標を DFS(深さ優先探索)を用いて近いもの同士で集約 ◦ モデルによってはライセンスが怪しい (ホストに質問しても返事がない)
  8. 2025/02/15 第3回 関東Kaggler会 11 team yu4u & tattakaの解法 (4th) 1/6

    ヒートマップを教師とする2D-3DUNetを用いたアプローチ 中盤以降は推論の高速化・public LBでひたすらチューニング 最終的に7 modelをensemble yu4u’s model tattaka’s model
  9. 2025/02/15 第3回 関東Kaggler会 12 team yu4u & tattakaの解法 (4th) 2/6

    • 教師用ヒートマップの作成 1. 粒子中心が1になるように3d gaussian patchを生成 (yu4uさんはsigma=6、tattakaは粒子の種類毎に変更) 2. 粒子座標を物理座標系を画像座標系に変換 3. それぞれの粒子座標に配置
  10. 2025/02/15 第3回 関東Kaggler会 13 team yu4u & tattakaの解法 (4th) 3/6

    • yu4u’s model ◦ ConvNeXt-nanoをbackboneとして、 stage間でdepth pooling・UNetへ渡すときに3DCNNを経由する ◦ 計算量の削減のため、1/4解像度より上は PixelShuffleを使ってUpsample ◦ MSEベースの損失関数を使用 ▪ 教師ヒートマップの値に 応じた重みを加える
  11. 2025/02/15 第3回 関東Kaggler会 14 team yu4u & tattakaの解法 (4th) 4/6

    • tattaka’s model ◦ backboneはResNetRS-50を使用し、stage1まではdepthを1/2に avg pooling、以降はkernel_size=3, padding=1でpooling ◦ stem・stage1のfeature mapは解像度が大きく計算コストが 高いため使用しない ◦ 中間特徴からaux lossを 計算し正則化 ◦ 正例と負例のバランスをとった MSE Lossを使って学習
  12. 2025/02/15 第3回 関東Kaggler会 15 Tips: timmを用いたdepth pooling modelの効率的な実装 • 公開notebookでは、resnetにのみ適用できる実装になっている

    • 実はtimm.create_model(feature_only=True)で作成されたモデルは FeatureListNetをラップして作成されるため、FeatureListNetを継承した クラスを作るとstage間に任意の処理を実装することができる backbone内のlayerを順番に処理 stage毎にdictに出力を保存 https://github.com/huggingface/pytorch-ima ge-models/blob/main/timm/models/_features .py#L230
  13. 2025/02/15 第3回 関東Kaggler会 16 team yu4u & tattakaの解法 (4th) 5/6

    • 推論の高速化 ◦ 初期の推論notebook: P100 notebookでfp32推論 ▪ 後処理(後述)をGPU上で行うことで速度1.5倍 ▪ T4x2 notebook・fp16推論に変更して速度1.5倍 ▪ multi threading(tattaka)・DataParallel(yu4u)に変更して2倍 ▪ multi processing化・modelをtensorrt変換して1.4倍 ◦ 最終的には初期notebookより約6倍高速化
  14. 2025/02/15 第3回 関東Kaggler会 17 team yu4u & tattakaの解法 (4th) 6/6

    • 後処理 1. 16x128x128(yu4u)・32x128x128(tattaka)の入力解像度で推論、 xy方向は約0.4のoverlap、z方向は0.5のoverlapでスライド・ flipを使ったTTA ◦ 出力の中心に近い予測値を信用するようにweightを調整 2. MaxPooling(kernel_size=7)を用いてピーク検出 3. 粒子の種類によって異なる閾値を適用してフィルタリング • cc3dを使用するよりも高速 4. 画像座標系から物理座標系へ変換
  15. 2025/02/15 第3回 関東Kaggler会 18 team yu4u & tattakaの進め方 • 2024.12

    第1週: チームマージ・yu4uさんがヒートマップベースモデルのbaseline作成 • 2024.12 第2週: tattakaモデルのbaseline完成 • 2024.12 第3週: yu4uさんモデルのfirst submit(LB: 0.71, 11h) • 2025.1 第1週: tattakaモデルのfirst submit (LB: 0.63)、metricを勘違いしていたことに気づく • 2025.1 第2週: ◦ あまりにも推論が遅いので高速化・ensembleが可能に ◦ yu4uさんモデルのみのensembleで金圏へ ◦ tattakaモデルを加えるもスコア改善ならず (閾値の傾向がかなり異なることが原因だと考え、tattakaもMSEベースのlossに切り替え) • 2025.1 第4週: ◦ tensorrtを使った高速化 ◦ 良いbackboneがないか探索 ◦ CVが当てにならないのでひたすらLBでチューニング ▪ 閾値・TTAや推論箇所の刻み幅など、推論時のパラメータを調整 ◦ yu4uさんモデル+tattakaモデルでほぼ最終スコアに
  16. 2025/02/15 第3回 関東Kaggler会 19 他チーム解法紹介 • 1st place solution ◦

    物体検出モデルとセグメンテーションモデルのensemble ◦ 1/2解像度の出力を使うことで高速化 ◦ 傾向の違うモデルをensembleするためのhistgram matching ▪ モデルAとBがある時、AとBの予測値をsortして 同じindexのBの予測値をAの予測値に置き換える • 2nd place solution ◦ 軽量なセグメンテーションモデルのensemble ◦ 正規化層にInstanceNorm3d、活性化関数にPReLUを用いることで学習が安定 • 3rd place solution ◦ resnet101をbackboneとした3D UNet (segmentation_models_pytorch_3d)を使用 ◦ 他のsolutionに比べて大きめのモデル (foldが異なる同じモデルを4つensembleしているのみ) ◦ 学習時よりも推論時の入力サイズを大きくする
  17. 2025/02/15 第3回 関東Kaggler会 20 まとめ • 推論しなければいけないサンプル数が多いため、効率的なモデル作成・ 推論の高速化を行いensembleをうまく作るのが上位入賞の鍵? • train

    set 7件でTrust CVするよりも約120件のpublic LBを 信じた方がよかった • 1回目のsubmitは早めに作りましょう...... ◦ 実行環境とKaggle環境の推論時間ギャップやLB/CVギャップ、バグや 勘違いなど色々な気づきがある