diff options
author | Jonathan Protzenko | 2023-01-25 17:57:52 -0800 |
---|---|---|
committer | Son HO | 2023-06-04 21:44:33 +0200 |
commit | 20c076b2bae86450dbc63a0d4976e6338f5c9aa0 (patch) | |
tree | 818ccda7a4ec1c6d4fb54ffcead8beca48c15871 /compiler | |
parent | d841397d93c06310a7e91087e15ba441c2b74f26 (diff) |
Custom syntax support for structures in Lean
Diffstat (limited to '')
-rw-r--r-- | compiler/Extract.ml | 43 | ||||
-rw-r--r-- | compiler/Pure.ml | 22 |
2 files changed, 45 insertions, 20 deletions
diff --git a/compiler/Extract.ml b/compiler/Extract.ml index 98a5f41a..7ba64155 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -1624,15 +1624,40 @@ and extract_adt_cons (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) ctx_get_variant adt_cons.adt_id vid ctx | None -> ctx_get_struct adt_cons.adt_id ctx in - let use_parentheses = inside && args <> [] in - if use_parentheses then F.pp_print_string fmt "("; - F.pp_print_string fmt cons; - Collections.List.iter - (fun v -> - F.pp_print_space fmt (); - extract_texpression ctx fmt true v) - args; - if use_parentheses then F.pp_print_string fmt ")" + let is_lean_struct = !backend = Lean && adt_cons.variant_id = None in + if is_lean_struct then + let decl_id = match adt_cons.adt_id with | AdtId id -> id | _ -> assert false in + let def_kind = (TypeDeclId.Map.find decl_id ctx.trans_ctx.type_context.type_decls).kind in + let fields = match def_kind with | T.Struct fields -> fields | _ -> assert false in + let fields = FieldId.mapi (fun fid f -> (fid, f)) fields in + F.pp_open_vbox fmt ctx.indent_incr; + F.pp_print_string fmt "{"; + F.pp_print_space fmt (); + Collections.List.iter_link + (fun () -> + F.pp_print_string fmt ","; + F.pp_print_space fmt () + ) + (fun ((fid, _), e) -> + let f = ctx_get_field adt_cons.adt_id fid ctx in + F.pp_print_string fmt f; + F.pp_print_string fmt " := "; + extract_texpression ctx fmt true e + ) + (List.combine fields args); + F.pp_print_space fmt (); + F.pp_close_box fmt (); + F.pp_print_string fmt "}"; + else + let use_parentheses = inside && args <> [] in + if use_parentheses then F.pp_print_string fmt "("; + F.pp_print_string fmt cons; + Collections.List.iter + (fun v -> + F.pp_print_space fmt (); + extract_texpression ctx fmt true v) + args; + if use_parentheses then F.pp_print_string fmt ")" (** Subcase of the app case: ADT field projector. *) and extract_field_projector (ctx : extraction_ctx) (fmt : F.formatter) diff --git a/compiler/Pure.ml b/compiler/Pure.ml index 912e05fb..5b2fca7d 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -92,7 +92,7 @@ class ['self] map_ty_base = type ty = | Adt of type_id * ty list (** {!Adt} encodes ADTs and tuples and assumed types. - + TODO: what about the ended regions? (ADTs may be parameterized with several region variables. When giving back an ADT value, we may be able to only give back part of the ADT. We need a way to encode @@ -342,7 +342,7 @@ type fun_or_op_id = type adt_cons_id = { adt_id : type_id; variant_id : variant_id option } [@@deriving show] -(** Projection - For now we don't support projection of tuple fields +(** Projection - For now we don't support projection of tuple fields (because not all the backends have syntax for this). *) type projection = { adt_id : type_id; field_id : FieldId.id } [@@deriving show] @@ -438,7 +438,7 @@ type expression = | Qualif of qualif (** A top-level qualifier *) | Let of bool * typed_pattern * texpression * texpression (** Let binding. - + TODO: the boolean should be replaced by an enum: sometimes we use the error-monad, sometimes we use the state-error monad (and we should do this an a per-function basis! For instance, arithmetic @@ -459,14 +459,14 @@ type expression = ]} (not all languages have syntax like [p.0], [p.1]... and it is more readable anyway). - + 2. When expanding an enumeration with one variant. In this case, {!Let} has to be understood as: {[ let Cons x tl = ls in ... ]} - + Note that later, depending on the language we extract to, we can eventually update it to something like this (for F*, for instance): {[ @@ -512,7 +512,7 @@ and texpression = { e : expression; ty : ty } (** Meta-value (converted to an expression). It is important that the content is opaque. - + TODO: is it possible to mark the whole mvalue type as opaque? *) and mvalue = (texpression[@opaque]) @@ -620,7 +620,7 @@ type fun_sig_info = { result (back_out0 & ... & back_outp)] (* error-monad *) [in_ty0 -> ... -> in_tyn -> state -> back_in0 -> ... back_inm -> state -> result (state & (back_out0 & ... & back_outp))] (* state-error *) - + Note that a stateful backward function may take two states as inputs: the state received by the associated forward function, and the state at which the backward is called. This leads to code of the following shape: @@ -674,11 +674,11 @@ type fun_sig = { In case of a forward function, the list has length = 1, for the type of the returned value. - + In case of backward function, the list contains all the types of all the given back values (there is at most one type per forward input argument). - + Ex.: {[ fn choose<'a, T>(b : bool, x : &'a mut T, y : &'a mut T) -> &'a mut T; @@ -686,7 +686,7 @@ type fun_sig = { Decomposed outputs: - forward function: [[T]] - backward function: [[T; T]] (for "x" and "y") - + Non-decomposed ouputs (if the function can fail, but is not stateful): - [result T] - [[result (T * T)]] @@ -725,7 +725,7 @@ type fun_decl = { back_id : T.RegionGroupId.id option; basename : fun_name; (** The "base" name of the function. - + The base name is the original name of the Rust function. We add suffixes (to identify the forward/backward functions) later. *) |