Upgrade to Pro — share decks privately, control downloads, hide ads and more …

JIT compilation for CPython

Delimitry
December 17, 2019

JIT compilation for CPython

The presentation from SPbPython meetup about simple self-made just-in-time compiler of Python code.

Delimitry

December 17, 2019
Tweet

More Decks by Delimitry

Other Decks in Programming

Transcript

  1. JIT compilation and JIT history My experience with JIT in

    CPython Python projects that use JIT and projects for JIT Outline
  2. JIT Just-in-time compilation (aka dynamic translation, run-time compilation) The earliest

    JIT compiler on LISP by John McCarthy in 1960 Ken Thompson in 1968 used for regex in text editor QED
  3. JIT Just-in-time compilation (aka dynamic translation, run-time compilation) The earliest

    JIT compiler on LISP by John McCarthy in 1960 Ken Thompson in 1968 used for regex in text editor QED LC2
  4. JIT Just-in-time compilation (aka dynamic translation, run-time compilation) The earliest

    JIT compiler on LISP by John McCarthy in 1960 Ken Thompson in 1968 used for regex in text editor QED LC2 Smalltalk
  5. JIT Just-in-time compilation (aka dynamic translation, run-time compilation) The earliest

    JIT compiler on LISP by John McCarthy in 1960 Ken Thompson in 1968 used for regex in text editor QED LC2 Smalltalk Self
  6. JIT Just-in-time compilation (aka dynamic translation, run-time compilation) The earliest

    JIT compiler on LISP by John McCarthy in 1960 Ken Thompson in 1968 used for regex in text editor QED LC2 Smalltalk Self Popularized by Java with James Gosling using the term from 1993
  7. JIT Just-in-time compilation (aka dynamic translation, run-time compilation) The earliest

    JIT compiler on LISP by John McCarthy in 1960 Ken Thompson in 1968 used for regex in text editor QED LC2 Smalltalk Self Popularized by Java with James Gosling using the term from 1993 Just-in-time manufacturing, also known as just-in-time production or the Toyota Production System (TPS)
  8. Example def fibonacci(n): """Returns n-th Fibonacci number""" a = 0

    b = 1 if n < 1: return a i = 0 while i < n: temp = a a = b b = temp + b i += 1 return a Fibonacci Sequence: 0, 1, 1, 2, 3, 5, 8, 13, 21, 34, ...
  9. Let’s JIT it 1) Convert function to machine code at

    run-time 2) Execute this machine code
  10. Let’s JIT it @jit def fibonacci(n): """Returns n-th Fibonacci number"""

    a = 0 b = 1 if n < 1: return a i = 0 while i < n: temp = a a = b b = temp + b i += 1 return a
  11. Convert function to AST import ast import inspect lines =

    inspect.getsource(func) node = ast.parse(lines) visitor = Visitor() visitor.visit(node)
  12. AST Module(body=[ FunctionDef(name='fibonacci', args=arguments(args=[Name(id='n', ctx=Param())], vararg=None, kwarg=None, defaults=[]), body=[ Expr(value=Str(s='Returns

    n-th Fibonacci number')), Assign(targets=[Name(id='a', ctx=Store())], value=Num(n=0)), Assign(targets=[Name(id='b', ctx=Store())], value=Num(n=1)), If(test=Compare(left=Name(id='n', ctx=Load()), ops=[Lt()], comparators=[Num(n=1)]), body=[ Return(value=Name(id='a', ctx=Load())) ], orelse=[]), Assign(targets=[Name(id='i', ctx=Store())], value=Num(n=0)), While(test=Compare(left=Name(id='i', ctx=Load()), ops=[Lt()], comparators=[Name(id='n', ctx=Load())]), body=[ Assign(targets=[Name(id='temp', ctx=Store())], value=Name(id='a', ctx=Load())), Assign(targets=[Name(id='a', ctx=Store())], value=Name(id='b', ctx=Load())), Assign(targets=[Name(id='b', ctx=Store())], value=BinOp( left=Name(id='temp', ctx=Load()), op=Add(), right=Name(id='b', ctx=Load()))), AugAssign(target=Name(id='i', ctx=Store()), op=Add(), value=Num(n=1)) ], orelse=[]), Return(value=Name(id='a', ctx=Load())) ], decorator_list=[Name(id='jit', ctx=Load())]) ])
  13. AST to IL ASM class Visitor(ast.NodeVisitor): def __init__(self): self.ops =

    [] ... ... def visit_Assign(self, node): if isinstance(node.value, ast.Num): self.ops.append('MOV <{}>, {}'.format(node.targets[0].id, node.value.n)) elif isinstance(node.value, ast.Name): self.ops.append('MOV <{}>, <{}>'.format(node.targets[0].id, node.value.id)) elif isinstance(node.value, ast.BinOp): self.ops.extend(self.visit_BinOp(node.value)) self.ops.append('MOV <{}>, <{}>'.format(node.targets[0].id, node.value.left.id)) ...
  14. AST to IL ASM class Visitor(ast.NodeVisitor): def __init__(self): self.ops =

    [] ... ... def visit_Assign(self, node): if isinstance(node.value, ast.Num): self.ops.append('MOV <{}>, {}'.format(node.targets[0].id, node.value.n)) elif isinstance(node.value, ast.Name): self.ops.append('MOV <{}>, <{}>'.format(node.targets[0].id, node.value.id)) elif isinstance(node.value, ast.BinOp): self.ops.extend(self.visit_BinOp(node.value)) self.ops.append('MOV <{}>, <{}>'.format(node.targets[0].id, node.value.left.id)) ... ... Assign( targets=[Name(id='i', ctx=Store())], value=Num(n=0) ), Assign( targets=[Name(id='a', ctx=Store())], value=Name(id='b', ctx=Load()) ), ... ... MOV <i>, 0 ...
  15. AST to IL ASM class Visitor(ast.NodeVisitor): def __init__(self): self.ops =

    [] ... ... def visit_Assign(self, node): if isinstance(node.value, ast.Num): self.ops.append('MOV <{}>, {}'.format(node.targets[0].id, node.value.n)) elif isinstance(node.value, ast.Name): self.ops.append('MOV <{}>, <{}>'.format(node.targets[0].id, node.value.id)) elif isinstance(node.value, ast.BinOp): self.ops.extend(self.visit_BinOp(node.value)) self.ops.append('MOV <{}>, <{}>'.format(node.targets[0].id, node.value.left.id)) ... ... Assign( targets=[Name(id='i', ctx=Store())], value=Num(n=0) ), Assign( targets=[Name(id='a', ctx=Store())], value=Name(id='b', ctx=Load()) ), ... ... MOV <i>, 0 MOV <a>, <b> ...
  16. IL ASM to ASM MOV <a>, 0 MOV <b>, 1

    CMP <n>, 1 JNL label0 RET label0: MOV <i>, 0 loop0: MOV <temp>, <a> MOV <a>, <b> ADD <temp>, <b> MOV <b>, <temp> INC <i> CMP <i>, <n> JL loop0 RET
  17. IL ASM to ASM MOV <a>, 0 MOV <b>, 1

    CMP <n>, 1 JNL label0 RET label0: MOV <i>, 0 loop0: MOV <temp>, <a> MOV <a>, <b> ADD <temp>, <b> MOV <b>, <temp> INC <i> CMP <i>, <n> JL loop0 RET # for x64 system args_registers = ['rdi', 'rsi', 'rdx', ...] registers = ['rax', 'rbx', 'rcx', ...] # return register: rax def fibonacci(n): n ⇔ rdi ... return a a ⇔ rax
  18. IL ASM to ASM MOV rax, 0 MOV rbx, 1

    CMP rdi, 1 JNL label0 RET label0: MOV rcx, 0 loop0: MOV rdx, rax MOV rax, rbx ADD rdx, rbx MOV rbx, rdx INC rcx CMP rcx, rdi JL loop0 RET MOV <a>, 0 MOV <b>, 1 CMP <n>, 1 JNL label0 RET label0: MOV <i>, 0 loop0: MOV <temp>, <a> MOV <a>, <b> ADD <temp>, <b> MOV <b>, <temp> INC <i> CMP <i>, <n> JL loop0 RET
  19. ASM to machine code MOV rax, 0 MOV rbx, 1

    CMP rdi, 1 JNL label0 RET label0: MOV rcx, 0 loop0: MOV rdx, rax MOV rax, rbx ADD rdx, rbx MOV rbx, rdx INC rcx CMP rcx, rdi JL loop0 RET
  20. from pwnlib.asm import asm code = asm(asm_code, arch='amd64') ASM to

    machine code MOV rax, 0 MOV rbx, 1 CMP rdi, 1 JNL label0 RET label0: MOV rcx, 0 loop0: MOV rdx, rax MOV rax, rbx ADD rdx, rbx MOV rbx, rdx INC rcx CMP rcx, rdi JL loop0 RET
  21. ASM to machine code MOV rax, 0 MOV rbx, 1

    CMP rdi, 1 JNL label0 RET label0: MOV rcx, 0 loop0: MOV rdx, rax MOV rax, rbx ADD rdx, rbx MOV rbx, rdx INC rcx CMP rcx, rdi JL loop0 RET \x48\xc7\xc0\x00\x00\x00\x00 \x48\xc7\xc3\x01\x00\x00\x00 \x48\x83\xff\x01\x7d\x01\xc3 \x48\xc7\xc1\x00\x00\x00\x00 \x48\x89\xc2\x48\x89\xd8\x48 \x01\xda\x48\x89\xd3\x48\xff \xc1\x48\x39\xf9\x7c\xec\xc3
  22. Create function in memory 1) Allocate memory 2) Copy machine

    code to allocated memory 3) Mark the memory as executable
  23. Create function in memory 1) Allocate memory 2) Copy machine

    code to allocated memory 3) Mark the memory as executable Linux: mmap, mprotect Windows: VirtualAlloc, VirtualProtect
  24. Signatures in C/C++ Linux: void *mmap(void *addr, size_t length, int

    prot, int flags, int fd, off_t offset); int mprotect(void *addr, size_t len, int prot); void *memcpy(void *dest, const void *src, size_t n); int munmap(void *addr, size_t length); Windows: LPVOID VirtualAlloc(LPVOID lpAddress, SIZE_T dwSize, DWORD flAllocationType, DWORD flProtect); BOOL VirtualProtect(LPVOID lpAddress, SIZE_T dwSize, DWORD flNewProtect, PDWORD lpflOldProtect); void *memcpy(void *dest, const void *src, size_t count); BOOL VirtualFree(LPVOID lpAddress, SIZE_T dwSize, DWORD dwFreeType);
  25. Create function in memory import ctypes # Linux libc =

    ctypes.CDLL('libc.so.6') libc.mmap libc.mprotect libc.memcpy libc.munmap # Windows ctypes.windll.kernel32.VirtualAlloc ctypes.windll.kernel32.VirtualProtect ctypes.cdll.msvcrt.memcpy ctypes.windll.kernel32.VirtualFree
  26. Create function in memory mmap_func = libc.mmap mmap_func.argtype = [ctypes.c_void_p,

    ctypes.c_size_t, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_size_t] mmap_func.restype = ctypes.c_void_p memcpy_func = libc.memcpy memcpy_func.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t] memcpy_func.restype = ctypes.c_char_p
  27. Create function in memory machine_code = '\x48\xc7\xc0\x00\x00\x00\x00\x48\xc7\xc3\x01\x00\x00\x00\x48 \x83\xff\x01\x7d\x01\xc3\x48\xc7\xc1\x00\x00\x00\x00\x48\x89\xc2\x48\x89\xd8 \x48\x01\xda\x48\x89\xd3\x48\xff\xc1\x48\x39\xf9\x7c\xec\xc3' machine_code_size

    = len(machine_code) addr = mmap_func(None, machine_code_size, PROT_READ | PROT_WRITE | PROT_EXEC, MAP_ANONYMOUS | MAP_PRIVATE, -1, 0) memcpy_func(addr, machine_code, machine_code_size) func = ctypes.CFUNCTYPE(ctypes.c_uint64)(addr) func.argtypes = [ctypes.c_uint32]
  28. for _ in range(1000000): fibonacci(n) n No JIT (s) JIT

    (s) 0 0,153 0,882 10 1,001 0,878 20 1,805 0,942 30 2,658 0,955 60 4,800 0,928 90 7,117 0,922 500 50,611 1,251 Python 2.7 No JIT JIT
  29. n No JIT (s) JIT (s) 0 0,150 1,079 10

    1,093 0,971 20 2,206 1,135 30 3,313 1,204 60 6,815 1,198 90 10,458 1,270 500 63.949 1,652 for _ in range(1000000): fibonacci(n) Python 3.7 No JIT JIT
  30. Python 2.7 vs 3.7 fibonacci(n=93) No JIT: 10.524 s JIT:

    1.185 s JIT ~8.5 times faster JIT compilation time: ~0.08 s fibonacci(n=93) No JIT: 7.942 s JIT: 0.887 s JIT ~8.5 times faster JIT compilation time: ~0.07 s VS * fibonacci(n=92) = 0x68a3dd8e61eccfbd fibonacci(n=93) = 0xa94fad42221f2702
  31. 0 LOAD_CONST 1 (0) 3 STORE_FAST 1 (a) 6 LOAD_CONST

    2 (1) 9 STORE_FAST 2 (b) 12 LOAD_FAST 0 (n) 15 LOAD_CONST 2 (1) 18 COMPARE_OP 0 (<) 21 POP_JUMP_IF_FALSE 28 24 LOAD_FAST 1 (a) 27 RETURN_VALUE >> 28 LOAD_CONST 1 (0) 31 STORE_FAST 3 (i) 34 SETUP_LOOP 48 (to 85) >> 37 LOAD_FAST 3 (i) 40 LOAD_FAST 0 (n) 43 COMPARE_OP 0 (<) 46 POP_JUMP_IF_FALSE 84 49 LOAD_FAST 1 (a) 52 STORE_FAST 4 (temp) 55 LOAD_FAST 2 (b) 58 STORE_FAST 1 (a) 61 LOAD_FAST 4 (temp) 64 LOAD_FAST 2 (b) 67 BINARY_ADD 68 STORE_FAST 2 (b) 71 LOAD_FAST 3 (i) 74 LOAD_CONST 2 (1) 77 INPLACE_ADD 78 STORE_FAST 3 (i) 81 JUMP_ABSOLUTE 37 >> 84 POP_BLOCK >> 85 LOAD_FAST 1 (a) 88 RETURN_VALUE MOV rax, 0 MOV rbx, 1 CMP rdi, 1 JNL label0 RET label0: MOV rcx, 0 loop0: MOV rdx, rax MOV rax, rbx ADD rdx, rbx MOV rbx, rdx INC rcx CMP rcx, rdi JL loop0 RET VS 33 (VM opcodes) vs 14 (real machine instructions) No JIT vs JIT
  32. Numba makes Python code fast Numba is an open source

    JIT compiler that translates a subset of Python and NumPy code into fast machine code - Parallelization - SIMD Vectorization - GPU Acceleration Numba
  33. from numba import jit import numpy as np @jit(nopython=True) #

    Set "nopython" mode for best performance, equivalent to @njit def go_fast(a): # Function is compiled to machine code when called the first time trace = 0 for i in range(a.shape[0]): # Numba likes loops trace += np.tanh(a[i, i]) # Numba likes NumPy functions return a + trace # Numba likes NumPy broadcasting @cuda.jit def matmul(A, B, C): """Perform square matrix multiplication of C = A * B """ i, j = cuda.grid(2) if i < C.shape[0] and j < C.shape[1]: tmp = 0. for k in range(A.shape[1]): tmp += A[i, k] * B[k, j] C[i, j] = tmp
  34. LLVM — compiler infrastructure project Tutorial “Building a JIT: Starting

    out with KaleidoscopeJIT” LLVMPy — Python bindings for LLVM LLVMLite project by Numba team — lightweight LLVM Python binding for writing JIT compilers LLVM
  35. x86-64 assembler embedded in Python Portable Efficient Assembly Code-generator in

    Higher-level Python PeachPy from peachpy.x86_64 import * ADD(eax, 5).encode() # bytearray(b'\x83\xc0\x05') MOVAPS(xmm0, xmm1).encode_options() # [bytearray(b'\x0f(\xc1'), bytearray(b'\x0f)\xc8')] VPSLLVD(ymm0, ymm1, [rsi + 8]).encode_length_options() # {6: bytearray(b'\xc4\xe2uGF\x08'), # 7: bytearray(b'\xc4\xe2uGD&\x08'), # 9: bytearray(b'\xc4\xe2uG\x86\x08\x00\x00\x00')}
  36. PyPy PyPy is a fast, compliant alternative implementation of the

    Python language Python programs often run faster on PyPy thanks to its Just-in-Time compiler PyPy works best when executing long-running programs where a significant fraction of the time is spent executing Python code “If you want your code to run faster, you should probably just use PyPy” — Guido van Rossum (creator of Python)
  37. Other projects Pyjion — A JIT for Python based upon

    CoreCLR Pyston — built using LLVM and modern JIT techniques Psyco — extension module which can greatly speed up the execution of code The first just-in-time compiler for Python, now unmaintained and dead Unladen Swallow — was an attempt to make LLVM be a JIT compiler for CPython
  38. References 1. https://en.wikipedia.org/wiki/Just-in-time_compilation 2. John Aycock: A Brief History of

    Just-In-Time. ACM Computing Surveys (CSUR) Surveys, volume 35, issue 2, pages 97-113, June 2003, DOI: 10.1145/857076.857077 3. https://eli.thegreenplace.net/2013/11/05/how-to-jit-an-introduction 4. https://medium.com/starschema-blog/jit-fast-supercharge-tensor-processing-in-python-with-jit-com pilation-47598de6ee96 5. https://github.com/Gallopsled/pwntools 6. https://numba.pydata.org 7. https://llvm.org/docs/tutorial/BuildingAJIT1.html 8. https://llvmlite.readthedocs.io/en/latest/ 9. http://www.llvmpy.org 10. https://github.com/Maratyszcza/PeachPy 11. https://github.com/microsoft/Pyjion 12. https://blog.pyston.org