Slide 18
Slide 18 text
import numpy as np
import torch
from jaxtyping import Num, Float16, Float32
def cast_fp32_to_fp16(
x: Float32[np.ndarray, "..."],
) -> Float16[np.ndarray, "..."]:
return x.astype(np.float16)
def cast_numpy_to_torch(
x: Num[np.ndarray, "..."],
) -> Num[torch.Tensor, "..."]:
return torch.from_numpy(x)
from jaxtyping import install_import_hook
with install_import_hook(
__name__,
"beartype.beartype",
):
from .convert import (
cast_fp32_to_fp16,
cast_numpy_to_torch,
)
install_import_hook でコードベース全体に付与
src
├── main.py
└── convert.py のような構造
install_import_hook を適⽤することで @jaxtyped が⾃動的に付与
convert.py
main.py