Archive for the ‘Ada’ Category

Unrolled x86_64 Assembly Multiplication for updated to Ch. 14 FFA

Thursday, July 18th, 2019

In this post, I'll update the x86 assembly multiplication routines to the FFA Chapter 14.

But first, seals for FFA Chapters 13 and 14:

curl 'http://bvt-trace.net/vpatches/ffa_ch13_measure_and_qshifts.kv.vpatch.bvt.sig' > ffa_ch13_measure_and_qshifts.kv.vpatch.bvt.sig
curl 'http://bvt-trace.net/vpatches/ffa_ch14_barrett.kv.vpatch.bvt.sig' > ffa_ch14_barrett.kv.vpatch.bvt.sig

Some comments that I have to Chapter 13: for FZ_Quiet_Shift*_SubW_Soft, I don't see at all how the carry is slided in at fz_qshft.adb:65 and fz_qshft.adb:105. I'd say that this comment would be more appropriate at fz_qshft.adb:46 and fz_qshft.adb:86. In the end, to better understand the code, I have rewritten it to a bit different form, keeping the algorithm itself untouched:

   -- Constant-time subword shift, for where there is no barrel shifter
   procedure FZ_Quiet_ShiftRight_SubW_Soft(N        : in FZ;
                                           ShiftedN : in out FZ;
                                           Count    : in WBit_Index) is
      Nw  : constant Word  := Word(Count);
      nC  : constant WBool := W_ZeroP(Nw); -- 'no carry' for Count == 0 case
      Ni  : Word := 0; -- Current word
      C   : Word := 0; -- Current carry
      S   : Positive;  -- Current shiftness level
      B   : Word;      -- Quantity of shift (bitwalked over)
      CB  : Word;      -- Quantity of carry counter-shift (bitwalked over)
   begin
      for i in reverse N'Range loop
         -- Need to set it here as N and ShiftedN can be the same array
         Ni          := N(i);

         -- Write down carry from previous iteration
         ShiftedN(i) := C;

         -- For each shift level (of the subword shiftvalue width) :
         C           := W_Mux(Ni, 0, nC);
         S           := 1;
         B           := Word(Count);
         CB          := Word(Bitness) - B;
         -- For each shift level (of the subword shiftvalue width) :
         for j in 1 .. BitnessLog2 loop
            -- Shift and mux the current word
            Ni := Shift_Right_Gated(Ni, S, B and 1);
            B  := Shift_Right(B,  1);
            -- Shift and mux the current carry
            C  := Shift_Left_Gated(C, S, CB and 1);
            CB := Shift_Right(CB, 1);
            -- Go to the next shiftness level
            S  := S * 2;
         end loop;
         -- Slide in the leftovers of the current Word N(i)
         ShiftedN(i) := ShiftedN(i) or Ni;
      end loop;
   end FZ_Quiet_ShiftRight_SubW_Soft;

   -- Constant-time subword shift, for where there is no barrel shifter
   procedure FZ_Quiet_ShiftLeft_SubW_Soft(N        : in FZ;
                                          ShiftedN : in out FZ;
                                          Count    : in WBit_Index) is
      Nw  : constant Word  := Word(Count);
      nC  : constant WBool := W_ZeroP(Nw); -- 'no carry' for Count == 0 case
      Ni  : Word := 0; -- Current word
      C   : Word := 0; -- Current carry
      S   : Positive;  -- Current shiftness level
      B   : Word;      -- Quantity of shift (bitwalked over)
      CB  : Word;      -- Quantity of carry counter-shift (bitwalked over)
   begin
      for i in N'Range loop
         -- Need to set Ni here as N and ShiftedN can be the same array
         Ni          := N(i);

         -- Write down carry from previous iteration
         ShiftedN(i) := C;

         -- For each shift level (of the subword shiftvalue width) :
         C           := W_Mux(Ni, 0, nC);
         S           := 1;
         B           := Word(Count);
         CB          := Word(Bitness) - B;
         for j in 1 .. BitnessLog2 loop
            -- Shift and mux the current word
            -- If have to shift at current position, do the shift
            Ni := Shift_Left_Gated(Ni, S, B and 1);
            B := Shift_Right(B, 1);
            -- Shift and mux the current carry
            C  := Shift_Right_Gated(C, S, CB and 1);
            CB := Shift_Right(CB, 1);
            -- Go to the next shiftness level
            S  := S * 2;
         end loop;
         -- Slide in the leftovers of the current Word N(i)
         ShiftedN(i) := ShiftedN(i) or Ni;
      end loop;
   end FZ_Quiet_ShiftLeft_SubW_Soft;

where Shift_Left_Gated and Shift_Right_Gated do constant-time shifting conditional on the third argument:

package body W_Shifts is
   function Shift_Left_Gated (W: Word;
                              Amount: Positive;
                              Gate: WBool)
                             return Word is
      Temp: Word := Shift_Left(W, Amount);
   begin
      return W_Mux(W, Temp, Gate);
   end Shift_Left_Gated;

   function Shift_Right_Gated (W: Word;
                               Amount: Positive;
                               Gate: WBool)
                              return Word is
      Temp: Word := Shift_Right(W, Amount);
   begin
      return W_Mux(W, Temp, Gate);
   end Shift_Right_Gated;
end W_Shifts;

These changes are available in the following vpatch. I do not suggest to modify the code this way, merely that this was the form in which I finally groked them:

curl 'http://bvt-trace.net/vpatches/ch13-more-abstractions.vpatch' > ch13-more-abstractions.vpatch
curl 'http://bvt-trace.net/vpatches/ch13-more-abstractions.vpatch.bvt.sig' > ch13-more-abstractions.vpatch.bvt.sig

For the Chapter 14, I have no comments to the code, and only to the test tapes: 100_shots_4096bit_unif_rng.*.tape contain a mixture of all sorts of bitnesses, and definitely more than 100 shots of modexp. For my personal testing, I regenerated them using unif_testgen.tape.


Now, for the updated assembly code: the first thing that I had to add was unrolled version of the FZ_Low_Mul_Comba. The changes follow the approach uses in previous posts. That is, a fast path for FZ_Low_Mul_Comba:

   -- Comba's low multiplier fastpath. (CAUTION: UNBUFFERED)
   procedure FZ_Low_Mul_Comba_Fast(X     : in  FZ;
				   Y     : in  FZ;
				   XY    : out FZ)
   is
      procedure Asm_Comba(X      : in  FZ;
                          Y      : in  FZ;
                          XY     : out FZ;
                          L      : in  Word_Index);
      pragma Import (C, Asm_Comba, "x86_64_comba_lomul_unrolled");
   begin
      pragma Assert(X'Length = Low_Mul_Thresh and
                      Y'Length = Low_Mul_Thresh and
                      XY'Length = Low_Mul_Thresh);
      Asm_Comba(X, Y, XY, X'Length);
   end FZ_LowMul_Comba_Fast;

   ----------------------

   -- Low-Only Multiplier. (CAUTION: UNBUFFERED)
   procedure FZ_Low_Multiply_Unbuffered(X     : in  FZ;
                                        Y     : in  FZ;
                                        XY    : out FZ) is

      -- The length of either multiplicand
      L : constant Word_Count := X'Length;

   begin

      if L = Low_Mul_Thresh then

	 -- Optimized case:
	 FZ_Low_Mul_Comba_Fast(X, Y, XY);

      elsif L < Low_Mul_Thresh then

         -- Base case:
         FZ_Low_Mul_Comba(X, Y, XY);

      else

         -- Recursive case:
         Low_Mul(X, Y, XY);

      end if;

   end FZ_Low_Multiply_Unbuffered;

The assembly code reuses macros introduced in the previous posts as well:

.global x86_64_comba_lomul_unrolled
x86_64_comba_lomul_unrolled:
push rbx

cmp rcx, Low_Karatsuba_Thresh
jne size_fail_lomul

mov rcx, rdx   # RCX := XY
xor r8,  r8    # A0  := 0
xor r9,  r9    # A1  := 0
xor r10, r10   # A2  := 0
xor rbx, rbx   # N   := 0

gen_loop_low Low_Karatsuba_Thresh

pop rbx
ret

size_fail_lomul:
ud2

In general, nothing new and exciting - the only difference from x86_64_comba_mul_unrolled is that register r12 is not used in this version of routine, and that the loop for calculating the higher part of the result is not generated. One change to assembly code that affects both x86_64_comba_mul_unrolled and x86_64_comba_lomul_unrolled is a small optimization to macro gen_col_inner that reduces the number of instructions executed but does not affect the execution speed that much:

.macro gen_col_inner I NIter
.if \NIter > \I
gen_col_inner "(\I + 1)" \NIter
mov rax, rbx             # rax := N
sub rax, r11             # rax := N - j
mov rdx, [rsi + 8*rax]   # rdx := *(Y'Address + (N - j)*8)
mov rax, [rdi + 8*r11]   # rax := X(j) := *(X'Address + j*8)
mul rdx                  # rdx:rax := rax*rdx
accumulate               # (A0, A1, A2) := (A0, A1, A2) + (rax, rdx, 0)
inc r11                  # J := J + 1
.endif
.endm

Functions introduced in previous chapters required no changes.

As far as the performance is concerned, the total time in seconds for execution of 100 modexps of each bitness are available in the table:

Variant Bitness /
Runtime (s)
2048 4096 8192
Orig. Ch.14 14.623825906 89.870492208 547.992436110
Asm-8 5.520542900 34.096842345 210.922537664
Asm-16 4.238617218 26.818359452 166.722059786
Asm-32 3.839936113 24.636377576 153.749986906

So for the maximum unrolling, the performance benefit seems to be ~3.5x, which is quite good, but not as good as expected 5x.

Where are the cycles spent? Perfing Asm-32 version showed the following distribution:

  • 55% - unrolled Comba multiplication
  • 17% - Karatsuba multiplication
  • 15% - modexp with its inline routines (FZ_Quiet_ShiftRigth stands out as the biggest cycle consumer).
  • 4% - unrolled squaring
  • 4% - unrolled lower-part Comba multiplication.

So this leaves 2 biggest directions for optimizing the performance further:

  1. The biggest performance improvement can come from improving unrolled Comba multiplication even further. I currently see no way to achieve that.
  2. A more limited performance gain can come from further asming the codebase (FZ_Quiet_ShiftRigth and Karatsuba multiplication).

The vpatches are here:

curl 'http://bvt-trace.net/vpatches/ch14_asm_comba.vpatch' > ch14_asm_comba.vpatch
curl 'http://bvt-trace.net/vpatches/ch14_asm_comba.vpatch.bvt.sig' > ch14_asm_comba.vpatch.bvt.sig