Unrolled x86_64 Assembly Squaring for Ch. 12 FFA

June 23rd, 2019

FFA Chapter 12 introduces a Comba squaring function as an optimization for modular exponentiation code. In this post, we will rewrite it into x86_64 assembly.

But first, I'd like to make a small update to Comba multiplication code. In the previous post, I have used an iteration scheme that is too convoluted: gen_col_inner was generating a multiplication code even when zero iterations were requested, and this was accounted for in gen_loop_low, gen_loop_high_inner by reducing the iteration count by one. I have NFI how I managed to write that code like that, but in this vpatch I bring it back to sanity1:

.macro accumulate
add r8,  rax             # A0, C := A0 + rax
adc r9,  rdx             # A1, C := A1 + rdx + C
adc r10, 0               # A2, [C=0] := A2 + 0 + C
.endm

.macro gen_col_inner I NIter
.if \NIter > \I
gen_col_inner "(\I + 1)" \NIter
lea rdx, [rsi + 8*rbx]   # rdx := Y'Address + N*8
lea rax, [8*r11]         # rax := 8*j
sub rdx, rax             # rdx := rdx - j*8
mov rdx, [rdx]           # rdx := *(rdx)
mov rax, [rdi + 8*r11]   # rax := X(j) := *(X'Address + j*8)
mul rdx                  # rdx:rax := rax*rdx
accumulate               # (A0, A1, A2) := (A0, A1, A2) + (rax, rdx, 0)
inc r11                  # J := J + 1
.endif
.endm

.macro gen_loop_low L
.if \L
gen_loop_low "(\L-1)"
xor r11, r11            # U := 0
gen_col \L
.endif
.endm

.macro gen_loop_high_inner I L
.if \L-\I
inc r12                 # I := I + 1
mov r11, r12            # U := I (U in col)
gen_col "(\L-\I)"
gen_loop_high_inner "(\I+1)" \L
.endif
.endm

One can also spot a new macro, accumulate, which will be used in the inner loops of Comba squaring as well. It generates code for adding results of multiplication into the accumulator registers.

I decided to keep register allocation of Comba squaring as close as possible to that of Comba multiplication. However, the squaring function, called x86_64_comba_sqr_unrolled, receives one of its arguments in a bit different register (Input FZ in rdi, Output FZ in rsi, FZ length in rdx), so I shuffle the Output FZ address from rsi into rcx to reuse the multiplication code. Once again, there is a check for FZ size to catch potential mismatches between Comba-Karatsuba threshold in Ada code and the FZ size used for generation of the unrolled routine.

## COMBA SQUARING
# Arguments
# RDI: X
# RSI: XY
# RDX: L
.global x86_64_comba_sqr_unrolled
x86_64_comba_sqr_unrolled:
push rbx
push r12

cmp rdx, Sqr_Karatsuba_Thresh
jne size_fail_sqr

mov rcx, rsi   # RCX := XY
xor r12, r12   # TMP := 0
xor r8,  r8    # A0  := 0
xor r9,  r9    # A1  := 0
xor r10, r10   # A2  := 0
xor rbx, rbx   # N   := 0

gen_sqr_loop_low Sqr_Karatsuba_Thresh
gen_sqr_loop_high Sqr_Karatsuba_Thresh
col_finish

pop r12
pop rbx
ret

size_fail_sqr:
ud2

The actual multiplication loops are generated by gen_sqr_loop_low and gen_sqr_loop_high, which follow the code from asciilifeform in Chapter 12. The auxiliary registers (r11, r12) are used in the same way as in the previous post:

.macro gen_sqr_loop_low L
.if \L
gen_sqr_loop_low \L-1
xor r11, r11             # U := 0
gen_sqr_col \L ( (\L)/2 )
.endif
.endm

.macro gen_sqr_loop_high_inner Col I L
.if \L > \I
inc r12                 # I := I + 1
mov r11, r12            # U := I (U in col)
gen_sqr_col \Col ( \L-(\I) )/2
gen_sqr_loop_high_inner \Col+1 \I+1 \L
.endif
.endm

.macro gen_sqr_loop_high L
gen_sqr_loop_high_inner \L+1 1 \L
.endm

In the previous vpatch, information on which column is processed was maintained only at runtime in register rbx, however to generate squaring code we also need to keep track of it at compile time, and this is the purpose of Col variable.

The crazy bracket nesting inside gen_sqr_loop_* is necessary to work around the primitive and brittle macrosystem of GNU assembler, which does sed-like text expansion, ignoring operation priorities.

.macro gen_sqr_col Col L
.if (\Col) & 1 == 0
gen_sqr_mul \L
.else
gen_sqr_mul \L
gen_sqr_square
.endif
col_finish
.endm

Macro gen_sqr_col takes care of per-colunm multiplication loop generation, choosing correct code for odd (multiplications) and even (multiplications and a squaring) columns. Column variable (Col) is used to resolve the branch at code generation time; the branch does not exist at runtime and thus it cannot possibly leak information via timing/cache side channel.

.macro gen_sqr_mul_inner I NIter
.if \NIter > \I
gen_sqr_mul_inner "(\I+1)" \NIter
lea rdx, [rdi + 8*rbx]   # rdx := X'Address + N*8
lea rax, [8*r11]         # rax := 8*j
sub rdx, rax             # rdx := rdx - j*8
mov rdx, [rdx]           # rdx := *(rdx)
mov rax, [rdi + 8*r11]   # rax := X(j) := *(X'Address + j*8)
mul rdx                  # rdx:rax := rax*rdx
accumulate
accumulate
inc r11                  # J := J + 1
.endif
.endm

.macro gen_sqr_square
mov rdx, [rdi + 8*r11]   # rdx := X'Address + j*8
mov rax, rdx             # rax := rdx
mul rdx                  # rdx:rax := rax*rdx
accumulate
.endm

.macro gen_sqr_mul NIter
gen_sqr_mul_inner 0 \NIter
.endm

The innermost multiplication macro for squaring routine is slightly edited version of the one used for Comba multiplication: the differences are double accumulation, and that all loads are done via register rdi (because squaring works only with one FZ).

The gen_sqr_square is responsible for actual squaring (X(j)*X(j)).

The integration of assembly code into Ada is done same as before, using 'fastpath' approach:

   -- Squaring. (CAUTION: UNBUFFERED)
   procedure FZ_Square_Unbuffered(X     : in  FZ;
                                  XX    : out FZ) is
   begin

      if X'Length = Sqr_Karatsuba_Thresh then

         -- Optimized case:
         Fz_Sqr_Comba_Fast(X, XX);

      elsif X'Length < Sqr_Karatsuba_Thresh then

         -- Base case:
         FZ_Sqr_Comba(X, XX);

      else

         -- Recursive case:
         Sqr_Karatsuba(X, XX);

      end if;

   end FZ_Square_Unbuffered;

-------------------------------------------------------

   -- Multiplier. (CAUTION: UNBUFFERED)
   procedure FZ_Multiply_Unbuffered(X     : in  FZ;
                                    Y     : in  FZ;
                                    XY    : out FZ) is

      -- The length of either multiplicand
      L : constant Word_Count := X'Length;

   begin

      if L = Karatsuba_Thresh then

	 -- Optimized case:
         FZ_Mul_Comba_Fast(X, Y, XY);

      elsif L < Karatsuba_Thresh then

         -- Base case:
         FZ_Mul_Comba(X, Y, XY);

      else

         -- Recursive case:
         Mul_Karatsuba(X, Y, XY);

      end if;

   end FZ_Multiply_Unbuffered;

I have measured the performance of the resulting code using a squaring microbenchmark, we can see that squaring is around 20% faster than multiplication, both for the original Ada code and for assembler code:

Variant Bitness /
Runtime (s)
2048 4096 8192
Orig. Ch.12 0.476 1.458 4.498
Asm-8 0.167 0.558 1.783
Asm-16 0.120 0.415 1.345
Asm-32 0.101 0.362 1.195

My seal for FFA Chapter 12B is available here:

curl 'http://bvt-trace.net/vpatches/ffa_ch12_karatsuba_redux.kv.vpatch.bvt.sig' > ffa_ch12_karatsuba_redux.kv.vpatch.bvt.sig

And the vpatch and seal for the code in this post:

curl 'http://bvt-trace.net/vpatches/ch12_unrolled_asm_comba_sqr.vpatch' > ch12_unrolled_asm_comba_sqr.vpatch
curl 'http://bvt-trace.net/vpatches/ch12_unrolled_asm_comba_sqr.vpatch.bvt.sig' > ch12_unrolled_asm_comba_sqr.vpatch.bvt.sig

Please write if you have any questions.

  1. gen_col_inner does nothing when zero iterations are requested, -1 offsets are removed from gen_col's []