From 955fdab55304979ba2d61432ea654241f20abaa4 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 15 Dec 2023 18:14:12 +0100 Subject: Make progress on propagating the changes --- compiler/Extract.ml | 8 ++-- compiler/PrintPure.ml | 16 ++----- compiler/Pure.ml | 3 +- compiler/PureMicroPasses.ml | 110 ++++++++++++++++++++++---------------------- compiler/PureTypeCheck.ml | 8 +--- compiler/PureUtils.ml | 24 +++------- 6 files changed, 72 insertions(+), 97 deletions(-) diff --git a/compiler/Extract.ml b/compiler/Extract.ml index 1ea26d79..7e2efd8a 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -285,9 +285,9 @@ let rec extract_texpression (ctx : extraction_ctx) (fmt : F.formatter) | App _ -> let app, args = destruct_apps e in extract_App ctx fmt inside app args - | Abs _ -> - let xl, e = destruct_abs_list e in - extract_Abs ctx fmt inside xl e + | Lambda _ -> + let xl, e = destruct_lambdas e in + extract_Lambda ctx fmt inside xl e | Qualif _ -> (* We use the app case *) extract_App ctx fmt inside e [] @@ -574,7 +574,7 @@ and extract_field_projector (ctx : extraction_ctx) (fmt : F.formatter) (* No argument: shouldn't happen *) raise (Failure "Unreachable") -and extract_Abs (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) +and extract_Lambda (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) (xl : typed_pattern list) (e : texpression) : unit = (* Open a box for the abs expression *) F.pp_open_hovbox fmt ctx.indent_incr; diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml index 3a5ce513..79506c04 100644 --- a/compiler/PrintPure.ml +++ b/compiler/PrintPure.ml @@ -543,9 +543,9 @@ let rec texpression_to_string (env : fmt_env) (inside : bool) (indent : string) let app, args = destruct_apps e in (* Convert to string *) app_to_string env inside indent indent_incr app args - | Abs _ -> - let xl, e = destruct_abs_list e in - let e = abs_to_string env indent indent_incr xl e in + | Lambda _ -> + let xl, e = destruct_lambdas e in + let e = lambda_to_string env indent indent_incr xl e in if inside then "(" ^ e ^ ")" else e | Qualif _ -> (* Qualifier without arguments *) @@ -592,14 +592,6 @@ let rec texpression_to_string (env : fmt_env) (inside : bool) (indent : string) in "[ " ^ String.concat ", " fields ^ " ]" | _ -> raise (Failure "Unexpected")) - | Lambda _ -> - let pats, e = destruct_lambdas e in - let vars = - String.concat " " (List.map (typed_pattern_to_string env) pats) - in - let e = texpression_to_string env false indent indent_incr e in - let s = "λ " ^ vars ^ " => " ^ e in - if inside then "(" ^ s ^ ")" else s | Meta (meta, e) -> ( let meta_s = emeta_to_string env meta in let e = texpression_to_string env inside indent indent_incr e in @@ -668,7 +660,7 @@ and app_to_string (env : fmt_env) (inside : bool) (indent : string) (* Add parentheses *) if all_args <> [] && inside then "(" ^ e ^ ")" else e -and abs_to_string (env : fmt_env) (indent : string) (indent_incr : string) +and lambda_to_string (env : fmt_env) (indent : string) (indent_incr : string) (xl : typed_pattern list) (e : texpression) : string = let xl = List.map (typed_pattern_to_string env) xl in let e = texpression_to_string env false indent indent_incr e in diff --git a/compiler/Pure.ml b/compiler/Pure.ml index eb6b00c8..ddacf0c4 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -684,7 +684,7 @@ type expression = field accesses with calls to projectors over fields (when there are clashes of field names, some provers like F* get pretty bad...) *) - | Abs of typed_pattern * texpression (** Lambda abstraction: [fun x -> e] *) + | Lambda of typed_pattern * texpression (** Lambda abstraction: [λ x => e] *) | Qualif of qualif (** A top-level qualifier *) | Let of bool * typed_pattern * texpression * texpression (** Let binding. @@ -728,7 +728,6 @@ type expression = | Switch of texpression * switch_body | Loop of loop (** See the comments for {!loop} *) | StructUpdate of struct_update (** See the comments for {!struct_update} *) - | Lambda of typed_pattern * texpression (** [λ x => e] *) | Meta of (emeta[@opaque]) * texpression (** Meta-information *) and switch_body = If of texpression * texpression | Match of match_branch list diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index d92b3de0..0102b13e 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -385,17 +385,17 @@ let compute_pretty_names (def : fun_decl) : fun_decl = let ctx, arg = update_texpression arg ctx in let e = App (app, arg) in (ctx, e) - | Abs (x, e) -> update_abs x e ctx | Qualif _ -> (* nothing to do *) (ctx, e.e) | Let (monadic, lb, re, e) -> update_let monadic lb re e ctx | Switch (scrut, body) -> update_switch_body scrut body ctx | Loop loop -> update_loop loop ctx | StructUpdate supd -> update_struct_update supd ctx + | Lambda (lb, e) -> update_lambda lb e ctx | Meta (meta, e) -> update_emeta meta e ctx in (ctx, { e; ty }) (* *) - and update_abs (x : typed_pattern) (e : texpression) (ctx : pn_ctx) : + and update_lambda (x : typed_pattern) (e : texpression) (ctx : pn_ctx) : pn_ctx * expression = (* We first add the left-constraint *) let ctx = add_left_constraint x ctx in @@ -404,7 +404,7 @@ let compute_pretty_names (def : fun_decl) : fun_decl = (* Update the abstracted value *) let x = update_typed_pattern ctx x in (* Put together *) - (ctx, Abs (x, e)) + (ctx, Lambda (x, e)) (* *) and update_let (monadic : bool) (lv : typed_pattern) (re : texpression) (e : texpression) (ctx : pn_ctx) : pn_ctx * expression = @@ -890,12 +890,12 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx) let call_is_child = check_call func1 generics1 args1 in if call_is_child then fun () -> true else fun () -> self#visit_texpression env e ()) + | Lambda (_, e) -> self#visit_texpression env e | App _ -> ( fun () -> match opt_destruct_function_call e with | Some (func1, tys1, args1) -> check_call func1 tys1 args1 | None -> false) - | Abs (_, e) -> self#visit_texpression env e | Qualif _ -> (* Note that this case includes functions without arguments *) fun () -> false @@ -975,7 +975,7 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx) | Var _ | CVar _ | Const _ | App _ | Qualif _ | Switch (_, _) | Meta (_, _) - | StructUpdate _ | Abs _ -> + | StructUpdate _ | Lambda _ -> super#visit_expression env e | Let (monadic, lv, re, e) -> (* Compute the set of values used in the next expression *) @@ -1323,28 +1323,20 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = method! visit_Loop env loop = let fun_sig = def.signature in - let fun_sig_info = fun_sig.info in - let fun_effect_info = fun_sig_info.effect_info in + let fwd_info = fun_sig.fwd_info in + let fwd_effect_info = fwd_info.effect_info in (* TODO: *) assert (not !Config.return_back_funs); (* Generate the loop definition *) - let loop_effect_info = - { - stateful_group = fun_effect_info.stateful_group; - stateful = fun_effect_info.stateful; - can_fail = fun_effect_info.can_fail; - can_diverge = fun_effect_info.can_diverge; - is_rec = fun_effect_info.is_rec; - } - in + let loop_fwd_effect_info = fwd_effect_info in - let loop_sig_info = + let loop_fwd_sig_info : fun_sig_info = let fuel = if !Config.use_fuel then 1 else 0 in let num_inputs = List.length loop.inputs in let fwd_info : inputs_info = - let info = fun_sig_info.fwd_info in + let info = fwd_info.fwd_info in let fwd_state = info.num_inputs_with_fuel_with_state - info.num_inputs_with_fuel_no_state @@ -1358,48 +1350,48 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = } in - { fwd_info; effect_info = loop_effect_info } + { fwd_info; effect_info = loop_fwd_effect_info } in - assert (fun_sig_info_is_wf loop_sig_info); + assert (fun_sig_info_is_wf loop_fwd_sig_info); let inputs_tys = - (* TODO: *) - assert (not !Config.return_back_funs); - let fuel = if !Config.use_fuel then [ mk_fuel_ty ] else [] in let fwd_inputs = List.map (fun (v : var) -> v.ty) loop.inputs in - let info = fun_sig_info.fwd_info in - let state = + let info = fwd_info.fwd_info in + let fwd_state = Collections.List.subslice fun_sig.inputs info.num_inputs_with_fuel_no_state info.num_inputs_with_fuel_with_state in - let _, back_inputs = - Collections.List.split_at fun_sig.inputs - info.num_inputs_with_fuel_with_state + let back_inputs = + if !Config.return_back_funs then [] + else + snd + (Collections.List.split_at fun_sig.inputs + info.num_inputs_with_fuel_with_state) in - List.concat [ fuel; fwd_inputs; state; back_inputs ] + List.concat [ fuel; fwd_inputs; fwd_state; back_inputs ] in - let output, doutputs = + let output = match loop.back_output_tys with | None -> (* Forward function: the return type is the same as the parent function *) - (fun_sig.output, fun_sig.doutputs) + fun_sig.output | Some doutputs -> (* Backward function: custom return type *) let output = mk_simpl_tuple_ty doutputs in let output = - if loop_effect_info.stateful then + if loop_fwd_effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; output ] else output in let output = - if loop_effect_info.can_fail then mk_result_ty output + if loop_fwd_effect_info.can_fail then mk_result_ty output else output in - (output, doutputs) + output in let loop_sig = @@ -1409,8 +1401,8 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = preds = fun_sig.preds; inputs = inputs_tys; output; - doutputs; - info = loop_sig_info; + fwd_info = loop_fwd_sig_info; + back_effect_info = fun_sig.back_effect_info; } in @@ -1427,7 +1419,8 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = (* Introduce the forward input state *) let fwd_state_var, fwd_state_lvs = assert ( - loop_effect_info.stateful = Option.is_some loop.input_state); + loop_fwd_effect_info.stateful + = Option.is_some loop.input_state); match loop.input_state with | None -> ([], []) | Some input_state -> @@ -1436,11 +1429,9 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = ([ state_var ], [ state_lvs ]) in - (* Introduce the additional backward inputs *) - (* TODO: *) - assert (not !Config.return_back_funs); + (* Introduce the additional backward inputs, if necessary *) let fun_body = Option.get def.body in - let info = fun_sig_info.fwd_info in + let info = fwd_info.fwd_info in let _, back_inputs = Collections.List.split_at fun_body.inputs info.num_inputs_with_fuel_with_state @@ -2063,14 +2054,12 @@ let filter_loop_inputs (transl : pure_fun_translation list) : (* We start by computing the filtering information, for each function *) let compute_one_filter_info (decl : fun_decl) = - (* TODO: *) - assert (not !Config.return_back_funs); (* There should be a body *) let body = Option.get decl.body in (* We only look at the forward inputs, without the state *) let inputs_prefix, _ = Collections.List.split_at body.inputs - decl.signature.info.fwd_info.num_inputs_with_fuel_no_state + decl.signature.fwd_info.fwd_info.num_inputs_with_fuel_no_state in let used = ref (List.map (fun v -> (var_get_id v, false)) inputs_prefix) in let inputs_prefix_length = List.length inputs_prefix in @@ -2089,9 +2078,7 @@ let filter_loop_inputs (transl : pure_fun_translation list) : in (* Set the fuel as used *) - let sg_info = decl.signature.info in - (* TODO: *) - assert (not !Config.return_back_funs); + let sg_info = decl.signature.fwd_info in if sg_info.fwd_info.has_fuel then set_used (fst (Collections.List.nth inputs 0)); @@ -2177,13 +2164,18 @@ let filter_loop_inputs (transl : pure_fun_translation list) : let num_filtered = List.length (List.filter (fun b -> not b) used_info) in - let { generics; llbc_generics; preds; inputs; output; doutputs; info } - = + let { + generics; + llbc_generics; + preds; + inputs; + output; + fwd_info; + back_effect_info; + } = decl.signature in - (* TODO: *) - assert (not !Config.return_back_funs); - let { fwd_info; effect_info } = info in + let { fwd_info; effect_info } = fwd_info in let { has_fuel; @@ -2208,10 +2200,18 @@ let filter_loop_inputs (transl : pure_fun_translation list) : } in - let info = { fwd_info; effect_info } in - assert (fun_sig_info_is_wf info); + let fwd_info = { fwd_info; effect_info } in + assert (fun_sig_info_is_wf fwd_info); let signature = - { generics; llbc_generics; preds; inputs; output; doutputs; info } + { + generics; + llbc_generics; + preds; + inputs; + output; + fwd_info; + back_effect_info; + } in { decl with signature } diff --git a/compiler/PureTypeCheck.ml b/compiler/PureTypeCheck.ml index 3c1800a8..d60d6a05 100644 --- a/compiler/PureTypeCheck.ml +++ b/compiler/PureTypeCheck.ml @@ -120,7 +120,7 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit = assert (output_ty = e.ty); check_texpression ctx app; check_texpression ctx arg - | Abs (pat, body) -> + | Lambda (pat, body) -> let pat_ty, body_ty = destruct_arrow e.ty in assert (pat.ty = pat_ty); assert (body.ty = body_ty); @@ -229,12 +229,6 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit = check_texpression ctx fe) supd.updates | _ -> raise (Failure "Unexpected")) - | Lambda (pat, e_next) -> - assert (e.ty = e_next.ty); - (* Check the pattern and register the introduced variables at the same time *) - let ctx = check_typed_pattern ctx pat in - (* Check the next expression *) - check_texpression ctx e_next | Meta (_, e_next) -> assert (e_next.ty = e.ty); check_texpression ctx e_next diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml index 80b25641..6e86578c 100644 --- a/compiler/PureUtils.ml +++ b/compiler/PureUtils.ml @@ -215,8 +215,7 @@ let fun_sig_substitute (subst : subst) (sg : fun_sig) : inst_fun_sig = *) let rec let_group_requires_parentheses (e : texpression) : bool = match e.e with - | Var _ | CVar _ | Const _ | App _ | Abs _ | Qualif _ | StructUpdate _ -> - false + | Var _ | CVar _ | Const _ | App _ | Qualif _ | StructUpdate _ -> false | Let (monadic, _, _, next_e) -> if monadic then true else let_group_requires_parentheses next_e | Switch (_, _) -> false @@ -374,18 +373,6 @@ let opt_destruct_tuple (ty : ty) : ty list option = Some generics.types | _ -> None -let mk_abs (x : typed_pattern) (e : texpression) : texpression = - let ty = TArrow (x.ty, e.ty) in - let e = Abs (x, e) in - { e; ty } - -let rec destruct_abs_list (e : texpression) : typed_pattern list * texpression = - match e.e with - | Abs (x, e') -> - let xl, e'' = destruct_abs_list e' in - (x :: xl, e'') - | _ -> ([], e) - let destruct_arrow (ty : ty) : ty * ty = match ty with | TArrow (ty0, ty1) -> (ty0, ty1) @@ -717,13 +704,16 @@ let type_decl_from_type_id_is_tuple_struct (ctx : TypesAnalysis.type_infos) info.is_tuple_struct | TAssumed _ -> false +let mk_lambda (x : typed_pattern) (e : texpression) : texpression = + let ty = TArrow (x.ty, e.ty) in + let e = Lambda (x, e) in + { e; ty } + let mk_lambda_from_var (var : var) (mp : mplace option) (e : texpression) : texpression = - let ty = TArrow (var.ty, e.ty) in let pat = PatVar (var, mp) in let pat = { value = pat; ty = var.ty } in - let e = Lambda (pat, e) in - { e; ty } + mk_lambda pat e let mk_lambdas_from_vars (vars : var list) (mps : mplace option list) (e : texpression) : texpression = -- cgit v1.2.3