Slide 1

Slide 1 text

numpyやPyTorchの配列に dtypeとshapeをアノテーションする jaxtypingのススメ 堅牢.py #1 @colum2131

Slide 2

Slide 2 text

想定読者 ● Pythonを使⽤した開発を⾏なったことがある ● 特に機械学習‧深層学習で安全で理解しやすいコードが書きたい ● テンソルのdtypeとshapeが何を表しているのか悩んだことがある 以下は知らなくても⼤丈夫です! ● Pythonの型ヒントを書いたことがない ● 静的型チェッカーを使ったことがない  \⼿元で試せます!∕ https://colab.research.google.com/drive/1wdJXEPO3PpKLF09xJq3H5AFee6H1vy5_?usp=sharing

Slide 3

Slide 3 text

Pythonの型付け (Type Hints) ● Python は動的型付け⾔語であり、変数の型が保証されない ● Python 3.5 から型ヒントサポートとして typing が追加 (PEP 484) ● 以下のようにアノテーションすることが可能: ● ⼀⽅で string型 以外の変数を渡してもエラーや警告は出⼒されない😢 ● 静的型チェッカーや実⾏時型検証のデータ構造を使おう! def fn(a: str) -> str: return a + "1" 引数が string 型 返り値も string 型であることがわかる!

Slide 4

Slide 4 text

Pythonの静的型チェッカー: mypy ● 静的型チェッカーとして1つとして mypy が採⽤される ● uvであれば uv add mypy --dev または uvx mypy でインストールおすすめ ● mypy <file名 | ディレクトリ> で実⾏ def fn(a: str) -> str: return a + "1" fn(1) error: Argument 1 to "fn" has incompatible type "int"; expected "str" [arg-type] Found 1 error in 1 file (checked 1 source file) これをmypyでチェックすると怒ってくれる👏 Success: no issues found in 1 source file 正しいと以下のように出⼒:

Slide 5

Slide 5 text

型付けを⾏うデータ構造 TypedDict dataclass from typing import TypedDict class PoseDict(TypedDict): translation: np.ndarray rotation: np.ndarray from dataclasses import dataclass @dataclass class Pose: translation: np.ndarray rotation: np.ndarray ● 値の型を決めた辞書型ヒント ● 実⾏時は誤った型でも、存在しない keyのデータでも代⼊可能 ● 静的型チェックで定義と合わない 代⼊やアクセスを検知する ● 属性の型を決めたクラスインスタンス ● 実⾏時に誤った型でも代⼊可能 (存在しない属性は初期化時にエラー) ● 静的型チェックで型が正しいか検知する

Slide 6

Slide 6 text

実⾏時型検証を⾏うデータクラス: pydantic どれを使うべきか? → ⽤途に合わせて使う ● dict型 のデータ構造に型を付けたい → TypedDict ● 軽量なドメインモデル‧値オブジェクトとして → dataclasses ● ⼊出⼒を保証して安全なオブジェクトに → Pydantic みたいな (ざっくりですが...) Pydantic from pydantic import BaseModel class Pose(BaseModel): translation: list[float] rotation: list[float] ● 属性の型と制約を決めた実⾏時型検証の クラスインスタンス ● 実⾏時に誤った型や不正な値を検知する ● 静的型チェックでも通常のクラス同様に検知する

Slide 7

Slide 7 text

こんなお悩みありませんか?(1) 3次元空間での Pose (姿勢) というデータクラスがあったとき @dataclass class Pose: translation: np.ndarray rotation: np.ndarray rotation (回転) は何を表している ● オイラー⾓なら3次元のベクトル? ● クオータニオンなら4次元のベクトル? ● 回転⾏列なら 3 x 3 の⾏列? translation は何を表している ● x, y, zの3次元のベクトル? そもそも時間のような軸は⼊っていない?

Slide 8

Slide 8 text

こんなお悩みありませんか?(2) def forward(self, image): x = self.conv1(image) x = self.conv2(x) x = x.view(x.size(0), -1) x = self.fc(x) return x 以下のようなニューラルネットワークの forward メソッドがあったとき image のshapeは何?dtypeはfloat型? バッチサイズの軸が最初? .view 以降の x のshapeは何? 返り値の shape は何?

Slide 9

Slide 9 text

jaxtypingなら解決できます!(1) from jaxtyping import Float32 @dataclass class Pose: """Attributes: translation (np.ndarray): ENU coordinates. [x, y, z]. rotation (np.ndarray): quaternion. [w, x, y, z]. """ translation: Float32[np.ndarray, "3"] rotation: Float32[np.ndarray, "4"]

Slide 10

Slide 10 text

jaxtypingなら解決できます!(2) def forward( self, image: Float[torch.Tensor, "batch 3 height width"] ) -> Float[torch.Tensor, "batch num_classes"]: x = self.conv1(image) x = self.conv2(x) x = x.view(x.size(0), -1) x = self.fc(x) return x from jaxtyping import Float

Slide 11

Slide 11 text

jaxtypingの書き⽅ dtype[array, shape] 形式で書く dtype array shape ... jaxtypingのデータ型 (e.g., Int, Float, Float32) ... 配列型 (e.g., jax.Array, np.ndarray, torch.tensor) ... 記号をスペースで区切った⽂字列

Slide 12

Slide 12 text

dtypeについて ● Float64 ● Float32 ● Float16 ● BFloat8 ● INT2 ● INT4 ● INT8 ● INT16 ● INT32 ● INT64 ● UINT2 ● UINT4 ● UINT8 ● UINT16 ● UINT32 ● UINT64 ● Complex64 ● Complex128 Bool Float Complex INT UINT Inexact Integer Num 以下のような階層構造で、精度ごとのデータ型もある 他にも存在する https://docs.kidger.site/jaxtyping/api/array/#dtype

Slide 13

Slide 13 text

arrayについて JAX 以外にも NumPy, PyTorch, TensorFlow, MLX に対応 ● jax.Array ● np.ndarray ● torch.Tensor ● tf.Tensor ● mx.array

Slide 14

Slide 14 text

shapeについて ● スペース区切りした記号で各軸の情報を表す ○ str: 可変な値の軸. 変数として, 同じ記号名は同じ次元数であることを表す Float[torch.Tensor, "batch channels height width"] ○ int: 固定の値の軸. 定数として, その軸の次元数を表す Float[torch.Tensor, "1 3 244 244"] ● スカラーの場合のshapeは "" とする ● * をつけると連続した 0 個以上の軸列を表す 例えば Float[torch.Tensor, "*B 3 H W"] とすると以下もマッチする (3, H, W) (*B の軸が0つ) (N1, N2, 3, H, W) (*B の軸が2つ) ● "..." の場合は変数名がない連続した0個以上の軸列を表す

Slide 15

Slide 15 text

実⾏時型検証 ● jaxtyping の型アノテーションは基本的には実⾏時に検証しない 😢 ● 静的型チェックも array に対して検証 (shape や dtype 対応していない 😢) 主に3つの⽅法で実⾏時型検証機能をを付与 ● jaxtyping.jaxtyped を使⽤した単⼀の関数/データクラスに付与 ● jaxtyping.install_import_hook を使⽤してコードベース全体に付与 ● pytest --jaxtyping-packages=... を使⽤したpytest時での⾃動付与

Slide 16

Slide 16 text

@jaxtyped で単⼀の関数/データクラスに付与 関数に @jaxtyped としてデコレータを設定することで実⾏時に検証 from beartype import beartype from jaxtyping import Float32, jaxtyped import numpy as np @jaxtyped(typechecker=beartype) def f(x: Float32[np.ndarray, "3 H W"]) -> Float32[np.ndarray, "3 H W"]: return x f(np.random.randn(3, 224, 224).astype(np.float32)) *typecheckerは beartype を使⽤

Slide 17

Slide 17 text

@jaxtyped で単⼀の関数/データクラスに付与 shape が期待したものと異なるとエラー👍 (dtype も array も同様) TypeCheckError ... Expected type: . ---------------------- Called with parameters: {'x': f32[4,224,224](numpy)} Parameter annotations: (x: Float32[ndarray, '3 H W']) -> Any. f(np.random.randn(4, 224, 224).astype(np.float32))

Slide 18

Slide 18 text

import numpy as np import torch from jaxtyping import Num, Float16, Float32 def cast_fp32_to_fp16( x: Float32[np.ndarray, "..."], ) -> Float16[np.ndarray, "..."]: return x.astype(np.float16) def cast_numpy_to_torch( x: Num[np.ndarray, "..."], ) -> Num[torch.Tensor, "..."]: return torch.from_numpy(x) from jaxtyping import install_import_hook with install_import_hook( __name__, "beartype.beartype", ): from .convert import ( cast_fp32_to_fp16, cast_numpy_to_torch, ) install_import_hook でコードベース全体に付与 src ├── main.py └── convert.py のような構造 install_import_hook を適⽤することで @jaxtyped が⾃動的に付与 convert.py main.py

Slide 19

Slide 19 text

from jaxtyping import install_import_hook import numpy as np with install_import_hook("src", "beartype.beartype"): from src.convert import cast_fp32_to_fp16 def test_cast_fp32_to_fp16(): x = np.random.randn(3, 224, 224).astype(np.float32) y = cast_fp32_to_fp16(x) install_import_hook でコードベース全体に付与 テストの実⾏時のみに型チェックするのもいい 👍 tests/test_convert.py

Slide 20

Slide 20 text

pytest --jaxtyping-packages=...で⾃動付与 pytest --jaxtyping-packages=src,beartype.beartype tests でsrc.*はimport時に install_import_hook を適⽤したことと同等になる ● ,で分割した⽂字列に適⽤したいモジュール名の prefix を指定する ● 最後には使いたい typecheker を表す⽂字列にする: beartype.beartype import numpy as np from src.convert import cast_fp32_to_fp16 def test_cast_fp32_to_fp16(): x = np.random.randn(3, 224, 224).astype(np.float32) y = cast_fp32_to_fp16(x) tests/test_convert.py install_import_hookが不要に🙌

Slide 21

Slide 21 text

pyproject.toml に --jaxtyping-packages を記載することも可能 ● 以下の設定で pytest tests だけでも同様の引数を渡したことになる pytest --jaxtyping-packages=...で⾃動付与 [tool.pytest.ini_options] pythonpath = "./" testpaths = ["tests"] addopts = "--jaxtyping-packages=src,beartype.beartype" pyproject.toml

Slide 22

Slide 22 text

より⾼度な型アノテーションの書き⽅ ● 複数の array をUnion型で指定することが可能 ● shape で _ を設定することでその軸の実⾏時チェックを無効にする ● shape の⽂字列変数は演算することができる ● shape で # をつけるとブロードキャストが許容となる (任意のサイズか1か) Num[np.ndarray | torch.Tensor, "..."] Float32[np.ndarray, "_ H W"] Float32[np.ndarray, "B*3"] Float32[np.ndarray, "#B 3 H W"] 1 か 64 は 󰢏 32 か 64 は󰢃

Slide 23

Slide 23 text

● jaxtyping はテンソルのdtype, array, shapeをアノテーションできる! ● jaxtyping.jaxtyped を使⽤することで実⾏時型チェックができる! ● 学習ループなどで実⾏時型検証は速度遅延の可能性がある ○ テストやデプロイ時などに使うといいかも ● 個⼈的には実⾏時型検証の仕組みを使わなくてもコメントとして jaxtyping を使うのもおすすめ👍 ○ 他のメンバーに共有する時や未来の⾃分が⾒返した時にわかりやすい ● どんどん流⾏ってほしい まとめ