summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon Ho2022-02-03 20:05:35 +0100
committerSon Ho2022-02-03 20:05:35 +0100
commitf674791b2c89f3ed0def6c9cf543bb48410c7229 (patch)
tree8eebc79edb7230d3379df086db8a706a812aa05e
parent5eacfc7cdbe99f401d6cf925cbb50d63c3a780c3 (diff)
Implement extraction of switch int and make extract_texpression return
unit instead of [extraction_ctx]
-rw-r--r--src/ExtractToFStar.ml174
-rw-r--r--src/Pure.ml3
-rw-r--r--src/PureToExtract.ml14
3 files changed, 110 insertions, 81 deletions
diff --git a/src/ExtractToFStar.ml b/src/ExtractToFStar.ml
index 689bd797..e323d156 100644
--- a/src/ExtractToFStar.ml
+++ b/src/ExtractToFStar.ml
@@ -74,10 +74,9 @@ let fstar_names_map_init =
assumed_functions = fstar_assumed_functions;
}
-let fstar_extract_unop (ctx : extraction_ctx) (fmt : F.formatter)
- (extract_expr :
- extraction_ctx -> F.formatter -> bool -> texpression -> extraction_ctx)
- (inside : bool) (unop : unop) (arg : texpression) : extraction_ctx =
+let fstar_extract_unop (extract_expr : bool -> texpression -> unit)
+ (fmt : F.formatter) (inside : bool) (unop : unop) (arg : texpression) : unit
+ =
let unop =
match unop with
| Not -> "not"
@@ -86,50 +85,42 @@ let fstar_extract_unop (ctx : extraction_ctx) (fmt : F.formatter)
if inside then F.pp_print_string fmt "(";
F.pp_print_string fmt unop;
F.pp_print_space fmt ();
- let ctx = extract_expr ctx fmt true arg in
- if inside then F.pp_print_string fmt ")";
- ctx
+ extract_expr true arg;
+ if inside then F.pp_print_string fmt ")"
-let fstar_extract_binop (ctx : extraction_ctx) (fmt : F.formatter)
- (extract_expr :
- extraction_ctx -> F.formatter -> bool -> texpression -> extraction_ctx)
- (inside : bool) (binop : E.binop) (int_ty : integer_type)
- (arg0 : texpression) (arg1 : texpression) : extraction_ctx =
+let fstar_extract_binop (extract_expr : bool -> texpression -> unit)
+ (fmt : F.formatter) (inside : bool) (binop : E.binop)
+ (int_ty : integer_type) (arg0 : texpression) (arg1 : texpression) : unit =
if inside then F.pp_print_string fmt "(";
- let ctx =
- match binop with
- | Eq ->
- let ctx = extract_expr ctx fmt false arg0 in
- F.pp_print_space fmt ();
- F.pp_print_string fmt "=";
- F.pp_print_space fmt ();
- let ctx = extract_expr ctx fmt false arg1 in
- ctx
- | _ ->
- let binop =
- match binop with
- | Eq -> failwith "Unreachable"
- | Lt -> "lt"
- | Le -> "le"
- | Ne -> "ne"
- | Ge -> "ge"
- | Gt -> "gt"
- | Div -> "div"
- | Rem -> "rem"
- | Add -> "add"
- | Sub -> "sub"
- | Mul -> "mul"
- | BitXor | BitAnd | BitOr | Shl | Shr -> raise Unimplemented
- in
- F.pp_print_string fmt (fstar_int_name int_ty ^ "_" ^ binop);
- F.pp_print_space fmt ();
- let ctx = extract_expr ctx fmt false arg0 in
- F.pp_print_space fmt ();
- let ctx = extract_expr ctx fmt false arg1 in
- ctx
- in
- if inside then F.pp_print_string fmt ")";
- ctx
+ (match binop with
+ | Eq ->
+ extract_expr false arg0;
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "=";
+ F.pp_print_space fmt ();
+ extract_expr false arg1
+ | _ ->
+ let binop =
+ match binop with
+ | Eq -> failwith "Unreachable"
+ | Lt -> "lt"
+ | Le -> "le"
+ | Ne -> "ne"
+ | Ge -> "ge"
+ | Gt -> "gt"
+ | Div -> "div"
+ | Rem -> "rem"
+ | Add -> "add"
+ | Sub -> "sub"
+ | Mul -> "mul"
+ | BitXor | BitAnd | BitOr | Shl | Shr -> raise Unimplemented
+ in
+ F.pp_print_string fmt (fstar_int_name int_ty ^ "_" ^ binop);
+ F.pp_print_space fmt ();
+ extract_expr false arg0;
+ F.pp_print_space fmt ();
+ extract_expr false arg1);
+ if inside then F.pp_print_string fmt ")"
(**
* [ctx]: we use the context to lookup type definitions, to retrieve type names.
@@ -694,16 +685,19 @@ let rec extract_typed_rvalue (ctx : extraction_ctx) (fmt : F.formatter)
(** [inside]: see [extract_ty] *)
let rec extract_texpression (ctx : extraction_ctx) (fmt : F.formatter)
- (inside : bool) (e : texpression) : extraction_ctx =
+ (inside : bool) (e : texpression) : unit =
match e.e with
- | Value (rv, _) -> extract_typed_rvalue ctx fmt inside rv
+ | Value (rv, _) ->
+ let _ = extract_typed_rvalue ctx fmt inside rv in
+ ()
| Call call -> (
match (call.func, call.args) with
| Unop unop, [ arg ] ->
- ctx.fmt.extract_unop ctx fmt extract_texpression inside unop arg
+ ctx.fmt.extract_unop (extract_texpression ctx fmt) fmt inside unop arg
| Binop (binop, int_ty), [ arg0; arg1 ] ->
- ctx.fmt.extract_binop ctx fmt extract_texpression inside binop int_ty
- arg0 arg1
+ ctx.fmt.extract_binop
+ (extract_texpression ctx fmt)
+ fmt inside binop int_ty arg0 arg1
| Regular (fun_id, rg_id), _ ->
if inside then F.pp_print_string fmt "(";
(* Open a box for the function call *)
@@ -718,16 +712,11 @@ let rec extract_texpression (ctx : extraction_ctx) (fmt : F.formatter)
extract_ty ctx fmt true ty)
call.type_params;
(* Print the input values *)
- let ctx =
- List.fold_left
- (fun ctx ve -> extract_texpression ctx fmt true ve)
- ctx call.args
- in
+ List.iter (fun ve -> extract_texpression ctx fmt true ve) call.args;
(* Close the box for the function call *)
F.pp_close_box fmt ();
(* Return *)
- if inside then F.pp_print_string fmt ")";
- ctx
+ if inside then F.pp_print_string fmt ")"
| _ -> failwith "Unreachable")
| Let (monadic, lv, re, next_e) ->
(* Open a box for the let-binding *)
@@ -740,7 +729,7 @@ let rec extract_texpression (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_print_space fmt ();
F.pp_print_string fmt "<--";
F.pp_print_space fmt ();
- let ctx = extract_texpression ctx fmt false re in
+ extract_texpression ctx fmt false re;
F.pp_print_string fmt ";";
ctx)
else (
@@ -750,7 +739,7 @@ let rec extract_texpression (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_print_space fmt ();
F.pp_print_string fmt "=";
F.pp_print_space fmt ();
- let ctx = extract_texpression ctx fmt false re in
+ extract_texpression ctx fmt false re;
F.pp_print_space fmt ();
F.pp_print_string fmt "in";
ctx)
@@ -768,11 +757,11 @@ let rec extract_texpression (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_open_hovbox fmt ctx.indent_incr;
F.pp_print_string fmt "if";
F.pp_print_space fmt ();
- let ctx = extract_texpression ctx fmt false scrut in
+ extract_texpression ctx fmt false scrut;
(* Close the box for the `if` *)
F.pp_close_box fmt ();
- let extract_branch (is_then : bool) (e_branch : texpression) :
- extraction_ctx =
+ (* Extract the branches *)
+ let extract_branch (is_then : bool) (e_branch : texpression) : unit =
F.pp_print_space fmt ();
(* Open a box for the branch *)
F.pp_open_hovbox fmt ctx.indent_incr;
@@ -781,19 +770,62 @@ let rec extract_texpression (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_print_space fmt ();
let parenth = PureUtils.expression_requires_parentheses e_branch in
if parenth then F.pp_print_string fmt "(";
- let ctx = extract_texpression ctx fmt false e_branch in
+ extract_texpression ctx fmt false e_branch;
if parenth then F.pp_print_string fmt ")";
(* Close the box for the branch *)
- F.pp_close_box fmt ();
- ctx
+ F.pp_close_box fmt ()
in
- let ctx = extract_branch false e_then in
- let ctx = extract_branch false e_else in
+ extract_branch false e_then;
+ extract_branch false e_else;
(* Close the box for the whole `if ... then ... else ...` *)
+ F.pp_close_box fmt ()
+ | SwitchInt (_, branches, otherwise) ->
+ (* Open a box for the whole match *)
+ F.pp_open_hvbox fmt 0;
+ (* Open a box for the `match ... with` *)
+ F.pp_open_hovbox fmt ctx.indent_incr;
+ (* Print the `match ... with` *)
+ F.pp_print_string fmt "begin match";
+ extract_texpression ctx fmt false scrut;
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "with";
+ (* Close the box for the `match ... with` *)
F.pp_close_box fmt ();
- (* Return *)
- ctx
- | SwitchInt (_, branches, otherwise) -> raise Unimplemented
+
+ (* Extract the branches *)
+ let extract_branch (ctx : extraction_ctx) (sv : scalar_value option)
+ (e_branch : texpression) : unit =
+ F.pp_print_space fmt ();
+ (* Open a box for the branch *)
+ F.pp_open_hovbox fmt ctx.indent_incr;
+ F.pp_print_string fmt "|";
+ (* Print the pattern *)
+ F.pp_print_space fmt ();
+ (match sv with
+ | Some sv -> ctx.fmt.extract_constant_value fmt false (V.Scalar sv)
+ | None -> F.pp_print_string fmt "_");
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "->";
+ (* Print the branch itself *)
+ F.pp_print_space fmt ();
+ extract_texpression ctx fmt false e_branch;
+ (* Close the box for the branch *)
+ F.pp_close_box fmt ()
+ in
+
+ let all_branches =
+ List.map (fun (sv, br) -> (Some sv, br)) branches
+ in
+ let all_branches = List.append all_branches [ (None, otherwise) ] in
+ let ctx =
+ List.iter (fun (sv, br) -> extract_branch ctx sv br) all_branches
+ in
+
+ (* End the match *)
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "end";
+ (* Close the box for the whole match *)
+ F.pp_close_box fmt ()
| Match branches -> raise Unimplemented)
| Meta (_, e) -> extract_texpression ctx fmt inside e
diff --git a/src/Pure.ml b/src/Pure.ml
index fd3eb03f..c1dbaa13 100644
--- a/src/Pure.ml
+++ b/src/Pure.ml
@@ -505,7 +505,8 @@ and switch_body =
| If of texpression * texpression
| SwitchInt of integer_type * (scalar_value * texpression) list * texpression
| Match of match_branch list
-(* TODO: we could (should?) merge SwitchInt and Match *)
+(* TODO: merge SwitchInt and Match. In order to do that,
+ * we need to add constants to lvalues. *)
and match_branch = { pat : typed_lvalue; branch : texpression }
diff --git a/src/PureToExtract.ml b/src/PureToExtract.ml
index 54b59141..47097d7d 100644
--- a/src/PureToExtract.ml
+++ b/src/PureToExtract.ml
@@ -31,7 +31,7 @@ module StringMap = Collections.MakeMap (Collections.OrderedString)
type name = Identifiers.name
-type 'ctx g_formatter = {
+type formatter = {
bool_name : string;
char_name : string;
int_name : integer_type -> string;
@@ -106,13 +106,12 @@ type 'ctx g_formatter = {
if it is made of an application (ex.: `U32 3`)
*)
extract_unop :
- 'ctx ->
+ (bool -> texpression -> unit) ->
F.formatter ->
- ('ctx -> F.formatter -> bool -> texpression -> 'ctx) ->
bool ->
unop ->
texpression ->
- 'ctx;
+ unit;
(** Format a unary operation
Inputs:
@@ -124,15 +123,14 @@ type 'ctx g_formatter = {
- argument
*)
extract_binop :
- 'ctx ->
+ (bool -> texpression -> unit) ->
F.formatter ->
- ('ctx -> F.formatter -> bool -> texpression -> 'ctx) ->
bool ->
E.binop ->
integer_type ->
texpression ->
texpression ->
- 'ctx;
+ unit;
(** Format a binary operation
Inputs:
@@ -298,8 +296,6 @@ type extraction_ctx = {
functions, etc.
*)
-and formatter = extraction_ctx g_formatter
-
let ctx_add (id : id) (name : string) (ctx : extraction_ctx) : extraction_ctx =
(* TODO : nice debugging message if collision *)
let names_map = names_map_add id name ctx.names_map in