Upgrade to Pro — share decks privately, control downloads, hide ads and more …

numpyやPyTorchの配列にdtypeとshapeをアノテーションするjaxtyping...

Avatar for Kohei Iwamasa Kohei Iwamasa
November 20, 2025
840

 numpyやPyTorchの配列にdtypeとshapeをアノテーションするjaxtypingのススメ

numpyやPyTorchの配列にdtypeとshapeをアノテーションするjaxtypingのススメ
サンプルコードはこちら
https://colab.research.google.com/drive/1wdJXEPO3PpKLF09xJq3H5AFee6H1vy5_?usp=sharing

Avatar for Kohei Iwamasa

Kohei Iwamasa

November 20, 2025
Tweet

More Decks by Kohei Iwamasa

Transcript

  1. Pythonの型付け (Type Hints) • Python は動的型付け⾔語であり、変数の型が保証されない • Python 3.5 から型ヒントサポートとして

    typing が追加 (PEP 484) • 以下のようにアノテーションすることが可能: • ⼀⽅で string型 以外の変数を渡してもエラーや警告は出⼒されない😢 • 静的型チェッカーや実⾏時型検証のデータ構造を使おう! def fn(a: str) -> str: return a + "1" 引数が string 型 返り値も string 型であることがわかる!
  2. Pythonの静的型チェッカー: mypy • 静的型チェッカーとして1つとして mypy が採⽤される • uvであれば uv add

    mypy --dev または uvx mypy でインストールおすすめ • mypy <file名 | ディレクトリ> で実⾏ def fn(a: str) -> str: return a + "1" fn(1) error: Argument 1 to "fn" has incompatible type "int"; expected "str" [arg-type] Found 1 error in 1 file (checked 1 source file) これをmypyでチェックすると怒ってくれる👏 Success: no issues found in 1 source file 正しいと以下のように出⼒:
  3. 型付けを⾏うデータ構造 TypedDict dataclass from typing import TypedDict class PoseDict(TypedDict): translation:

    np.ndarray rotation: np.ndarray from dataclasses import dataclass @dataclass class Pose: translation: np.ndarray rotation: np.ndarray • 値の型を決めた辞書型ヒント • 実⾏時は誤った型でも、存在しない keyのデータでも代⼊可能 • 静的型チェックで定義と合わない 代⼊やアクセスを検知する • 属性の型を決めたクラスインスタンス • 実⾏時に誤った型でも代⼊可能 (存在しない属性は初期化時にエラー) • 静的型チェックで型が正しいか検知する
  4. 実⾏時型検証を⾏うデータクラス: pydantic どれを使うべきか? → ⽤途に合わせて使う • dict型 のデータ構造に型を付けたい → TypedDict

    • 軽量なドメインモデル‧値オブジェクトとして → dataclasses • ⼊出⼒を保証して安全なオブジェクトに → Pydantic みたいな (ざっくりですが...) Pydantic from pydantic import BaseModel class Pose(BaseModel): translation: list[float] rotation: list[float] • 属性の型と制約を決めた実⾏時型検証の クラスインスタンス • 実⾏時に誤った型や不正な値を検知する • 静的型チェックでも通常のクラス同様に検知する
  5. こんなお悩みありませんか?(1) 3次元空間での Pose (姿勢) というデータクラスがあったとき @dataclass class Pose: translation: np.ndarray

    rotation: np.ndarray rotation (回転) は何を表している • オイラー⾓なら3次元のベクトル? • クオータニオンなら4次元のベクトル? • 回転⾏列なら 3 x 3 の⾏列? translation は何を表している • x, y, zの3次元のベクトル? そもそも時間のような軸は⼊っていない?
  6. こんなお悩みありませんか?(2) def forward(self, image): x = self.conv1(image) x = self.conv2(x)

    x = x.view(x.size(0), -1) x = self.fc(x) return x 以下のようなニューラルネットワークの forward メソッドがあったとき image のshapeは何?dtypeはfloat型? バッチサイズの軸が最初? .view 以降の x のshapeは何? 返り値の shape は何?
  7. jaxtypingなら解決できます!(1) from jaxtyping import Float32 @dataclass class Pose: """Attributes: translation

    (np.ndarray): ENU coordinates. [x, y, z]. rotation (np.ndarray): quaternion. [w, x, y, z]. """ translation: Float32[np.ndarray, "3"] rotation: Float32[np.ndarray, "4"]
  8. jaxtypingなら解決できます!(2) def forward( self, image: Float[torch.Tensor, "batch 3 height width"]

    ) -> Float[torch.Tensor, "batch num_classes"]: x = self.conv1(image) x = self.conv2(x) x = x.view(x.size(0), -1) x = self.fc(x) return x from jaxtyping import Float
  9. jaxtypingの書き⽅ dtype[array, shape] 形式で書く dtype array shape ... jaxtypingのデータ型 (e.g.,

    Int, Float, Float32) ... 配列型 (e.g., jax.Array, np.ndarray, torch.tensor) ... 記号をスペースで区切った⽂字列
  10. dtypeについて • Float64 • Float32 • Float16 • BFloat8 •

    INT2 • INT4 • INT8 • INT16 • INT32 • INT64 • UINT2 • UINT4 • UINT8 • UINT16 • UINT32 • UINT64 • Complex64 • Complex128 Bool Float Complex INT UINT Inexact Integer Num 以下のような階層構造で、精度ごとのデータ型もある 他にも存在する https://docs.kidger.site/jaxtyping/api/array/#dtype
  11. arrayについて JAX 以外にも NumPy, PyTorch, TensorFlow, MLX に対応 • jax.Array

    • np.ndarray • torch.Tensor • tf.Tensor • mx.array
  12. shapeについて • スペース区切りした記号で各軸の情報を表す ◦ str: 可変な値の軸. 変数として, 同じ記号名は同じ次元数であることを表す Float[torch.Tensor, "batch

    channels height width"] ◦ int: 固定の値の軸. 定数として, その軸の次元数を表す Float[torch.Tensor, "1 3 244 244"] • スカラーの場合のshapeは "" とする • * をつけると連続した 0 個以上の軸列を表す 例えば Float[torch.Tensor, "*B 3 H W"] とすると以下もマッチする (3, H, W) (*B の軸が0つ) (N1, N2, 3, H, W) (*B の軸が2つ) • "..." の場合は変数名がない連続した0個以上の軸列を表す
  13. 実⾏時型検証 • jaxtyping の型アノテーションは基本的には実⾏時に検証しない 😢 • 静的型チェックも array に対して検証 (shape

    や dtype 対応していない 😢) 主に3つの⽅法で実⾏時型検証機能をを付与 • jaxtyping.jaxtyped を使⽤した単⼀の関数/データクラスに付与 • jaxtyping.install_import_hook を使⽤してコードベース全体に付与 • pytest --jaxtyping-packages=... を使⽤したpytest時での⾃動付与
  14. @jaxtyped で単⼀の関数/データクラスに付与 関数に @jaxtyped としてデコレータを設定することで実⾏時に検証 from beartype import beartype from

    jaxtyping import Float32, jaxtyped import numpy as np @jaxtyped(typechecker=beartype) def f(x: Float32[np.ndarray, "3 H W"]) -> Float32[np.ndarray, "3 H W"]: return x f(np.random.randn(3, 224, 224).astype(np.float32)) *typecheckerは beartype を使⽤
  15. @jaxtyped で単⼀の関数/データクラスに付与 shape が期待したものと異なるとエラー👍 (dtype も array も同様) TypeCheckError ...

    Expected type: <class 'Float32[ndarray, '3 H W']'>. ---------------------- Called with parameters: {'x': f32[4,224,224](numpy)} Parameter annotations: (x: Float32[ndarray, '3 H W']) -> Any. f(np.random.randn(4, 224, 224).astype(np.float32))
  16. import numpy as np import torch from jaxtyping import Num,

    Float16, Float32 def cast_fp32_to_fp16( x: Float32[np.ndarray, "..."], ) -> Float16[np.ndarray, "..."]: return x.astype(np.float16) def cast_numpy_to_torch( x: Num[np.ndarray, "..."], ) -> Num[torch.Tensor, "..."]: return torch.from_numpy(x) from jaxtyping import install_import_hook with install_import_hook( __name__, "beartype.beartype", ): from .convert import ( cast_fp32_to_fp16, cast_numpy_to_torch, ) install_import_hook でコードベース全体に付与 src ├── main.py └── convert.py のような構造 install_import_hook を適⽤することで @jaxtyped が⾃動的に付与 convert.py main.py
  17. from jaxtyping import install_import_hook import numpy as np with install_import_hook("src",

    "beartype.beartype"): from src.convert import cast_fp32_to_fp16 def test_cast_fp32_to_fp16(): x = np.random.randn(3, 224, 224).astype(np.float32) y = cast_fp32_to_fp16(x) install_import_hook でコードベース全体に付与 テストの実⾏時のみに型チェックするのもいい 👍 tests/test_convert.py
  18. pytest --jaxtyping-packages=...で⾃動付与 pytest --jaxtyping-packages=src,beartype.beartype tests でsrc.*はimport時に install_import_hook を適⽤したことと同等になる • ,で分割した⽂字列に適⽤したいモジュール名の

    prefix を指定する • 最後には使いたい typecheker を表す⽂字列にする: beartype.beartype import numpy as np from src.convert import cast_fp32_to_fp16 def test_cast_fp32_to_fp16(): x = np.random.randn(3, 224, 224).astype(np.float32) y = cast_fp32_to_fp16(x) tests/test_convert.py install_import_hookが不要に🙌
  19. pyproject.toml に --jaxtyping-packages を記載することも可能 • 以下の設定で pytest tests だけでも同様の引数を渡したことになる pytest

    --jaxtyping-packages=...で⾃動付与 [tool.pytest.ini_options] pythonpath = "./" testpaths = ["tests"] addopts = "--jaxtyping-packages=src,beartype.beartype" pyproject.toml
  20. より⾼度な型アノテーションの書き⽅ • 複数の array をUnion型で指定することが可能 • shape で _ を設定することでその軸の実⾏時チェックを無効にする

    • shape の⽂字列変数は演算することができる • shape で # をつけるとブロードキャストが許容となる (任意のサイズか1か) Num[np.ndarray | torch.Tensor, "..."] Float32[np.ndarray, "_ H W"] Float32[np.ndarray, "B*3"] Float32[np.ndarray, "#B 3 H W"] 1 か 64 は 󰢏 32 か 64 は󰢃
  21. • jaxtyping はテンソルのdtype, array, shapeをアノテーションできる! • jaxtyping.jaxtyped を使⽤することで実⾏時型チェックができる! • 学習ループなどで実⾏時型検証は速度遅延の可能性がある

    ◦ テストやデプロイ時などに使うといいかも • 個⼈的には実⾏時型検証の仕組みを使わなくてもコメントとして jaxtyping を使うのもおすすめ👍 ◦ 他のメンバーに共有する時や未来の⾃分が⾒返した時にわかりやすい • どんどん流⾏ってほしい まとめ