diff options
author | Son Ho | 2023-05-23 14:43:45 +0200 |
---|---|---|
committer | Son HO | 2023-06-04 21:54:38 +0200 |
commit | dbc040b720862ddb40210c8ca5caf84123fb20fc (patch) | |
tree | 875c95104efb3f8ba76c271943a7d8d4a65f5e31 /backends | |
parent | 057f68ea639c52c33cff36017fc3f1365503934b (diff) |
Improve the parsing for divDefLib
Diffstat (limited to 'backends')
-rw-r--r-- | backends/hol4/divDefLib.sml | 92 |
1 files changed, 91 insertions, 1 deletions
diff --git a/backends/hol4/divDefLib.sml b/backends/hol4/divDefLib.sml index 128e0b0f..edeb63a4 100644 --- a/backends/hol4/divDefLib.sml +++ b/backends/hol4/divDefLib.sml @@ -852,10 +852,100 @@ fun prove_def_eqs (body_is_valid : thm) (def_tms : term list) (raw_defs : thm li * * ============================================================================*) +type absyn = Absyn.absyn + +(* Helper: convert an absyn to a vstruct (i.e., turn a "standard" term into + a quantified term; we use it to transform function arguments into abstracted + terms (in a lambda) *) +fun absyn_to_vstruct (x : absyn) : Absyn.vstruct = + case x of + Absyn.AQ (l, t) => Absyn.VAQ (l, t) + | Absyn.IDENT (l, s) => Absyn.VIDENT (l, s) + | Absyn.QIDENT _ => raise (mk_HOL_ERR "divDefLib" "absyn_to_vstruct" "Unsupported: QIDENT") + | Absyn.APP _ => raise (mk_HOL_ERR "divDefLib" "absyn_to_vstruct" "Unsupported: APP") + | Absyn.LAM _ => raise (mk_HOL_ERR "divDefLib" "absyn_to_vstruct" "Unsupported: LAM") + | Absyn.TYPED (l, y, ty) => Absyn.VTYPED (l, absyn_to_vstruct y, ty) + +(* We need to parse the quotation in a specific manner. + + The issue is that, with mutually recursive functions, the parser sometimes + gets confused if some funtions have parameters with the same name but with + different types. + + For instance: + {[ + f (x : int) = ... /\ + g (x : bool) = ... + ]} + + The solution is to rewrite the equations to make lambdas appear explicitely, + like so: + {[ + f = λ(x : int) = ... /\ + g = λ(x : bool) = ... + ]} + + We do the following: + - we convert the quotation to an abstract syntax tree + - transform this tree into a shape where function bodies are abstractions + - parse this to a term + - change the shape of the term back to the original shape (with arguments + on the left of the “=”) + *) +fun parse_quote (defs_qt : term quotation) : term = + let + val def_abs = Parse.Absyn defs_qt + val absl = Absyn.strip_conj def_abs + + (* Turn an equation of the shape “f x = ...” into “f = \x. ...” *) + fun make_lambda_def (def_abs : absyn) : absyn = + let + (* Retrieve the body *) + val (app, body) = Absyn.dest_eq def_abs + (* Remove the typing annotation from around the lhs, if there is, + and put it around the rhs *) + val (app, body) = + if Absyn.is_typed app then + let val (app, ty) = Absyn.dest_typed app in (app, Absyn.mk_typed (body, ty)) end + else (app, body) + (* Strip the arguments *) + val (f, args) = Absyn.strip_app app + (* Make a lambda abstraction *) + val args = map absyn_to_vstruct args + val body = Absyn.list_mk_lam (args, body) + in + Absyn.mk_eq (f, body) + end + val absl = map make_lambda_def absl + val def_abs = Absyn.list_mk_conj absl + + (* Parse the quote now that it is in the proper shape *) + val def_tm = + (* This is taken from Defn.sml: we removed the [sort_eqns] because it is not + useful in our case (its untangle the dependencies so that functions are + defined before use). *) + fst (Defn.parse_absyn def_abs) + handle e => raise wrap_exn "divDefLib" "parse_quote" e + + (* Put the definition back into the original shape *) + fun make_args_def (tm : term) : term = + let + val (f, body) = dest_eq tm + val (args, body) = strip_abs body + in + mk_eq (list_mk_comb (f, args), body) + end + val def_tms = strip_conj def_tm + val def_tms = map make_args_def def_tms + val def_tm = list_mk_conj def_tms + in + def_tm + end + fun DefineDiv (def_qt : term quotation) = let (* Parse the definitions *) - val def_tms = (strip_conj o list_mk_conj o rev) (Defn.parse_quote def_qt) + val def_tms = strip_conj (parse_quote def_qt) (* Compute the names and the input/output types of the functions *) fun compute_names_in_out_tys (tm : term) : string * (hol_type * hol_type) = |