summaryrefslogtreecommitdiff
path: root/compiler/Extract.ml
diff options
context:
space:
mode:
authorSon HO2023-12-23 01:46:58 +0100
committerGitHub2023-12-23 01:46:58 +0100
commit15a7d7b7322a1cd0ebeb328fde214060e23fa8b4 (patch)
tree6cce7d76969870f5bc18c5a7cd585e8873a1c0dc /compiler/Extract.ml
parentc3e0b90e422cbd902ee6d2b47073940c0017b7fb (diff)
parent63ccbd914d5d44aa30dee38a6fcc019310ab640b (diff)
Merge pull request #64 from AeneasVerif/son/merge_back
Merge the forward/backward functions
Diffstat (limited to 'compiler/Extract.ml')
-rw-r--r--compiler/Extract.ml157
1 files changed, 93 insertions, 64 deletions
diff --git a/compiler/Extract.ml b/compiler/Extract.ml
index 20cdb20b..87dcb1fd 100644
--- a/compiler/Extract.ml
+++ b/compiler/Extract.ml
@@ -43,8 +43,12 @@ let extract_fun_decl_register_names (ctx : extraction_ctx)
}
| _ -> ctx
in
- let backs = List.map (fun f -> f.f) def.backs in
- let funs = if def.keep_fwd then def.fwd.f :: backs else backs in
+ let funs =
+ if !Config.return_back_funs then [ def.fwd.f ]
+ else
+ let backs = List.map (fun f -> f.f) def.backs in
+ if def.keep_fwd then def.fwd.f :: backs else backs
+ in
List.fold_left
(fun ctx (f : fun_decl) ->
let open ExtractBuiltin in
@@ -128,9 +132,15 @@ let extract_adt_g_value
F.pp_print_string fmt "tt";
ctx)
else
- (* If there is exactly one value, we don't print the parentheses *)
+ (* If there is exactly one value, we don't print the parentheses.
+ Also, for Coq, we need the special syntax ['(...)] if we destruct
+ a tuple pattern in a let-binding and the tuple has > 2 values.
+ *)
let lb, rb =
- if List.length field_values = 1 then ("", "") else ("(", ")")
+ if List.length field_values = 1 then ("", "")
+ else if !backend = Coq && is_single_pat && List.length field_values > 2
+ then ("'(", ")")
+ else ("(", ")")
in
F.pp_print_string fmt lb;
let ctx =
@@ -237,30 +247,60 @@ let fun_builtin_filter_types (id : FunDeclId.id) (types : 'a list)
Result.Ok types
(** [inside]: see {!extract_ty}.
+ [with_type]: do we also generate a type annotation? This is necessary for
+ backends like Coq when we write lambdas (Coq is not powerful enough to
+ infer the type).
As a pattern can introduce new variables, we return an extraction context
updated with new bindings.
*)
let rec extract_typed_pattern (ctx : extraction_ctx) (fmt : F.formatter)
- (is_let : bool) (inside : bool) (v : typed_pattern) : extraction_ctx =
- match v.value with
- | PatConstant cv ->
- extract_literal fmt inside cv;
- ctx
- | PatVar (v, _) ->
- let vname = ctx_compute_var_basename ctx v.basename v.ty in
- let ctx, vname = ctx_add_var vname v.id ctx in
- F.pp_print_string fmt vname;
- ctx
- | PatDummy ->
- F.pp_print_string fmt "_";
- ctx
- | PatAdt av ->
- let extract_value ctx inside v =
- extract_typed_pattern ctx fmt is_let inside v
- in
- extract_adt_g_value extract_value fmt ctx is_let inside av.variant_id
- av.field_values v.ty
+ (is_let : bool) (inside : bool) ?(with_type = false) (v : typed_pattern) :
+ extraction_ctx =
+ if with_type then F.pp_print_string fmt "(";
+ let inside = inside && not with_type in
+ let ctx =
+ match v.value with
+ | PatConstant cv ->
+ extract_literal fmt inside cv;
+ ctx
+ | PatVar (v, _) ->
+ let vname = ctx_compute_var_basename ctx v.basename v.ty in
+ let ctx, vname = ctx_add_var vname v.id ctx in
+ F.pp_print_string fmt vname;
+ ctx
+ | PatDummy ->
+ F.pp_print_string fmt "_";
+ ctx
+ | PatAdt av ->
+ let extract_value ctx inside v =
+ extract_typed_pattern ctx fmt is_let inside v
+ in
+ extract_adt_g_value extract_value fmt ctx is_let inside av.variant_id
+ av.field_values v.ty
+ in
+ if with_type then (
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt ":";
+ F.pp_print_space fmt ();
+ extract_ty ctx fmt TypeDeclId.Set.empty false v.ty;
+ F.pp_print_string fmt ")");
+ ctx
+
+(** Return true if we need to wrap a succession of let-bindings in a [do ...]
+ block (because some of them are monadic) *)
+let lets_require_wrap_in_do (lets : (bool * typed_pattern * texpression) list) :
+ bool =
+ match !backend with
+ | Lean ->
+ (* For Lean, we wrap in a block iff at least one of the let-bindings is monadic *)
+ List.exists (fun (m, _, _) -> m) lets
+ | HOL4 ->
+ (* HOL4 is similar to HOL4, but we add a sanity check *)
+ let wrap_in_do = List.exists (fun (m, _, _) -> m) lets in
+ if wrap_in_do then assert (List.for_all (fun (m, _, _) -> m) lets);
+ wrap_in_do
+ | FStar | Coq -> false
(** [inside]: controls the introduction of parentheses. See [extract_ty]
@@ -285,9 +325,9 @@ let rec extract_texpression (ctx : extraction_ctx) (fmt : F.formatter)
| App _ ->
let app, args = destruct_apps e in
extract_App ctx fmt inside app args
- | Abs _ ->
- let xl, e = destruct_abs_list e in
- extract_Abs ctx fmt inside xl e
+ | Lambda _ ->
+ let xl, e = destruct_lambdas e in
+ extract_Lambda ctx fmt inside xl e
| Qualif _ ->
(* We use the app case *)
extract_App ctx fmt inside e []
@@ -574,7 +614,7 @@ and extract_field_projector (ctx : extraction_ctx) (fmt : F.formatter)
(* No argument: shouldn't happen *)
raise (Failure "Unreachable")
-and extract_Abs (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
+and extract_Lambda (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
(xl : typed_pattern list) (e : texpression) : unit =
(* Open a box for the abs expression *)
F.pp_open_hovbox fmt ctx.indent_incr;
@@ -583,15 +623,16 @@ and extract_Abs (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
(* Print the lambda - note that there should always be at least one variable *)
assert (xl <> []);
F.pp_print_string fmt "fun";
+ let with_type = !backend = Coq in
let ctx =
List.fold_left
(fun ctx x ->
F.pp_print_space fmt ();
- extract_typed_pattern ctx fmt true true x)
+ extract_typed_pattern ctx fmt true true ~with_type x)
ctx xl
in
F.pp_print_space fmt ();
- if !backend = Lean then F.pp_print_string fmt "=>"
+ if !backend = Lean || !backend = Coq then F.pp_print_string fmt "=>"
else F.pp_print_string fmt "->";
F.pp_print_space fmt ();
(* Print the body *)
@@ -630,15 +671,6 @@ and extract_lets (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
| HOL4 -> destruct_lets_no_interleave e
| FStar | Coq | Lean -> destruct_lets e
in
- (* Open a box for the whole expression.
-
- In the case of Lean, we use a vbox so that line breaks are inserted
- at the end of every let-binding: let-bindings are indeed not ended
- with an "in" keyword.
- *)
- if !Config.backend = Lean then F.pp_open_vbox fmt 0 else F.pp_open_hvbox fmt 0;
- (* Open parentheses *)
- if inside && !backend <> Lean then F.pp_print_string fmt "(";
(* Extract the let-bindings *)
let extract_let (ctx : extraction_ctx) (monadic : bool) (lv : typed_pattern)
(re : texpression) : extraction_ctx =
@@ -711,22 +743,19 @@ and extract_lets (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
(* Return *)
ctx
in
+ (* Open a box for the whole expression.
+
+ In the case of Lean, we use a vbox so that line breaks are inserted
+ at the end of every let-binding: let-bindings are indeed not ended
+ with an "in" keyword.
+ *)
+ if !Config.backend = Lean then F.pp_open_vbox fmt 0 else F.pp_open_hvbox fmt 0;
+ (* Open parentheses *)
+ if inside && !backend <> Lean then F.pp_print_string fmt "(";
(* If Lean and HOL4, we rely on monadic blocks, so we insert a do and open a new box
immediately *)
- let wrap_in_do_od =
- match !backend with
- | Lean ->
- (* For Lean, we wrap in a block iff at least one of the let-bindings is monadic *)
- List.exists (fun (m, _, _) -> m) lets
- | HOL4 ->
- (* HOL4 is similar to HOL4, but we add a sanity check *)
- let wrap_in_do = List.exists (fun (m, _, _) -> m) lets in
- if wrap_in_do then assert (List.for_all (fun (m, _, _) -> m) lets);
- wrap_in_do
- | FStar | Coq -> false
- in
+ let wrap_in_do_od = lets_require_wrap_in_do lets in
if wrap_in_do_od then (
- F.pp_open_vbox fmt (if !backend = Lean then ctx.indent_incr else 0);
F.pp_print_string fmt "do";
F.pp_print_space fmt ());
let ctx =
@@ -742,11 +771,10 @@ and extract_lets (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
F.pp_close_box fmt ();
(* do-box (Lean and HOL4 only) *)
- if wrap_in_do_od then (
+ if wrap_in_do_od then
if !backend = HOL4 then (
F.pp_print_space fmt ();
F.pp_print_string fmt "od");
- F.pp_close_box fmt ());
(* Close parentheses *)
if inside && !backend <> Lean then F.pp_print_string fmt ")";
(* Close the box for the whole expression *)
@@ -1319,16 +1347,16 @@ let extract_fun_comment (ctx : extraction_ctx) (fmt : F.formatter)
(Pure.FunId (FRegular def.def_id), def.loop_id, def.back_id)
ctx.fun_name_info
in
- let comment_pre = "[" ^ name_to_string ctx def.llbc_name ^ "]: " in
+ let comment_pre = "[" ^ name_to_string ctx def.llbc_name ^ "]:" in
let comment =
let loop_comment =
match def.loop_id with
| None -> ""
- | Some id -> "loop " ^ LoopId.to_string id ^ ": "
+ | Some id -> " loop " ^ LoopId.to_string id ^ ":"
in
let fwd_back_comment =
match def.back_id with
- | None -> [ "forward function" ]
+ | None -> if !Config.return_back_funs then [] else [ "forward function" ]
| Some id ->
(* Check if there is only one backward function, and no forward function *)
if (not keep_fwd) && num_backs = 1 then
@@ -1340,9 +1368,9 @@ let extract_fun_comment (ctx : extraction_ctx) (fmt : F.formatter)
else [ "backward function " ^ T.RegionGroupId.to_string id ]
in
match fwd_back_comment with
- | [] -> raise (Failure "Unreachable")
- | [ s ] -> [ comment_pre ^ loop_comment ^ s ]
- | s :: sl -> (comment_pre ^ loop_comment ^ s) :: sl
+ | [] -> [ comment_pre ^ loop_comment ]
+ | [ s ] -> [ comment_pre ^ loop_comment ^ " " ^ s ]
+ | s :: sl -> (comment_pre ^ loop_comment ^ " " ^ s) :: sl
in
extract_comment_with_span fmt comment def.meta.span
@@ -1470,7 +1498,7 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter)
let inputs_lvs =
let all_inputs = (Option.get def.body).inputs_lvs in
let num_fwd_inputs =
- def.signature.info.num_fwd_inputs_with_fuel_with_state
+ def.signature.fwd_info.fwd_info.num_inputs_with_fuel_with_state
in
Collections.List.prefix num_fwd_inputs all_inputs
in
@@ -1516,7 +1544,7 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter)
let def_body = Option.get def.body in
let all_vars = List.map (fun (v : var) -> v.id) def_body.inputs in
let num_fwd_inputs =
- def.signature.info.num_fwd_inputs_with_fuel_with_state
+ def.signature.fwd_info.fwd_info.num_inputs_with_fuel_with_state
in
let vars = Collections.List.prefix num_fwd_inputs all_vars in
@@ -1794,7 +1822,6 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter)
assert body.is_global_decl_body;
assert (Option.is_none body.back_id);
assert (body.signature.inputs = []);
- assert (List.length body.signature.doutputs = 1);
assert (body.signature.generics = empty_generic_params);
(* Add a break then the name of the corresponding LLBC declaration *)
@@ -1813,7 +1840,8 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter)
let decl_ty, body_ty =
let ty = body.signature.output in
- if body.signature.info.effect_info.can_fail then (unwrap_result_ty ty, ty)
+ if body.signature.fwd_info.effect_info.can_fail then
+ (unwrap_result_ty ty, ty)
else (ty, mk_result_ty ty)
in
match body.body with
@@ -1984,7 +2012,8 @@ let extract_trait_decl_method_names (ctx : extraction_ctx)
(* We add one field per required forward/backward function *)
let get_funs_for_id (id : fun_decl_id) : fun_decl list =
let trans : pure_fun_translation = FunDeclId.Map.find id ctx.trans_funs in
- List.map (fun f -> f.f) (trans.fwd :: trans.backs)
+ if !Config.return_back_funs then [ trans.fwd.f ]
+ else List.map (fun f -> f.f) (trans.fwd :: trans.backs)
in
match builtin_info with
| None ->