(*
 * This file is part of Barista.
 * Copyright (C) 2007-2014 Xavier Clerc.
 *
 * Barista is free software; you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as published by
 * the Free Software Foundation; either version 3 of the License, or
 * (at your option) any later version.
 *
 * Barista is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 *)


let (++) = UTF8.(++)


(* Tokens *)

type token =
  | Directive of string
  | Attribute of string
  | Label of UTF8.t
  | Int of int64
  | Float of float
  | String of UTF8.t
  | Class_name of Name.for_class
  | Array_type of UTF8.t
  | Primitive_type of Descriptor.java_type
  | Field of Name.for_class * Name.for_field * Descriptor.for_field
  | Dynamic_method of Name.for_method * Descriptor.for_method
  | Method of Name.for_class * Name.for_method * Descriptor.for_method
  | Array_method of Descriptor.array_type * Name.for_method * Descriptor.for_method
  | Method_signature of Name.for_method * (Descriptor.for_parameter list)
  | Method_type of Descriptor.for_method
  | Method_handle of Bootstrap.method_handle
  | Identifier of UTF8.t
  | Arrow
  | Tilde


(* Exception *)

BARISTA_ERROR =
  | Invalid_label of (s : UTF8.t) ->
      Printf.sprintf "invalid label %S"
        (UTF8.to_string_noerr s)
  | Invalid_directive of (s : UTF8.t) ->
      Printf.sprintf "invalid directive %S"
        (UTF8.to_string_noerr s)
  | Invalid_attribute of (s : UTF8.t) ->
      Printf.sprintf "invalid attribute %S"
        (UTF8.to_string_noerr s)
  | Invalid_string of (s : UTF8.t) ->
      Printf.sprintf "invalid string %S"
        (UTF8.to_string_noerr s)
  | Invalid_character of (s : UChar.t) ->
      Printf.sprintf "invalid character %C"
        (UChar.to_char_noerr s)
  | Invalid_float of (s : string) ->
      Printf.sprintf "invalid float constant %S" s
  | Invalid_integer of (s : string) ->
      Printf.sprintf "invalid integer constant %S" s
  | Invalid_method_handle of (s : UTF8.t) ->
      Printf.sprintf "invalid method handle %S"
        (UTF8.to_string_noerr s)
  | Invalid_token ->
      "invalid token"
  | Name_error of (e : Name.error) ->
      Printf.sprintf "invalid name (%s)"
        (Name.string_of_error e)
  | Descriptor_error of (e : Descriptor.error) ->
       Printf.sprintf "invalid descriptor (%s)"
        (Descriptor.string_of_error e)
  | UChar_error of (e : UChar.error) ->
       Printf.sprintf "invalid character (%s)"
        (UChar.string_of_error e)
  | UTF8_error of (e : UTF8.error) ->
       Printf.sprintf "invalid UTF8 (%s)"
        (UTF8.string_of_error e)


(* Lexing funtion *)

let rec analyze_token s =
  let len = UTF8.length s in
  let last = pred len in
  if UChar.equal @':' (UTF8.get s last) then begin
    let i = ref 1 in
    while !i < last && UChar.is_identifier_part (UTF8.get s !i) do
      incr i
    done;
    if !i = last && len > 1 && UChar.is_letter (UTF8.get s 0) then
      Label s
    else
      fail (Invalid_label s)
  end else if UTF8.contains @'%' s then begin
    let index = UTF8.index_from s 0 @'%' in
    let prefix = UTF8.substring s 0 (pred index) in
    let suffix = UTF8.substring s (succ index) last in
    let sub_token = analyze_token suffix in
    let handle =
      match (UTF8.to_string_noerr prefix), sub_token with
      | "getField", Field (x, y, z) -> `getField (x, y, z)
      | "getStatic", Field (x, y, z) -> `getStatic (x, y, z)
      | "putField", Field (x, y, z) -> `putField (x, y, z)
      | "putStatic", Field (x, y, z) -> `putStatic (x, y, z)
      | "invokeVirtual", Method (x, y, z) -> `invokeVirtual (x, y, z)
      | "invokeStatic", Method (x, y, z) -> `invokeStatic (x, y, z)
      | "invokeSpecial", Method (x, y, z) -> `invokeSpecial (x, y, z)
      | "newInvokeSpecial", Method (x, y, (z, t)) ->
          if (UTF8.equal (Name.utf8_for_method y) @"<init>")
              && (Descriptor.equal_java_type t (`Class x)) then
            `newInvokeSpecial (x, z)
          else
            fail (Invalid_method_handle s)
      | "invokeInterface", Method (x, y, z) -> `invokeInterface (x, y, z)
      | _ -> fail (Invalid_method_handle s) in
    Method_handle handle
  end else if UTF8.contains @'(' s && UTF8.contains @')' s then begin
    let opening_index = UTF8.index_from s 0 @'(' in
    let closing_index = UTF8.rindex_from s last @')' in
    let prefix = UTF8.substring s 0 (pred opening_index) in
    let params = UTF8.substring s (succ opening_index) (pred closing_index) in
    if (succ closing_index) < len
        && UChar.equal @':' (UTF8.get s (succ closing_index)) then
      let desc = (List.map Descriptor.java_type_of_external_utf8_no_void (UTF8.split @',' params)),
        (Descriptor.java_type_of_external_utf8 (UTF8.substring s (succ (succ closing_index)) last)) in
      if UTF8.contains @'.' prefix then
        let dot_idx = UTF8.rindex_from prefix (pred (UTF8.length prefix)) @'.' in
        let class_name = UTF8.substring prefix 0 (pred dot_idx) in
        let meth_name = UTF8.substring prefix (succ dot_idx) (pred (UTF8.length prefix)) in
        if UTF8.contains @'[' class_name then
          Array_method ((Descriptor.filter_non_array Descriptor.Invalid_array_element_type (Descriptor.java_type_of_external_utf8 class_name)),
                        (Name.make_for_method meth_name),
                        desc)
        else
          Method ((Name.make_for_class_from_external class_name),
                  (Name.make_for_method meth_name),
                  desc)
      else if (UTF8.length prefix) = 0 then
        Method_type desc
      else
        Dynamic_method ((Name.make_for_method prefix), desc)
    else
      Method_signature ((Name.make_for_method prefix),
                        (List.map Descriptor.java_type_of_external_utf8_no_void (UTF8.split @',' params)))
  end else begin
    if UTF8.contains @':' s then
      let colon_idx = UTF8.index_from s 0 @':' in
      let prefix = UTF8.substring s 0 (pred colon_idx) in
      let dot_idx = UTF8.rindex_from prefix (pred (UTF8.length prefix)) @'.' in
      Field ((Name.make_for_class_from_external (UTF8.substring prefix 0 (pred dot_idx))),
             (Name.make_for_field (UTF8.substring prefix (succ dot_idx) (pred (UTF8.length prefix)))),
             (Descriptor.java_type_of_external_utf8_no_void (UTF8.substring s (succ colon_idx) (pred (UTF8.length s)))))
    else if UChar.equal @']' (UTF8.get s last) then
      Array_type s
    else if UTF8.contains @'.' s then
      Class_name (Name.make_for_class_from_external s)
    else
      try
        let t = Descriptor.java_type_of_external_utf8 s in
        if (Descriptor.is_primitive t) || (t = `Void) then
            Primitive_type t
        else
          Identifier s
      with _ ->
        Identifier s
  end

let tokens_of_line l =
  let state = new UTF8LexerState.t l in
  let read_token () =
    let buf = UTF8Buffer.make () in
    if state#look_ahead @'.' then begin
      state#consume;
      while state#is_available && not (state#look_ahead_string @" \t#") do
        UTF8Buffer.add_char buf state#consume_char
      done;
      let dir = UTF8.to_string (UTF8Buffer.contents buf) in
      if (String.length dir) = 0 then
        fail (Invalid_directive (UTF8Buffer.contents buf))
      else
        Directive dir
    end else if state#look_ahead @'@' then begin
      state#consume;
      while state#is_available && not (state#look_ahead_string @" \t#") do
        UTF8Buffer.add_char buf state#consume_char
      done;
      let attr = UTF8.to_string (UTF8Buffer.contents buf) in
      if (String.length attr) = 0 then
        fail (Invalid_attribute (UTF8Buffer.contents buf))
      else
        Attribute attr
    end else if state#look_ahead @'"' then begin
      let prev = ref state#consume_char in
      UTF8Buffer.add_char buf !prev;
      while not (state#look_ahead @'"' && not (UChar.equal !prev @'\\')) do
        let curr = state#consume_char in
        UTF8Buffer.add_char buf curr;
        prev :=
          if (UChar.equal !prev @'\\') && (UChar.equal curr @'\\') then
            @'"'
          else
            curr
      done;
      UTF8Buffer.add_char buf state#consume_char;
      if state#is_available && not (state#look_ahead_string @" \t#") then
        fail (Invalid_string (UTF8Buffer.contents buf))
      else
         String (UTF8.unescape (UTF8Buffer.contents buf))
    end else if state#look_ahead @'\'' then begin
      let prev = ref state#consume_char in
      while not (state#look_ahead @'\'' && not (UChar.equal !prev @'\\')) do
        let curr = state#consume_char in
        UTF8Buffer.add_char buf curr;
        prev :=
          if (UChar.equal !prev @'\\') && (UChar.equal curr @'\\') then
            @'"'
          else
            curr
      done;
      state#consume;
      if state#is_available && not (state#look_ahead_string @" \t#") then
        fail (Invalid_character state#peek);
      let s = UTF8.unescape ( @"\"" ++ (UTF8Buffer.contents buf) ++ @"\"" ) in
      if (UTF8.length s) <> 1 then
        fail (Invalid_character (UTF8.get s 0))
      else
        Int (Int64.of_int (UChar.to_code (UTF8.get s 0)))
    end else if state#look_ahead_string @"+-0123456789" then begin
      while state#is_available && not (state#look_ahead_string @" \t#") do
        UTF8Buffer.add_char buf state#consume_char
      done;
      let n = UTF8.to_string (UTF8Buffer.contents buf) in
      let number = if (String.get n 0) = '+' then String.sub n 1 (pred (String.length n)) else n in
      if String.contains number '.' then
        try
          Float (float_of_string number)
        with _ -> fail (Invalid_float number)
      else
        try
          Int (Int64.of_string number)
        with _ -> fail (Invalid_integer number)
    end else if state#look_ahead @'=' then begin
      state#consume;
      state#consume_only @'>';
      Arrow
    end else if state#look_ahead @'~' then begin
      state#consume;
      Tilde
    end else begin
      let continue = ref false in
      while state#is_available && not (state#look_ahead_string @" \t:#") do
        let ch = state#consume_char in
        if UChar.equal @'.' ch then continue := true;
        UTF8Buffer.add_char buf ch;
        if UChar.equal @'(' ch then begin
          continue := true;
          while not (state#look_ahead @')') do
            let ch = state#consume_char in
            if not (UChar.is_whitespace ch) then
              UTF8Buffer.add_char buf ch
          done;
          UTF8Buffer.add_char buf state#consume_char;
        end
      done;
      state#consume_whitespace;
      if state#is_available && (state#look_ahead @':') then begin
        UTF8Buffer.add_char buf state#consume_char;
        if !continue then begin
          state#consume_whitespace;
          while state#is_available && not (state#look_ahead_string @" \t #") do
            UTF8Buffer.add_char buf state#consume_char;
          done;
        end;
      end;
      analyze_token (UTF8Buffer.contents buf)
    end in
  let tokens = ref [] in
  while state#is_available do
    state#consume_whitespace;
    if not state#is_available || state#look_ahead @'#' then
      while state#is_available do state#consume done
    else
      let tok =
        try
          read_token ()
        with
        | Name.Exception e -> fail (Name_error e)
        | Descriptor.Exception e -> fail (Descriptor_error e)
        | UChar.Exception e -> fail (UChar_error e)
        | UTF8.Exception e -> fail (UTF8_error e)
        | Exception e -> fail e
        | _ -> fail Invalid_token in
      tokens := tok :: !tokens;
  done;
  List.rev !tokens


(* Miscellaneous *)

let equal tok1 tok2 =
  match tok1, tok2 with
  | (Directive dir1), (Directive dir2) ->
      dir1 = dir2
  | (Attribute attr1), (Attribute attr2) ->
      attr1 = attr2
  | (Label lab1), (Label lab2) ->
      UTF8.equal lab1 lab2
  | (Int int1), (Int int2) ->
      int1 = int2
  | (Float flo1), (Float flo2) ->
      flo1 = flo2
  | (String str1), (String str2) ->
      UTF8.equal str1 str2
  | (Class_name cn1), (Class_name cn2) ->
      Name.equal_for_class cn1 cn2
  | (Array_type at1), (Array_type at2) ->
      UTF8.equal at1 at2
  | (Primitive_type pt1), (Primitive_type pt2) ->
      Descriptor.equal_java_type pt1 pt2
  | (Field (cn1, fn1, fd1)), (Field (cn2, fn2, fd2)) ->
      (Name.equal_for_class cn1 cn2)
        && (Name.equal_for_field fn1 fn2)
        && (Descriptor.equal_java_type
              (fd1 :> Descriptor.java_type)
              (fd2 :> Descriptor.java_type))
  | (Dynamic_method (mn1, md1)), (Dynamic_method (mn2, md2)) ->
      (Name.equal_for_method mn1 mn2)
        && (Descriptor.equal_for_method md1 md2)
  | (Method (cn1, mn1, md1)), (Method (cn2, mn2, md2)) ->
      (Name.equal_for_class cn1 cn2)
        && (Name.equal_for_method mn1 mn2)
        && (Descriptor.equal_for_method md1 md2)
  | (Array_method (ad1, mn1, md1)), (Array_method (ad2, mn2, md2)) ->
      (Descriptor.equal_java_type
         (ad1 :> Descriptor.java_type)
         (ad2 :> Descriptor.java_type))
        && (Name.equal_for_method mn1 mn2)
        && (Descriptor.equal_for_method md1 md2)
  | (Method_signature (mn1, md1)), (Method_signature (mn2, md2)) ->
      (Name.equal_for_method mn1 mn2)
        && (Utils.list_equal
              Descriptor.equal_java_type
              (md1 :> Descriptor.java_type list)
              (md2 :> Descriptor.java_type list))
  | (Method_type mt1), (Method_type mt2) ->
      Descriptor.equal_for_method mt1 mt2
  | (Method_handle mh1), (Method_handle mh2) ->
      Bootstrap.equal_method_handle mh1 mh2
  | (Identifier id1), (Identifier id2) ->
      UTF8.equal id1 id2
  | Arrow, Arrow ->
      true
  | Tilde, Tilde ->
      true
  | _ -> false
