Upgrade to Pro — share decks privately, control downloads, hide ads and more …

Low-Level Implementation of Finite Field Operations using LLVM/Asm

herumi
March 30, 2023

Low-Level Implementation of Finite Field Operations using LLVM/Asm

Open Source Cryptography Workshop 2023/Mar/30
https://rsvp.withgoogle.com/events/open-source-cryptography-workshop

herumi

March 30, 2023
Tweet

More Decks by herumi

Other Decks in Programming

Transcript

  1. Low-Level Implementation of Finite Field Operations using LLVM/Asm 2023/3/30 OSS

    Cryptography Workshop Cybozu Labs, Inc. Mitsunari Shigeo (光成滋生)
  2. • Reasercher and Software Engineer at Cybozu Labs, Inc. •

    Developer of some OSS • mcl/bls (https://github.com/herumi/mcl) • Pairing/BLS Signature Library • For DFINITY (2016), for Ethereum 2.0 (2019) • https://www.npmjs.com/package/mcl-wasm WeAssembly package.130k DLs/week • Xbyak/Xbyak_aarch64 • https://github.com/herumi/xbyak • JIT assembler for x64/AArch64 (A64) by C++ • Designed For Low-Level Optimization • Heavily Used in oneDNN (Intel AI Framework) / ZenDNN (AMD), Fugaku (Japanese Supercomputer) My Introduction 2 /32 @herumi
  3. • Motivation • For security reasons, we want to write

    in a "safe language" as much as possible. • But high performance is also required. • We want to minimize the use of assembly language. • What are some safe ways to write such code? • How practical is a generic description using LLVM? • DSL (domain specific language) makes it easier to read and write. • The current performance and limitations of these approaches. Today's Topics 3 /32
  4. • Pairing is a mathematical function used in BLS signature,

    zk-SNARK, homomorphic encryption, etc. • Multiplication on finite fields occupies most of the time for pairing calculations. • 256~384-bit prime of a finite field is often used. Pairing 4 /32
  5. • Compare with blst (https://github.com/supranational/blst) • Finite field Operations on

    MacBook Pro, M1, 2020 nsec • DSL is almost at blst speed. • https://github.com/herumi/mcl-ff (under construction) • See below for details. Benchmark of Generated Code by DSL 381 bit prime blst DSL+LLVM add 6.40 6.50 sub 5.30 5.60 mont (mul) 31.40 31.50 5 /32
  6. • 𝑝 : a prime • 𝑝 is about 256~384-bit

    • 𝔽𝑝 = {0,1, … , 𝑝 − 1} • 𝑥 ± 𝑦 ≔ 𝑥 ± 𝑦 mod 𝑝 • 𝑥𝑦 ≔ 𝑥𝑦 mod 𝑝 for 𝑥, 𝑦 ∈ 𝔽𝑝 • 𝐿 = 64(𝑜𝑟 32) : CPU register size, 𝑀 = 2𝐿 • 𝑥 = 𝑥 4 = [𝑥3 : 𝑥2 : 𝑥1 : 𝑥0 ] (𝑥𝑖 is L-bit) is a 256-bit integer 𝑥 = ෍ 𝑖 𝑥𝑖 𝑀𝑖 • 𝑥𝑖 ∈ [0, 2𝐿) • uint64_t x[N]; // written in C/C++. N = 4 Finite Field and Notation 6 /32
  7. • A sum of 64-bit registers is 65 bits. •

    65bit = 64bit + CF (1bit) • How CF is calculated Review of Addition of Two Digits 36 + 47 3 6 + 4 7 ----- 1 3 3 4 ----- 8 3 increase in digits bool hasCF(U x, U y) { U z = x + y; return z < x; // CF = 1 if z < x } 7 /32
  8. • RISC-V has no carry operations, so this is good,

    but slow on x64 (x86-64) / A64 (AArch64). addT using C++ // U = uint64_t or uint32_t // ⟨𝑧⟩𝑁 = ⟨𝑥⟩𝑁 + ⟨𝑦⟩𝑁 and return CF template<size_t N> U addT(U *z, const U *x, const U *y) { U c = 0; // There is no CF at first. for (size_t i = 0; i < N; i++) { U xc = x[i] + c; c = xc < c; // CF1 U yi = y[i]; xc += yi; c += xc < yi; // CF2 z[i] = xc; } return c; } 8 /32
  9. • Add operations • Input : x, y are 64-bit

    integer registers • Output : x ← (x + y) % 𝑀(= 264) • CF ← 1 if (x + y) ≥ 𝑀 else 0 • Input : x, y and CF • Output : x ← (x + y + CF) % 𝑀 • CF ← 1 if (x + y + CF) ≥ 𝑀 else 0 • adds and adcs are add operations of AArch64 (A64). Carry Operation of x64 add x, y adc x, y 9 /32
  10. • x64 Linux or Intel macOS Asm Code of add3

    for x64 Linux add3: ; macOS requires an underscore prefix mov (%rsi), %rax ; rax ← x[0] add (%rdx), %rax ; rax ← rax + y[0] mov %rax, (%rdi) ; z[0] ← rax mov 8(%rsi), %rax ; rax ← x[1] adc 8(%rdx), %rax ; rax ← rax + y[1] + CF mov %rax, 8(%rdi) ; z[2] ← rax mov 16(%rsi), %rax ; rax ← x[2] adc 16(%rdx), %rax ; rax ← rax + y[2] + CF mov %rax, 16(%rdi) ; z[2] ← rax setc %al ; al ← CF movzx %al, %eax ; rax ← zero extension of al ret calling convension U addN(U *z, const U *x, const U *y); rax rdi rsi rdx 10 /32
  11. • 64-bit Windows uses different registers Asm Code of add3

    for Win64 add3: mov rax, [rdx] ; rax ← x[0] add rax, [r8] ; rax ← rax + y[0] mov [rcx], rax ; z[0] ← rax mov rax, [rdx+8] ; rax ← x[1] adc rax, [r8+8] ; rax ← rax + y[1] + CF mov [rcx+8], rax ; z[2] ← rax mov rax, [rdx+16] ; rax ← x[2] adc rax, [r8+16] ; rax ← rax + y[2] + CF mov [rcx+16], rax ; z[2] ← rax setc al ; al ← CF movzx eax, al ; rax ← zero extension of al ret calling convension U addN(U *z, const U *x, const U *y); rax rcx rdx r8 11 /32
  12. • Absorbing the differences between Linux and Windows • static

    code generator, a thin wrapper of assembler • gen_add(N) generates add{N}(z, x, y) for Win(masm)/Linux(gas) DSL for x64 def gen_add(N): with FuncProc(f'add{N}'): # declare the function with StackFrame(3) as sf: # make a stack frame z = sf.p[0] # 1st argument reg x = sf.p[1] # 2nd argument reg y = sf.p[2] # 3rd argument reg for i in range(N): mov(rax, ptr(x + 8 * i)) # rax ← x[i] if i == 0: add(rax, ptr(y + 8 * i)) # rax ← add(rax, y[i]) else: adc(rax, ptr(y + 8 * i)) # rax ← adc(rax, y[i]) mov(ptr(z + 8 * i), rax) # z[i] ← rax setc(al) movzx(eax, al) 12 /32
  13. • LLVM IR : A low-level, platform-independent language • Features

    • SSA (Static Single Assignment form) • Integer type with an arbitrary fixed bit width %i{N} • Supports add/sub of arbitrary bit fixed width without CF • %i{N} ← add(%i{N}, %i{N}) • %i{N} ← sub(%i{N}, %i{N}) • Supports minimum multiplication • 128bit ← 64bit × 64bit (64bit CPU) • 64bit ← 32bit × 32bit (32bit CPU) • reference : https://llvm.org/docs/LangRef.html LLVM source code (LLVM IR : ll) clang –S –O2 x64 A64 RISC-V MIPS WASM etc. 13 /32
  14. • bool add4(U *z, const U *x, const U *y);

    • It is annoying to always write the register size. add4 Using LLVM define i1 @add4(i256* %pz, i256* %px, i256* %py) { %x = load i256, i256* %px // x ← *(const uint256_t*)px; %y = load i256, i256* %py // y ← *(const uint256_t*)py; %x2 = zext i256 %x to i257 // x2 ← uint257_t(x); %y2 = zext i256 %y to i257 // y2 ← uint257_t(y); %z = add i257 %x2, %y2 // z ← x2 + y2 %z2 = trunc i257 %z to i256 // z2 ← uint256_t(x); store i256 %z2, i256* %pz // *(uint256_t*)pz = z2; %z3 = lshr i257 %z, 256 // z3 ← z >> 256 %z4 = trunc i257 %z3 to i1 // z4 ← uint1_t(z3); ret i1 %z4 // return z4; } 14 /32
  15. • clang –S –O2 –target x86_64 add4.ll • The output

    is almost the same as the previous code. • clang –S –O2 –target aarch64 add4.ll outputs Generating Target Assembler Code add4: ldp x9, x8, [x1] # [x8:x9] ← ((u128*)px)[0]; ldp x11, x10, [x2] # [x10:x11] ← ((u128*)py)[0]; ldp x13, x12, [x1, #16] # [x12:x13] ← ((u128*)px)[1]; ldp x14, x15, [x2, #16] # [x15:x14] ← ((u128*)py)[1]; adds x9, x11, x9 # x9 ← add(x11, x9) adcs x8, x10, x8 # x8 ← add(x10, x8, CF); adcs x10, x14, x13 # x10 ← add(x14, x13, CF); stp x9, x8, [x0] # ((u128*)pz)[0] = [z8:z9] adcs x9, x15, x12 # x9 ← add(x15, x12, CF); adcs x8, xzr, xzr # x8 ← add(0, 0, CF); stp x10, x9, [x0, #16] # ((u128*)pz)[1] = [x9:x10] mov w0, w8 # w0 ← x8 ret 15 /32
  16. • ll (LLVM IR) seems good, but it is hard

    to write by hand. • I developed a simple DSL to make it easier to write. • Variables can be rewritten. • non SSA • No need to write register size • ll : %x1 = add i256 %x, %y DSL : x = add(x, y); DSL for LLVM using Python (and C++) code (ll) clang –S –O2 x64 A64 RISC-V MIPS generator DSL WASM 16 /32
  17. • gen_add(N) generates the previous code for each N Generating

    add.ll with DSL unit = 64 def gen_add(N): bit = unit * N pz = IntPtr(unit) # define pointer to int px = IntPtr(unit) py = IntPtr(unit) name = f'mcl_fp_addPre{N}' with Function(name, Int(unit), pz, px, py): x = loadN(px, N) # x = ⟨∗ 𝑝𝑥⟩𝑁 y = loadN(py, N) # y = ⟨∗ 𝑝𝑦⟩𝑁 x = zext(x, bit + unit) y = zext(y, bit + unit) z = add(x, y) storeN(trunc(z, bit), pz) r = trunc(lshr(z, bit), unit) ret(r) 17 /32
  18. • Python code of add of a finite field •

    Introduce select function Fp::add Using Python # assume 0 <= x, y < p def fp_add(x, y): z = x + y if z >= p: z -= p return z def select(cond, x, y): if cond: return x else: return y def fp_add(x, y): z = x + y w = z – p return select(w < 0, z, w) 18 /32
  19. • gen_fp_add(N) generates F::add for each N Fp::add Using DSL

    def gen_fp_add(name, N, pp): bit = unit * N pz = IntPtr(unit) px = IntPtr(unit) py = IntPtr(unit) with Function(name, Void, pz, px, py): x = zext(loadN(px, N), bit + unit) y = zext(loadN(py, N), bit + unit) x = add(x, y) # x ← x + y p = zext(load(pp), bit + unit) y = sub(x, p) # y ← x + y - p c = trunc(lshr(y, bit), 1) # c = MSB of y x = select(c, x, y) # x = (y < 0) ? x : y x = trunc(x, bit) storeN(x, pz) Differences with gen_add() 19 /32
  20. • Full-Bit Prime • The bit size of p is

    a multiple of 64. e.g. 𝑝 = 256, 384 • Non-Full-Bit Prime • e.g. 𝑝 = 381 • If p is non-full-bit, the x+y value does not have to be zext. Optimizing Non-Full-Bit Primes x = loadN(px, N) y = loadN(py, N) x = add(x, y) p = load(pp) y = sub(x, p) c = trunc(lshr(y, bit - 1), 1) # c = MSB(y) = (y < 0) x = select(c, x, y) storeN(x, pz) ret(Void) 20 /32
  21. • x64 code has no conditional branch (constant time). ◼

    This code is nearly perfect. ◼ -target=aarch64 also generates a good code. Fp:add Generated Code by LLVM fp_add3_not_full: mov r8, [rdx] add r8, [rsi] mov r9, [rdx + 8] adc r9, [rsi + 8] mov r10, [rdx + 16] adc r10, [rsi + 16] mov rsi, r8 sub rsi, [rcx] mov rax, r9 sbb rax, [rcx + 8] mov rdx, r10 sbb rdx, [rcx + 16] mov rcx, rdx sar rcx, 63 cmovs rdx, r10 cmovs rax, r9 cmovs rsi, r8 mov [rdi], rsi mov [rdi + 8], rax mov [rdi + 16], rdx ret [r10:r9:r8] ← [px] [r10:r9:r8] += [py] [rdx:rax:rsi] ← [r10:r9:r8] ; copy [rdx:rax:rsi] -= [pp] check the top bit of rdx [rdx:rax:rsi] ← [r10:r9:r8] if CF (<0) [pz] ← [rdx:rax:rsi] 21 /32
  22. • u64 mulUnit(u64 *z, const u64 *x, u64 y); •

    computes 𝑥 𝑁 × 𝑦 = 𝑧 𝑁+1 . • LLVM IR has only u128 mul(u64, u64); • pack is a combination of shl+or Multiplication by u64 [x3:x2:x1:x0] X y ---------------- [H0:L0] [H1:L1] [H2:L2] [H3:L3] ----------------- [z4:z3:z2:z1:z0] L=[ 0:L3:L2:L1:L0] + H=[H3:H2:H1:H0: 0] ↓ z=[z4:z3:z2:z1:z0] with Function(f'mulUnit{N}', z, px, y) Ls = [] Hs = [] y = zext(y, unit*2) for i in range(N): x = load(getelementptr(px, i)) # x[i] xy = mul(zext(x, unit*2), y) Ls.append(trunc(xy, unit)) Hs.append(trunc(lshr(xy, unit), unit)) bu = bit + unit L = zext(pack(Ls), bu) H = shl(zext(pack(Hs), bu), unit) z = add(L, H) 22 /32
  23. • mulUnit3 • The DSL code generates many shifts and

    or operations, so it might be slow, but, Generated ll by DSL define i256 @mulUnit3(i64* noalias %r2, i64 %r3) { ...(snip) %r25 = zext i64 %r11 to i128 %r26 = shl i128 %r25, 64 %r27 = or i128 %r24, %r26 %r28 = zext i128 %r27 to i192 %r29 = zext i64 %r15 to i192 %r30 = shl i192 %r29, 128 %r31 = or i192 %r28, %r30 %r32 = zext i192 %r23 to i256 %r33 = zext i192 %r31 to i256 %r34 = shl i256 %r33, 64 %r35 = add i256 %r32, %r34 ret i256 %r35 } 23 /32
  24. • It uses only mul(3) + add(1) + adc(2) •

    This is the best code and –mbmi2 then mulx is used. ◼ mulx ◼ supported by Haswell or laster. ◼ clashes the older system. ◼ does not change CF, so we can use it with add and adc. ◼ reduces the use of temporary regs. Generated Code for x64 mclb_mulUnit3: mov rcx, rdx mov rax, rdx mul qword [rsi] mov r8, rdx mov r9, rax mov rax, rcx mul qword [rsi + 8] mov r10, rdx mov r11, rax mov rax, rcx mul qword [rsi + 16] add r11, r8 adc rax, r10 adc rdx, 0 mov [rdi], r9 mov [rdi + 8], r11 mov [rdi + 16], rax mov rax, rdx ret mclb_mulUnit3: mulx r10, r8, [rsi] mulx r9, rcx, [rsi + 8] mulx rax, rdx, [rsi + 16] add rcx, r10 adc rdx, r9 adc rax, 0 mov [rdi], r8 mov [rdi + 8], rcx mov [rdi + 16], rdx ret -O2 -mbmi2 24 /32
  25. • A64 code also seems good. Generated Code for A64

    mulUnit3: ldp x9, x10, [x1] mov x8, x0 # pz = x8 ldr x11, [x1, #16] # x = [x11:x10:x9] umulh x12, x9, x2 # H0 = x[0] * y umulh x13, x10, x2 # H1 = x[1] * y mul x10, x10, x2 # L1 = x[1] * y mul x14, x11, x2 # L2 = x[2] * y umulh x11, x11, x2 # H2 = x[2] * y adds x10, x12, x10 # z1 = H0 + L1 adcs x12, x13, x14 # z2 = H1 + L2 + CF mul x9, x9, x2 # L0 = x[0] * y cinc x0, x11, hs # ret = H2 + CF str x12, [x8, #16] # pz[2] = z2 stp x9, x10, [x8] # pz[0] = L0, pz[1] = z1 ret A64 mul operations M=1<<64 mul(x:u64, y:u64): return x*y % M umulh(x:u64, y:u64): return (x*y)>>64 25 /32
  26. • A method for fast modular multiplication • toM(x) :

    converts x to a Montgomery form • fromM(x) : converts x to a normal form • 𝑥 × 𝑦 ≡ 𝑓𝑟𝑜𝑚𝑀 𝑚𝑜𝑛𝑡 𝑡𝑜𝑀 𝑥 , 𝑡𝑜𝑀 𝑦 mod p • Preparation • 𝐿 = 64 𝑜𝑟 32 and 𝑀 = 2𝐿 • 𝑝 = 𝑝 𝑁 : a prime • 𝑖𝑝 ≡ 𝑝−1 mod 𝑀 Montgomery Multiplication 26 /32
  27. • An example code of mont(x, y) Mont using Python

    L=64 M=2**L # ip = p^(-1) mod M def mont(x, y): t = 0 for i in range(N): t += x * ((y >> (L * i)) % M) # mulUnit q = ((t % M) * ip) % M # u64 x u64 → u64 t += q * p # mulUnit t >>= L if t >= p: t -= p return t 27 /32
  28. • Assume Non-Full-Bit prime Mont using DSL def gen_mont(name, mont,

    mulUnit): pz = IntPtr(unit) px = IntPtr(unit) py = IntPtr(unit) with Function(name, Void, pz, px, py): t = Imm(0, unit*N+unit) for i in range(N): y = load(getelementptr(py, i)) t = add(t, call(mulUnit, px, y)) q = mul(trunc(t, unit), ip) t = add(t, call(mulUnit, pp, q)) t = lshr(t, unit) t = trunc(t, unit*N) vc = sub(t, loadN(pp, N)) c = trunc(lshr(vc, unit*N-1), 1) storeN(select(c, t, vc), pz) ret(Void) # Python code t += x * ((y >> (L*i)) % M) q = ((t % M) * ip) % M t += q * p t >>= L if t >= p: t -= p 𝑧 𝑁+1 =mulUnit( 𝑥 𝑁 ,y) 28 /32
  29. • Compare with blst (https://github.com/supranational/blst) • on MacBook Pro, M1,

    2020 nsec • mont on Xeon Platinum 8280 (turbo boost off) • DSL+LLVM with –O2 is not fast. • slower than mcl even if –mbmi2 (using mulx) option is added. • mcl (and blst) uses adox/adcx that can be used after IvyBridge. Benchmark of Generated Code by DSL 381 bit prime blst DSL+LLVM add 6.40 6.50 sub 5.30 5.60 mont 31.40 31.50 381 bit prime mcl DSL+LLVM DSL+LLVM –mbmi2 mont 28.93 42.86 35.52 29 /32
  30. • We only have one CF, then z and xy

    must be kept in registers until they are added. • A64 has 29 regs. but x64 has only 15 regs. • adox/adcx offer two CFs. Register spills can be avoided. • clang/gcc will not use the instruction set now. • https://gcc.gnu.org/legacy-ml/gcc-help/2017-08/msg00100.html Computing 𝑧 𝑁 + 𝑥 𝑁 ∗ 𝑦 with adox/adcx [ z ] [H0 L0] ← 𝑥0 𝑦 [H1 L1] ← 𝑥1 𝑦 [H2 L2] ← 𝑥2 𝑦 [H3 L3] ← 𝑥3 𝑦 ------------------ [z4:z3:z2:z1:z0] adcx adox 30 /32
  31. • How to make mulAdd : z += [px] *

    rdx • z is an array of registers. • px is a pointer to x. • L, H are temporary registers. • rdx is y. • make Mont using mulAdd DSL for x64 # z[n..0] = z[n-1..0] + px[n-1..0] * rdx def mulAdd(z, px): n = len(z)-1 xor_(z[n], z[n]) for i in range(n): mulx(H, L, ptr(px+i*8)) # [H:L] = px[i] * y adox(z[i], L) # z[i] += L with CF if i == n-1: break adcx(z[i + 1], H) # z[i+1] += H with CF' adox(z[n], H) # z[n] += H with CF adc(z[n], 0) # z[n] += CF' 31 /32
  32. • DSL that generates LLVM IR • Advantages • Easy

    to read and write • Fast enough on M1 • Disadvantages • Not fully optimized for x64 • May be solved by compiler optimization in the future. • Requires x64-specific DSL now. Conclusion 32 /32