From f674791b2c89f3ed0def6c9cf543bb48410c7229 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 3 Feb 2022 20:05:35 +0100 Subject: Implement extraction of switch int and make extract_texpression return unit instead of [extraction_ctx] --- src/ExtractToFStar.ml | 174 ++++++++++++++++++++++++++++++-------------------- src/Pure.ml | 3 +- src/PureToExtract.ml | 14 ++-- 3 files changed, 110 insertions(+), 81 deletions(-) (limited to 'src') diff --git a/src/ExtractToFStar.ml b/src/ExtractToFStar.ml index 689bd797..e323d156 100644 --- a/src/ExtractToFStar.ml +++ b/src/ExtractToFStar.ml @@ -74,10 +74,9 @@ let fstar_names_map_init = assumed_functions = fstar_assumed_functions; } -let fstar_extract_unop (ctx : extraction_ctx) (fmt : F.formatter) - (extract_expr : - extraction_ctx -> F.formatter -> bool -> texpression -> extraction_ctx) - (inside : bool) (unop : unop) (arg : texpression) : extraction_ctx = +let fstar_extract_unop (extract_expr : bool -> texpression -> unit) + (fmt : F.formatter) (inside : bool) (unop : unop) (arg : texpression) : unit + = let unop = match unop with | Not -> "not" @@ -86,50 +85,42 @@ let fstar_extract_unop (ctx : extraction_ctx) (fmt : F.formatter) if inside then F.pp_print_string fmt "("; F.pp_print_string fmt unop; F.pp_print_space fmt (); - let ctx = extract_expr ctx fmt true arg in - if inside then F.pp_print_string fmt ")"; - ctx + extract_expr true arg; + if inside then F.pp_print_string fmt ")" -let fstar_extract_binop (ctx : extraction_ctx) (fmt : F.formatter) - (extract_expr : - extraction_ctx -> F.formatter -> bool -> texpression -> extraction_ctx) - (inside : bool) (binop : E.binop) (int_ty : integer_type) - (arg0 : texpression) (arg1 : texpression) : extraction_ctx = +let fstar_extract_binop (extract_expr : bool -> texpression -> unit) + (fmt : F.formatter) (inside : bool) (binop : E.binop) + (int_ty : integer_type) (arg0 : texpression) (arg1 : texpression) : unit = if inside then F.pp_print_string fmt "("; - let ctx = - match binop with - | Eq -> - let ctx = extract_expr ctx fmt false arg0 in - F.pp_print_space fmt (); - F.pp_print_string fmt "="; - F.pp_print_space fmt (); - let ctx = extract_expr ctx fmt false arg1 in - ctx - | _ -> - let binop = - match binop with - | Eq -> failwith "Unreachable" - | Lt -> "lt" - | Le -> "le" - | Ne -> "ne" - | Ge -> "ge" - | Gt -> "gt" - | Div -> "div" - | Rem -> "rem" - | Add -> "add" - | Sub -> "sub" - | Mul -> "mul" - | BitXor | BitAnd | BitOr | Shl | Shr -> raise Unimplemented - in - F.pp_print_string fmt (fstar_int_name int_ty ^ "_" ^ binop); - F.pp_print_space fmt (); - let ctx = extract_expr ctx fmt false arg0 in - F.pp_print_space fmt (); - let ctx = extract_expr ctx fmt false arg1 in - ctx - in - if inside then F.pp_print_string fmt ")"; - ctx + (match binop with + | Eq -> + extract_expr false arg0; + F.pp_print_space fmt (); + F.pp_print_string fmt "="; + F.pp_print_space fmt (); + extract_expr false arg1 + | _ -> + let binop = + match binop with + | Eq -> failwith "Unreachable" + | Lt -> "lt" + | Le -> "le" + | Ne -> "ne" + | Ge -> "ge" + | Gt -> "gt" + | Div -> "div" + | Rem -> "rem" + | Add -> "add" + | Sub -> "sub" + | Mul -> "mul" + | BitXor | BitAnd | BitOr | Shl | Shr -> raise Unimplemented + in + F.pp_print_string fmt (fstar_int_name int_ty ^ "_" ^ binop); + F.pp_print_space fmt (); + extract_expr false arg0; + F.pp_print_space fmt (); + extract_expr false arg1); + if inside then F.pp_print_string fmt ")" (** * [ctx]: we use the context to lookup type definitions, to retrieve type names. @@ -694,16 +685,19 @@ let rec extract_typed_rvalue (ctx : extraction_ctx) (fmt : F.formatter) (** [inside]: see [extract_ty] *) let rec extract_texpression (ctx : extraction_ctx) (fmt : F.formatter) - (inside : bool) (e : texpression) : extraction_ctx = + (inside : bool) (e : texpression) : unit = match e.e with - | Value (rv, _) -> extract_typed_rvalue ctx fmt inside rv + | Value (rv, _) -> + let _ = extract_typed_rvalue ctx fmt inside rv in + () | Call call -> ( match (call.func, call.args) with | Unop unop, [ arg ] -> - ctx.fmt.extract_unop ctx fmt extract_texpression inside unop arg + ctx.fmt.extract_unop (extract_texpression ctx fmt) fmt inside unop arg | Binop (binop, int_ty), [ arg0; arg1 ] -> - ctx.fmt.extract_binop ctx fmt extract_texpression inside binop int_ty - arg0 arg1 + ctx.fmt.extract_binop + (extract_texpression ctx fmt) + fmt inside binop int_ty arg0 arg1 | Regular (fun_id, rg_id), _ -> if inside then F.pp_print_string fmt "("; (* Open a box for the function call *) @@ -718,16 +712,11 @@ let rec extract_texpression (ctx : extraction_ctx) (fmt : F.formatter) extract_ty ctx fmt true ty) call.type_params; (* Print the input values *) - let ctx = - List.fold_left - (fun ctx ve -> extract_texpression ctx fmt true ve) - ctx call.args - in + List.iter (fun ve -> extract_texpression ctx fmt true ve) call.args; (* Close the box for the function call *) F.pp_close_box fmt (); (* Return *) - if inside then F.pp_print_string fmt ")"; - ctx + if inside then F.pp_print_string fmt ")" | _ -> failwith "Unreachable") | Let (monadic, lv, re, next_e) -> (* Open a box for the let-binding *) @@ -740,7 +729,7 @@ let rec extract_texpression (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_space fmt (); F.pp_print_string fmt "<--"; F.pp_print_space fmt (); - let ctx = extract_texpression ctx fmt false re in + extract_texpression ctx fmt false re; F.pp_print_string fmt ";"; ctx) else ( @@ -750,7 +739,7 @@ let rec extract_texpression (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_space fmt (); F.pp_print_string fmt "="; F.pp_print_space fmt (); - let ctx = extract_texpression ctx fmt false re in + extract_texpression ctx fmt false re; F.pp_print_space fmt (); F.pp_print_string fmt "in"; ctx) @@ -768,11 +757,11 @@ let rec extract_texpression (ctx : extraction_ctx) (fmt : F.formatter) F.pp_open_hovbox fmt ctx.indent_incr; F.pp_print_string fmt "if"; F.pp_print_space fmt (); - let ctx = extract_texpression ctx fmt false scrut in + extract_texpression ctx fmt false scrut; (* Close the box for the `if` *) F.pp_close_box fmt (); - let extract_branch (is_then : bool) (e_branch : texpression) : - extraction_ctx = + (* Extract the branches *) + let extract_branch (is_then : bool) (e_branch : texpression) : unit = F.pp_print_space fmt (); (* Open a box for the branch *) F.pp_open_hovbox fmt ctx.indent_incr; @@ -781,19 +770,62 @@ let rec extract_texpression (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_space fmt (); let parenth = PureUtils.expression_requires_parentheses e_branch in if parenth then F.pp_print_string fmt "("; - let ctx = extract_texpression ctx fmt false e_branch in + extract_texpression ctx fmt false e_branch; if parenth then F.pp_print_string fmt ")"; (* Close the box for the branch *) - F.pp_close_box fmt (); - ctx + F.pp_close_box fmt () in - let ctx = extract_branch false e_then in - let ctx = extract_branch false e_else in + extract_branch false e_then; + extract_branch false e_else; (* Close the box for the whole `if ... then ... else ...` *) + F.pp_close_box fmt () + | SwitchInt (_, branches, otherwise) -> + (* Open a box for the whole match *) + F.pp_open_hvbox fmt 0; + (* Open a box for the `match ... with` *) + F.pp_open_hovbox fmt ctx.indent_incr; + (* Print the `match ... with` *) + F.pp_print_string fmt "begin match"; + extract_texpression ctx fmt false scrut; + F.pp_print_space fmt (); + F.pp_print_string fmt "with"; + (* Close the box for the `match ... with` *) F.pp_close_box fmt (); - (* Return *) - ctx - | SwitchInt (_, branches, otherwise) -> raise Unimplemented + + (* Extract the branches *) + let extract_branch (ctx : extraction_ctx) (sv : scalar_value option) + (e_branch : texpression) : unit = + F.pp_print_space fmt (); + (* Open a box for the branch *) + F.pp_open_hovbox fmt ctx.indent_incr; + F.pp_print_string fmt "|"; + (* Print the pattern *) + F.pp_print_space fmt (); + (match sv with + | Some sv -> ctx.fmt.extract_constant_value fmt false (V.Scalar sv) + | None -> F.pp_print_string fmt "_"); + F.pp_print_space fmt (); + F.pp_print_string fmt "->"; + (* Print the branch itself *) + F.pp_print_space fmt (); + extract_texpression ctx fmt false e_branch; + (* Close the box for the branch *) + F.pp_close_box fmt () + in + + let all_branches = + List.map (fun (sv, br) -> (Some sv, br)) branches + in + let all_branches = List.append all_branches [ (None, otherwise) ] in + let ctx = + List.iter (fun (sv, br) -> extract_branch ctx sv br) all_branches + in + + (* End the match *) + F.pp_print_space fmt (); + F.pp_print_string fmt "end"; + (* Close the box for the whole match *) + F.pp_close_box fmt () | Match branches -> raise Unimplemented) | Meta (_, e) -> extract_texpression ctx fmt inside e diff --git a/src/Pure.ml b/src/Pure.ml index fd3eb03f..c1dbaa13 100644 --- a/src/Pure.ml +++ b/src/Pure.ml @@ -505,7 +505,8 @@ and switch_body = | If of texpression * texpression | SwitchInt of integer_type * (scalar_value * texpression) list * texpression | Match of match_branch list -(* TODO: we could (should?) merge SwitchInt and Match *) +(* TODO: merge SwitchInt and Match. In order to do that, + * we need to add constants to lvalues. *) and match_branch = { pat : typed_lvalue; branch : texpression } diff --git a/src/PureToExtract.ml b/src/PureToExtract.ml index 54b59141..47097d7d 100644 --- a/src/PureToExtract.ml +++ b/src/PureToExtract.ml @@ -31,7 +31,7 @@ module StringMap = Collections.MakeMap (Collections.OrderedString) type name = Identifiers.name -type 'ctx g_formatter = { +type formatter = { bool_name : string; char_name : string; int_name : integer_type -> string; @@ -106,13 +106,12 @@ type 'ctx g_formatter = { if it is made of an application (ex.: `U32 3`) *) extract_unop : - 'ctx -> + (bool -> texpression -> unit) -> F.formatter -> - ('ctx -> F.formatter -> bool -> texpression -> 'ctx) -> bool -> unop -> texpression -> - 'ctx; + unit; (** Format a unary operation Inputs: @@ -124,15 +123,14 @@ type 'ctx g_formatter = { - argument *) extract_binop : - 'ctx -> + (bool -> texpression -> unit) -> F.formatter -> - ('ctx -> F.formatter -> bool -> texpression -> 'ctx) -> bool -> E.binop -> integer_type -> texpression -> texpression -> - 'ctx; + unit; (** Format a binary operation Inputs: @@ -298,8 +296,6 @@ type extraction_ctx = { functions, etc. *) -and formatter = extraction_ctx g_formatter - let ctx_add (id : id) (name : string) (ctx : extraction_ctx) : extraction_ctx = (* TODO : nice debugging message if collision *) let names_map = names_map_add id name ctx.names_map in -- cgit v1.2.3