diff options
Diffstat (limited to 'compiler/Extract.ml')
-rw-r--r-- | compiler/Extract.ml | 101 |
1 files changed, 70 insertions, 31 deletions
diff --git a/compiler/Extract.ml b/compiler/Extract.ml index 3ea3a862..0e9a53df 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -172,8 +172,8 @@ let keywords () = "macro"; "match"; "namespace"; + "opaque"; "open"; - "return"; "run_cmd"; "set_option"; "simp"; @@ -569,6 +569,10 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) fname ^ lp_suffix ^ suffix in + let opaque_pre () = + match !Config.backend with FStar | Coq -> "" | Lean -> "opaque_defs." + in + let var_basename (_varset : StringSet.t) (basename : string option) (ty : ty) : string = (* If there is a basename, we use it *) @@ -699,6 +703,7 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) fun_name; termination_measure_name; decreases_proof_name; + opaque_pre; var_basename; type_var_basename; append_index; @@ -726,6 +731,7 @@ let unit_name () = match !backend with Lean -> "Unit" | Coq | FStar -> "unit" *) let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) (ty : ty) : unit = + let extract_rec = extract_ty ctx fmt in match ty with | Adt (type_id, tys) -> ( match type_id with @@ -743,15 +749,18 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) in F.pp_print_string fmt product; F.pp_print_space fmt ()) - (extract_ty ctx fmt true) tys; + (extract_rec true) tys; F.pp_print_string fmt ")") | AdtId _ | Assumed _ -> let print_paren = inside && tys <> [] in if print_paren then F.pp_print_string fmt "("; - F.pp_print_string fmt (ctx_get_type type_id ctx); + (* TODO: for now, only the opaque *functions* are extracted in the + opaque module. The opaque *types* are assumed. *) + let with_opaque_pre = false in + F.pp_print_string fmt (ctx_get_type with_opaque_pre type_id ctx); if tys <> [] then F.pp_print_space fmt (); - Collections.List.iter_link (F.pp_print_space fmt) - (extract_ty ctx fmt true) tys; + Collections.List.iter_link (F.pp_print_space fmt) (extract_rec true) + tys; if print_paren then F.pp_print_string fmt ")") | TypeVar vid -> F.pp_print_string fmt (ctx_get_type_var vid ctx) | Bool -> F.pp_print_string fmt ctx.fmt.bool_name @@ -760,11 +769,11 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) | Str -> F.pp_print_string fmt ctx.fmt.str_name | Arrow (arg_ty, ret_ty) -> if inside then F.pp_print_string fmt "("; - extract_ty ctx fmt false arg_ty; + extract_rec false arg_ty; F.pp_print_space fmt (); F.pp_print_string fmt "->"; F.pp_print_space fmt (); - extract_ty ctx fmt false ret_ty; + extract_rec false ret_ty; if inside then F.pp_print_string fmt ")" | Array _ | Slice _ -> raise Unimplemented @@ -969,7 +978,9 @@ let extract_type_decl_struct_body (ctx : extraction_ctx) (fmt : F.formatter) (* If Coq: print the constructor name *) (* TODO: remove superfluous test not is_rec below *) if !backend = Coq && not is_rec then ( - F.pp_print_string fmt (ctx_get_struct (AdtId def.def_id) ctx); + let with_opaque_pre = false in + F.pp_print_string fmt + (ctx_get_struct with_opaque_pre (AdtId def.def_id) ctx); F.pp_print_string fmt " "); if !backend <> Lean then F.pp_print_string fmt "{"; F.pp_print_break fmt 1 ctx.indent_incr; @@ -1000,8 +1011,9 @@ let extract_type_decl_struct_body (ctx : extraction_ctx) (fmt : F.formatter) a group of mutually recursive types: we extract it as an inductive type *) assert (is_rec && !backend = Coq); - let cons_name = ctx_get_struct (AdtId def.def_id) ctx in - let def_name = ctx_get_local_type def.def_id ctx in + let with_opaque_pre = false in + let cons_name = ctx_get_struct with_opaque_pre (AdtId def.def_id) ctx in + let def_name = ctx_get_local_type with_opaque_pre def.def_id ctx in extract_type_decl_variant ctx fmt def_name type_params cons_name fields) in () @@ -1043,10 +1055,12 @@ let extract_type_decl (ctx : extraction_ctx) (fmt : F.formatter) The boolean [is_opaque_coq] is used to detect this case. *) - let is_opaque_coq = !backend = Coq && type_kind = None in + let is_opaque = type_kind = None in + let is_opaque_coq = !backend = Coq && is_opaque in let use_forall = is_opaque_coq && def.type_params <> [] in (* Retrieve the definition name *) - let def_name = ctx_get_local_type def.def_id ctx in + let with_opaque_pre = false in + let def_name = ctx_get_local_type with_opaque_pre def.def_id ctx in (* Add the type params - note that we need those bindings only for the * body translation (they are not top-level) *) let ctx_body, type_params = ctx_add_type_params def.type_params ctx in @@ -1173,7 +1187,8 @@ let extract_type_decl_coq_arguments (ctx : extraction_ctx) (fmt : F.formatter) | Struct fields -> let adt_id = AdtId decl.def_id in (* Generate the instruction for the record constructor *) - let cons_name = ctx_get_struct adt_id ctx in + let with_opaque_pre = false in + let cons_name = ctx_get_struct with_opaque_pre adt_id ctx in extract_arguments_info cons_name fields; (* Generate the instruction for the record projectors, if there are *) let is_rec = decl_is_from_rec_group kind in @@ -1215,8 +1230,11 @@ let extract_type_decl_record_field_projectors (ctx : extraction_ctx) let ctx, type_params = ctx_add_type_params decl.type_params ctx in let ctx, record_var = ctx_add_var "x" (VarId.of_int 0) ctx in let ctx, field_var = ctx_add_var "x" (VarId.of_int 1) ctx in - let def_name = ctx_get_local_type decl.def_id ctx in - let cons_name = ctx_get_struct (AdtId decl.def_id) ctx in + let with_opaque_pre = false in + let def_name = ctx_get_local_type with_opaque_pre decl.def_id ctx in + let cons_name = + ctx_get_struct with_opaque_pre (AdtId decl.def_id) ctx + in let extract_field_proj (field_id : FieldId.id) (_ : field) : unit = F.pp_print_space fmt (); (* Outer box for the projector definition *) @@ -1500,12 +1518,16 @@ let extract_adt_g_value * [{ field0=...; ...; fieldn=...; }] in case of structures. *) let cons = + (* The ADT shouldn't be opaque *) + let with_opaque_pre = false in match variant_id with | Some vid -> if !backend = Lean then - ctx_get_type adt_id ctx ^ "." ^ ctx_get_variant adt_id vid ctx + ctx_get_type with_opaque_pre adt_id ctx + ^ "." + ^ ctx_get_variant adt_id vid ctx else ctx_get_variant adt_id vid ctx - | None -> ctx_get_struct adt_id ctx + | None -> ctx_get_struct with_opaque_pre adt_id ctx in if inside && field_values <> [] then F.pp_print_string fmt "("; F.pp_print_string fmt cons; @@ -1523,7 +1545,8 @@ let extract_adt_g_value (* Extract globals in the same way as variables *) let extract_global (ctx : extraction_ctx) (fmt : F.formatter) (id : A.GlobalDeclId.id) : unit = - F.pp_print_string fmt (ctx_get_global id ctx) + let with_opaque_pre = ctx.use_opaque_pre in + F.pp_print_string fmt (ctx_get_global with_opaque_pre id ctx) (** [inside]: see {!extract_ty}. @@ -1643,11 +1666,8 @@ and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter) (* Open a box for the function call *) F.pp_open_hovbox fmt ctx.indent_incr; (* Print the function name *) - let fun_name = - Option.value - ~default:(ctx_get_function fun_id ctx) - (ctx_maybe_get (DeclaredId fun_id) ctx) - in + let with_opaque_pre = ctx.use_opaque_pre in + let fun_name = ctx_get_function with_opaque_pre fun_id ctx in F.pp_print_string fmt fun_name; (* Print the type parameters *) List.iter @@ -1703,14 +1723,16 @@ and extract_adt_cons (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) * applied structure constructors. *) let cons = + (* The ADT shouldn't be opaque *) + let with_opaque_pre = false in match adt_cons.variant_id with | Some vid -> if !backend = Lean then - ctx_get_type adt_cons.adt_id ctx + ctx_get_type with_opaque_pre adt_cons.adt_id ctx ^ "." ^ ctx_get_variant adt_cons.adt_id vid ctx else ctx_get_variant adt_cons.adt_id vid ctx - | None -> ctx_get_struct adt_cons.adt_id ctx + | None -> ctx_get_struct with_opaque_pre adt_cons.adt_id ctx in let is_lean_struct = !backend = Lean && adt_cons.variant_id = None in if is_lean_struct then ( @@ -2309,8 +2331,10 @@ let extract_fun_decl (ctx : extraction_ctx) (fmt : F.formatter) (kind : decl_kind) (has_decreases_clause : bool) (def : fun_decl) : unit = assert (not def.is_global_decl_body); (* Retrieve the function name *) + let with_opaque_pre = false in let def_name = - ctx_get_local_function def.def_id def.loop_id def.back_id ctx + ctx_get_local_function with_opaque_pre def.def_id def.loop_id def.back_id + ctx in (* Add a break before *) F.pp_print_break fmt 0 0; @@ -2612,9 +2636,12 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter) extract_comment fmt ("[" ^ Print.global_name_to_string global.name ^ "]"); F.pp_print_space fmt (); - let decl_name = ctx_get_global global.def_id ctx in + let with_opaque_pre = false in + let decl_name = ctx_get_global with_opaque_pre global.def_id ctx in let body_name = - ctx_get_function (FromLlbc (Regular global.body_id, None, None)) ctx + ctx_get_function with_opaque_pre + (FromLlbc (Regular global.body_id, None, None)) + ctx in let decl_ty, body_ty = @@ -2685,8 +2712,12 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_string fmt "assert_norm"; F.pp_print_space fmt (); F.pp_print_string fmt "("; + (* Note that if the function is opaque, the unit test will fail + because the normalizer will get stuck *) + let with_opaque_pre = ctx.use_opaque_pre in let fun_name = - ctx_get_local_function def.def_id def.loop_id def.back_id ctx + ctx_get_local_function with_opaque_pre def.def_id def.loop_id + def.back_id ctx in F.pp_print_string fmt fun_name; if sg.inputs <> [] then ( @@ -2701,8 +2732,12 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_string fmt "Check"; F.pp_print_space fmt (); F.pp_print_string fmt "("; + (* Note that if the function is opaque, the unit test will fail + because the normalizer will get stuck *) + let with_opaque_pre = ctx.use_opaque_pre in let fun_name = - ctx_get_local_function def.def_id def.loop_id def.back_id ctx + ctx_get_local_function with_opaque_pre def.def_id def.loop_id + def.back_id ctx in F.pp_print_string fmt fun_name; if sg.inputs <> [] then ( @@ -2714,8 +2749,12 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_string fmt "#assert"; F.pp_print_space fmt (); F.pp_print_string fmt "("; + (* Note that if the function is opaque, the unit test will fail + because the normalizer will get stuck *) + let with_opaque_pre = ctx.use_opaque_pre in let fun_name = - ctx_get_local_function def.def_id def.loop_id def.back_id ctx + ctx_get_local_function with_opaque_pre def.def_id def.loop_id + def.back_id ctx in F.pp_print_string fmt fun_name; if sg.inputs <> [] then ( |