diff options
-rw-r--r-- | compiler/Extract.ml | 6 | ||||
-rw-r--r-- | compiler/ExtractTypes.ml | 14 | ||||
-rw-r--r-- | tests/lean/Matches.lean | 4 |
3 files changed, 15 insertions, 9 deletions
diff --git a/compiler/Extract.ml b/compiler/Extract.ml index eab85054..4acf3f99 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -241,11 +241,12 @@ let rec extract_typed_pattern (span : Meta.span) (ctx : extraction_ctx) (fmt : F.formatter) (is_let : bool) (inside : bool) ?(with_type = false) (v : typed_pattern) : extraction_ctx = if with_type then F.pp_print_string fmt "("; + let is_pattern = true in let inside = inside && not with_type in let ctx = match v.value with | PatConstant cv -> - extract_literal span fmt inside cv; + extract_literal span fmt is_pattern inside cv; ctx | PatVar (v, _) -> let vname = ctx_compute_var_basename span ctx v.basename v.ty in @@ -307,6 +308,7 @@ let extract_texpression_errors (fmt : F.formatter) = let rec extract_texpression (span : Meta.span) (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) (e : texpression) : unit = + let is_pattern = false in match e.e with | Var var_id -> let var_name = ctx_get_var span var_id ctx in @@ -314,7 +316,7 @@ let rec extract_texpression (span : Meta.span) (ctx : extraction_ctx) | CVar var_id -> let var_name = ctx_get_const_generic_var span var_id ctx in F.pp_print_string fmt var_name - | Const cv -> extract_literal span fmt inside cv + | Const cv -> extract_literal span fmt is_pattern inside cv | App _ -> let app, args = destruct_apps e in extract_App span ctx fmt inside app args diff --git a/compiler/ExtractTypes.ml b/compiler/ExtractTypes.ml index 15e75da2..631db13e 100644 --- a/compiler/ExtractTypes.ml +++ b/compiler/ExtractTypes.ml @@ -11,12 +11,13 @@ include ExtractBase Inputs: - formatter + - [is_pattern]: if [true], it means we are generating a (match) pattern - [inside]: if [true], the value should be wrapped in parentheses if it is made of an application (ex.: [U32 3]) - the constant value *) -let extract_literal (span : Meta.span) (fmt : F.formatter) (inside : bool) - (cv : literal) : unit = +let extract_literal (span : Meta.span) (fmt : F.formatter) (is_pattern : bool) + (inside : bool) (cv : literal) : unit = match cv with | VScalar sv -> ( match backend () with @@ -39,8 +40,11 @@ let extract_literal (span : Meta.span) (fmt : F.formatter) (inside : bool) let iname = int_name sv.int_ty in F.pp_print_string fmt ("%" ^ iname) | Lean -> - let iname = String.lowercase_ascii (int_name sv.int_ty) in - F.pp_print_string fmt ("#" ^ iname) + (* We don't use the same notation for patterns and regular literals *) + if is_pattern then F.pp_print_string fmt "#scalar" + else + let iname = String.lowercase_ascii (int_name sv.int_ty) in + F.pp_print_string fmt ("#" ^ iname) | HOL4 -> () | _ -> craise __FILE__ __LINE__ span "Unreachable"); if print_brackets then F.pp_print_string fmt ")") @@ -409,7 +413,7 @@ let extract_const_generic (span : Meta.span) (ctx : extraction_ctx) | CgGlobal id -> let s = ctx_get_global span id ctx in F.pp_print_string fmt s - | CgValue v -> extract_literal span fmt inside v + | CgValue v -> extract_literal span fmt false inside v | CgVar id -> let s = ctx_get_const_generic_var span id ctx in F.pp_print_string fmt s diff --git a/tests/lean/Matches.lean b/tests/lean/Matches.lean index 3e3a558b..9233841b 100644 --- a/tests/lean/Matches.lean +++ b/tests/lean/Matches.lean @@ -9,8 +9,8 @@ namespace matches Source: 'tests/src/matches.rs', lines 4:0-4:27 -/ def match_u32 (x : U32) : Result U32 := match x with - | 0#u32 => Result.ok 0#u32 - | 1#u32 => Result.ok 1#u32 + | 0#scalar => Result.ok 0#u32 + | 1#scalar => Result.ok 1#u32 | _ => Result.ok 2#u32 end matches |