Slide 1

Slide 1 text

Standardizing on a single N-dimensional array API for Python Coordination: Sponsors: Ralf Gommers 15 June 2021

Slide 2

Slide 2 text

How often do you write code for novel array/tensor APIs? vs. Rewriting for another library or for higher performance?

Slide 3

Slide 3 text

Array-based computing in Python

Slide 4

Slide 4 text

Today’s Python data ecosystem Can we make it easy to build on top of multiple array data structures?

Slide 5

Slide 5 text

Example: scikit-image, CuPy & Dask

Slide 6

Slide 6 text

Example: einops package Einops is a popular package for array manipulation (reshaping, concatenating, stacking, etc.) Supports all major array/tensor libraries. It has: ● 700 LoC for public APIs ● 550 LoC for backends ⇒ transpose() is still relatively well-behaved, it gets worse for other functions

Slide 7

Slide 7 text

State of compatibility today All libraries have common concepts and functionality. But, there are many small (and some large) incompatibilities. It’s very painful to translate code from one array library to another. Let’s look at some examples!

Slide 8

Slide 8 text

So compatibility is poor? Fix it: create a standard!

Slide 9

Slide 9 text

No content

Slide 10

Slide 10 text

Consortium for Python Data API Standards A new organization, with participation from maintainers of many array (a.k.a. tensor) and dataframe libraries. Concrete goals for first year: 1. Define a standardization methodology and necessary tooling for it 2. Publish an RFC for an array API standard 3. Publish an RFC for a dataframe API standard 4. Finalize 2021.0x API standards after community review See data-apis.org and github.com/data-apis for more on the Consortium expected within a month

Slide 11

Slide 11 text

Goals for and scope of the array API Syntax and semantics of functions and objects in the API Casting rules, broadcasting, indexing, Python operator support Data interchange & device support Execution semantics (e.g. task scheduling, parallelism, lazy eval) Non-standard dtypes, masked arrays, I/O, subclassing array object, C API Error handling & behaviour for invalid inputs to functions and methods Goal 1: enable writing code & packages that support multiple array libraries Goal 2: make it easy for end users to switch between array libraries In Scope Out of Scope

Slide 12

Slide 12 text

Array- and array-consuming libraries Using DLPack, will work for any two libraries if they support device the data resides on x = xp.from_dlpack(x_other) Data interchange between array libs Portable code in array-consuming libs def softmax(x): # grab standard namespace from # the passed-in array xp = get_array_api(x) x_exp = xp.exp(x) partition = xp.sum(x_exp, axis=1, keepdims=True) return x_exp / partition

Slide 13

Slide 13 text

What does the full API surface look like? ● 1 array object with ○ 6 attributes: ndim, shape, size, dtype, device, T ○ dunder methods to support all Python operators ○ __array_api_version__, __array_namespace__, __dlpack__ ● 11 dtype literals: bool, (u)int8/16/32/64, float32/64 ● 1 device object ● 4 constants: inf, nan, pi, e ● ~125 functions: ○ Array creation & manipulation (20) ○ Element-wise math & logic (6) ○ Statistics (7) ○ Linear algebra (22) ○ Search, sort & set (7) ○ Utilities, dtypes, broadcasting (8)

Slide 14

Slide 14 text

Latest: github.com/data-apis/array-api/

Slide 15

Slide 15 text

Mutability & copies/views x = ones(4) # y may be a view on data of x y = x[:2] # modifies x if y is a view y += 1 Mutable operations and the concept of views are important for strided in-memory array implementations (NumPy, CuPy, PyTorch, MXNet) They are problematic for libraries based on immutable data structures or delayed evaluation (TensorFlow, JAX, Dask) Decisions in API standard: 1. Support inplace operators 2. Support item and slice assignment 3. Do not support out= keyword 4. Warn users that mixing mutating operations and views may result in implementation-specific behavior

Slide 16

Slide 16 text

Dtype casting rules x = xp.arange(5) # will be integer y = xp.ones(5, dtype=xp.float32) # This may give float32, float64, or raise dtype = (x * y).dtype Casting rules are straightforward to align between libraries when the dtypes are of the same kind Mixed integer and floating-point casting is very inconsistent between libraries, and hard to change: Hence this will remain unspecified.

Slide 17

Slide 17 text

Data-dependent output shape/dtype # Boolean indexing, and even slicing # in some cases, results in shapes # that depend on values in `x` x2 = x[:, x > 3] val = somefunc(x) x3 = x[:val] # Functions for which output shape # depends on value unique(x) nonzero(x) # NumPy does value-based casting x = np.ones(3, dtype=np.float32) x + 1 # float32 output x + 100000 # float64 output Data-dependent output shapes or dtypes are problematic, because of: ● static memory allocation (TensorFlow, JAX) ● graph-based scheduling (Dask) ● JIT compilation (Numba, PyTorch, JAX, Gluon) Value-based dtype results can be avoided. Value-based shapes can be important - the API standard will include but clearly mark such functionality.

Slide 18

Slide 18 text

DLPack - device-aware zero copy protocol Improved Python API: x_mylib = from_dlpack(x_otherlib) Getting stream handling right was hard: def from_dlpack(x): device = x.__dlpack_device__() consumer_stream = _find_exchange_stream(device) dlpack_caps = x.__dlpack__(consumer_stream) return _convert_to_consumer_array(dlpack_caps) def __dlpack__(self, /, *, stream=None): # stream: optional pointer to a stream, as a Python integer, # provided by the consumer that the producer will use to make # the array safe to operate on (e.g., via cudaStreamWaitForEvent) return dlpack_capsule

Slide 19

Slide 19 text

Where are we today? (1/2) The array API standard is >95% complete and published for community review. A mechanism for future extensions is also defined. Open discussion points include: ● unique is the only polymorphic function (output type depends on keywords) - should it be changed? ● Type promotion for reductions, and one-off promotion corner caser ● Resolving issues that come up during implementation in libraries The NumPy Enhancement Proposal (NEP 47) for adoption is also merged (Draft), and reference implementation progressing nicely - will be merged with experimental status in the next few weeks: https://github.com/numpy/numpy/pull/18585

Slide 20

Slide 20 text

Where are we today? (2/2) ● PyTorch has decided that the array API standard will be adopted: ● JAX and CuPy will wait till NumPy has an implementation, and then add compatibility in the same way. Dask hasn’t confirmed yet, but in general aims for a NumPy-compatible API too ● MXNet and ONNX have stated they will implement the standard ● TensorFlow will likely add support in tf.experimental (not confirmed yet)

Slide 21

Slide 21 text

What is next? — array API standard 1. Complete the library-independent test suite 2. Merge reference implementation in NumPy 3. Prototype implementations in other array libraries & use downstream (SciPy, scikit-learn, scikit-image, domain-specific libraries) 4. Get sign-off from maintainers of each array library ⇒ array API v2021 final

Slide 22

Slide 22 text

What is next? — Data APIs roadmap

Slide 23

Slide 23 text

How can you help? Give feedback! Is your use case covered? See a small gap in functionality? Contribute! Portable test & benchmarking suites, remaining design issues Implement! The standard is complete enough to adopt today (draft mode) Spread awareness! Blog, reference in your talk, ... Support! Funding or engineering time -- lots more to do, also for dataframes

Slide 24

Slide 24 text

Consortium: ● Website & introductory blog posts: data-apis.org ● Array API main repo: github.com/data-apis/array-api ● Latest version of the standard: data-apis.github.io/array-api/latest ● Members: github.com/data-apis/governance Find me at: [email protected], rgommers, ralfgommers Try this at home - installing the latest version of all seven array libraries in one env to experiment: conda create -n many-libs python=3.7 conda activate many-libs conda install cudatoolkit=10.2 pip install numpy torch jax jaxlib tensorflow mxnet cupy-cuda102 dask toolz sparse To learn more