summaryrefslogtreecommitdiff
path: root/compiler
diff options
context:
space:
mode:
authorSon Ho2022-11-14 11:57:53 +0100
committerSon HO2022-11-14 14:21:04 +0100
commit868fa924a37a3af6e701bbc0a2d51fefc2dc7c33 (patch)
treee770fe4d89baf7b1017d2c88d9f866eb54a56ce3 /compiler
parent019a9e34e6375a5e015e4978aad89aa8febc237c (diff)
Make [Result::Failure] type an [Error] parameter
Diffstat (limited to '')
-rw-r--r--compiler/ConstStrings.ml3
-rw-r--r--compiler/Extract.ml13
-rw-r--r--compiler/ExtractBase.ml9
-rw-r--r--compiler/FunsAnalysis.ml2
-rw-r--r--compiler/InterpreterExpressions.ml8
-rw-r--r--compiler/InterpreterStatements.ml4
-rw-r--r--compiler/PrintPure.ml21
-rw-r--r--compiler/Pure.ml7
-rw-r--r--compiler/PureMicroPasses.ml21
-rw-r--r--compiler/PureTypeCheck.ml8
-rw-r--r--compiler/PureUtils.ml32
-rw-r--r--compiler/SymbolicToPure.ml8
12 files changed, 107 insertions, 29 deletions
diff --git a/compiler/ConstStrings.ml b/compiler/ConstStrings.ml
index ae169a2e..6cf57fe4 100644
--- a/compiler/ConstStrings.ml
+++ b/compiler/ConstStrings.ml
@@ -5,3 +5,6 @@ let state_basename = "st"
(** ADT constructor prefix (used when pretty-printing) *)
let constructor_prefix = "Mk"
+
+(** Basename for error variables *)
+let error_basename = "e"
diff --git a/compiler/Extract.ml b/compiler/Extract.ml
index 13c02bca..17b6aa54 100644
--- a/compiler/Extract.ml
+++ b/compiler/Extract.ml
@@ -126,7 +126,13 @@ let keywords () =
List.concat [ named_unops; named_binops; misc ]
let assumed_adts : (assumed_ty * string) list =
- [ (State, "state"); (Result, "result"); (Option, "option"); (Vec, "vec") ]
+ [
+ (State, "state");
+ (Result, "result");
+ (Error, "error");
+ (Option, "option");
+ (Vec, "vec");
+ ]
let assumed_structs : (assumed_ty * string) list = []
@@ -136,6 +142,8 @@ let assumed_variants () : (assumed_ty * VariantId.id * string) list =
[
(Result, result_return_id, "Return");
(Result, result_fail_id, "Fail");
+ (Error, error_failure_id, "Failure");
+ (Error, error_out_of_fuel_id, "OutOfFuel");
(Option, option_some_id, "Some");
(Option, option_none_id, "None");
]
@@ -143,6 +151,8 @@ let assumed_variants () : (assumed_ty * VariantId.id * string) list =
[
(Result, result_return_id, "Return");
(Result, result_fail_id, "Fail_");
+ (Error, error_failure_id, "Failure");
+ (Error, error_out_of_fuel_id, "OutOfFuel");
(Option, option_some_id, "Some");
(Option, option_none_id, "None");
]
@@ -429,6 +439,7 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string)
(* The "pair" case is frequent enough to have its special treatment *)
if List.length tys = 2 then "p" else "t"
| Assumed Result -> "r"
+ | Assumed Error -> "e"
| Assumed Option -> "opt"
| Assumed Vec -> "v"
| Assumed State -> "st"
diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml
index 33939e6a..9690d9fc 100644
--- a/compiler/ExtractBase.ml
+++ b/compiler/ExtractBase.ml
@@ -413,7 +413,7 @@ type extraction_ctx = {
(** The indent increment we insert whenever we need to indent more *)
}
-(** Debugging function *)
+(** Debugging function, used when communicating name collisions to the user *)
let id_to_string (id : id) (ctx : extraction_ctx) : string =
let global_decls = ctx.trans_ctx.global_context.global_decls in
let fun_decls = ctx.trans_ctx.fun_context.fun_decls in
@@ -467,6 +467,10 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string =
if variant_id = result_return_id then "@result::Return"
else if variant_id = result_fail_id then "@result::Fail"
else raise (Failure "Unreachable")
+ | Assumed Error ->
+ if variant_id = error_failure_id then "@error::Failure"
+ else if variant_id = error_out_of_fuel_id then "@error::OutOfFuel"
+ else raise (Failure "Unreachable")
| Assumed Option ->
if variant_id = option_some_id then "@option::Some"
else if variant_id = option_none_id then "@option::None"
@@ -485,7 +489,8 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string =
let field_name =
match id with
| Tuple -> raise (Failure "Unreachable")
- | Assumed (State | Result | Option) -> raise (Failure "Unreachable")
+ | Assumed (State | Result | Error | Option) ->
+ raise (Failure "Unreachable")
| Assumed Vec ->
(* We can't directly have access to the fields of a vector *)
raise (Failure "Unreachable")
diff --git a/compiler/FunsAnalysis.ml b/compiler/FunsAnalysis.ml
index 4d33056b..75a6c0ce 100644
--- a/compiler/FunsAnalysis.ml
+++ b/compiler/FunsAnalysis.ml
@@ -103,7 +103,7 @@ let analyze_module (m : crate) (funs_map : fun_decl FunDeclId.Map.t)
(* We need to know if the declaration group contains a global - note that
* groups containing globals contain exactly one declaration *)
let is_global_decl_body = List.exists (fun f -> f.is_global_decl_body) d in
- assert ((not is_global_decl_body) || List.length d == 1);
+ assert ((not is_global_decl_body) || List.length d = 1);
(* We ignore on purpose functions that cannot fail and consider they *can*
* fail: the result of the analysis is not used yet to adjust the translation
* so that the functions which syntactically can't fail don't use an error monad.
diff --git a/compiler/InterpreterExpressions.ml b/compiler/InterpreterExpressions.ml
index 5bc440e7..5d1a3cfe 100644
--- a/compiler/InterpreterExpressions.ml
+++ b/compiler/InterpreterExpressions.ml
@@ -357,7 +357,7 @@ let eval_unary_op_concrete (config : C.config) (unop : E.unop) (op : E.operand)
| Error _ -> cf (Error EPanic)
| Ok sv -> cf (Ok { v with V.value = V.Primitive (PV.Scalar sv) }))
| E.Cast (src_ty, tgt_ty), V.Primitive (PV.Scalar sv) -> (
- assert (src_ty == sv.int_ty);
+ assert (src_ty = sv.int_ty);
let i = sv.PV.value in
match mk_scalar tgt_ty i with
| Error _ -> cf (Error EPanic)
@@ -637,9 +637,9 @@ let eval_rvalue_aggregate (config : C.config)
cf aggregated ctx
| E.AggregatedOption (variant_id, ty) ->
(* Sanity check *)
- if variant_id == T.option_none_id then assert (values == [])
- else if variant_id == T.option_some_id then
- assert (List.length values == 1)
+ if variant_id = T.option_none_id then assert (values = [])
+ else if variant_id = T.option_some_id then
+ assert (List.length values = 1)
else raise (Failure "Unreachable");
(* Construt the value *)
let aty = T.Adt (T.Assumed T.Option, [], [ ty ]) in
diff --git a/compiler/InterpreterStatements.ml b/compiler/InterpreterStatements.ml
index 14dd59b1..3bf7b723 100644
--- a/compiler/InterpreterStatements.ml
+++ b/compiler/InterpreterStatements.ml
@@ -469,8 +469,8 @@ let eval_box_deref_mut_or_shared_concrete (config : C.config)
:: Var (_ret_var, _)
:: C.Frame :: _ ) ->
(* Required type checking. We must have:
- - input_value.ty == & (mut) Box<ty>
- - boxed_ty == ty
+ - input_value.ty = & (mut) Box<ty>
+ - boxed_ty = ty
for some ty
*)
(let _, input_ty, ref_kind = ty_get_ref input_value.V.ty in
diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml
index b4ab26b8..0879f553 100644
--- a/compiler/PrintPure.ml
+++ b/compiler/PrintPure.ml
@@ -128,6 +128,7 @@ let type_id_to_string (fmt : type_formatter) (id : type_id) : string =
match aty with
| State -> "State"
| Result -> "Result"
+ | Error -> "Error"
| Option -> "Option"
| Vec -> "Vec")
@@ -247,6 +248,11 @@ let adt_variant_to_string (fmt : value_formatter) (adt_id : type_id)
else if variant_id = result_fail_id then "@Result::Fail"
else
raise (Failure "Unreachable: improper variant id for result type")
+ | Error ->
+ let variant_id = Option.get variant_id in
+ if variant_id = error_failure_id then "@Error::Failure"
+ else if variant_id = error_out_of_fuel_id then "@Error::OutOfFuel"
+ else raise (Failure "Unreachable: improper variant id for error type")
| Option ->
let variant_id = Option.get variant_id in
if variant_id = option_some_id then "@Option::Some "
@@ -275,7 +281,7 @@ let adt_field_to_string (fmt : value_formatter) (adt_id : type_id)
| State | Vec ->
(* Opaque types: we can't get there *)
raise (Failure "Unreachable")
- | Result | Option ->
+ | Result | Error | Option ->
(* Enumerations: we can't get there *)
raise (Failure "Unreachable"))
@@ -324,11 +330,18 @@ let adt_g_value_to_string (fmt : value_formatter)
match field_values with
| [ v ] -> "@Result::Return " ^ v
| _ -> raise (Failure "Result::Return takes exactly one value")
- else if variant_id = result_fail_id then (
- assert (field_values = []);
- "@Result::Fail")
+ else if variant_id = result_fail_id then
+ match field_values with
+ | [ v ] -> "@Result::Fail " ^ v
+ | _ -> raise (Failure "Result::Fail takes exactly one value")
else
raise (Failure "Unreachable: improper variant id for result type")
+ | Error ->
+ assert (field_values = []);
+ let variant_id = Option.get variant_id in
+ if variant_id = error_failure_id then "@Error::Failure"
+ else if variant_id = error_out_of_fuel_id then "@Error::OutOfFuel"
+ else raise (Failure "Unreachable: improper variant id for error type")
| Option ->
let variant_id = Option.get variant_id in
if variant_id = option_some_id then
diff --git a/compiler/Pure.ml b/compiler/Pure.ml
index b0114baa..6cc73bef 100644
--- a/compiler/Pure.ml
+++ b/compiler/Pure.ml
@@ -26,16 +26,17 @@ type integer_type = T.integer_type [@@deriving show, ord]
(** The assumed types for the pure AST.
In comparison with LLBC:
- - we removed [Box] (because it is translated as the identity: [Box T == T])
+ - we removed [Box] (because it is translated as the identity: [Box T = T])
- we added:
- [Result]: the type used in the error monad. This allows us to have a
unified treatment of expressions (especially when we have to unfold the
monadic binds)
+ - [Error]: the kind of error, in case of failure (used by [Result])
- [State]: the type of the state, when using state-error monads. Note that
this state is opaque to Aeneas (the user can define it, or leave it as
assumed)
*)
-type assumed_ty = State | Result | Vec | Option [@@deriving show, ord]
+type assumed_ty = State | Result | Error | Vec | Option [@@deriving show, ord]
(* TODO: we should never directly manipulate [Return] and [Fail], but rather
* the monadic functions [return] and [fail] (makes treatment of error and
@@ -44,6 +45,8 @@ let result_return_id = VariantId.of_int 0
let result_fail_id = VariantId.of_int 1
let option_some_id = T.option_some_id
let option_none_id = T.option_none_id
+let error_failure_id = VariantId.of_int 0
+let error_out_of_fuel_id = VariantId.of_int 1
type type_id = AdtId of TypeDeclId.id | Tuple | Assumed of assumed_ty
[@@deriving show, ord]
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml
index 1cb35613..c5eb3c64 100644
--- a/compiler/PureMicroPasses.ml
+++ b/compiler/PureMicroPasses.ml
@@ -123,7 +123,7 @@ type pn_ctx = {
{[
let py = id(&mut x);
*py = 2;
- assert!(x == 2);
+ assert!(x = 2);
]}
After desugaring, we get the following MIR:
@@ -131,7 +131,7 @@ type pn_ctx = {
^0 = &mut x; // anonymous variable
py = id(move ^0);
*py += 2;
- assert!(x == 2);
+ assert!(x = 2);
]}
We want this to be translated as:
@@ -1228,6 +1228,9 @@ let unfold_monadic_let_bindings (_ctx : trans_ctx) (def : fun_decl) : fun_decl =
match def.body with
| None -> def
| Some body ->
+ let cnt = get_body_min_var_counter body in
+ let _, fresh_id = VarId.mk_stateful_generator cnt in
+
(* It is a very simple map *)
let obj =
object (_self)
@@ -1257,8 +1260,18 @@ let unfold_monadic_let_bindings (_ctx : trans_ctx) (def : fun_decl) : fun_decl =
* store in an enum ("monadic" should be an enum, not a bool). *)
let re_ty = Option.get (opt_destruct_result re.ty) in
assert (lv.ty = re_ty);
- let fail_pat = mk_result_fail_pattern lv.ty in
- let fail_value = mk_result_fail_texpression e.ty in
+ let err_vid = fresh_id () in
+ let err_var : var =
+ {
+ id = err_vid;
+ basename = Some ConstStrings.error_basename;
+ ty = mk_error_ty;
+ }
+ in
+ let err_pat = mk_typed_pattern_from_var err_var None in
+ let fail_pat = mk_result_fail_pattern err_pat.value lv.ty in
+ let err_v = mk_texpression_from_var err_var in
+ let fail_value = mk_result_fail_texpression err_v e.ty in
let fail_branch = { pat = fail_pat; branch = fail_value } in
let success_pat = mk_result_return_pattern lv in
let success_branch = { pat = success_pat; branch = e } in
diff --git a/compiler/PureTypeCheck.ml b/compiler/PureTypeCheck.ml
index 6b6a82ad..a1e4e834 100644
--- a/compiler/PureTypeCheck.ml
+++ b/compiler/PureTypeCheck.ml
@@ -26,9 +26,15 @@ let get_adt_field_types (type_decls : type_decl TypeDeclId.Map.t)
let ty = Collections.List.to_cons_nil tys in
let variant_id = Option.get variant_id in
if variant_id = result_return_id then [ ty ]
- else if variant_id = result_fail_id then []
+ else if variant_id = result_fail_id then [ mk_error_ty ]
else
raise (Failure "Unreachable: improper variant id for result type")
+ | Error ->
+ assert (tys = []);
+ let variant_id = Option.get variant_id in
+ assert (
+ variant_id = error_failure_id || variant_id = error_out_of_fuel_id);
+ []
| Option ->
let ty = Collections.List.to_cons_nil tys in
let variant_id = Option.get variant_id in
diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml
index 728a4fe6..f5c280fb 100644
--- a/compiler/PureUtils.ml
+++ b/compiler/PureUtils.ml
@@ -421,13 +421,21 @@ let type_decl_is_enum (def : T.type_decl) : bool =
let mk_state_ty : ty = Adt (Assumed State, [])
let mk_result_ty (ty : ty) : ty = Adt (Assumed Result, [ ty ])
+let mk_error_ty : ty = Adt (Assumed Error, [])
+
+let mk_error (error : VariantId.id) : texpression =
+ let ty = mk_error_ty in
+ let id = AdtCons { adt_id = Assumed Error; variant_id = Some error } in
+ let qualif = { id; type_args = [] } in
+ let e = Qualif qualif in
+ { e; ty }
let unwrap_result_ty (ty : ty) : ty =
match ty with
| Adt (Assumed Result, [ ty ]) -> ty
| _ -> raise (Failure "not a result type")
-let mk_result_fail_texpression (ty : ty) : texpression =
+let mk_result_fail_texpression (error : texpression) (ty : ty) : texpression =
let type_args = [ ty ] in
let ty = Adt (Assumed Result, type_args) in
let id =
@@ -435,9 +443,14 @@ let mk_result_fail_texpression (ty : ty) : texpression =
in
let qualif = { id; type_args } in
let cons_e = Qualif qualif in
- let cons_ty = ty in
+ let cons_ty = mk_arrow error.ty ty in
let cons = { e = cons_e; ty = cons_ty } in
- cons
+ mk_app cons error
+
+let mk_result_fail_texpression_with_error_id (error : VariantId.id) (ty : ty) :
+ texpression =
+ let error = mk_error error in
+ mk_result_fail_texpression error ty
let mk_result_return_texpression (v : texpression) : texpression =
let type_args = [ v.ty ] in
@@ -451,11 +464,20 @@ let mk_result_return_texpression (v : texpression) : texpression =
let cons = { e = cons_e; ty = cons_ty } in
mk_app cons v
-let mk_result_fail_pattern (ty : ty) : typed_pattern =
+(** Create a [Fail err] pattern which captures the error *)
+let mk_result_fail_pattern (error_pat : pattern) (ty : ty) : typed_pattern =
+ let error_pat : typed_pattern = { value = error_pat; ty = mk_error_ty } in
let ty = Adt (Assumed Result, [ ty ]) in
- let value = PatAdt { variant_id = Some result_fail_id; field_values = [] } in
+ let value =
+ PatAdt { variant_id = Some result_fail_id; field_values = [ error_pat ] }
+ in
{ value; ty }
+(** Create a [Fail _] pattern (we ignore the error) *)
+let mk_result_fail_pattern_ignore_error (ty : ty) : typed_pattern =
+ let error_pat : pattern = PatDummy in
+ mk_result_fail_pattern error_pat ty
+
let mk_result_return_pattern (v : typed_pattern) : typed_pattern =
let ty = Adt (Assumed Result, [ v.ty ]) in
let value =
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index 62be5efd..8fa66f93 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -1134,9 +1134,11 @@ and translate_panic (ctx : bs_ctx) : texpression =
if ctx.sg.info.effect_info.stateful then
(* Create the [Fail] value *)
let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; output_ty ] in
- let ret_v = mk_result_fail_texpression ret_ty in
+ let ret_v =
+ mk_result_fail_texpression_with_error_id error_failure_id ret_ty
+ in
ret_v
- else mk_result_fail_texpression output_ty
+ else mk_result_fail_texpression_with_error_id error_failure_id output_ty
(** [opt_v]: the value to return, in case we translate a forward function *)
and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression
@@ -1661,7 +1663,7 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
| _ -> raise (Failure "Unreachable")
in
(* We simply introduce an assignment - the box type is the
- * identity when extracted ([box a == a]) *)
+ * identity when extracted ([box a = a]) *)
let monadic = false in
mk_let monadic
(mk_typed_pattern_from_var var None)