From 191e8f1b2393e108da38470d7e22f0574567eb25 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Wed, 10 May 2023 14:19:59 +0200 Subject: Improve the performance of dependent rewrites by using nets --- backends/hol4/primitivesBaseTacLib.sml | 155 ++++++++++++++++++++++----------- backends/hol4/primitivesLib.sml | 23 +++-- 2 files changed, 120 insertions(+), 58 deletions(-) (limited to 'backends/hol4') diff --git a/backends/hol4/primitivesBaseTacLib.sml b/backends/hol4/primitivesBaseTacLib.sml index 1e874ad5..fe87e894 100644 --- a/backends/hol4/primitivesBaseTacLib.sml +++ b/backends/hol4/primitivesBaseTacLib.sml @@ -242,43 +242,53 @@ val th = SPEC_ALL NUM_SUB_1_EQ (* Call a matching function on all the subterms in the provided list of term. This is a generic function. - [try_match] should return an instantiated theorem, as well as a term which - identifies this theorem (the lhs of the equality, if this is a rewriting + [try_match] should return a list of instantiated theorems, as well as terms which + identify those theorem (the lhs of the equality, if this is a rewriting theorem for instance - we use this to check for collisions, and discard redundant instantiations). It takes as input the set of bound variables (it should not perform substitutions with variables belonging to this set). *) fun inst_match_in_terms - (try_match: string Redblackset.set -> term -> term * thm) + (try_match: string Redblackset.set -> term -> (term * thm) list) (tml : term list) : thm list = let (* We use a map when storing the theorems, to avoid storing the same theorem twice *) val inst_thms: (term, thm) Redblackmap.dict ref = ref (Redblackmap.mkDict Term.compare); fun try_instantiate bvars t = let - val (inst_th_tm, inst_th) = try_match bvars t; + val matched_thms = try_match bvars t; + fun insert_th (inst_th_tm, inst_th) = + inst_thms := Redblackmap.insert (!inst_thms, inst_th_tm, inst_th) in - inst_thms := Redblackmap.insert (!inst_thms, inst_th_tm, inst_th) + List.app insert_th matched_thms end - handle HOL_ERR _ => (); (* Explore the term *) val _ = List.app (dep_apply_in_subterms try_instantiate (Redblackset.empty String.compare)) tml; in map snd (Redblackmap.listItems (!inst_thms)) end -(* Given a rewriting theorem [th] which has premises, return all the - instantiations of this theorem which make its conclusion match subterms +(* Given a net of rewriting theorems [ths] which have premises, return all the + instantiations of those theorems which make its conclusion match subterms in the provided list of term. [keep]: if this function returns false on an instantiated theorem, ignore the theorem. + + The theorems in the Net should be of the shape: + {[ + H0, ..., Hn ⊢ x = y + ]} + (no implications, no quantified variables) + For the above theorem, the key term used in the net should be ‘x’. *) -fun inst_match_concl_in_terms (keep : thm -> bool) (th : thm) (tml : term list) : thm list = +fun inst_match_concl_in_terms (keep : thm -> bool) (ths : thm Net.net) (tml : term list) : thm list = let - val th = (UNDISCH_ALL o SPEC_ALL) th; - fun try_match bvars t = + (* First, find potential matches in the net *) + fun find_thms (t : term) : thm list = Net.match t ths + (* Then, match more precisely for every theorem found *) + fun try_match (bvars : string Redblackset.set) t th = let val _ = print_dbg ("inst_match_concl_in_terms: " ^ term_to_string t ^ "\n") val inst_th = inst_match_concl bvars th t @@ -290,42 +300,72 @@ fun inst_match_concl_in_terms (keep : thm -> bool) (th : thm) (tml : term list) else let val _ = print_dbg ("inst_match_concl_in_terms: matched failed\n") in failwith "inst_match_concl_in_terms: ignore theorem" end - end; + end + (* Compose *) + fun try_match_on_thms bvars t = + let + val matched_thms = find_thms t + in + mapfilter (try_match bvars t) matched_thms + end in - inst_match_in_terms try_match tml + inst_match_in_terms try_match_on_thms tml end (* val t = “!x. u32_to_int (int_to_u32 x) = u32_to_int (int_to_u32 y)” -val th = int_to_u32_id +val th = u32_to_int_int_to_u32 +val th = (UNDISCH_ALL o SPEC_ALL) th +val ths = Net.insert ((lhs o concl) th, th) Net.empty +val keep = fn _ => true -val thms = inst_match_concl_in_terms int_to_u32_id [t] +val thms = inst_match_concl_in_terms keep ths [t] *) -(* Given a theorem [th] which has premises, return all the - instantiations of this theorem which make its first premise match subterms +(* Given a net of theorems which have premises, return all the + instantiations of those theorems which make their first premise match subterms in the provided list of term. + + The theorems in the Net should be of the shape: + {[ + ⊢ H0 => ... => Hn => H + ]} + (no quantified variables) + For the above theorem, the matching term used in the net should be ‘H0’. *) -fun inst_match_first_premise_in_terms (keep : thm -> bool) (th : thm) (tml : term list) : thm list = +fun inst_match_first_premise_in_terms + (keep : thm -> bool) (ths : thm Net.net) (tml : term list) : thm list = let - val th = SPEC_ALL th; - fun try_match bvars t = + (* First, find potential matches in the net *) + fun find_thms (t : term) : thm list = Net.match t ths + (* Then, match more precisely for every theorem found *) + fun try_match bvars t th = let val inst_th = inst_match_first_premise bvars th t; in if keep inst_th then ((fst o dest_imp o concl) inst_th, inst_th) else failwith "inst_match_first_premise_in_terms: ignore theorem" - end; + end + (* Compose *) + fun try_match_on_thms bvars t = + let + val matched_thms = find_thms t + in + mapfilter (try_match bvars t) matched_thms + end in - inst_match_in_terms try_match tml + inst_match_in_terms try_match_on_thms tml end (* -val t = “x = y - 1 ==> T” -val th = SPEC_ALL NUM_SUB_1_EQ +val t = “n : int = m - 1 ==> T” +val th = prove (“x: int = y - 1 ==> x + 1 = y”, COOPER_TAC) +val th = SPEC_ALL th +val ths = Net.insert ((fst o dest_imp o concl) th, th) Net.empty +val keep = fn _ => true -val thms = inst_match_first_premise_in_terms th [t] +val thms = inst_match_first_premise_in_terms keep ths [t] *) @@ -333,39 +373,50 @@ val thms = inst_match_first_premise_in_terms th [t] conclusion with subterms in the given list of terms. Leaves the premises as subgoals if [prove_premise] doesn't prove them. + + The theorems in the Net should be of the shape: + {[ + H0, ..., Hn ⊢ x = y + ]} + (no implications, no quantified variables) + For the above theorem, the key term used in the net should be ‘x’. *) fun apply_dep_rewrites_match_concl_with_terms_tac (prove_premise : tactic) (then_tac : thm_tactic) (ignore_tml : term list) - (tml : term list) (th : thm) : tactic = + (tml : term list) (ths : thm Net.net) : tactic = let val ignore = Redblackset.fromList Term.compare ignore_tml fun keep th = not (Redblackset.member (ignore, concl th)) (* Discharge the assumptions so that the goal is one single term *) - val thms = inst_match_concl_in_terms keep th tml + val thms = inst_match_concl_in_terms keep ths tml val thms = map thm_to_conj_implies thms in - (* Apply each theorem *) + (* Try to prove each theorem, and insert the result in the subgoal *) map_every_tac (try_tac o sg_premise_then prove_premise then_tac) thms end -(* Attempt to apply dependent rewrites with a theorem by matching its +(* Attempt to apply dependent rewrites with theorems by matching their conclusion with subterms of the goal (including the assumptions). Leaves the premises as subgoals if [prove_premise] doesn't prove them. + + See [apply_dep_rewrites_match_concl_with_terms_tac] for the shape of + the theorems used in the net. *) fun apply_dep_rewrites_match_concl_with_all_tac - (prove_premise : tactic) (then_tac : thm_tactic) (th : thm) : tactic = + (prove_premise : tactic) (then_tac : thm_tactic) (ths : thm Net.net) : tactic = fn (asms, g) => - apply_dep_rewrites_match_concl_with_terms_tac prove_premise then_tac asms (g :: asms) th (asms, g) + apply_dep_rewrites_match_concl_with_terms_tac prove_premise then_tac + asms (g :: asms) ths (asms, g) (* Same as {!apply_dep_rewrites_match_concl_with_all_tac} but we only match the conclusion of the goal. *) fun apply_dep_rewrites_match_concl_with_goal_tac - (prove_premise : tactic) (then_tac : thm_tactic) (th : thm) : tactic = + (prove_premise : tactic) (then_tac : thm_tactic) (ths : thm Net.net) : tactic = fn (asms, g) => - apply_dep_rewrites_match_concl_with_terms_tac prove_premise then_tac asms [g] th (asms, g) + apply_dep_rewrites_match_concl_with_terms_tac prove_premise then_tac asms [g] ths (asms, g) (* A theorem might be of the shape: [H => A = B /\ C = D], in which case we can split it into: @@ -396,6 +447,21 @@ fun split_rewrite_thm (th : thm) : thm list = map (transform_th o mk_th) tml end +(* Create a net from a list of rewriting theorems, from which we will match + the conclusion against various subterms. *) +fun net_of_rewrite_thms (thml : thm list) : thm Net.net = + let + fun insert_th (th, net) = + let + val th = (UNDISCH_ALL o SPEC_ALL) th + val tm = (lhs o concl) th + in + Net.insert (tm, th) net + end + in + foldl insert_th Net.empty thml + end + (* A dependent rewrite tactic which introduces the premises in a new goal. We try to apply dependent rewrites to the whole goal, including its assumptions. @@ -407,8 +473,9 @@ fun sg_dep_rewrite_all_tac (th : thm) = let (* Split the theorem *) val thml = split_rewrite_thm th + val ths = net_of_rewrite_thms thml in - MAP_EVERY (apply_dep_rewrites_match_concl_with_all_tac all_tac assume_tac) thml + apply_dep_rewrites_match_concl_with_all_tac all_tac assume_tac ths end (* Same as {!sg_dep_rewrite_tac} but this time we apply rewrites only in the conclusion @@ -418,8 +485,9 @@ fun sg_dep_rewrite_goal_tac (th : thm) = let (* Split the theorem *) val thml = split_rewrite_thm th + val ths = net_of_rewrite_thms thml in - MAP_EVERY (apply_dep_rewrites_match_concl_with_goal_tac all_tac assume_tac) thml + apply_dep_rewrites_match_concl_with_goal_tac all_tac assume_tac ths end (* @@ -434,18 +502,18 @@ apply_dep_rewrites_match_concl_tac int_to_u32_id *) -(* Attempt to apply dependent rewrites with a theorem by matching its +(* Attempt to apply dependent rewrites with theorems by matching their first premise with subterms of the goal. Leaves the premises as subgoals if [prove_premise] doesn't prove them. *) fun apply_dep_rewrites_match_first_premise_with_all_tac (keep : thm -> bool) - (prove_premise : tactic) (then_tac : thm_tactic) (th : thm) : tactic = + (prove_premise : tactic) (then_tac : thm_tactic) (ths : thm Net.net) : tactic = fn (asms, g) => let (* Discharge the assumptions so that the goal is one single term *) - val thms = inst_match_first_premise_in_terms keep th (g :: asms); + val thms = inst_match_first_premise_in_terms keep ths (g :: asms); val thms = map thm_to_conj_implies thms; fun apply_tac th = let @@ -460,17 +528,6 @@ fun apply_dep_rewrites_match_first_premise_with_all_tac val cooper_tac = COOPER_TAC -(* TODO: COOPER_TAC fails in the proof below, because of x <> y. We should - create an issue/PR for HOL4. - -Theorem cooper_fail: - ∀(x y : 'a). x ≠ y ==> 0 ≤ i ==> i ≠ 0 ⇒ 0 < i -Proof - rw [] >> cooper_tac -QED - -*) - (* Filter the assumptions in the goal *) fun filter_assums_tac (keep : term -> bool) : tactic = fn (asms, g) => diff --git a/backends/hol4/primitivesLib.sml b/backends/hol4/primitivesLib.sml index 057c57bd..cf7368a6 100644 --- a/backends/hol4/primitivesLib.sml +++ b/backends/hol4/primitivesLib.sml @@ -150,6 +150,9 @@ val integer_conversion_lemmas_list = [ u128_to_int_int_to_u128 ] +(* Using a net for efficiency *) +val integer_conversion_lemmas_net = net_of_rewrite_thms integer_conversion_lemmas_list + (* Look for conversions from integers to machine integers and back. {[ u32_to_int (int_to_u32 x) @@ -161,17 +164,19 @@ val integer_conversion_lemmas_list = [ ]} *) val rewrite_with_dep_int_lemmas : tactic = - (* We're not trying to be smart: we just try to rewrite with each theorem at - a time *) let - val prove_premise = full_simp_tac simpLib.empty_ss integer_bounds_defs_list >> int_tac; - val then_tac1 = (fn th => full_simp_tac simpLib.empty_ss [th]); - val rewr_tac1 = apply_dep_rewrites_match_concl_with_all_tac prove_premise then_tac1; - val then_tac2 = (fn th => full_simp_tac simpLib.empty_ss [th]); - val rewr_tac2 = apply_dep_rewrites_match_first_premise_with_all_tac (fn _ => true) prove_premise then_tac2; + val prove_premise = full_simp_tac simpLib.empty_ss integer_bounds_defs_list >> int_tac + (* Rewriting based on matching the conclusion. *) + val then_tac1 = (fn th => full_simp_tac simpLib.empty_ss [th]) + val rewr_tac1 = apply_dep_rewrites_match_concl_with_all_tac prove_premise then_tac1 + (* Rewriting based on matching the first premise. + We're not trying to be smart: we just try to rewrite with each theorem at + a time. + Remark: this is not used for now. *) + val then_tac2 = (fn th => full_simp_tac simpLib.empty_ss [th]) + val rewr_tac2 = apply_dep_rewrites_match_first_premise_with_all_tac (fn _ => true) prove_premise then_tac2 in - map_every_tac rewr_tac1 integer_conversion_lemmas_list >> - map_every_tac rewr_tac2 [] + rewr_tac1 integer_conversion_lemmas_net end (* Massage a bit the goal, for instance by introducing integer bounds in the -- cgit v1.2.3