Float16, Float32 def cast_fp32_to_fp16( x: Float32[np.ndarray, "..."], ) -> Float16[np.ndarray, "..."]: return x.astype(np.float16) def cast_numpy_to_torch( x: Num[np.ndarray, "..."], ) -> Num[torch.Tensor, "..."]: return torch.from_numpy(x) from jaxtyping import install_import_hook with install_import_hook( __name__, "beartype.beartype", ): from .convert import ( cast_fp32_to_fp16, cast_numpy_to_torch, ) install_import_hook でコードベース全体に付与 src ├── main.py └── convert.py のような構造 install_import_hook を適⽤することで @jaxtyped が⾃動的に付与 convert.py main.py