summaryrefslogtreecommitdiff
path: root/backends/hol4/divDefLib.sml
diff options
context:
space:
mode:
authorSon Ho2023-05-23 14:43:45 +0200
committerSon HO2023-06-04 21:54:38 +0200
commitdbc040b720862ddb40210c8ca5caf84123fb20fc (patch)
tree875c95104efb3f8ba76c271943a7d8d4a65f5e31 /backends/hol4/divDefLib.sml
parent057f68ea639c52c33cff36017fc3f1365503934b (diff)
Improve the parsing for divDefLib
Diffstat (limited to '')
-rw-r--r--backends/hol4/divDefLib.sml92
1 files changed, 91 insertions, 1 deletions
diff --git a/backends/hol4/divDefLib.sml b/backends/hol4/divDefLib.sml
index 128e0b0f..edeb63a4 100644
--- a/backends/hol4/divDefLib.sml
+++ b/backends/hol4/divDefLib.sml
@@ -852,10 +852,100 @@ fun prove_def_eqs (body_is_valid : thm) (def_tms : term list) (raw_defs : thm li
*
* ============================================================================*)
+type absyn = Absyn.absyn
+
+(* Helper: convert an absyn to a vstruct (i.e., turn a "standard" term into
+ a quantified term; we use it to transform function arguments into abstracted
+ terms (in a lambda) *)
+fun absyn_to_vstruct (x : absyn) : Absyn.vstruct =
+ case x of
+ Absyn.AQ (l, t) => Absyn.VAQ (l, t)
+ | Absyn.IDENT (l, s) => Absyn.VIDENT (l, s)
+ | Absyn.QIDENT _ => raise (mk_HOL_ERR "divDefLib" "absyn_to_vstruct" "Unsupported: QIDENT")
+ | Absyn.APP _ => raise (mk_HOL_ERR "divDefLib" "absyn_to_vstruct" "Unsupported: APP")
+ | Absyn.LAM _ => raise (mk_HOL_ERR "divDefLib" "absyn_to_vstruct" "Unsupported: LAM")
+ | Absyn.TYPED (l, y, ty) => Absyn.VTYPED (l, absyn_to_vstruct y, ty)
+
+(* We need to parse the quotation in a specific manner.
+
+ The issue is that, with mutually recursive functions, the parser sometimes
+ gets confused if some funtions have parameters with the same name but with
+ different types.
+
+ For instance:
+ {[
+ f (x : int) = ... /\
+ g (x : bool) = ...
+ ]}
+
+ The solution is to rewrite the equations to make lambdas appear explicitely,
+ like so:
+ {[
+ f = λ(x : int) = ... /\
+ g = λ(x : bool) = ...
+ ]}
+
+ We do the following:
+ - we convert the quotation to an abstract syntax tree
+ - transform this tree into a shape where function bodies are abstractions
+ - parse this to a term
+ - change the shape of the term back to the original shape (with arguments
+ on the left of the “=”)
+ *)
+fun parse_quote (defs_qt : term quotation) : term =
+ let
+ val def_abs = Parse.Absyn defs_qt
+ val absl = Absyn.strip_conj def_abs
+
+ (* Turn an equation of the shape “f x = ...” into “f = \x. ...” *)
+ fun make_lambda_def (def_abs : absyn) : absyn =
+ let
+ (* Retrieve the body *)
+ val (app, body) = Absyn.dest_eq def_abs
+ (* Remove the typing annotation from around the lhs, if there is,
+ and put it around the rhs *)
+ val (app, body) =
+ if Absyn.is_typed app then
+ let val (app, ty) = Absyn.dest_typed app in (app, Absyn.mk_typed (body, ty)) end
+ else (app, body)
+ (* Strip the arguments *)
+ val (f, args) = Absyn.strip_app app
+ (* Make a lambda abstraction *)
+ val args = map absyn_to_vstruct args
+ val body = Absyn.list_mk_lam (args, body)
+ in
+ Absyn.mk_eq (f, body)
+ end
+ val absl = map make_lambda_def absl
+ val def_abs = Absyn.list_mk_conj absl
+
+ (* Parse the quote now that it is in the proper shape *)
+ val def_tm =
+ (* This is taken from Defn.sml: we removed the [sort_eqns] because it is not
+ useful in our case (its untangle the dependencies so that functions are
+ defined before use). *)
+ fst (Defn.parse_absyn def_abs)
+ handle e => raise wrap_exn "divDefLib" "parse_quote" e
+
+ (* Put the definition back into the original shape *)
+ fun make_args_def (tm : term) : term =
+ let
+ val (f, body) = dest_eq tm
+ val (args, body) = strip_abs body
+ in
+ mk_eq (list_mk_comb (f, args), body)
+ end
+ val def_tms = strip_conj def_tm
+ val def_tms = map make_args_def def_tms
+ val def_tm = list_mk_conj def_tms
+ in
+ def_tm
+ end
+
fun DefineDiv (def_qt : term quotation) =
let
(* Parse the definitions *)
- val def_tms = (strip_conj o list_mk_conj o rev) (Defn.parse_quote def_qt)
+ val def_tms = strip_conj (parse_quote def_qt)
(* Compute the names and the input/output types of the functions *)
fun compute_names_in_out_tys (tm : term) : string * (hol_type * hol_type) =