 -- Message reader & writers for SMG Communication Protocol
 -- S.MG, 2018

with Interfaces; use Interfaces;
with Serpent;
with System; use System;
with Ada.Assertions; use Ada.Assertions;

package body Messages is

  ----------------------
  -- Serpent Messages --
  ----------------------

  procedure Write_SKeys_SMsg( Keyset  : in Serpent_Keyset;
                              Counter : in Interfaces.Unsigned_16;
                              Msg     : out Raw_Types.Serpent_Msg) is
  begin
    -- call internal write on Octets with correct type id
    Write_SKeys( Keyset, Counter, SKeys_S_Type, Msg );
  end Write_SKeys_SMsg;


  -- Reads a Serpent keyset from given Serpent Message
  procedure Read_SKeys_SMsg( Msg     : in Raw_Types.Serpent_Msg;
                             Counter : out Interfaces.Unsigned_16;
                             Keyset  : out Serpent_Keyset) is
  begin
    -- check type id and call internal Read_SKeys if correct
    if Msg(Msg'First) /= SKeys_S_Type then
      raise Invalid_Msg;
    else
      Read_SKeys( Msg, Counter, Keyset );
    end if;
  end Read_SKeys_SMsg;

  -- writes given key mgm structure into a Serpent message
  procedure Write_KMgm_SMsg( KMgm    : in Keys_Mgm;
                             Counter : in Interfaces.Unsigned_16;
                             Msg     : out Raw_Types.Serpent_Msg) is
  begin
    -- call internal write of key mgm with correct type ID
    Write_KMgm( KMgm, Counter, Key_Mgm_S_Type, Msg );
  end Write_KMgm_SMsg;

  -- reads a key mgm structure from the given Serpent message
  procedure Read_KMgm_SMsg( Msg     : in Raw_Types.Serpent_Msg;
                            Counter : out Interfaces.Unsigned_16;
                            KMgm    : out Keys_Mgm) is
  begin
    -- check type id and call internal Read_KMgm if correct
    if Msg(Msg'First) /= Key_Mgm_S_Type then
      raise Invalid_Msg;
    else
      Read_KMgm( Msg, Counter, KMgm );
    end if;
  end Read_KMgm_SMsg;

  ------ File Transfer ------
  procedure Write_File_Transfer( Chunk   : in File_Chunk;
                                 Msg     : out Raw_Types.Serpent_Msg) is
    Pos: Integer := Msg'First;
    U16: Interfaces.Unsigned_16;
  begin
    -- write type ID
    Msg(Pos) := File_Transfer_S_Type;
    Pos := Pos + 1;

    -- write filename as text field (size+2, text)
    -- check against overflows
    if Chunk.Name_Len > Text_Len'Last - 2 or
       Pos + Integer(Chunk.Name_Len) + 2 > Msg'Last then
      raise Invalid_Msg;
    end if;

    -- write total size: filename size + 2 
    U16 := Interfaces.Unsigned_16( Chunk.Name_Len + 2 );
    Write_U16( Msg, Pos, U16 );

    -- write filename
    String_To_Octets( Chunk.Filename, 
                      Msg(Pos..Pos+Integer(Chunk.Name_Len)-1) );
    Pos := Pos + Integer(Chunk.Name_Len);

    --write content
    -- check against overflow, including the 2 octets for counter at the end
    if Chunk.Len > Text_Len'Last - 2 or
       Pos + Integer(Chunk.Len) + 4 > Msg'Last then
      raise Invalid_Msg;
    end if;

    -- write total size for this text field
    U16 := Interfaces.Unsigned_16( Chunk.Len + 2 );
    Write_U16( Msg, Pos, U16 );

    -- write actual content
    Msg(Pos..Pos+Chunk.Content'Length-1) := Chunk.Content;
    Pos := Pos + Chunk.Content'Length;

    -- write counter
    Write_U16( Msg, Pos, Chunk.Count );

    -- write padding if needed
    if Pos <= Msg'Last then
      RNG.Get_Octets( Msg(Pos..Msg'Last) );
    end if;

  end Write_File_Transfer;

  -- The opposite of Write_File_Transfer method above.
  -- Counter will contain the message counter
  -- Chunk will contain the chunk counter, filename and content
  procedure Read_File_Transfer( Msg     : in Raw_Types.Serpent_Msg;
                                Chunk   : out File_Chunk) is
    Pos: Integer := Msg'First;
    U16: Interfaces.Unsigned_16;
    S_Name, E_Name: Integer; --start/end for filename in Msg
    S_Len: Text_Len; -- length of filename (needed as Text_Len anyway)
    S_Content, E_Content: Integer; --start/end for content in Msg
    Content_Len: text_Len; -- length of content (needed as Text_Len anyway)
  begin
    -- read and check type ID
    if Msg(Pos) /= File_Transfer_S_Type then
      raise Invalid_Msg;
    end if;
    Pos := Pos + 1;

    -- read filename size
    Read_U16( Msg, Pos, U16 );

    -- check for overflow and underflow; filename size >= 1
    if Pos + Integer(U16) - 2 > Msg'Last or
       U16 < 3 then
      raise Invalid_Msg;
    end if;
    U16 := U16 - 2;
    S_Len := Text_Len(U16);

    -- set start + end for reading filename later, when ready
    S_Name := Pos;
    E_Name := Pos + Integer(U16)-1;
    Pos := Pos + S_Len;

    -- read size of content
    Read_U16( Msg, Pos, U16 );
    -- check for overflow and underflow; content >=1; counter =2 octets
    if Pos + Integer(U16) - 1 > Msg'Last or
       U16 < 3 then
      raise Invalid_msg;
    end if;
    U16 := U16 - 2;
    Content_Len := Text_Len(U16);
    -- set start and end for reading content later, when ready
    S_Content := Pos;
    E_Content := Pos + Integer(U16) - 1;
    Pos := Pos + Content_Len;

    -- read counter
    Read_U16( Msg, Pos, U16 );
    -- check chunking validity i.e. if counter>0 then no padding
    if U16 /= 0 and Pos /= Msg'Last then
      raise Invalid_Msg;
    end if;

    -- create File_Chunk structure and fill it with data from Msg
    declare
      FC : File_Chunk( Len      => Content_Len, 
                       Count    => U16,
                       Name_Len => S_Len);
    begin
      -- read from Msg
      FC.Content  := Msg( S_Content..E_Content );
      Octets_To_String( Msg( S_Name..E_Name ), FC.Filename);
      -- copy to output var
      Chunk := FC;
    end;

  end Read_File_Transfer;

  ---- File Requests ----
  procedure Write_File_Request( FR      : in Filenames;
                                Counter : in Interfaces.Unsigned_16;
                                Msg     : out Raw_Types.Serpent_Msg;
                                Written : out Natural) is
    Pos    : Integer := Msg'First;
    Max_Pos: Integer := Msg'Last - 2; -- 2 octets at end for counter
    Text_Sz: Integer;
    Max_Sz : Integer;
  begin
    -- write ID for File Request type
    Msg( Pos ) := File_Req_S_Type;
    Pos := Pos + 1;

    -- write Text size: filenames + separators
    -- consider fewer filenames if they don't ALL fit
    -- 2 octets are taken by size itself
    Max_Sz := Max_Pos - Pos - 1;
    Text_Sz := FR.Sz + FR.F_No - 1;
    if Text_Sz > Max_Sz then
      -- walk the array of filenames backwards and stop when they fit
      Written := FR.F_No - 1;
      -- calculate actual size written based on start of first discarded
        -- filename and (Written -1) octets for needed separators
      Text_Sz := Integer(FR.Starts(Written+1)) - FR.Starts'First + 
                   (Written - 1);
  
      -- loop until either fits or nothing left
      while Written > 0 and Text_Sz > Max_Sz loop
        Written := Written - 1;
        Text_Sz := Integer(FR.Starts(Written+1))- FR.Starts'First + 
                     (Written - 1);
      end loop;
      -- check that there is what to write, since nothing -> invalid message
      if Written = 0 then
        raise Invalid_Msg;
      end if;

    else --from if Text_Sz > Max_Sz
      -- ALL are written
      Written := FR.F_No;  
    end if;

    -- write Text_Sz + 2 (i.e. TOTAL size)
    if Text_Sz + 2 > Integer(Interfaces.Unsigned_16'Last) then
      raise Invalid_Msg;
    end if;

    Write_U16( Msg, Pos, Interfaces.Unsigned_16(Text_Sz+2) );

    -- write filenames separated by Sep
    for I in 1..Written loop
      declare
        Start_Pos : Positive;
        End_Pos   : Positive;
        Len       : Positive;
      begin
        -- current start pos in FR.S
        Start_Pos := Positive( FR.Starts( FR.Starts'First + I - 1));

        -- calculate end based on start of next name or last
        if I < FR.F_No then
          End_Pos := Positive( FR.Starts( FR.Starts'First + I)) - 1;
        else
          End_Pos := FR.S'Last;
        end if;

        -- NB: this WILL fail if starting positions are not in order!
        Len := End_Pos - Start_Pos + 1;
        if Len <= 0 then
          raise Invalid_Msg;
        end if;
  
        --write the actual filename
        String_To_Octets( FR.S( Start_Pos..End_Pos ), Msg(Pos..Pos+Len-1) );
        Pos := Pos + Len;

        --if it's not the last one, write a separator
        if I < Written then
          Msg(Pos) := Sep;
          Pos := Pos + 1;
        end if;
      end;
    end loop;

    -- write the message counter in little endian at all times
    Write_U16( Msg, Pos, Counter );

    -- write padding if needed
    if Pos <= Msg'Last then
      Rng.Get_Octets( Msg(Pos..Msg'Last) );
    end if;
  end Write_File_Request;
  
  -- Reads a request for files; the opposite of Write_File_Request above
  procedure Read_File_Request( Msg      : in Raw_Types.Serpent_Msg;
                               Counter  : out Interfaces.Unsigned_16;
                               FR       : out Filenames) is
    Pos       : Integer := Msg'First;
    Max_Pos   : Integer := Msg'Last - 2; --at least 2 reserved for counter
    Text_Sz   : Integer;
    Max_Sz    : Integer := Max_Pos - Pos - 1; --text only i.e. w.o. size itself
    F_No      : Integer;
    U16       : Interfaces.Unsigned_16;
  begin
    -- read type ID and check
    if Msg(Pos) /= File_Req_S_Type then
      raise Invalid_Msg;
    end if;
    Pos := Pos + 1;

    -- read total size of filenames+separators
    Read_U16( Msg, Pos, U16 );
    Text_Sz := Integer(U16);
    -- take away the 2 octets for size itself
    Text_Sz := Text_Sz - 2;
    
    -- check that Text_Sz is not overflowing/underflowing
    if Text_Sz < 1 or Text_Sz > Max_Sz then
      raise Invalid_Msg;
    end if;

    -- count first the separators to know how many filenames
    -- NB: there is always at least 1 filename as Text_Sz > 0
    F_No := 1;
    for I in Pos .. Pos + Text_Sz - 1 loop
      if Msg(I) = Sep then
        F_No := F_No + 1;
      end if;
    end loop;

    -- create the output structure and discard separators
    -- text without separators should be Text_Sz - F_No + 1 
    -- (because ONLY one separator between 2 filenames allowed)
    -- if it's not that => Invalid_Msg
    -- F_No and Text_Sz are not overflow (earlier check + calc)
    declare
      F     : Filenames(Text_Len(F_No), Text_Len(Text_Sz-F_No+1)); 
      S_Pos : Positive;
      Index : Positive;
    begin
      S_Pos := F.S'First;
      Index := F.Starts'First;
      F.Starts(Index) := Interfaces.Unsigned_16(S_Pos);

      for I in Pos .. Pos + Text_Sz - 1 loop
        -- copy over to F.S anything that is not separator
        if Msg(I) /= Sep then
          F.S( S_Pos ) := Character'Val(Msg(I));
          S_Pos := S_Pos + 1;
        else
          -- if it's separator, check and if ok, add next as start
          if I = Pos + Text_Sz or -- separator as last character is error
               Msg(I+1) = Sep or  -- 2 consecutive separators is error
               Index >= F.Starts'Last then -- too many separators is error
            raise Invalid_Msg;
          else
            Index := Index + 1;
            F.Starts( Index ) := Interfaces.Unsigned_16(S_Pos);
          end if;
        end if;
      end loop;

      -- copy the whole structure to output variable
      FR := F;      
    end;

    -- read message counter now
    Pos := Pos + Text_Sz;
    Read_U16( Msg, Pos, Counter );
    
  end Read_File_Request;

  ------------------
  -- RSA Messages --
  ------------------

  procedure Write_SKeys_RMsg( Keyset  : in Serpent_Keyset;
                              Counter : in Interfaces.Unsigned_16;
                              Msg     : out Raw_Types.RSA_Msg) is
  begin
    -- call internal write of Serpent keys with correct type ID
    Write_SKeys( Keyset, Counter, SKeys_R_Type, Msg );
  end Write_SKeys_RMsg;

  procedure Read_SKeys_RMsg( Msg     : in Raw_Types.RSA_Msg;
                             Counter : out Interfaces.Unsigned_16;
                             Keyset  : out Serpent_Keyset) is
  begin
    -- check type id and call internal Read_SKeys if correct
    if Msg(Msg'First) /= SKeys_R_Type then
      raise Invalid_Msg;
    else
      Read_SKeys( Msg, Counter, Keyset );
    end if;
  end Read_SKeys_RMsg;

  procedure Write_KMgm_RMsg( KMgm    : in Keys_Mgm;
                             Counter : in Interfaces.Unsigned_16;
                             Msg     : out Raw_Types.RSA_Msg) is
  begin
    -- call internal write of key mgm with correct type ID
    Write_KMgm( KMgm, Counter, Key_Mgm_R_Type, Msg );
  end Write_KMgm_RMsg;

  procedure Read_KMgm_RMsg( Msg     : in Raw_Types.RSA_Msg;
                            Counter : out Interfaces.Unsigned_16;
                            KMgm    : out Keys_Mgm) is
  begin
    -- check type id and call internal Read_KMgm if correct
    if Msg(Msg'First) /= Key_Mgm_R_Type then
      raise Invalid_Msg;
    else
      Read_KMgm( Msg, Counter, KMgm );
    end if;
  end Read_KMgm_RMsg;


  ----------Utilities ----------
  -- String to Octets conversion
  procedure String_To_Octets(Str: in String; O: out Raw_Types.Octets) is
  begin
    Assert( Str'Length = O'Length );
    for I in 1..Str'Length loop
      O( O'First+I-1 ) := Character'Pos(Str(Str'First + I - 1 ));
    end loop;
  end String_To_Octets;

  -- Octets to string conversion
  -- NB: Str'Length has to be EQUAL to Octets'Length!
  procedure Octets_To_String(O: in Raw_Types.Octets; Str: out String) is
  begin
    Assert( O'Length = Str'Length );
    for I in 1..O'Length loop
      Str( Str'First+I-1 ) := Character'Val(O(O'First + I - 1 ));
    end loop;
  end Octets_To_String;

  ------------------
  -- private part --
  ------------------
  procedure Cast_LE( LE: in out Raw_Types.Octets ) is
  begin
    -- flip octets ONLY if native is big endian.
    if System.Default_Bit_Order = System.High_Order_First then
      declare
        BE: constant Raw_Types.Octets := LE;
      begin
        for I in 1..LE'Length loop
          LE(LE'First+I-1) := BE(BE'Last-I+1);
        end loop;
      end;
    end if;
    -- NOTHING to do for native little endian
  end Cast_LE;

  procedure Write_SKeys( Keyset  : in Serpent_Keyset;
                         Counter : in Interfaces.Unsigned_16;
                         Type_ID : in Interfaces.Unsigned_8;
                         Msg     : out Raw_Types.Octets) is
    Pos   : Integer := Msg'First;
    Check : CRC32.CRC32;
    K     : Serpent.Key;
  begin
    -- write Type ID
    Msg(Pos) := Type_ID;
    Pos := Pos + 1;

    -- write count of keys (NB: this IS 8 bits by definition)
    Msg(Pos) := Keyset.Keys'Length;
    Pos := Pos + 1;

    -- write keys
    for I in Keyset.Keys'Range loop
      -- retrieve Key to write
      K := Keyset.Keys( I );

      -- write key itself
      Msg(Pos..Pos+K'Length-1) := K;
      -- ensure little endian order in message
      Cast_LE(Msg(Pos..Pos+K'Length-1));
      Pos := Pos + K'Length;

      -- write CRC of key
      Check := CRC32.CRC( K );
      Msg(Pos..Pos+3) := Raw_Types.Cast(Check);
      Cast_LE(Msg(Pos..Pos+3));
      Pos := Pos + 4;
    end loop;

    -- write flag
    Msg(Pos) := Keyset.Flag;
    Pos := Pos + 1;

    -- write message counter
    Write_U16( Msg, Pos, Counter );

    -- write padding as needed; endianness is irrelevant here
    if Pos <= Msg'Last then
      RNG.Get_Octets( Msg(Pos..Msg'Last) );
    end if;

  end Write_SKeys;

  procedure Read_SKeys( Msg     : in Raw_Types.Octets;
                        Counter : out Interfaces.Unsigned_16;
                        Keyset  : out Serpent_Keyset) is
    Pos: Integer := Msg'First;
  begin
    -- read type and check
    if Msg(Pos) = SKeys_S_Type or 
       Msg(Pos) = SKeys_R_Type then
      Pos := Pos + 1;
    else
      raise Invalid_Msg;
    end if;

    -- read count of keys and check
    if Msg(Pos) in Keys_Count'Range then
      declare
        N     : Keys_Count := Keys_Count(Msg(Pos));
        KS    : Serpent_Keyset(N);
        K     : Serpent.Key;
        Check : CRC32.CRC32;
        O4    : Raw_Types.Octets_4; 
      begin
        Pos := Pos + 1;
        --read keys and check crc for each
        for I in 1 .. N loop
          -- read key and advance pos
          K := Msg(Pos..Pos+K'Length-1);
          Cast_LE(K);
          Pos := Pos + K'Length;
          -- read crc and compare to crc32(key)
          O4 := Msg(Pos..Pos+3);
          Cast_LE(O4);
          Check   := Raw_Types.Cast(O4);
          Pos := Pos + 4;
          if Check /= CRC32.CRC(K) then
            raise Invalid_Msg;
          end if;
          -- if it got here, key is fine so add to set
          KS.Keys(KS.Keys'First + I -1) := K;
        end loop;
        -- read and set flag
        KS.Flag := Msg(Pos);
        Pos := Pos + 1;
        -- read and set message counter
        Read_U16( Msg, Pos, Counter );
        -- rest of message is padding so it's ignored
        -- copy keyset to output variable
        Keyset := KS;
      end;
    else
      raise Invalid_Msg;
    end if;
  end Read_SKeys;

  -- writes given key management structure to the given octets array
  procedure Write_KMgm( KMgm    : in Keys_Mgm;
                        Counter : in Interfaces.Unsigned_16;
                        Type_ID : in Interfaces.Unsigned_8;
                        Msg     : out Raw_Types.Octets) is
    Pos   : Integer := Msg'First;
  begin
    -- write given type id
    Msg(Pos) := Type_ID;
    Pos := Pos + 1;

    -- write count of server keys requested
    Msg(Pos) := KMgm.N_Server;
    Pos := Pos + 1;

    -- write count of client keys requested
    Msg(Pos) := KMgm.N_Client;
    Pos := Pos + 1;

    -- write id of key preferred for further inbound Serpent messages
    Msg(Pos) := KMgm.Key_ID;
    Pos := Pos + 1;

    -- write count of burnt keys in this message
    Msg(Pos..Pos) := Cast( KMgm.N_Burnt );
    Pos := Pos + 1;

    -- if there are any burnt keys, write their ids
    if KMgm.N_Burnt > 0 then
      Msg( Pos .. Pos + KMgm.Burnt'Length - 1 ) := KMgm.Burnt;
      Pos := Pos + KMgm.Burnt'Length;
    end if;

    -- write the message count
    Write_U16( Msg, Pos, Counter );

    -- pad with random octets until the end of Msg
    if Pos <= Msg'Last then
      RNG.Get_Octets( Msg(Pos..Msg'Last) );
    end if;

  end Write_KMgm;

  -- attempts to read from the given array of octets a key management structure
  procedure Read_KMgm( Msg     : in Raw_Types.Octets;
                       Counter : out Interfaces.Unsigned_16;
                       KMgm    : out Keys_Mgm) is
    Pos       : Integer := Msg'First;
    Burnt_Pos : Integer := Msg'First + 4;
  begin
    -- read type and check
    if Msg(Pos) = Key_Mgm_S_Type or 
       Msg(Pos) = Key_Mgm_R_Type then
      Pos := Pos + 1;
    else
      raise Invalid_Msg;
    end if;

    -- read the count of burnt keys and check
    -- NB: Burnt_Pos IS in range of Counter_8bits since it's an octet
    declare
      N_Burnt : Counter_8bits := Counter_8bits(Msg(Burnt_Pos));
      Mgm     : Keys_Mgm(N_Burnt);
    begin
      -- read count of server keys requested
      Mgm.N_Server := Msg(Pos);
      Pos := Pos + 1;

      -- read count of client keys requested
      Mgm.N_Client := Msg(Pos);
      Pos := Pos + 1;

      -- read ID of Serpent key preferred for further inbound messages
      Mgm.Key_ID   := Msg(Pos);
      Pos := Pos + 2; --skip the count of burnt keys as it's read already

      -- read ids of burnt keys, if any
      if N_Burnt > 0 then
        Mgm.Burnt := Msg(Pos..Pos+N_Burnt-1);
        Pos := Pos + N_Burnt;
      end if;

      -- read and set message counter
      Read_U16( Msg, Pos, Counter );
      -- rest of message is padding so it's ignored
      -- copy the keys mgm structure to output param
      KMgm := Mgm;
    end;
  end Read_KMgm;

  -- Write a 16 bits value to Octets at Pos; Pos increases by 2.
  procedure Write_U16( Msg: in out Raw_Types.Octets;
                       Pos: in out Natural;
                       U16: in Interfaces.Unsigned_16) is
  begin
    Msg(Pos..Pos+1) := Raw_Types.Cast(U16);
    Cast_LE(Msg(Pos..Pos+1));
    Pos := Pos + 2;
  end Write_U16;

  -- Read a 16-bits values from Octets from Pos; Pos increases by 2.
  procedure Read_U16( Msg: in Raw_Types.Octets;
                      Pos: in out Natural;
                      U16: out Interfaces.Unsigned_16) is
    O2  : Raw_Types.Octets_2;
  begin
    O2  := Msg(Pos..Pos+1);
    Cast_LE(O2);
    U16 := Raw_Types.Cast(O2);
    Pos := Pos + 2;
  end Read_U16;

end Messages;
