Upgrade to PRO for Only $50/Year—Limited-Time Offer! 🔥

実践で使えるtorchのテンソル演算

 実践で使えるtorchのテンソル演算

Transcript

  1. 2 テンソル演算を習得するときのポイント • ⼊⼒のshapeと出⼒のshapeを把握する ◦ mvであれば,[n, m] x [m] →

    [n]など • Notebookなどで適当に遊んでみると感覚的に理解できる ◦ 感覚的に理解すると,数値計算以外の開発においても, 「ここテンソルにしたら効率良くね?」って気づけるようになる • 今回はPyTorchを使うが,Numpyなども8割⽅同じ考えで習得できる
  2. 3 ⽬次 • 積 ◦ dot, mv, mm, bmm, matmul

    • ⽣成 ◦ zeros, ones, randn, arange • 次元操作および形状変更 ◦ t, transpose, reshape, squeeze, unsqueeze, cat, stack, split • index操作 ◦ gather, masked_select • 集約関数 ◦ sum, mean, max, min, std ※ 他にも色々な関数があるので, Pytorchのdocを是非ご覧ください https://docs.pytorch.org/docs/stable/t orch.html
  3. 4 ⽤語の定義 • スカラー ◦ 0次元(単⼀)の値 ex. 1.0 • ベクトル

    ◦ 1次元のテンソル ex. [1.0, 2.0, 3.0] • ⾏列 ◦ 2次元のテンソル ex. [[1.0, 2.0], [3.0, 4.0]]
  4. 5 dot • ベクトルとベクトルの積 ◦ [n] x [n] → []

    (スカラー) • 2次元([1, n] x [1, n])にするとエラー
  5. 9 matmul • dot, mv, mm, bmmを⼀般化したAPI ◦ つまり⼊⼒テンソルの次元数によって操作の意味が異なる •

    以下の例では,10x10がバッチ次元として解釈され, 各バッチに対する[3, 5] x [5, 2]のmmが計算されている
  6. 14 t • ⾏列の転置 • 3次元以上の テンソルだとエラー 転置してもstride(indexの進み⽅)が変わるだけ。 コピーしたい場合は,contiguousやcloneを⾏う。 ※

    torchの転置の具体的な挙動に関して解説してくれているサイトです https://aisinkakura-datascientist.hatenablog.com/entry/2024/06/19/113008
  7. 25 集約関数 - sum, mean, max, min, std • 指定した次元で集約する

    ◦ keepdim=True/Falseで出⼒の形状が変化 x: n次元, dim: [k]の場合, keepdim=Falseなら出力はn-k次元, keepdim=Trueなら出力はn次元