summaryrefslogtreecommitdiff
path: root/src/ExtractToFStar.ml
diff options
context:
space:
mode:
Diffstat (limited to 'src/ExtractToFStar.ml')
-rw-r--r--src/ExtractToFStar.ml26
1 files changed, 22 insertions, 4 deletions
diff --git a/src/ExtractToFStar.ml b/src/ExtractToFStar.ml
index 1f59075a..9d96d058 100644
--- a/src/ExtractToFStar.ml
+++ b/src/ExtractToFStar.ml
@@ -1192,7 +1192,7 @@ let extract_template_decreases_clause (ctx : extraction_ctx) (fmt : F.formatter)
*)
let extract_fun_decl (ctx : extraction_ctx) (fmt : F.formatter)
(qualif : fun_decl_qualif) (has_decreases_clause : bool)
- (fwd_def : fun_decl) (def : fun_decl) : unit =
+ (has_state_param : bool) (fwd_def : fun_decl) (def : fun_decl) : unit =
(* Retrieve the function name *)
let def_name = ctx_get_local_function def.def_id def.back_id ctx in
(* (* Add the type parameters - note that we need those bindings only for the
@@ -1279,12 +1279,30 @@ let extract_fun_decl (ctx : extraction_ctx) (fmt : F.formatter)
* backward functions have no influence on termination: we thus
* share the decrease clauses between the forward and the backward
* functions).
+ * Something annoying is that there may be a state parameter, for
+ * the state-error monad, in which case we need to forget the input
+ * parameter before last:
+ * ```
+ * val f_fwd (x : u32) (st : state) :
+ * Tot (result (state & u32)) (decreases (f_decreases x st))
+ *
+ * We ignore this parameter
+ * VVV
+ * val f_back (x : u32) (ret : u32) (st : state) :
+ * Tot (result (state & u32)) (decreases (f_decreases x st))
+ * ```
* Rk.: if a function has a decreases clause, it is necessarily
* a transparent function *)
let inputs_lvs =
- Collections.List.prefix
- (List.length (Option.get fwd_def.body).inputs_lvs)
- (Option.get def.body).inputs_lvs
+ let num_fwd_inputs = List.length (Option.get fwd_def.body).inputs_lvs in
+ let num_fwd_inputs =
+ if has_state_param then num_fwd_inputs - 1 else num_fwd_inputs
+ in
+ let all_inputs = (Option.get def.body).inputs_lvs in
+ let inputs = Collections.List.prefix num_fwd_inputs all_inputs in
+ if has_state_param then
+ inputs @ [ List.nth all_inputs (List.length all_inputs - 1) ]
+ else inputs
in
let _ =
List.fold_left