 -- Message reader & writers for SMG Communication Protocol
 -- S.MG, 2018

with Interfaces; use Interfaces;
with Serpent;
with System; use System;

package body Messages is

  ----------------------
  -- Serpent Messages --
  ----------------------

  procedure Write_SKeys_SMsg( Keyset  : in Serpent_Keyset;
                              Counter : in Interfaces.Unsigned_16;
                              Msg     : out Raw_Types.Serpent_Msg) is
  begin
    -- call internal write on Octets with correct type id
    Write_SKeys( Keyset, Counter, SKeys_S_Type, Msg );
  end Write_SKeys_SMsg;


  -- Reads a Serpent keyset from given Serpent Message
  procedure Read_SKeys_SMsg( Msg     : in Raw_Types.Serpent_Msg;
                             Counter : out Interfaces.Unsigned_16;
                             Keyset  : out Serpent_Keyset) is
  begin
    -- check type id and call internal Read_SKeys if correct
    if Msg(Msg'First) /= SKeys_S_Type then
      raise Invalid_Msg;
    else
      Read_SKeys( Msg, Counter, Keyset );
    end if;
  end Read_SKeys_SMsg;

  -- writes given key mgm structure into a Serpent message
  procedure Write_KMgm_SMsg( KMgm    : in Keys_Mgm;
                             Counter : in Interfaces.Unsigned_16;
                             Msg     : out Raw_Types.Serpent_Msg) is
  begin
    -- call internal write of key mgm with correct type ID
    Write_KMgm( KMgm, Counter, Key_Mgm_S_Type, Msg );
  end Write_KMgm_SMsg;

  -- reads a key mgm structure from the given Serpent message
  procedure Read_KMgm_SMsg( Msg     : in Raw_Types.Serpent_Msg;
                            Counter : out Interfaces.Unsigned_16;
                            KMgm    : out Keys_Mgm) is
  begin
    -- check type id and call internal Read_KMgm if correct
    if Msg(Msg'First) /= Key_Mgm_S_Type then
      raise Invalid_Msg;
    else
      Read_KMgm( Msg, Counter, KMgm );
    end if;
  end Read_KMgm_SMsg;


  ------------------
  -- RSA Messages --
  ------------------

  procedure Write_SKeys_RMsg( Keyset  : in Serpent_Keyset;
                              Counter : in Interfaces.Unsigned_16;
                              Msg     : out Raw_Types.RSA_Msg) is
  begin
    -- call internal write of Serpent keys with correct type ID
    Write_SKeys( Keyset, Counter, SKeys_R_Type, Msg );
  end Write_SKeys_RMsg;

  procedure Read_SKeys_RMsg( Msg     : in Raw_Types.RSA_Msg;
                             Counter : out Interfaces.Unsigned_16;
                             Keyset  : out Serpent_Keyset) is
  begin
    -- check type id and call internal Read_SKeys if correct
    if Msg(Msg'First) /= SKeys_R_Type then
      raise Invalid_Msg;
    else
      Read_SKeys( Msg, Counter, Keyset );
    end if;
  end Read_SKeys_RMsg;

  procedure Write_KMgm_RMsg( KMgm    : in Keys_Mgm;
                             Counter : in Interfaces.Unsigned_16;
                             Msg     : out Raw_Types.RSA_Msg) is
  begin
    -- call internal write of key mgm with correct type ID
    Write_KMgm( KMgm, Counter, Key_Mgm_R_Type, Msg );
  end Write_KMgm_RMsg;

  procedure Read_KMgm_RMsg( Msg     : in Raw_Types.RSA_Msg;
                            Counter : out Interfaces.Unsigned_16;
                            KMgm    : out Keys_Mgm) is
  begin
    -- check type id and call internal Read_KMgm if correct
    if Msg(Msg'First) /= Key_Mgm_R_Type then
      raise Invalid_Msg;
    else
      Read_KMgm( Msg, Counter, KMgm );
    end if;
  end Read_KMgm_RMsg;

  ------------------
  -- private part --
  ------------------
  procedure Cast_LE( LE: in out Raw_Types.Octets ) is
  begin
    -- flip octets ONLY if native is big endian.
    if System.Default_Bit_Order = System.High_Order_First then
      declare
        BE: constant Raw_Types.Octets := LE;
      begin
        for I in 1..LE'Length loop
          LE(LE'First+I-1) := BE(BE'Last-I+1);
        end loop;
      end;
    end if;
    -- NOTHING to do for native little endian
  end Cast_LE;

  procedure Write_SKeys( Keyset  : in Serpent_Keyset;
                         Counter : in Interfaces.Unsigned_16;
                         Type_ID : in Interfaces.Unsigned_8;
                         Msg     : out Raw_Types.Octets) is
    Pos   : Integer := Msg'First;
    Check : CRC32.CRC32;
    PadLen: Integer;
    K     : Serpent.Key;
  begin
    -- write Type ID
    Msg(Pos) := Type_ID;
    Pos := Pos + 1;

    -- write count of keys (NB: this IS 8 bits by definition)
    Msg(Pos) := Keyset.Keys'Length;
    Pos := Pos + 1;

    -- write keys
    for I in Keyset.Keys'Range loop
      -- retrieve Key to write
      K := Keyset.Keys( I );

      -- write key itself
      Msg(Pos..Pos+K'Length-1) := K;
      -- ensure little endian order in message
      Cast_LE(Msg(Pos..Pos+K'Length-1));
      Pos := Pos + K'Length;

      -- write CRC of key
      Check := CRC32.CRC( K );
      Msg(Pos..Pos+3) := Raw_Types.Cast(Check);
      Cast_LE(Msg(Pos..Pos+3));
      Pos := Pos + 4;
    end loop;

    -- write flag
    Msg(Pos) := Keyset.Flag;
    Pos := Pos + 1;

    -- write message counter
    Msg(Pos..Pos+1) := Raw_Types.Cast(Counter);
    Cast_LE(Msg(Pos..Pos+1));
    Pos := Pos + 2;

    -- write padding as needed; endianness is irrelevant here
    PadLen := Msg'Last - Pos + 1;
    if PadLen > 0 then
      declare
        Pad : Raw_Types.Octets(1..PadLen);
      begin
        RNG.Get_Octets( Pad );
        Msg(Pos..Pos+PadLen-1) := Pad;
      end;
    end if;

  end Write_SKeys;

  procedure Read_SKeys( Msg     : in Raw_Types.Octets;
                        Counter : out Interfaces.Unsigned_16;
                        Keyset  : out Serpent_Keyset) is
    Pos: Integer := Msg'First;
  begin
    -- read type and check
    if Msg(Pos) = SKeys_S_Type or 
       Msg(Pos) = SKeys_R_Type then
      Pos := Pos + 1;
    else
      raise Invalid_Msg;
    end if;

    -- read count of keys and check
    if Msg(Pos) in Keys_Count'Range then
      declare
        N     : Keys_Count := Keys_Count(Msg(Pos));
        KS    : Serpent_Keyset(N);
        K     : Serpent.Key;
        Check : CRC32.CRC32;
        O4    : Raw_Types.Octets_4; 
        O2    : Raw_Types.Octets_2;
      begin
        Pos := Pos + 1;
        --read keys and check crc for each
        for I in 1 .. N loop
          -- read key and advance pos
          K := Msg(Pos..Pos+K'Length-1);
          Cast_LE(K);
          Pos := Pos + K'Length;
          -- read crc and compare to crc32(key)
          O4 := Msg(Pos..Pos+3);
          Cast_LE(O4);
          Check   := Raw_Types.Cast(O4);
          Pos := Pos + 4;
          if Check /= CRC32.CRC(K) then
            raise Invalid_Msg;
          end if;
          -- if it got here, key is fine so add to set
          KS.Keys(KS.Keys'First + I -1) := K;
        end loop;
        -- read and set flag
        KS.Flag := Msg(Pos);
        Pos := Pos + 1;
        -- read and set message counter
        O2 := Msg(Pos..Pos+1);
        Cast_LE(O2);
        Counter := Raw_Types.Cast(O2);
        -- rest of message is padding so it's ignored
        -- copy keyset to output variable
        Keyset := KS;
      end;
    else
      raise Invalid_Msg;
    end if;
  end Read_SKeys;

  -- writes given key management structure to the given octets array
  procedure Write_KMgm( KMgm    : in Keys_Mgm;
                        Counter : in Interfaces.Unsigned_16;
                        Type_ID : in Interfaces.Unsigned_8;
                        Msg     : out Raw_Types.Octets) is
    Pos   : Integer := Msg'First;
  begin
    -- write given type id
    Msg(Pos) := Type_ID;
    Pos := Pos + 1;

    -- write count of server keys requested
    Msg(Pos) := KMgm.N_Server;
    Pos := Pos + 1;

    -- write count of client keys requested
    Msg(Pos) := KMgm.N_Client;
    Pos := Pos + 1;

    -- write id of key preferred for further inbound Serpent messages
    Msg(Pos) := KMgm.Key_ID;
    Pos := Pos + 1;

    -- write count of burnt keys in this message
    Msg(Pos..Pos) := Cast( KMgm.N_Burnt );
    Pos := Pos + 1;

    -- if there are any burnt keys, write their ids
    if KMgm.N_Burnt > 0 then
      Msg( Pos .. Pos + KMgm.Burnt'Length - 1 ) := KMgm.Burnt;
      Pos := Pos + KMgm.Burnt'Length;
    end if;

    -- write the message count
    Msg(Pos..Pos+1) := Raw_Types.Cast( Counter );
    Cast_LE( Msg(Pos..Pos+1) );
    Pos := Pos + 2;

    -- pad with random octets until the end of Msg
    RNG.Get_Octets( Msg(Pos..Msg'Last) );

  end Write_KMgm;

  -- attempts to read from the given array of octets a key management structure
  procedure Read_KMgm( Msg     : in Raw_Types.Octets;
                       Counter : out Interfaces.Unsigned_16;
                       KMgm    : out Keys_Mgm) is
    Pos       : Integer := Msg'First;
    Burnt_Pos : Integer := Msg'First + 4;
  begin
    -- read type and check
    if Msg(Pos) = Key_Mgm_S_Type or 
       Msg(Pos) = Key_Mgm_R_Type then
      Pos := Pos + 1;
    else
      raise Invalid_Msg;
    end if;

    -- read the count of burnt keys and check
    -- NB: Burnt_Pos IS in range of Counter_8bits since it's an octet
    declare
      N_Burnt : Counter_8bits := Counter_8bits(Msg(Burnt_Pos));
      Mgm     : Keys_Mgm(N_Burnt);
      O2      : Raw_Types.Octets_2;
    begin
      -- read count of server keys requested
      Mgm.N_Server := Msg(Pos);
      Pos := Pos + 1;

      -- read count of client keys requested
      Mgm.N_Client := Msg(Pos);
      Pos := Pos + 1;

      -- read ID of Serpent key preferred for further inbound messages
      Mgm.Key_ID   := Msg(Pos);
      Pos := Pos + 2; --skip the count of burnt keys as it's read already

      -- read ids of burnt keys, if any
      if N_Burnt > 0 then
        Mgm.Burnt := Msg(Pos..Pos+N_Burnt-1);
        Pos := Pos + N_Burnt;
      end if;

      -- read and set message counter
      O2 := Msg(Pos..Pos+1);
      Cast_LE(O2);
      Counter := Raw_Types.Cast(O2);
      -- rest of message is padding so it's ignored
      -- copy the keys mgm structure to output param
      KMgm := Mgm;
    end;
  end Read_KMgm;


end Messages;
