Upgrade to Pro — share decks privately, control downloads, hide ads and more …

Low-Level Implementation of Finite Field Operations using LLVM/Asm

herumi
March 30, 2023

Low-Level Implementation of Finite Field Operations using LLVM/Asm

Open Source Cryptography Workshop 2023/Mar/30
https://rsvp.withgoogle.com/events/open-source-cryptography-workshop

herumi

March 30, 2023
Tweet

More Decks by herumi

Other Decks in Programming

Transcript

  1. Low-Level Implementation of
    Finite Field Operations using LLVM/Asm
    2023/3/30
    OSS Cryptography Workshop
    Cybozu Labs, Inc. Mitsunari Shigeo (光成滋生)

    View full-size slide

  2. • Reasercher and Software Engineer at Cybozu Labs, Inc.
    • Developer of some OSS
    • mcl/bls (https://github.com/herumi/mcl)
    • Pairing/BLS Signature Library
    • For DFINITY (2016), for Ethereum 2.0 (2019)
    • https://www.npmjs.com/package/mcl-wasm
    WeAssembly package.130k DLs/week
    • Xbyak/Xbyak_aarch64
    • https://github.com/herumi/xbyak
    • JIT assembler for x64/AArch64 (A64) by C++
    • Designed For Low-Level Optimization
    • Heavily Used in oneDNN (Intel AI Framework) / ZenDNN
    (AMD), Fugaku (Japanese Supercomputer)
    My Introduction
    2 /32
    @herumi

    View full-size slide

  3. • Motivation
    • For security reasons, we want to write in a "safe language" as
    much as possible.
    • But high performance is also required.
    • We want to minimize the use of assembly language.
    • What are some safe ways to write such code?
    • How practical is a generic description using LLVM?
    • DSL (domain specific language) makes it easier to read and
    write.
    • The current performance and limitations of these approaches.
    Today's Topics
    3 /32

    View full-size slide

  4. • Pairing is a mathematical function used in BLS
    signature, zk-SNARK, homomorphic encryption, etc.
    • Multiplication on finite fields occupies most of the time
    for pairing calculations.
    • 256~384-bit prime of a finite field is often used.
    Pairing
    4 /32

    View full-size slide

  5. • Compare with blst (https://github.com/supranational/blst)
    • Finite field Operations on MacBook Pro, M1, 2020
    nsec
    • DSL is almost at blst speed.
    • https://github.com/herumi/mcl-ff (under construction)
    • See below for details.
    Benchmark of Generated Code by DSL
    381 bit prime blst DSL+LLVM
    add 6.40 6.50
    sub 5.30 5.60
    mont (mul) 31.40 31.50
    5 /32

    View full-size slide

  6. • 𝑝 : a prime
    • 𝑝 is about 256~384-bit
    • 𝔽𝑝
    = {0,1, … , 𝑝 − 1}
    • 𝑥 ± 𝑦 ≔ 𝑥 ± 𝑦 mod 𝑝
    • 𝑥𝑦 ≔ 𝑥𝑦 mod 𝑝 for 𝑥, 𝑦 ∈ 𝔽𝑝
    • 𝐿 = 64(𝑜𝑟 32) : CPU register size, 𝑀 = 2𝐿
    • 𝑥 = 𝑥 4
    = [𝑥3
    : 𝑥2
    : 𝑥1
    : 𝑥0
    ] (𝑥𝑖
    is L-bit) is a 256-bit integer
    𝑥 = ෍
    𝑖
    𝑥𝑖
    𝑀𝑖
    • 𝑥𝑖
    ∈ [0, 2𝐿)
    • uint64_t x[N]; // written in C/C++. N = 4
    Finite Field and Notation
    6 /32

    View full-size slide

  7. • A sum of 64-bit registers is 65 bits.
    • 65bit = 64bit + CF (1bit)
    • How CF is calculated
    Review of Addition of Two Digits 36 + 47
    3 6
    + 4 7
    -----
    1 3
    3
    4
    -----
    8 3
    increase in digits
    bool hasCF(U x, U y) {
    U z = x + y;
    return z < x; // CF = 1 if z < x
    }
    7 /32

    View full-size slide

  8. • RISC-V has no carry operations, so this is good,
    but slow on x64 (x86-64) / A64 (AArch64).
    addT using C++
    // U = uint64_t or uint32_t
    // ⟨𝑧⟩𝑁
    = ⟨𝑥⟩𝑁
    + ⟨𝑦⟩𝑁
    and return CF
    template
    U addT(U *z, const U *x, const U *y) {
    U c = 0; // There is no CF at first.
    for (size_t i = 0; i < N; i++) {
    U xc = x[i] + c;
    c = xc < c; // CF1
    U yi = y[i];
    xc += yi;
    c += xc < yi; // CF2
    z[i] = xc;
    }
    return c;
    }
    8 /32

    View full-size slide

  9. • Add operations
    • Input : x, y are 64-bit integer registers
    • Output : x ← (x + y) % 𝑀(= 264)
    • CF ← 1 if (x + y) ≥ 𝑀 else 0
    • Input : x, y and CF
    • Output : x ← (x + y + CF) % 𝑀
    • CF ← 1 if (x + y + CF) ≥ 𝑀 else 0
    • adds and adcs are add operations of AArch64 (A64).
    Carry Operation of x64
    add x, y
    adc x, y
    9 /32

    View full-size slide

  10. • x64 Linux or Intel macOS
    Asm Code of add3 for x64 Linux
    add3: ; macOS requires an underscore prefix
    mov (%rsi), %rax ; rax ← x[0]
    add (%rdx), %rax ; rax ← rax + y[0]
    mov %rax, (%rdi) ; z[0] ← rax
    mov 8(%rsi), %rax ; rax ← x[1]
    adc 8(%rdx), %rax ; rax ← rax + y[1] + CF
    mov %rax, 8(%rdi) ; z[2] ← rax
    mov 16(%rsi), %rax ; rax ← x[2]
    adc 16(%rdx), %rax ; rax ← rax + y[2] + CF
    mov %rax, 16(%rdi) ; z[2] ← rax
    setc %al ; al ← CF
    movzx %al, %eax ; rax ← zero extension of al
    ret
    calling convension
    U addN(U *z, const U *x, const U *y);
    rax rdi rsi rdx
    10 /32

    View full-size slide

  11. • 64-bit Windows uses different registers
    Asm Code of add3 for Win64
    add3:
    mov rax, [rdx] ; rax ← x[0]
    add rax, [r8] ; rax ← rax + y[0]
    mov [rcx], rax ; z[0] ← rax
    mov rax, [rdx+8] ; rax ← x[1]
    adc rax, [r8+8] ; rax ← rax + y[1] + CF
    mov [rcx+8], rax ; z[2] ← rax
    mov rax, [rdx+16] ; rax ← x[2]
    adc rax, [r8+16] ; rax ← rax + y[2] + CF
    mov [rcx+16], rax ; z[2] ← rax
    setc al ; al ← CF
    movzx eax, al ; rax ← zero extension of al
    ret
    calling convension
    U addN(U *z, const U *x, const U *y);
    rax rcx rdx r8
    11 /32

    View full-size slide

  12. • Absorbing the differences between Linux and Windows
    • static code generator, a thin wrapper of assembler
    • gen_add(N) generates add{N}(z, x, y) for Win(masm)/Linux(gas)
    DSL for x64
    def gen_add(N):
    with FuncProc(f'add{N}'): # declare the function
    with StackFrame(3) as sf: # make a stack frame
    z = sf.p[0] # 1st argument reg
    x = sf.p[1] # 2nd argument reg
    y = sf.p[2] # 3rd argument reg
    for i in range(N):
    mov(rax, ptr(x + 8 * i)) # rax ← x[i]
    if i == 0:
    add(rax, ptr(y + 8 * i)) # rax ← add(rax, y[i])
    else:
    adc(rax, ptr(y + 8 * i)) # rax ← adc(rax, y[i])
    mov(ptr(z + 8 * i), rax) # z[i] ← rax
    setc(al)
    movzx(eax, al)
    12 /32

    View full-size slide

  13. • LLVM IR : A low-level, platform-independent language
    • Features
    • SSA (Static Single Assignment form)
    • Integer type with an arbitrary fixed bit width %i{N}
    • Supports add/sub of arbitrary bit fixed width without CF
    • %i{N} ← add(%i{N}, %i{N})
    • %i{N} ← sub(%i{N}, %i{N})
    • Supports minimum multiplication
    • 128bit ← 64bit × 64bit (64bit CPU)
    • 64bit ← 32bit × 32bit (32bit CPU)
    • reference : https://llvm.org/docs/LangRef.html
    LLVM
    source code
    (LLVM IR : ll)
    clang –S –O2
    x64
    A64
    RISC-V
    MIPS
    WASM etc.
    13 /32

    View full-size slide

  14. • bool add4(U *z, const U *x, const U *y);
    • It is annoying to always write the register size.
    add4 Using LLVM
    define i1 @add4(i256* %pz, i256* %px, i256* %py)
    {
    %x = load i256, i256* %px // x ← *(const uint256_t*)px;
    %y = load i256, i256* %py // y ← *(const uint256_t*)py;
    %x2 = zext i256 %x to i257 // x2 ← uint257_t(x);
    %y2 = zext i256 %y to i257 // y2 ← uint257_t(y);
    %z = add i257 %x2, %y2 // z ← x2 + y2
    %z2 = trunc i257 %z to i256 // z2 ← uint256_t(x);
    store i256 %z2, i256* %pz // *(uint256_t*)pz = z2;
    %z3 = lshr i257 %z, 256 // z3 ← z >> 256
    %z4 = trunc i257 %z3 to i1 // z4 ← uint1_t(z3);
    ret i1 %z4 // return z4;
    }
    14 /32

    View full-size slide

  15. • clang –S –O2 –target x86_64 add4.ll
    • The output is almost the same as the previous code.
    • clang –S –O2 –target aarch64 add4.ll outputs
    Generating Target Assembler Code
    add4:
    ldp x9, x8, [x1] # [x8:x9] ← ((u128*)px)[0];
    ldp x11, x10, [x2] # [x10:x11] ← ((u128*)py)[0];
    ldp x13, x12, [x1, #16] # [x12:x13] ← ((u128*)px)[1];
    ldp x14, x15, [x2, #16] # [x15:x14] ← ((u128*)py)[1];
    adds x9, x11, x9 # x9 ← add(x11, x9)
    adcs x8, x10, x8 # x8 ← add(x10, x8, CF);
    adcs x10, x14, x13 # x10 ← add(x14, x13, CF);
    stp x9, x8, [x0] # ((u128*)pz)[0] = [z8:z9]
    adcs x9, x15, x12 # x9 ← add(x15, x12, CF);
    adcs x8, xzr, xzr # x8 ← add(0, 0, CF);
    stp x10, x9, [x0, #16] # ((u128*)pz)[1] = [x9:x10]
    mov w0, w8 # w0 ← x8
    ret
    15 /32

    View full-size slide

  16. • ll (LLVM IR) seems good, but it is hard to write by hand.
    • I developed a simple DSL to make it easier to write.
    • Variables can be rewritten.
    • non SSA
    • No need to write register size
    • ll : %x1 = add i256 %x, %y
    DSL : x = add(x, y);
    DSL for LLVM using Python (and C++)
    code (ll) clang –S –O2
    x64
    A64
    RISC-V
    MIPS
    generator
    DSL
    WASM
    16 /32

    View full-size slide

  17. • gen_add(N) generates the previous code for each N
    Generating add.ll with DSL
    unit = 64
    def gen_add(N):
    bit = unit * N
    pz = IntPtr(unit) # define pointer to int
    px = IntPtr(unit)
    py = IntPtr(unit)
    name = f'mcl_fp_addPre{N}'
    with Function(name, Int(unit), pz, px, py):
    x = loadN(px, N) # x = ⟨∗ 𝑝𝑥⟩𝑁
    y = loadN(py, N) # y = ⟨∗ 𝑝𝑦⟩𝑁
    x = zext(x, bit + unit)
    y = zext(y, bit + unit)
    z = add(x, y)
    storeN(trunc(z, bit), pz)
    r = trunc(lshr(z, bit), unit)
    ret(r)
    17 /32

    View full-size slide

  18. • Python code of add of a finite field
    • Introduce select function
    Fp::add Using Python
    # assume 0 <= x, y < p
    def fp_add(x, y):
    z = x + y
    if z >= p:
    z -= p
    return z
    def select(cond, x, y):
    if cond:
    return x
    else:
    return y
    def fp_add(x, y):
    z = x + y
    w = z – p
    return select(w < 0, z, w)
    18 /32

    View full-size slide

  19. • gen_fp_add(N) generates F::add for each N
    Fp::add Using DSL
    def gen_fp_add(name, N, pp):
    bit = unit * N
    pz = IntPtr(unit)
    px = IntPtr(unit)
    py = IntPtr(unit)
    with Function(name, Void, pz, px, py):
    x = zext(loadN(px, N), bit + unit)
    y = zext(loadN(py, N), bit + unit)
    x = add(x, y) # x ← x + y
    p = zext(load(pp), bit + unit)
    y = sub(x, p) # y ← x + y - p
    c = trunc(lshr(y, bit), 1) # c = MSB of y
    x = select(c, x, y) # x = (y < 0) ? x : y
    x = trunc(x, bit)
    storeN(x, pz)
    Differences with gen_add()
    19 /32

    View full-size slide

  20. • Full-Bit Prime
    • The bit size of p is a multiple of 64. e.g. 𝑝 = 256, 384
    • Non-Full-Bit Prime
    • e.g. 𝑝 = 381
    • If p is non-full-bit, the x+y value does not have to be
    zext.
    Optimizing Non-Full-Bit Primes
    x = loadN(px, N)
    y = loadN(py, N)
    x = add(x, y)
    p = load(pp)
    y = sub(x, p)
    c = trunc(lshr(y, bit - 1), 1) # c = MSB(y) = (y < 0)
    x = select(c, x, y)
    storeN(x, pz)
    ret(Void)
    20 /32

    View full-size slide

  21. • x64 code has no conditional branch (constant time).
    ◼ This code is nearly perfect.
    ◼ -target=aarch64 also generates a good code.
    Fp:add Generated Code by LLVM
    fp_add3_not_full:
    mov r8, [rdx]
    add r8, [rsi]
    mov r9, [rdx + 8]
    adc r9, [rsi + 8]
    mov r10, [rdx + 16]
    adc r10, [rsi + 16]
    mov rsi, r8
    sub rsi, [rcx]
    mov rax, r9
    sbb rax, [rcx + 8]
    mov rdx, r10
    sbb rdx, [rcx + 16]
    mov rcx, rdx
    sar rcx, 63
    cmovs rdx, r10
    cmovs rax, r9
    cmovs rsi, r8
    mov [rdi], rsi
    mov [rdi + 8], rax
    mov [rdi + 16], rdx
    ret
    [r10:r9:r8] ← [px]
    [r10:r9:r8] += [py]
    [rdx:rax:rsi] ← [r10:r9:r8] ; copy
    [rdx:rax:rsi] -= [pp]
    check the top bit of rdx
    [rdx:rax:rsi] ← [r10:r9:r8] if CF (<0)
    [pz] ← [rdx:rax:rsi]
    21 /32

    View full-size slide

  22. • u64 mulUnit(u64 *z, const u64 *x, u64 y);
    • computes 𝑥 𝑁
    × 𝑦 = 𝑧 𝑁+1
    .
    • LLVM IR has only u128 mul(u64, u64);
    • pack is a combination of shl+or
    Multiplication by u64
    [x3:x2:x1:x0]
    X y
    ----------------
    [H0:L0]
    [H1:L1]
    [H2:L2]
    [H3:L3]
    -----------------
    [z4:z3:z2:z1:z0]
    L=[ 0:L3:L2:L1:L0]

    H=[H3:H2:H1:H0: 0]

    z=[z4:z3:z2:z1:z0]
    with Function(f'mulUnit{N}', z, px, y)
    Ls = []
    Hs = []
    y = zext(y, unit*2)
    for i in range(N):
    x = load(getelementptr(px, i)) # x[i]
    xy = mul(zext(x, unit*2), y)
    Ls.append(trunc(xy, unit))
    Hs.append(trunc(lshr(xy, unit), unit))
    bu = bit + unit
    L = zext(pack(Ls), bu)
    H = shl(zext(pack(Hs), bu), unit)
    z = add(L, H)
    22 /32

    View full-size slide

  23. • mulUnit3
    • The DSL code generates many shifts and or operations,
    so it might be slow, but,
    Generated ll by DSL
    define i256 @mulUnit3(i64* noalias %r2, i64 %r3) {
    ...(snip)
    %r25 = zext i64 %r11 to i128
    %r26 = shl i128 %r25, 64
    %r27 = or i128 %r24, %r26
    %r28 = zext i128 %r27 to i192
    %r29 = zext i64 %r15 to i192
    %r30 = shl i192 %r29, 128
    %r31 = or i192 %r28, %r30
    %r32 = zext i192 %r23 to i256
    %r33 = zext i192 %r31 to i256
    %r34 = shl i256 %r33, 64
    %r35 = add i256 %r32, %r34
    ret i256 %r35
    }
    23 /32

    View full-size slide

  24. • It uses only mul(3) + add(1) + adc(2)
    • This is the best code and –mbmi2 then mulx is used.
    ◼ mulx
    ◼ supported by Haswell or laster.
    ◼ clashes the older system.
    ◼ does not change CF, so we can use it with
    add and adc.
    ◼ reduces the use of temporary regs.
    Generated Code for x64
    mclb_mulUnit3:
    mov rcx, rdx
    mov rax, rdx
    mul qword [rsi]
    mov r8, rdx
    mov r9, rax
    mov rax, rcx
    mul qword [rsi + 8]
    mov r10, rdx
    mov r11, rax
    mov rax, rcx
    mul qword [rsi + 16]
    add r11, r8
    adc rax, r10
    adc rdx, 0
    mov [rdi], r9
    mov [rdi + 8], r11
    mov [rdi + 16], rax
    mov rax, rdx
    ret
    mclb_mulUnit3:
    mulx r10, r8, [rsi]
    mulx r9, rcx, [rsi + 8]
    mulx rax, rdx, [rsi + 16]
    add rcx, r10
    adc rdx, r9
    adc rax, 0
    mov [rdi], r8
    mov [rdi + 8], rcx
    mov [rdi + 16], rdx
    ret
    -O2 -mbmi2
    24 /32

    View full-size slide

  25. • A64 code also seems good.
    Generated Code for A64
    mulUnit3:
    ldp x9, x10, [x1]
    mov x8, x0 # pz = x8
    ldr x11, [x1, #16] # x = [x11:x10:x9]
    umulh x12, x9, x2 # H0 = x[0] * y
    umulh x13, x10, x2 # H1 = x[1] * y
    mul x10, x10, x2 # L1 = x[1] * y
    mul x14, x11, x2 # L2 = x[2] * y
    umulh x11, x11, x2 # H2 = x[2] * y
    adds x10, x12, x10 # z1 = H0 + L1
    adcs x12, x13, x14 # z2 = H1 + L2 + CF
    mul x9, x9, x2 # L0 = x[0] * y
    cinc x0, x11, hs # ret = H2 + CF
    str x12, [x8, #16] # pz[2] = z2
    stp x9, x10, [x8] # pz[0] = L0, pz[1] = z1
    ret
    A64 mul operations
    M=1<<64
    mul(x:u64, y:u64):
    return x*y % M
    umulh(x:u64, y:u64):
    return (x*y)>>64
    25 /32

    View full-size slide

  26. • A method for fast modular multiplication
    • toM(x) : converts x to a Montgomery form
    • fromM(x) : converts x to a normal form
    • 𝑥 × 𝑦 ≡ 𝑓𝑟𝑜𝑚𝑀 𝑚𝑜𝑛𝑡 𝑡𝑜𝑀 𝑥 , 𝑡𝑜𝑀 𝑦 mod p
    • Preparation
    • 𝐿 = 64 𝑜𝑟 32 and 𝑀 = 2𝐿
    • 𝑝 = 𝑝 𝑁
    : a prime
    • 𝑖𝑝 ≡ 𝑝−1 mod 𝑀
    Montgomery Multiplication
    26 /32

    View full-size slide

  27. • An example code of mont(x, y)
    Mont using Python
    L=64
    M=2**L
    # ip = p^(-1) mod M
    def mont(x, y):
    t = 0
    for i in range(N):
    t += x * ((y >> (L * i)) % M) # mulUnit
    q = ((t % M) * ip) % M # u64 x u64 → u64
    t += q * p # mulUnit
    t >>= L
    if t >= p:
    t -= p
    return t
    27 /32

    View full-size slide

  28. • Assume Non-Full-Bit prime
    Mont using DSL
    def gen_mont(name, mont, mulUnit):
    pz = IntPtr(unit)
    px = IntPtr(unit)
    py = IntPtr(unit)
    with Function(name, Void, pz, px, py):
    t = Imm(0, unit*N+unit)
    for i in range(N):
    y = load(getelementptr(py, i))
    t = add(t, call(mulUnit, px, y))
    q = mul(trunc(t, unit), ip)
    t = add(t, call(mulUnit, pp, q))
    t = lshr(t, unit)
    t = trunc(t, unit*N)
    vc = sub(t, loadN(pp, N))
    c = trunc(lshr(vc, unit*N-1), 1)
    storeN(select(c, t, vc), pz)
    ret(Void)
    # Python code
    t += x * ((y >> (L*i)) % M)
    q = ((t % M) * ip) % M
    t += q * p
    t >>= L
    if t >= p:
    t -= p
    𝑧 𝑁+1
    =mulUnit( 𝑥 𝑁
    ,y)
    28 /32

    View full-size slide

  29. • Compare with blst (https://github.com/supranational/blst)
    • on MacBook Pro, M1, 2020
    nsec
    • mont on Xeon Platinum 8280 (turbo boost off)
    • DSL+LLVM with –O2 is not fast.
    • slower than mcl even if –mbmi2 (using mulx) option is added.
    • mcl (and blst) uses adox/adcx that can be used after IvyBridge.
    Benchmark of Generated Code by DSL
    381 bit prime blst DSL+LLVM
    add 6.40 6.50
    sub 5.30 5.60
    mont 31.40 31.50
    381 bit prime mcl DSL+LLVM DSL+LLVM –mbmi2
    mont 28.93 42.86 35.52
    29 /32

    View full-size slide

  30. • We only have one CF, then z and xy must be kept in
    registers until they are added.
    • A64 has 29 regs. but x64 has only 15 regs.
    • adox/adcx offer two CFs. Register spills can be avoided.
    • clang/gcc will not use the instruction set now.
    • https://gcc.gnu.org/legacy-ml/gcc-help/2017-08/msg00100.html
    Computing 𝑧 𝑁
    + 𝑥 𝑁
    ∗ 𝑦 with adox/adcx
    [ z ]
    [H0 L0] ← 𝑥0
    𝑦
    [H1 L1] ← 𝑥1
    𝑦
    [H2 L2] ← 𝑥2
    𝑦
    [H3 L3] ← 𝑥3
    𝑦
    ------------------
    [z4:z3:z2:z1:z0]
    adcx
    adox
    30 /32

    View full-size slide

  31. • How to make mulAdd : z += [px] * rdx
    • z is an array of registers.
    • px is a pointer to x.
    • L, H are temporary registers.
    • rdx is y.
    • make Mont
    using mulAdd
    DSL for x64
    # z[n..0] = z[n-1..0] + px[n-1..0] * rdx
    def mulAdd(z, px):
    n = len(z)-1
    xor_(z[n], z[n])
    for i in range(n):
    mulx(H, L, ptr(px+i*8)) # [H:L] = px[i] * y
    adox(z[i], L) # z[i] += L with CF
    if i == n-1:
    break
    adcx(z[i + 1], H) # z[i+1] += H with CF'
    adox(z[n], H) # z[n] += H with CF
    adc(z[n], 0) # z[n] += CF'
    31 /32

    View full-size slide

  32. • DSL that generates LLVM IR
    • Advantages
    • Easy to read and write
    • Fast enough on M1
    • Disadvantages
    • Not fully optimized for x64
    • May be solved by compiler optimization in the future.
    • Requires x64-specific DSL now.
    Conclusion
    32 /32

    View full-size slide