diff options
Diffstat (limited to '')
-rw-r--r-- | backends/coq/Primitives.v | 42 | ||||
-rw-r--r-- | backends/fstar/Makefile | 47 | ||||
-rw-r--r-- | backends/fstar/Primitives.fst | 32 | ||||
-rw-r--r-- | compiler/ConstStrings.ml | 3 | ||||
-rw-r--r-- | compiler/Extract.ml | 13 | ||||
-rw-r--r-- | compiler/ExtractBase.ml | 9 | ||||
-rw-r--r-- | compiler/FunsAnalysis.ml | 2 | ||||
-rw-r--r-- | compiler/InterpreterExpressions.ml | 8 | ||||
-rw-r--r-- | compiler/InterpreterStatements.ml | 4 | ||||
-rw-r--r-- | compiler/PrintPure.ml | 21 | ||||
-rw-r--r-- | compiler/Pure.ml | 7 | ||||
-rw-r--r-- | compiler/PureMicroPasses.ml | 21 | ||||
-rw-r--r-- | compiler/PureTypeCheck.ml | 8 | ||||
-rw-r--r-- | compiler/PureUtils.ml | 32 | ||||
-rw-r--r-- | compiler/SymbolicToPure.ml | 8 |
15 files changed, 195 insertions, 62 deletions
diff --git a/backends/coq/Primitives.v b/backends/coq/Primitives.v index c27b8aed..9a97d6c7 100644 --- a/backends/coq/Primitives.v +++ b/backends/coq/Primitives.v @@ -13,40 +13,44 @@ Module Primitives. Declare Scope Primitives_scope. (*** Result *) - + +Inductive error := + | Failure + | OutOfFuel. + Inductive result A := | Return : A -> result A - | Fail_ : result A. + | Fail_ : error -> result A. Arguments Return {_} a. Arguments Fail_ {_}. Definition bind {A B} (m: result A) (f: A -> result B) : result B := match m with - | Fail_ => Fail_ + | Fail_ e => Fail_ e | Return x => f x end. -Definition return_ {A: Type} (x: A) : result A := Return x . -Definition fail_ {A: Type} : result A := Fail_ . +Definition return_ {A: Type} (x: A) : result A := Return x. +Definition fail_ {A: Type} (e: error) : result A := Fail_ e. Notation "x <- c1 ; c2" := (bind c1 (fun x => c2)) (at level 61, c1 at next level, right associativity). (** Monadic assert *) Definition massert (b: bool) : result unit := - if b then Return tt else Fail_. + if b then Return tt else Fail_ Failure. (** Normalize and unwrap a successful result (used for globals) *) Definition eval_result_refl {A} {x} (a: result A) (p: a = Return x) : A := match a as r return (r = Return x -> A) with | Return a' => fun _ => a' - | Fail_ => fun p' => - False_rect _ (eq_ind Fail_ + | Fail_ e => fun p' => + False_rect _ (eq_ind (Fail_ e) (fun e : result A => match e with | Return _ => False - | Fail_ => True + | Fail_ e => True end) I (Return x) p') end p. @@ -55,7 +59,7 @@ Notation "x %global" := (eval_result_refl x eq_refl) (at level 40). Notation "x %return" := (eval_result_refl x eq_refl) (at level 40). (* Sanity check *) -Check (if true then Return (1 + 2) else Fail_)%global = 3. +Check (if true then Return (1 + 2) else Fail_ Failure)%global = 3. (*** Misc *) @@ -232,7 +236,7 @@ Import Sumbool. Definition mk_scalar (ty: scalar_ty) (x: Z) : result (scalar ty) := match sumbool_of_bool (scalar_in_bounds ty x) with | left H => Return (exist _ x (scalar_in_bounds_valid _ _ H)) - | right _ => Fail_ + | right _ => Fail_ Failure end. Definition scalar_add {ty} (x y: scalar ty) : result (scalar ty) := mk_scalar ty (to_Z x + to_Z y). @@ -242,7 +246,7 @@ Definition scalar_sub {ty} (x y: scalar ty) : result (scalar ty) := mk_scalar ty Definition scalar_mul {ty} (x y: scalar ty) : result (scalar ty) := mk_scalar ty (to_Z x * to_Z y). Definition scalar_div {ty} (x y: scalar ty) : result (scalar ty) := - if to_Z y =? 0 then Fail_ else + if to_Z y =? 0 then Fail_ Failure else mk_scalar ty (to_Z x / to_Z y). Definition scalar_rem {ty} (x y: scalar ty) : result (scalar ty) := mk_scalar ty (Z.rem (to_Z x) (to_Z y)). @@ -433,7 +437,7 @@ Definition vec_bind {A B} (v: vec A) (f: list A -> result (list B)) : result (ve l <- f (vec_to_list v) ; match sumbool_of_bool (scalar_le_max Usize (Z.of_nat (length l))) with | left H => Return (exist _ l (scalar_le_max_valid _ _ H)) - | right _ => Fail_ + | right _ => Fail_ Failure end. (* The **forward** function shouldn't be used *) @@ -444,35 +448,35 @@ Definition vec_push_back (T: Type) (v: vec T) (x: T) : result (vec T) := (* The **forward** function shouldn't be used *) Definition vec_insert_fwd (T: Type) (v: vec T) (i: usize) (x: T) : result unit := - if to_Z i <? vec_length v then Return tt else Fail_. + if to_Z i <? vec_length v then Return tt else Fail_ Failure. Definition vec_insert_back (T: Type) (v: vec T) (i: usize) (x: T) : result (vec T) := vec_bind v (fun l => if to_Z i <? Z.of_nat (length l) then Return (list_update l (usize_to_nat i) x) - else Fail_). + else Fail_ Failure). (* The **backward** function shouldn't be used *) Definition vec_index_fwd (T: Type) (v: vec T) (i: usize) : result T := match nth_error (vec_to_list v) (usize_to_nat i) with | Some n => Return n - | None => Fail_ + | None => Fail_ Failure end. Definition vec_index_back (T: Type) (v: vec T) (i: usize) (x: T) : result unit := - if to_Z i <? vec_length v then Return tt else Fail_. + if to_Z i <? vec_length v then Return tt else Fail_ Failure. (* The **backward** function shouldn't be used *) Definition vec_index_mut_fwd (T: Type) (v: vec T) (i: usize) : result T := match nth_error (vec_to_list v) (usize_to_nat i) with | Some n => Return n - | None => Fail_ + | None => Fail_ Failure end. Definition vec_index_mut_back (T: Type) (v: vec T) (i: usize) (x: T) : result (vec T) := vec_bind v (fun l => if to_Z i <? Z.of_nat (length l) then Return (list_update l (usize_to_nat i) x) - else Fail_). + else Fail_ Failure). End Primitives. diff --git a/backends/fstar/Makefile b/backends/fstar/Makefile new file mode 100644 index 00000000..a16b0edb --- /dev/null +++ b/backends/fstar/Makefile @@ -0,0 +1,47 @@ +INCLUDE_DIRS = . + +FSTAR_INCLUDES = $(addprefix --include ,$(INCLUDE_DIRS)) + +FSTAR_HINTS ?= --use_hints --use_hint_hashes --record_hints + +FSTAR_OPTIONS = $(FSTAR_HINTS) \ + --cache_checked_modules $(FSTAR_INCLUDES) --cmi \ + --warn_error '+241@247+285-274' \ + +FSTAR_NO_FLAGS = fstar.exe --already_cached 'Prims FStar LowStar Steel' --odir obj --cache_dir obj + +FSTAR = $(FSTAR_NO_FLAGS) $(FSTAR_OPTIONS) + +# The F* roots are used to compute the dependency graph, and generate the .depend file +FSTAR_ROOTS ?= $(wildcard *.fst *.fsti) + +# Build all the files +all: $(addprefix obj/,$(addsuffix .checked,$(FSTAR_ROOTS))) + +# This is the right way to ensure the .depend file always gets re-built. +ifeq (,$(filter %-in,$(MAKECMDGOALS))) +ifndef NODEPEND +ifndef MAKE_RESTARTS +.depend: .FORCE + $(FSTAR_NO_FLAGS) --dep full $(notdir $(FSTAR_ROOTS)) > $@ + +.PHONY: .FORCE +.FORCE: +endif +endif + +include .depend +endif + +# For the interactive mode +%.fst-in %.fsti-in: + @echo $(FSTAR_OPTIONS) + +# Generete the .checked files in batch mode +%.checked: + $(FSTAR) $(FSTAR_OPTIONS) $< && \ + touch -c $@ + +.PHONY: clean +clean: + rm -f obj/* diff --git a/backends/fstar/Primitives.fst b/backends/fstar/Primitives.fst index 96138e46..82622656 100644 --- a/backends/fstar/Primitives.fst +++ b/backends/fstar/Primitives.fst @@ -18,9 +18,13 @@ let rec list_update #a ls i x = #pop-options (*** Result *) +type error : Type0 = +| Failure +| OutOfFuel + type result (a : Type0) : Type0 = | Return : v:a -> result a -| Fail : result a +| Fail : e:error -> result a // Monadic bind and return. // Re-definining those allows us to customize the result of the monadic notations @@ -29,10 +33,10 @@ let return (#a : Type0) (x:a) : result a = Return x let bind (#a #b : Type0) (m : result a) (f : a -> result b) : result b = match m with | Return x -> f x - | Fail -> Fail + | Fail e -> Fail e // Monadic assert(...) -let massert (b:bool) : result unit = if b then Return () else Fail +let massert (b:bool) : result unit = if b then Return () else Fail Failure // Normalize and unwrap a successful result (used for globals). let eval_global (#a : Type0) (x : result a{Return? (normalize_term x)}) : a = Return?.v x @@ -119,12 +123,12 @@ let scalar_max (ty : scalar_ty) : int = type scalar (ty : scalar_ty) : eqtype = x:int{scalar_min ty <= x && x <= scalar_max ty} let mk_scalar (ty : scalar_ty) (x : int) : result (scalar ty) = - if scalar_min ty <= x && scalar_max ty >= x then Return x else Fail + if scalar_min ty <= x && scalar_max ty >= x then Return x else Fail Failure let scalar_neg (#ty : scalar_ty) (x : scalar ty) : result (scalar ty) = mk_scalar ty (-x) let scalar_div (#ty : scalar_ty) (x : scalar ty) (y : scalar ty) : result (scalar ty) = - if y <> 0 then mk_scalar ty (x / y) else Fail + if y <> 0 then mk_scalar ty (x / y) else Fail Failure /// The remainder operation let int_rem (x : int) (y : int{y <> 0}) : int = @@ -137,7 +141,7 @@ let _ = assert_norm(int_rem 1 (-2) = 1) let _ = assert_norm(int_rem (-1) (-2) = -1) let scalar_rem (#ty : scalar_ty) (x : scalar ty) (y : scalar ty) : result (scalar ty) = - if y <> 0 then mk_scalar ty (int_rem x y) else Fail + if y <> 0 then mk_scalar ty (int_rem x y) else Fail Failure let scalar_add (#ty : scalar_ty) (x : scalar ty) (y : scalar ty) : result (scalar ty) = mk_scalar ty (x + y) @@ -258,7 +262,7 @@ let vec_push_back (a : Type0) (v : vec a) (x : a) : (requires True) (ensures (fun res -> match res with - | Fail -> True + | Fail e -> e == Failure | Return v' -> length v' = length v + 1)) = if length v < usize_max then begin (**) assert_norm(length [x] == 1); @@ -266,22 +270,22 @@ let vec_push_back (a : Type0) (v : vec a) (x : a) : (**) assert(length (append v [x]) = length v + 1); Return (append v [x]) end - else Fail + else Fail Failure // The **forward** function shouldn't be used let vec_insert_fwd (a : Type0) (v : vec a) (i : usize) (x : a) : result unit = - if i < length v then Return () else Fail + if i < length v then Return () else Fail Failure let vec_insert_back (a : Type0) (v : vec a) (i : usize) (x : a) : result (vec a) = - if i < length v then Return (list_update v i x) else Fail + if i < length v then Return (list_update v i x) else Fail Failure // The **backward** function shouldn't be used let vec_index_fwd (a : Type0) (v : vec a) (i : usize) : result a = - if i < length v then Return (index v i) else Fail + if i < length v then Return (index v i) else Fail Failure let vec_index_back (a : Type0) (v : vec a) (i : usize) (x : a) : result unit = - if i < length v then Return () else Fail + if i < length v then Return () else Fail Failure let vec_index_mut_fwd (a : Type0) (v : vec a) (i : usize) : result a = - if i < length v then Return (index v i) else Fail + if i < length v then Return (index v i) else Fail Failure let vec_index_mut_back (a : Type0) (v : vec a) (i : usize) (nx : a) : result (vec a) = - if i < length v then Return (list_update v i nx) else Fail + if i < length v then Return (list_update v i nx) else Fail Failure diff --git a/compiler/ConstStrings.ml b/compiler/ConstStrings.ml index ae169a2e..6cf57fe4 100644 --- a/compiler/ConstStrings.ml +++ b/compiler/ConstStrings.ml @@ -5,3 +5,6 @@ let state_basename = "st" (** ADT constructor prefix (used when pretty-printing) *) let constructor_prefix = "Mk" + +(** Basename for error variables *) +let error_basename = "e" diff --git a/compiler/Extract.ml b/compiler/Extract.ml index 13c02bca..17b6aa54 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -126,7 +126,13 @@ let keywords () = List.concat [ named_unops; named_binops; misc ] let assumed_adts : (assumed_ty * string) list = - [ (State, "state"); (Result, "result"); (Option, "option"); (Vec, "vec") ] + [ + (State, "state"); + (Result, "result"); + (Error, "error"); + (Option, "option"); + (Vec, "vec"); + ] let assumed_structs : (assumed_ty * string) list = [] @@ -136,6 +142,8 @@ let assumed_variants () : (assumed_ty * VariantId.id * string) list = [ (Result, result_return_id, "Return"); (Result, result_fail_id, "Fail"); + (Error, error_failure_id, "Failure"); + (Error, error_out_of_fuel_id, "OutOfFuel"); (Option, option_some_id, "Some"); (Option, option_none_id, "None"); ] @@ -143,6 +151,8 @@ let assumed_variants () : (assumed_ty * VariantId.id * string) list = [ (Result, result_return_id, "Return"); (Result, result_fail_id, "Fail_"); + (Error, error_failure_id, "Failure"); + (Error, error_out_of_fuel_id, "OutOfFuel"); (Option, option_some_id, "Some"); (Option, option_none_id, "None"); ] @@ -429,6 +439,7 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) (* The "pair" case is frequent enough to have its special treatment *) if List.length tys = 2 then "p" else "t" | Assumed Result -> "r" + | Assumed Error -> "e" | Assumed Option -> "opt" | Assumed Vec -> "v" | Assumed State -> "st" diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml index 33939e6a..9690d9fc 100644 --- a/compiler/ExtractBase.ml +++ b/compiler/ExtractBase.ml @@ -413,7 +413,7 @@ type extraction_ctx = { (** The indent increment we insert whenever we need to indent more *) } -(** Debugging function *) +(** Debugging function, used when communicating name collisions to the user *) let id_to_string (id : id) (ctx : extraction_ctx) : string = let global_decls = ctx.trans_ctx.global_context.global_decls in let fun_decls = ctx.trans_ctx.fun_context.fun_decls in @@ -467,6 +467,10 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string = if variant_id = result_return_id then "@result::Return" else if variant_id = result_fail_id then "@result::Fail" else raise (Failure "Unreachable") + | Assumed Error -> + if variant_id = error_failure_id then "@error::Failure" + else if variant_id = error_out_of_fuel_id then "@error::OutOfFuel" + else raise (Failure "Unreachable") | Assumed Option -> if variant_id = option_some_id then "@option::Some" else if variant_id = option_none_id then "@option::None" @@ -485,7 +489,8 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string = let field_name = match id with | Tuple -> raise (Failure "Unreachable") - | Assumed (State | Result | Option) -> raise (Failure "Unreachable") + | Assumed (State | Result | Error | Option) -> + raise (Failure "Unreachable") | Assumed Vec -> (* We can't directly have access to the fields of a vector *) raise (Failure "Unreachable") diff --git a/compiler/FunsAnalysis.ml b/compiler/FunsAnalysis.ml index 4d33056b..75a6c0ce 100644 --- a/compiler/FunsAnalysis.ml +++ b/compiler/FunsAnalysis.ml @@ -103,7 +103,7 @@ let analyze_module (m : crate) (funs_map : fun_decl FunDeclId.Map.t) (* We need to know if the declaration group contains a global - note that * groups containing globals contain exactly one declaration *) let is_global_decl_body = List.exists (fun f -> f.is_global_decl_body) d in - assert ((not is_global_decl_body) || List.length d == 1); + assert ((not is_global_decl_body) || List.length d = 1); (* We ignore on purpose functions that cannot fail and consider they *can* * fail: the result of the analysis is not used yet to adjust the translation * so that the functions which syntactically can't fail don't use an error monad. diff --git a/compiler/InterpreterExpressions.ml b/compiler/InterpreterExpressions.ml index 5bc440e7..5d1a3cfe 100644 --- a/compiler/InterpreterExpressions.ml +++ b/compiler/InterpreterExpressions.ml @@ -357,7 +357,7 @@ let eval_unary_op_concrete (config : C.config) (unop : E.unop) (op : E.operand) | Error _ -> cf (Error EPanic) | Ok sv -> cf (Ok { v with V.value = V.Primitive (PV.Scalar sv) })) | E.Cast (src_ty, tgt_ty), V.Primitive (PV.Scalar sv) -> ( - assert (src_ty == sv.int_ty); + assert (src_ty = sv.int_ty); let i = sv.PV.value in match mk_scalar tgt_ty i with | Error _ -> cf (Error EPanic) @@ -637,9 +637,9 @@ let eval_rvalue_aggregate (config : C.config) cf aggregated ctx | E.AggregatedOption (variant_id, ty) -> (* Sanity check *) - if variant_id == T.option_none_id then assert (values == []) - else if variant_id == T.option_some_id then - assert (List.length values == 1) + if variant_id = T.option_none_id then assert (values = []) + else if variant_id = T.option_some_id then + assert (List.length values = 1) else raise (Failure "Unreachable"); (* Construt the value *) let aty = T.Adt (T.Assumed T.Option, [], [ ty ]) in diff --git a/compiler/InterpreterStatements.ml b/compiler/InterpreterStatements.ml index 14dd59b1..3bf7b723 100644 --- a/compiler/InterpreterStatements.ml +++ b/compiler/InterpreterStatements.ml @@ -469,8 +469,8 @@ let eval_box_deref_mut_or_shared_concrete (config : C.config) :: Var (_ret_var, _) :: C.Frame :: _ ) -> (* Required type checking. We must have: - - input_value.ty == & (mut) Box<ty> - - boxed_ty == ty + - input_value.ty = & (mut) Box<ty> + - boxed_ty = ty for some ty *) (let _, input_ty, ref_kind = ty_get_ref input_value.V.ty in diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml index b4ab26b8..0879f553 100644 --- a/compiler/PrintPure.ml +++ b/compiler/PrintPure.ml @@ -128,6 +128,7 @@ let type_id_to_string (fmt : type_formatter) (id : type_id) : string = match aty with | State -> "State" | Result -> "Result" + | Error -> "Error" | Option -> "Option" | Vec -> "Vec") @@ -247,6 +248,11 @@ let adt_variant_to_string (fmt : value_formatter) (adt_id : type_id) else if variant_id = result_fail_id then "@Result::Fail" else raise (Failure "Unreachable: improper variant id for result type") + | Error -> + let variant_id = Option.get variant_id in + if variant_id = error_failure_id then "@Error::Failure" + else if variant_id = error_out_of_fuel_id then "@Error::OutOfFuel" + else raise (Failure "Unreachable: improper variant id for error type") | Option -> let variant_id = Option.get variant_id in if variant_id = option_some_id then "@Option::Some " @@ -275,7 +281,7 @@ let adt_field_to_string (fmt : value_formatter) (adt_id : type_id) | State | Vec -> (* Opaque types: we can't get there *) raise (Failure "Unreachable") - | Result | Option -> + | Result | Error | Option -> (* Enumerations: we can't get there *) raise (Failure "Unreachable")) @@ -324,11 +330,18 @@ let adt_g_value_to_string (fmt : value_formatter) match field_values with | [ v ] -> "@Result::Return " ^ v | _ -> raise (Failure "Result::Return takes exactly one value") - else if variant_id = result_fail_id then ( - assert (field_values = []); - "@Result::Fail") + else if variant_id = result_fail_id then + match field_values with + | [ v ] -> "@Result::Fail " ^ v + | _ -> raise (Failure "Result::Fail takes exactly one value") else raise (Failure "Unreachable: improper variant id for result type") + | Error -> + assert (field_values = []); + let variant_id = Option.get variant_id in + if variant_id = error_failure_id then "@Error::Failure" + else if variant_id = error_out_of_fuel_id then "@Error::OutOfFuel" + else raise (Failure "Unreachable: improper variant id for error type") | Option -> let variant_id = Option.get variant_id in if variant_id = option_some_id then diff --git a/compiler/Pure.ml b/compiler/Pure.ml index b0114baa..6cc73bef 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -26,16 +26,17 @@ type integer_type = T.integer_type [@@deriving show, ord] (** The assumed types for the pure AST. In comparison with LLBC: - - we removed [Box] (because it is translated as the identity: [Box T == T]) + - we removed [Box] (because it is translated as the identity: [Box T = T]) - we added: - [Result]: the type used in the error monad. This allows us to have a unified treatment of expressions (especially when we have to unfold the monadic binds) + - [Error]: the kind of error, in case of failure (used by [Result]) - [State]: the type of the state, when using state-error monads. Note that this state is opaque to Aeneas (the user can define it, or leave it as assumed) *) -type assumed_ty = State | Result | Vec | Option [@@deriving show, ord] +type assumed_ty = State | Result | Error | Vec | Option [@@deriving show, ord] (* TODO: we should never directly manipulate [Return] and [Fail], but rather * the monadic functions [return] and [fail] (makes treatment of error and @@ -44,6 +45,8 @@ let result_return_id = VariantId.of_int 0 let result_fail_id = VariantId.of_int 1 let option_some_id = T.option_some_id let option_none_id = T.option_none_id +let error_failure_id = VariantId.of_int 0 +let error_out_of_fuel_id = VariantId.of_int 1 type type_id = AdtId of TypeDeclId.id | Tuple | Assumed of assumed_ty [@@deriving show, ord] diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index 1cb35613..c5eb3c64 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -123,7 +123,7 @@ type pn_ctx = { {[ let py = id(&mut x); *py = 2; - assert!(x == 2); + assert!(x = 2); ]} After desugaring, we get the following MIR: @@ -131,7 +131,7 @@ type pn_ctx = { ^0 = &mut x; // anonymous variable py = id(move ^0); *py += 2; - assert!(x == 2); + assert!(x = 2); ]} We want this to be translated as: @@ -1228,6 +1228,9 @@ let unfold_monadic_let_bindings (_ctx : trans_ctx) (def : fun_decl) : fun_decl = match def.body with | None -> def | Some body -> + let cnt = get_body_min_var_counter body in + let _, fresh_id = VarId.mk_stateful_generator cnt in + (* It is a very simple map *) let obj = object (_self) @@ -1257,8 +1260,18 @@ let unfold_monadic_let_bindings (_ctx : trans_ctx) (def : fun_decl) : fun_decl = * store in an enum ("monadic" should be an enum, not a bool). *) let re_ty = Option.get (opt_destruct_result re.ty) in assert (lv.ty = re_ty); - let fail_pat = mk_result_fail_pattern lv.ty in - let fail_value = mk_result_fail_texpression e.ty in + let err_vid = fresh_id () in + let err_var : var = + { + id = err_vid; + basename = Some ConstStrings.error_basename; + ty = mk_error_ty; + } + in + let err_pat = mk_typed_pattern_from_var err_var None in + let fail_pat = mk_result_fail_pattern err_pat.value lv.ty in + let err_v = mk_texpression_from_var err_var in + let fail_value = mk_result_fail_texpression err_v e.ty in let fail_branch = { pat = fail_pat; branch = fail_value } in let success_pat = mk_result_return_pattern lv in let success_branch = { pat = success_pat; branch = e } in diff --git a/compiler/PureTypeCheck.ml b/compiler/PureTypeCheck.ml index 6b6a82ad..a1e4e834 100644 --- a/compiler/PureTypeCheck.ml +++ b/compiler/PureTypeCheck.ml @@ -26,9 +26,15 @@ let get_adt_field_types (type_decls : type_decl TypeDeclId.Map.t) let ty = Collections.List.to_cons_nil tys in let variant_id = Option.get variant_id in if variant_id = result_return_id then [ ty ] - else if variant_id = result_fail_id then [] + else if variant_id = result_fail_id then [ mk_error_ty ] else raise (Failure "Unreachable: improper variant id for result type") + | Error -> + assert (tys = []); + let variant_id = Option.get variant_id in + assert ( + variant_id = error_failure_id || variant_id = error_out_of_fuel_id); + [] | Option -> let ty = Collections.List.to_cons_nil tys in let variant_id = Option.get variant_id in diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml index 728a4fe6..f5c280fb 100644 --- a/compiler/PureUtils.ml +++ b/compiler/PureUtils.ml @@ -421,13 +421,21 @@ let type_decl_is_enum (def : T.type_decl) : bool = let mk_state_ty : ty = Adt (Assumed State, []) let mk_result_ty (ty : ty) : ty = Adt (Assumed Result, [ ty ]) +let mk_error_ty : ty = Adt (Assumed Error, []) + +let mk_error (error : VariantId.id) : texpression = + let ty = mk_error_ty in + let id = AdtCons { adt_id = Assumed Error; variant_id = Some error } in + let qualif = { id; type_args = [] } in + let e = Qualif qualif in + { e; ty } let unwrap_result_ty (ty : ty) : ty = match ty with | Adt (Assumed Result, [ ty ]) -> ty | _ -> raise (Failure "not a result type") -let mk_result_fail_texpression (ty : ty) : texpression = +let mk_result_fail_texpression (error : texpression) (ty : ty) : texpression = let type_args = [ ty ] in let ty = Adt (Assumed Result, type_args) in let id = @@ -435,9 +443,14 @@ let mk_result_fail_texpression (ty : ty) : texpression = in let qualif = { id; type_args } in let cons_e = Qualif qualif in - let cons_ty = ty in + let cons_ty = mk_arrow error.ty ty in let cons = { e = cons_e; ty = cons_ty } in - cons + mk_app cons error + +let mk_result_fail_texpression_with_error_id (error : VariantId.id) (ty : ty) : + texpression = + let error = mk_error error in + mk_result_fail_texpression error ty let mk_result_return_texpression (v : texpression) : texpression = let type_args = [ v.ty ] in @@ -451,11 +464,20 @@ let mk_result_return_texpression (v : texpression) : texpression = let cons = { e = cons_e; ty = cons_ty } in mk_app cons v -let mk_result_fail_pattern (ty : ty) : typed_pattern = +(** Create a [Fail err] pattern which captures the error *) +let mk_result_fail_pattern (error_pat : pattern) (ty : ty) : typed_pattern = + let error_pat : typed_pattern = { value = error_pat; ty = mk_error_ty } in let ty = Adt (Assumed Result, [ ty ]) in - let value = PatAdt { variant_id = Some result_fail_id; field_values = [] } in + let value = + PatAdt { variant_id = Some result_fail_id; field_values = [ error_pat ] } + in { value; ty } +(** Create a [Fail _] pattern (we ignore the error) *) +let mk_result_fail_pattern_ignore_error (ty : ty) : typed_pattern = + let error_pat : pattern = PatDummy in + mk_result_fail_pattern error_pat ty + let mk_result_return_pattern (v : typed_pattern) : typed_pattern = let ty = Adt (Assumed Result, [ v.ty ]) in let value = diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 62be5efd..8fa66f93 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -1134,9 +1134,11 @@ and translate_panic (ctx : bs_ctx) : texpression = if ctx.sg.info.effect_info.stateful then (* Create the [Fail] value *) let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; output_ty ] in - let ret_v = mk_result_fail_texpression ret_ty in + let ret_v = + mk_result_fail_texpression_with_error_id error_failure_id ret_ty + in ret_v - else mk_result_fail_texpression output_ty + else mk_result_fail_texpression_with_error_id error_failure_id output_ty (** [opt_v]: the value to return, in case we translate a forward function *) and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression @@ -1661,7 +1663,7 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) | _ -> raise (Failure "Unreachable") in (* We simply introduce an assignment - the box type is the - * identity when extracted ([box a == a]) *) + * identity when extracted ([box a = a]) *) let monadic = false in mk_let monadic (mk_typed_pattern_from_var var None) |