From 868fa924a37a3af6e701bbc0a2d51fefc2dc7c33 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Mon, 14 Nov 2022 11:57:53 +0100 Subject: Make [Result::Failure] type an [Error] parameter --- backends/coq/Primitives.v | 42 +++++++++++++++++++++----------------- backends/fstar/Makefile | 47 +++++++++++++++++++++++++++++++++++++++++++ backends/fstar/Primitives.fst | 32 ++++++++++++++++------------- 3 files changed, 88 insertions(+), 33 deletions(-) create mode 100644 backends/fstar/Makefile (limited to 'backends') diff --git a/backends/coq/Primitives.v b/backends/coq/Primitives.v index c27b8aed..9a97d6c7 100644 --- a/backends/coq/Primitives.v +++ b/backends/coq/Primitives.v @@ -13,40 +13,44 @@ Module Primitives. Declare Scope Primitives_scope. (*** Result *) - + +Inductive error := + | Failure + | OutOfFuel. + Inductive result A := | Return : A -> result A - | Fail_ : result A. + | Fail_ : error -> result A. Arguments Return {_} a. Arguments Fail_ {_}. Definition bind {A B} (m: result A) (f: A -> result B) : result B := match m with - | Fail_ => Fail_ + | Fail_ e => Fail_ e | Return x => f x end. -Definition return_ {A: Type} (x: A) : result A := Return x . -Definition fail_ {A: Type} : result A := Fail_ . +Definition return_ {A: Type} (x: A) : result A := Return x. +Definition fail_ {A: Type} (e: error) : result A := Fail_ e. Notation "x <- c1 ; c2" := (bind c1 (fun x => c2)) (at level 61, c1 at next level, right associativity). (** Monadic assert *) Definition massert (b: bool) : result unit := - if b then Return tt else Fail_. + if b then Return tt else Fail_ Failure. (** Normalize and unwrap a successful result (used for globals) *) Definition eval_result_refl {A} {x} (a: result A) (p: a = Return x) : A := match a as r return (r = Return x -> A) with | Return a' => fun _ => a' - | Fail_ => fun p' => - False_rect _ (eq_ind Fail_ + | Fail_ e => fun p' => + False_rect _ (eq_ind (Fail_ e) (fun e : result A => match e with | Return _ => False - | Fail_ => True + | Fail_ e => True end) I (Return x) p') end p. @@ -55,7 +59,7 @@ Notation "x %global" := (eval_result_refl x eq_refl) (at level 40). Notation "x %return" := (eval_result_refl x eq_refl) (at level 40). (* Sanity check *) -Check (if true then Return (1 + 2) else Fail_)%global = 3. +Check (if true then Return (1 + 2) else Fail_ Failure)%global = 3. (*** Misc *) @@ -232,7 +236,7 @@ Import Sumbool. Definition mk_scalar (ty: scalar_ty) (x: Z) : result (scalar ty) := match sumbool_of_bool (scalar_in_bounds ty x) with | left H => Return (exist _ x (scalar_in_bounds_valid _ _ H)) - | right _ => Fail_ + | right _ => Fail_ Failure end. Definition scalar_add {ty} (x y: scalar ty) : result (scalar ty) := mk_scalar ty (to_Z x + to_Z y). @@ -242,7 +246,7 @@ Definition scalar_sub {ty} (x y: scalar ty) : result (scalar ty) := mk_scalar ty Definition scalar_mul {ty} (x y: scalar ty) : result (scalar ty) := mk_scalar ty (to_Z x * to_Z y). Definition scalar_div {ty} (x y: scalar ty) : result (scalar ty) := - if to_Z y =? 0 then Fail_ else + if to_Z y =? 0 then Fail_ Failure else mk_scalar ty (to_Z x / to_Z y). Definition scalar_rem {ty} (x y: scalar ty) : result (scalar ty) := mk_scalar ty (Z.rem (to_Z x) (to_Z y)). @@ -433,7 +437,7 @@ Definition vec_bind {A B} (v: vec A) (f: list A -> result (list B)) : result (ve l <- f (vec_to_list v) ; match sumbool_of_bool (scalar_le_max Usize (Z.of_nat (length l))) with | left H => Return (exist _ l (scalar_le_max_valid _ _ H)) - | right _ => Fail_ + | right _ => Fail_ Failure end. (* The **forward** function shouldn't be used *) @@ -444,35 +448,35 @@ Definition vec_push_back (T: Type) (v: vec T) (x: T) : result (vec T) := (* The **forward** function shouldn't be used *) Definition vec_insert_fwd (T: Type) (v: vec T) (i: usize) (x: T) : result unit := - if to_Z i if to_Z i Return n - | None => Fail_ + | None => Fail_ Failure end. Definition vec_index_back (T: Type) (v: vec T) (i: usize) (x: T) : result unit := - if to_Z i Return n - | None => Fail_ + | None => Fail_ Failure end. Definition vec_index_mut_back (T: Type) (v: vec T) (i: usize) (x: T) : result (vec T) := vec_bind v (fun l => if to_Z i $@ + +.PHONY: .FORCE +.FORCE: +endif +endif + +include .depend +endif + +# For the interactive mode +%.fst-in %.fsti-in: + @echo $(FSTAR_OPTIONS) + +# Generete the .checked files in batch mode +%.checked: + $(FSTAR) $(FSTAR_OPTIONS) $< && \ + touch -c $@ + +.PHONY: clean +clean: + rm -f obj/* diff --git a/backends/fstar/Primitives.fst b/backends/fstar/Primitives.fst index 96138e46..82622656 100644 --- a/backends/fstar/Primitives.fst +++ b/backends/fstar/Primitives.fst @@ -18,9 +18,13 @@ let rec list_update #a ls i x = #pop-options (*** Result *) +type error : Type0 = +| Failure +| OutOfFuel + type result (a : Type0) : Type0 = | Return : v:a -> result a -| Fail : result a +| Fail : e:error -> result a // Monadic bind and return. // Re-definining those allows us to customize the result of the monadic notations @@ -29,10 +33,10 @@ let return (#a : Type0) (x:a) : result a = Return x let bind (#a #b : Type0) (m : result a) (f : a -> result b) : result b = match m with | Return x -> f x - | Fail -> Fail + | Fail e -> Fail e // Monadic assert(...) -let massert (b:bool) : result unit = if b then Return () else Fail +let massert (b:bool) : result unit = if b then Return () else Fail Failure // Normalize and unwrap a successful result (used for globals). let eval_global (#a : Type0) (x : result a{Return? (normalize_term x)}) : a = Return?.v x @@ -119,12 +123,12 @@ let scalar_max (ty : scalar_ty) : int = type scalar (ty : scalar_ty) : eqtype = x:int{scalar_min ty <= x && x <= scalar_max ty} let mk_scalar (ty : scalar_ty) (x : int) : result (scalar ty) = - if scalar_min ty <= x && scalar_max ty >= x then Return x else Fail + if scalar_min ty <= x && scalar_max ty >= x then Return x else Fail Failure let scalar_neg (#ty : scalar_ty) (x : scalar ty) : result (scalar ty) = mk_scalar ty (-x) let scalar_div (#ty : scalar_ty) (x : scalar ty) (y : scalar ty) : result (scalar ty) = - if y <> 0 then mk_scalar ty (x / y) else Fail + if y <> 0 then mk_scalar ty (x / y) else Fail Failure /// The remainder operation let int_rem (x : int) (y : int{y <> 0}) : int = @@ -137,7 +141,7 @@ let _ = assert_norm(int_rem 1 (-2) = 1) let _ = assert_norm(int_rem (-1) (-2) = -1) let scalar_rem (#ty : scalar_ty) (x : scalar ty) (y : scalar ty) : result (scalar ty) = - if y <> 0 then mk_scalar ty (int_rem x y) else Fail + if y <> 0 then mk_scalar ty (int_rem x y) else Fail Failure let scalar_add (#ty : scalar_ty) (x : scalar ty) (y : scalar ty) : result (scalar ty) = mk_scalar ty (x + y) @@ -258,7 +262,7 @@ let vec_push_back (a : Type0) (v : vec a) (x : a) : (requires True) (ensures (fun res -> match res with - | Fail -> True + | Fail e -> e == Failure | Return v' -> length v' = length v + 1)) = if length v < usize_max then begin (**) assert_norm(length [x] == 1); @@ -266,22 +270,22 @@ let vec_push_back (a : Type0) (v : vec a) (x : a) : (**) assert(length (append v [x]) = length v + 1); Return (append v [x]) end - else Fail + else Fail Failure // The **forward** function shouldn't be used let vec_insert_fwd (a : Type0) (v : vec a) (i : usize) (x : a) : result unit = - if i < length v then Return () else Fail + if i < length v then Return () else Fail Failure let vec_insert_back (a : Type0) (v : vec a) (i : usize) (x : a) : result (vec a) = - if i < length v then Return (list_update v i x) else Fail + if i < length v then Return (list_update v i x) else Fail Failure // The **backward** function shouldn't be used let vec_index_fwd (a : Type0) (v : vec a) (i : usize) : result a = - if i < length v then Return (index v i) else Fail + if i < length v then Return (index v i) else Fail Failure let vec_index_back (a : Type0) (v : vec a) (i : usize) (x : a) : result unit = - if i < length v then Return () else Fail + if i < length v then Return () else Fail Failure let vec_index_mut_fwd (a : Type0) (v : vec a) (i : usize) : result a = - if i < length v then Return (index v i) else Fail + if i < length v then Return (index v i) else Fail Failure let vec_index_mut_back (a : Type0) (v : vec a) (i : usize) (nx : a) : result (vec a) = - if i < length v then Return (list_update v i nx) else Fail + if i < length v then Return (list_update v i nx) else Fail Failure -- cgit v1.2.3