 -- S.MG, 2018
with System; use System;  -- for Bit_Order

package body SMG_Keccak is

-- public function, sponge
  procedure Sponge( Input      : in Bitstream;
                    Output     : out Bitstream;
                    Block_Len  : in Keccak_Rate := Default_Bitrate ) is
    Internal  : State := (others => (others => 0));
  begin
    --absorb input into sponge in a loop on available blocks, including padding
    declare
      -- number of input blocks after padding (between 2 and block_len bits pad)
      Padded_Blocks : constant Positive := 1 + (Input'Length + 1) / Block_Len;
      Padded        : Bitstream ( 1 .. Padded_Blocks * Block_Len );
      Block         : Bitstream ( 1 .. Block_Len );
    begin
      -- initialise Padded with 0 everywhere
      Padded := ( others => 0 );
      -- copy and pad input with rule 10*1
      Padded( Padded'First .. Padded'First + Input'Length - 1 ) := Input;
      Padded( Padded'First + Input'Length )                     := 1;
      Padded( Padded'Last )                                     := 1;

      -- loop through padded input and absorb block by block into sponge
      -- padded input IS a multiple of blocks, so no stray bits left
      for B in 0 .. Padded_Blocks - 1 loop
        -- first get the current block to absorb
        Block   := Padded( Padded'First + B * Block_Len .. 
                           Padded'First + (B+1) * Block_Len - 1 );
        AbsorbBlock( Block, Internal );
        -- scramble state with Keccak function
        Internal := Keccak_Function( Internal );

      end loop; -- end absorb loop for blocks
    end; -- end absorb stage

    --squeeze required bits from sponge in a loop as needed
    declare
      -- full blocks per output
      BPO     : constant Natural := Output'Length / Block_Len;
      -- stray bits per output
      SPO     : constant Natural := Output'Length mod Block_Len;
      Block   : Bitstream( 1 .. Block_Len );
    begin
      -- squeeze block by block (if at least one full block is needed)
      for I in 0 .. BPO - 1 loop
        SqueezeBlock( Block, Internal );
        Output( Output'First + I * Block_Len .. 
                Output'First + (I + 1) * Block_Len -1) := Block;
 
        -- scramble state
        Internal := Keccak_Function( Internal );
      end loop;  -- end squeezing full blocks

      -- squeeze any partial block needed (stray bits)
      if SPO > 0 then
        SqueezeBlock( Block, Internal );
        Output( Output'Last - SPO + 1 .. Output'Last ) := 
                Block( Block'First .. Block'First + SPO - 1 );
      end if; -- end squeezing partial last block (stray bits)

    end; -- end squeeze stage
  end Sponge;

  -- public interface, state based Sponge
  procedure KeccakBegin(Ctx : in out Keccak_Context) is
  begin
     Ctx.Internal := (others => (others => 0));
     Ctx.Block := (others => 0);
     Ctx.Pos := Ctx.Block'First;
  end;

  procedure KeccakHash(Ctx : in out Keccak_Context;
                       Input : Bitstream) is
     I0 : Natural;
     I1 : Natural;
     B0 : Natural;
     B1 : Natural;
  begin
     I0 := Input'First;
     <<Block_Process_Loop>>
         I1 := Input'Last;
         B0 := Ctx.Pos;
         B1 := B0 + (I1-I0);

         if B1>Ctx.Block'Last then
            B1 := Ctx.Block'Last;
            I1 := I0 + (B1-B0);
         end if;
         Ctx.Block(B0..B1) := Input(I0..I1);
         Ctx.Pos := B1 + 1;
         -- we've filled up the buffer
         if Ctx.Pos > Ctx.Block'Last then
            AbsorbBlock(Ctx.Block, Ctx.Internal);
            Ctx.Internal := Keccak_Function(Ctx.Internal);
            Ctx.Pos := Ctx.Block'First;
         end if;
         -- we haven't processed entire input block, loop
         if I1 < Input'Last then
            I0 := I1 + 1;
            goto Block_Process_Loop;
         end if;
  end;

  procedure KeccakEnd(Ctx : in out Keccak_Context;
                      Output : out Bitstream) is
     BlocksPerOutput : constant Natural := Output'Length / Ctx.Block_Len;
     StrayPerOutput : constant Natural := Output'Length mod Ctx.Block_Len;
     Block : Bitstream(1 .. Ctx.Block_Len);
     Need : Natural;
  begin
     if Ctx.Pos /= 0 then -- needs padding
        Block := (others => 0);
        Need := Ctx.Block'Last - Ctx.Pos;
        Block(Block'First) := 1;
        Block(Block'First+Need) := 1;
        KeccakHash(Ctx, Block(1..Need+1));
     end if;

     -- squeez bits
     for I in 0 .. BlocksPerOutput - 1 loop
        SqueezeBlock(Block, Ctx.Internal);
        Output(Output'First + I * Ctx.Block_Len ..
                 Output'First + (I + 1) * Ctx.Block_Len -1) := Block;
        Ctx.Internal := Keccak_Function(Ctx.Internal);
     end loop;
     if StrayPerOutput > 0 then
        SqueezeBlock(Block, Ctx.Internal);
        Output(Output'Last - StrayPerOutput + 1 .. Output'Last) :=
          Block(Block'First .. Block'First + StrayPerOutput - 1);
     end if;
  end;

  -- convert from a bitstream of ZWord size to an actual ZWord number
  function BitsToWord( BWord: in Bitword ) return ZWord is
    W    : ZWord;
    Bits: Bitword;
  begin
    -- just copy octets if machine is little endian
    -- flip octets if machine is big endian
    if Default_Bit_Order = Low_Order_First then
      Bits := BWord;
    else
      Bits := FlipOctets( BWord );
    end if;
    -- actual bits to word conversion
    W := 0;
    -- LSB bit order (inside octet) as per Keccak spec
    for I in reverse Bitword'Range loop
      W := Shift_Left( W, 1 ) + ZWord( Bits( I ) );
    end loop;
    return W;
  end BitsToWord;

  -- convert from a ZWord (lane of state) to a bitstream of ZWord size
  function WordToBits( Word: in ZWord ) return Bitword is
    Bits: Bitword := (others => 0);
    W: ZWord;
  begin
    W := Word;
    for I in Bitword'Range loop
      Bits( I ) := Bit( W mod 2 );
      W := Shift_Right( W, 1 );
    end loop;

    -- flip octets if machine is big endian
    if Default_Bit_Order = High_Order_First then
      Bits := FlipOctets( Bits );
    end if;

    return Bits;
  end WordToBits;

  -- flip given octets (i.e. groups of 8 bits)
  function FlipOctets( BWord : in Bitword ) return Bitword is
    Bits : Bitword;
  begin
    -- copy groups of 8 octets changing their order in the array 
    -- i.e. 1st octet in BWord becomes last octet in Bits and so on
    for I in 0 .. ( Bitword'Length / 8 - 1 ) loop
      Bits ( Bits'First  + I * 8     .. Bits'First + I * 8 + 7 ) :=
      BWord( BWord'Last  - I * 8 - 7 .. BWord'Last - I * 8);
    end loop;
    return Bits;
  end FlipOctets;

-- helper procedures for sponge absorb/squeeze

  -- NO scramble here, this will absorb ALL given block, make sure it fits!
  procedure AbsorbBlock( Block: in Bitstream; S: in out State ) is
    WPB: constant Natural := Block'Length / Z_Length;   -- words per block
    SBB: constant Natural := Block'Length mod Z_Length; -- stray bits
    FromPos, ToPos        : Natural;
    X, Y                  : XYCoord;
    Word                  : ZWord;
    BWord                 : Bitword;
  begin
    -- xor current block into first Block'Length bits of state
    -- a block can consist in more than one word
    X := 0;
    Y := 0;
    for I in 0..WPB-1 loop
      FromPos := Block'First + I * Z_Length;
      ToPos   := FromPos + Z_Length - 1;
      Word := BitsToWord( Block( FromPos .. ToPos ) );
      S( X, Y ) := S( X, Y ) xor Word;
      -- move on to next word in state
      X := X + 1;
      if X = 0 then
        Y := Y + 1;
      end if;
    end loop;
    -- absorb also any remaining bits from block
    if SBB > 0 then
      ToPos := Block'Last;
      FromPos := ToPos - SBB + 1;
      BWord := (others => 0);
      BWord(Bitword'First .. Bitword'First + SBB - 1) := Block(FromPos..ToPos);
      Word := BitsToWord( BWord );
      S( X, Y ) := S( X, Y ) xor Word;
    end if;
  end AbsorbBlock;

  -- NO scramble here, this will squeeze Block'Length bits out of *same* state S
  procedure SqueezeBlock( Block: out Bitstream; S: in State) is
    X, Y    : XYCoord;
    BWord   : Bitword;
    FromPos : Natural;
    Len     : Natural;
  begin
    X := 0;
    Y := 0;
    FromPos := Block'First;

    while FromPos <= Block'Last loop
      BWord := WordToBits( S(X, Y) );

      X := X + 1;
      if X = 0 then
        Y := Y + 1;
      end if;

      -- copy full word if it fits or
      --   only as many bits as are still needed to fill the block
      Len := Block'Last - FromPos + 1;
      if Len > Z_Length then
        Len := Z_Length;
      end if;

      Block(FromPos..FromPos+Len-1) := BWord(BWord'First..BWord'First+Len-1);
      FromPos := FromPos + Len;
    end loop;
  end SqueezeBlock;


-- private, internal transformations
  function Theta(Input : in State) return State is
    Output : State;
    C      : Plane;
    W      : ZWord;
  begin
    for X in XYCoord loop
      C(X) := Input(X, 0);
      for Y in 1..XYCoord'Last loop
        C(X) := C(X) xor Input(X, Y);
      end loop;
    end loop;

    for X in XYCoord loop
      W := C(X-1) xor Rotate_Left(C(X+1), 1);
      for Y in XYCoord loop
        Output(X,Y) := Input(X,Y) xor W;
      end loop;
    end loop;

    return Output;
  end Theta;

  function Rho(Input : in State) return State is
    Output      : State;
    X, Y, Old_Y : XYCoord;
  begin
    Output(0,0) := Input(0,0);
    X           := 1;
    Y           := 0;

    for T in 0..23 loop
      Output(X, Y) := Rotate_Left(Input(X,Y), ((T+1)*(T+2)/2) mod Z_Length);
      Old_Y := Y;
      Y := 2*X + 3*Y;
      X := Old_Y;
    end loop;
    return Output;
  end rho;

  function Pi(Input : in State) return State is
    Output: State;
  begin
    for X in XYCoord loop
      for Y in XYCoord loop
        Output(Y, 2*X + 3*Y) := Input(X, Y);
      end loop;
    end loop;
    return Output;
  end pi;

  function Chi(Input : in State) return State is
    Output: State;
  begin
    for Y in XYCoord loop
      for X in XYCoord loop
        Output(X, Y) := Input(X, Y) xor 
                        ( (not Input(X + 1, Y)) and Input(X + 2, Y) );
      end loop;
    end loop;
    return Output;
  end chi;

  function Iota(Round_Const : in ZWord; Input : in State) return State is
    Output: State;
  begin
    Output := Input;
    Output(0,0) := Input(0,0) xor Round_Const;
    return Output;
  end iota;

  function Keccak_Function(Input: in State) return State is
    Output: State;
  begin
    Output := Input;
    for I in Round_Index loop
      Output := Iota(RC(I), Chi(Pi(Rho(Theta(Output)))));
    end loop;

    return Output;
  end Keccak_Function;

end SMG_Keccak;
