diff options
author | Son Ho | 2022-12-14 18:25:49 +0100 |
---|---|---|
committer | Son HO | 2023-02-03 11:21:46 +0100 |
commit | 20332f3faa5e1205602c946f1c7abb9b6660e6f0 (patch) | |
tree | 4ac5f36f0487a53f6461885fd25c70c06b6f656c | |
parent | 1a912cbf23c31c95041526c71bbd050bb5ac4e7c (diff) |
Add a `Loop` node in the pure AST
-rw-r--r-- | compiler/Cps.ml | 10 | ||||
-rw-r--r-- | compiler/Extract.ml | 3 | ||||
-rw-r--r-- | compiler/InterpreterLoops.ml | 8 | ||||
-rw-r--r-- | compiler/PrintPure.ml | 15 | ||||
-rw-r--r-- | compiler/Pure.ml | 34 | ||||
-rw-r--r-- | compiler/PureMicroPasses.ml | 18 | ||||
-rw-r--r-- | compiler/PureTypeCheck.ml | 5 | ||||
-rw-r--r-- | compiler/PureUtils.ml | 6 | ||||
-rw-r--r-- | compiler/SymbolicAst.ml | 10 | ||||
-rw-r--r-- | compiler/SynthesizeSymbolic.ml | 3 |
10 files changed, 97 insertions, 15 deletions
diff --git a/compiler/Cps.ml b/compiler/Cps.ml index 1e5c0e70..1b5164a1 100644 --- a/compiler/Cps.ml +++ b/compiler/Cps.ml @@ -17,19 +17,21 @@ type statement_eval_res = | Return | Panic | LoopReturn (** We reached a return statement *while inside a loop* *) - | EndEnterLoop of V.typed_value list + | EndEnterLoop of V.typed_value V.SymbolicValueId.Map.t (** When we enter a loop, we delegate the end of the function is synthesized with a call to the loop translation. We use this evaluation result to transmit the fact that we end evaluation because we entered a loop. - We provide the list of values for the translated loop function call. + We provide the list of values for the translated loop function call + (or to be more precise the input values instantiation). *) - | EndContinue of V.typed_value list + | EndContinue of V.typed_value V.SymbolicValueId.Map.t (** For loop translations: we end with a continue (i.e., a recursive call to the translation for the loop body). - We provide the list of values for the translated loop function call. + We provide the list of values for the translated loop function call + (or to be more precise the input values instantiation). *) [@@deriving show] diff --git a/compiler/Extract.ml b/compiler/Extract.ml index fbfcadfd..fa384de6 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -1394,6 +1394,9 @@ let rec extract_texpression (ctx : extraction_ctx) (fmt : F.formatter) | Let (_, _, _, _) -> extract_lets ctx fmt inside e | Switch (scrut, body) -> extract_Switch ctx fmt inside scrut body | Meta (_, e) -> extract_texpression ctx fmt inside e + | Loop _ -> + (* The loop nodes should have been eliminated in {!PureMicroPasses} *) + raise (Failure "Unreachable") (* Extract an application *or* a top-level qualif (function extraction has * to handle top-level qualifiers, so it seemed more natural to merge the diff --git a/compiler/InterpreterLoops.ml b/compiler/InterpreterLoops.ml index 48292968..29e68ca0 100644 --- a/compiler/InterpreterLoops.ml +++ b/compiler/InterpreterLoops.ml @@ -3061,12 +3061,8 @@ let match_ctx_with_target (config : C.config) (loop_id : V.LoopId.id) in let cc = InterpreterBorrows.end_borrows config new_borrows in - (* List the loop input values - when iterating over a map, we iterate - over the keys, in increasing order *) - let input_values = - List.map snd - (V.SymbolicValueId.Map.bindings tgt_to_src_maps.sid_to_value_map) - in + (* Compute the loop input values *) + let input_values = tgt_to_src_maps.sid_to_value_map in (* Continue *) cc diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml index 152e29c0..c83858b3 100644 --- a/compiler/PrintPure.ml +++ b/compiler/PrintPure.ml @@ -508,6 +508,9 @@ let rec texpression_to_string (fmt : ast_formatter) (inside : bool) | Switch (scrutinee, body) -> let e = switch_to_string fmt indent indent_incr scrutinee body in if inside then "(" ^ e ^ ")" else e + | Loop loop -> + let e = loop_to_string fmt indent indent_incr loop in + if inside then "(" ^ e ^ ")" else e | Meta (meta, e) -> ( let meta_s = meta_to_string fmt meta in let e = texpression_to_string fmt inside indent indent_incr e in @@ -613,6 +616,18 @@ and switch_to_string (fmt : ast_formatter) (indent : string) let branches = List.map branch_to_string branches in "match " ^ scrut ^ " with\n" ^ String.concat "\n" branches +and loop_to_string (fmt : ast_formatter) (indent : string) + (indent_incr : string) (loop : loop) : string = + let indent1 = indent ^ indent_incr in + let fun_end = + texpression_to_string fmt false indent1 indent_incr loop.fun_end + in + let loop_body = + texpression_to_string fmt false indent1 indent_incr loop.loop_body + in + "loop {\n" ^ indent1 ^ "fun_end: " ^ fun_end ^ "\n" ^ indent1 ^ "loop_body:" + ^ loop_body ^ "\n" ^ indent ^ "}" + and meta_to_string (fmt : ast_formatter) (meta : meta) : string = let meta = match meta with diff --git a/compiler/Pure.ml b/compiler/Pure.ml index 9972d539..e6106eed 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -20,6 +20,8 @@ module GlobalDeclId = A.GlobalDeclId *) module LoopId = IdGen () +type loop_id = LoopId.id [@@deriving show, ord] + (** We give an identifier to every phase of the synthesis (forward, backward for group of regions 0, etc.) *) module SynthPhaseId = IdGen () @@ -365,6 +367,7 @@ class ['self] iter_expression_base = method visit_integer_type : 'env -> integer_type -> unit = fun _ _ -> () method visit_var_id : 'env -> var_id -> unit = fun _ _ -> () method visit_qualif : 'env -> qualif -> unit = fun _ _ -> () + method visit_loop_id : 'env -> loop_id -> unit = fun _ _ -> () end (** Ancestor for {!map_expression} visitor *) @@ -377,6 +380,7 @@ class ['self] map_expression_base = method visit_var_id : 'env -> var_id -> var_id = fun _ x -> x method visit_qualif : 'env -> qualif -> qualif = fun _ x -> x + method visit_loop_id : 'env -> loop_id -> loop_id = fun _ x -> x end (** Ancestor for {!reduce_expression} visitor *) @@ -389,6 +393,7 @@ class virtual ['self] reduce_expression_base = method visit_var_id : 'env -> var_id -> 'a = fun _ _ -> self#zero method visit_qualif : 'env -> qualif -> 'a = fun _ _ -> self#zero + method visit_loop_id : 'env -> loop_id -> 'a = fun _ _ -> self#zero end (** Ancestor for {!mapreduce_expression} visitor *) @@ -404,6 +409,9 @@ class virtual ['self] mapreduce_expression_base = method visit_qualif : 'env -> qualif -> qualif * 'a = fun _ x -> (x, self#zero) + + method visit_loop_id : 'env -> loop_id -> loop_id * 'a = + fun _ x -> (x, self#zero) end (** **Rk.:** here, {!expression} is not at all equivalent to the expressions @@ -464,10 +472,32 @@ type expression = ]} *) | Switch of texpression * switch_body + | Loop of loop (** See the comments for {!loop} *) | Meta of (meta[@opaque]) * texpression (** Meta-information *) and switch_body = If of texpression * texpression | Match of match_branch list and match_branch = { pat : typed_pattern; branch : texpression } + +(** In {!SymbolicToPure}, whenever we encounter a loop we insert a {!loop} + node, which contains the end of the function (i.e., the call to the + loop function) as well as the *body* of the loop translation (to be + more precise, the bodies of the loop forward and backward function). + We later split the function definition in {!PureMicroPasses}, to + remove this node. + + Note that the loop body is a forward body if the function is + a forward function, and a backward body (for the corresponding region + group) if the function is a backward function. + *) +and loop = { + fun_end : texpression; + loop_id : loop_id; + inputs : var list; + inputs_lvs : typed_pattern list; + (** The inputs seen as patterns. See {!fun_body}. *) + loop_body : texpression; +} + and texpression = { e : expression; ty : ty } (** Meta-value (converted to an expression). It is important that the content @@ -634,6 +664,10 @@ type fun_body = { type fun_decl = { def_id : FunDeclId.id; num_loops : int; + (** The number of loops in the parent forward function (basically the number + of loops appearing in the original Rust functions, unless some loops are + duplicated because we don't join the control-flow after a branching) + *) loop_id : LoopId.id option; (** [Some] if this definition was generated for a loop *) back_id : T.RegionGroupId.id option; diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index a27b9d95..87ab4609 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -376,6 +376,7 @@ let compute_pretty_names (def : fun_decl) : fun_decl = | Qualif _ -> (* nothing to do *) (ctx, e.e) | Let (monadic, lb, re, e) -> update_let monadic lb re e ctx | Switch (scrut, body) -> update_switch_body scrut body ctx + | Loop loop -> update_loop loop ctx | Meta (meta, e) -> update_meta meta e ctx in (ctx, { e; ty }) @@ -430,6 +431,15 @@ let compute_pretty_names (def : fun_decl) : fun_decl = in (ctx, Switch (scrut, body)) (* *) + and update_loop (loop : loop) (ctx : pn_ctx) : pn_ctx * expression = + let { fun_end; loop_id; inputs; inputs_lvs; loop_body } = loop in + let ctx, fun_end = update_texpression fun_end ctx in + let ctx, loop_body = update_texpression loop_body ctx in + let inputs = List.map (fun input -> update_var ctx input None) inputs in + let inputs_lvs = List.map (update_typed_pattern ctx) inputs_lvs in + let loop = { fun_end; loop_id; inputs; inputs_lvs; loop_body } in + (ctx, Loop loop) + (* *) and update_meta (meta : meta) (e : texpression) (ctx : pn_ctx) : pn_ctx * expression = let ctx = @@ -706,6 +716,9 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx) (* Note that this case includes functions without arguments *) fun () -> false | Meta (_, e) -> self#visit_texpression env e + | Loop loop -> + (* We only visit the *function end* *) + self#visit_texpression env loop.fun_end | Switch (_, body) -> self#visit_switch_body env body method! visit_switch_body env body = @@ -819,6 +832,11 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx) dont_filter () else (* There are used variables: don't filter *) dont_filter () + | Loop loop -> + (* We take care to ignore the varset computed on the *loop body* *) + let fun_end, s = self#visit_texpression () loop.fun_end in + let loop_body, _ = self#visit_texpression () loop.loop_body in + (Loop { loop with fun_end; loop_body }, s) end in (* We filter only inside of transparent (i.e., non-opaque) definitions *) diff --git a/compiler/PureTypeCheck.ml b/compiler/PureTypeCheck.ml index fe4fb841..78fd077a 100644 --- a/compiler/PureTypeCheck.ml +++ b/compiler/PureTypeCheck.ml @@ -184,6 +184,11 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit = check_texpression ctx br.branch in List.iter check_branch branches) + | Loop loop -> + assert (loop.fun_end.ty = e.ty); + assert (loop.loop_body.ty = e.ty); + check_texpression ctx loop.fun_end; + check_texpression ctx loop.loop_body | Meta (_, e_next) -> assert (e_next.ty = e.ty); check_texpression ctx e_next diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml index da15d635..0e245f35 100644 --- a/compiler/PureUtils.ml +++ b/compiler/PureUtils.ml @@ -113,6 +113,9 @@ let fun_sig_substitute (tsubst : TypeVarId.id -> ty) (sg : fun_sig) : (** We use this to check whether we need to add parentheses around expressions. We only look for outer monadic let-bindings. This is used when printing the branches of [if ... then ... else ...]. + + Rem.: this function will *fail* if there are {!Loop} nodes (you should call + it on an expression where those nodes have been eliminated). *) let rec let_group_requires_parentheses (e : texpression) : bool = match e.e with @@ -121,6 +124,9 @@ let rec let_group_requires_parentheses (e : texpression) : bool = if monadic then true else let_group_requires_parentheses next_e | Switch (_, _) -> false | Meta (_, next_e) -> let_group_requires_parentheses next_e + | Loop _ -> + (* Should have been eliminated *) + raise (Failure "Unreachable") let is_var (e : texpression) : bool = match e.e with Var _ -> true | _ -> false diff --git a/compiler/SymbolicAst.ml b/compiler/SymbolicAst.ml index 60b45d99..79865e73 100644 --- a/compiler/SymbolicAst.ml +++ b/compiler/SymbolicAst.ml @@ -101,7 +101,9 @@ type expression = to prettify the generated code. *) | ForwardEnd of - V.typed_value list option * expression * expression T.RegionGroupId.Map.t + V.typed_value V.SymbolicValueId.Map.t option + * expression + * expression T.RegionGroupId.Map.t (** We use this delimiter to indicate at which point we switch to the generation of code specific to the backward function(s). This allows us in particular to factor the work out: we don't need to replay the @@ -112,9 +114,9 @@ type expression = function, the map from region group ids to expressions gives the end of the translation for the backward functions. - The optional list of input values are input values for loops: upon - entering a loop, in the translation we call the loop translation - function, which takes care of the end of the execution. + The optional map from symbolic values to input values are input values + for loops: upon entering a loop, in the translation we call the loop + translation function, which takes care of the end of the execution. *) | Loop of loop (** Loop *) | Meta of meta * expression (** Meta information *) diff --git a/compiler/SynthesizeSymbolic.ml b/compiler/SynthesizeSymbolic.ml index a7f84f61..4bb6529b 100644 --- a/compiler/SynthesizeSymbolic.ml +++ b/compiler/SynthesizeSymbolic.ml @@ -156,7 +156,8 @@ let synthesize_assertion (ctx : Contexts.eval_ctx) (v : V.typed_value) (e : expression option) = Option.map (fun e -> Assertion (ctx, v, e)) e -let synthesize_forward_end (loop_input_values : V.typed_value list option) +let synthesize_forward_end + (loop_input_values : V.typed_value V.SymbolicValueId.Map.t option) (e : expression) (el : expression T.RegionGroupId.Map.t) = Some (ForwardEnd (loop_input_values, e, el)) |