summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon Ho2023-12-15 18:14:12 +0100
committerSon Ho2023-12-15 18:14:12 +0100
commit955fdab55304979ba2d61432ea654241f20abaa4 (patch)
treefcf1cd7dc3257e0c9242f5ec2eb79ee3bf2f49fb
parent5fa83883b4d573cfd252478f7937c8bde0ec01f6 (diff)
Make progress on propagating the changes
-rw-r--r--compiler/Extract.ml8
-rw-r--r--compiler/PrintPure.ml16
-rw-r--r--compiler/Pure.ml3
-rw-r--r--compiler/PureMicroPasses.ml110
-rw-r--r--compiler/PureTypeCheck.ml8
-rw-r--r--compiler/PureUtils.ml24
6 files changed, 72 insertions, 97 deletions
diff --git a/compiler/Extract.ml b/compiler/Extract.ml
index 1ea26d79..7e2efd8a 100644
--- a/compiler/Extract.ml
+++ b/compiler/Extract.ml
@@ -285,9 +285,9 @@ let rec extract_texpression (ctx : extraction_ctx) (fmt : F.formatter)
| App _ ->
let app, args = destruct_apps e in
extract_App ctx fmt inside app args
- | Abs _ ->
- let xl, e = destruct_abs_list e in
- extract_Abs ctx fmt inside xl e
+ | Lambda _ ->
+ let xl, e = destruct_lambdas e in
+ extract_Lambda ctx fmt inside xl e
| Qualif _ ->
(* We use the app case *)
extract_App ctx fmt inside e []
@@ -574,7 +574,7 @@ and extract_field_projector (ctx : extraction_ctx) (fmt : F.formatter)
(* No argument: shouldn't happen *)
raise (Failure "Unreachable")
-and extract_Abs (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
+and extract_Lambda (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
(xl : typed_pattern list) (e : texpression) : unit =
(* Open a box for the abs expression *)
F.pp_open_hovbox fmt ctx.indent_incr;
diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml
index 3a5ce513..79506c04 100644
--- a/compiler/PrintPure.ml
+++ b/compiler/PrintPure.ml
@@ -543,9 +543,9 @@ let rec texpression_to_string (env : fmt_env) (inside : bool) (indent : string)
let app, args = destruct_apps e in
(* Convert to string *)
app_to_string env inside indent indent_incr app args
- | Abs _ ->
- let xl, e = destruct_abs_list e in
- let e = abs_to_string env indent indent_incr xl e in
+ | Lambda _ ->
+ let xl, e = destruct_lambdas e in
+ let e = lambda_to_string env indent indent_incr xl e in
if inside then "(" ^ e ^ ")" else e
| Qualif _ ->
(* Qualifier without arguments *)
@@ -592,14 +592,6 @@ let rec texpression_to_string (env : fmt_env) (inside : bool) (indent : string)
in
"[ " ^ String.concat ", " fields ^ " ]"
| _ -> raise (Failure "Unexpected"))
- | Lambda _ ->
- let pats, e = destruct_lambdas e in
- let vars =
- String.concat " " (List.map (typed_pattern_to_string env) pats)
- in
- let e = texpression_to_string env false indent indent_incr e in
- let s = "λ " ^ vars ^ " => " ^ e in
- if inside then "(" ^ s ^ ")" else s
| Meta (meta, e) -> (
let meta_s = emeta_to_string env meta in
let e = texpression_to_string env inside indent indent_incr e in
@@ -668,7 +660,7 @@ and app_to_string (env : fmt_env) (inside : bool) (indent : string)
(* Add parentheses *)
if all_args <> [] && inside then "(" ^ e ^ ")" else e
-and abs_to_string (env : fmt_env) (indent : string) (indent_incr : string)
+and lambda_to_string (env : fmt_env) (indent : string) (indent_incr : string)
(xl : typed_pattern list) (e : texpression) : string =
let xl = List.map (typed_pattern_to_string env) xl in
let e = texpression_to_string env false indent indent_incr e in
diff --git a/compiler/Pure.ml b/compiler/Pure.ml
index eb6b00c8..ddacf0c4 100644
--- a/compiler/Pure.ml
+++ b/compiler/Pure.ml
@@ -684,7 +684,7 @@ type expression =
field accesses with calls to projectors over fields (when there
are clashes of field names, some provers like F* get pretty bad...)
*)
- | Abs of typed_pattern * texpression (** Lambda abstraction: [fun x -> e] *)
+ | Lambda of typed_pattern * texpression (** Lambda abstraction: [λ x => e] *)
| Qualif of qualif (** A top-level qualifier *)
| Let of bool * typed_pattern * texpression * texpression
(** Let binding.
@@ -728,7 +728,6 @@ type expression =
| Switch of texpression * switch_body
| Loop of loop (** See the comments for {!loop} *)
| StructUpdate of struct_update (** See the comments for {!struct_update} *)
- | Lambda of typed_pattern * texpression (** [λ x => e] *)
| Meta of (emeta[@opaque]) * texpression (** Meta-information *)
and switch_body = If of texpression * texpression | Match of match_branch list
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml
index d92b3de0..0102b13e 100644
--- a/compiler/PureMicroPasses.ml
+++ b/compiler/PureMicroPasses.ml
@@ -385,17 +385,17 @@ let compute_pretty_names (def : fun_decl) : fun_decl =
let ctx, arg = update_texpression arg ctx in
let e = App (app, arg) in
(ctx, e)
- | Abs (x, e) -> update_abs x e ctx
| Qualif _ -> (* nothing to do *) (ctx, e.e)
| Let (monadic, lb, re, e) -> update_let monadic lb re e ctx
| Switch (scrut, body) -> update_switch_body scrut body ctx
| Loop loop -> update_loop loop ctx
| StructUpdate supd -> update_struct_update supd ctx
+ | Lambda (lb, e) -> update_lambda lb e ctx
| Meta (meta, e) -> update_emeta meta e ctx
in
(ctx, { e; ty })
(* *)
- and update_abs (x : typed_pattern) (e : texpression) (ctx : pn_ctx) :
+ and update_lambda (x : typed_pattern) (e : texpression) (ctx : pn_ctx) :
pn_ctx * expression =
(* We first add the left-constraint *)
let ctx = add_left_constraint x ctx in
@@ -404,7 +404,7 @@ let compute_pretty_names (def : fun_decl) : fun_decl =
(* Update the abstracted value *)
let x = update_typed_pattern ctx x in
(* Put together *)
- (ctx, Abs (x, e))
+ (ctx, Lambda (x, e))
(* *)
and update_let (monadic : bool) (lv : typed_pattern) (re : texpression)
(e : texpression) (ctx : pn_ctx) : pn_ctx * expression =
@@ -890,12 +890,12 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx)
let call_is_child = check_call func1 generics1 args1 in
if call_is_child then fun () -> true
else fun () -> self#visit_texpression env e ())
+ | Lambda (_, e) -> self#visit_texpression env e
| App _ -> (
fun () ->
match opt_destruct_function_call e with
| Some (func1, tys1, args1) -> check_call func1 tys1 args1
| None -> false)
- | Abs (_, e) -> self#visit_texpression env e
| Qualif _ ->
(* Note that this case includes functions without arguments *)
fun () -> false
@@ -975,7 +975,7 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx)
| Var _ | CVar _ | Const _ | App _ | Qualif _
| Switch (_, _)
| Meta (_, _)
- | StructUpdate _ | Abs _ ->
+ | StructUpdate _ | Lambda _ ->
super#visit_expression env e
| Let (monadic, lv, re, e) ->
(* Compute the set of values used in the next expression *)
@@ -1323,28 +1323,20 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list =
method! visit_Loop env loop =
let fun_sig = def.signature in
- let fun_sig_info = fun_sig.info in
- let fun_effect_info = fun_sig_info.effect_info in
+ let fwd_info = fun_sig.fwd_info in
+ let fwd_effect_info = fwd_info.effect_info in
(* TODO: *)
assert (not !Config.return_back_funs);
(* Generate the loop definition *)
- let loop_effect_info =
- {
- stateful_group = fun_effect_info.stateful_group;
- stateful = fun_effect_info.stateful;
- can_fail = fun_effect_info.can_fail;
- can_diverge = fun_effect_info.can_diverge;
- is_rec = fun_effect_info.is_rec;
- }
- in
+ let loop_fwd_effect_info = fwd_effect_info in
- let loop_sig_info =
+ let loop_fwd_sig_info : fun_sig_info =
let fuel = if !Config.use_fuel then 1 else 0 in
let num_inputs = List.length loop.inputs in
let fwd_info : inputs_info =
- let info = fun_sig_info.fwd_info in
+ let info = fwd_info.fwd_info in
let fwd_state =
info.num_inputs_with_fuel_with_state
- info.num_inputs_with_fuel_no_state
@@ -1358,48 +1350,48 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list =
}
in
- { fwd_info; effect_info = loop_effect_info }
+ { fwd_info; effect_info = loop_fwd_effect_info }
in
- assert (fun_sig_info_is_wf loop_sig_info);
+ assert (fun_sig_info_is_wf loop_fwd_sig_info);
let inputs_tys =
- (* TODO: *)
- assert (not !Config.return_back_funs);
-
let fuel = if !Config.use_fuel then [ mk_fuel_ty ] else [] in
let fwd_inputs = List.map (fun (v : var) -> v.ty) loop.inputs in
- let info = fun_sig_info.fwd_info in
- let state =
+ let info = fwd_info.fwd_info in
+ let fwd_state =
Collections.List.subslice fun_sig.inputs
info.num_inputs_with_fuel_no_state
info.num_inputs_with_fuel_with_state
in
- let _, back_inputs =
- Collections.List.split_at fun_sig.inputs
- info.num_inputs_with_fuel_with_state
+ let back_inputs =
+ if !Config.return_back_funs then []
+ else
+ snd
+ (Collections.List.split_at fun_sig.inputs
+ info.num_inputs_with_fuel_with_state)
in
- List.concat [ fuel; fwd_inputs; state; back_inputs ]
+ List.concat [ fuel; fwd_inputs; fwd_state; back_inputs ]
in
- let output, doutputs =
+ let output =
match loop.back_output_tys with
| None ->
(* Forward function: the return type is the same as the
parent function *)
- (fun_sig.output, fun_sig.doutputs)
+ fun_sig.output
| Some doutputs ->
(* Backward function: custom return type *)
let output = mk_simpl_tuple_ty doutputs in
let output =
- if loop_effect_info.stateful then
+ if loop_fwd_effect_info.stateful then
mk_simpl_tuple_ty [ mk_state_ty; output ]
else output
in
let output =
- if loop_effect_info.can_fail then mk_result_ty output
+ if loop_fwd_effect_info.can_fail then mk_result_ty output
else output
in
- (output, doutputs)
+ output
in
let loop_sig =
@@ -1409,8 +1401,8 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list =
preds = fun_sig.preds;
inputs = inputs_tys;
output;
- doutputs;
- info = loop_sig_info;
+ fwd_info = loop_fwd_sig_info;
+ back_effect_info = fun_sig.back_effect_info;
}
in
@@ -1427,7 +1419,8 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list =
(* Introduce the forward input state *)
let fwd_state_var, fwd_state_lvs =
assert (
- loop_effect_info.stateful = Option.is_some loop.input_state);
+ loop_fwd_effect_info.stateful
+ = Option.is_some loop.input_state);
match loop.input_state with
| None -> ([], [])
| Some input_state ->
@@ -1436,11 +1429,9 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list =
([ state_var ], [ state_lvs ])
in
- (* Introduce the additional backward inputs *)
- (* TODO: *)
- assert (not !Config.return_back_funs);
+ (* Introduce the additional backward inputs, if necessary *)
let fun_body = Option.get def.body in
- let info = fun_sig_info.fwd_info in
+ let info = fwd_info.fwd_info in
let _, back_inputs =
Collections.List.split_at fun_body.inputs
info.num_inputs_with_fuel_with_state
@@ -2063,14 +2054,12 @@ let filter_loop_inputs (transl : pure_fun_translation list) :
(* We start by computing the filtering information, for each function *)
let compute_one_filter_info (decl : fun_decl) =
- (* TODO: *)
- assert (not !Config.return_back_funs);
(* There should be a body *)
let body = Option.get decl.body in
(* We only look at the forward inputs, without the state *)
let inputs_prefix, _ =
Collections.List.split_at body.inputs
- decl.signature.info.fwd_info.num_inputs_with_fuel_no_state
+ decl.signature.fwd_info.fwd_info.num_inputs_with_fuel_no_state
in
let used = ref (List.map (fun v -> (var_get_id v, false)) inputs_prefix) in
let inputs_prefix_length = List.length inputs_prefix in
@@ -2089,9 +2078,7 @@ let filter_loop_inputs (transl : pure_fun_translation list) :
in
(* Set the fuel as used *)
- let sg_info = decl.signature.info in
- (* TODO: *)
- assert (not !Config.return_back_funs);
+ let sg_info = decl.signature.fwd_info in
if sg_info.fwd_info.has_fuel then
set_used (fst (Collections.List.nth inputs 0));
@@ -2177,13 +2164,18 @@ let filter_loop_inputs (transl : pure_fun_translation list) :
let num_filtered =
List.length (List.filter (fun b -> not b) used_info)
in
- let { generics; llbc_generics; preds; inputs; output; doutputs; info }
- =
+ let {
+ generics;
+ llbc_generics;
+ preds;
+ inputs;
+ output;
+ fwd_info;
+ back_effect_info;
+ } =
decl.signature
in
- (* TODO: *)
- assert (not !Config.return_back_funs);
- let { fwd_info; effect_info } = info in
+ let { fwd_info; effect_info } = fwd_info in
let {
has_fuel;
@@ -2208,10 +2200,18 @@ let filter_loop_inputs (transl : pure_fun_translation list) :
}
in
- let info = { fwd_info; effect_info } in
- assert (fun_sig_info_is_wf info);
+ let fwd_info = { fwd_info; effect_info } in
+ assert (fun_sig_info_is_wf fwd_info);
let signature =
- { generics; llbc_generics; preds; inputs; output; doutputs; info }
+ {
+ generics;
+ llbc_generics;
+ preds;
+ inputs;
+ output;
+ fwd_info;
+ back_effect_info;
+ }
in
{ decl with signature }
diff --git a/compiler/PureTypeCheck.ml b/compiler/PureTypeCheck.ml
index 3c1800a8..d60d6a05 100644
--- a/compiler/PureTypeCheck.ml
+++ b/compiler/PureTypeCheck.ml
@@ -120,7 +120,7 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit =
assert (output_ty = e.ty);
check_texpression ctx app;
check_texpression ctx arg
- | Abs (pat, body) ->
+ | Lambda (pat, body) ->
let pat_ty, body_ty = destruct_arrow e.ty in
assert (pat.ty = pat_ty);
assert (body.ty = body_ty);
@@ -229,12 +229,6 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit =
check_texpression ctx fe)
supd.updates
| _ -> raise (Failure "Unexpected"))
- | Lambda (pat, e_next) ->
- assert (e.ty = e_next.ty);
- (* Check the pattern and register the introduced variables at the same time *)
- let ctx = check_typed_pattern ctx pat in
- (* Check the next expression *)
- check_texpression ctx e_next
| Meta (_, e_next) ->
assert (e_next.ty = e.ty);
check_texpression ctx e_next
diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml
index 80b25641..6e86578c 100644
--- a/compiler/PureUtils.ml
+++ b/compiler/PureUtils.ml
@@ -215,8 +215,7 @@ let fun_sig_substitute (subst : subst) (sg : fun_sig) : inst_fun_sig =
*)
let rec let_group_requires_parentheses (e : texpression) : bool =
match e.e with
- | Var _ | CVar _ | Const _ | App _ | Abs _ | Qualif _ | StructUpdate _ ->
- false
+ | Var _ | CVar _ | Const _ | App _ | Qualif _ | StructUpdate _ -> false
| Let (monadic, _, _, next_e) ->
if monadic then true else let_group_requires_parentheses next_e
| Switch (_, _) -> false
@@ -374,18 +373,6 @@ let opt_destruct_tuple (ty : ty) : ty list option =
Some generics.types
| _ -> None
-let mk_abs (x : typed_pattern) (e : texpression) : texpression =
- let ty = TArrow (x.ty, e.ty) in
- let e = Abs (x, e) in
- { e; ty }
-
-let rec destruct_abs_list (e : texpression) : typed_pattern list * texpression =
- match e.e with
- | Abs (x, e') ->
- let xl, e'' = destruct_abs_list e' in
- (x :: xl, e'')
- | _ -> ([], e)
-
let destruct_arrow (ty : ty) : ty * ty =
match ty with
| TArrow (ty0, ty1) -> (ty0, ty1)
@@ -717,13 +704,16 @@ let type_decl_from_type_id_is_tuple_struct (ctx : TypesAnalysis.type_infos)
info.is_tuple_struct
| TAssumed _ -> false
+let mk_lambda (x : typed_pattern) (e : texpression) : texpression =
+ let ty = TArrow (x.ty, e.ty) in
+ let e = Lambda (x, e) in
+ { e; ty }
+
let mk_lambda_from_var (var : var) (mp : mplace option) (e : texpression) :
texpression =
- let ty = TArrow (var.ty, e.ty) in
let pat = PatVar (var, mp) in
let pat = { value = pat; ty = var.ty } in
- let e = Lambda (pat, e) in
- { e; ty }
+ mk_lambda pat e
let mk_lambdas_from_vars (vars : var list) (mps : mplace option list)
(e : texpression) : texpression =