distributed and parallel execution since the initial JAX release. An API that requires specifying the number of devices upfront for sharding, which began to conflict with the JAX trends jax.pmap
to define device num Always use global array, explicitly define divided num, easily extended to other API jax.shard_map Special page for migration of jap.pmap v0.8.0+
shard_map Prep, same as pmap - Define Kernel - Scipy: Cholesky decomposition - Scipy: Solve triangular - Scipy: α=(K**-1) y Decorator with jax.shard_map and calculates Gaussian process Really readable!