1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
|
structure divDefLib :> divDefLib =
struct
open primitivesBaseTacLib primitivesLib divDefTheory
val dbg = ref false
fun print_dbg s = if (!dbg) then print s else ()
val result_ty = “:'a result”
val error_ty = “:error”
val alpha_ty = “:'a”
val num_ty = “:num”
val zero_num_tm = “0:num”
val suc_tm = “SUC”
val return_tm = “Return : 'a -> 'a result”
val fail_tm = “Fail : error -> 'a result”
val fail_failure_tm = “Fail Failure : 'a result”
val diverge_tm = “Diverge : 'a result”
(* Switch to use ‘fix_exec’ (leading to executable definitions) and ‘fix’ (non
executable) *)
val use_fix_exec = ref true
val fix_tm = “fix”
val fix_exec_tm = “fix_exec”
val is_valid_fp_body_tm = “is_valid_fp_body”
fun mk_result (ty : hol_type) : hol_type = Type.type_subst [ alpha_ty |-> ty ] result_ty
fun dest_result (ty : hol_type) : hol_type =
let
val {Args=out_ty, Thy=thy, Tyop=tyop} = dest_thy_type ty
in
if thy = "primitives" andalso tyop = "result" then hd out_ty
else failwith "dest_result: not a result"
end
fun mk_return (x : term) : term = mk_icomb (return_tm, x)
fun mk_fail (ty : hol_type) (e : term) : term = mk_comb (inst [ alpha_ty |-> ty ] fail_tm, e)
fun mk_fail_failure (ty : hol_type) : term = inst [ alpha_ty |-> ty ] fail_failure_tm
fun mk_diverge (ty : hol_type) : term = inst [ alpha_ty |-> ty ] diverge_tm
fun mk_suc (n : term) = mk_comb (suc_tm, n)
fun enumerate (ls : 'a list) : (int * 'a) list =
zip (List.tabulate (List.length ls, fn i => i)) ls
(*=============================================================================*
*
* Generate the (non-recursive) body to give to the fixed-point operator
*
* ============================================================================*)
(* Small helper to generate wrappers of the shape: ‘INL x’, ‘INR (INL x)’, etc.
Note that we should have: ‘length before_tys + 1 + length after tys >= 2’
Ex.:
====
The enumeration has type: “: 'a + 'b + 'c + 'd”.
We want to generate the variant which injects “x:'c” into this enumeration.
We need to split the list of types into:
{[
before_tys = [“:'a”, “'b”]
tm = “x: 'c”
after_tys = [“:'d”]
]}
The function will generate:
{[
INR (INR (INL x) : 'a + 'b + 'c + 'd
]}
(* Debug *)
val before_tys = [“:'a”, “:'b”, “:'c”]
val tm = “x:'d”
val after_tys = [“:'e”, “:'f”]
val before_tys = [“:'a”, “:'b”, “:'c”]
val tm = “x:'d”
val after_tys = []
mk_inl_inr_wrapper before_tys tm after_tys
*)
fun list_mk_inl_inr (before_tys : hol_type list) (tm : term) (after_tys : hol_type list) :
term =
let
val (before_tys, pat) =
if after_tys = []
then
let
val just_before_ty = List.last before_tys
val before_tys = List.take (before_tys, List.length before_tys - 1)
val pat = sumSyntax.mk_inr (tm, just_before_ty)
in
(before_tys, pat)
end
else (before_tys, sumSyntax.mk_inl (tm, sumSyntax.list_mk_sum after_tys))
val pat = foldr (fn (ty, pat) => sumSyntax.mk_inr (pat, ty)) pat before_tys
in
pat
end
(* This function wraps a term into the proper variant of the input/output
sum.
Ex.:
====
For the input of the first function, we generate: ‘INL x’
For the output of the first function, we generate: ‘INR (INL x)’
For the input of the 2nd function, we generate: ‘INR (INR (INL x))’
etc.
If ‘is_input’ is true: we are wrapping an input. Otherwise we are wrapping
an output.
(* Debug *)
val tys = [(“:'a”, “:'b”), (“:'c”, “:'d”), (“:'e”, “:'f”)]
val j = 1
val tm = “x:'c”
val tm = “y:'d”
val is_input = true
*)
fun inject_in_param_sum (tys : (hol_type * hol_type) list) (j : int) (is_input : bool)
(tm : term) : term =
let
fun flatten ls = List.concat (map (fn (x, y) => [x, y]) ls)
val before_tys = flatten (List.take (tys, j))
val (input_ty, output_ty) = List.nth (tys, j)
val after_tys = flatten (List.drop (tys, j + 1))
val (before_tys, after_tys) =
if is_input then (before_tys, output_ty :: after_tys)
else (before_tys @ [input_ty], after_tys)
in
list_mk_inl_inr before_tys tm after_tys
end
(* Remark: the order of the branches when creating matches is important.
For instance, in the case of ‘result’ it must be: ‘Return’, ‘Fail’, ‘Diverge’.
For the purpose of stability and maintainability, we introduce this small helper
which reorders the cases in a pattern before actually creating the case
expression.
*)
fun unordered_mk_case (scrut: term, pats: (term * term) list) : term =
let
(* Retrieve the constructors *)
val cl = TypeBase.constructors_of (type_of scrut)
(* Retrieve the names of the constructors *)
val names = map (fst o dest_const) cl
(* Use those to reorder the patterns *)
fun is_pat (name : string) (pat, _) =
let
val app = (fst o strip_comb) pat
val app_name = (fst o dest_const) app
in
app_name = name
end
val pats = map (fn name => valOf (List.find (is_pat name) pats)) names
in
(* Create the case *)
TypeBase.mk_case (scrut, pats)
end
(* Wrap a term of type “:'a result” into a ‘case of’ which matches over
the result.
Ex.:
====
{[
f x
~~>
case f x of
| Fail e => Fail e
| Diverge => Diverge
| Return y => ... (* The branch content is generated by the continuation *)
]}
‘gen_ret_branch’ is a *continuation* which generates the content of the
‘Return’ branch (i.e., the content of the ‘...’ in the example above).
It receives as input the value contained by the ‘Return’ (i.e., the variable
‘y’ in the example above).
Remark.: the type of the term generated by ‘gen_ret_branch’ must have
the type ‘result’, but it can change the content of the result (i.e.,
if ‘scrut’ has type ‘:'a result’, we can change the type of the wrapped
expression to ‘:'b result’).
(* Debug *)
val scrut = “x: int result”
fun gen_ret_branch x = mk_return x
val scrut = “x: int result”
fun gen_ret_branch _ = “Return T”
mk_result_case scrut gen_ret_branch
*)
fun mk_result_case (scrut : term) (gen_ret_branch : term -> term) : term =
let
val scrut_ty = dest_result (type_of scrut)
(* Return branch *)
val ret_var = genvar scrut_ty
val ret_pat = mk_return ret_var
val ret_br = gen_ret_branch ret_var
val ret_ty = dest_result (type_of ret_br)
(* Failure branch *)
val fail_var = genvar error_ty
val fail_pat = mk_fail scrut_ty fail_var
val fail_br = mk_fail ret_ty fail_var
(* Diverge branch *)
val div_pat = mk_diverge scrut_ty
val div_br = mk_diverge ret_ty
in
unordered_mk_case (scrut, [(ret_pat, ret_br), (fail_pat, fail_br), (div_pat, div_br)])
end
(* Generate a ‘case ... of’ over a sum type.
Ex.:
====
If the scrutinee is: “x : 'a + 'b + 'c” (i.e., the tys list is: [“:'a”, “:b”, “:c”]),
we generate:
{[
case x of
| INL y0 => ... (* Branch of index 0 *)
| INR (INL y1) => ... (* Branch of index 1 *)
| INR (INR (INL y2)) => ... (* Branch of index 2 *)
| INR (INR (INR y3)) => ... (* Branch of index 3 *)
]}
The content of the branches is generated by the ‘gen_branch’ continuation,
which receives as input the index of the branch as well as the variable
introduced by the pattern (in the example above: ‘y0’ for the branch 0,
‘y1’ for the branch 1, etc.)
(* Debug *)
val tys = [“:'a”, “:'b”]
val scrut = mk_var ("x", sumSyntax.list_mk_sum tys)
fun gen_branch i (x : term) = “F”
val tys = [“:'a”, “:'b”, “:'c”, “:'d”]
val scrut = mk_var ("x", sumSyntax.list_mk_sum tys)
fun gen_branch i (x : term) = if type_of x = “:'c” then mk_return x else mk_fail_failure “:'c”
list_mk_sum_case scrut tys gen_branch
*)
(* For debugging *)
val list_mk_sum_case_case = ref (“T”, [] : (term * term) list)
(*
val (scrut, [(pat1, br1), (pat2, br2)]) = !list_mk_sum_case_case
*)
fun list_mk_sum_case (scrut : term) (tys : hol_type list)
(gen_branch : int -> term -> term) : term =
let
(* Create the cases. Note that without sugar, the match actually looks like this:
{[
case x of
| INL y0 => ... (* Branch of index 0 *)
| INR x1
case x1 of
| INL y1 => ... (* Branch of index 1 *)
| INR x2 =>
case x2 of
| INL y2 => ... (* Branch of index 2 *)
| INR y3 => ... (* Branch of index 3 *)
]}
*)
fun create_case (j : int) (scrut : term) (tys : hol_type list) : term =
let
val _ = print_dbg ("list_mk_sum_case: " ^
String.concatWith ", " (map type_to_string tys) ^ "\n")
in
case tys of
[] => failwith "tys is too short"
| [ ty ] =>
(* Last element: no match to perform *)
gen_branch j scrut
| ty1 :: tys =>
(* Not last: we create a pattern:
{[
case scrut of
| INL pat_var1 => ... (* Branch of index i *)
| INR pat_var2 =>
... (* Generate this term recursively *)
]}
*)
let
(* INL branch *)
val after_ty = sumSyntax.list_mk_sum tys
val pat_var1 = genvar ty1
val pat1 = sumSyntax.mk_inl (pat_var1, after_ty)
val br1 = gen_branch j pat_var1
(* INR branch *)
val pat_var2 = genvar after_ty
val pat2 = sumSyntax.mk_inr (pat_var2, ty1)
val br2 = create_case (j+1) pat_var2 tys
val _ = print_dbg ("list_mk_sum_case: assembling:\n" ^
term_to_string scrut ^ ",\n" ^
"[(" ^ term_to_string pat1 ^ ",\n " ^ term_to_string br1 ^ "),\n\n" ^
" (" ^ term_to_string pat2 ^ ",\n " ^ term_to_string br2 ^ ")]\n\n")
val case_elems = (scrut, [(pat1, br1), (pat2, br2)])
val _ = list_mk_sum_case_case := case_elems
in
(* Put everything together *)
TypeBase.mk_case case_elems
end
end
in
create_case 0 scrut tys
end
(* Generate a ‘case ... of’ to select the input/output of the ith variant of
the param enumeration.
Ex.:
====
There are two functions in the group, and we select the input of the function of index 1:
{[
case x of
| INL _ => Fail Failure (* Input of function of index 0 *)
| INR (INL _) => Fail Failure (* Output of function of index 0 *)
| INR (INR (INL y)) => Return y (* Input of the function of index 1: select this one *)
| INR (INR (INR _)) => Fail Failure (* Output of the function of index 1 *)
]}
(* Debug *)
val tys = [(“:'a”, “:'b”)]
val scrut = “x : 'a + 'b”
val fi = 0
val is_input = true
val tys = [(“:'a”, “:'b”), (“:'c”, “:'d”)]
val scrut = “x : 'a + 'b + 'c + 'd”
val fi = 1
val is_input = false
val scrut = mk_var ("x", sumSyntax.list_mk_sum (flatten tys))
list_mk_case_select scrut tys fi is_input
*)
fun list_mk_case_sum_select (scrut : term) (tys : (hol_type * hol_type) list)
(fi : int) (is_input : bool) : term =
let
(* The index of the element in the enumeration that we will select *)
val i = 2 * fi + (if is_input then 0 else 1)
(* Flatten the types and numerotate them *)
fun flatten ls = List.concat (map (fn (x, y) => [x, y]) ls)
val tys = flatten tys
(* Get the return type *)
val ret_ty = List.nth (tys, i)
(* The continuation which will generate the content of the branches *)
fun gen_branch j var = if j = i then mk_return var else mk_fail_failure ret_ty
in
(* Generate the ‘case ... of’ *)
list_mk_sum_case scrut tys gen_branch
end
(* Generate a ‘case ... of’ to select the input/output of the ith variant of
the param enumeration.
Ex.:
====
There are two functions in the group, and we select the input of the function of index 1:
{[
case x of
| Fail e => Fail e
| Diverge => Diverge
| Return r =>
case r of
| INL _ => Fail Failure (* Input of function of index 0 *)
| INR (INL _) => Fail Failure (* Output of function of index 0 *)
| INR (INR (INL y)) => Return y (* Input of the function of index 1: select this one *)
| INR (INR (INR _)) => Fail Failure (* Output of the function of index 1 *)
]}
*)
fun mk_case_select_result_sum (scrut : term) (tys : (hol_type * hol_type) list)
(fi : int) (is_input : bool) : term =
(* We match over the result, then over the enumeration *)
mk_result_case scrut (fn x => list_mk_case_sum_select x tys fi is_input)
(* Generate a body for the fixed-point operator from a quoted group of mutually
recursive definitions.
See TODO for detailed explanations: from the quoted equations for ‘nth’
(or for [‘even’, ‘odd’]) we generate the body ‘nth_body’ (or ‘even_odd_body’,
respectively).
*)
fun mk_body (fnames : string list) (in_out_tys : (hol_type * hol_type) list)
(def_tms : term list) : term =
let
val fnames_set = Redblackset.fromList String.compare fnames
(* Compute a map from function name to function index *)
val fnames_map = Redblackmap.fromList String.compare
(map (fn (x, y) => (y, x)) (enumerate fnames))
(* Compute the input/output type, that we dub the "parameter type" *)
fun flatten ls = List.concat (map (fn (x, y) => [x, y]) ls)
val param_type = sumSyntax.list_mk_sum (flatten in_out_tys)
(* Introduce a variable for the confinuation *)
val fcont = genvar (param_type --> mk_result param_type)
(* In the function equations, replace all the recursive calls with calls to the continuation.
When replacing a recursive call, we have to do two things:
- we need to inject the input parameters into the parameter type
Ex.:
- ‘nth tl i’ becomes ‘f (INL (tl, i))’ where ‘f’ is the continuation
- ‘even i’ becomes ‘f (INL i)’ where ‘f’ is the continuation
- we need to wrap the the call to the continuation into a ‘case ... of’
to extract its output (we need to make sure that the transformation
preserves the type of the expression!)
Ex.: ‘nth tl i’ becomes:
{[
case f (INL (tl, i)) of
| Fail e => Fail e
| Diverge => Diverge
| Return r =>
case r of
| INL _ => Fail Failure
| INR x => Return (INR x)
]}
*)
(* For debugging *)
val replace_rec_calls_rec_call_tm = ref “T”
fun replace_rec_calls (fnames_set : string Redblackset.set) (tm : term) : term =
let
val _ = print_dbg ("replace_rec_calls: original expression:\n" ^
term_to_string tm ^ "\n\n")
val ntm =
case dest_term tm of
VAR (name, ty) =>
(* Check that this is not one of the functions in the group - remark:
we could handle that by introducing lambdas.
*)
if Redblackset.member (fnames_set, name)
then failwith ("mk_body: not well-formed definition: found " ^ name ^
" in an improper position")
else tm
| CONST _ => tm
| LAMB (x, tm) =>
let
(* The variable might shadow one of the functions: remove it from
the set of function names - remark: Redblackset.delete raises
[NotFound] if the value is not present in the set *)
val varname = (fst o dest_var) x
val fnames_set =
if Redblackset.member (fnames_set, varname)
then Redblackset.delete (fnames_set, varname)
else fnames_set
(* Update the term in the lambda *)
val tm = replace_rec_calls fnames_set tm
in
(* Reconstruct *)
mk_abs (x, tm)
end
| COMB (_, _) =>
let
(* Completely destruct the application, check if this is a recursive call *)
val (app, args) = strip_comb tm
val is_rec_call = Redblackset.member (fnames_set, (fst o dest_var) app)
handle HOL_ERR _ => false
(* Whatever the case, apply the transformation to all the inputs *)
val args = map (replace_rec_calls fnames_set) args
in
(* If this is not a recursive call: apply the transformation to all the
terms. Otherwise, replace. *)
if not is_rec_call then list_mk_comb (replace_rec_calls fnames_set app, args)
else
(* Rec call: replace *)
let
val _ = print_dbg ("replace_rec_calls: rec call\n\n")
val _ = replace_rec_calls_rec_call_tm := tm
(* First, find the index of the function *)
val fname = (fst o dest_var) app
val fi = Redblackmap.find (fnames_map, fname)
(* Inject the input values into the param type *)
val input = pairSyntax.list_mk_pair args
val input = inject_in_param_sum in_out_tys fi true input
(* Create the recursive call *)
val call = mk_comb (fcont, input)
(* Wrap the call into a ‘case ... of’ to extract the output *)
val call = mk_case_select_result_sum call in_out_tys fi false
in
(* Return *)
call
end
end
val _ = print_dbg ("replace_rec_calls: new expression:\n" ^ term_to_string ntm ^ "\n\n")
in
ntm
end
handle HOL_ERR e =>
let
val _ = print_dbg ("replace_rec_calls: failed on:\n" ^ term_to_string tm ^ "\n\n")
in
raise (HOL_ERR e)
end
fun replace_rec_calls_in_eq (eq : term) : term =
let
val (l, r) = dest_eq eq
in
mk_eq (l, replace_rec_calls fnames_set r)
end
val def_tms_with_fcont = map replace_rec_calls_in_eq def_tms
(* Wrap all the function bodies to inject their result into the param type.
We collect the function inputs at the same time, because they will be
grouped into a tuple that we will have to deconstruct.
*)
fun inject_body_to_enums (i : int, def_eq : term) : term list * term =
let
val (l, body) = dest_eq def_eq
val (_, args) = strip_comb l
(* We have the deconstruct the result, then, in the ‘Return’ branch,
properly wrap the returned value *)
val body = mk_result_case body (fn x => mk_return (inject_in_param_sum in_out_tys i false x))
in
(args, body)
end
val def_tms_inject = map inject_body_to_enums (enumerate def_tms_with_fcont)
(* Currify the body inputs.
For instance, if the body has inputs: ‘x’, ‘y’; we return the following:
{[
(‘z’, ‘case z of (x, y) => ... (* body *) ’)
]}
where ‘z’ is fresh.
We return: (curried input, body).
(* Debug *)
val body = “(x:'a, y:'b, z:'c)”
val args = [“x:'a”, “y:'b”, “z:'c”]
currify_body_inputs (args, body)
*)
fun currify_body_inputs (args : term list, body : term) : term * term =
let
fun mk_curry (args : term list) (body : term) : term * term =
case args of
[] => failwith "no inputs"
| [x] => (x, body)
| x1 :: args =>
let
val (x2, body) = mk_curry args body
val scrut = genvar (pairSyntax.list_mk_prod (map type_of (x1 :: args)))
val pat = pairSyntax.mk_pair (x1, x2)
val br = body
in
(scrut, TypeBase.mk_case (scrut, [(pat, br)]))
end
in
mk_curry args body
end
val def_tms_currified = map currify_body_inputs def_tms_inject
(* Group all the functions into a single body, with an outer ‘case .. of’
which selects the appropriate body depending on the input *)
val param_ty = sumSyntax.list_mk_sum (flatten in_out_tys)
val input = genvar param_ty
fun mk_mut_rec_body_branch (i : int) (patvar : term) : term =
(* Case disjunction on whether the branch is for an input value (in
which case we call the proper body) or an output value (in which
case we return ‘Fail ...’ *)
if i mod 2 = 0 then
let
val fi = i div 2
val (x, def_tm) = List.nth (def_tms_currified, fi)
(* The variable in the pattern and the variable expected by the
body may not be the same: we introduce a let binding *)
val def_tm = mk_let (mk_abs (x, def_tm), patvar)
in
def_tm
end
else
(* Output value: fail *)
mk_fail_failure param_ty
val mut_rec_body = list_mk_sum_case input (flatten in_out_tys) mk_mut_rec_body_branch
(* Abstract away the parameters to produce the final body of the fixed point *)
val mut_rec_body = list_mk_abs ([fcont, input], mut_rec_body)
in
mut_rec_body
end
(*=============================================================================*
*
* Prove that the body satisfies the validity condition
*
* ============================================================================*)
(* Tactic to prove that a body is valid: perform one step. *)
fun prove_body_is_valid_tac_step (asms, g) =
let
(* The goal has the shape:
{[
(∀g h. ... g x = ... h x) ∨
∃h y. is_valid_fp_body n h ∧ ∀g. ... g x = ... od
]}
*)
(* Retrieve the scrutinee in the goal (‘x’).
There are two cases:
- either the function has the shape:
{[
(λ(y,z). ...) x
]}
in which case we need to destruct ‘x’
- or we have a normal ‘case ... of’
*)
val body = (lhs o snd o strip_forall o fst o dest_disj) g
val scrut =
let
val (app, x) = dest_comb body
val (app, _) = dest_comb app
val {Name=name, Thy=thy, Ty = _ } = dest_thy_const app
in
if thy = "pair" andalso name = "UNCURRY" then x else failwith "not a curried argument"
end
handle HOL_ERR _ => strip_all_cases_get_scrutinee body
(* Retrieve the first quantified continuations from the goal (‘g’) *)
val qc = (hd o fst o strip_forall o fst o dest_disj) g
(* Check if the scrutinee is a recursive call *)
val (scrut_app, _) = strip_comb scrut
val _ = print_dbg ("prove_body_is_valid_step: Scrutinee: " ^ term_to_string scrut ^ "\n")
(* For the recursive calls: *)
fun step_rec () =
let
val _ = print_dbg ("prove_body_is_valid_step: rec call\n")
(* We need to instantiate the ‘h’ existantially quantified function *)
(* First, retrieve the body of the function: it is given by the ‘Return’ branch *)
val (_, _, branches) = TypeBase.dest_case body
(* Find the branch corresponding to the return *)
val ret_branch = List.find (fn (pat, _) =>
let
val {Name=name, Thy=thy, Ty = _ } = (dest_thy_const o fst o strip_comb) pat
in
thy = "primitives" andalso name = "Return"
end) branches
val var = (hd o snd o strip_comb o fst o valOf) ret_branch
val br = (snd o valOf) ret_branch
(* Abstract away the input variable introduced by the pattern and the continuation ‘g’ *)
val h = list_mk_abs ([qc, var], br)
val _ = print_dbg ("prove_body_is_valid_step: h: " ^ term_to_string h ^ "\n")
(* Retrieve the input parameter ‘x’ *)
val input = (snd o dest_comb) scrut
val _ = print_dbg ("prove_body_is_valid_step: y: " ^ term_to_string input ^ "\n")
in
((* Choose the right possibility (this is a recursive call) *)
disj2_tac >>
(* Instantiate the quantifiers *)
qexists ‘^h’ >>
qexists ‘^input’ >>
(* Unfold the predicate once *)
pure_once_rewrite_tac [is_valid_fp_body_def] >>
(* We have two subgoals:
- we have to prove that ‘h’ is valid
- we have to finish the proof of validity for the current body
*)
conj_tac >> fs [case_result_switch_eq, bind_def] >>
(* The first subgoal should have been eliminated *)
gen_tac)
end
in
(* If recursive call: special treatment. Otherwise, we do a simple disjunction *)
(if term_eq scrut_app qc then step_rec ()
else (Cases_on ‘^scrut’ >> fs [case_result_switch_eq])) (asms, g)
end
(* Tactic to prove that a body is valid *)
fun prove_body_is_valid_tac (body_def : thm option) : tactic =
let val body_def_thm = case body_def of SOME th => [th] | NONE => []
in
pure_once_rewrite_tac [is_valid_fp_body_def] >> gen_tac >>
(* Expand *)
fs body_def_thm >>
fs [bind_def, case_result_switch_eq] >>
(* Explore the body *)
rpt prove_body_is_valid_tac_step
end
(* Prove that a body satisfies the validity condition of the fixed point *)
fun prove_body_is_valid (body : term) : thm =
let
(* Explore the body and count the number of occurrences of recursive
calls so that we can properly instantiate the ‘N’ argument of ‘is_valid_fp_body’
(note: we compute an overapproximation).
We first retrieve the name of the continuation parameter.
Rem.: we generated fresh names so that, for instance, the continuation name
doesn't collide with other names. Because of this, we don't need to look for
collisions when exploring the body (and in the worst case, we would cound
an overapproximation of the number of recursive calls).
*)
val fcont = (hd o fst o strip_abs) body
val fcont_name = (fst o dest_var) fcont
fun count_body_rec_calls (body : term) : int =
case dest_term body of
VAR (name, _) => if name = fcont_name then 1 else 0
| CONST _ => 0
| COMB (x, y) => count_body_rec_calls x + count_body_rec_calls y
| LAMB (_, x) => count_body_rec_calls x
val num_rec_calls = count_body_rec_calls body
(* Generate the term ‘SUC (SUC ... (SUC n))’ where ‘n’ is a fresh variable.
Remark: we first prove ‘is_valid_fp_body (SUC ... n) body’ then substitue
‘n’ with ‘0’ to prevent the quantity from being rewritten to a bit
representation, which would prevent unfolding of the ‘is_valid_fp_body’.
*)
val nvar = genvar num_ty
(* Rem.: we stack num_rec_calls + 1 occurrences of ‘SUC’ (and the + 1 is important) *)
fun mk_n i = if i = 0 then mk_suc nvar else mk_suc (mk_n (i-1))
val n_tm = mk_n num_rec_calls
(* Generate the lemma statement *)
val is_valid_tm = list_mk_icomb (is_valid_fp_body_tm, [n_tm, body])
val is_valid_thm = prove (is_valid_tm, prove_body_is_valid_tac NONE)
(* Replace ‘nvar’ with ‘0’ *)
val is_valid_thm = INST [nvar |-> zero_num_tm] is_valid_thm
in
is_valid_thm
end
(*=============================================================================*
*
* Generate the definitions with the fixed-point operator
*
* ============================================================================*)
(* Generate the raw definitions by using the grouped definition body and the
fixed point operator *)
fun mk_raw_defs (in_out_tys : (hol_type * hol_type) list)
(def_tms : term list) (body_is_valid : thm) : thm list =
let
(* Retrieve the body *)
val body = (List.last o snd o strip_comb o concl) body_is_valid
(* Create the term ‘fix_exec body’ *)
val fixed_body = mk_icomb (if !use_fix_exec then fix_exec_tm else fix_tm, body)
(* For every function in the group, generate the equation that we will
use as definition. In particular:
- add the properly injected input ‘x’ to ‘fix body’ (ex.: for ‘nth ls i’
we add the input ‘INL (ls, i)’)
- wrap ‘fix body x’ into a case disjunction to extract the relevant output
For instance, in the case of ‘nth ls i’:
{[
nth (ls : 't list_t) (i : u32) =
case fix nth_body (INL (ls, i)) of
| Fail e => Fail e
| Diverge => Diverge
| Return r =>
case r of
| INL _ => Fail Failure
| INR x => Return x
]}
*)
fun mk_def_eq (i : int, def_tm : term) : term =
let
(* Retrieve the lhs of the original definition equation, and in
particular the inputs *)
val def_lhs = lhs def_tm
val args = (snd o strip_comb) def_lhs
(* Inject the inputs into the param type *)
val input = pairSyntax.list_mk_pair args
val input = inject_in_param_sum in_out_tys i true input
(* Compose*)
val def_rhs = mk_comb (fixed_body, input)
(* Wrap in the case disjunction *)
val def_rhs = mk_case_select_result_sum def_rhs in_out_tys i false
(* Create the equation *)
val def_eq_tm = mk_eq (def_lhs, def_rhs)
in
def_eq_tm
end
val raw_def_tms = map mk_def_eq (enumerate def_tms)
(* Generate the definitions *)
val raw_defs = map (fn tm => Define ‘^tm’) raw_def_tms
in
raw_defs
end
(*=============================================================================*
*
* Prove that the definitions satisfy the target equations
*
* ============================================================================*)
(* Tactic which makes progress in a proof by making a case disjunction (we use
this to explore all the paths in a function body). *)
fun case_progress (asms, g) =
let
val scrut = (strip_all_cases_get_scrutinee o lhs) g
in Cases_on ‘^scrut’ (asms, g) end
(* Prove the final equation, that we will use as definition. *)
fun prove_def_eq_tac
(current_raw_def : thm) (all_raw_defs : thm list) (is_valid : thm)
(body_def : thm option) : tactic =
let
val body_def_thm = case body_def of SOME th => [th] | NONE => []
val fix_eq = if !use_fix_exec then fix_exec_fixed_eq else fix_fixed_eq
in
rpt gen_tac >>
(* Expand the definition *)
pure_once_rewrite_tac [current_raw_def] >>
(* Use the fixed-point equality *)
pure_once_rewrite_left_tac [HO_MATCH_MP fix_eq is_valid] >>
(* Expand the body definition *)
pure_rewrite_tac body_def_thm >>
(* Expand all the definitions from the group *)
pure_rewrite_tac all_raw_defs >>
(* Explore all the paths - maybe we can be smarter, but this is fast and really easy *)
fs [bind_def] >>
rpt (case_progress >> fs [])
end
(* Prove the final equations that we will give to the user as definitions *)
fun prove_def_eqs (body_is_valid : thm) (def_tms : term list) (raw_defs : thm list) : thm list=
let
val defs_tgt_raw = zip def_tms raw_defs
(* Substitute the function variables with the constants introduced in the raw
definitions *)
fun compute_fsubst (def_tm, raw_def) : {redex: term, residue: term} =
let
val (fvar, _) = (strip_comb o lhs) def_tm
val fconst = (fst o strip_comb o lhs o snd o strip_forall o concl) raw_def
in
(fvar |-> fconst)
end
val fsubst = map compute_fsubst defs_tgt_raw
val defs_tgt_raw = map (fn (x, y) => (subst fsubst x, y)) defs_tgt_raw
fun prove_def_eq (def_tm, raw_def) : thm =
let
(* Quantify the parameters *)
val (_, params) = (strip_comb o lhs) def_tm
val def_eq_tm = list_mk_forall (params, def_tm)
(* Prove *)
val def_eq = prove (def_eq_tm, prove_def_eq_tac raw_def raw_defs body_is_valid NONE)
in
def_eq
end
val def_eqs = map prove_def_eq defs_tgt_raw
in
def_eqs
end
(*=============================================================================*
*
* The final DefineDiv function
*
* ============================================================================*)
fun DefineDiv (def_qt : term quotation) =
let
(* Parse the definitions *)
val def_tms = (strip_conj o list_mk_conj o rev) (Defn.parse_quote def_qt)
(* Compute the names and the input/output types of the functions *)
fun compute_names_in_out_tys (tm : term) : string * (hol_type * hol_type) =
let
val app = lhs tm
val name = (fst o dest_var o fst o strip_comb) app
val out_ty = dest_result (type_of app)
val input_tys = pairSyntax.list_mk_prod (map type_of ((snd o strip_comb) app))
in
(name, (input_tys, out_ty))
end
val (fnames, in_out_tys) = unzip (map compute_names_in_out_tys def_tms)
(* Generate the body to give to the fixed-point operator *)
val body = mk_body fnames in_out_tys def_tms
(* Prove that the body satisfies the validity property required by the fixed point *)
val body_is_valid = prove_body_is_valid body
(* Generate the definitions for the various functions by using the fixed point
and the body *)
val raw_defs = mk_raw_defs in_out_tys def_tms body_is_valid
(* Prove the final equations *)
val def_eqs = prove_def_eqs body_is_valid def_tms raw_defs
(* Save the final equations as definitions. *)
val thm_names = map (fn x => x ^ "_def") fnames
(* Because [store_definition] overrides existing names, it seems that in
practice we don't really need to delete the previous definitions
(we still do it: it doesn't cost much). *)
val _ = app delete_binding thm_names
val _ = map store_definition (zip thm_names def_eqs)
(* Also save the custom unfoldings, for evaluation (unit tests) *)
val _ = evalLib.add_unfold_thms thm_names
in
def_eqs
end
end
|