aboutsummaryrefslogtreecommitdiff
path: root/spartan/core/eqsubst.ML
blob: ea6f09826b51090a73d5275e09c6b48a60635f9c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
(*  Title:      eqsubst.ML
    Author:     Lucas Dixon, University of Edinburgh
    Modified:   Joshua Chen, University of Innsbruck

Perform a substitution using an equation.

This code is slightly modified from the original at Tools/eqsubst..ML,
to incorporate auto-typechecking for type theory.
*)

signature EQSUBST =
sig
  type match =
    ((indexname * (sort * typ)) list (* type instantiations *)
      * (indexname * (typ * term)) list) (* term instantiations *)
    * (string * typ) list (* fake named type abs env *)
    * (string * typ) list (* type abs env *)
    * term (* outer term *)

  type searchinfo =
    Proof.context
    * int (* maxidx *)
    * Zipper.T (* focusterm to search under *)

  datatype 'a skipseq = SkipMore of int | SkipSeq of 'a Seq.seq Seq.seq

  val skip_first_asm_occs_search: ('a -> 'b -> 'c Seq.seq Seq.seq) -> 'a -> int -> 'b -> 'c skipseq
  val skip_first_occs_search: int -> ('a -> 'b -> 'c Seq.seq Seq.seq) -> 'a -> 'b -> 'c Seq.seq
  val skipto_skipseq: int -> 'a Seq.seq Seq.seq -> 'a skipseq

  (* tactics *)
  val eqsubst_asm_tac: Proof.context -> int list -> thm list -> int -> tactic
  val eqsubst_asm_tac': Proof.context ->
    (searchinfo -> int -> term -> match skipseq) -> int -> thm -> int -> tactic
  val eqsubst_tac: Proof.context ->
    int list -> (* list of occurrences to rewrite, use [0] for any *)
    thm list -> int -> tactic
  val eqsubst_tac': Proof.context ->
    (searchinfo -> term -> match Seq.seq) (* search function *)
    -> thm (* equation theorem to rewrite with *)
    -> int (* subgoal number in goal theorem *)
    -> thm (* goal theorem *)
    -> thm Seq.seq (* rewritten goal theorem *)

  (* search for substitutions *)
  val valid_match_start: Zipper.T -> bool
  val search_lr_all: Zipper.T -> Zipper.T Seq.seq
  val search_lr_valid: (Zipper.T -> bool) -> Zipper.T -> Zipper.T Seq.seq
  val searchf_lr_unify_all: searchinfo -> term -> match Seq.seq Seq.seq
  val searchf_lr_unify_valid: searchinfo -> term -> match Seq.seq Seq.seq
  val searchf_bt_unify_valid: searchinfo -> term -> match Seq.seq Seq.seq
end;

structure EqSubst: EQSUBST =
struct

(* changes object "=" to meta "==" which prepares a given rewrite rule *)
fun prep_meta_eq ctxt =
  Simplifier.mksimps ctxt #> map Drule.zero_var_indexes;

(* make free vars into schematic vars with index zero *)
fun unfix_frees frees =
   fold (K (Thm.forall_elim_var 0)) frees o Drule.forall_intr_list frees;


type match =
  ((indexname * (sort * typ)) list (* type instantiations *)
   * (indexname * (typ * term)) list) (* term instantiations *)
  * (string * typ) list (* fake named type abs env *)
  * (string * typ) list (* type abs env *)
  * term; (* outer term *)

type searchinfo =
  Proof.context
  * int (* maxidx *)
  * Zipper.T; (* focusterm to search under *)


(* skipping non-empty sub-sequences but when we reach the end
   of the seq, remembering how much we have left to skip. *)
datatype 'a skipseq =
  SkipMore of int |
  SkipSeq of 'a Seq.seq Seq.seq;

(* given a seqseq, skip the first m non-empty seq's, note deficit *)
fun skipto_skipseq m s =
  let
    fun skip_occs n sq =
      (case Seq.pull sq of
        NONE => SkipMore n
      | SOME (h, t) =>
        (case Seq.pull h of
          NONE => skip_occs n t
        | SOME _ => if n <= 1 then SkipSeq (Seq.cons h t) else skip_occs (n - 1) t))
  in skip_occs m s end;

(* note: outerterm is the taget with the match replaced by a bound
   variable : ie: "P lhs" beocmes "%x. P x"
   insts is the types of instantiations of vars in lhs
   and typinsts is the type instantiations of types in the lhs
   Note: Final rule is the rule lifted into the ontext of the
   taget thm. *)
fun mk_foo_match mkuptermfunc Ts t =
  let
    val ty = Term.type_of t
    val bigtype = rev (map snd Ts) ---> ty
    fun mk_foo 0 t = t
      | mk_foo i t = mk_foo (i - 1) (t $ (Bound (i - 1)))
    val num_of_bnds = length Ts
    (* foo_term = "fooabs y0 ... yn" where y's are local bounds *)
    val foo_term = mk_foo num_of_bnds (Bound num_of_bnds)
  in Abs ("fooabs", bigtype, mkuptermfunc foo_term) end;

(* T is outer bound vars, n is number of locally bound vars *)
(* THINK: is order of Ts correct...? or reversed? *)
fun mk_fake_bound_name n = ":b_" ^ n;
fun fakefree_badbounds Ts t =
  let val (FakeTs, Ts, newnames) =
    fold_rev (fn (n, ty) => fn (FakeTs, Ts, usednames) =>
      let
        val newname = singleton (Name.variant_list usednames) n
      in
        ((mk_fake_bound_name newname, ty) :: FakeTs,
          (newname, ty) :: Ts,
          newname :: usednames)
      end) Ts ([], [], [])
  in (FakeTs, Ts, Term.subst_bounds (map Free FakeTs, t)) end;

(* before matching we need to fake the bound vars that are missing an
   abstraction. In this function we additionally construct the
   abstraction environment, and an outer context term (with the focus
   abstracted out) for use in rewriting with RW_Inst.rw *)
fun prep_zipper_match z =
  let
    val t = Zipper.trm z
    val c = Zipper.ctxt z
    val Ts = Zipper.C.nty_ctxt c
    val (FakeTs', Ts', t') = fakefree_badbounds Ts t
    val absterm = mk_foo_match (Zipper.C.apply c) Ts' t'
  in
    (t', (FakeTs', Ts', absterm))
  end;

(* Unification with exception handled *)
(* given context, max var index, pat, tgt; returns Seq of instantiations *)
fun clean_unify ctxt ix (a as (pat, tgt)) =
  let
    (* type info will be re-derived, maybe this can be cached
       for efficiency? *)
    val pat_ty = Term.type_of pat;
    val tgt_ty = Term.type_of tgt;
    (* FIXME is it OK to ignore the type instantiation info?
       or should I be using it? *)
    val typs_unify =
      SOME (Sign.typ_unify (Proof_Context.theory_of ctxt) (pat_ty, tgt_ty) (Vartab.empty, ix))
        handle Type.TUNIFY => NONE;
  in
    (case typs_unify of
      SOME (typinsttab, ix2) =>
        let
          (* FIXME is it right to throw away the flexes?
             or should I be using them somehow? *)
          fun mk_insts env =
            (Vartab.dest (Envir.type_env env),
             Vartab.dest (Envir.term_env env));
          val initenv =
            Envir.Envir {maxidx = ix2, tenv = Vartab.empty, tyenv = typinsttab};
          val useq = Unify.smash_unifiers (Context.Proof ctxt) [a] initenv
            handle ListPair.UnequalLengths => Seq.empty
              | Term.TERM _ => Seq.empty;
          fun clean_unify' useq () =
            (case (Seq.pull useq) of
               NONE => NONE
             | SOME (h, t) => SOME (mk_insts h, Seq.make (clean_unify' t)))
            handle ListPair.UnequalLengths => NONE
              | Term.TERM _ => NONE;
        in
          (Seq.make (clean_unify' useq))
        end
    | NONE => Seq.empty)
  end;

(* Unification for zippers *)
(* Note: Ts is a modified version of the original names of the outer
   bound variables. New names have been introduced to make sure they are
   unique w.r.t all names in the term and each other. usednames' is
   oldnames + new names. *)
fun clean_unify_z ctxt maxidx pat z =
  let val (t, (FakeTs, Ts, absterm)) = prep_zipper_match z in
    Seq.map (fn insts => (insts, FakeTs, Ts, absterm))
      (clean_unify ctxt maxidx (t, pat))
  end;


fun bot_left_leaf_of (l $ _) = bot_left_leaf_of l
  | bot_left_leaf_of (Abs (_, _, t)) = bot_left_leaf_of t
  | bot_left_leaf_of x = x;

(* Avoid considering replacing terms which have a var at the head as
   they always succeed trivially, and uninterestingly. *)
fun valid_match_start z =
  (case bot_left_leaf_of (Zipper.trm z) of
    Var _ => false
  | _ => true);

(* search from top, left to right, then down *)
val search_lr_all = ZipperSearch.all_bl_ur;

(* search from top, left to right, then down *)
fun search_lr_valid validf =
  let
    fun sf_valid_td_lr z =
      let val here = if validf z then [Zipper.Here z] else [] in
        (case Zipper.trm z of
          _ $ _ =>
            [Zipper.LookIn (Zipper.move_down_left z)] @ here @
            [Zipper.LookIn (Zipper.move_down_right z)]
        | Abs _ => here @ [Zipper.LookIn (Zipper.move_down_abs z)]
        | _ => here)
      end;
  in Zipper.lzy_search sf_valid_td_lr end;

(* search from bottom to top, left to right *)
fun search_bt_valid validf =
  let
    fun sf_valid_td_lr z =
      let val here = if validf z then [Zipper.Here z] else [] in
        (case Zipper.trm z of
          _ $ _ =>
            [Zipper.LookIn (Zipper.move_down_left z),
             Zipper.LookIn (Zipper.move_down_right z)] @ here
        | Abs _ => [Zipper.LookIn (Zipper.move_down_abs z)] @ here
        | _ => here)
      end;
  in Zipper.lzy_search sf_valid_td_lr end;

fun searchf_unify_gen f (ctxt, maxidx, z) lhs =
  Seq.map (clean_unify_z ctxt maxidx lhs) (Zipper.limit_apply f z);

(* search all unifications *)
val searchf_lr_unify_all = searchf_unify_gen search_lr_all;

(* search only for 'valid' unifiers (non abs subterms and non vars) *)
val searchf_lr_unify_valid = searchf_unify_gen (search_lr_valid valid_match_start);

val searchf_bt_unify_valid = searchf_unify_gen (search_bt_valid valid_match_start);

(* apply a substitution in the conclusion of the theorem *)
(* cfvs are certified free var placeholders for goal params *)
(* conclthm is a theorem of for just the conclusion *)
(* m is instantiation/match information *)
(* rule is the equation for substitution *)
fun apply_subst_in_concl ctxt i st (cfvs, conclthm) rule m =
  RW_Inst.rw ctxt m rule conclthm
  |> unfix_frees cfvs
  |> Conv.fconv_rule Drule.beta_eta_conversion
  |> (fn r => resolve_tac ctxt [r] i st);

(* substitute within the conclusion of goal i of gth, using a meta
equation rule. Note that we assume rule has var indicies zero'd *)
fun prep_concl_subst ctxt i gth =
  let
    val th = Thm.incr_indexes 1 gth;
    val tgt_term = Thm.prop_of th;

    val (fixedbody, fvs) = IsaND.fix_alls_term ctxt i tgt_term;
    val cfvs = rev (map (Thm.cterm_of ctxt) fvs);

    val conclterm = Logic.strip_imp_concl fixedbody;
    val conclthm = Thm.trivial (Thm.cterm_of ctxt conclterm);
    val maxidx = Thm.maxidx_of th;
    val ft =
      (Zipper.move_down_right (* ==> *)
       o Zipper.move_down_left (* Trueprop *)
       o Zipper.mktop
       o Thm.prop_of) conclthm
  in
    ((cfvs, conclthm), (ctxt, maxidx, ft))
  end;

(* substitute using an object or meta level equality *)
fun eqsubst_tac' ctxt searchf instepthm i st =
  let
    val (cvfsconclthm, searchinfo) = prep_concl_subst ctxt i st;
    val stepthms = Seq.of_list (prep_meta_eq ctxt instepthm);
    fun rewrite_with_thm r =
      let val (lhs,_) = Logic.dest_equals (Thm.concl_of r) in
        searchf searchinfo lhs
        |> Seq.maps (apply_subst_in_concl ctxt i st cvfsconclthm r)
      end;
  in stepthms |> Seq.maps rewrite_with_thm end;


(* General substitution of multiple occurrences using one of
   the given theorems *)

fun skip_first_occs_search occ srchf sinfo lhs =
  (case skipto_skipseq occ (srchf sinfo lhs) of
    SkipMore _ => Seq.empty
  | SkipSeq ss => Seq.flat ss);

(* The "occs" argument is a list of integers indicating which occurrence
w.r.t. the search order, to rewrite. Backtracking will also find later
occurrences, but all earlier ones are skipped. Thus you can use [0] to
just find all rewrites. *)

fun eqsubst_tac ctxt occs thms i st =
  let val nprems = Thm.nprems_of st in
    if nprems < i then Seq.empty else
    let
      val thmseq = Seq.of_list thms;
      fun apply_occ occ st =
        thmseq |> Seq.maps (fn r =>
          eqsubst_tac' ctxt
            (skip_first_occs_search occ searchf_lr_unify_valid) r
            (i + (Thm.nprems_of st - nprems)) st);
      val sorted_occs = Library.sort (rev_order o int_ord) occs;
    in
      Seq.maps distinct_subgoals_tac (Seq.EVERY (map apply_occ sorted_occs) st)
    end
  end;


(* apply a substitution inside assumption j, keeps asm in the same place *)
fun apply_subst_in_asm ctxt i st rule ((cfvs, j, _, pth),m) =
  let
    val st2 = Thm.rotate_rule (j - 1) i st; (* put premice first *)
    val preelimrule =
      RW_Inst.rw ctxt m rule pth
      |> (Seq.hd o prune_params_tac ctxt)
      |> Thm.permute_prems 0 ~1 (* put old asm first *)
      |> unfix_frees cfvs (* unfix any global params *)
      |> Conv.fconv_rule Drule.beta_eta_conversion; (* normal form *)
  in
    (* ~j because new asm starts at back, thus we subtract 1 *)
    Seq.map (Thm.rotate_rule (~ j) (Thm.nprems_of rule + i))
      (dresolve_tac ctxt [preelimrule] i st2)
  end;


(* prepare to substitute within the j'th premise of subgoal i of gth,
using a meta-level equation. Note that we assume rule has var indicies
zero'd. Note that we also assume that premt is the j'th premice of
subgoal i of gth. Note the repetition of work done for each
assumption, i.e. this can be made more efficient for search over
multiple assumptions.  *)
fun prep_subst_in_asm ctxt i gth j =
  let
    val th = Thm.incr_indexes 1 gth;
    val tgt_term = Thm.prop_of th;

    val (fixedbody, fvs) = IsaND.fix_alls_term ctxt i tgt_term;
    val cfvs = rev (map (Thm.cterm_of ctxt) fvs);

    val asmt = nth (Logic.strip_imp_prems fixedbody) (j - 1);
    val asm_nprems = length (Logic.strip_imp_prems asmt);

    val pth = Thm.trivial ((Thm.cterm_of ctxt) asmt);
    val maxidx = Thm.maxidx_of th;

    val ft =
      (Zipper.move_down_right (* trueprop *)
         o Zipper.mktop
         o Thm.prop_of) pth
  in ((cfvs, j, asm_nprems, pth), (ctxt, maxidx, ft)) end;

(* prepare subst in every possible assumption *)
fun prep_subst_in_asms ctxt i gth =
  map (prep_subst_in_asm ctxt i gth)
    ((fn l => Library.upto (1, length l))
      (Logic.prems_of_goal (Thm.prop_of gth) i));


(* substitute in an assumption using an object or meta level equality *)
fun eqsubst_asm_tac' ctxt searchf skipocc instepthm i st =
  let
    val asmpreps = prep_subst_in_asms ctxt i st;
    val stepthms = Seq.of_list (prep_meta_eq ctxt instepthm);
    fun rewrite_with_thm r =
      let
        val (lhs,_) = Logic.dest_equals (Thm.concl_of r);
        fun occ_search occ [] = Seq.empty
          | occ_search occ ((asminfo, searchinfo)::moreasms) =
              (case searchf searchinfo occ lhs of
                SkipMore i => occ_search i moreasms
              | SkipSeq ss =>
                  Seq.append (Seq.map (Library.pair asminfo) (Seq.flat ss))
                    (occ_search 1 moreasms)) (* find later substs also *)
      in
        occ_search skipocc asmpreps |> Seq.maps (apply_subst_in_asm ctxt i st r)
      end;
  in stepthms |> Seq.maps rewrite_with_thm end;


fun skip_first_asm_occs_search searchf sinfo occ lhs =
  skipto_skipseq occ (searchf sinfo lhs);

fun eqsubst_asm_tac ctxt occs thms i st =
  let val nprems = Thm.nprems_of st in
    if nprems < i then Seq.empty
    else
      let
        val thmseq = Seq.of_list thms;
        fun apply_occ occ st =
          thmseq |> Seq.maps (fn r =>
            eqsubst_asm_tac' ctxt
              (skip_first_asm_occs_search searchf_lr_unify_valid) occ r
              (i + (Thm.nprems_of st - nprems)) st);
        val sorted_occs = Library.sort (rev_order o int_ord) occs;
      in
        Seq.maps distinct_subgoals_tac (Seq.EVERY (map apply_occ sorted_occs) st)
      end
  end;

(* combination method that takes a flag (true indicates that subst
   should be done to an assumption, false = apply to the conclusion of
   the goal) as well as the theorems to use *)
val _ =
  Theory.setup
    (Method.setup \<^binding>\<open>sub\<close>
      (Scan.lift (Args.mode "asm" -- Scan.optional (Args.parens (Scan.repeat Parse.nat)) [0]) --
        Attrib.thms >> (fn ((asm, occs), inthms) => fn ctxt =>
          SIMPLE_METHOD' ((if asm then eqsubst_asm_tac else eqsubst_tac) ctxt occs inthms)))
      "single-step substitution"
    #>
    (Method.setup \<^binding>\<open>subst\<close>
      (Scan.lift (Args.mode "asm" -- Scan.optional (Args.parens (Scan.repeat Parse.nat)) [0]) --
        Attrib.thms >> (fn ((asm, occs), inthms) => fn ctxt =>
          SIMPLE_METHOD' (SIDE_CONDS
            ((if asm then eqsubst_asm_tac else eqsubst_tac) ctxt occs inthms)
            ctxt)))
      "single-step substitution with auto-typechecking"))

end;