aboutsummaryrefslogtreecommitdiff
path: root/spartan/core/ml
diff options
context:
space:
mode:
authorJosh Chen2020-06-15 11:53:44 +0200
committerJosh Chen2020-06-15 11:53:44 +0200
commite42b7b3c7d29160939a150b9ec94fc476f7d53e3 (patch)
treed2e404094ff77d1969eb1207f542095794246038 /spartan/core/ml
parent9050b7414021db31b23a034567ebc6da3f6c5f67 (diff)
parent69bf0744a5ce3ba144f59564ebf74d7d2f56b748 (diff)
Merge branch 'dev'
Diffstat (limited to '')
-rw-r--r--spartan/core/ml/cases.ML42
-rw-r--r--spartan/core/ml/congruence.ML82
-rw-r--r--spartan/core/ml/elimination.ML46
-rw-r--r--spartan/core/ml/eqsubst.ML434
-rw-r--r--spartan/core/ml/equality.ML90
-rw-r--r--spartan/core/ml/focus.ML125
-rw-r--r--spartan/core/ml/goals.ML214
-rw-r--r--spartan/core/ml/implicits.ML78
-rw-r--r--spartan/core/ml/lib.ML145
-rw-r--r--spartan/core/ml/rewrite.ML465
-rw-r--r--spartan/core/ml/tactics.ML228
-rw-r--r--spartan/core/ml/types.ML18
12 files changed, 1967 insertions, 0 deletions
diff --git a/spartan/core/ml/cases.ML b/spartan/core/ml/cases.ML
new file mode 100644
index 0000000..560a9f1
--- /dev/null
+++ b/spartan/core/ml/cases.ML
@@ -0,0 +1,42 @@
+(* Title: cases.ML
+ Author: Joshua Chen
+
+Case reasoning.
+*)
+
+structure Case: sig
+
+val rules: Proof.context -> thm list
+val lookup_rule: Proof.context -> Termtab.key -> thm option
+val register_rule: thm -> Context.generic -> Context.generic
+
+end = struct
+
+(* Context data *)
+
+(*Stores elimination rules together with a list of the indexnames of the
+ variables each rule eliminates. Keyed by head of the type being eliminated.*)
+structure Rules = Generic_Data (
+ type T = thm Termtab.table
+ val empty = Termtab.empty
+ val extend = I
+ val merge = Termtab.merge Thm.eq_thm_prop
+)
+
+val rules = map #2 o Termtab.dest o Rules.get o Context.Proof
+fun lookup_rule ctxt = Termtab.lookup (Rules.get (Context.Proof ctxt))
+fun register_rule rl =
+ let val hd = Term.head_of (Lib.type_of_typing (Thm.major_prem_of rl))
+ in Rules.map (Termtab.update (hd, rl)) end
+
+
+(* [cases] attribute *)
+val _ = Theory.setup (
+ Attrib.setup \<^binding>\<open>cases\<close>
+ (Scan.succeed (Thm.declaration_attribute register_rule))
+ ""
+ #> Global_Theory.add_thms_dynamic (\<^binding>\<open>cases\<close>, rules o Context.proof_of)
+)
+
+
+end
diff --git a/spartan/core/ml/congruence.ML b/spartan/core/ml/congruence.ML
new file mode 100644
index 0000000..d9f4ffa
--- /dev/null
+++ b/spartan/core/ml/congruence.ML
@@ -0,0 +1,82 @@
+structure Congruence = struct
+
+(* Congruence context data *)
+
+structure RHS = Generic_Data (
+ type T = (term * indexname) Termtab.table
+ val empty = Termtab.empty
+ val extend = I
+ val merge = Termtab.merge (Term.aconv o apply2 #1)
+)
+
+fun register_rhs t var =
+ let
+ val key = Term.head_of t
+ val idxname = #1 (dest_Var var)
+ in
+ RHS.map (Termtab.update (key, (t, idxname)))
+ end
+
+fun lookup_congruence ctxt t =
+ Termtab.lookup (RHS.get (Context.Proof ctxt)) (Term.head_of t)
+
+
+(* Congruence declarations *)
+
+local val Frees_to_Vars =
+ map_aterms (fn tm =>
+ case tm of
+ Free (name, T) => Var (("*!"^name, 0), T) (*Hacky naming!*)
+ | _ => tm)
+in
+
+(*Declare the "right-hand side" of types that are congruences.
+ Does not handle bound variables, so no dependent RHS in declarations!*)
+val _ = Outer_Syntax.local_theory \<^command_keyword>\<open>congruence\<close>
+ "declare right hand side of congruence"
+ (Parse.term -- (\<^keyword>\<open>rhs\<close> |-- Parse.term) >>
+ (fn (t_str, rhs_str) => fn lthy =>
+ let
+ val (t, rhs) = apply2 (Frees_to_Vars o Syntax.read_term lthy)
+ (t_str, rhs_str)
+ in lthy |>
+ Local_Theory.background_theory (
+ Context.theory_map (register_rhs t rhs))
+ end))
+
+end
+
+
+(* Calculational reasoning: ".." setup *)
+
+fun last_rhs ctxt = map_aterms (fn t =>
+ case t of
+ Const (\<^const_name>\<open>rhs\<close>, _) =>
+ let
+ val this_name = Name_Space.full_name (Proof_Context.naming_of ctxt)
+ (Binding.name Auto_Bind.thisN)
+ val this = #thms (the (Proof_Context.lookup_fact ctxt this_name))
+ handle Option => []
+ val rhs =
+ (case map Thm.prop_of this of
+ [prop] =>
+ (let
+ val typ = Lib.type_of_typing (Logic.strip_assums_concl prop)
+ val (cong_pttrn, varname) = the (lookup_congruence ctxt typ)
+ val unif_res = Pattern.unify (Context.Proof ctxt)
+ (cong_pttrn, typ) Envir.init
+ val rhs = #2 (the
+ (Vartab.lookup (Envir.term_env unif_res) varname))
+ in
+ rhs
+ end handle Option =>
+ error (".. can't match right-hand side of congruence"))
+ | _ => Term.dummy)
+ in rhs end
+ | _ => t)
+
+val _ = Context.>>
+ (Syntax_Phases.term_check 5 "" (fn ctxt => map (last_rhs ctxt)))
+
+
+end
diff --git a/spartan/core/ml/elimination.ML b/spartan/core/ml/elimination.ML
new file mode 100644
index 0000000..617f83e
--- /dev/null
+++ b/spartan/core/ml/elimination.ML
@@ -0,0 +1,46 @@
+(* Title: elimination.ML
+ Author: Joshua Chen
+
+Type elimination setup.
+*)
+
+structure Elim: sig
+
+val rules: Proof.context -> (thm * indexname list) list
+val lookup_rule: Proof.context -> Termtab.key -> (thm * indexname list) option
+val register_rule: term list -> thm -> Context.generic -> Context.generic
+
+end = struct
+
+(** Context data **)
+
+(* Elimination rule data *)
+
+(*Stores elimination rules together with a list of the indexnames of the
+ variables each rule eliminates. Keyed by head of the type being eliminated.*)
+structure Rules = Generic_Data (
+ type T = (thm * indexname list) Termtab.table
+ val empty = Termtab.empty
+ val extend = I
+ val merge = Termtab.merge (eq_fst Thm.eq_thm_prop)
+)
+
+fun rules ctxt = map (op #2) (Termtab.dest (Rules.get (Context.Proof ctxt)))
+fun lookup_rule ctxt = Termtab.lookup (Rules.get (Context.Proof ctxt))
+fun register_rule tms rl =
+ let val hd = Term.head_of (Lib.type_of_typing (Thm.major_prem_of rl))
+ in Rules.map (Termtab.update (hd, (rl, map (#1 o dest_Var) tms))) end
+
+
+(* [elims] attribute *)
+val _ = Theory.setup (
+ Attrib.setup \<^binding>\<open>elims\<close>
+ (Scan.repeat Args.term_pattern >>
+ (Thm.declaration_attribute o register_rule))
+ ""
+ #> Global_Theory.add_thms_dynamic (\<^binding>\<open>elims\<close>,
+ fn context => (map #1 (rules (Context.proof_of context))))
+)
+
+
+end
diff --git a/spartan/core/ml/eqsubst.ML b/spartan/core/ml/eqsubst.ML
new file mode 100644
index 0000000..ea6f098
--- /dev/null
+++ b/spartan/core/ml/eqsubst.ML
@@ -0,0 +1,434 @@
+(* Title: eqsubst.ML
+ Author: Lucas Dixon, University of Edinburgh
+ Modified: Joshua Chen, University of Innsbruck
+
+Perform a substitution using an equation.
+
+This code is slightly modified from the original at Tools/eqsubst..ML,
+to incorporate auto-typechecking for type theory.
+*)
+
+signature EQSUBST =
+sig
+ type match =
+ ((indexname * (sort * typ)) list (* type instantiations *)
+ * (indexname * (typ * term)) list) (* term instantiations *)
+ * (string * typ) list (* fake named type abs env *)
+ * (string * typ) list (* type abs env *)
+ * term (* outer term *)
+
+ type searchinfo =
+ Proof.context
+ * int (* maxidx *)
+ * Zipper.T (* focusterm to search under *)
+
+ datatype 'a skipseq = SkipMore of int | SkipSeq of 'a Seq.seq Seq.seq
+
+ val skip_first_asm_occs_search: ('a -> 'b -> 'c Seq.seq Seq.seq) -> 'a -> int -> 'b -> 'c skipseq
+ val skip_first_occs_search: int -> ('a -> 'b -> 'c Seq.seq Seq.seq) -> 'a -> 'b -> 'c Seq.seq
+ val skipto_skipseq: int -> 'a Seq.seq Seq.seq -> 'a skipseq
+
+ (* tactics *)
+ val eqsubst_asm_tac: Proof.context -> int list -> thm list -> int -> tactic
+ val eqsubst_asm_tac': Proof.context ->
+ (searchinfo -> int -> term -> match skipseq) -> int -> thm -> int -> tactic
+ val eqsubst_tac: Proof.context ->
+ int list -> (* list of occurrences to rewrite, use [0] for any *)
+ thm list -> int -> tactic
+ val eqsubst_tac': Proof.context ->
+ (searchinfo -> term -> match Seq.seq) (* search function *)
+ -> thm (* equation theorem to rewrite with *)
+ -> int (* subgoal number in goal theorem *)
+ -> thm (* goal theorem *)
+ -> thm Seq.seq (* rewritten goal theorem *)
+
+ (* search for substitutions *)
+ val valid_match_start: Zipper.T -> bool
+ val search_lr_all: Zipper.T -> Zipper.T Seq.seq
+ val search_lr_valid: (Zipper.T -> bool) -> Zipper.T -> Zipper.T Seq.seq
+ val searchf_lr_unify_all: searchinfo -> term -> match Seq.seq Seq.seq
+ val searchf_lr_unify_valid: searchinfo -> term -> match Seq.seq Seq.seq
+ val searchf_bt_unify_valid: searchinfo -> term -> match Seq.seq Seq.seq
+end;
+
+structure EqSubst: EQSUBST =
+struct
+
+(* changes object "=" to meta "==" which prepares a given rewrite rule *)
+fun prep_meta_eq ctxt =
+ Simplifier.mksimps ctxt #> map Drule.zero_var_indexes;
+
+(* make free vars into schematic vars with index zero *)
+fun unfix_frees frees =
+ fold (K (Thm.forall_elim_var 0)) frees o Drule.forall_intr_list frees;
+
+
+type match =
+ ((indexname * (sort * typ)) list (* type instantiations *)
+ * (indexname * (typ * term)) list) (* term instantiations *)
+ * (string * typ) list (* fake named type abs env *)
+ * (string * typ) list (* type abs env *)
+ * term; (* outer term *)
+
+type searchinfo =
+ Proof.context
+ * int (* maxidx *)
+ * Zipper.T; (* focusterm to search under *)
+
+
+(* skipping non-empty sub-sequences but when we reach the end
+ of the seq, remembering how much we have left to skip. *)
+datatype 'a skipseq =
+ SkipMore of int |
+ SkipSeq of 'a Seq.seq Seq.seq;
+
+(* given a seqseq, skip the first m non-empty seq's, note deficit *)
+fun skipto_skipseq m s =
+ let
+ fun skip_occs n sq =
+ (case Seq.pull sq of
+ NONE => SkipMore n
+ | SOME (h, t) =>
+ (case Seq.pull h of
+ NONE => skip_occs n t
+ | SOME _ => if n <= 1 then SkipSeq (Seq.cons h t) else skip_occs (n - 1) t))
+ in skip_occs m s end;
+
+(* note: outerterm is the taget with the match replaced by a bound
+ variable : ie: "P lhs" beocmes "%x. P x"
+ insts is the types of instantiations of vars in lhs
+ and typinsts is the type instantiations of types in the lhs
+ Note: Final rule is the rule lifted into the ontext of the
+ taget thm. *)
+fun mk_foo_match mkuptermfunc Ts t =
+ let
+ val ty = Term.type_of t
+ val bigtype = rev (map snd Ts) ---> ty
+ fun mk_foo 0 t = t
+ | mk_foo i t = mk_foo (i - 1) (t $ (Bound (i - 1)))
+ val num_of_bnds = length Ts
+ (* foo_term = "fooabs y0 ... yn" where y's are local bounds *)
+ val foo_term = mk_foo num_of_bnds (Bound num_of_bnds)
+ in Abs ("fooabs", bigtype, mkuptermfunc foo_term) end;
+
+(* T is outer bound vars, n is number of locally bound vars *)
+(* THINK: is order of Ts correct...? or reversed? *)
+fun mk_fake_bound_name n = ":b_" ^ n;
+fun fakefree_badbounds Ts t =
+ let val (FakeTs, Ts, newnames) =
+ fold_rev (fn (n, ty) => fn (FakeTs, Ts, usednames) =>
+ let
+ val newname = singleton (Name.variant_list usednames) n
+ in
+ ((mk_fake_bound_name newname, ty) :: FakeTs,
+ (newname, ty) :: Ts,
+ newname :: usednames)
+ end) Ts ([], [], [])
+ in (FakeTs, Ts, Term.subst_bounds (map Free FakeTs, t)) end;
+
+(* before matching we need to fake the bound vars that are missing an
+ abstraction. In this function we additionally construct the
+ abstraction environment, and an outer context term (with the focus
+ abstracted out) for use in rewriting with RW_Inst.rw *)
+fun prep_zipper_match z =
+ let
+ val t = Zipper.trm z
+ val c = Zipper.ctxt z
+ val Ts = Zipper.C.nty_ctxt c
+ val (FakeTs', Ts', t') = fakefree_badbounds Ts t
+ val absterm = mk_foo_match (Zipper.C.apply c) Ts' t'
+ in
+ (t', (FakeTs', Ts', absterm))
+ end;
+
+(* Unification with exception handled *)
+(* given context, max var index, pat, tgt; returns Seq of instantiations *)
+fun clean_unify ctxt ix (a as (pat, tgt)) =
+ let
+ (* type info will be re-derived, maybe this can be cached
+ for efficiency? *)
+ val pat_ty = Term.type_of pat;
+ val tgt_ty = Term.type_of tgt;
+ (* FIXME is it OK to ignore the type instantiation info?
+ or should I be using it? *)
+ val typs_unify =
+ SOME (Sign.typ_unify (Proof_Context.theory_of ctxt) (pat_ty, tgt_ty) (Vartab.empty, ix))
+ handle Type.TUNIFY => NONE;
+ in
+ (case typs_unify of
+ SOME (typinsttab, ix2) =>
+ let
+ (* FIXME is it right to throw away the flexes?
+ or should I be using them somehow? *)
+ fun mk_insts env =
+ (Vartab.dest (Envir.type_env env),
+ Vartab.dest (Envir.term_env env));
+ val initenv =
+ Envir.Envir {maxidx = ix2, tenv = Vartab.empty, tyenv = typinsttab};
+ val useq = Unify.smash_unifiers (Context.Proof ctxt) [a] initenv
+ handle ListPair.UnequalLengths => Seq.empty
+ | Term.TERM _ => Seq.empty;
+ fun clean_unify' useq () =
+ (case (Seq.pull useq) of
+ NONE => NONE
+ | SOME (h, t) => SOME (mk_insts h, Seq.make (clean_unify' t)))
+ handle ListPair.UnequalLengths => NONE
+ | Term.TERM _ => NONE;
+ in
+ (Seq.make (clean_unify' useq))
+ end
+ | NONE => Seq.empty)
+ end;
+
+(* Unification for zippers *)
+(* Note: Ts is a modified version of the original names of the outer
+ bound variables. New names have been introduced to make sure they are
+ unique w.r.t all names in the term and each other. usednames' is
+ oldnames + new names. *)
+fun clean_unify_z ctxt maxidx pat z =
+ let val (t, (FakeTs, Ts, absterm)) = prep_zipper_match z in
+ Seq.map (fn insts => (insts, FakeTs, Ts, absterm))
+ (clean_unify ctxt maxidx (t, pat))
+ end;
+
+
+fun bot_left_leaf_of (l $ _) = bot_left_leaf_of l
+ | bot_left_leaf_of (Abs (_, _, t)) = bot_left_leaf_of t
+ | bot_left_leaf_of x = x;
+
+(* Avoid considering replacing terms which have a var at the head as
+ they always succeed trivially, and uninterestingly. *)
+fun valid_match_start z =
+ (case bot_left_leaf_of (Zipper.trm z) of
+ Var _ => false
+ | _ => true);
+
+(* search from top, left to right, then down *)
+val search_lr_all = ZipperSearch.all_bl_ur;
+
+(* search from top, left to right, then down *)
+fun search_lr_valid validf =
+ let
+ fun sf_valid_td_lr z =
+ let val here = if validf z then [Zipper.Here z] else [] in
+ (case Zipper.trm z of
+ _ $ _ =>
+ [Zipper.LookIn (Zipper.move_down_left z)] @ here @
+ [Zipper.LookIn (Zipper.move_down_right z)]
+ | Abs _ => here @ [Zipper.LookIn (Zipper.move_down_abs z)]
+ | _ => here)
+ end;
+ in Zipper.lzy_search sf_valid_td_lr end;
+
+(* search from bottom to top, left to right *)
+fun search_bt_valid validf =
+ let
+ fun sf_valid_td_lr z =
+ let val here = if validf z then [Zipper.Here z] else [] in
+ (case Zipper.trm z of
+ _ $ _ =>
+ [Zipper.LookIn (Zipper.move_down_left z),
+ Zipper.LookIn (Zipper.move_down_right z)] @ here
+ | Abs _ => [Zipper.LookIn (Zipper.move_down_abs z)] @ here
+ | _ => here)
+ end;
+ in Zipper.lzy_search sf_valid_td_lr end;
+
+fun searchf_unify_gen f (ctxt, maxidx, z) lhs =
+ Seq.map (clean_unify_z ctxt maxidx lhs) (Zipper.limit_apply f z);
+
+(* search all unifications *)
+val searchf_lr_unify_all = searchf_unify_gen search_lr_all;
+
+(* search only for 'valid' unifiers (non abs subterms and non vars) *)
+val searchf_lr_unify_valid = searchf_unify_gen (search_lr_valid valid_match_start);
+
+val searchf_bt_unify_valid = searchf_unify_gen (search_bt_valid valid_match_start);
+
+(* apply a substitution in the conclusion of the theorem *)
+(* cfvs are certified free var placeholders for goal params *)
+(* conclthm is a theorem of for just the conclusion *)
+(* m is instantiation/match information *)
+(* rule is the equation for substitution *)
+fun apply_subst_in_concl ctxt i st (cfvs, conclthm) rule m =
+ RW_Inst.rw ctxt m rule conclthm
+ |> unfix_frees cfvs
+ |> Conv.fconv_rule Drule.beta_eta_conversion
+ |> (fn r => resolve_tac ctxt [r] i st);
+
+(* substitute within the conclusion of goal i of gth, using a meta
+equation rule. Note that we assume rule has var indicies zero'd *)
+fun prep_concl_subst ctxt i gth =
+ let
+ val th = Thm.incr_indexes 1 gth;
+ val tgt_term = Thm.prop_of th;
+
+ val (fixedbody, fvs) = IsaND.fix_alls_term ctxt i tgt_term;
+ val cfvs = rev (map (Thm.cterm_of ctxt) fvs);
+
+ val conclterm = Logic.strip_imp_concl fixedbody;
+ val conclthm = Thm.trivial (Thm.cterm_of ctxt conclterm);
+ val maxidx = Thm.maxidx_of th;
+ val ft =
+ (Zipper.move_down_right (* ==> *)
+ o Zipper.move_down_left (* Trueprop *)
+ o Zipper.mktop
+ o Thm.prop_of) conclthm
+ in
+ ((cfvs, conclthm), (ctxt, maxidx, ft))
+ end;
+
+(* substitute using an object or meta level equality *)
+fun eqsubst_tac' ctxt searchf instepthm i st =
+ let
+ val (cvfsconclthm, searchinfo) = prep_concl_subst ctxt i st;
+ val stepthms = Seq.of_list (prep_meta_eq ctxt instepthm);
+ fun rewrite_with_thm r =
+ let val (lhs,_) = Logic.dest_equals (Thm.concl_of r) in
+ searchf searchinfo lhs
+ |> Seq.maps (apply_subst_in_concl ctxt i st cvfsconclthm r)
+ end;
+ in stepthms |> Seq.maps rewrite_with_thm end;
+
+
+(* General substitution of multiple occurrences using one of
+ the given theorems *)
+
+fun skip_first_occs_search occ srchf sinfo lhs =
+ (case skipto_skipseq occ (srchf sinfo lhs) of
+ SkipMore _ => Seq.empty
+ | SkipSeq ss => Seq.flat ss);
+
+(* The "occs" argument is a list of integers indicating which occurrence
+w.r.t. the search order, to rewrite. Backtracking will also find later
+occurrences, but all earlier ones are skipped. Thus you can use [0] to
+just find all rewrites. *)
+
+fun eqsubst_tac ctxt occs thms i st =
+ let val nprems = Thm.nprems_of st in
+ if nprems < i then Seq.empty else
+ let
+ val thmseq = Seq.of_list thms;
+ fun apply_occ occ st =
+ thmseq |> Seq.maps (fn r =>
+ eqsubst_tac' ctxt
+ (skip_first_occs_search occ searchf_lr_unify_valid) r
+ (i + (Thm.nprems_of st - nprems)) st);
+ val sorted_occs = Library.sort (rev_order o int_ord) occs;
+ in
+ Seq.maps distinct_subgoals_tac (Seq.EVERY (map apply_occ sorted_occs) st)
+ end
+ end;
+
+
+(* apply a substitution inside assumption j, keeps asm in the same place *)
+fun apply_subst_in_asm ctxt i st rule ((cfvs, j, _, pth),m) =
+ let
+ val st2 = Thm.rotate_rule (j - 1) i st; (* put premice first *)
+ val preelimrule =
+ RW_Inst.rw ctxt m rule pth
+ |> (Seq.hd o prune_params_tac ctxt)
+ |> Thm.permute_prems 0 ~1 (* put old asm first *)
+ |> unfix_frees cfvs (* unfix any global params *)
+ |> Conv.fconv_rule Drule.beta_eta_conversion; (* normal form *)
+ in
+ (* ~j because new asm starts at back, thus we subtract 1 *)
+ Seq.map (Thm.rotate_rule (~ j) (Thm.nprems_of rule + i))
+ (dresolve_tac ctxt [preelimrule] i st2)
+ end;
+
+
+(* prepare to substitute within the j'th premise of subgoal i of gth,
+using a meta-level equation. Note that we assume rule has var indicies
+zero'd. Note that we also assume that premt is the j'th premice of
+subgoal i of gth. Note the repetition of work done for each
+assumption, i.e. this can be made more efficient for search over
+multiple assumptions. *)
+fun prep_subst_in_asm ctxt i gth j =
+ let
+ val th = Thm.incr_indexes 1 gth;
+ val tgt_term = Thm.prop_of th;
+
+ val (fixedbody, fvs) = IsaND.fix_alls_term ctxt i tgt_term;
+ val cfvs = rev (map (Thm.cterm_of ctxt) fvs);
+
+ val asmt = nth (Logic.strip_imp_prems fixedbody) (j - 1);
+ val asm_nprems = length (Logic.strip_imp_prems asmt);
+
+ val pth = Thm.trivial ((Thm.cterm_of ctxt) asmt);
+ val maxidx = Thm.maxidx_of th;
+
+ val ft =
+ (Zipper.move_down_right (* trueprop *)
+ o Zipper.mktop
+ o Thm.prop_of) pth
+ in ((cfvs, j, asm_nprems, pth), (ctxt, maxidx, ft)) end;
+
+(* prepare subst in every possible assumption *)
+fun prep_subst_in_asms ctxt i gth =
+ map (prep_subst_in_asm ctxt i gth)
+ ((fn l => Library.upto (1, length l))
+ (Logic.prems_of_goal (Thm.prop_of gth) i));
+
+
+(* substitute in an assumption using an object or meta level equality *)
+fun eqsubst_asm_tac' ctxt searchf skipocc instepthm i st =
+ let
+ val asmpreps = prep_subst_in_asms ctxt i st;
+ val stepthms = Seq.of_list (prep_meta_eq ctxt instepthm);
+ fun rewrite_with_thm r =
+ let
+ val (lhs,_) = Logic.dest_equals (Thm.concl_of r);
+ fun occ_search occ [] = Seq.empty
+ | occ_search occ ((asminfo, searchinfo)::moreasms) =
+ (case searchf searchinfo occ lhs of
+ SkipMore i => occ_search i moreasms
+ | SkipSeq ss =>
+ Seq.append (Seq.map (Library.pair asminfo) (Seq.flat ss))
+ (occ_search 1 moreasms)) (* find later substs also *)
+ in
+ occ_search skipocc asmpreps |> Seq.maps (apply_subst_in_asm ctxt i st r)
+ end;
+ in stepthms |> Seq.maps rewrite_with_thm end;
+
+
+fun skip_first_asm_occs_search searchf sinfo occ lhs =
+ skipto_skipseq occ (searchf sinfo lhs);
+
+fun eqsubst_asm_tac ctxt occs thms i st =
+ let val nprems = Thm.nprems_of st in
+ if nprems < i then Seq.empty
+ else
+ let
+ val thmseq = Seq.of_list thms;
+ fun apply_occ occ st =
+ thmseq |> Seq.maps (fn r =>
+ eqsubst_asm_tac' ctxt
+ (skip_first_asm_occs_search searchf_lr_unify_valid) occ r
+ (i + (Thm.nprems_of st - nprems)) st);
+ val sorted_occs = Library.sort (rev_order o int_ord) occs;
+ in
+ Seq.maps distinct_subgoals_tac (Seq.EVERY (map apply_occ sorted_occs) st)
+ end
+ end;
+
+(* combination method that takes a flag (true indicates that subst
+ should be done to an assumption, false = apply to the conclusion of
+ the goal) as well as the theorems to use *)
+val _ =
+ Theory.setup
+ (Method.setup \<^binding>\<open>sub\<close>
+ (Scan.lift (Args.mode "asm" -- Scan.optional (Args.parens (Scan.repeat Parse.nat)) [0]) --
+ Attrib.thms >> (fn ((asm, occs), inthms) => fn ctxt =>
+ SIMPLE_METHOD' ((if asm then eqsubst_asm_tac else eqsubst_tac) ctxt occs inthms)))
+ "single-step substitution"
+ #>
+ (Method.setup \<^binding>\<open>subst\<close>
+ (Scan.lift (Args.mode "asm" -- Scan.optional (Args.parens (Scan.repeat Parse.nat)) [0]) --
+ Attrib.thms >> (fn ((asm, occs), inthms) => fn ctxt =>
+ SIMPLE_METHOD' (SIDE_CONDS
+ ((if asm then eqsubst_asm_tac else eqsubst_tac) ctxt occs inthms)
+ ctxt)))
+ "single-step substitution with auto-typechecking"))
+
+end;
diff --git a/spartan/core/ml/equality.ML b/spartan/core/ml/equality.ML
new file mode 100644
index 0000000..023147b
--- /dev/null
+++ b/spartan/core/ml/equality.ML
@@ -0,0 +1,90 @@
+(* Title: equality.ML
+ Author: Joshua Chen
+
+Equality reasoning with identity types.
+*)
+
+structure Equality:
+sig
+
+val dest_Id: term -> term * term * term
+
+val push_hyp_tac: term * term -> Proof.context -> int -> tactic
+val induction_tac: term -> term -> term -> term -> Proof.context -> tactic
+val equality_context_tac: Facts.ref -> Proof.context -> context_tactic
+
+end = struct
+
+fun dest_Id tm = case tm of
+ Const (\<^const_name>\<open>Id\<close>, _) $ A $ x $ y => (A, x, y)
+ | _ => error "dest_Id"
+
+(*Context assumptions that have already been pushed into the type family*)
+structure Inserts = Proof_Data (
+ type T = term Item_Net.T
+ val init = K (Item_Net.init Term.aconv_untyped single)
+)
+
+fun push_hyp_tac (t, _) =
+ Subgoal.FOCUS_PARAMS (fn {context = ctxt, concl, ...} =>
+ let
+ val (_, C) = Lib.dest_typing (Thm.term_of concl)
+ val B = Thm.cterm_of ctxt (Lib.lambda_var t C)
+ val a = Thm.cterm_of ctxt t
+ (*The resolvent is PiE[where ?B=B and ?a=a]*)
+ val resolvent =
+ Drule.infer_instantiate' ctxt [NONE, NONE, SOME B, SOME a] @{thm PiE}
+ in
+ HEADGOAL (resolve_tac ctxt [resolvent])
+ THEN SOMEGOAL (known_tac ctxt)
+ end)
+
+fun induction_tac p A x y ctxt =
+ let
+ val [p, A, x, y] = map (Thm.cterm_of ctxt) [p, A, x, y]
+ in
+ HEADGOAL (resolve_tac ctxt
+ [Drule.infer_instantiate' ctxt [SOME p, SOME A, SOME x, SOME y] @{thm IdE}])
+ end
+
+val side_conds_tac = TRY oo typechk_tac
+
+fun equality_context_tac fact ctxt =
+ let
+ val eq_th = Proof_Context.get_fact_single ctxt fact
+ val (p, (A, x, y)) = (Lib.dest_typing ##> dest_Id) (Thm.prop_of eq_th)
+
+ val hyps =
+ Facts.props (Proof_Context.facts_of ctxt)
+ |> filter (fn (th, _) => Lib.is_typing (Thm.prop_of th))
+ |> map (Lib.dest_typing o Thm.prop_of o fst)
+ |> filter_out (fn (t, _) =>
+ Term.aconv (t, p) orelse Item_Net.member (Inserts.get ctxt) t)
+ |> map (fn (t, T) => ((t, T), Lib.subterm_count_distinct [p, x, y] T))
+ |> filter (fn (_, i) => i > 0)
+ (*`t1: T1` comes before `t2: T2` if T1 contains t2 as subterm.
+ If they are incomparable, then order by decreasing
+ `subterm_count [p, x, y] T`*)
+ |> sort (fn (((t1, _), i), ((_, T2), j)) =>
+ Lib.cond_order (Lib.subterm_order T2 t1) (int_ord (j, i)))
+ |> map #1
+
+ val record_inserts =
+ Inserts.map (fold (fn (t, _) => fn net => Item_Net.update t net) hyps)
+
+ val tac =
+ fold (fn hyp => fn tac => tac THEN HEADGOAL (push_hyp_tac hyp ctxt))
+ hyps all_tac
+ THEN (
+ induction_tac p A x y ctxt
+ THEN RANGE (replicate 3 (typechk_tac ctxt) @ [side_conds_tac ctxt]) 1
+ )
+ THEN (
+ REPEAT_DETERM_N (length hyps) (SOMEGOAL (resolve_tac ctxt @{thms PiI}))
+ THEN ALLGOALS (side_conds_tac ctxt)
+ )
+ in
+ fn (ctxt, st) => Context_Tactic.TACTIC_CONTEXT (record_inserts ctxt) (tac st)
+ end
+
+end
diff --git a/spartan/core/ml/focus.ML b/spartan/core/ml/focus.ML
new file mode 100644
index 0000000..1d8de78
--- /dev/null
+++ b/spartan/core/ml/focus.ML
@@ -0,0 +1,125 @@
+(* Title: focus.ML
+ Author: Makarius Wenzel, Joshua Chen
+
+A modified version of the Isar `subgoal` command
+that keeps schematic variables in the goal state.
+
+Modified from code originally written by Makarius Wenzel.
+*)
+
+local
+
+fun param_bindings ctxt (param_suffix, raw_param_specs) st =
+ let
+ val _ = if Thm.no_prems st then error "No subgoals!" else ()
+ val subgoal = #1 (Logic.dest_implies (Thm.prop_of st))
+ val subgoal_params =
+ map (apfst (Name.internal o Name.clean)) (Term.strip_all_vars subgoal)
+ |> Term.variant_frees subgoal |> map #1
+
+ val n = length subgoal_params
+ val m = length raw_param_specs
+ val _ =
+ m <= n orelse
+ error ("Excessive subgoal parameter specification" ^
+ Position.here_list (map snd (drop n raw_param_specs)))
+
+ val param_specs =
+ raw_param_specs |> map
+ (fn (NONE, _) => NONE
+ | (SOME x, pos) =>
+ let
+ val b = #1 (#1 (Proof_Context.cert_var (Binding.make (x, pos), NONE, NoSyn) ctxt))
+ val _ = Variable.check_name b
+ in SOME b end)
+ |> param_suffix ? append (replicate (n - m) NONE)
+
+ fun bindings (SOME x :: xs) (_ :: ys) = x :: bindings xs ys
+ | bindings (NONE :: xs) (y :: ys) = Binding.name y :: bindings xs ys
+ | bindings _ ys = map Binding.name ys
+ in bindings param_specs subgoal_params end
+
+fun gen_schematic_subgoal prep_atts raw_result_binding raw_prems_binding param_specs state =
+ let
+ val _ = Proof.assert_backward state
+
+ val state1 = state
+ |> Proof.map_context (Proof_Context.set_mode Proof_Context.mode_schematic)
+ |> Proof.refine_insert []
+
+ val {context = ctxt, facts = facts, goal = st} = Proof.raw_goal state1
+
+ val result_binding = apsnd (map (prep_atts ctxt)) raw_result_binding
+ val (prems_binding, do_prems) =
+ (case raw_prems_binding of
+ SOME (b, raw_atts) => ((b, map (prep_atts ctxt) raw_atts), true)
+ | NONE => (Binding.empty_atts, false))
+
+ val (subgoal_focus, _) =
+ (if do_prems then Subgoal.focus_prems else Subgoal.focus_params) ctxt
+ 1 (SOME (param_bindings ctxt param_specs st)) st
+
+ fun after_qed (ctxt'', [[result]]) =
+ Proof.end_block #> (fn state' =>
+ let
+ val ctxt' = Proof.context_of state'
+ val results' =
+ Proof_Context.export ctxt'' ctxt' (Conjunction.elim_conjunctions result)
+ in
+ state'
+ |> Proof.refine_primitive (fn _ => fn _ =>
+ Subgoal.retrofit ctxt'' ctxt' (#params subgoal_focus) (#asms subgoal_focus) 1
+ (Goal.protect 0 result) st
+ |> Seq.hd)
+ |> Proof.map_context
+ (#2 o Proof_Context.note_thmss "" [(result_binding, [(results', [])])])
+ end)
+ #> Proof.reset_facts
+ #> Proof.enter_backward
+ in
+ state1
+ |> Proof.enter_forward
+ |> Proof.using_facts []
+ |> Proof.begin_block
+ |> Proof.map_context (fn _ =>
+ #context subgoal_focus
+ |> Proof_Context.note_thmss "" [(prems_binding, [(#prems subgoal_focus, [])])] |> #2)
+ |> Proof.internal_goal (K (K ())) (Proof_Context.get_mode ctxt) true "subgoal"
+ NONE after_qed [] [] [(Binding.empty_atts, [(Thm.term_of (#concl subgoal_focus), [])])] |> #2
+ |> Proof.using_facts facts
+ |> pair subgoal_focus
+ end
+
+val opt_fact_binding =
+ Scan.optional (Parse.binding -- Parse.opt_attribs || Parse.attribs >> pair Binding.empty)
+ Binding.empty_atts
+
+val for_params =
+ Scan.optional
+ (\<^keyword>\<open>vars\<close> |--
+ Parse.!!! ((Scan.option Parse.dots >> is_some) --
+ (Scan.repeat1 (Parse.maybe_position Parse.name_position))))
+ (false, [])
+
+val schematic_subgoal_cmd = gen_schematic_subgoal Attrib.attribute_cmd
+
+val parser =
+ opt_fact_binding
+ -- (Scan.option (\<^keyword>\<open>prems\<close> |-- Parse.!!! opt_fact_binding))
+ -- for_params >> (fn ((a, b), c) =>
+ Toplevel.proofs (Seq.make_results o Seq.single o #2 o schematic_subgoal_cmd a b c))
+
+in
+
+(** Outer syntax commands **)
+
+val _ = Outer_Syntax.command \<^command_keyword>\<open>focus\<close>
+ "focus on first subgoal within backward refinement, without instantiating schematic vars"
+ parser
+
+val _ = Outer_Syntax.command \<^command_keyword>\<open>\<guillemotright>\<close> "focus bullet" parser
+val _ = Outer_Syntax.command \<^command_keyword>\<open>\<^item>\<close> "focus bullet" parser
+val _ = Outer_Syntax.command \<^command_keyword>\<open>\<^enum>\<close> "focus bullet" parser
+val _ = Outer_Syntax.command \<^command_keyword>\<open>~\<close> "focus bullet" parser
+
+end
diff --git a/spartan/core/ml/goals.ML b/spartan/core/ml/goals.ML
new file mode 100644
index 0000000..9f394f0
--- /dev/null
+++ b/spartan/core/ml/goals.ML
@@ -0,0 +1,214 @@
+(* Title: goals.ML
+ Author: Makarius Wenzel, Joshua Chen
+
+Goal statements and proof term export.
+
+Modified from code originally written by Makarius Wenzel.
+*)
+
+local
+
+val long_keyword =
+ Parse_Spec.includes >> K "" ||
+ Parse_Spec.long_statement_keyword
+
+val long_statement =
+ Scan.optional
+ (Parse_Spec.opt_thm_name ":" --| Scan.ahead long_keyword)
+ Binding.empty_atts --
+ Scan.optional Parse_Spec.includes [] -- Parse_Spec.long_statement
+ >> (fn ((binding, includes), (elems, concl)) =>
+ (true, binding, includes, elems, concl))
+
+val short_statement =
+ Parse_Spec.statement -- Parse_Spec.if_statement -- Parse.for_fixes
+ >> (fn ((shows, assumes), fixes) =>
+ (false, Binding.empty_atts, [],
+ [Element.Fixes fixes, Element.Assumes assumes],
+ Element.Shows shows))
+
+fun prep_statement prep_att prep_stmt raw_elems raw_stmt ctxt =
+ let
+ val (stmt, elems_ctxt) = prep_stmt raw_elems raw_stmt ctxt
+ val prems = Assumption.local_prems_of elems_ctxt ctxt
+ val stmt_ctxt = fold (fold (Proof_Context.augment o fst) o snd)
+ stmt elems_ctxt
+ in
+ case raw_stmt of
+ Element.Shows _ =>
+ let val stmt' = Attrib.map_specs (map prep_att) stmt
+ in (([], prems, stmt', NONE), stmt_ctxt) end
+ | Element.Obtains raw_obtains =>
+ let
+ val asms_ctxt = stmt_ctxt
+ |> fold (fn ((name, _), asm) =>
+ snd o Proof_Context.add_assms Assumption.assume_export
+ [((name, [Context_Rules.intro_query NONE]), asm)]) stmt
+ val that = Assumption.local_prems_of asms_ctxt stmt_ctxt
+ val ([(_, that')], that_ctxt) = asms_ctxt
+ |> Proof_Context.set_stmt true
+ |> Proof_Context.note_thmss ""
+ [((Binding.name Auto_Bind.thatN, []), [(that, [])])]
+ ||> Proof_Context.restore_stmt asms_ctxt
+
+ val stmt' = [
+ (Binding.empty_atts,
+ [(#2 (#1 (Obtain.obtain_thesis ctxt)), [])])
+ ]
+ in
+ ((Obtain.obtains_attribs raw_obtains, prems, stmt', SOME that'),
+ that_ctxt)
+ end
+ end
+
+fun define_proof_term name (local_name, [th]) lthy =
+ let
+ fun make_name_binding suffix local_name =
+ let val base_local_name = Long_Name.base_name local_name
+ in
+ Binding.qualified_name
+ ((case base_local_name of
+ "" => name
+ | _ => base_local_name)
+ ^(case suffix of
+ SOME "prf" => "_prf"
+ | SOME "def" => "_def"
+ | _ => ""))
+ end
+
+ val (prems, concl) =
+ (Logic.strip_assums_hyp (Thm.prop_of th),
+ Logic.strip_assums_concl (Thm.prop_of th))
+ in
+ if not (Lib.is_typing concl) then
+ ([], lthy)
+ else let
+ val prems_vars = distinct Term.aconv (flat
+ (map (Lib.collect_subterms is_Var) prems))
+
+ val concl_vars = Lib.collect_subterms is_Var
+ (Lib.term_of_typing concl)
+
+ val params = inter Term.aconv concl_vars prems_vars
+
+ val prf_tm =
+ fold_rev lambda params (Lib.term_of_typing concl)
+
+ val ((_, (_, raw_def)), lthy') = Local_Theory.define
+ ((make_name_binding NONE local_name, Mixfix.NoSyn),
+ ((make_name_binding (SOME "prf") local_name, []), prf_tm)) lthy
+
+ val def =
+ fold
+ (fn th1 => fn th2 => Thm.combination th2 th1)
+ (map (Thm.reflexive o Thm.cterm_of lthy) params)
+ raw_def
+
+ val ((_, def'), lthy'') = Local_Theory.note
+ ((make_name_binding (SOME "def") local_name, []), [def])
+ lthy'
+ in
+ (def', lthy'')
+ end
+ end
+ | define_proof_term _ _ _ = error
+ ("Unimplemented: handling proof terms of multiple facts in"
+ ^" single result")
+
+fun gen_schematic_theorem
+ bundle_includes prep_att prep_stmt
+ gen_prf long kind before_qed after_qed (name, raw_atts)
+ raw_includes raw_elems raw_concl int lthy =
+ let
+ val _ = Local_Theory.assert lthy;
+
+ val elems = raw_elems |> map (Element.map_ctxt_attrib (prep_att lthy))
+ val ((more_atts, prems, stmt, facts), goal_ctxt) = lthy
+ |> bundle_includes raw_includes
+ |> prep_statement (prep_att lthy) prep_stmt elems raw_concl
+ val atts = more_atts @ map (prep_att lthy) raw_atts
+ val pos = Position.thread_data ()
+
+ val prems_name = if long then Auto_Bind.assmsN else Auto_Bind.thatN
+
+ fun after_qed' results goal_ctxt' =
+ let
+ val results' = burrow
+ (map (Goal.norm_result lthy) o Proof_Context.export goal_ctxt' lthy)
+ results
+
+ val ((res, lthy'), substmts) =
+ if forall (Binding.is_empty_atts o fst) stmt
+ then ((map (pair "") results', lthy), false)
+ else
+ (Local_Theory.notes_kind kind
+ (map2 (fn (b, _) => fn ths => (b, [(ths, [])])) stmt results')
+ lthy,
+ true)
+
+ val (res', lthy'') =
+ if gen_prf
+ then
+ let
+ val (prf_tm_defs, lthy'') =
+ fold
+ (fn result => fn (defs, lthy) =>
+ apfst (fn new_defs => defs @ new_defs)
+ (define_proof_term (Binding.name_of name) result lthy))
+ res ([], lthy')
+
+ val res_folded =
+ map (apsnd (map (Local_Defs.fold lthy'' prf_tm_defs))) res
+ in
+ Local_Theory.notes_kind kind
+ [((name, @{attributes [typechk]} @ atts),
+ [(maps #2 res_folded, [])])]
+ lthy''
+ end
+ else
+ Local_Theory.notes_kind kind
+ [((name, atts), [(maps #2 res, [])])]
+ lthy'
+
+ val _ = Proof_Display.print_results int pos lthy''
+ ((kind, Binding.name_of name), map (fn (_, ths) => ("", ths)) res')
+
+ val _ =
+ if substmts then map
+ (fn (name, ths) => Proof_Display.print_results int pos lthy''
+ (("and", name), [("", ths)]))
+ res
+ else []
+ in
+ after_qed results' lthy''
+ end
+ in
+ goal_ctxt
+ |> not (null prems) ?
+ (Proof_Context.note_thmss "" [((Binding.name prems_name, []), [(prems, [])])] #> snd)
+ |> Proof.theorem before_qed after_qed' (map snd stmt)
+ |> (case facts of NONE => I | SOME ths => Proof.refine_insert ths)
+ end
+
+val schematic_theorem_cmd =
+ gen_schematic_theorem
+ Bundle.includes_cmd
+ Attrib.check_src
+ Expression.read_statement
+
+fun theorem spec descr =
+ Outer_Syntax.local_theory_to_proof' spec ("state " ^ descr)
+ (Scan.option (Args.parens (Args.$$$ "derive"))
+ -- (long_statement || short_statement) >>
+ (fn (opt_derive, (long, binding, includes, elems, concl)) =>
+ schematic_theorem_cmd
+ (case opt_derive of SOME "derive" => true | _ => false)
+ long descr NONE (K I) binding includes elems concl))
+in
+
+val _ = theorem \<^command_keyword>\<open>Theorem\<close> "Theorem"
+val _ = theorem \<^command_keyword>\<open>Lemma\<close> "Lemma"
+val _ = theorem \<^command_keyword>\<open>Corollary\<close> "Corollary"
+val _ = theorem \<^command_keyword>\<open>Proposition\<close> "Proposition"
+
+end
diff --git a/spartan/core/ml/implicits.ML b/spartan/core/ml/implicits.ML
new file mode 100644
index 0000000..4d73c8d
--- /dev/null
+++ b/spartan/core/ml/implicits.ML
@@ -0,0 +1,78 @@
+structure Implicits :
+sig
+
+val implicit_defs: Proof.context -> (term * term) Symtab.table
+val implicit_defs_attr: attribute
+val make_holes: Proof.context -> term -> term
+
+end = struct
+
+structure Defs = Generic_Data (
+ type T = (term * term) Symtab.table
+ val empty = Symtab.empty
+ val extend = I
+ val merge = Symtab.merge (Term.aconv o apply2 #1)
+)
+
+val implicit_defs = Defs.get o Context.Proof
+
+val implicit_defs_attr = Thm.declaration_attribute (fn th =>
+ let
+ val (t, def) = Lib.dest_eq (Thm.prop_of th)
+ val (head, args) = Term.strip_comb t
+ val def' = fold_rev lambda args def
+ in
+ Defs.map (Symtab.update (Term.term_name head, (head, def')))
+ end)
+
+fun make_holes ctxt =
+ let
+ fun iarg_to_hole (Const (\<^const_name>\<open>iarg\<close>, T)) =
+ Const (\<^const_name>\<open>hole\<close>, T)
+ | iarg_to_hole t = t
+
+ fun expand head args =
+ let
+ fun betapplys (head', args') =
+ Term.betapplys (map_aterms iarg_to_hole head', args')
+ in
+ case head of
+ Abs (x, T, t) =>
+ list_comb (Abs (x, T, Lib.traverse_term expand t), args)
+ | _ =>
+ case Symtab.lookup (implicit_defs ctxt) (Term.term_name head) of
+ SOME (t, def) => betapplys
+ (Envir.expand_atom
+ (Term.fastype_of head)
+ (Term.fastype_of t, def),
+ args)
+ | NONE => list_comb (head, args)
+ end
+
+ fun holes_to_vars t =
+ let
+ val count = Lib.subterm_count (Const (\<^const_name>\<open>hole\<close>, dummyT))
+
+ fun subst (Const (\<^const_name>\<open>hole\<close>, T)) (Var (idx, _)::_) Ts =
+ let
+ val bounds = map Bound (0 upto (length Ts - 1))
+ val T' = foldr1 (op -->) (Ts @ [T])
+ in
+ foldl1 (op $) (Var (idx, T')::bounds)
+ end
+ | subst (Abs (x, T, t)) vs Ts = Abs (x, T, subst t vs (T::Ts))
+ | subst (t $ u) vs Ts =
+ let val n = count t
+ in subst t (take n vs) Ts $ subst u (drop n vs) Ts end
+ | subst t _ _ = t
+
+ val vars = map (fn n => Var ((n, 0), dummyT))
+ (Name.invent (Variable.names_of ctxt) "*" (count t))
+ in
+ subst t vars []
+ end
+ in
+ Lib.traverse_term expand #> holes_to_vars
+ end
+
+end
diff --git a/spartan/core/ml/lib.ML b/spartan/core/ml/lib.ML
new file mode 100644
index 0000000..615f601
--- /dev/null
+++ b/spartan/core/ml/lib.ML
@@ -0,0 +1,145 @@
+structure Lib :
+sig
+
+(*Lists*)
+val max: ('a * 'a -> bool) -> 'a list -> 'a
+val maxint: int list -> int
+
+(*Terms*)
+val is_rigid: term -> bool
+val dest_eq: term -> term * term
+val mk_Var: string -> int -> typ -> term
+val lambda_var: term -> term -> term
+
+val is_typing: term -> bool
+val dest_typing: term -> term * term
+val term_of_typing: term -> term
+val type_of_typing: term -> term
+val mk_Pi: term -> term -> term -> term
+
+val typing_of_term: term -> term
+
+(*Goals*)
+val rigid_typing_concl: term -> bool
+
+(*Subterms*)
+val has_subterm: term list -> term -> bool
+val subterm_count: term -> term -> int
+val subterm_count_distinct: term list -> term -> int
+val traverse_term: (term -> term list -> term) -> term -> term
+val collect_subterms: (term -> bool) -> term -> term list
+
+(*Orderings*)
+val subterm_order: term -> term -> order
+val cond_order: order -> order -> order
+
+end = struct
+
+
+(** Lists **)
+
+fun max gt (x::xs) = fold (fn a => fn b => if gt (a, b) then a else b) xs x
+ | max _ [] = error "max of empty list"
+
+val maxint = max (op >)
+
+
+(** Terms **)
+
+(* Meta *)
+
+val is_rigid = not o is_Var o head_of
+
+fun dest_eq (Const (\<^const_name>\<open>Pure.eq\<close>, _) $ t $ def) = (t, def)
+ | dest_eq _ = error "dest_eq"
+
+fun mk_Var name idx T = Var ((name, idx), T)
+
+fun lambda_var x tm =
+ let
+ fun var_args (Var (idx, T)) = Var (idx, \<^typ>\<open>o\<close> --> T) $ x
+ | var_args t = t
+ in
+ tm |> map_aterms var_args
+ |> lambda x
+ end
+
+(* Object *)
+
+fun is_typing (Const (\<^const_name>\<open>has_type\<close>, _) $ _ $ _) = true
+ | is_typing _ = false
+
+fun dest_typing (Const (\<^const_name>\<open>has_type\<close>, _) $ t $ T) = (t, T)
+ | dest_typing _ = error "dest_typing"
+
+val term_of_typing = #1 o dest_typing
+val type_of_typing = #2 o dest_typing
+
+fun mk_Pi v typ body = Const (\<^const_name>\<open>Pi\<close>, dummyT) $ typ $ lambda v body
+
+fun typing_of_term tm = \<^const>\<open>has_type\<close> $ tm $ Var (("*?", 0), \<^typ>\<open>o\<close>)
+(*The above is a bit hacky; basically we need to guarantee that the schematic
+ var is fresh*)
+
+
+(** Goals **)
+
+fun rigid_typing_concl goal =
+ let val concl = Logic.strip_assums_concl goal
+ in is_typing concl andalso is_rigid (term_of_typing concl) end
+
+
+(** Subterms **)
+
+fun has_subterm tms =
+ Term.exists_subterm
+ (foldl1 (op orf) (map (fn t => fn s => Term.aconv_untyped (s, t)) tms))
+
+fun subterm_count s t =
+ let
+ fun count (t1 $ t2) i = i + count t1 0 + count t2 0
+ | count (Abs (_, _, t)) i = i + count t 0
+ | count t i = if Term.aconv_untyped (s, t) then i + 1 else i
+ in
+ count t 0
+ end
+
+(*Number of distinct subterms in `tms` that appear in `tm`*)
+fun subterm_count_distinct tms tm =
+ length (filter I (map (fn t => has_subterm [t] tm) tms))
+
+(*
+ "Folds" a function f over the term structure of t by traversing t from child
+ nodes upwards through parents. At each node n in the term syntax tree, f is
+ additionally passed a list of the results of f at all children of n.
+*)
+fun traverse_term f t =
+ let
+ fun map_aux (Abs (x, T, t)) = Abs (x, T, map_aux t)
+ | map_aux t =
+ let
+ val (head, args) = Term.strip_comb t
+ val args' = map map_aux args
+ in
+ f head args'
+ end
+ in
+ map_aux t
+ end
+
+fun collect_subterms f (t $ u) = collect_subterms f t @ collect_subterms f u
+ | collect_subterms f (Abs (_, _, t)) = collect_subterms f t
+ | collect_subterms f t = if f t then [t] else []
+
+
+(** Orderings **)
+
+fun subterm_order t1 t2 =
+ if has_subterm [t1] t2 then LESS
+ else if has_subterm [t2] t1 then GREATER
+ else EQUAL
+
+fun cond_order o1 o2 = case o1 of EQUAL => o2 | _ => o1
+
+
+end
diff --git a/spartan/core/ml/rewrite.ML b/spartan/core/ml/rewrite.ML
new file mode 100644
index 0000000..f9c5d8e
--- /dev/null
+++ b/spartan/core/ml/rewrite.ML
@@ -0,0 +1,465 @@
+(* Title: rewrite.ML
+ Author: Christoph Traut, Lars Noschinski, TU Muenchen
+ Modified: Joshua Chen, University of Innsbruck
+
+This is a rewrite method that supports subterm-selection based on patterns.
+
+The patterns accepted by rewrite are of the following form:
+ <atom> ::= <term> | "concl" | "asm" | "for" "(" <names> ")"
+ <pattern> ::= (in <atom> | at <atom>) [<pattern>]
+ <args> ::= [<pattern>] ("to" <term>) <thms>
+
+This syntax was clearly inspired by Gonthier's and Tassi's language of
+patterns but has diverged significantly during its development.
+
+We also allow introduction of identifiers for bound variables,
+which can then be used to match arbitrary subterms inside abstractions.
+
+This code is slightly modified from the original at HOL/Library/rewrite.ML,
+to incorporate auto-typechecking for type theory.
+*)
+
+infix 1 then_pconv;
+infix 0 else_pconv;
+
+signature REWRITE =
+sig
+ type patconv = Proof.context -> Type.tyenv * (string * term) list -> cconv
+ val then_pconv: patconv * patconv -> patconv
+ val else_pconv: patconv * patconv -> patconv
+ val abs_pconv: patconv -> string option * typ -> patconv (*XXX*)
+ val fun_pconv: patconv -> patconv
+ val arg_pconv: patconv -> patconv
+ val imp_pconv: patconv -> patconv
+ val params_pconv: patconv -> patconv
+ val forall_pconv: patconv -> string option * typ option -> patconv
+ val all_pconv: patconv
+ val for_pconv: patconv -> (string option * typ option) list -> patconv
+ val concl_pconv: patconv -> patconv
+ val asm_pconv: patconv -> patconv
+ val asms_pconv: patconv -> patconv
+ val judgment_pconv: patconv -> patconv
+ val in_pconv: patconv -> patconv
+ val match_pconv: patconv -> term * (string option * typ) list -> patconv
+ val rewrs_pconv: term option -> thm list -> patconv
+
+ datatype ('a, 'b) pattern = At | In | Term of 'a | Concl | Asm | For of 'b list
+
+ val mk_hole: int -> typ -> term
+
+ val rewrite_conv: Proof.context
+ -> (term * (string * typ) list, string * typ option) pattern list * term option
+ -> thm list
+ -> conv
+end
+
+structure Rewrite : REWRITE =
+struct
+
+datatype ('a, 'b) pattern = At | In | Term of 'a | Concl | Asm | For of 'b list
+
+exception NO_TO_MATCH
+
+val holeN = Name.internal "_hole"
+
+fun prep_meta_eq ctxt = Simplifier.mksimps ctxt #> map Drule.zero_var_indexes
+
+
+(* holes *)
+
+fun mk_hole i T = Var ((holeN, i), T)
+
+fun is_hole (Var ((name, _), _)) = (name = holeN)
+ | is_hole _ = false
+
+fun is_hole_const (Const (\<^const_name>\<open>rewrite_HOLE\<close>, _)) = true
+ | is_hole_const _ = false
+
+val hole_syntax =
+ let
+ (* Modified variant of Term.replace_hole *)
+ fun replace_hole Ts (Const (\<^const_name>\<open>rewrite_HOLE\<close>, T)) i =
+ (list_comb (mk_hole i (Ts ---> T), map_range Bound (length Ts)), i + 1)
+ | replace_hole Ts (Abs (x, T, t)) i =
+ let val (t', i') = replace_hole (T :: Ts) t i
+ in (Abs (x, T, t'), i') end
+ | replace_hole Ts (t $ u) i =
+ let
+ val (t', i') = replace_hole Ts t i
+ val (u', i'') = replace_hole Ts u i'
+ in (t' $ u', i'') end
+ | replace_hole _ a i = (a, i)
+ fun prep_holes ts = #1 (fold_map (replace_hole []) ts 1)
+ in
+ Context.proof_map (Syntax_Phases.term_check 101 "hole_expansion" (K prep_holes))
+ #> Proof_Context.set_mode Proof_Context.mode_pattern
+ end
+
+
+(* pattern conversions *)
+
+type patconv = Proof.context -> Type.tyenv * (string * term) list -> cterm -> thm
+
+fun (cv1 then_pconv cv2) ctxt tytenv ct = (cv1 ctxt tytenv then_conv cv2 ctxt tytenv) ct
+
+fun (cv1 else_pconv cv2) ctxt tytenv ct = (cv1 ctxt tytenv else_conv cv2 ctxt tytenv) ct
+
+fun raw_abs_pconv cv ctxt tytenv ct =
+ case Thm.term_of ct of
+ Abs _ => CConv.abs_cconv (fn (x, ctxt') => cv x ctxt' tytenv) ctxt ct
+ | t => raise TERM ("raw_abs_pconv", [t])
+
+fun raw_fun_pconv cv ctxt tytenv ct =
+ case Thm.term_of ct of
+ _ $ _ => CConv.fun_cconv (cv ctxt tytenv) ct
+ | t => raise TERM ("raw_fun_pconv", [t])
+
+fun raw_arg_pconv cv ctxt tytenv ct =
+ case Thm.term_of ct of
+ _ $ _ => CConv.arg_cconv (cv ctxt tytenv) ct
+ | t => raise TERM ("raw_arg_pconv", [t])
+
+fun abs_pconv cv (s,T) ctxt (tyenv, ts) ct =
+ let val u = Thm.term_of ct
+ in
+ case try (fastype_of #> dest_funT) u of
+ NONE => raise TERM ("abs_pconv: no function type", [u])
+ | SOME (U, _) =>
+ let
+ val tyenv' =
+ if T = dummyT then tyenv
+ else Sign.typ_match (Proof_Context.theory_of ctxt) (T, U) tyenv
+ val eta_expand_cconv =
+ case u of
+ Abs _=> Thm.reflexive
+ | _ => CConv.rewr_cconv @{thm eta_expand}
+ fun add_ident NONE _ l = l
+ | add_ident (SOME name) ct l = (name, Thm.term_of ct) :: l
+ val abs_cv = CConv.abs_cconv (fn (ct, ctxt) => cv ctxt (tyenv', add_ident s ct ts)) ctxt
+ in (eta_expand_cconv then_conv abs_cv) ct end
+ handle Pattern.MATCH => raise TYPE ("abs_pconv: types don't match", [T,U], [u])
+ end
+
+fun fun_pconv cv ctxt tytenv ct =
+ case Thm.term_of ct of
+ _ $ _ => CConv.fun_cconv (cv ctxt tytenv) ct
+ | Abs (_, T, _ $ Bound 0) => abs_pconv (fun_pconv cv) (NONE, T) ctxt tytenv ct
+ | t => raise TERM ("fun_pconv", [t])
+
+local
+
+fun arg_pconv_gen cv0 cv ctxt tytenv ct =
+ case Thm.term_of ct of
+ _ $ _ => cv0 (cv ctxt tytenv) ct
+ | Abs (_, T, _ $ Bound 0) => abs_pconv (arg_pconv_gen cv0 cv) (NONE, T) ctxt tytenv ct
+ | t => raise TERM ("arg_pconv_gen", [t])
+
+in
+
+fun arg_pconv ctxt = arg_pconv_gen CConv.arg_cconv ctxt
+fun imp_pconv ctxt = arg_pconv_gen (CConv.concl_cconv 1) ctxt
+
+end
+
+(* Move to B in !!x_1 ... x_n. B. Do not eta-expand *)
+fun params_pconv cv ctxt tytenv ct =
+ let val pconv =
+ case Thm.term_of ct of
+ Const (\<^const_name>\<open>Pure.all\<close>, _) $ Abs _ => (raw_arg_pconv o raw_abs_pconv) (fn _ => params_pconv cv)
+ | Const (\<^const_name>\<open>Pure.all\<close>, _) => raw_arg_pconv (params_pconv cv)
+ | _ => cv
+ in pconv ctxt tytenv ct end
+
+fun forall_pconv cv ident ctxt tytenv ct =
+ case Thm.term_of ct of
+ Const (\<^const_name>\<open>Pure.all\<close>, T) $ _ =>
+ let
+ val def_U = T |> dest_funT |> fst |> dest_funT |> fst
+ val ident' = apsnd (the_default (def_U)) ident
+ in arg_pconv (abs_pconv cv ident') ctxt tytenv ct end
+ | t => raise TERM ("forall_pconv", [t])
+
+fun all_pconv _ _ = Thm.reflexive
+
+fun for_pconv cv idents ctxt tytenv ct =
+ let
+ fun f rev_idents (Const (\<^const_name>\<open>Pure.all\<close>, _) $ t) =
+ let val (rev_idents', cv') = f rev_idents (case t of Abs (_,_,u) => u | _ => t)
+ in
+ case rev_idents' of
+ [] => ([], forall_pconv cv' (NONE, NONE))
+ | (x :: xs) => (xs, forall_pconv cv' x)
+ end
+ | f rev_idents _ = (rev_idents, cv)
+ in
+ case f (rev idents) (Thm.term_of ct) of
+ ([], cv') => cv' ctxt tytenv ct
+ | _ => raise CTERM ("for_pconv", [ct])
+ end
+
+fun concl_pconv cv ctxt tytenv ct =
+ case Thm.term_of ct of
+ (Const (\<^const_name>\<open>Pure.imp\<close>, _) $ _) $ _ => imp_pconv (concl_pconv cv) ctxt tytenv ct
+ | _ => cv ctxt tytenv ct
+
+fun asm_pconv cv ctxt tytenv ct =
+ case Thm.term_of ct of
+ (Const (\<^const_name>\<open>Pure.imp\<close>, _) $ _) $ _ => CConv.with_prems_cconv ~1 (cv ctxt tytenv) ct
+ | t => raise TERM ("asm_pconv", [t])
+
+fun asms_pconv cv ctxt tytenv ct =
+ case Thm.term_of ct of
+ (Const (\<^const_name>\<open>Pure.imp\<close>, _) $ _) $ _ =>
+ ((CConv.with_prems_cconv ~1 oo cv) else_pconv imp_pconv (asms_pconv cv)) ctxt tytenv ct
+ | t => raise TERM ("asms_pconv", [t])
+
+fun judgment_pconv cv ctxt tytenv ct =
+ if Object_Logic.is_judgment ctxt (Thm.term_of ct)
+ then arg_pconv cv ctxt tytenv ct
+ else cv ctxt tytenv ct
+
+fun in_pconv cv ctxt tytenv ct =
+ (cv else_pconv
+ raw_fun_pconv (in_pconv cv) else_pconv
+ raw_arg_pconv (in_pconv cv) else_pconv
+ raw_abs_pconv (fn _ => in_pconv cv))
+ ctxt tytenv ct
+
+fun replace_idents idents t =
+ let
+ fun subst ((n1, s)::ss) (t as Free (n2, _)) = if n1 = n2 then s else subst ss t
+ | subst _ t = t
+ in Term.map_aterms (subst idents) t end
+
+fun match_pconv cv (t,fixes) ctxt (tyenv, env_ts) ct =
+ let
+ val t' = replace_idents env_ts t
+ val thy = Proof_Context.theory_of ctxt
+ val u = Thm.term_of ct
+
+ fun descend_hole fixes (Abs (_, _, t)) =
+ (case descend_hole fixes t of
+ NONE => NONE
+ | SOME (fix :: fixes', pos) => SOME (fixes', abs_pconv pos fix)
+ | SOME ([], _) => raise Match (* less fixes than abstractions on path to hole *))
+ | descend_hole fixes (t as l $ r) =
+ let val (f, _) = strip_comb t
+ in
+ if is_hole f
+ then SOME (fixes, cv)
+ else
+ (case descend_hole fixes l of
+ SOME (fixes', pos) => SOME (fixes', fun_pconv pos)
+ | NONE =>
+ (case descend_hole fixes r of
+ SOME (fixes', pos) => SOME (fixes', arg_pconv pos)
+ | NONE => NONE))
+ end
+ | descend_hole fixes t =
+ if is_hole t then SOME (fixes, cv) else NONE
+
+ val to_hole = descend_hole (rev fixes) #> the_default ([], cv) #> snd
+ in
+ case try (Pattern.match thy (apply2 Logic.mk_term (t',u))) (tyenv, Vartab.empty) of
+ NONE => raise TERM ("match_pconv: Does not match pattern", [t, t',u])
+ | SOME (tyenv', _) => to_hole t ctxt (tyenv', env_ts) ct
+ end
+
+fun rewrs_pconv to thms ctxt (tyenv, env_ts) =
+ let
+ fun instantiate_normalize_env ctxt env thm =
+ let
+ val prop = Thm.prop_of thm
+ val norm_type = Envir.norm_type o Envir.type_env
+ val insts = Term.add_vars prop []
+ |> map (fn x as (s, T) =>
+ ((s, norm_type env T), Thm.cterm_of ctxt (Envir.norm_term env (Var x))))
+ val tyinsts = Term.add_tvars prop []
+ |> map (fn x => (x, Thm.ctyp_of ctxt (norm_type env (TVar x))))
+ in Drule.instantiate_normalize (tyinsts, insts) thm end
+
+ fun unify_with_rhs context to env thm =
+ let
+ val (_, rhs) = thm |> Thm.concl_of |> Logic.dest_equals
+ val env' = Pattern.unify context (Logic.mk_term to, Logic.mk_term rhs) env
+ handle Pattern.Unif => raise NO_TO_MATCH
+ in env' end
+
+ fun inst_thm_to _ (NONE, _) thm = thm
+ | inst_thm_to (ctxt : Proof.context) (SOME to, env) thm =
+ instantiate_normalize_env ctxt (unify_with_rhs (Context.Proof ctxt) to env thm) thm
+
+ fun inst_thm ctxt idents (to, tyenv) thm =
+ let
+ (* Replace any identifiers with their corresponding bound variables. *)
+ val maxidx = Term.maxidx_typs (map (snd o snd) (Vartab.dest tyenv)) 0
+ val env = Envir.Envir {maxidx = maxidx, tenv = Vartab.empty, tyenv = tyenv}
+ val maxidx = Envir.maxidx_of env |> fold Term.maxidx_term (the_list to)
+ val thm' = Thm.incr_indexes (maxidx + 1) thm
+ in SOME (inst_thm_to ctxt (Option.map (replace_idents idents) to, env) thm') end
+ handle NO_TO_MATCH => NONE
+
+ in CConv.rewrs_cconv (map_filter (inst_thm ctxt env_ts (to, tyenv)) thms) end
+
+fun rewrite_conv ctxt (pattern, to) thms ct =
+ let
+ fun apply_pat At = judgment_pconv
+ | apply_pat In = in_pconv
+ | apply_pat Asm = params_pconv o asms_pconv
+ | apply_pat Concl = params_pconv o concl_pconv
+ | apply_pat (For idents) = (fn cv => for_pconv cv (map (apfst SOME) idents))
+ | apply_pat (Term x) = (fn cv => match_pconv cv (apsnd (map (apfst SOME)) x))
+
+ val cv = fold_rev apply_pat pattern
+
+ fun distinct_prems th =
+ case Seq.pull (distinct_subgoals_tac th) of
+ NONE => th
+ | SOME (th', _) => th'
+
+ val rewrite = rewrs_pconv to (maps (prep_meta_eq ctxt) thms)
+ in cv rewrite ctxt (Vartab.empty, []) ct |> distinct_prems end
+
+fun rewrite_export_tac ctxt (pat, pat_ctxt) thms =
+ let
+ val export = case pat_ctxt of
+ NONE => I
+ | SOME inner => singleton (Proof_Context.export inner ctxt)
+ in CCONVERSION (export o rewrite_conv ctxt pat thms) end
+
+val _ =
+ Theory.setup
+ let
+ fun mk_fix s = (Binding.name s, NONE, NoSyn)
+
+ val raw_pattern : (string, binding * string option * mixfix) pattern list parser =
+ let
+ val sep = (Args.$$$ "at" >> K At) || (Args.$$$ "in" >> K In)
+ val atom = (Args.$$$ "asm" >> K Asm) ||
+ (Args.$$$ "concl" >> K Concl) ||
+ (Args.$$$ "for" |-- Args.parens (Scan.optional Parse.vars []) >> For) ||
+ (Parse.term >> Term)
+ val sep_atom = sep -- atom >> (fn (s,a) => [s,a])
+
+ fun append_default [] = [Concl, In]
+ | append_default (ps as Term _ :: _) = Concl :: In :: ps
+ | append_default [For x, In] = [For x, Concl, In]
+ | append_default (For x :: (ps as In :: Term _:: _)) = For x :: Concl :: ps
+ | append_default ps = ps
+
+ in Scan.repeats sep_atom >> (rev #> append_default) end
+
+ fun context_lift (scan : 'a parser) f = fn (context : Context.generic, toks) =>
+ let
+ val (r, toks') = scan toks
+ val (r', context') = Context.map_proof_result (fn ctxt => f ctxt r) context
+ in (r', (context', toks' : Token.T list)) end
+
+ fun read_fixes fixes ctxt =
+ let fun read_typ (b, rawT, mx) = (b, Option.map (Syntax.read_typ ctxt) rawT, mx)
+ in Proof_Context.add_fixes (map read_typ fixes) ctxt end
+
+ fun prep_pats ctxt (ps : (string, binding * string option * mixfix) pattern list) =
+ let
+ fun add_constrs ctxt n (Abs (x, T, t)) =
+ let
+ val (x', ctxt') = yield_singleton Proof_Context.add_fixes (mk_fix x) ctxt
+ in
+ (case add_constrs ctxt' (n+1) t of
+ NONE => NONE
+ | SOME ((ctxt'', n', xs), t') =>
+ let
+ val U = Type_Infer.mk_param n []
+ val u = Type.constraint (U --> dummyT) (Abs (x, T, t'))
+ in SOME ((ctxt'', n', (x', U) :: xs), u) end)
+ end
+ | add_constrs ctxt n (l $ r) =
+ (case add_constrs ctxt n l of
+ SOME (c, l') => SOME (c, l' $ r)
+ | NONE =>
+ (case add_constrs ctxt n r of
+ SOME (c, r') => SOME (c, l $ r')
+ | NONE => NONE))
+ | add_constrs ctxt n t =
+ if is_hole_const t then SOME ((ctxt, n, []), t) else NONE
+
+ fun prep (Term s) (n, ctxt) =
+ let
+ val t = Syntax.parse_term ctxt s
+ val ((ctxt', n', bs), t') =
+ the_default ((ctxt, n, []), t) (add_constrs ctxt (n+1) t)
+ in (Term (t', bs), (n', ctxt')) end
+ | prep (For ss) (n, ctxt) =
+ let val (ns, ctxt') = read_fixes ss ctxt
+ in (For ns, (n, ctxt')) end
+ | prep At (n,ctxt) = (At, (n, ctxt))
+ | prep In (n,ctxt) = (In, (n, ctxt))
+ | prep Concl (n,ctxt) = (Concl, (n, ctxt))
+ | prep Asm (n,ctxt) = (Asm, (n, ctxt))
+
+ val (xs, (_, ctxt')) = fold_map prep ps (0, ctxt)
+
+ in (xs, ctxt') end
+
+ fun prep_args ctxt (((raw_pats, raw_to), raw_ths)) =
+ let
+
+ fun check_terms ctxt ps to =
+ let
+ fun safe_chop (0: int) xs = ([], xs)
+ | safe_chop n (x :: xs) = chop (n - 1) xs |>> cons x
+ | safe_chop _ _ = raise Match
+
+ fun reinsert_pat _ (Term (_, cs)) (t :: ts) =
+ let val (cs', ts') = safe_chop (length cs) ts
+ in (Term (t, map dest_Free cs'), ts') end
+ | reinsert_pat _ (Term _) [] = raise Match
+ | reinsert_pat ctxt (For ss) ts =
+ let val fixes = map (fn s => (s, Variable.default_type ctxt s)) ss
+ in (For fixes, ts) end
+ | reinsert_pat _ At ts = (At, ts)
+ | reinsert_pat _ In ts = (In, ts)
+ | reinsert_pat _ Concl ts = (Concl, ts)
+ | reinsert_pat _ Asm ts = (Asm, ts)
+
+ fun free_constr (s,T) = Type.constraint T (Free (s, dummyT))
+ fun mk_free_constrs (Term (t, cs)) = t :: map free_constr cs
+ | mk_free_constrs _ = []
+
+ val ts = maps mk_free_constrs ps @ the_list to
+ |> Syntax.check_terms (hole_syntax ctxt)
+ val ctxt' = fold Variable.declare_term ts ctxt
+ val (ps', (to', ts')) = fold_map (reinsert_pat ctxt') ps ts
+ ||> (fn xs => case to of NONE => (NONE, xs) | SOME _ => (SOME (hd xs), tl xs))
+ val _ = case ts' of (_ :: _) => raise Match | [] => ()
+ in ((ps', to'), ctxt') end
+
+ val (pats, ctxt') = prep_pats ctxt raw_pats
+
+ val ths = Attrib.eval_thms ctxt' raw_ths
+ val to = Option.map (Syntax.parse_term ctxt') raw_to
+
+ val ((pats', to'), ctxt'') = check_terms ctxt' pats to
+
+ in ((pats', ths, (to', ctxt)), ctxt'') end
+
+ val to_parser = Scan.option ((Args.$$$ "to") |-- Parse.term)
+
+ val subst_parser =
+ let val scan = raw_pattern -- to_parser -- Parse.thms1
+ in context_lift scan prep_args end
+ in
+ Method.setup \<^binding>\<open>rewr\<close> (subst_parser >>
+ (fn (pattern, inthms, (to, pat_ctxt)) => fn orig_ctxt =>
+ SIMPLE_METHOD'
+ (rewrite_export_tac orig_ctxt ((pattern, to), SOME pat_ctxt) inthms)))
+ "single-step rewriting, allowing subterm selection via patterns"
+ #>
+ (Method.setup \<^binding>\<open>rewrite\<close> (subst_parser >>
+ (fn (pattern, inthms, (to, pat_ctxt)) => fn orig_ctxt =>
+ SIMPLE_METHOD' (SIDE_CONDS
+ (rewrite_export_tac orig_ctxt ((pattern, to), SOME pat_ctxt) inthms)
+ orig_ctxt)))
+ "single-step rewriting with auto-typechecking")
+ end
+end
diff --git a/spartan/core/ml/tactics.ML b/spartan/core/ml/tactics.ML
new file mode 100644
index 0000000..0c71665
--- /dev/null
+++ b/spartan/core/ml/tactics.ML
@@ -0,0 +1,228 @@
+(* Title: tactics.ML
+ Author: Joshua Chen
+
+General tactics for dependent type theory.
+*)
+
+structure Tactics:
+sig
+
+val assumptions_tac: Proof.context -> int -> tactic
+val known_tac: Proof.context -> int -> tactic
+val typechk_tac: Proof.context -> int -> tactic
+val auto_typechk: bool Config.T
+val SIDE_CONDS: (int -> tactic) -> Proof.context -> int -> tactic
+val rule_tac: thm list -> Proof.context -> int -> tactic
+val dest_tac: int option -> thm list -> Proof.context -> int -> tactic
+val intro_tac: Proof.context -> int -> tactic
+val intros_tac: Proof.context -> int -> tactic
+val elim_context_tac: term list -> Proof.context -> int -> context_tactic
+val cases_tac: term -> Proof.context -> int -> tactic
+
+end = struct
+
+(*An assumption tactic that only solves typing goals with rigid terms and
+ judgmental equalities without schematic variables*)
+fun assumptions_tac ctxt = SUBGOAL (fn (goal, i) =>
+ let
+ val concl = Logic.strip_assums_concl goal
+ in
+ if
+ Lib.is_typing concl andalso Lib.is_rigid (Lib.term_of_typing concl)
+ orelse not ((exists_subterm is_Var) concl)
+ then assume_tac ctxt i
+ else no_tac
+ end)
+
+(*Solves typing goals with rigid term by resolving with context facts and
+ simplifier premises, or arbitrary goals by *non-unifying* assumption*)
+fun known_tac ctxt = SUBGOAL (fn (goal, i) =>
+ let
+ val concl = Logic.strip_assums_concl goal
+ in
+ ((if Lib.is_typing concl andalso Lib.is_rigid (Lib.term_of_typing concl)
+ then
+ let val ths = map fst (Facts.props (Proof_Context.facts_of ctxt))
+ in resolve_tac ctxt (ths @ Simplifier.prems_of ctxt) end
+ else K no_tac)
+ ORELSE' assumptions_tac ctxt) i
+ end)
+
+(*Typechecking: try to solve goals of the form "a: A" where a is rigid*)
+fun typechk_tac ctxt =
+ let
+ val tac = SUBGOAL (fn (goal, i) =>
+ if Lib.rigid_typing_concl goal
+ then
+ let val net = Tactic.build_net
+ ((Named_Theorems.get ctxt \<^named_theorems>\<open>typechk\<close>)
+ @(Named_Theorems.get ctxt \<^named_theorems>\<open>intros\<close>)
+ @(map #1 (Elim.rules ctxt)))
+ in (resolve_from_net_tac ctxt net) i end
+ else no_tac)
+ in
+ REPEAT_ALL_NEW (known_tac ctxt ORELSE' tac)
+ end
+
+fun typechk_context_tac (ctxt, st) =
+ let
+
+ in
+ ()
+ end
+
+(*Many methods try to automatically discharge side conditions by typechecking.
+ Switch this flag off to discharge by non-unifying assumption instead.*)
+val auto_typechk = Attrib.setup_config_bool \<^binding>\<open>auto_typechk\<close> (K true)
+
+fun side_cond_tac ctxt = CHANGED o REPEAT o
+ (if Config.get ctxt auto_typechk then typechk_tac ctxt else known_tac ctxt)
+
+(*Combinator runs tactic and tries to discharge all new typing side conditions*)
+fun SIDE_CONDS tac ctxt = tac THEN_ALL_NEW (TRY o side_cond_tac ctxt)
+
+local
+fun mk_rules _ ths [] = ths
+ | mk_rules n ths ths' =
+ let val ths'' = foldr1 (op @)
+ (map (fn th => [rotate_prems n (th RS @{thm PiE})] handle THM _ => []) ths')
+ in
+ mk_rules n (ths @ ths') ths''
+ end
+in
+
+(*Resolves with given rules, discharging as many side conditions as possible*)
+fun rule_tac ths ctxt = resolve_tac ctxt (mk_rules 0 [] ths)
+
+(*Attempts destruct-resolution with the n-th premise of the given rules*)
+fun dest_tac opt_n ths ctxt = dresolve_tac ctxt
+ (mk_rules (case opt_n of NONE => 0 | SOME 0 => 0 | SOME n => n-1) [] ths)
+
+end
+
+(*Applies some introduction rule*)
+fun intro_tac ctxt = SUBGOAL (fn (_, i) => SIDE_CONDS
+ (resolve_tac ctxt (Named_Theorems.get ctxt \<^named_theorems>\<open>intros\<close>)) ctxt i)
+
+fun intros_tac ctxt = SUBGOAL (fn (_, i) =>
+ (CHANGED o REPEAT o CHANGED o intro_tac ctxt) i)
+
+(* Induction/elimination *)
+
+(*Pushes a context/goal premise typing t:T into a \<Prod>-type*)
+fun internalize_fact_tac t =
+ Subgoal.FOCUS_PARAMS (fn {context = ctxt, concl = raw_concl, ...} =>
+ let
+ val concl = Logic.strip_assums_concl (Thm.term_of raw_concl)
+ val C = Lib.type_of_typing concl
+ val B = Thm.cterm_of ctxt (Lib.lambda_var t C)
+ val a = Thm.cterm_of ctxt t
+ (*The resolvent is PiE[where ?B=B and ?a=a]*)
+ val resolvent =
+ Drule.infer_instantiate' ctxt [NONE, NONE, SOME B, SOME a] @{thm PiE}
+ in
+ HEADGOAL (resolve_tac ctxt [resolvent])
+ (*known_tac infers the correct type T inferred by unification*)
+ THEN SOMEGOAL (known_tac ctxt)
+ end)
+
+(*Premises that have already been pushed into the \<Prod>-type*)
+structure Inserts = Proof_Data (
+ type T = term Item_Net.T
+ val init = K (Item_Net.init Term.aconv_untyped single)
+)
+
+local
+
+fun elim_core_tac tms types ctxt = SUBGOAL (K (
+ let
+ val rule_insts = map ((Elim.lookup_rule ctxt) o Term.head_of) types
+ val rules = flat (map
+ (fn rule_inst => case rule_inst of
+ NONE => []
+ | SOME (rl, idxnames) => [Drule.infer_instantiate ctxt
+ (idxnames ~~ map (Thm.cterm_of ctxt) tms) rl])
+ rule_insts)
+ in
+ HEADGOAL (resolve_tac ctxt rules)
+ THEN RANGE (replicate (length tms) (typechk_tac ctxt)) 1
+ end handle Option => no_tac))
+
+in
+
+fun elim_context_tac tms ctxt = case tms of
+ [] => CONTEXT_SUBGOAL (K (Context_Tactic.CONTEXT_TACTIC (HEADGOAL (
+ SIDE_CONDS (eresolve_tac ctxt (map #1 (Elim.rules ctxt))) ctxt))))
+ | major::_ => CONTEXT_SUBGOAL (fn (goal, _) =>
+ let
+ val facts = Proof_Context.facts_of ctxt
+ val prems = Logic.strip_assums_hyp goal
+ val template = Lib.typing_of_term major
+ val types =
+ map (Thm.prop_of o #1) (Facts.could_unify facts template)
+ @ filter (fn prem => Term.could_unify (template, prem)) prems
+ |> map Lib.type_of_typing
+ in case types of
+ [] => Context_Tactic.CONTEXT_TACTIC no_tac
+ | _ =>
+ let
+ val inserts = map (Thm.prop_of o fst) (Facts.props facts) @ prems
+ |> filter Lib.is_typing
+ |> map Lib.dest_typing
+ |> filter_out (fn (t, _) =>
+ Term.aconv (t, major) orelse Item_Net.member (Inserts.get ctxt) t)
+ |> map (fn (t, T) => ((t, T), Lib.subterm_count_distinct tms T))
+ |> filter (fn (_, i) => i > 0)
+ (*`t1: T1` comes before `t2: T2` if T1 contains t2 as subterm.
+ If they are incomparable, then order by decreasing
+ `subterm_count [p, x, y] T`*)
+ |> sort (fn (((t1, _), i), ((_, T2), j)) =>
+ Lib.cond_order (Lib.subterm_order T2 t1) (int_ord (j, i)))
+ |> map (#1 o #1)
+ val record_inserts = Inserts.map (fold Item_Net.update inserts)
+ val tac =
+ (*Push premises having a subterm in `tms` into a \<Prod>*)
+ fold (fn t => fn tac =>
+ tac THEN HEADGOAL (internalize_fact_tac t ctxt))
+ inserts all_tac
+ (*Apply elimination rule*)
+ THEN (HEADGOAL (
+ elim_core_tac tms types ctxt
+ (*Pull pushed premises back out*)
+ THEN_ALL_NEW (SUBGOAL (fn (_, i) =>
+ REPEAT_DETERM_N (length inserts)
+ (resolve_tac ctxt @{thms PiI} i)))
+ ))
+ (*Side conditions*)
+ THEN ALLGOALS (TRY o side_cond_tac ctxt)
+ in
+ fn (ctxt, st) => Context_Tactic.TACTIC_CONTEXT
+ (record_inserts ctxt) (tac st)
+ end
+ end)
+
+fun cases_tac tm ctxt = SUBGOAL (fn (goal, i) =>
+ let
+ val facts = Proof_Context.facts_of ctxt
+ val prems = Logic.strip_assums_hyp goal
+ val template = Lib.typing_of_term tm
+ val types =
+ map (Thm.prop_of o #1) (Facts.could_unify facts template)
+ @ filter (fn prem => Term.could_unify (template, prem)) prems
+ |> map Lib.type_of_typing
+ val res = (case types of
+ [typ] => Drule.infer_instantiate' ctxt [SOME (Thm.cterm_of ctxt tm)]
+ (the (Case.lookup_rule ctxt (Term.head_of typ)))
+ | [] => raise Option
+ | _ => raise error (Syntax.string_of_term ctxt tm ^ "not uniquely typed"))
+ handle Option => error ("no case rule known for "
+ ^ (Syntax.string_of_term ctxt tm))
+ in
+ SIDE_CONDS (resolve_tac ctxt [res]) ctxt i
+ end)
+
+end
+
+end
+
+open Tactics
diff --git a/spartan/core/ml/types.ML b/spartan/core/ml/types.ML
new file mode 100644
index 0000000..b0792fe
--- /dev/null
+++ b/spartan/core/ml/types.ML
@@ -0,0 +1,18 @@
+structure Types
+= struct
+
+structure Data = Generic_Data (
+ type T = thm Item_Net.T
+ val empty = Item_Net.init Thm.eq_thm
+ (single o Lib.term_of_typing o Thm.prop_of)
+ val extend = I
+ val merge = Item_Net.merge
+)
+
+fun put_type typing = Context.proof_map (Data.map (Item_Net.update typing))
+fun put_types typings = foldr1 (op o) (map put_type typings)
+
+fun get_types ctxt tm = Item_Net.retrieve (Data.get (Context.Proof ctxt)) tm
+
+
+end