summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon Ho2024-06-12 18:40:27 +0200
committerSon Ho2024-06-12 18:40:27 +0200
commit19abb19134efe0b16409f955b13af36262f231a8 (patch)
tree51d7148c8a5ce7464dde94844d77a1e03dfde65e
parent79e19aa701086de9f080357d817284559f900bcc (diff)
Update the code extraction and regenerate the tests
-rw-r--r--compiler/Extract.ml6
-rw-r--r--compiler/ExtractTypes.ml14
-rw-r--r--tests/lean/Matches.lean4
3 files changed, 15 insertions, 9 deletions
diff --git a/compiler/Extract.ml b/compiler/Extract.ml
index eab85054..4acf3f99 100644
--- a/compiler/Extract.ml
+++ b/compiler/Extract.ml
@@ -241,11 +241,12 @@ let rec extract_typed_pattern (span : Meta.span) (ctx : extraction_ctx)
(fmt : F.formatter) (is_let : bool) (inside : bool) ?(with_type = false)
(v : typed_pattern) : extraction_ctx =
if with_type then F.pp_print_string fmt "(";
+ let is_pattern = true in
let inside = inside && not with_type in
let ctx =
match v.value with
| PatConstant cv ->
- extract_literal span fmt inside cv;
+ extract_literal span fmt is_pattern inside cv;
ctx
| PatVar (v, _) ->
let vname = ctx_compute_var_basename span ctx v.basename v.ty in
@@ -307,6 +308,7 @@ let extract_texpression_errors (fmt : F.formatter) =
let rec extract_texpression (span : Meta.span) (ctx : extraction_ctx)
(fmt : F.formatter) (inside : bool) (e : texpression) : unit =
+ let is_pattern = false in
match e.e with
| Var var_id ->
let var_name = ctx_get_var span var_id ctx in
@@ -314,7 +316,7 @@ let rec extract_texpression (span : Meta.span) (ctx : extraction_ctx)
| CVar var_id ->
let var_name = ctx_get_const_generic_var span var_id ctx in
F.pp_print_string fmt var_name
- | Const cv -> extract_literal span fmt inside cv
+ | Const cv -> extract_literal span fmt is_pattern inside cv
| App _ ->
let app, args = destruct_apps e in
extract_App span ctx fmt inside app args
diff --git a/compiler/ExtractTypes.ml b/compiler/ExtractTypes.ml
index 15e75da2..631db13e 100644
--- a/compiler/ExtractTypes.ml
+++ b/compiler/ExtractTypes.ml
@@ -11,12 +11,13 @@ include ExtractBase
Inputs:
- formatter
+ - [is_pattern]: if [true], it means we are generating a (match) pattern
- [inside]: if [true], the value should be wrapped in parentheses
if it is made of an application (ex.: [U32 3])
- the constant value
*)
-let extract_literal (span : Meta.span) (fmt : F.formatter) (inside : bool)
- (cv : literal) : unit =
+let extract_literal (span : Meta.span) (fmt : F.formatter) (is_pattern : bool)
+ (inside : bool) (cv : literal) : unit =
match cv with
| VScalar sv -> (
match backend () with
@@ -39,8 +40,11 @@ let extract_literal (span : Meta.span) (fmt : F.formatter) (inside : bool)
let iname = int_name sv.int_ty in
F.pp_print_string fmt ("%" ^ iname)
| Lean ->
- let iname = String.lowercase_ascii (int_name sv.int_ty) in
- F.pp_print_string fmt ("#" ^ iname)
+ (* We don't use the same notation for patterns and regular literals *)
+ if is_pattern then F.pp_print_string fmt "#scalar"
+ else
+ let iname = String.lowercase_ascii (int_name sv.int_ty) in
+ F.pp_print_string fmt ("#" ^ iname)
| HOL4 -> ()
| _ -> craise __FILE__ __LINE__ span "Unreachable");
if print_brackets then F.pp_print_string fmt ")")
@@ -409,7 +413,7 @@ let extract_const_generic (span : Meta.span) (ctx : extraction_ctx)
| CgGlobal id ->
let s = ctx_get_global span id ctx in
F.pp_print_string fmt s
- | CgValue v -> extract_literal span fmt inside v
+ | CgValue v -> extract_literal span fmt false inside v
| CgVar id ->
let s = ctx_get_const_generic_var span id ctx in
F.pp_print_string fmt s
diff --git a/tests/lean/Matches.lean b/tests/lean/Matches.lean
index 3e3a558b..9233841b 100644
--- a/tests/lean/Matches.lean
+++ b/tests/lean/Matches.lean
@@ -9,8 +9,8 @@ namespace matches
Source: 'tests/src/matches.rs', lines 4:0-4:27 -/
def match_u32 (x : U32) : Result U32 :=
match x with
- | 0#u32 => Result.ok 0#u32
- | 1#u32 => Result.ok 1#u32
+ | 0#scalar => Result.ok 0#u32
+ | 1#scalar => Result.ok 1#u32
| _ => Result.ok 2#u32
end matches