Upgrade to Pro — share decks privately, control downloads, hide ads and more …

Standardizing on a single N-dimensional array API for Python

Standardizing on a single N-dimensional array API for Python

Numerical computing and deep learning libraries for Python all offer array (or tensor) data structures and associated compute functionality with similar APIs. There are many subtle differences however, making it hard for users to migrate from one library to another, or for library authors to write code that supports multiple array libraries. The Consortium for Python Data API Standards (https://data-apis.org/) recently released a first version of its array API standard - which aims to address these issues - for community review.

In this talk, we will start with an overview of array API standard goals, benefits and API surface, and then focus on some of the key technical issues, such as reconciling in-place operations with immutable/mutable array data models, dtype casting rules, and zero-copy exchange protocols. Finally we will look at initial implementations in NumPy and PyTorch, and plans for use in downstream libraries like SciPy and scikit-learn.

Toulouse Data Science

June 17, 2021
Tweet

More Decks by Toulouse Data Science

Other Decks in Programming

Transcript

  1. How often do you write code for novel array/tensor APIs?

    vs. Rewriting for another library or for higher performance?
  2. Today’s Python data ecosystem Can we make it easy to

    build on top of multiple array data structures?
  3. Example: einops package Einops is a popular package for array

    manipulation (reshaping, concatenating, stacking, etc.) Supports all major array/tensor libraries. It has: • 700 LoC for public APIs • 550 LoC for backends ⇒ transpose() is still relatively well-behaved, it gets worse for other functions
  4. State of compatibility today All libraries have common concepts and

    functionality. But, there are many small (and some large) incompatibilities. It’s very painful to translate code from one array library to another. Let’s look at some examples!
  5. Consortium for Python Data API Standards A new organization, with

    participation from maintainers of many array (a.k.a. tensor) and dataframe libraries. Concrete goals for first year: 1. Define a standardization methodology and necessary tooling for it 2. Publish an RFC for an array API standard 3. Publish an RFC for a dataframe API standard 4. Finalize 2021.0x API standards after community review See data-apis.org and github.com/data-apis for more on the Consortium expected within a month
  6. Goals for and scope of the array API Syntax and

    semantics of functions and objects in the API Casting rules, broadcasting, indexing, Python operator support Data interchange & device support Execution semantics (e.g. task scheduling, parallelism, lazy eval) Non-standard dtypes, masked arrays, I/O, subclassing array object, C API Error handling & behaviour for invalid inputs to functions and methods Goal 1: enable writing code & packages that support multiple array libraries Goal 2: make it easy for end users to switch between array libraries In Scope Out of Scope
  7. Array- and array-consuming libraries Using DLPack, will work for any

    two libraries if they support device the data resides on x = xp.from_dlpack(x_other) Data interchange between array libs Portable code in array-consuming libs def softmax(x): # grab standard namespace from # the passed-in array xp = get_array_api(x) x_exp = xp.exp(x) partition = xp.sum(x_exp, axis=1, keepdims=True) return x_exp / partition
  8. What does the full API surface look like? • 1

    array object with ◦ 6 attributes: ndim, shape, size, dtype, device, T ◦ dunder methods to support all Python operators ◦ __array_api_version__, __array_namespace__, __dlpack__ • 11 dtype literals: bool, (u)int8/16/32/64, float32/64 • 1 device object • 4 constants: inf, nan, pi, e • ~125 functions: ◦ Array creation & manipulation (20) ◦ Element-wise math & logic (6) ◦ Statistics (7) ◦ Linear algebra (22) ◦ Search, sort & set (7) ◦ Utilities, dtypes, broadcasting (8)
  9. Mutability & copies/views x = ones(4) # y may be

    a view on data of x y = x[:2] # modifies x if y is a view y += 1 Mutable operations and the concept of views are important for strided in-memory array implementations (NumPy, CuPy, PyTorch, MXNet) They are problematic for libraries based on immutable data structures or delayed evaluation (TensorFlow, JAX, Dask) Decisions in API standard: 1. Support inplace operators 2. Support item and slice assignment 3. Do not support out= keyword 4. Warn users that mixing mutating operations and views may result in implementation-specific behavior
  10. Dtype casting rules x = xp.arange(5) # will be integer

    y = xp.ones(5, dtype=xp.float32) # This may give float32, float64, or raise dtype = (x * y).dtype Casting rules are straightforward to align between libraries when the dtypes are of the same kind Mixed integer and floating-point casting is very inconsistent between libraries, and hard to change: Hence this will remain unspecified.
  11. Data-dependent output shape/dtype # Boolean indexing, and even slicing #

    in some cases, results in shapes # that depend on values in `x` x2 = x[:, x > 3] val = somefunc(x) x3 = x[:val] # Functions for which output shape # depends on value unique(x) nonzero(x) # NumPy does value-based casting x = np.ones(3, dtype=np.float32) x + 1 # float32 output x + 100000 # float64 output Data-dependent output shapes or dtypes are problematic, because of: • static memory allocation (TensorFlow, JAX) • graph-based scheduling (Dask) • JIT compilation (Numba, PyTorch, JAX, Gluon) Value-based dtype results can be avoided. Value-based shapes can be important - the API standard will include but clearly mark such functionality.
  12. DLPack - device-aware zero copy protocol Improved Python API: x_mylib

    = from_dlpack(x_otherlib) Getting stream handling right was hard: def from_dlpack(x): device = x.__dlpack_device__() consumer_stream = _find_exchange_stream(device) dlpack_caps = x.__dlpack__(consumer_stream) return _convert_to_consumer_array(dlpack_caps) def __dlpack__(self, /, *, stream=None): # stream: optional pointer to a stream, as a Python integer, # provided by the consumer that the producer will use to make # the array safe to operate on (e.g., via cudaStreamWaitForEvent) return dlpack_capsule
  13. Where are we today? (1/2) The array API standard is

    >95% complete and published for community review. A mechanism for future extensions is also defined. Open discussion points include: • unique is the only polymorphic function (output type depends on keywords) - should it be changed? • Type promotion for reductions, and one-off promotion corner caser • Resolving issues that come up during implementation in libraries The NumPy Enhancement Proposal (NEP 47) for adoption is also merged (Draft), and reference implementation progressing nicely - will be merged with experimental status in the next few weeks: https://github.com/numpy/numpy/pull/18585
  14. Where are we today? (2/2) • PyTorch has decided that

    the array API standard will be adopted: • JAX and CuPy will wait till NumPy has an implementation, and then add compatibility in the same way. Dask hasn’t confirmed yet, but in general aims for a NumPy-compatible API too • MXNet and ONNX have stated they will implement the standard • TensorFlow will likely add support in tf.experimental (not confirmed yet)
  15. What is next? — array API standard 1. Complete the

    library-independent test suite 2. Merge reference implementation in NumPy 3. Prototype implementations in other array libraries & use downstream (SciPy, scikit-learn, scikit-image, domain-specific libraries) 4. Get sign-off from maintainers of each array library ⇒ array API v2021 final
  16. How can you help? Give feedback! Is your use case

    covered? See a small gap in functionality? Contribute! Portable test & benchmarking suites, remaining design issues Implement! The standard is complete enough to adopt today (draft mode) Spread awareness! Blog, reference in your talk, ... Support! Funding or engineering time -- lots more to do, also for dataframes
  17. Consortium: • Website & introductory blog posts: data-apis.org • Array

    API main repo: github.com/data-apis/array-api • Latest version of the standard: data-apis.github.io/array-api/latest • Members: github.com/data-apis/governance Find me at: [email protected], rgommers, ralfgommers Try this at home - installing the latest version of all seven array libraries in one env to experiment: conda create -n many-libs python=3.7 conda activate many-libs conda install cudatoolkit=10.2 pip install numpy torch jax jaxlib tensorflow mxnet cupy-cuda102 dask toolz sparse To learn more