summaryrefslogtreecommitdiff
path: root/compiler
diff options
context:
space:
mode:
authorSon Ho2022-12-14 18:25:49 +0100
committerSon HO2023-02-03 11:21:46 +0100
commit20332f3faa5e1205602c946f1c7abb9b6660e6f0 (patch)
tree4ac5f36f0487a53f6461885fd25c70c06b6f656c /compiler
parent1a912cbf23c31c95041526c71bbd050bb5ac4e7c (diff)
Add a `Loop` node in the pure AST
Diffstat (limited to 'compiler')
-rw-r--r--compiler/Cps.ml10
-rw-r--r--compiler/Extract.ml3
-rw-r--r--compiler/InterpreterLoops.ml8
-rw-r--r--compiler/PrintPure.ml15
-rw-r--r--compiler/Pure.ml34
-rw-r--r--compiler/PureMicroPasses.ml18
-rw-r--r--compiler/PureTypeCheck.ml5
-rw-r--r--compiler/PureUtils.ml6
-rw-r--r--compiler/SymbolicAst.ml10
-rw-r--r--compiler/SynthesizeSymbolic.ml3
10 files changed, 97 insertions, 15 deletions
diff --git a/compiler/Cps.ml b/compiler/Cps.ml
index 1e5c0e70..1b5164a1 100644
--- a/compiler/Cps.ml
+++ b/compiler/Cps.ml
@@ -17,19 +17,21 @@ type statement_eval_res =
| Return
| Panic
| LoopReturn (** We reached a return statement *while inside a loop* *)
- | EndEnterLoop of V.typed_value list
+ | EndEnterLoop of V.typed_value V.SymbolicValueId.Map.t
(** When we enter a loop, we delegate the end of the function is
synthesized with a call to the loop translation. We use this
evaluation result to transmit the fact that we end evaluation
because we entered a loop.
- We provide the list of values for the translated loop function call.
+ We provide the list of values for the translated loop function call
+ (or to be more precise the input values instantiation).
*)
- | EndContinue of V.typed_value list
+ | EndContinue of V.typed_value V.SymbolicValueId.Map.t
(** For loop translations: we end with a continue (i.e., a recursive call
to the translation for the loop body).
- We provide the list of values for the translated loop function call.
+ We provide the list of values for the translated loop function call
+ (or to be more precise the input values instantiation).
*)
[@@deriving show]
diff --git a/compiler/Extract.ml b/compiler/Extract.ml
index fbfcadfd..fa384de6 100644
--- a/compiler/Extract.ml
+++ b/compiler/Extract.ml
@@ -1394,6 +1394,9 @@ let rec extract_texpression (ctx : extraction_ctx) (fmt : F.formatter)
| Let (_, _, _, _) -> extract_lets ctx fmt inside e
| Switch (scrut, body) -> extract_Switch ctx fmt inside scrut body
| Meta (_, e) -> extract_texpression ctx fmt inside e
+ | Loop _ ->
+ (* The loop nodes should have been eliminated in {!PureMicroPasses} *)
+ raise (Failure "Unreachable")
(* Extract an application *or* a top-level qualif (function extraction has
* to handle top-level qualifiers, so it seemed more natural to merge the
diff --git a/compiler/InterpreterLoops.ml b/compiler/InterpreterLoops.ml
index 48292968..29e68ca0 100644
--- a/compiler/InterpreterLoops.ml
+++ b/compiler/InterpreterLoops.ml
@@ -3061,12 +3061,8 @@ let match_ctx_with_target (config : C.config) (loop_id : V.LoopId.id)
in
let cc = InterpreterBorrows.end_borrows config new_borrows in
- (* List the loop input values - when iterating over a map, we iterate
- over the keys, in increasing order *)
- let input_values =
- List.map snd
- (V.SymbolicValueId.Map.bindings tgt_to_src_maps.sid_to_value_map)
- in
+ (* Compute the loop input values *)
+ let input_values = tgt_to_src_maps.sid_to_value_map in
(* Continue *)
cc
diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml
index 152e29c0..c83858b3 100644
--- a/compiler/PrintPure.ml
+++ b/compiler/PrintPure.ml
@@ -508,6 +508,9 @@ let rec texpression_to_string (fmt : ast_formatter) (inside : bool)
| Switch (scrutinee, body) ->
let e = switch_to_string fmt indent indent_incr scrutinee body in
if inside then "(" ^ e ^ ")" else e
+ | Loop loop ->
+ let e = loop_to_string fmt indent indent_incr loop in
+ if inside then "(" ^ e ^ ")" else e
| Meta (meta, e) -> (
let meta_s = meta_to_string fmt meta in
let e = texpression_to_string fmt inside indent indent_incr e in
@@ -613,6 +616,18 @@ and switch_to_string (fmt : ast_formatter) (indent : string)
let branches = List.map branch_to_string branches in
"match " ^ scrut ^ " with\n" ^ String.concat "\n" branches
+and loop_to_string (fmt : ast_formatter) (indent : string)
+ (indent_incr : string) (loop : loop) : string =
+ let indent1 = indent ^ indent_incr in
+ let fun_end =
+ texpression_to_string fmt false indent1 indent_incr loop.fun_end
+ in
+ let loop_body =
+ texpression_to_string fmt false indent1 indent_incr loop.loop_body
+ in
+ "loop {\n" ^ indent1 ^ "fun_end: " ^ fun_end ^ "\n" ^ indent1 ^ "loop_body:"
+ ^ loop_body ^ "\n" ^ indent ^ "}"
+
and meta_to_string (fmt : ast_formatter) (meta : meta) : string =
let meta =
match meta with
diff --git a/compiler/Pure.ml b/compiler/Pure.ml
index 9972d539..e6106eed 100644
--- a/compiler/Pure.ml
+++ b/compiler/Pure.ml
@@ -20,6 +20,8 @@ module GlobalDeclId = A.GlobalDeclId
*)
module LoopId = IdGen ()
+type loop_id = LoopId.id [@@deriving show, ord]
+
(** We give an identifier to every phase of the synthesis (forward, backward
for group of regions 0, etc.) *)
module SynthPhaseId = IdGen ()
@@ -365,6 +367,7 @@ class ['self] iter_expression_base =
method visit_integer_type : 'env -> integer_type -> unit = fun _ _ -> ()
method visit_var_id : 'env -> var_id -> unit = fun _ _ -> ()
method visit_qualif : 'env -> qualif -> unit = fun _ _ -> ()
+ method visit_loop_id : 'env -> loop_id -> unit = fun _ _ -> ()
end
(** Ancestor for {!map_expression} visitor *)
@@ -377,6 +380,7 @@ class ['self] map_expression_base =
method visit_var_id : 'env -> var_id -> var_id = fun _ x -> x
method visit_qualif : 'env -> qualif -> qualif = fun _ x -> x
+ method visit_loop_id : 'env -> loop_id -> loop_id = fun _ x -> x
end
(** Ancestor for {!reduce_expression} visitor *)
@@ -389,6 +393,7 @@ class virtual ['self] reduce_expression_base =
method visit_var_id : 'env -> var_id -> 'a = fun _ _ -> self#zero
method visit_qualif : 'env -> qualif -> 'a = fun _ _ -> self#zero
+ method visit_loop_id : 'env -> loop_id -> 'a = fun _ _ -> self#zero
end
(** Ancestor for {!mapreduce_expression} visitor *)
@@ -404,6 +409,9 @@ class virtual ['self] mapreduce_expression_base =
method visit_qualif : 'env -> qualif -> qualif * 'a =
fun _ x -> (x, self#zero)
+
+ method visit_loop_id : 'env -> loop_id -> loop_id * 'a =
+ fun _ x -> (x, self#zero)
end
(** **Rk.:** here, {!expression} is not at all equivalent to the expressions
@@ -464,10 +472,32 @@ type expression =
]}
*)
| Switch of texpression * switch_body
+ | Loop of loop (** See the comments for {!loop} *)
| Meta of (meta[@opaque]) * texpression (** Meta-information *)
and switch_body = If of texpression * texpression | Match of match_branch list
and match_branch = { pat : typed_pattern; branch : texpression }
+
+(** In {!SymbolicToPure}, whenever we encounter a loop we insert a {!loop}
+ node, which contains the end of the function (i.e., the call to the
+ loop function) as well as the *body* of the loop translation (to be
+ more precise, the bodies of the loop forward and backward function).
+ We later split the function definition in {!PureMicroPasses}, to
+ remove this node.
+
+ Note that the loop body is a forward body if the function is
+ a forward function, and a backward body (for the corresponding region
+ group) if the function is a backward function.
+ *)
+and loop = {
+ fun_end : texpression;
+ loop_id : loop_id;
+ inputs : var list;
+ inputs_lvs : typed_pattern list;
+ (** The inputs seen as patterns. See {!fun_body}. *)
+ loop_body : texpression;
+}
+
and texpression = { e : expression; ty : ty }
(** Meta-value (converted to an expression). It is important that the content
@@ -634,6 +664,10 @@ type fun_body = {
type fun_decl = {
def_id : FunDeclId.id;
num_loops : int;
+ (** The number of loops in the parent forward function (basically the number
+ of loops appearing in the original Rust functions, unless some loops are
+ duplicated because we don't join the control-flow after a branching)
+ *)
loop_id : LoopId.id option;
(** [Some] if this definition was generated for a loop *)
back_id : T.RegionGroupId.id option;
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml
index a27b9d95..87ab4609 100644
--- a/compiler/PureMicroPasses.ml
+++ b/compiler/PureMicroPasses.ml
@@ -376,6 +376,7 @@ let compute_pretty_names (def : fun_decl) : fun_decl =
| Qualif _ -> (* nothing to do *) (ctx, e.e)
| Let (monadic, lb, re, e) -> update_let monadic lb re e ctx
| Switch (scrut, body) -> update_switch_body scrut body ctx
+ | Loop loop -> update_loop loop ctx
| Meta (meta, e) -> update_meta meta e ctx
in
(ctx, { e; ty })
@@ -430,6 +431,15 @@ let compute_pretty_names (def : fun_decl) : fun_decl =
in
(ctx, Switch (scrut, body))
(* *)
+ and update_loop (loop : loop) (ctx : pn_ctx) : pn_ctx * expression =
+ let { fun_end; loop_id; inputs; inputs_lvs; loop_body } = loop in
+ let ctx, fun_end = update_texpression fun_end ctx in
+ let ctx, loop_body = update_texpression loop_body ctx in
+ let inputs = List.map (fun input -> update_var ctx input None) inputs in
+ let inputs_lvs = List.map (update_typed_pattern ctx) inputs_lvs in
+ let loop = { fun_end; loop_id; inputs; inputs_lvs; loop_body } in
+ (ctx, Loop loop)
+ (* *)
and update_meta (meta : meta) (e : texpression) (ctx : pn_ctx) :
pn_ctx * expression =
let ctx =
@@ -706,6 +716,9 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx)
(* Note that this case includes functions without arguments *)
fun () -> false
| Meta (_, e) -> self#visit_texpression env e
+ | Loop loop ->
+ (* We only visit the *function end* *)
+ self#visit_texpression env loop.fun_end
| Switch (_, body) -> self#visit_switch_body env body
method! visit_switch_body env body =
@@ -819,6 +832,11 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx)
dont_filter ()
else (* There are used variables: don't filter *)
dont_filter ()
+ | Loop loop ->
+ (* We take care to ignore the varset computed on the *loop body* *)
+ let fun_end, s = self#visit_texpression () loop.fun_end in
+ let loop_body, _ = self#visit_texpression () loop.loop_body in
+ (Loop { loop with fun_end; loop_body }, s)
end
in
(* We filter only inside of transparent (i.e., non-opaque) definitions *)
diff --git a/compiler/PureTypeCheck.ml b/compiler/PureTypeCheck.ml
index fe4fb841..78fd077a 100644
--- a/compiler/PureTypeCheck.ml
+++ b/compiler/PureTypeCheck.ml
@@ -184,6 +184,11 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit =
check_texpression ctx br.branch
in
List.iter check_branch branches)
+ | Loop loop ->
+ assert (loop.fun_end.ty = e.ty);
+ assert (loop.loop_body.ty = e.ty);
+ check_texpression ctx loop.fun_end;
+ check_texpression ctx loop.loop_body
| Meta (_, e_next) ->
assert (e_next.ty = e.ty);
check_texpression ctx e_next
diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml
index da15d635..0e245f35 100644
--- a/compiler/PureUtils.ml
+++ b/compiler/PureUtils.ml
@@ -113,6 +113,9 @@ let fun_sig_substitute (tsubst : TypeVarId.id -> ty) (sg : fun_sig) :
(** We use this to check whether we need to add parentheses around expressions.
We only look for outer monadic let-bindings.
This is used when printing the branches of [if ... then ... else ...].
+
+ Rem.: this function will *fail* if there are {!Loop} nodes (you should call
+ it on an expression where those nodes have been eliminated).
*)
let rec let_group_requires_parentheses (e : texpression) : bool =
match e.e with
@@ -121,6 +124,9 @@ let rec let_group_requires_parentheses (e : texpression) : bool =
if monadic then true else let_group_requires_parentheses next_e
| Switch (_, _) -> false
| Meta (_, next_e) -> let_group_requires_parentheses next_e
+ | Loop _ ->
+ (* Should have been eliminated *)
+ raise (Failure "Unreachable")
let is_var (e : texpression) : bool =
match e.e with Var _ -> true | _ -> false
diff --git a/compiler/SymbolicAst.ml b/compiler/SymbolicAst.ml
index 60b45d99..79865e73 100644
--- a/compiler/SymbolicAst.ml
+++ b/compiler/SymbolicAst.ml
@@ -101,7 +101,9 @@ type expression =
to prettify the generated code.
*)
| ForwardEnd of
- V.typed_value list option * expression * expression T.RegionGroupId.Map.t
+ V.typed_value V.SymbolicValueId.Map.t option
+ * expression
+ * expression T.RegionGroupId.Map.t
(** We use this delimiter to indicate at which point we switch to the
generation of code specific to the backward function(s). This allows
us in particular to factor the work out: we don't need to replay the
@@ -112,9 +114,9 @@ type expression =
function, the map from region group ids to expressions gives the end
of the translation for the backward functions.
- The optional list of input values are input values for loops: upon
- entering a loop, in the translation we call the loop translation
- function, which takes care of the end of the execution.
+ The optional map from symbolic values to input values are input values
+ for loops: upon entering a loop, in the translation we call the loop
+ translation function, which takes care of the end of the execution.
*)
| Loop of loop (** Loop *)
| Meta of meta * expression (** Meta information *)
diff --git a/compiler/SynthesizeSymbolic.ml b/compiler/SynthesizeSymbolic.ml
index a7f84f61..4bb6529b 100644
--- a/compiler/SynthesizeSymbolic.ml
+++ b/compiler/SynthesizeSymbolic.ml
@@ -156,7 +156,8 @@ let synthesize_assertion (ctx : Contexts.eval_ctx) (v : V.typed_value)
(e : expression option) =
Option.map (fun e -> Assertion (ctx, v, e)) e
-let synthesize_forward_end (loop_input_values : V.typed_value list option)
+let synthesize_forward_end
+ (loop_input_values : V.typed_value V.SymbolicValueId.Map.t option)
(e : expression) (el : expression T.RegionGroupId.Map.t) =
Some (ForwardEnd (loop_input_values, e, el))