diff options
author | Son HO | 2023-12-23 01:46:58 +0100 |
---|---|---|
committer | GitHub | 2023-12-23 01:46:58 +0100 |
commit | 15a7d7b7322a1cd0ebeb328fde214060e23fa8b4 (patch) | |
tree | 6cce7d76969870f5bc18c5a7cd585e8873a1c0dc /compiler/Extract.ml | |
parent | c3e0b90e422cbd902ee6d2b47073940c0017b7fb (diff) | |
parent | 63ccbd914d5d44aa30dee38a6fcc019310ab640b (diff) |
Merge pull request #64 from AeneasVerif/son/merge_back
Merge the forward/backward functions
Diffstat (limited to '')
-rw-r--r-- | compiler/Extract.ml | 157 |
1 files changed, 93 insertions, 64 deletions
diff --git a/compiler/Extract.ml b/compiler/Extract.ml index 20cdb20b..87dcb1fd 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -43,8 +43,12 @@ let extract_fun_decl_register_names (ctx : extraction_ctx) } | _ -> ctx in - let backs = List.map (fun f -> f.f) def.backs in - let funs = if def.keep_fwd then def.fwd.f :: backs else backs in + let funs = + if !Config.return_back_funs then [ def.fwd.f ] + else + let backs = List.map (fun f -> f.f) def.backs in + if def.keep_fwd then def.fwd.f :: backs else backs + in List.fold_left (fun ctx (f : fun_decl) -> let open ExtractBuiltin in @@ -128,9 +132,15 @@ let extract_adt_g_value F.pp_print_string fmt "tt"; ctx) else - (* If there is exactly one value, we don't print the parentheses *) + (* If there is exactly one value, we don't print the parentheses. + Also, for Coq, we need the special syntax ['(...)] if we destruct + a tuple pattern in a let-binding and the tuple has > 2 values. + *) let lb, rb = - if List.length field_values = 1 then ("", "") else ("(", ")") + if List.length field_values = 1 then ("", "") + else if !backend = Coq && is_single_pat && List.length field_values > 2 + then ("'(", ")") + else ("(", ")") in F.pp_print_string fmt lb; let ctx = @@ -237,30 +247,60 @@ let fun_builtin_filter_types (id : FunDeclId.id) (types : 'a list) Result.Ok types (** [inside]: see {!extract_ty}. + [with_type]: do we also generate a type annotation? This is necessary for + backends like Coq when we write lambdas (Coq is not powerful enough to + infer the type). As a pattern can introduce new variables, we return an extraction context updated with new bindings. *) let rec extract_typed_pattern (ctx : extraction_ctx) (fmt : F.formatter) - (is_let : bool) (inside : bool) (v : typed_pattern) : extraction_ctx = - match v.value with - | PatConstant cv -> - extract_literal fmt inside cv; - ctx - | PatVar (v, _) -> - let vname = ctx_compute_var_basename ctx v.basename v.ty in - let ctx, vname = ctx_add_var vname v.id ctx in - F.pp_print_string fmt vname; - ctx - | PatDummy -> - F.pp_print_string fmt "_"; - ctx - | PatAdt av -> - let extract_value ctx inside v = - extract_typed_pattern ctx fmt is_let inside v - in - extract_adt_g_value extract_value fmt ctx is_let inside av.variant_id - av.field_values v.ty + (is_let : bool) (inside : bool) ?(with_type = false) (v : typed_pattern) : + extraction_ctx = + if with_type then F.pp_print_string fmt "("; + let inside = inside && not with_type in + let ctx = + match v.value with + | PatConstant cv -> + extract_literal fmt inside cv; + ctx + | PatVar (v, _) -> + let vname = ctx_compute_var_basename ctx v.basename v.ty in + let ctx, vname = ctx_add_var vname v.id ctx in + F.pp_print_string fmt vname; + ctx + | PatDummy -> + F.pp_print_string fmt "_"; + ctx + | PatAdt av -> + let extract_value ctx inside v = + extract_typed_pattern ctx fmt is_let inside v + in + extract_adt_g_value extract_value fmt ctx is_let inside av.variant_id + av.field_values v.ty + in + if with_type then ( + F.pp_print_space fmt (); + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + extract_ty ctx fmt TypeDeclId.Set.empty false v.ty; + F.pp_print_string fmt ")"); + ctx + +(** Return true if we need to wrap a succession of let-bindings in a [do ...] + block (because some of them are monadic) *) +let lets_require_wrap_in_do (lets : (bool * typed_pattern * texpression) list) : + bool = + match !backend with + | Lean -> + (* For Lean, we wrap in a block iff at least one of the let-bindings is monadic *) + List.exists (fun (m, _, _) -> m) lets + | HOL4 -> + (* HOL4 is similar to HOL4, but we add a sanity check *) + let wrap_in_do = List.exists (fun (m, _, _) -> m) lets in + if wrap_in_do then assert (List.for_all (fun (m, _, _) -> m) lets); + wrap_in_do + | FStar | Coq -> false (** [inside]: controls the introduction of parentheses. See [extract_ty] @@ -285,9 +325,9 @@ let rec extract_texpression (ctx : extraction_ctx) (fmt : F.formatter) | App _ -> let app, args = destruct_apps e in extract_App ctx fmt inside app args - | Abs _ -> - let xl, e = destruct_abs_list e in - extract_Abs ctx fmt inside xl e + | Lambda _ -> + let xl, e = destruct_lambdas e in + extract_Lambda ctx fmt inside xl e | Qualif _ -> (* We use the app case *) extract_App ctx fmt inside e [] @@ -574,7 +614,7 @@ and extract_field_projector (ctx : extraction_ctx) (fmt : F.formatter) (* No argument: shouldn't happen *) raise (Failure "Unreachable") -and extract_Abs (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) +and extract_Lambda (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) (xl : typed_pattern list) (e : texpression) : unit = (* Open a box for the abs expression *) F.pp_open_hovbox fmt ctx.indent_incr; @@ -583,15 +623,16 @@ and extract_Abs (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) (* Print the lambda - note that there should always be at least one variable *) assert (xl <> []); F.pp_print_string fmt "fun"; + let with_type = !backend = Coq in let ctx = List.fold_left (fun ctx x -> F.pp_print_space fmt (); - extract_typed_pattern ctx fmt true true x) + extract_typed_pattern ctx fmt true true ~with_type x) ctx xl in F.pp_print_space fmt (); - if !backend = Lean then F.pp_print_string fmt "=>" + if !backend = Lean || !backend = Coq then F.pp_print_string fmt "=>" else F.pp_print_string fmt "->"; F.pp_print_space fmt (); (* Print the body *) @@ -630,15 +671,6 @@ and extract_lets (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) | HOL4 -> destruct_lets_no_interleave e | FStar | Coq | Lean -> destruct_lets e in - (* Open a box for the whole expression. - - In the case of Lean, we use a vbox so that line breaks are inserted - at the end of every let-binding: let-bindings are indeed not ended - with an "in" keyword. - *) - if !Config.backend = Lean then F.pp_open_vbox fmt 0 else F.pp_open_hvbox fmt 0; - (* Open parentheses *) - if inside && !backend <> Lean then F.pp_print_string fmt "("; (* Extract the let-bindings *) let extract_let (ctx : extraction_ctx) (monadic : bool) (lv : typed_pattern) (re : texpression) : extraction_ctx = @@ -711,22 +743,19 @@ and extract_lets (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) (* Return *) ctx in + (* Open a box for the whole expression. + + In the case of Lean, we use a vbox so that line breaks are inserted + at the end of every let-binding: let-bindings are indeed not ended + with an "in" keyword. + *) + if !Config.backend = Lean then F.pp_open_vbox fmt 0 else F.pp_open_hvbox fmt 0; + (* Open parentheses *) + if inside && !backend <> Lean then F.pp_print_string fmt "("; (* If Lean and HOL4, we rely on monadic blocks, so we insert a do and open a new box immediately *) - let wrap_in_do_od = - match !backend with - | Lean -> - (* For Lean, we wrap in a block iff at least one of the let-bindings is monadic *) - List.exists (fun (m, _, _) -> m) lets - | HOL4 -> - (* HOL4 is similar to HOL4, but we add a sanity check *) - let wrap_in_do = List.exists (fun (m, _, _) -> m) lets in - if wrap_in_do then assert (List.for_all (fun (m, _, _) -> m) lets); - wrap_in_do - | FStar | Coq -> false - in + let wrap_in_do_od = lets_require_wrap_in_do lets in if wrap_in_do_od then ( - F.pp_open_vbox fmt (if !backend = Lean then ctx.indent_incr else 0); F.pp_print_string fmt "do"; F.pp_print_space fmt ()); let ctx = @@ -742,11 +771,10 @@ and extract_lets (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) F.pp_close_box fmt (); (* do-box (Lean and HOL4 only) *) - if wrap_in_do_od then ( + if wrap_in_do_od then if !backend = HOL4 then ( F.pp_print_space fmt (); F.pp_print_string fmt "od"); - F.pp_close_box fmt ()); (* Close parentheses *) if inside && !backend <> Lean then F.pp_print_string fmt ")"; (* Close the box for the whole expression *) @@ -1319,16 +1347,16 @@ let extract_fun_comment (ctx : extraction_ctx) (fmt : F.formatter) (Pure.FunId (FRegular def.def_id), def.loop_id, def.back_id) ctx.fun_name_info in - let comment_pre = "[" ^ name_to_string ctx def.llbc_name ^ "]: " in + let comment_pre = "[" ^ name_to_string ctx def.llbc_name ^ "]:" in let comment = let loop_comment = match def.loop_id with | None -> "" - | Some id -> "loop " ^ LoopId.to_string id ^ ": " + | Some id -> " loop " ^ LoopId.to_string id ^ ":" in let fwd_back_comment = match def.back_id with - | None -> [ "forward function" ] + | None -> if !Config.return_back_funs then [] else [ "forward function" ] | Some id -> (* Check if there is only one backward function, and no forward function *) if (not keep_fwd) && num_backs = 1 then @@ -1340,9 +1368,9 @@ let extract_fun_comment (ctx : extraction_ctx) (fmt : F.formatter) else [ "backward function " ^ T.RegionGroupId.to_string id ] in match fwd_back_comment with - | [] -> raise (Failure "Unreachable") - | [ s ] -> [ comment_pre ^ loop_comment ^ s ] - | s :: sl -> (comment_pre ^ loop_comment ^ s) :: sl + | [] -> [ comment_pre ^ loop_comment ] + | [ s ] -> [ comment_pre ^ loop_comment ^ " " ^ s ] + | s :: sl -> (comment_pre ^ loop_comment ^ " " ^ s) :: sl in extract_comment_with_span fmt comment def.meta.span @@ -1470,7 +1498,7 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) let inputs_lvs = let all_inputs = (Option.get def.body).inputs_lvs in let num_fwd_inputs = - def.signature.info.num_fwd_inputs_with_fuel_with_state + def.signature.fwd_info.fwd_info.num_inputs_with_fuel_with_state in Collections.List.prefix num_fwd_inputs all_inputs in @@ -1516,7 +1544,7 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) let def_body = Option.get def.body in let all_vars = List.map (fun (v : var) -> v.id) def_body.inputs in let num_fwd_inputs = - def.signature.info.num_fwd_inputs_with_fuel_with_state + def.signature.fwd_info.fwd_info.num_inputs_with_fuel_with_state in let vars = Collections.List.prefix num_fwd_inputs all_vars in @@ -1794,7 +1822,6 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter) assert body.is_global_decl_body; assert (Option.is_none body.back_id); assert (body.signature.inputs = []); - assert (List.length body.signature.doutputs = 1); assert (body.signature.generics = empty_generic_params); (* Add a break then the name of the corresponding LLBC declaration *) @@ -1813,7 +1840,8 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter) let decl_ty, body_ty = let ty = body.signature.output in - if body.signature.info.effect_info.can_fail then (unwrap_result_ty ty, ty) + if body.signature.fwd_info.effect_info.can_fail then + (unwrap_result_ty ty, ty) else (ty, mk_result_ty ty) in match body.body with @@ -1984,7 +2012,8 @@ let extract_trait_decl_method_names (ctx : extraction_ctx) (* We add one field per required forward/backward function *) let get_funs_for_id (id : fun_decl_id) : fun_decl list = let trans : pure_fun_translation = FunDeclId.Map.find id ctx.trans_funs in - List.map (fun f -> f.f) (trans.fwd :: trans.backs) + if !Config.return_back_funs then [ trans.fwd.f ] + else List.map (fun f -> f.f) (trans.fwd :: trans.backs) in match builtin_info with | None -> |