aboutsummaryrefslogtreecommitdiff
path: root/mltt/core/goals.ML
blob: 23a6c28a0044bd0ad0b8bca3e275207ec2fb8c6c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
(*  Title:      goals.ML
    Author:     Joshua Chen

Goal statements and proof term export.

Modified from code contained in ~~/Pure/Isar/specification.ML.
*)

local

val long_keyword =
  Parse_Spec.includes >> K "" ||
  Parse_Spec.long_statement_keyword

val long_statement =
  Scan.optional
    (Parse_Spec.opt_thm_name ":" --| Scan.ahead long_keyword)
    Binding.empty_atts
  -- Scan.optional Parse_Spec.includes []
  -- Parse_Spec.long_statement >>
    (fn ((binding, includes), (elems, concl)) =>
      (true, binding, includes, elems, concl))

val short_statement =
  Parse_Spec.statement -- Parse_Spec.if_statement -- Parse.for_fixes >>
    (fn ((shows, assumes), fixes) =>
      (false, Binding.empty_atts, [],
      [Element.Fixes fixes, Element.Assumes assumes], Element.Shows shows)
    )

val where_statement = Scan.optional (Parse.$$$ "where" |-- Parse.!!! Parse_Spec.statement) []

val def_statement =
  Parse_Spec.statement -- where_statement >>
    (fn (shows, assumes) =>
      (false, Binding.empty_atts, [],
      [Element.Fixes [], Element.Assumes assumes], Element.Shows shows)
    )

fun prep_statement prep_att prep_stmt raw_elems raw_stmt ctxt =
  let
    val (stmt, elems_ctxt) = prep_stmt raw_elems raw_stmt ctxt
    val prems = Assumption.local_prems_of elems_ctxt ctxt
    val stmt_ctxt =
      fold (fold (Proof_Context.augment o fst) o snd) stmt elems_ctxt
  in case raw_stmt
    of Element.Shows _ =>
        let val stmt' = Attrib.map_specs (map prep_att) stmt
        in (([], prems, stmt', NONE), stmt_ctxt) end
    | Element.Obtains raw_obtains =>
        let
          val asms_ctxt = stmt_ctxt
            |> fold (fn ((name, _), asm) =>
                snd o Proof_Context.add_assms Assumption.assume_export
                  [((name, [Context_Rules.intro_query NONE]), asm)]) stmt
          val that = Assumption.local_prems_of asms_ctxt stmt_ctxt
          val ([(_, that')], that_ctxt) = asms_ctxt
            |> Proof_Context.set_stmt true
            |> Proof_Context.note_thmss ""
                [((Binding.name Auto_Bind.thatN, []), [(that, [])])]
            ||> Proof_Context.restore_stmt asms_ctxt
          val stmt' =
            [(Binding.empty_atts, [(#2 (#1 (Obtain.obtain_thesis ctxt)), [])])]
        in
          ((Obtain.obtains_attribs raw_obtains, prems, stmt', SOME that'),
          that_ctxt)
        end
  end

fun make_name_binding name suffix local_name =
  let val base_local_name = Long_Name.base_name local_name
  in Binding.qualified_name
    ((case base_local_name of "" => name | _ => base_local_name) ^
     (case suffix
      of SOME s => "_" ^ s
      | NONE => ""))
  end

fun define_proof_term name (local_name, [th]) lthy =
      let
        val (prems, concl) =
          (Logic.strip_assums_hyp (Thm.prop_of th),
           Logic.strip_assums_concl (Thm.prop_of th))
      in
        if not (Lib.is_typing concl) then ([], lthy)
        else let
          val prems_vars = distinct Term.aconv (
            flat (map (Lib.collect_subterms is_Var) prems))

          val concl_vars = distinct Term.aconv (
            Lib.collect_subterms is_Var (Lib.term_of_typing concl))

          val params = sort (uncurry Lib.lvl_order) (inter Term.aconv concl_vars prems_vars)

          val prf_tm = fold_rev lambda params (Lib.term_of_typing concl)

          val levels = filter Lib.is_lvl (distinct Term.aconv (
            Lib.collect_subterms is_Var prf_tm))

          val prf_tm' = fold_rev lambda levels prf_tm

          val ((_, (_, raw_def)), lthy') = Local_Theory.define
            ((make_name_binding name NONE local_name, Mixfix.NoSyn),
            ((make_name_binding name (SOME "prf") local_name, []), prf_tm')) lthy

          val def = fold
            (fn th1 => fn th2 => Thm.combination th2 th1)
            (map (Thm.reflexive o Thm.cterm_of lthy) (params @ levels))
            raw_def

          val ((_, def'), lthy'') = Local_Theory.note
            ((make_name_binding name (SOME "def") local_name, []), [def])
            lthy'
        in
          (def', lthy'')
        end
      end
  | define_proof_term _ _ _ = error
      ("Can't generate proof terms for multiple facts in one statement")

fun gen_schematic_theorem
  bundle_includes prep_att prep_stmt
  gen_prf_tm long kind defn
  before_qed after_qed
  (name, raw_atts) raw_includes raw_elems raw_concl
  do_print lthy
  =
  let
    val _ = Local_Theory.assert lthy
    val elems = raw_elems |> map (Element.map_ctxt_attrib (prep_att lthy))
    val ((more_atts, prems, stmt, facts), goal_ctxt) = lthy
      |> bundle_includes raw_includes
      |> prep_statement (prep_att lthy) prep_stmt elems raw_concl
    val atts = more_atts @ map (prep_att lthy) raw_atts
    val pos = Position.thread_data ()
    val prems_name = if long then Auto_Bind.assmsN else Auto_Bind.thatN

    fun gen_and_after_qed results goal_ctxt' =
      let
        val results' = burrow
          (map (Goal.norm_result lthy) o Proof_Context.export goal_ctxt' lthy)
          results

        val ((res, lthy'), substmts) =
          if forall (Binding.is_empty_atts o fst) stmt
          then ((map (pair "") results', lthy), false)
          else
            (Local_Theory.notes_kind kind
              (map2 (fn (b, _) => fn ths => (b, [(ths, [])])) stmt results')
              lthy,
            true)

        val ((name_def, defs), (res', lthy'')) =
          if gen_prf_tm
          then
            let
              val (prf_tm_defs, new_lthy) = fold
                (fn result => fn (defs, lthy) =>
                  apfst (fn new_defs => defs @ new_defs)
                    (define_proof_term (Binding.name_of name) result lthy))
                res
                ([], lthy')

              val res_folded =
                map (apsnd (map (Local_Defs.fold new_lthy prf_tm_defs))) res

              val name_def =
                make_name_binding (Binding.name_of name) (SOME "def") (#1 (hd res_folded))

              val name_type =
                if defn then
                  make_name_binding (Binding.name_of name) (SOME "type") (#1 (hd res_folded))
                else name
            in
              ((name_def, prf_tm_defs),
              Local_Theory.notes_kind kind
                [((name_type, @{attributes [type]} @ atts),
                  [(maps #2 res_folded, [])])]
                new_lthy)
            end
          else
            ((Binding.empty, []),
            Local_Theory.notes_kind kind
              [((name, atts), [(maps #2 res, [])])]
              lthy')

        (*Display theorems*)
        val _ =
          if defn then
            single (Proof_Display.print_results do_print pos lthy''
              ((kind, Binding.name_of name_def), [("", defs)]))
          else if not long andalso not substmts then
            single (Proof_Display.print_results do_print pos lthy''
              ((kind, Binding.name_of name), map (fn (_, ths) => ("", ths)) res'))
          else
            (if long then
              Proof_Display.print_results do_print pos lthy''
                ((kind, Binding.name_of name), map (fn (_, ths) => ("", ths)) res')
            else ();
            if substmts then
              map (fn (name, ths) =>
                    Proof_Display.print_results do_print pos lthy''
                      ((kind, name), [("", ths)]))
                  res
            else [])
      in
        after_qed results' lthy''
      end
  in
    goal_ctxt
    |> not (null prems) ?
      (Proof_Context.note_thmss "" [((Binding.name prems_name, []), [(prems, [])])]
        #> snd #> Context_Facts.register_facts prems)
    |> Proof.theorem before_qed gen_and_after_qed (map snd stmt)
    |> (case facts of NONE => I | SOME ths => Proof.refine_insert ths)
  end

val schematic_theorem_cmd =
  gen_schematic_theorem
    Bundle.includes_cmd
    Attrib.check_src
    Elaborated_Statement.read_goal_statement

fun theorem spec descr =
  Outer_Syntax.local_theory_to_proof' spec ("state " ^ descr) 
    (Scan.option (Args.parens (Args.$$$ "def"))
      -- (long_statement || short_statement) >>
        (fn (opt_derive, (long, binding, includes, elems, concl)) =>
          schematic_theorem_cmd
            (case opt_derive of SOME "def" => true | _ => false)
            long descr false NONE (K I) binding includes elems concl))

fun definition spec descr =
  Outer_Syntax.local_theory_to_proof' spec "definition with explicit type checking obligation"
    (def_statement >>
      (fn (long, binding, includes, elems, concl) =>
        schematic_theorem_cmd
          true long descr true NONE (K I) binding includes elems concl))

in

val _ = theorem \<^command_keyword>\<open>Theorem\<close> "Theorem"
val _ = theorem \<^command_keyword>\<open>Lemma\<close> "Lemma"
val _ = theorem \<^command_keyword>\<open>Corollary\<close> "Corollary"
val _ = theorem \<^command_keyword>\<open>Proposition\<close> "Proposition"
val _ = definition \<^command_keyword>\<open>Definition\<close> "Definition"

end