diff options
Diffstat (limited to '')
-rw-r--r-- | compiler/ConstStrings.ml | 3 | ||||
-rw-r--r-- | compiler/Extract.ml | 13 | ||||
-rw-r--r-- | compiler/ExtractBase.ml | 9 | ||||
-rw-r--r-- | compiler/FunsAnalysis.ml | 2 | ||||
-rw-r--r-- | compiler/InterpreterExpressions.ml | 8 | ||||
-rw-r--r-- | compiler/InterpreterStatements.ml | 4 | ||||
-rw-r--r-- | compiler/PrintPure.ml | 21 | ||||
-rw-r--r-- | compiler/Pure.ml | 7 | ||||
-rw-r--r-- | compiler/PureMicroPasses.ml | 21 | ||||
-rw-r--r-- | compiler/PureTypeCheck.ml | 8 | ||||
-rw-r--r-- | compiler/PureUtils.ml | 32 | ||||
-rw-r--r-- | compiler/SymbolicToPure.ml | 8 |
12 files changed, 107 insertions, 29 deletions
diff --git a/compiler/ConstStrings.ml b/compiler/ConstStrings.ml index ae169a2e..6cf57fe4 100644 --- a/compiler/ConstStrings.ml +++ b/compiler/ConstStrings.ml @@ -5,3 +5,6 @@ let state_basename = "st" (** ADT constructor prefix (used when pretty-printing) *) let constructor_prefix = "Mk" + +(** Basename for error variables *) +let error_basename = "e" diff --git a/compiler/Extract.ml b/compiler/Extract.ml index 13c02bca..17b6aa54 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -126,7 +126,13 @@ let keywords () = List.concat [ named_unops; named_binops; misc ] let assumed_adts : (assumed_ty * string) list = - [ (State, "state"); (Result, "result"); (Option, "option"); (Vec, "vec") ] + [ + (State, "state"); + (Result, "result"); + (Error, "error"); + (Option, "option"); + (Vec, "vec"); + ] let assumed_structs : (assumed_ty * string) list = [] @@ -136,6 +142,8 @@ let assumed_variants () : (assumed_ty * VariantId.id * string) list = [ (Result, result_return_id, "Return"); (Result, result_fail_id, "Fail"); + (Error, error_failure_id, "Failure"); + (Error, error_out_of_fuel_id, "OutOfFuel"); (Option, option_some_id, "Some"); (Option, option_none_id, "None"); ] @@ -143,6 +151,8 @@ let assumed_variants () : (assumed_ty * VariantId.id * string) list = [ (Result, result_return_id, "Return"); (Result, result_fail_id, "Fail_"); + (Error, error_failure_id, "Failure"); + (Error, error_out_of_fuel_id, "OutOfFuel"); (Option, option_some_id, "Some"); (Option, option_none_id, "None"); ] @@ -429,6 +439,7 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) (* The "pair" case is frequent enough to have its special treatment *) if List.length tys = 2 then "p" else "t" | Assumed Result -> "r" + | Assumed Error -> "e" | Assumed Option -> "opt" | Assumed Vec -> "v" | Assumed State -> "st" diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml index 33939e6a..9690d9fc 100644 --- a/compiler/ExtractBase.ml +++ b/compiler/ExtractBase.ml @@ -413,7 +413,7 @@ type extraction_ctx = { (** The indent increment we insert whenever we need to indent more *) } -(** Debugging function *) +(** Debugging function, used when communicating name collisions to the user *) let id_to_string (id : id) (ctx : extraction_ctx) : string = let global_decls = ctx.trans_ctx.global_context.global_decls in let fun_decls = ctx.trans_ctx.fun_context.fun_decls in @@ -467,6 +467,10 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string = if variant_id = result_return_id then "@result::Return" else if variant_id = result_fail_id then "@result::Fail" else raise (Failure "Unreachable") + | Assumed Error -> + if variant_id = error_failure_id then "@error::Failure" + else if variant_id = error_out_of_fuel_id then "@error::OutOfFuel" + else raise (Failure "Unreachable") | Assumed Option -> if variant_id = option_some_id then "@option::Some" else if variant_id = option_none_id then "@option::None" @@ -485,7 +489,8 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string = let field_name = match id with | Tuple -> raise (Failure "Unreachable") - | Assumed (State | Result | Option) -> raise (Failure "Unreachable") + | Assumed (State | Result | Error | Option) -> + raise (Failure "Unreachable") | Assumed Vec -> (* We can't directly have access to the fields of a vector *) raise (Failure "Unreachable") diff --git a/compiler/FunsAnalysis.ml b/compiler/FunsAnalysis.ml index 4d33056b..75a6c0ce 100644 --- a/compiler/FunsAnalysis.ml +++ b/compiler/FunsAnalysis.ml @@ -103,7 +103,7 @@ let analyze_module (m : crate) (funs_map : fun_decl FunDeclId.Map.t) (* We need to know if the declaration group contains a global - note that * groups containing globals contain exactly one declaration *) let is_global_decl_body = List.exists (fun f -> f.is_global_decl_body) d in - assert ((not is_global_decl_body) || List.length d == 1); + assert ((not is_global_decl_body) || List.length d = 1); (* We ignore on purpose functions that cannot fail and consider they *can* * fail: the result of the analysis is not used yet to adjust the translation * so that the functions which syntactically can't fail don't use an error monad. diff --git a/compiler/InterpreterExpressions.ml b/compiler/InterpreterExpressions.ml index 5bc440e7..5d1a3cfe 100644 --- a/compiler/InterpreterExpressions.ml +++ b/compiler/InterpreterExpressions.ml @@ -357,7 +357,7 @@ let eval_unary_op_concrete (config : C.config) (unop : E.unop) (op : E.operand) | Error _ -> cf (Error EPanic) | Ok sv -> cf (Ok { v with V.value = V.Primitive (PV.Scalar sv) })) | E.Cast (src_ty, tgt_ty), V.Primitive (PV.Scalar sv) -> ( - assert (src_ty == sv.int_ty); + assert (src_ty = sv.int_ty); let i = sv.PV.value in match mk_scalar tgt_ty i with | Error _ -> cf (Error EPanic) @@ -637,9 +637,9 @@ let eval_rvalue_aggregate (config : C.config) cf aggregated ctx | E.AggregatedOption (variant_id, ty) -> (* Sanity check *) - if variant_id == T.option_none_id then assert (values == []) - else if variant_id == T.option_some_id then - assert (List.length values == 1) + if variant_id = T.option_none_id then assert (values = []) + else if variant_id = T.option_some_id then + assert (List.length values = 1) else raise (Failure "Unreachable"); (* Construt the value *) let aty = T.Adt (T.Assumed T.Option, [], [ ty ]) in diff --git a/compiler/InterpreterStatements.ml b/compiler/InterpreterStatements.ml index 14dd59b1..3bf7b723 100644 --- a/compiler/InterpreterStatements.ml +++ b/compiler/InterpreterStatements.ml @@ -469,8 +469,8 @@ let eval_box_deref_mut_or_shared_concrete (config : C.config) :: Var (_ret_var, _) :: C.Frame :: _ ) -> (* Required type checking. We must have: - - input_value.ty == & (mut) Box<ty> - - boxed_ty == ty + - input_value.ty = & (mut) Box<ty> + - boxed_ty = ty for some ty *) (let _, input_ty, ref_kind = ty_get_ref input_value.V.ty in diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml index b4ab26b8..0879f553 100644 --- a/compiler/PrintPure.ml +++ b/compiler/PrintPure.ml @@ -128,6 +128,7 @@ let type_id_to_string (fmt : type_formatter) (id : type_id) : string = match aty with | State -> "State" | Result -> "Result" + | Error -> "Error" | Option -> "Option" | Vec -> "Vec") @@ -247,6 +248,11 @@ let adt_variant_to_string (fmt : value_formatter) (adt_id : type_id) else if variant_id = result_fail_id then "@Result::Fail" else raise (Failure "Unreachable: improper variant id for result type") + | Error -> + let variant_id = Option.get variant_id in + if variant_id = error_failure_id then "@Error::Failure" + else if variant_id = error_out_of_fuel_id then "@Error::OutOfFuel" + else raise (Failure "Unreachable: improper variant id for error type") | Option -> let variant_id = Option.get variant_id in if variant_id = option_some_id then "@Option::Some " @@ -275,7 +281,7 @@ let adt_field_to_string (fmt : value_formatter) (adt_id : type_id) | State | Vec -> (* Opaque types: we can't get there *) raise (Failure "Unreachable") - | Result | Option -> + | Result | Error | Option -> (* Enumerations: we can't get there *) raise (Failure "Unreachable")) @@ -324,11 +330,18 @@ let adt_g_value_to_string (fmt : value_formatter) match field_values with | [ v ] -> "@Result::Return " ^ v | _ -> raise (Failure "Result::Return takes exactly one value") - else if variant_id = result_fail_id then ( - assert (field_values = []); - "@Result::Fail") + else if variant_id = result_fail_id then + match field_values with + | [ v ] -> "@Result::Fail " ^ v + | _ -> raise (Failure "Result::Fail takes exactly one value") else raise (Failure "Unreachable: improper variant id for result type") + | Error -> + assert (field_values = []); + let variant_id = Option.get variant_id in + if variant_id = error_failure_id then "@Error::Failure" + else if variant_id = error_out_of_fuel_id then "@Error::OutOfFuel" + else raise (Failure "Unreachable: improper variant id for error type") | Option -> let variant_id = Option.get variant_id in if variant_id = option_some_id then diff --git a/compiler/Pure.ml b/compiler/Pure.ml index b0114baa..6cc73bef 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -26,16 +26,17 @@ type integer_type = T.integer_type [@@deriving show, ord] (** The assumed types for the pure AST. In comparison with LLBC: - - we removed [Box] (because it is translated as the identity: [Box T == T]) + - we removed [Box] (because it is translated as the identity: [Box T = T]) - we added: - [Result]: the type used in the error monad. This allows us to have a unified treatment of expressions (especially when we have to unfold the monadic binds) + - [Error]: the kind of error, in case of failure (used by [Result]) - [State]: the type of the state, when using state-error monads. Note that this state is opaque to Aeneas (the user can define it, or leave it as assumed) *) -type assumed_ty = State | Result | Vec | Option [@@deriving show, ord] +type assumed_ty = State | Result | Error | Vec | Option [@@deriving show, ord] (* TODO: we should never directly manipulate [Return] and [Fail], but rather * the monadic functions [return] and [fail] (makes treatment of error and @@ -44,6 +45,8 @@ let result_return_id = VariantId.of_int 0 let result_fail_id = VariantId.of_int 1 let option_some_id = T.option_some_id let option_none_id = T.option_none_id +let error_failure_id = VariantId.of_int 0 +let error_out_of_fuel_id = VariantId.of_int 1 type type_id = AdtId of TypeDeclId.id | Tuple | Assumed of assumed_ty [@@deriving show, ord] diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index 1cb35613..c5eb3c64 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -123,7 +123,7 @@ type pn_ctx = { {[ let py = id(&mut x); *py = 2; - assert!(x == 2); + assert!(x = 2); ]} After desugaring, we get the following MIR: @@ -131,7 +131,7 @@ type pn_ctx = { ^0 = &mut x; // anonymous variable py = id(move ^0); *py += 2; - assert!(x == 2); + assert!(x = 2); ]} We want this to be translated as: @@ -1228,6 +1228,9 @@ let unfold_monadic_let_bindings (_ctx : trans_ctx) (def : fun_decl) : fun_decl = match def.body with | None -> def | Some body -> + let cnt = get_body_min_var_counter body in + let _, fresh_id = VarId.mk_stateful_generator cnt in + (* It is a very simple map *) let obj = object (_self) @@ -1257,8 +1260,18 @@ let unfold_monadic_let_bindings (_ctx : trans_ctx) (def : fun_decl) : fun_decl = * store in an enum ("monadic" should be an enum, not a bool). *) let re_ty = Option.get (opt_destruct_result re.ty) in assert (lv.ty = re_ty); - let fail_pat = mk_result_fail_pattern lv.ty in - let fail_value = mk_result_fail_texpression e.ty in + let err_vid = fresh_id () in + let err_var : var = + { + id = err_vid; + basename = Some ConstStrings.error_basename; + ty = mk_error_ty; + } + in + let err_pat = mk_typed_pattern_from_var err_var None in + let fail_pat = mk_result_fail_pattern err_pat.value lv.ty in + let err_v = mk_texpression_from_var err_var in + let fail_value = mk_result_fail_texpression err_v e.ty in let fail_branch = { pat = fail_pat; branch = fail_value } in let success_pat = mk_result_return_pattern lv in let success_branch = { pat = success_pat; branch = e } in diff --git a/compiler/PureTypeCheck.ml b/compiler/PureTypeCheck.ml index 6b6a82ad..a1e4e834 100644 --- a/compiler/PureTypeCheck.ml +++ b/compiler/PureTypeCheck.ml @@ -26,9 +26,15 @@ let get_adt_field_types (type_decls : type_decl TypeDeclId.Map.t) let ty = Collections.List.to_cons_nil tys in let variant_id = Option.get variant_id in if variant_id = result_return_id then [ ty ] - else if variant_id = result_fail_id then [] + else if variant_id = result_fail_id then [ mk_error_ty ] else raise (Failure "Unreachable: improper variant id for result type") + | Error -> + assert (tys = []); + let variant_id = Option.get variant_id in + assert ( + variant_id = error_failure_id || variant_id = error_out_of_fuel_id); + [] | Option -> let ty = Collections.List.to_cons_nil tys in let variant_id = Option.get variant_id in diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml index 728a4fe6..f5c280fb 100644 --- a/compiler/PureUtils.ml +++ b/compiler/PureUtils.ml @@ -421,13 +421,21 @@ let type_decl_is_enum (def : T.type_decl) : bool = let mk_state_ty : ty = Adt (Assumed State, []) let mk_result_ty (ty : ty) : ty = Adt (Assumed Result, [ ty ]) +let mk_error_ty : ty = Adt (Assumed Error, []) + +let mk_error (error : VariantId.id) : texpression = + let ty = mk_error_ty in + let id = AdtCons { adt_id = Assumed Error; variant_id = Some error } in + let qualif = { id; type_args = [] } in + let e = Qualif qualif in + { e; ty } let unwrap_result_ty (ty : ty) : ty = match ty with | Adt (Assumed Result, [ ty ]) -> ty | _ -> raise (Failure "not a result type") -let mk_result_fail_texpression (ty : ty) : texpression = +let mk_result_fail_texpression (error : texpression) (ty : ty) : texpression = let type_args = [ ty ] in let ty = Adt (Assumed Result, type_args) in let id = @@ -435,9 +443,14 @@ let mk_result_fail_texpression (ty : ty) : texpression = in let qualif = { id; type_args } in let cons_e = Qualif qualif in - let cons_ty = ty in + let cons_ty = mk_arrow error.ty ty in let cons = { e = cons_e; ty = cons_ty } in - cons + mk_app cons error + +let mk_result_fail_texpression_with_error_id (error : VariantId.id) (ty : ty) : + texpression = + let error = mk_error error in + mk_result_fail_texpression error ty let mk_result_return_texpression (v : texpression) : texpression = let type_args = [ v.ty ] in @@ -451,11 +464,20 @@ let mk_result_return_texpression (v : texpression) : texpression = let cons = { e = cons_e; ty = cons_ty } in mk_app cons v -let mk_result_fail_pattern (ty : ty) : typed_pattern = +(** Create a [Fail err] pattern which captures the error *) +let mk_result_fail_pattern (error_pat : pattern) (ty : ty) : typed_pattern = + let error_pat : typed_pattern = { value = error_pat; ty = mk_error_ty } in let ty = Adt (Assumed Result, [ ty ]) in - let value = PatAdt { variant_id = Some result_fail_id; field_values = [] } in + let value = + PatAdt { variant_id = Some result_fail_id; field_values = [ error_pat ] } + in { value; ty } +(** Create a [Fail _] pattern (we ignore the error) *) +let mk_result_fail_pattern_ignore_error (ty : ty) : typed_pattern = + let error_pat : pattern = PatDummy in + mk_result_fail_pattern error_pat ty + let mk_result_return_pattern (v : typed_pattern) : typed_pattern = let ty = Adt (Assumed Result, [ v.ty ]) in let value = diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 62be5efd..8fa66f93 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -1134,9 +1134,11 @@ and translate_panic (ctx : bs_ctx) : texpression = if ctx.sg.info.effect_info.stateful then (* Create the [Fail] value *) let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; output_ty ] in - let ret_v = mk_result_fail_texpression ret_ty in + let ret_v = + mk_result_fail_texpression_with_error_id error_failure_id ret_ty + in ret_v - else mk_result_fail_texpression output_ty + else mk_result_fail_texpression_with_error_id error_failure_id output_ty (** [opt_v]: the value to return, in case we translate a forward function *) and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression @@ -1661,7 +1663,7 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) | _ -> raise (Failure "Unreachable") in (* We simply introduce an assignment - the box type is the - * identity when extracted ([box a == a]) *) + * identity when extracted ([box a = a]) *) let monadic = false in mk_let monadic (mk_typed_pattern_from_var var None) |