open InterpreterStatements
open Interpreter
module L = Logging
module T = Types
module A = CfimAst
module M = Modules
module SA = SymbolicAst
module Micro = PureMicroPasses
open PureUtils
open TranslateCore

(** The local logger *)
let log = TranslateCore.log

type config = {
  eval_config : Contexts.partial_config;
  mp_config : Micro.config;
  split_files : bool;
      (** Controls whether we split the generated definitions between different
          files for the types, clauses and functions, or if we group them in
          one file.
       *)
  test_unit_functions : bool;
      (** If true, insert tests in the generated files to check that the
          unit functions normalize to `Success _`.
          
          For instance, in F* it generates code like this:
          ```
          let _ = assert_norm (FUNCTION () = Success ())
          ```
       *)
  extract_decreases_clauses : bool;
      (** If true, insert `decreases` clauses for all the recursive definitions.

          The body of such clauses must be defined by the user.
       *)
  extract_template_decreases_clauses : bool;
      (** In order to help the user, we can generate "template" decrease clauses
          (i.e., definitions with proper signatures but dummy bodies) in a
          dedicated file.
       *)
}

type symbolic_fun_translation = V.symbolic_value list * SA.expression
(** The result of running the symbolic interpreter on a function:
    - the list of symbolic values used for the input values
    - the generated symbolic AST
*)

(** Execute the symbolic interpreter on a function to generate a list of symbolic ASTs,
    for the forward function and the backward functions.
*)
let translate_function_to_symbolics (config : C.partial_config)
    (trans_ctx : trans_ctx) (fdef : A.fun_def) :
    symbolic_fun_translation * symbolic_fun_translation list =
  (* Debug *)
  log#ldebug
    (lazy
      ("translate_function_to_symbolics: " ^ Print.name_to_string fdef.A.name));

  let { type_context; fun_context } = trans_ctx in

  (* Evaluate *)
  let synthesize = true in
  let evaluate gid =
    let inputs, symb =
      evaluate_function_symbolic config synthesize type_context fun_context fdef
        gid
    in
    (inputs, Option.get symb)
  in
  (* Execute the forward function *)
  let forward = evaluate None in
  (* Execute the backward functions *)
  let backwards =
    T.RegionGroupId.mapi
      (fun gid _ -> evaluate (Some gid))
      fdef.signature.regions_hierarchy
  in

  (* Return *)
  (forward, backwards)

(** Translate a function, by generating its forward and backward translations.

    [fun_sigs]: maps the forward/backward functions to their signatures. In case
    of backward functions, we also provide names for the outputs.
    TODO: maybe we should introduce a record for this.
*)
let translate_function_to_pure (config : C.partial_config)
    (mp_config : Micro.config) (trans_ctx : trans_ctx)
    (fun_sigs : SymbolicToPure.fun_sig_named_outputs RegularFunIdMap.t)
    (pure_type_defs : Pure.type_def Pure.TypeDefId.Map.t) (fdef : A.fun_def) :
    pure_fun_translation =
  (* Debug *)
  log#ldebug
    (lazy ("translate_function_to_pure: " ^ Print.name_to_string fdef.A.name));

  let { type_context; fun_context } = trans_ctx in
  let def_id = fdef.def_id in

  (* Compute the symbolic ASTs *)
  let symbolic_forward, symbolic_backwards =
    translate_function_to_symbolics config trans_ctx fdef
  in

  (* Convert the symbolic ASTs to pure ASTs: *)

  (* Initialize the context *)
  let forward_sig = RegularFunIdMap.find (A.Local def_id, None) fun_sigs in
  let forward_ret_ty =
    match forward_sig.sg.outputs with
    | [ ty ] -> ty
    | _ -> failwith "Unreachable"
  in
  let sv_to_var = V.SymbolicValueId.Map.empty in
  let var_counter = Pure.VarId.generator_zero in
  let calls = V.FunCallId.Map.empty in
  let abstractions = V.AbstractionId.Map.empty in
  let type_context =
    {
      SymbolicToPure.types_infos = type_context.type_infos;
      cfim_type_defs = type_context.type_defs;
      type_defs = pure_type_defs;
    }
  in
  let fun_context =
    { SymbolicToPure.cfim_fun_defs = fun_context.fun_defs; fun_sigs }
  in
  let ctx =
    {
      SymbolicToPure.bid = None;
      (* Dummy for now *)
      ret_ty = forward_ret_ty;
      (* Will need to be updated for the backward functions *)
      sv_to_var;
      var_counter;
      type_context;
      fun_context;
      fun_def = fdef;
      forward_inputs = [];
      (* Empty for now *)
      backward_inputs = T.RegionGroupId.Map.empty;
      (* Empty for now *)
      backward_outputs = T.RegionGroupId.Map.empty;
      (* Empty for now *)
      calls;
      abstractions;
    }
  in

  (* We need to initialize the input/output variables *)
  let forward_input_vars = CfimAstUtils.fun_def_get_input_vars fdef in
  let forward_input_varnames =
    List.map (fun (v : A.var) -> v.name) forward_input_vars
  in
  let num_forward_inputs = fdef.arg_count in
  let add_forward_inputs input_svs ctx =
    let input_svs = List.combine forward_input_varnames input_svs in
    let ctx, forward_inputs =
      SymbolicToPure.fresh_named_vars_for_symbolic_values input_svs ctx
    in
    { ctx with forward_inputs }
  in

  (* The symbolic to pure config *)
  let sp_config =
    {
      SymbolicToPure.filter_useless_back_calls =
        mp_config.filter_useless_monadic_calls;
    }
  in

  (* Translate the forward function *)
  let pure_forward =
    SymbolicToPure.translate_fun_def sp_config
      (add_forward_inputs (fst symbolic_forward) ctx)
      (snd symbolic_forward)
  in

  (* Translate the backward functions *)
  let translate_backward (rg : T.region_var_group) : Pure.fun_def =
    (* For the backward inputs/outputs initialization: we use the fact that
     * there are no nested borrows for now, and so that the region groups
     * can't have parents *)
    assert (rg.parents = []);
    let back_id = rg.id in
    let input_svs, symbolic = T.RegionGroupId.nth symbolic_backwards back_id in
    let ctx = add_forward_inputs input_svs ctx in
    (* TODO: the computation of the backward inputs is a bit awckward... *)
    let backward_sg =
      RegularFunIdMap.find (A.Local def_id, Some back_id) fun_sigs
    in
    let _, backward_inputs =
      Collections.List.split_at backward_sg.sg.inputs num_forward_inputs
    in
    (* As we forbid nested borrows, the additional inputs for the backward
     * functions come from the borrows in the return value of the rust function:
     * we thus use the name "ret" for those inputs *)
    let backward_inputs =
      List.map (fun ty -> (Some "ret", ty)) backward_inputs
    in
    let ctx, backward_inputs = SymbolicToPure.fresh_vars backward_inputs ctx in
    (* The outputs for the backward functions, however, come from borrows
     * present in the input values of the rust function: for those we reuse
     * the names of the  input values. *)
    let backward_outputs =
      List.combine backward_sg.output_names backward_sg.sg.outputs
    in
    let ctx, backward_outputs =
      SymbolicToPure.fresh_vars backward_outputs ctx
    in
    let backward_output_tys =
      List.map (fun (v : Pure.var) -> v.ty) backward_outputs
    in
    let backward_ret_ty = mk_simpl_tuple_ty backward_output_tys in
    let backward_inputs =
      T.RegionGroupId.Map.singleton back_id backward_inputs
    in
    let backward_outputs =
      T.RegionGroupId.Map.singleton back_id backward_outputs
    in

    (* Put everything in the context *)
    let ctx =
      {
        ctx with
        bid = Some back_id;
        ret_ty = backward_ret_ty;
        backward_inputs;
        backward_outputs;
      }
    in

    (* Translate *)
    SymbolicToPure.translate_fun_def sp_config ctx symbolic
  in
  let pure_backwards =
    List.map translate_backward fdef.signature.regions_hierarchy
  in

  (* Return *)
  (pure_forward, pure_backwards)

let translate_module_to_pure (config : C.partial_config)
    (mp_config : Micro.config) (m : M.cfim_module) :
    trans_ctx * Pure.type_def list * (bool * pure_fun_translation) list =
  (* Debug *)
  log#ldebug (lazy "translate_module_to_pure");

  (* Compute the type and function contexts *)
  let type_context, fun_context = compute_type_fun_contexts m in
  let trans_ctx = { type_context; fun_context } in

  (* Translate all the type definitions *)
  let type_defs = SymbolicToPure.translate_type_defs m.types in

  (* Compute the type definition map *)
  let type_defs_map =
    Pure.TypeDefId.Map.of_list
      (List.map (fun (def : Pure.type_def) -> (def.def_id, def)) type_defs)
  in

  (* Translate all the function *signatures* *)
  let assumed_sigs =
    List.map
      (fun (id, sg, _, _) ->
        (A.Assumed id, List.map (fun _ -> None) (sg : A.fun_sig).inputs, sg))
      Assumed.assumed_infos
  in
  let local_sigs =
    List.map
      (fun (fdef : A.fun_def) ->
        ( A.Local fdef.def_id,
          List.map
            (fun (v : A.var) -> v.name)
            (CfimAstUtils.fun_def_get_input_vars fdef),
          fdef.signature ))
      m.functions
  in
  let sigs = List.append assumed_sigs local_sigs in
  let fun_sigs =
    SymbolicToPure.translate_fun_signatures type_context.type_infos sigs
  in

  (* Translate all the functions *)
  let pure_translations =
    List.map
      (translate_function_to_pure config mp_config trans_ctx fun_sigs
         type_defs_map)
      m.functions
  in

  (* Apply the micro-passes *)
  let pure_translations =
    List.map
      (Micro.apply_passes_to_pure_fun_translation mp_config trans_ctx)
      pure_translations
  in

  (* Return *)
  (trans_ctx, type_defs, pure_translations)

type gen_ctx = {
  m : M.cfim_module;
  extract_ctx : PureToExtract.extraction_ctx;
  trans_types : Pure.type_def Pure.TypeDefId.Map.t;
  trans_funs : (bool * pure_fun_translation) Pure.FunDefId.Map.t;
  functions_with_decreases_clause : Pure.FunDefId.Set.t;
}
(** Extraction context *)

type gen_config = {
  extract_types : bool;
  extract_decreases_clauses : bool;
  extract_template_decreases_clauses : bool;
  extract_fun_defs : bool;
  test_unit_functions : bool;
}

(** A generic utility to generate the extracted definitions: as we may want to
    split the definitions between different files (or not), we can control
    what is precisely extracted.
 *)
let extract_definitions (fmt : Format.formatter) (config : gen_config)
    (ctx : gen_ctx) : unit =
  (* Export the definition groups to the file, in the proper order *)
  let export_type (qualif : ExtractToFStar.type_def_qualif)
      (id : Pure.TypeDefId.id) : unit =
    let def = Pure.TypeDefId.Map.find id ctx.trans_types in
    ExtractToFStar.extract_type_def ctx.extract_ctx fmt qualif def
  in

  (* Utility to check a function has a decrease clause *)
  let has_decreases_clause (def : Pure.fun_def) : bool =
    Pure.FunDefId.Set.mem def.def_id ctx.functions_with_decreases_clause
  in

  (* In case of (non-mutually) recursive functions, we use a simple procedure to
   * check if the forward and backward functions are mutually recursive.
   *)
  let export_functions (is_rec : bool)
      (pure_ls : (bool * pure_fun_translation) list) : unit =
    (* Concatenate the function definitions, filtering the useless forward
     * functions. We also make pairs: (forward function, backward function)
     * (the forward function contains useful information that we want to keep) *)
    let fls =
      List.concat
        (List.map
           (fun (keep_fwd, (fwd, back_ls)) ->
             let back_ls = List.map (fun back -> (fwd, back)) back_ls in
             if keep_fwd then (fwd, fwd) :: back_ls else back_ls)
           pure_ls)
    in
    (* Extract the decrease clauses template bodies *)
    if config.extract_template_decreases_clauses then
      List.iter
        (fun (_, (fwd, _)) ->
          let has_decr_clause = has_decreases_clause fwd in
          if has_decr_clause then
            ExtractToFStar.extract_template_decreases_clause ctx.extract_ctx fmt
              fwd)
        pure_ls;
    (* Extract the function definitions *)
    (if config.extract_fun_defs then
     (* Check if the functions are mutually recursive - this really works
      * to check if the forward and backward translations of a single
      * recursive function are mutually recursive *)
     let is_mut_rec =
       if is_rec then
         if List.length pure_ls <= 1 then
           not (PureUtils.functions_not_mutually_recursive (List.map fst fls))
         else true
       else false
     in
     List.iteri
       (fun i (fwd_def, def) ->
         let qualif =
           if not is_rec then ExtractToFStar.Let
           else if is_mut_rec then
             if i = 0 then ExtractToFStar.LetRec else ExtractToFStar.And
           else ExtractToFStar.LetRec
         in
         let has_decr_clause =
           has_decreases_clause def && config.extract_decreases_clauses
         in
         ExtractToFStar.extract_fun_def ctx.extract_ctx fmt qualif
           has_decr_clause fwd_def def)
       fls);
    (* Insert unit tests if necessary *)
    if config.test_unit_functions then
      List.iter
        (fun (keep_fwd, (fwd, _)) ->
          if keep_fwd then
            ExtractToFStar.extract_unit_test_if_unit_fun ctx.extract_ctx fmt fwd)
        pure_ls
  in

  let export_decl (decl : M.declaration_group) : unit =
    match decl with
    | Type (NonRec id) ->
        if config.extract_types then export_type ExtractToFStar.Type id
    | Type (Rec ids) ->
        if config.extract_types then
          List.iteri
            (fun i id ->
              let qualif =
                if i = 0 then ExtractToFStar.Type else ExtractToFStar.And
              in
              export_type qualif id)
            ids
    | Fun (NonRec id) ->
        (* Lookup *)
        let pure_fun = Pure.FunDefId.Map.find id ctx.trans_funs in
        (* Translate *)
        export_functions false [ pure_fun ]
    | Fun (Rec ids) ->
        (* General case of mutually recursive functions *)
        (* Lookup *)
        let pure_funs =
          List.map (fun id -> Pure.FunDefId.Map.find id ctx.trans_funs) ids
        in
        (* Translate *)
        export_functions true pure_funs
  in

  List.iter export_decl ctx.m.declarations

let extract_file (config : gen_config) (ctx : gen_ctx) (filename : string)
    (rust_module_name : string) (module_name : string) (custom_msg : string)
    (custom_imports : string list) (custom_includes : string list) : unit =
  (* Open the file and create the formatter *)
  let out = open_out filename in
  let fmt = Format.formatter_of_out_channel out in

  (* Print the headers.
   * Note that we don't use the OCaml formatter for purpose: we want to control
   * line insertion (we have to make sure that some instructions like `open MODULE`
   * are printed on one line!).
   * This is ok as long as we end up with a line break, so that the formatter's
   * internal count is consistent with the state of the file.
   *)
  (* Create the header *)
  Printf.fprintf out "(** THIS FILE WAS AUTOMATICALLY GENERATED BY AENEAS *)\n";
  Printf.fprintf out "(** [%s]%s *)\n" rust_module_name custom_msg;
  Printf.fprintf out "module %s\n" module_name;
  Printf.fprintf out "open Primitives\n";
  (* Add the custom imports *)
  List.iter (fun m -> Printf.fprintf out "open %s\n" m) custom_imports;
  (* Add the custom includes *)
  List.iter (fun m -> Printf.fprintf out "include %s\n" m) custom_includes;
  (* Z3 options *)
  Printf.fprintf out "\n#set-options \"--z3rlimit 50 --fuel 0 --ifuel 1\"\n";

  (* From now onwards, we use the formatter *)
  (* Set the margin *)
  Format.pp_set_margin fmt 80;

  (* Create a vertical box *)
  Format.pp_open_vbox fmt 0;

  (* Extract the definitions *)
  extract_definitions fmt config ctx;

  (* Close the box and end the formatting *)
  Format.pp_close_box fmt ();
  Format.pp_print_newline fmt ();

  (* Some logging *)
  log#linfo (lazy ("Generated: " ^ filename))

(** Translate a module and write the synthesized code to an output file. *)
let translate_module (filename : string) (dest_dir : string) (config : config)
    (m : M.cfim_module) : unit =
  (* Translate the module to the pure AST *)
  let trans_ctx, trans_types, trans_funs =
    translate_module_to_pure config.eval_config config.mp_config m
  in

  (* Initialize the extraction context - for now we extract only to F* *)
  let names_map =
    PureToExtract.initialize_names_map ExtractToFStar.fstar_names_map_init
  in
  let variant_concatenate_type_name = true in
  let fstar_fmt =
    ExtractToFStar.mk_formatter trans_ctx variant_concatenate_type_name
  in
  let ctx =
    { PureToExtract.trans_ctx; names_map; fmt = fstar_fmt; indent_incr = 2 }
  in

  (* We need to compute which functions are recursive, in order to know
   * whether we should generate a decrease clause or not. *)
  let rec_functions =
    Pure.FunDefId.Set.of_list
      (List.concat
         (List.map
            (fun decl -> match decl with M.Fun (Rec ids) -> ids | _ -> [])
            m.declarations))
  in

  (* Register unique names for all the top-level types and functions.
   * Note that the order in which we generate the names doesn't matter:
   * we just need to generate a mapping from identifier to name, and make
   * sure there are no name clashes. *)
  let ctx =
    List.fold_left
      (fun ctx def -> ExtractToFStar.extract_type_def_register_names ctx def)
      ctx trans_types
  in

  let ctx =
    List.fold_left
      (fun ctx (keep_fwd, def) ->
        (* Note that we generate a decrease clause for all the recursive functions *)
        let gen_decr_clause =
          Pure.FunDefId.Set.mem (fst def).Pure.def_id rec_functions
        in
        ExtractToFStar.extract_fun_def_register_names ctx keep_fwd
          gen_decr_clause def)
      ctx trans_funs
  in

  (* Open the output file *)
  (* First compute the filename by replacing the extension and converting the
   * case (rust module names are snake case) *)
  let module_name, extract_filebasename =
    match Filename.chop_suffix_opt ~suffix:".cfim" filename with
    | None ->
        (* Note that we already checked the suffix upon opening the file *)
        failwith "Unreachable"
    | Some filename ->
        (* Retrieve the file basename *)
        let basename = Filename.basename filename in
        (* Convert the case *)
        let module_name = StringUtils.to_camel_case basename in
        (* Concatenate *)
        (module_name, Filename.concat dest_dir module_name)
  in

  (* Put the translated definitions in maps *)
  let trans_types =
    Pure.TypeDefId.Map.of_list
      (List.map (fun (d : Pure.type_def) -> (d.def_id, d)) trans_types)
  in
  let trans_funs =
    Pure.FunDefId.Map.of_list
      (List.map
         (fun ((keep_fwd, (fd, bdl)) : bool * pure_fun_translation) ->
           (fd.def_id, (keep_fwd, (fd, bdl))))
         trans_funs)
  in

  (* Extract the file(s) *)
  let gen_ctx =
    {
      m;
      extract_ctx = ctx;
      trans_types;
      trans_funs;
      functions_with_decreases_clause = rec_functions;
    }
  in

  (* Extract one or several files, depending on the configuration *)
  if config.split_files then (
    let base_gen_config =
      {
        extract_types = false;
        extract_decreases_clauses = config.extract_decreases_clauses;
        extract_template_decreases_clauses = false;
        extract_fun_defs = false;
        test_unit_functions = false;
      }
    in

    (* Extract the types *)
    let types_filename = extract_filebasename ^ ".Types.fst" in
    let types_module = module_name ^ ".Types" in
    let types_config = { base_gen_config with extract_types = true } in
    extract_file types_config gen_ctx types_filename m.M.name types_module
      ": type definitions" [] [];

    (* Extract the template clauses *)
    (if
     config.extract_decreases_clauses
     && config.extract_template_decreases_clauses
    then
     let clauses_filename = extract_filebasename ^ ".Clauses.Template.fst" in
     let clauses_module = module_name ^ ".Clauses.Template" in
     let clauses_config =
       { base_gen_config with extract_template_decreases_clauses = true }
     in
     extract_file clauses_config gen_ctx clauses_filename m.M.name
       clauses_module ": templates for the decreases clauses" [ types_module ]
       []);

    (* Extract the functions *)
    let fun_filename = extract_filebasename ^ ".Funs.fst" in
    let fun_module = module_name ^ ".Funs" in
    let fun_config =
      {
        base_gen_config with
        extract_fun_defs = true;
        test_unit_functions = config.test_unit_functions;
      }
    in
    let clauses_module = module_name ^ ".Clauses" in
    extract_file fun_config gen_ctx fun_filename m.M.name fun_module
      ": function definitions" []
      [ types_module; clauses_module ])
  else
    let gen_config =
      {
        extract_types = true;
        extract_decreases_clauses = config.extract_decreases_clauses;
        extract_template_decreases_clauses =
          config.extract_template_decreases_clauses;
        extract_fun_defs = true;
        test_unit_functions = config.test_unit_functions;
      }
    in
    (* Add the extension for F* *)
    let extract_filename = extract_filebasename ^ ".fst" in
    extract_file gen_config gen_ctx extract_filename m.M.name module_name "" []
      []