summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon Ho2022-11-14 11:57:53 +0100
committerSon HO2022-11-14 14:21:04 +0100
commit868fa924a37a3af6e701bbc0a2d51fefc2dc7c33 (patch)
treee770fe4d89baf7b1017d2c88d9f866eb54a56ce3
parent019a9e34e6375a5e015e4978aad89aa8febc237c (diff)
Make [Result::Failure] type an [Error] parameter
Diffstat (limited to '')
-rw-r--r--backends/coq/Primitives.v42
-rw-r--r--backends/fstar/Makefile47
-rw-r--r--backends/fstar/Primitives.fst32
-rw-r--r--compiler/ConstStrings.ml3
-rw-r--r--compiler/Extract.ml13
-rw-r--r--compiler/ExtractBase.ml9
-rw-r--r--compiler/FunsAnalysis.ml2
-rw-r--r--compiler/InterpreterExpressions.ml8
-rw-r--r--compiler/InterpreterStatements.ml4
-rw-r--r--compiler/PrintPure.ml21
-rw-r--r--compiler/Pure.ml7
-rw-r--r--compiler/PureMicroPasses.ml21
-rw-r--r--compiler/PureTypeCheck.ml8
-rw-r--r--compiler/PureUtils.ml32
-rw-r--r--compiler/SymbolicToPure.ml8
15 files changed, 195 insertions, 62 deletions
diff --git a/backends/coq/Primitives.v b/backends/coq/Primitives.v
index c27b8aed..9a97d6c7 100644
--- a/backends/coq/Primitives.v
+++ b/backends/coq/Primitives.v
@@ -13,40 +13,44 @@ Module Primitives.
Declare Scope Primitives_scope.
(*** Result *)
-
+
+Inductive error :=
+ | Failure
+ | OutOfFuel.
+
Inductive result A :=
| Return : A -> result A
- | Fail_ : result A.
+ | Fail_ : error -> result A.
Arguments Return {_} a.
Arguments Fail_ {_}.
Definition bind {A B} (m: result A) (f: A -> result B) : result B :=
match m with
- | Fail_ => Fail_
+ | Fail_ e => Fail_ e
| Return x => f x
end.
-Definition return_ {A: Type} (x: A) : result A := Return x .
-Definition fail_ {A: Type} : result A := Fail_ .
+Definition return_ {A: Type} (x: A) : result A := Return x.
+Definition fail_ {A: Type} (e: error) : result A := Fail_ e.
Notation "x <- c1 ; c2" := (bind c1 (fun x => c2))
(at level 61, c1 at next level, right associativity).
(** Monadic assert *)
Definition massert (b: bool) : result unit :=
- if b then Return tt else Fail_.
+ if b then Return tt else Fail_ Failure.
(** Normalize and unwrap a successful result (used for globals) *)
Definition eval_result_refl {A} {x} (a: result A) (p: a = Return x) : A :=
match a as r return (r = Return x -> A) with
| Return a' => fun _ => a'
- | Fail_ => fun p' =>
- False_rect _ (eq_ind Fail_
+ | Fail_ e => fun p' =>
+ False_rect _ (eq_ind (Fail_ e)
(fun e : result A =>
match e with
| Return _ => False
- | Fail_ => True
+ | Fail_ e => True
end)
I (Return x) p')
end p.
@@ -55,7 +59,7 @@ Notation "x %global" := (eval_result_refl x eq_refl) (at level 40).
Notation "x %return" := (eval_result_refl x eq_refl) (at level 40).
(* Sanity check *)
-Check (if true then Return (1 + 2) else Fail_)%global = 3.
+Check (if true then Return (1 + 2) else Fail_ Failure)%global = 3.
(*** Misc *)
@@ -232,7 +236,7 @@ Import Sumbool.
Definition mk_scalar (ty: scalar_ty) (x: Z) : result (scalar ty) :=
match sumbool_of_bool (scalar_in_bounds ty x) with
| left H => Return (exist _ x (scalar_in_bounds_valid _ _ H))
- | right _ => Fail_
+ | right _ => Fail_ Failure
end.
Definition scalar_add {ty} (x y: scalar ty) : result (scalar ty) := mk_scalar ty (to_Z x + to_Z y).
@@ -242,7 +246,7 @@ Definition scalar_sub {ty} (x y: scalar ty) : result (scalar ty) := mk_scalar ty
Definition scalar_mul {ty} (x y: scalar ty) : result (scalar ty) := mk_scalar ty (to_Z x * to_Z y).
Definition scalar_div {ty} (x y: scalar ty) : result (scalar ty) :=
- if to_Z y =? 0 then Fail_ else
+ if to_Z y =? 0 then Fail_ Failure else
mk_scalar ty (to_Z x / to_Z y).
Definition scalar_rem {ty} (x y: scalar ty) : result (scalar ty) := mk_scalar ty (Z.rem (to_Z x) (to_Z y)).
@@ -433,7 +437,7 @@ Definition vec_bind {A B} (v: vec A) (f: list A -> result (list B)) : result (ve
l <- f (vec_to_list v) ;
match sumbool_of_bool (scalar_le_max Usize (Z.of_nat (length l))) with
| left H => Return (exist _ l (scalar_le_max_valid _ _ H))
- | right _ => Fail_
+ | right _ => Fail_ Failure
end.
(* The **forward** function shouldn't be used *)
@@ -444,35 +448,35 @@ Definition vec_push_back (T: Type) (v: vec T) (x: T) : result (vec T) :=
(* The **forward** function shouldn't be used *)
Definition vec_insert_fwd (T: Type) (v: vec T) (i: usize) (x: T) : result unit :=
- if to_Z i <? vec_length v then Return tt else Fail_.
+ if to_Z i <? vec_length v then Return tt else Fail_ Failure.
Definition vec_insert_back (T: Type) (v: vec T) (i: usize) (x: T) : result (vec T) :=
vec_bind v (fun l =>
if to_Z i <? Z.of_nat (length l)
then Return (list_update l (usize_to_nat i) x)
- else Fail_).
+ else Fail_ Failure).
(* The **backward** function shouldn't be used *)
Definition vec_index_fwd (T: Type) (v: vec T) (i: usize) : result T :=
match nth_error (vec_to_list v) (usize_to_nat i) with
| Some n => Return n
- | None => Fail_
+ | None => Fail_ Failure
end.
Definition vec_index_back (T: Type) (v: vec T) (i: usize) (x: T) : result unit :=
- if to_Z i <? vec_length v then Return tt else Fail_.
+ if to_Z i <? vec_length v then Return tt else Fail_ Failure.
(* The **backward** function shouldn't be used *)
Definition vec_index_mut_fwd (T: Type) (v: vec T) (i: usize) : result T :=
match nth_error (vec_to_list v) (usize_to_nat i) with
| Some n => Return n
- | None => Fail_
+ | None => Fail_ Failure
end.
Definition vec_index_mut_back (T: Type) (v: vec T) (i: usize) (x: T) : result (vec T) :=
vec_bind v (fun l =>
if to_Z i <? Z.of_nat (length l)
then Return (list_update l (usize_to_nat i) x)
- else Fail_).
+ else Fail_ Failure).
End Primitives.
diff --git a/backends/fstar/Makefile b/backends/fstar/Makefile
new file mode 100644
index 00000000..a16b0edb
--- /dev/null
+++ b/backends/fstar/Makefile
@@ -0,0 +1,47 @@
+INCLUDE_DIRS = .
+
+FSTAR_INCLUDES = $(addprefix --include ,$(INCLUDE_DIRS))
+
+FSTAR_HINTS ?= --use_hints --use_hint_hashes --record_hints
+
+FSTAR_OPTIONS = $(FSTAR_HINTS) \
+ --cache_checked_modules $(FSTAR_INCLUDES) --cmi \
+ --warn_error '+241@247+285-274' \
+
+FSTAR_NO_FLAGS = fstar.exe --already_cached 'Prims FStar LowStar Steel' --odir obj --cache_dir obj
+
+FSTAR = $(FSTAR_NO_FLAGS) $(FSTAR_OPTIONS)
+
+# The F* roots are used to compute the dependency graph, and generate the .depend file
+FSTAR_ROOTS ?= $(wildcard *.fst *.fsti)
+
+# Build all the files
+all: $(addprefix obj/,$(addsuffix .checked,$(FSTAR_ROOTS)))
+
+# This is the right way to ensure the .depend file always gets re-built.
+ifeq (,$(filter %-in,$(MAKECMDGOALS)))
+ifndef NODEPEND
+ifndef MAKE_RESTARTS
+.depend: .FORCE
+ $(FSTAR_NO_FLAGS) --dep full $(notdir $(FSTAR_ROOTS)) > $@
+
+.PHONY: .FORCE
+.FORCE:
+endif
+endif
+
+include .depend
+endif
+
+# For the interactive mode
+%.fst-in %.fsti-in:
+ @echo $(FSTAR_OPTIONS)
+
+# Generete the .checked files in batch mode
+%.checked:
+ $(FSTAR) $(FSTAR_OPTIONS) $< && \
+ touch -c $@
+
+.PHONY: clean
+clean:
+ rm -f obj/*
diff --git a/backends/fstar/Primitives.fst b/backends/fstar/Primitives.fst
index 96138e46..82622656 100644
--- a/backends/fstar/Primitives.fst
+++ b/backends/fstar/Primitives.fst
@@ -18,9 +18,13 @@ let rec list_update #a ls i x =
#pop-options
(*** Result *)
+type error : Type0 =
+| Failure
+| OutOfFuel
+
type result (a : Type0) : Type0 =
| Return : v:a -> result a
-| Fail : result a
+| Fail : e:error -> result a
// Monadic bind and return.
// Re-definining those allows us to customize the result of the monadic notations
@@ -29,10 +33,10 @@ let return (#a : Type0) (x:a) : result a = Return x
let bind (#a #b : Type0) (m : result a) (f : a -> result b) : result b =
match m with
| Return x -> f x
- | Fail -> Fail
+ | Fail e -> Fail e
// Monadic assert(...)
-let massert (b:bool) : result unit = if b then Return () else Fail
+let massert (b:bool) : result unit = if b then Return () else Fail Failure
// Normalize and unwrap a successful result (used for globals).
let eval_global (#a : Type0) (x : result a{Return? (normalize_term x)}) : a = Return?.v x
@@ -119,12 +123,12 @@ let scalar_max (ty : scalar_ty) : int =
type scalar (ty : scalar_ty) : eqtype = x:int{scalar_min ty <= x && x <= scalar_max ty}
let mk_scalar (ty : scalar_ty) (x : int) : result (scalar ty) =
- if scalar_min ty <= x && scalar_max ty >= x then Return x else Fail
+ if scalar_min ty <= x && scalar_max ty >= x then Return x else Fail Failure
let scalar_neg (#ty : scalar_ty) (x : scalar ty) : result (scalar ty) = mk_scalar ty (-x)
let scalar_div (#ty : scalar_ty) (x : scalar ty) (y : scalar ty) : result (scalar ty) =
- if y <> 0 then mk_scalar ty (x / y) else Fail
+ if y <> 0 then mk_scalar ty (x / y) else Fail Failure
/// The remainder operation
let int_rem (x : int) (y : int{y <> 0}) : int =
@@ -137,7 +141,7 @@ let _ = assert_norm(int_rem 1 (-2) = 1)
let _ = assert_norm(int_rem (-1) (-2) = -1)
let scalar_rem (#ty : scalar_ty) (x : scalar ty) (y : scalar ty) : result (scalar ty) =
- if y <> 0 then mk_scalar ty (int_rem x y) else Fail
+ if y <> 0 then mk_scalar ty (int_rem x y) else Fail Failure
let scalar_add (#ty : scalar_ty) (x : scalar ty) (y : scalar ty) : result (scalar ty) =
mk_scalar ty (x + y)
@@ -258,7 +262,7 @@ let vec_push_back (a : Type0) (v : vec a) (x : a) :
(requires True)
(ensures (fun res ->
match res with
- | Fail -> True
+ | Fail e -> e == Failure
| Return v' -> length v' = length v + 1)) =
if length v < usize_max then begin
(**) assert_norm(length [x] == 1);
@@ -266,22 +270,22 @@ let vec_push_back (a : Type0) (v : vec a) (x : a) :
(**) assert(length (append v [x]) = length v + 1);
Return (append v [x])
end
- else Fail
+ else Fail Failure
// The **forward** function shouldn't be used
let vec_insert_fwd (a : Type0) (v : vec a) (i : usize) (x : a) : result unit =
- if i < length v then Return () else Fail
+ if i < length v then Return () else Fail Failure
let vec_insert_back (a : Type0) (v : vec a) (i : usize) (x : a) : result (vec a) =
- if i < length v then Return (list_update v i x) else Fail
+ if i < length v then Return (list_update v i x) else Fail Failure
// The **backward** function shouldn't be used
let vec_index_fwd (a : Type0) (v : vec a) (i : usize) : result a =
- if i < length v then Return (index v i) else Fail
+ if i < length v then Return (index v i) else Fail Failure
let vec_index_back (a : Type0) (v : vec a) (i : usize) (x : a) : result unit =
- if i < length v then Return () else Fail
+ if i < length v then Return () else Fail Failure
let vec_index_mut_fwd (a : Type0) (v : vec a) (i : usize) : result a =
- if i < length v then Return (index v i) else Fail
+ if i < length v then Return (index v i) else Fail Failure
let vec_index_mut_back (a : Type0) (v : vec a) (i : usize) (nx : a) : result (vec a) =
- if i < length v then Return (list_update v i nx) else Fail
+ if i < length v then Return (list_update v i nx) else Fail Failure
diff --git a/compiler/ConstStrings.ml b/compiler/ConstStrings.ml
index ae169a2e..6cf57fe4 100644
--- a/compiler/ConstStrings.ml
+++ b/compiler/ConstStrings.ml
@@ -5,3 +5,6 @@ let state_basename = "st"
(** ADT constructor prefix (used when pretty-printing) *)
let constructor_prefix = "Mk"
+
+(** Basename for error variables *)
+let error_basename = "e"
diff --git a/compiler/Extract.ml b/compiler/Extract.ml
index 13c02bca..17b6aa54 100644
--- a/compiler/Extract.ml
+++ b/compiler/Extract.ml
@@ -126,7 +126,13 @@ let keywords () =
List.concat [ named_unops; named_binops; misc ]
let assumed_adts : (assumed_ty * string) list =
- [ (State, "state"); (Result, "result"); (Option, "option"); (Vec, "vec") ]
+ [
+ (State, "state");
+ (Result, "result");
+ (Error, "error");
+ (Option, "option");
+ (Vec, "vec");
+ ]
let assumed_structs : (assumed_ty * string) list = []
@@ -136,6 +142,8 @@ let assumed_variants () : (assumed_ty * VariantId.id * string) list =
[
(Result, result_return_id, "Return");
(Result, result_fail_id, "Fail");
+ (Error, error_failure_id, "Failure");
+ (Error, error_out_of_fuel_id, "OutOfFuel");
(Option, option_some_id, "Some");
(Option, option_none_id, "None");
]
@@ -143,6 +151,8 @@ let assumed_variants () : (assumed_ty * VariantId.id * string) list =
[
(Result, result_return_id, "Return");
(Result, result_fail_id, "Fail_");
+ (Error, error_failure_id, "Failure");
+ (Error, error_out_of_fuel_id, "OutOfFuel");
(Option, option_some_id, "Some");
(Option, option_none_id, "None");
]
@@ -429,6 +439,7 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string)
(* The "pair" case is frequent enough to have its special treatment *)
if List.length tys = 2 then "p" else "t"
| Assumed Result -> "r"
+ | Assumed Error -> "e"
| Assumed Option -> "opt"
| Assumed Vec -> "v"
| Assumed State -> "st"
diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml
index 33939e6a..9690d9fc 100644
--- a/compiler/ExtractBase.ml
+++ b/compiler/ExtractBase.ml
@@ -413,7 +413,7 @@ type extraction_ctx = {
(** The indent increment we insert whenever we need to indent more *)
}
-(** Debugging function *)
+(** Debugging function, used when communicating name collisions to the user *)
let id_to_string (id : id) (ctx : extraction_ctx) : string =
let global_decls = ctx.trans_ctx.global_context.global_decls in
let fun_decls = ctx.trans_ctx.fun_context.fun_decls in
@@ -467,6 +467,10 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string =
if variant_id = result_return_id then "@result::Return"
else if variant_id = result_fail_id then "@result::Fail"
else raise (Failure "Unreachable")
+ | Assumed Error ->
+ if variant_id = error_failure_id then "@error::Failure"
+ else if variant_id = error_out_of_fuel_id then "@error::OutOfFuel"
+ else raise (Failure "Unreachable")
| Assumed Option ->
if variant_id = option_some_id then "@option::Some"
else if variant_id = option_none_id then "@option::None"
@@ -485,7 +489,8 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string =
let field_name =
match id with
| Tuple -> raise (Failure "Unreachable")
- | Assumed (State | Result | Option) -> raise (Failure "Unreachable")
+ | Assumed (State | Result | Error | Option) ->
+ raise (Failure "Unreachable")
| Assumed Vec ->
(* We can't directly have access to the fields of a vector *)
raise (Failure "Unreachable")
diff --git a/compiler/FunsAnalysis.ml b/compiler/FunsAnalysis.ml
index 4d33056b..75a6c0ce 100644
--- a/compiler/FunsAnalysis.ml
+++ b/compiler/FunsAnalysis.ml
@@ -103,7 +103,7 @@ let analyze_module (m : crate) (funs_map : fun_decl FunDeclId.Map.t)
(* We need to know if the declaration group contains a global - note that
* groups containing globals contain exactly one declaration *)
let is_global_decl_body = List.exists (fun f -> f.is_global_decl_body) d in
- assert ((not is_global_decl_body) || List.length d == 1);
+ assert ((not is_global_decl_body) || List.length d = 1);
(* We ignore on purpose functions that cannot fail and consider they *can*
* fail: the result of the analysis is not used yet to adjust the translation
* so that the functions which syntactically can't fail don't use an error monad.
diff --git a/compiler/InterpreterExpressions.ml b/compiler/InterpreterExpressions.ml
index 5bc440e7..5d1a3cfe 100644
--- a/compiler/InterpreterExpressions.ml
+++ b/compiler/InterpreterExpressions.ml
@@ -357,7 +357,7 @@ let eval_unary_op_concrete (config : C.config) (unop : E.unop) (op : E.operand)
| Error _ -> cf (Error EPanic)
| Ok sv -> cf (Ok { v with V.value = V.Primitive (PV.Scalar sv) }))
| E.Cast (src_ty, tgt_ty), V.Primitive (PV.Scalar sv) -> (
- assert (src_ty == sv.int_ty);
+ assert (src_ty = sv.int_ty);
let i = sv.PV.value in
match mk_scalar tgt_ty i with
| Error _ -> cf (Error EPanic)
@@ -637,9 +637,9 @@ let eval_rvalue_aggregate (config : C.config)
cf aggregated ctx
| E.AggregatedOption (variant_id, ty) ->
(* Sanity check *)
- if variant_id == T.option_none_id then assert (values == [])
- else if variant_id == T.option_some_id then
- assert (List.length values == 1)
+ if variant_id = T.option_none_id then assert (values = [])
+ else if variant_id = T.option_some_id then
+ assert (List.length values = 1)
else raise (Failure "Unreachable");
(* Construt the value *)
let aty = T.Adt (T.Assumed T.Option, [], [ ty ]) in
diff --git a/compiler/InterpreterStatements.ml b/compiler/InterpreterStatements.ml
index 14dd59b1..3bf7b723 100644
--- a/compiler/InterpreterStatements.ml
+++ b/compiler/InterpreterStatements.ml
@@ -469,8 +469,8 @@ let eval_box_deref_mut_or_shared_concrete (config : C.config)
:: Var (_ret_var, _)
:: C.Frame :: _ ) ->
(* Required type checking. We must have:
- - input_value.ty == & (mut) Box<ty>
- - boxed_ty == ty
+ - input_value.ty = & (mut) Box<ty>
+ - boxed_ty = ty
for some ty
*)
(let _, input_ty, ref_kind = ty_get_ref input_value.V.ty in
diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml
index b4ab26b8..0879f553 100644
--- a/compiler/PrintPure.ml
+++ b/compiler/PrintPure.ml
@@ -128,6 +128,7 @@ let type_id_to_string (fmt : type_formatter) (id : type_id) : string =
match aty with
| State -> "State"
| Result -> "Result"
+ | Error -> "Error"
| Option -> "Option"
| Vec -> "Vec")
@@ -247,6 +248,11 @@ let adt_variant_to_string (fmt : value_formatter) (adt_id : type_id)
else if variant_id = result_fail_id then "@Result::Fail"
else
raise (Failure "Unreachable: improper variant id for result type")
+ | Error ->
+ let variant_id = Option.get variant_id in
+ if variant_id = error_failure_id then "@Error::Failure"
+ else if variant_id = error_out_of_fuel_id then "@Error::OutOfFuel"
+ else raise (Failure "Unreachable: improper variant id for error type")
| Option ->
let variant_id = Option.get variant_id in
if variant_id = option_some_id then "@Option::Some "
@@ -275,7 +281,7 @@ let adt_field_to_string (fmt : value_formatter) (adt_id : type_id)
| State | Vec ->
(* Opaque types: we can't get there *)
raise (Failure "Unreachable")
- | Result | Option ->
+ | Result | Error | Option ->
(* Enumerations: we can't get there *)
raise (Failure "Unreachable"))
@@ -324,11 +330,18 @@ let adt_g_value_to_string (fmt : value_formatter)
match field_values with
| [ v ] -> "@Result::Return " ^ v
| _ -> raise (Failure "Result::Return takes exactly one value")
- else if variant_id = result_fail_id then (
- assert (field_values = []);
- "@Result::Fail")
+ else if variant_id = result_fail_id then
+ match field_values with
+ | [ v ] -> "@Result::Fail " ^ v
+ | _ -> raise (Failure "Result::Fail takes exactly one value")
else
raise (Failure "Unreachable: improper variant id for result type")
+ | Error ->
+ assert (field_values = []);
+ let variant_id = Option.get variant_id in
+ if variant_id = error_failure_id then "@Error::Failure"
+ else if variant_id = error_out_of_fuel_id then "@Error::OutOfFuel"
+ else raise (Failure "Unreachable: improper variant id for error type")
| Option ->
let variant_id = Option.get variant_id in
if variant_id = option_some_id then
diff --git a/compiler/Pure.ml b/compiler/Pure.ml
index b0114baa..6cc73bef 100644
--- a/compiler/Pure.ml
+++ b/compiler/Pure.ml
@@ -26,16 +26,17 @@ type integer_type = T.integer_type [@@deriving show, ord]
(** The assumed types for the pure AST.
In comparison with LLBC:
- - we removed [Box] (because it is translated as the identity: [Box T == T])
+ - we removed [Box] (because it is translated as the identity: [Box T = T])
- we added:
- [Result]: the type used in the error monad. This allows us to have a
unified treatment of expressions (especially when we have to unfold the
monadic binds)
+ - [Error]: the kind of error, in case of failure (used by [Result])
- [State]: the type of the state, when using state-error monads. Note that
this state is opaque to Aeneas (the user can define it, or leave it as
assumed)
*)
-type assumed_ty = State | Result | Vec | Option [@@deriving show, ord]
+type assumed_ty = State | Result | Error | Vec | Option [@@deriving show, ord]
(* TODO: we should never directly manipulate [Return] and [Fail], but rather
* the monadic functions [return] and [fail] (makes treatment of error and
@@ -44,6 +45,8 @@ let result_return_id = VariantId.of_int 0
let result_fail_id = VariantId.of_int 1
let option_some_id = T.option_some_id
let option_none_id = T.option_none_id
+let error_failure_id = VariantId.of_int 0
+let error_out_of_fuel_id = VariantId.of_int 1
type type_id = AdtId of TypeDeclId.id | Tuple | Assumed of assumed_ty
[@@deriving show, ord]
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml
index 1cb35613..c5eb3c64 100644
--- a/compiler/PureMicroPasses.ml
+++ b/compiler/PureMicroPasses.ml
@@ -123,7 +123,7 @@ type pn_ctx = {
{[
let py = id(&mut x);
*py = 2;
- assert!(x == 2);
+ assert!(x = 2);
]}
After desugaring, we get the following MIR:
@@ -131,7 +131,7 @@ type pn_ctx = {
^0 = &mut x; // anonymous variable
py = id(move ^0);
*py += 2;
- assert!(x == 2);
+ assert!(x = 2);
]}
We want this to be translated as:
@@ -1228,6 +1228,9 @@ let unfold_monadic_let_bindings (_ctx : trans_ctx) (def : fun_decl) : fun_decl =
match def.body with
| None -> def
| Some body ->
+ let cnt = get_body_min_var_counter body in
+ let _, fresh_id = VarId.mk_stateful_generator cnt in
+
(* It is a very simple map *)
let obj =
object (_self)
@@ -1257,8 +1260,18 @@ let unfold_monadic_let_bindings (_ctx : trans_ctx) (def : fun_decl) : fun_decl =
* store in an enum ("monadic" should be an enum, not a bool). *)
let re_ty = Option.get (opt_destruct_result re.ty) in
assert (lv.ty = re_ty);
- let fail_pat = mk_result_fail_pattern lv.ty in
- let fail_value = mk_result_fail_texpression e.ty in
+ let err_vid = fresh_id () in
+ let err_var : var =
+ {
+ id = err_vid;
+ basename = Some ConstStrings.error_basename;
+ ty = mk_error_ty;
+ }
+ in
+ let err_pat = mk_typed_pattern_from_var err_var None in
+ let fail_pat = mk_result_fail_pattern err_pat.value lv.ty in
+ let err_v = mk_texpression_from_var err_var in
+ let fail_value = mk_result_fail_texpression err_v e.ty in
let fail_branch = { pat = fail_pat; branch = fail_value } in
let success_pat = mk_result_return_pattern lv in
let success_branch = { pat = success_pat; branch = e } in
diff --git a/compiler/PureTypeCheck.ml b/compiler/PureTypeCheck.ml
index 6b6a82ad..a1e4e834 100644
--- a/compiler/PureTypeCheck.ml
+++ b/compiler/PureTypeCheck.ml
@@ -26,9 +26,15 @@ let get_adt_field_types (type_decls : type_decl TypeDeclId.Map.t)
let ty = Collections.List.to_cons_nil tys in
let variant_id = Option.get variant_id in
if variant_id = result_return_id then [ ty ]
- else if variant_id = result_fail_id then []
+ else if variant_id = result_fail_id then [ mk_error_ty ]
else
raise (Failure "Unreachable: improper variant id for result type")
+ | Error ->
+ assert (tys = []);
+ let variant_id = Option.get variant_id in
+ assert (
+ variant_id = error_failure_id || variant_id = error_out_of_fuel_id);
+ []
| Option ->
let ty = Collections.List.to_cons_nil tys in
let variant_id = Option.get variant_id in
diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml
index 728a4fe6..f5c280fb 100644
--- a/compiler/PureUtils.ml
+++ b/compiler/PureUtils.ml
@@ -421,13 +421,21 @@ let type_decl_is_enum (def : T.type_decl) : bool =
let mk_state_ty : ty = Adt (Assumed State, [])
let mk_result_ty (ty : ty) : ty = Adt (Assumed Result, [ ty ])
+let mk_error_ty : ty = Adt (Assumed Error, [])
+
+let mk_error (error : VariantId.id) : texpression =
+ let ty = mk_error_ty in
+ let id = AdtCons { adt_id = Assumed Error; variant_id = Some error } in
+ let qualif = { id; type_args = [] } in
+ let e = Qualif qualif in
+ { e; ty }
let unwrap_result_ty (ty : ty) : ty =
match ty with
| Adt (Assumed Result, [ ty ]) -> ty
| _ -> raise (Failure "not a result type")
-let mk_result_fail_texpression (ty : ty) : texpression =
+let mk_result_fail_texpression (error : texpression) (ty : ty) : texpression =
let type_args = [ ty ] in
let ty = Adt (Assumed Result, type_args) in
let id =
@@ -435,9 +443,14 @@ let mk_result_fail_texpression (ty : ty) : texpression =
in
let qualif = { id; type_args } in
let cons_e = Qualif qualif in
- let cons_ty = ty in
+ let cons_ty = mk_arrow error.ty ty in
let cons = { e = cons_e; ty = cons_ty } in
- cons
+ mk_app cons error
+
+let mk_result_fail_texpression_with_error_id (error : VariantId.id) (ty : ty) :
+ texpression =
+ let error = mk_error error in
+ mk_result_fail_texpression error ty
let mk_result_return_texpression (v : texpression) : texpression =
let type_args = [ v.ty ] in
@@ -451,11 +464,20 @@ let mk_result_return_texpression (v : texpression) : texpression =
let cons = { e = cons_e; ty = cons_ty } in
mk_app cons v
-let mk_result_fail_pattern (ty : ty) : typed_pattern =
+(** Create a [Fail err] pattern which captures the error *)
+let mk_result_fail_pattern (error_pat : pattern) (ty : ty) : typed_pattern =
+ let error_pat : typed_pattern = { value = error_pat; ty = mk_error_ty } in
let ty = Adt (Assumed Result, [ ty ]) in
- let value = PatAdt { variant_id = Some result_fail_id; field_values = [] } in
+ let value =
+ PatAdt { variant_id = Some result_fail_id; field_values = [ error_pat ] }
+ in
{ value; ty }
+(** Create a [Fail _] pattern (we ignore the error) *)
+let mk_result_fail_pattern_ignore_error (ty : ty) : typed_pattern =
+ let error_pat : pattern = PatDummy in
+ mk_result_fail_pattern error_pat ty
+
let mk_result_return_pattern (v : typed_pattern) : typed_pattern =
let ty = Adt (Assumed Result, [ v.ty ]) in
let value =
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index 62be5efd..8fa66f93 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -1134,9 +1134,11 @@ and translate_panic (ctx : bs_ctx) : texpression =
if ctx.sg.info.effect_info.stateful then
(* Create the [Fail] value *)
let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; output_ty ] in
- let ret_v = mk_result_fail_texpression ret_ty in
+ let ret_v =
+ mk_result_fail_texpression_with_error_id error_failure_id ret_ty
+ in
ret_v
- else mk_result_fail_texpression output_ty
+ else mk_result_fail_texpression_with_error_id error_failure_id output_ty
(** [opt_v]: the value to return, in case we translate a forward function *)
and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression
@@ -1661,7 +1663,7 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
| _ -> raise (Failure "Unreachable")
in
(* We simply introduce an assignment - the box type is the
- * identity when extracted ([box a == a]) *)
+ * identity when extracted ([box a = a]) *)
let monadic = false in
mk_let monadic
(mk_typed_pattern_from_var var None)