diff options
author | Son Ho | 2023-01-06 16:51:27 +0100 |
---|---|---|
committer | Son HO | 2023-02-03 11:21:46 +0100 |
commit | 46381652adbece2d7ccfd57fae8b5ee2365fb374 (patch) | |
tree | 80e1d1e2cf5728c76736e213c9bedab5191b8376 | |
parent | 2935706e2670a6aad0a01f4ffa29803574a687ed (diff) |
Fix some issues with the values given back by loop backward translations
Diffstat (limited to '')
-rw-r--r-- | compiler/ExtractBase.ml | 10 | ||||
-rw-r--r-- | compiler/InterpreterLoops.ml | 93 | ||||
-rw-r--r-- | compiler/PrintPure.ml | 17 | ||||
-rw-r--r-- | compiler/Pure.ml | 3 | ||||
-rw-r--r-- | compiler/PureMicroPasses.ml | 29 | ||||
-rw-r--r-- | compiler/PureTypeCheck.ml | 7 | ||||
-rw-r--r-- | compiler/SymbolicAst.ml | 5 | ||||
-rw-r--r-- | compiler/SymbolicToPure.ml | 148 | ||||
-rw-r--r-- | compiler/SynthesizeSymbolic.ml | 18 | ||||
-rw-r--r-- | compiler/Translate.ml | 1 | ||||
-rw-r--r-- | tests/coq/misc/Loops.v | 73 | ||||
-rw-r--r-- | tests/fstar/misc/Loops.Clauses.Template.fst | 10 | ||||
-rw-r--r-- | tests/fstar/misc/Loops.Clauses.fst | 10 | ||||
-rw-r--r-- | tests/fstar/misc/Loops.Funs.fst | 67 |
14 files changed, 451 insertions, 40 deletions
diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml index b952d555..a9b44017 100644 --- a/compiler/ExtractBase.ml +++ b/compiler/ExtractBase.ml @@ -177,14 +177,16 @@ type formatter = { indices to derive unique names for the loops for instance - if there is exactly one loop, we don't need to use indices) - loop id (if pertinent) - - number of region groups (same comment as for the number of loops) + - number of region groups - region group information in case of a backward function ([None] if forward function) - pair: - do we generate the forward function (it may have been filtered)? - - the number of extracted backward functions (not necessarily equal - to the number of region groups, because we may have filtered - some of them) + - the number of *extracted backward functions* (same comment as for + the number of loops) + The number of extracted backward functions if not necessarily + equal to the number of region groups, because we may have + filtered some of them. TODO: use the fun id for the assumed functions. *) decreases_clause_name : A.FunDeclId.id -> fun_name -> string; diff --git a/compiler/InterpreterLoops.ml b/compiler/InterpreterLoops.ml index 2d900b7d..11bc7a07 100644 --- a/compiler/InterpreterLoops.ml +++ b/compiler/InterpreterLoops.ml @@ -2785,8 +2785,11 @@ let prepare_ashared_loans_no_synth (loop_id : V.LoopId.id) (ctx : C.eval_ctx) : get_cf_ctx_no_synth (prepare_ashared_loans (Some loop_id)) ctx (** Compute a fixed-point for the context at the entry of the loop. - We also return the sets of fixed ids, and the list of symbolic values - that appear in the fixed point context. + We also return: + - the sets of fixed ids + - the map from region group id to the corresponding abstraction appearing + in the fixed point (this is useful to compute the return type of the loop + backward functions for instance). Rem.: the list of symbolic values should be computable by simply exploring the fixed point environment and listing all the symbolic values we find. @@ -2794,7 +2797,8 @@ let prepare_ashared_loans_no_synth (loop_id : V.LoopId.id) (ctx : C.eval_ctx) : the values which are read or modified (some symbolic values may be ignored). *) let compute_loop_entry_fixed_point (config : C.config) (loop_id : V.LoopId.id) - (eval_loop_body : st_cm_fun) (ctx0 : C.eval_ctx) : C.eval_ctx * ids_sets = + (eval_loop_body : st_cm_fun) (ctx0 : C.eval_ctx) : + C.eval_ctx * ids_sets * V.abs T.RegionGroupId.Map.t = (* The continuation for when we exit the loop - we register the environments upon loop *reentry*, and synthesize nothing by returning [None] @@ -2963,7 +2967,7 @@ let compute_loop_entry_fixed_point (config : C.config) (loop_id : V.LoopId.id) "horizontally": the symbolic values contained in the abstractions (typically the shared values) will be preserved. *) - let fp = + let fp, rg_to_abs = (* List the loop abstractions in the fixed-point *) let fp_aids, add_aid, _mem_aid = V.AbstractionId.Set.mk_stateful_set () in @@ -3066,8 +3070,10 @@ let compute_loop_entry_fixed_point (config : C.config) (loop_id : V.LoopId.id) se, but if it doesn't happen it is bizarre and worth investigating... *) assert (V.AbstractionId.Set.equal !aids_union !fp_aids); - (* Merge the abstractions which need to be merged *) + (* Merge the abstractions which need to be merged, and compute the map from + region id to abstraction id *) let fp = ref fp in + let rg_to_abs = ref T.RegionGroupId.Map.empty in let _ = T.RegionGroupId.Map.iter (fun rg_id ids -> @@ -3108,9 +3114,13 @@ let compute_loop_entry_fixed_point (config : C.config) (loop_id : V.LoopId.id) id0 := id0'; () with ValueMatchFailure _ -> raise (Failure "Unexpected")) - ids) + ids; + (* Register the mapping *) + let abs = C.ctx_lookup_abs !fp !id0 in + rg_to_abs := T.RegionGroupId.Map.add_strict rg_id abs !rg_to_abs) !fp_ended_aids in + let rg_to_abs = !rg_to_abs in (* Reorder the loans and borrows in the fresh abstractions in the fixed-point *) let fp = @@ -3164,12 +3174,12 @@ let compute_loop_entry_fixed_point (config : C.config) (loop_id : V.LoopId.id) in (* Return *) - fp + (fp, rg_to_abs) in let fixed_ids = compute_fixed_ids [ fp ] in (* Return *) - (fp, fixed_ids) + (fp, fixed_ids, rg_to_abs) (** Split an environment between the fixed abstractions, values, etc. and the new abstractions, values, etc. @@ -4127,7 +4137,7 @@ let eval_loop_symbolic (config : C.config) (eval_loop_body : st_cm_fun) : let loop_id = C.fresh_loop_id () in (* Compute the fixed point at the loop entrance *) - let fp_ctx, fixed_ids = + let fp_ctx, fixed_ids, rg_to_abs = compute_loop_entry_fixed_point config loop_id eval_loop_body ctx in @@ -4197,8 +4207,71 @@ let eval_loop_symbolic (config : C.config) (eval_loop_body : st_cm_fun) : ^ Print.list_to_string (symbolic_value_to_string ctx) input_svalues ^ "\n\n")); + (* For every abstraction introduced by the fixed-point, compute the + types of the given back values. + + We need to explore the abstractions, looking for the mutable borrows. + Moreover, we list the borrows in the same order as the loans (this + is important in {!SymbolicToPure}, where we expect the given back + values to have a specific order. + *) + let compute_abs_given_back_tys (abs : V.abs) : T.RegionId.Set.t * T.rty list = + let is_borrow (av : V.typed_avalue) : bool = + match av.V.value with + | ABorrow _ -> true + | ALoan _ -> false + | _ -> raise (Failure "Unreachable") + in + let borrows, loans = List.partition is_borrow abs.avalues in + + let borrows = + List.filter_map + (fun av -> + match av.V.value with + | V.ABorrow (V.AMutBorrow (bid, child_av)) -> + assert (is_aignored child_av.V.value); + Some (bid, child_av.V.ty) + | V.ABorrow (V.ASharedBorrow _) -> None + | _ -> raise (Failure "Unreachable")) + borrows + in + let borrows = ref (V.BorrowId.Map.of_list borrows) in + + let loan_ids = + List.filter_map + (fun av -> + match av.V.value with + | V.ALoan (V.AMutLoan (bid, child_av)) -> + assert (is_aignored child_av.V.value); + Some bid + | V.ALoan (V.ASharedLoan _) -> None + | _ -> raise (Failure "Unreachable")) + loans + in + + (* List the given back types, in the order given by the loans *) + let given_back_tys = + List.map + (fun lid -> + let bid = + V.BorrowId.InjSubst.find lid fp_bl_corresp.loan_to_borrow_id_map + in + let ty = V.BorrowId.Map.find bid !borrows in + borrows := V.BorrowId.Map.remove bid !borrows; + ty) + loan_ids + in + assert (V.BorrowId.Map.is_empty !borrows); + + (abs.regions, given_back_tys) + in + let rg_to_given_back = + T.RegionGroupId.Map.map compute_abs_given_back_tys rg_to_abs + in + (* Put together *) - S.synthesize_loop loop_id input_svalues fresh_sids end_expr loop_expr + S.synthesize_loop loop_id input_svalues fresh_sids rg_to_given_back end_expr + loop_expr (** Evaluate a loop *) let eval_loop (config : C.config) (eval_loop_body : st_cm_fun) : st_cm_fun = diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml index c13ce238..532271c3 100644 --- a/compiler/PrintPure.ml +++ b/compiler/PrintPure.ml @@ -632,15 +632,26 @@ and loop_to_string (fmt : ast_formatter) (indent : string) ^ String.concat "; " (List.map (var_to_string type_fmt) loop.inputs) ^ "]" in + let back_output_tys = + let tys = + match loop.back_output_tys with + | None -> "" + | Some tys -> + String.concat "; " + (List.map (ty_to_string (ast_to_type_formatter fmt) false) tys) + in + "back_output_tys: [" ^ tys ^ "]" + in let fun_end = texpression_to_string fmt false indent2 indent_incr loop.fun_end in let loop_body = texpression_to_string fmt false indent2 indent_incr loop.loop_body in - "loop {\n" ^ indent1 ^ loop_inputs ^ "\n" ^ indent1 ^ "fun_end: {\n" ^ indent2 - ^ fun_end ^ "\n" ^ indent1 ^ "}\n" ^ indent1 ^ "loop_body: {\n" ^ indent2 - ^ loop_body ^ "\n" ^ indent1 ^ "}\n" ^ indent ^ "}" + "loop {\n" ^ indent1 ^ loop_inputs ^ "\n" ^ indent1 ^ back_output_tys ^ "\n" + ^ indent1 ^ "fun_end: {\n" ^ indent2 ^ fun_end ^ "\n" ^ indent1 ^ "}\n" + ^ indent1 ^ "loop_body: {\n" ^ indent2 ^ loop_body ^ "\n" ^ indent1 ^ "}\n" + ^ indent ^ "}" and meta_to_string (fmt : ast_formatter) (meta : meta) : string = let meta = diff --git a/compiler/Pure.ml b/compiler/Pure.ml index 1b0a6b5c..118aec50 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -498,6 +498,9 @@ and loop = { inputs : var list; inputs_lvs : typed_pattern list; (** The inputs seen as patterns. See {!fun_body}. *) + back_output_tys : ty list option; + (** The types of the given back values, if we ar esynthesizing a backward + function *) loop_body : texpression; } diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index aed5b02d..25d760fe 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -440,13 +440,14 @@ let compute_pretty_names (def : fun_decl) : fun_decl = input_state; inputs; inputs_lvs; + back_output_tys; loop_body; } = loop in let ctx, fun_end = update_texpression fun_end ctx in let ctx, loop_body = update_texpression loop_body ctx in - let inputs = List.map (fun input -> update_var ctx input None) inputs in + let inputs = List.map (fun v -> update_var ctx v None) inputs in let inputs_lvs = List.map (update_typed_pattern ctx) inputs_lvs in let loop = { @@ -457,6 +458,7 @@ let compute_pretty_names (def : fun_decl) : fun_decl = input_state; inputs; inputs_lvs; + back_output_tys; loop_body; } in @@ -1126,12 +1128,33 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = List.concat [ fuel; fwd_inputs; state; back_inputs ] in + let output, doutputs = + match loop.back_output_tys with + | None -> + (* Forward function: the return type is the same as the + parent function *) + (fun_sig.output, fun_sig.doutputs) + | Some doutputs -> + (* Backward function: custom return type *) + let output = mk_simpl_tuple_ty doutputs in + let output = + if loop_effect_info.stateful then + mk_simpl_tuple_ty [ mk_state_ty; output ] + else output + in + let output = + if loop_effect_info.can_fail then mk_result_ty output + else output + in + (output, doutputs) + in + let loop_sig = { type_params = fun_sig.type_params; inputs = inputs_tys; - output = fun_sig.output; - doutputs = fun_sig.doutputs; + output; + doutputs; info = loop_sig_info; } in diff --git a/compiler/PureTypeCheck.ml b/compiler/PureTypeCheck.ml index 78fd077a..1871f1bc 100644 --- a/compiler/PureTypeCheck.ml +++ b/compiler/PureTypeCheck.ml @@ -186,7 +186,12 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit = List.iter check_branch branches) | Loop loop -> assert (loop.fun_end.ty = e.ty); - assert (loop.loop_body.ty = e.ty); + (* If we translate forward functions, the type of the loop is the same + as the type of the parent expression - in case of backward functions, + the loop doesn't necessarily give back the same values as the parent + function + *) + assert (Option.is_some loop.back_output_tys || loop.loop_body.ty = e.ty); check_texpression ctx loop.fun_end; check_texpression ctx loop.loop_body | Meta (_, e_next) -> diff --git a/compiler/SymbolicAst.ml b/compiler/SymbolicAst.ml index 7f682c9c..0e68d2fd 100644 --- a/compiler/SymbolicAst.ml +++ b/compiler/SymbolicAst.ml @@ -216,6 +216,11 @@ and loop = { input_svalues : V.symbolic_value list; (** The input symbolic values *) fresh_svalues : V.symbolic_value_id_set; (** The symbolic values introduced by the loop fixed-point *) + rg_to_given_back_tys : + ((T.RegionId.Set.t * T.rty list) T.RegionGroupId.Map.t[@opaque]); + (** The map from region group ids to the types of the values given back + by the corresponding loop abstractions. + *) end_expr : expression; (** The end of the function (upon the moment it enters the loop) *) loop_expr : expression; (** The symbolically executed loop body *) diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index b024f40e..120689e5 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -163,6 +163,13 @@ type bs_ctx = { (** The variables that the backward functions will output, corresponding to the borrows they give back (don't include the backward state) *) + loop_backward_outputs : var list T.RegionGroupId.Map.t option; + (** Same as {!backward_outputs}, but for loops (if we entered a loop). + + [None] if we are not inside a loop, [Some] otherwise (and whatever + the kind of function we are translating: it will be [Some] even + though we are synthesizing a forward function). + *) calls : call_info V.FunCallId.Map.t; (** The function calls we encountered so far *) abstractions : (V.abs * texpression list) V.AbstractionId.Map.t; @@ -255,6 +262,11 @@ let ty_to_string (ctx : bs_ctx) (ty : ty) : string = let fmt = PrintPure.ast_to_type_formatter fmt in PrintPure.ty_to_string fmt false ty +let rty_to_string (ctx : bs_ctx) (ty : T.rty) : string = + let fmt = bs_ctx_to_ctx_formatter ctx in + let fmt = Print.PC.ctx_to_rtype_formatter fmt in + Print.PT.rty_to_string fmt ty + let type_decl_to_string (ctx : bs_ctx) (def : type_decl) : string = let type_params = def.type_params in let type_decls = ctx.type_context.llbc_type_decls in @@ -829,7 +841,7 @@ let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * typed_pattern = (* Return *) (ctx, state_pat) -let fresh_var (basename : string option) (ty : 'r T.ty) (ctx : bs_ctx) : +let fresh_var_llbc_ty (basename : string option) (ty : 'r T.ty) (ctx : bs_ctx) : bs_ctx * var = (* Generate the fresh variable *) let id, var_counter = VarId.fresh ctx.var_counter in @@ -843,7 +855,7 @@ let fresh_var (basename : string option) (ty : 'r T.ty) (ctx : bs_ctx) : let fresh_named_var_for_symbolic_value (basename : string option) (sv : V.symbolic_value) (ctx : bs_ctx) : bs_ctx * var = (* Generate the fresh variable *) - let ctx, var = fresh_var basename sv.sv_ty ctx in + let ctx, var = fresh_var_llbc_ty basename sv.sv_ty ctx in (* Insert in the map *) let sv_to_var = V.SymbolicValueId.Map.add_strict sv.sv_id var ctx.sv_to_var in (* Update the context *) @@ -981,8 +993,9 @@ let rec typed_value_to_texpression (ctx : bs_ctx) (ectx : C.eval_ctx) log#ldebug (lazy ("typed_value_to_texpression: result:" ^ "\n- input value:\n" - ^ V.show_typed_value v ^ "\n- translated expression:\n" - ^ show_texpression value)); + ^ typed_value_to_string ctx v + ^ "\n- translated expression:\n" + ^ texpression_to_string ctx value)); (* Sanity check *) type_check_texpression ctx value; (* Return *) @@ -1296,7 +1309,21 @@ and translate_panic (ctx : bs_ctx) : texpression = * but it won't be true anymore once we translate individual blocks *) (* If we use a state monad, we need to add a lambda for the state variable *) (* Note that only forward functions return a state *) - let output_ty = mk_simpl_tuple_ty ctx.sg.doutputs in + let output_ty = + if ctx.inside_loop && Option.is_some ctx.bid then + (* We are synthesizing the backward function of a loop body *) + let bid = Option.get ctx.bid in + let back_vars = + T.RegionGroupId.Map.find bid (Option.get ctx.loop_backward_outputs) + in + let tys = List.map (fun (v : var) -> v.ty) back_vars in + mk_simpl_tuple_ty tys + else + (* Regular function, or forward function (the forward translation for + a loop has the same return type as the parent function) + *) + mk_simpl_tuple_ty ctx.sg.doutputs + in (* TODO: we should use a [Fail] function *) if ctx.sg.info.effect_info.stateful then (* Create the [Fail] value *) @@ -1373,7 +1400,14 @@ and translate_return_with_loop (loop_id : V.LoopId.id) (is_continue : bool) (* Group the variables in which we stored the values we need to give back. * See the explanations for the [SynthInput] case in [translate_end_abstraction] *) let backward_outputs = - T.RegionGroupId.Map.find bid ctx.backward_outputs + let map = + if ctx.inside_loop then + (* We are synthesizing a loop body *) + Option.get ctx.loop_backward_outputs + else (* Regular function *) + ctx.backward_outputs + in + T.RegionGroupId.Map.find bid map in let field_values = List.map mk_texpression_from_var backward_outputs in mk_simpl_tuple_texpression field_values @@ -1535,9 +1569,15 @@ and translate_end_abstraction_synth_input (ectx : C.eval_ctx) (abs : V.abs) = log#ldebug (lazy - ("translate_end_abstraction_synth_input:" ^ "\n- eval_ctx:\n" - ^ IU.eval_ctx_to_string ectx ^ "\n- abs:\n" ^ IU.abs_to_string ectx abs - ^ "\n")); + ("translate_end_abstraction_synth_input:" ^ "\n- function: " + ^ Print.name_to_string ctx.fun_decl.name + ^ "\n- rg_id: " + ^ T.RegionGroupId.to_string rg_id + ^ "\n- loop_id: " + ^ Print.option_to_string Pure.LoopId.to_string ctx.loop_id + ^ "\n- eval_ctx:\n" ^ IU.eval_ctx_to_string ectx ^ "\n- abs:\n" + ^ IU.abs_to_string ectx abs ^ "\n")); + (* When we end an input abstraction, this input abstraction gets back * the borrows which it introduced in the context through the input * values: by listing those values, we get the values which are given @@ -1564,12 +1604,38 @@ and translate_end_abstraction_synth_input (ectx : C.eval_ctx) (abs : V.abs) * (v_i) * ]} * *) - (* First, get the given back variables *) + (* First, get the given back variables. + + We don't use the same given back variables if we translate a loop or + the standard body of a function. + *) let given_back_variables = - T.RegionGroupId.Map.find bid ctx.backward_outputs + let map = + if ctx.inside_loop then + (* We are synthesizing a loop body *) + Option.get ctx.loop_backward_outputs + else (* Regular function body *) + ctx.backward_outputs + in + T.RegionGroupId.Map.find bid map in + (* Get the list of values consumed by the abstraction upon ending *) let consumed_values = abs_to_consumed ctx ectx abs in + + log#ldebug + (lazy + ("translate_end_abstraction_synth_input:" + ^ "\n\n- given back variables types:\n" + ^ Print.list_to_string + (fun (v : var) -> ty_to_string ctx v.ty) + given_back_variables + ^ "\n\n- consumed values:\n" + ^ Print.list_to_string + (fun e -> texpression_to_string ctx e ^ " : " ^ ty_to_string ctx e.ty) + consumed_values + ^ "\n")); + (* Group the two lists *) let variables_values = List.combine given_back_variables consumed_values in (* Sanity check: the two lists match (same types) *) @@ -1655,11 +1721,11 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) ("\n- fun_id: " ^ A.show_fun_id fun_id ^ "\n- inputs (" ^ string_of_int (List.length inputs) ^ "): " - ^ String.concat ", " (List.map show_texpression inputs) + ^ String.concat ", " (List.map (texpression_to_string ctx) inputs) ^ "\n- inst_sg.inputs (" ^ string_of_int (List.length inst_sg.inputs) ^ "): " - ^ String.concat ", " (List.map show_ty inst_sg.inputs))); + ^ String.concat ", " (List.map (ty_to_string ctx) inst_sg.inputs))); List.iter (fun (x, ty) -> assert ((x : texpression).ty = ty)) (List.combine inputs inst_sg.inputs); @@ -2272,6 +2338,13 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = loop.input_svalues ^ "\n- filtered svl: " ^ (Print.list_to_string (symbolic_value_to_string ctx)) svl + ^ "\n- rg_to_abs\n:" + ^ T.RegionGroupId.Map.show + (fun (rids, tys) -> + "(" ^ T.RegionId.Set.show rids ^ ", " + ^ Print.list_to_string (rty_to_string ctx) tys + ^ ")") + loop.rg_to_given_back_tys ^ "\n")); let ctx, _ = fresh_vars_for_symbolic_values svl ctx in ctx @@ -2294,6 +2367,39 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = List.map (fun var -> mk_typed_pattern_from_var var None) inputs in + (* Compute the backward outputs *) + let ctx = ref ctx in + let loop_backward_outputs = + T.RegionGroupId.Map.map + (fun (_, tys) -> + (* The types shouldn't contain borrows - we can translate them as forward types *) + let vars = + List.map + (fun ty -> + assert ( + not (TypesUtils.ty_has_borrows !ctx.type_context.types_infos ty)); + (None, ctx_translate_fwd_ty !ctx ty)) + tys + in + (* Introduce fresh variables *) + let ctx', vars = fresh_vars vars !ctx in + ctx := ctx'; + vars) + loop.rg_to_given_back_tys + in + let ctx = !ctx in + + let back_output_tys = + match ctx.bid with + | None -> None + | Some rg_id -> + let back_outputs = + T.RegionGroupId.Map.find rg_id loop_backward_outputs + in + let back_output_tys = List.map (fun (v : var) -> v.ty) back_outputs in + Some back_output_tys + in + (* Add the loop information in the context *) let ctx = assert (not (LoopId.Map.mem loop_id ctx.loops)); @@ -2319,7 +2425,13 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = in (* Update the context to translate the function end *) - let ctx_end = { ctx with loop_id = Some loop_id } in + let ctx_end = + { + ctx with + loop_id = Some loop_id; + loop_backward_outputs = Some loop_backward_outputs; + } + in let fun_end = translate_expression loop.end_expr ctx_end in (* Update the context for the loop body *) @@ -2339,10 +2451,13 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = input_state = (if !Config.use_state then Some ctx.state_var else None); inputs; inputs_lvs; + back_output_tys; loop_body; } in - assert (fun_end.ty = loop_body.ty); + (* If we translate forward functions: the return type of a loop body is the + same as the parent function *) + assert (Option.is_some ctx.bid || fun_end.ty = loop_body.ty); let ty = fun_end.ty in { e = loop; ty } @@ -2524,7 +2639,8 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = ^ "\n- back_state: " ^ String.concat ", " (List.map show_var back_state) ^ "\n- signature.inputs: " - ^ String.concat ", " (List.map show_ty signature.inputs))); + ^ String.concat ", " (List.map (ty_to_string ctx) signature.inputs) + )); assert ( List.for_all (fun (var, ty) -> (var : var).ty = ty) diff --git a/compiler/SynthesizeSymbolic.ml b/compiler/SynthesizeSymbolic.ml index 8c06717a..976b781d 100644 --- a/compiler/SynthesizeSymbolic.ml +++ b/compiler/SynthesizeSymbolic.ml @@ -163,10 +163,22 @@ let synthesize_forward_end (ctx : Contexts.eval_ctx) let synthesize_loop (loop_id : V.LoopId.id) (input_svalues : V.symbolic_value list) - (fresh_svalues : V.SymbolicValueId.Set.t) (end_expr : expression option) - (loop_expr : expression option) : expression option = + (fresh_svalues : V.SymbolicValueId.Set.t) + (rg_to_given_back_tys : + (T.RegionId.Set.t * T.rty list) T.RegionGroupId.Map.t) + (end_expr : expression option) (loop_expr : expression option) : + expression option = match (end_expr, loop_expr) with | None, None -> None | Some end_expr, Some loop_expr -> - Some (Loop { loop_id; input_svalues; fresh_svalues; end_expr; loop_expr }) + Some + (Loop + { + loop_id; + input_svalues; + fresh_svalues; + rg_to_given_back_tys; + end_expr; + loop_expr; + }) | _ -> raise (Failure "Unreachable") diff --git a/compiler/Translate.ml b/compiler/Translate.ml index 800bac00..66280ed7 100644 --- a/compiler/Translate.ml +++ b/compiler/Translate.ml @@ -144,6 +144,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx) backward_inputs = T.RegionGroupId.Map.empty; (* Empty for now *) backward_outputs = T.RegionGroupId.Map.empty; + loop_backward_outputs = None; (* Empty for now *) calls; abstractions; diff --git a/tests/coq/misc/Loops.v b/tests/coq/misc/Loops.v index 29f312bf..22c2fd19 100644 --- a/tests/coq/misc/Loops.v +++ b/tests/coq/misc/Loops.v @@ -184,6 +184,79 @@ Definition list_nth_shared_loop_fwd list_nth_shared_loop_loop_fwd T n ls i . +(** [loops::get_elem_mut] *) +Fixpoint get_elem_mut_loop_fwd + (n : nat) (x : usize) (ls : List_t usize) : result usize := + match n with + | O => Fail_ OutOfFuel + | S n0 => + match ls with + | ListCons y tl => + if y s= x then Return y else get_elem_mut_loop_fwd n0 x tl + | ListNil => Fail_ Failure + end + end +. + +(** [loops::get_elem_mut] *) +Definition get_elem_mut_fwd + (n : nat) (slots : vec (List_t usize)) (x : usize) : result usize := + l <- vec_index_mut_fwd (List_t usize) slots (0%usize); + get_elem_mut_loop_fwd n x l +. + +(** [loops::get_elem_mut] *) +Fixpoint get_elem_mut_loop_back + (n : nat) (x : usize) (ls : List_t usize) (ret : usize) : + result (List_t usize) + := + match n with + | O => Fail_ OutOfFuel + | S n0 => + match ls with + | ListCons y tl => + if y s= x + then Return (ListCons ret tl) + else (l <- get_elem_mut_loop_back n0 x tl ret; Return (ListCons y l)) + | ListNil => Fail_ Failure + end + end +. + +(** [loops::get_elem_mut] *) +Definition get_elem_mut_back + (n : nat) (slots : vec (List_t usize)) (x : usize) (ret : usize) : + result (vec (List_t usize)) + := + l <- vec_index_mut_fwd (List_t usize) slots (0%usize); + l0 <- get_elem_mut_loop_back n x l ret; + vec_index_mut_back (List_t usize) slots (0%usize) l0 +. + +(** [loops::get_elem_shared] *) +Fixpoint get_elem_shared_loop_fwd + (n : nat) (x : usize) (v : vec (List_t usize)) (l : List_t usize) + (ls : List_t usize) : + result usize + := + match n with + | O => Fail_ OutOfFuel + | S n0 => + match ls with + | ListCons y tl => + if y s= x then Return y else get_elem_shared_loop_fwd n0 x v l tl + | ListNil => Fail_ Failure + end + end +. + +(** [loops::get_elem_shared] *) +Definition get_elem_shared_fwd + (n : nat) (slots : vec (List_t usize)) (x : usize) : result usize := + l <- vec_index_fwd (List_t usize) slots (0%usize); + get_elem_shared_loop_fwd n x slots l l +. + (** [loops::id_mut] *) Definition id_mut_fwd (T : Type) (ls : List_t T) : result (List_t T) := Return ls diff --git a/tests/fstar/misc/Loops.Clauses.Template.fst b/tests/fstar/misc/Loops.Clauses.Template.fst index 98d0a8ad..3d475d20 100644 --- a/tests/fstar/misc/Loops.Clauses.Template.fst +++ b/tests/fstar/misc/Loops.Clauses.Template.fst @@ -36,6 +36,16 @@ let list_nth_shared_loop_decreases (t : Type0) (ls : list_t t) (i : u32) : nat = admit () +(** [loops::get_elem_mut]: decreases clause *) +unfold +let get_elem_mut_decreases (x : usize) (ls : list_t usize) : nat = admit () + +(** [loops::get_elem_shared]: decreases clause *) +unfold +let get_elem_shared_decreases (x : usize) (v : vec (list_t usize)) + (l : list_t usize) (ls : list_t usize) : nat = + admit () + (** [loops::list_nth_mut_loop_with_id]: decreases clause *) unfold let list_nth_mut_loop_with_id_decreases (t : Type0) (i : u32) (ls : list_t t) : diff --git a/tests/fstar/misc/Loops.Clauses.fst b/tests/fstar/misc/Loops.Clauses.fst index e673d4ff..57849896 100644 --- a/tests/fstar/misc/Loops.Clauses.fst +++ b/tests/fstar/misc/Loops.Clauses.fst @@ -37,6 +37,16 @@ unfold let list_nth_shared_loop_decreases (t : Type0) (ls : list_t t) (i : u32) : list_t t = ls +(** [loops::get_elem_mut]: decreases clause *) +unfold +let get_elem_mut_decreases (x : usize) (ls : list_t usize) : list_t usize = ls + +(** [loops::get_elem_shared]: decreases clause *) +unfold +let get_elem_shared_decreases (x : usize) (v : vec (list_t usize)) + (l : list_t usize) (ls : list_t usize) : list_t usize = + ls + (** [loops::list_nth_mut_loop_with_id]: decreases clause *) unfold let list_nth_mut_loop_with_id_decreases (t : Type0) (i : u32) (ls : list_t t) : diff --git a/tests/fstar/misc/Loops.Funs.fst b/tests/fstar/misc/Loops.Funs.fst index c0aca975..b7dcd045 100644 --- a/tests/fstar/misc/Loops.Funs.fst +++ b/tests/fstar/misc/Loops.Funs.fst @@ -172,6 +172,73 @@ let rec list_nth_shared_loop_loop_fwd let list_nth_shared_loop_fwd (t : Type0) (ls : list_t t) (i : u32) : result t = list_nth_shared_loop_loop_fwd t ls i +(** [loops::get_elem_mut] *) +let rec get_elem_mut_loop_fwd + (x : usize) (ls : list_t usize) : + Tot (result usize) (decreases (get_elem_mut_decreases x ls)) + = + begin match ls with + | ListCons y tl -> if y = x then Return y else get_elem_mut_loop_fwd x tl + | ListNil -> Fail Failure + end + +(** [loops::get_elem_mut] *) +let get_elem_mut_fwd (slots : vec (list_t usize)) (x : usize) : result usize = + begin match vec_index_mut_fwd (list_t usize) slots 0 with + | Fail e -> Fail e + | Return l -> get_elem_mut_loop_fwd x l + end + +(** [loops::get_elem_mut] *) +let rec get_elem_mut_loop_back + (x : usize) (ls : list_t usize) (ret : usize) : + Tot (result (list_t usize)) (decreases (get_elem_mut_decreases x ls)) + = + begin match ls with + | ListCons y tl -> + if y = x + then Return (ListCons ret tl) + else + begin match get_elem_mut_loop_back x tl ret with + | Fail e -> Fail e + | Return l -> Return (ListCons y l) + end + | ListNil -> Fail Failure + end + +(** [loops::get_elem_mut] *) +let get_elem_mut_back + (slots : vec (list_t usize)) (x : usize) (ret : usize) : + result (vec (list_t usize)) + = + begin match vec_index_mut_fwd (list_t usize) slots 0 with + | Fail e -> Fail e + | Return l -> + begin match get_elem_mut_loop_back x l ret with + | Fail e -> Fail e + | Return l0 -> vec_index_mut_back (list_t usize) slots 0 l0 + end + end + +(** [loops::get_elem_shared] *) +let rec get_elem_shared_loop_fwd + (x : usize) (v : vec (list_t usize)) (l : list_t usize) (ls : list_t usize) : + Tot (result usize) (decreases (get_elem_shared_decreases x v l ls)) + = + begin match ls with + | ListCons y tl -> + if y = x then Return y else get_elem_shared_loop_fwd x v l tl + | ListNil -> Fail Failure + end + +(** [loops::get_elem_shared] *) +let get_elem_shared_fwd + (slots : vec (list_t usize)) (x : usize) : result usize = + begin match vec_index_fwd (list_t usize) slots 0 with + | Fail e -> Fail e + | Return l -> get_elem_shared_loop_fwd x slots l l + end + (** [loops::id_mut] *) let id_mut_fwd (t : Type0) (ls : list_t t) : result (list_t t) = Return ls |