From 931fabe3e8590815548d606b33fc8db31e9f6010 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 3 Aug 2023 16:21:43 +0200 Subject: Fix an issue with the extraction of aggregated arrays --- compiler/Extract.ml | 218 ++++++++++++++++++++------------- compiler/InterpreterExpressions.ml | 4 +- compiler/InterpreterLoopsFixedPoint.ml | 2 +- compiler/InterpreterLoopsMatchCtxs.mli | 6 +- compiler/PrintPure.ml | 38 ++++-- compiler/Pure.ml | 139 ++++++++++++++------- compiler/PureMicroPasses.ml | 6 +- compiler/PureTypeCheck.ml | 55 +++++---- compiler/SymbolicAst.ml | 9 +- compiler/SymbolicToPure.ml | 23 +++- 10 files changed, 320 insertions(+), 180 deletions(-) diff --git a/compiler/Extract.ml b/compiler/Extract.ml index 9ee94db2..f161cc13 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -2260,7 +2260,7 @@ let rec extract_texpression (ctx : extraction_ctx) (fmt : F.formatter) | Let (_, _, _, _) -> extract_lets ctx fmt inside e | Switch (scrut, body) -> extract_Switch ctx fmt inside scrut body | Meta (_, e) -> extract_texpression ctx fmt inside e - | StructUpdate supd -> extract_StructUpdate ctx fmt inside supd + | StructUpdate supd -> extract_StructUpdate ctx fmt inside e.ty supd | Loop _ -> (* The loop nodes should have been eliminated in {!PureMicroPasses} *) raise (Failure "Unreachable") @@ -2723,104 +2723,152 @@ and extract_Switch (ctx : extraction_ctx) (fmt : F.formatter) (_inside : bool) F.pp_close_box fmt () and extract_StructUpdate (ctx : extraction_ctx) (fmt : F.formatter) - (inside : bool) (supd : struct_update) : unit = + (inside : bool) (e_ty : ty) (supd : struct_update) : unit = (* We can't update a subset of the fields in Coq (i.e., we can do [{| x:= 3; y := 4 |}], but there is no syntax for [{| s with x := 3 |}]) *) assert (!backend <> Coq || supd.init = None); (* In the case of HOL4, records with no fields are not supported and are thus extracted to unit. We need to check that by looking up the definition *) let extract_as_unit = - if !backend = HOL4 then - let d = - TypeDeclId.Map.find supd.struct_id ctx.trans_ctx.type_context.type_decls - in - d.kind = Struct [] - else false + match (!backend, supd.struct_id) with + | HOL4, AdtId adt_id -> + let d = + TypeDeclId.Map.find adt_id ctx.trans_ctx.type_context.type_decls + in + d.kind = Struct [] + | _ -> false in if extract_as_unit then (* Remark: this is only valid for HOL4 (for instance the Coq unit value is [tt]) *) F.pp_print_string fmt "()" - else ( - F.pp_open_hvbox fmt 0; - F.pp_open_hvbox fmt ctx.indent_incr; - (* Inner/outer brackets: there are several syntaxes for the field updates. - - For instance, in F*: - {[ - { x with f = ..., ... } - ]} - - In HOL4: - {[ - x with <| f = ..., ... |> - ]} - - In the above examples: - - in F*, the { } brackets are "outer" brackets - - in HOL4, the <| |> brackets are "inner" brackets + else + (* There are two cases: + - this is a regular struct + - this is an array *) - (* Outer brackets *) - let olb, orb = - match !backend with - | Lean | FStar -> (Some "{", Some "}") - | Coq -> (Some "{|", Some "|}") - | HOL4 -> (None, None) - in - (* Inner brackets *) - let ilb, irb = - match !backend with - | Lean | FStar | Coq -> (None, None) - | HOL4 -> (Some "<|", Some "|>") - in - (* Helper *) - let print_bracket (is_left : bool) b = - match b with - | Some b -> - if not is_left then F.pp_print_space fmt (); - F.pp_print_string fmt b; - if is_left then F.pp_print_space fmt () - | None -> () - in - print_bracket true olb; - let need_paren = inside && !backend = HOL4 in - if need_paren then F.pp_print_string fmt "("; - F.pp_open_hvbox fmt ctx.indent_incr; - if supd.init <> None then ( - let var_name = ctx_get_var (Option.get supd.init) ctx in - F.pp_print_string fmt var_name; - F.pp_print_space fmt (); - F.pp_print_string fmt "with"; - F.pp_print_space fmt ()); - print_bracket true ilb; - F.pp_open_hvbox fmt 0; - let delimiter = - match !backend with Lean -> "," | Coq | FStar | HOL4 -> ";" - in - let assign = - match !backend with Coq | Lean | HOL4 -> ":=" | FStar -> "=" - in - Collections.List.iter_link - (fun () -> - F.pp_print_string fmt delimiter; - F.pp_print_space fmt ()) - (fun (fid, fe) -> + match supd.struct_id with + | AdtId _ -> + F.pp_open_hvbox fmt 0; F.pp_open_hvbox fmt ctx.indent_incr; - let f = ctx_get_field (AdtId supd.struct_id) fid ctx in - F.pp_print_string fmt f; - F.pp_print_string fmt (" " ^ assign); - F.pp_print_space fmt (); + (* Inner/outer brackets: there are several syntaxes for the field updates. + + For instance, in F*: + {[ + { x with f = ..., ... } + ]} + + In HOL4: + {[ + x with <| f = ..., ... |> + ]} + + In the above examples: + - in F*, the { } brackets are "outer" brackets + - in HOL4, the <| |> brackets are "inner" brackets + *) + (* Outer brackets *) + let olb, orb = + match !backend with + | Lean | FStar -> (Some "{", Some "}") + | Coq -> (Some "{|", Some "|}") + | HOL4 -> (None, None) + in + (* Inner brackets *) + let ilb, irb = + match !backend with + | Lean | FStar | Coq -> (None, None) + | HOL4 -> (Some "<|", Some "|>") + in + (* Helper *) + let print_bracket (is_left : bool) b = + match b with + | Some b -> + if not is_left then F.pp_print_space fmt (); + F.pp_print_string fmt b; + if is_left then F.pp_print_space fmt () + | None -> () + in + print_bracket true olb; + let need_paren = inside && !backend = HOL4 in + if need_paren then F.pp_print_string fmt "("; F.pp_open_hvbox fmt ctx.indent_incr; - extract_texpression ctx fmt true fe; + if supd.init <> None then ( + let var_name = ctx_get_var (Option.get supd.init) ctx in + F.pp_print_string fmt var_name; + F.pp_print_space fmt (); + F.pp_print_string fmt "with"; + F.pp_print_space fmt ()); + print_bracket true ilb; + F.pp_open_hvbox fmt 0; + let delimiter = + match !backend with Lean -> "," | Coq | FStar | HOL4 -> ";" + in + let assign = + match !backend with Coq | Lean | HOL4 -> ":=" | FStar -> "=" + in + Collections.List.iter_link + (fun () -> + F.pp_print_string fmt delimiter; + F.pp_print_space fmt ()) + (fun (fid, fe) -> + F.pp_open_hvbox fmt ctx.indent_incr; + let f = ctx_get_field supd.struct_id fid ctx in + F.pp_print_string fmt f; + F.pp_print_string fmt (" " ^ assign); + F.pp_print_space fmt (); + F.pp_open_hvbox fmt ctx.indent_incr; + extract_texpression ctx fmt true fe; + F.pp_close_box fmt (); + F.pp_close_box fmt ()) + supd.updates; F.pp_close_box fmt (); - F.pp_close_box fmt ()) - supd.updates; - F.pp_close_box fmt (); - print_bracket false irb; - F.pp_close_box fmt (); - F.pp_close_box fmt (); - if need_paren then F.pp_print_string fmt ")"; - print_bracket false orb; - F.pp_close_box fmt ()) + print_bracket false irb; + F.pp_close_box fmt (); + F.pp_close_box fmt (); + if need_paren then F.pp_print_string fmt ")"; + print_bracket false orb; + F.pp_close_box fmt () + | Assumed Array -> + (* Open the boxes *) + F.pp_open_hvbox fmt ctx.indent_incr; + let need_paren = inside in + if need_paren then F.pp_print_string fmt "("; + (* Open the box for `Array.mk T N [` *) + F.pp_open_hovbox fmt ctx.indent_incr; + (* Print the array constructor *) + let cs = ctx_get_struct false (Assumed Array) ctx in + F.pp_print_string fmt cs; + (* Print the parameters *) + let _, tys, cgs = ty_as_adt e_ty in + let ty = Collections.List.to_cons_nil tys in + F.pp_print_space fmt (); + extract_ty ctx fmt TypeDeclId.Set.empty true ty; + let cg = Collections.List.to_cons_nil cgs in + F.pp_print_space fmt (); + extract_const_generic ctx fmt true cg; + F.pp_print_space fmt (); + F.pp_print_string fmt "["; + (* Close the box for `Array.mk T N [` *) + F.pp_close_box fmt (); + (* Print the values *) + let delimiter = + match !backend with Lean -> "," | Coq | FStar | HOL4 -> ";" + in + F.pp_print_space fmt (); + F.pp_open_hovbox fmt 0; + Collections.List.iter_link + (fun () -> + F.pp_print_string fmt delimiter; + F.pp_print_space fmt ()) + (fun (_, fe) -> extract_texpression ctx fmt false fe) + supd.updates; + (* Close the boxes *) + F.pp_close_box fmt (); + if supd.updates <> [] then F.pp_print_space fmt (); + F.pp_print_string fmt "]"; + if need_paren then F.pp_print_string fmt ")"; + F.pp_close_box fmt () + | _ -> raise (Failure "Unreachable") (** Insert a space, if necessary *) let insert_req_space (fmt : F.formatter) (space : bool ref) : unit = diff --git a/compiler/InterpreterExpressions.ml b/compiler/InterpreterExpressions.ml index 0834cfe2..8b2070c6 100644 --- a/compiler/InterpreterExpressions.ml +++ b/compiler/InterpreterExpressions.ml @@ -719,9 +719,7 @@ let eval_rvalue_aggregate (config : C.config) (* Sanity check: the number of values is consistent with the length *) let len = (literal_as_scalar (const_generic_as_literal cg)).value in assert (len = Z.of_int (List.length values)); - let v = V.Adt { variant_id = None; field_values = values } in let ty = T.Adt (T.Assumed T.Array, [], [ ety ], [ cg ]) in - let aggregated : V.typed_value = { V.value = v; ty } in (* In order to generate a better AST, we introduce a symbolic value equal to the array. The reason is that otherwise, the array we introduce here might be duplicated in the generated @@ -736,7 +734,7 @@ let eval_rvalue_aggregate (config : C.config) | Some e -> (* Introduce the symbolic value in the AST *) let sv = ValuesUtils.value_as_symbolic saggregated.value in - Some (SymbolicAst.IntroSymbolic (ctx, None, sv, aggregated, e))) + Some (SymbolicAst.IntroSymbolic (ctx, None, sv, Array values, e))) in (* Compose and apply *) comp eval_ops compute cf diff --git a/compiler/InterpreterLoopsFixedPoint.ml b/compiler/InterpreterLoopsFixedPoint.ml index a9ec9ecf..4310f017 100644 --- a/compiler/InterpreterLoopsFixedPoint.ml +++ b/compiler/InterpreterLoopsFixedPoint.ml @@ -322,7 +322,7 @@ let prepare_ashared_loans (loop_id : V.LoopId.id option) : cm_fun = let sv = V.SymbolicValueId.Map.find sid new_ctx_ids_map.sids_to_values in - SymbolicAst.IntroSymbolic (ctx, None, sv, v, e)) + SymbolicAst.IntroSymbolic (ctx, None, sv, SingleValue v, e)) e !sid_subst) let prepare_ashared_loans_no_synth (loop_id : V.LoopId.id) (ctx : C.eval_ctx) : diff --git a/compiler/InterpreterLoopsMatchCtxs.mli b/compiler/InterpreterLoopsMatchCtxs.mli index d0f57f32..20b997ce 100644 --- a/compiler/InterpreterLoopsMatchCtxs.mli +++ b/compiler/InterpreterLoopsMatchCtxs.mli @@ -34,13 +34,13 @@ val compute_abs_borrows_loans_maps : We use it for joins, to check if two environments are convertible, etc. See for instance {!MakeJoinMatcher} and {!MakeCheckEquivMatcher}. - The functor is parameterized by a {!InterpreterLoopsCore.PrimMatcher}, which implements the - non-generic part of the match. More precisely, the role of {!InterpreterLoopsCore.PrimMatcher} is two + The functor is parameterized by a {!PrimMatcher}, which implements the + non-generic part of the match. More precisely, the role of {!PrimMatcher} is two provide generic functions which recursively match two values (by recursively matching the fields of ADT values for instance). When it does need to match values in a non-trivial manner (if two ADT values don't have the same variant for instance) it calls the corresponding specialized function from - {!InterpreterLoopsCore.PrimMatcher}. + {!PrimMatcher}. *) module MakeMatcher : functor (_ : PrimMatcher) -> Matcher diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml index 211fb2c2..43b11aa5 100644 --- a/compiler/PrintPure.ml +++ b/compiler/PrintPure.ml @@ -583,25 +583,39 @@ let rec texpression_to_string (fmt : ast_formatter) (inside : bool) | Loop loop -> let e = loop_to_string fmt indent indent_incr loop in if inside then "(" ^ e ^ ")" else e - | StructUpdate supd -> + | StructUpdate supd -> ( let s = match supd.init with | None -> "" | Some vid -> " " ^ fmt.var_id_to_string vid ^ " with" in - let field_names = Option.get (fmt.adt_field_names supd.struct_id None) in let indent1 = indent ^ indent_incr in let indent2 = indent1 ^ indent_incr in - let fields = - List.map - (fun (fid, fe) -> - let field = FieldId.nth field_names fid in - let fe = texpression_to_string fmt false indent2 indent_incr fe in - "\n" ^ indent1 ^ field ^ " := " ^ fe ^ ";") - supd.updates - in - let bl = if fields = [] then "" else "\n" ^ indent in - "{" ^ s ^ String.concat "" fields ^ bl ^ "}" + (* The id should be a custom type decl id or an array *) + match supd.struct_id with + | AdtId aid -> + let field_names = Option.get (fmt.adt_field_names aid None) in + let fields = + List.map + (fun (fid, fe) -> + let field = FieldId.nth field_names fid in + let fe = + texpression_to_string fmt false indent2 indent_incr fe + in + "\n" ^ indent1 ^ field ^ " := " ^ fe ^ ";") + supd.updates + in + let bl = if fields = [] then "" else "\n" ^ indent in + "{" ^ s ^ String.concat "" fields ^ bl ^ "}" + | Assumed Array -> + let fields = + List.map + (fun (_, fe) -> + texpression_to_string fmt false indent2 indent_incr fe) + supd.updates + in + "[ " ^ String.concat ", " fields ^ " ]" + | _ -> raise (Failure "Unexpected")) | Meta (meta, e) -> ( let meta_s = meta_to_string fmt meta in let e = texpression_to_string fmt inside indent indent_incr e in diff --git a/compiler/Pure.ml b/compiler/Pure.ml index e202b170..ac4ca081 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -105,6 +105,29 @@ class ['self] map_type_id_base = method visit_assumed_ty : 'env -> assumed_ty -> assumed_ty = fun _ x -> x end +(** Ancestor for reduce visitor for [ty] *) +class virtual ['self] reduce_type_id_base = + object (self : 'self) + inherit [_] VisitorsRuntime.reduce + + method visit_type_decl_id : 'env -> type_decl_id -> 'a = + fun _ _ -> self#zero + + method visit_assumed_ty : 'env -> assumed_ty -> 'a = fun _ _ -> self#zero + end + +(** Ancestor for mapreduce visitor for [ty] *) +class virtual ['self] mapreduce_type_id_base = + object (self : 'self) + inherit [_] VisitorsRuntime.mapreduce + + method visit_type_decl_id : 'env -> type_decl_id -> type_decl_id * 'a = + fun _ x -> (x, self#zero) + + method visit_assumed_ty : 'env -> assumed_ty -> assumed_ty * 'a = + fun _ x -> (x, self#zero) + end + type type_id = AdtId of type_decl_id | Tuple | Assumed of assumed_ty [@@deriving show, @@ -126,6 +149,22 @@ type type_id = AdtId of type_decl_id | Tuple | Assumed of assumed_ty nude = true (* Don't inherit {!VisitorsRuntime.iter} *); concrete = true; polymorphic = false; + }, + visitors + { + name = "reduce_type_id"; + variety = "reduce"; + ancestors = [ "reduce_type_id_base" ]; + nude = true (* Don't inherit {!VisitorsRuntime.iter} *); + polymorphic = false; + }, + visitors + { + name = "mapreduce_type_id"; + variety = "mapreduce"; + ancestors = [ "mapreduce_type_id_base" ]; + nude = true (* Don't inherit {!VisitorsRuntime.iter} *); + polymorphic = false; }] type literal_type = T.literal_type [@@deriving show, ord] @@ -148,6 +187,26 @@ class ['self] map_ty_base = method visit_type_var_id : 'env -> type_var_id -> type_var_id = fun _ x -> x end +(** Ancestor for reduce visitor for [ty] *) +class virtual ['self] reduce_ty_base = + object (self : 'self) + inherit [_] reduce_type_id + inherit! [_] T.reduce_const_generic + inherit! [_] PV.reduce_literal_type + method visit_type_var_id : 'env -> type_var_id -> 'a = fun _ _ -> self#zero + end + +(** Ancestor for mapreduce visitor for [ty] *) +class virtual ['self] mapreduce_ty_base = + object (self : 'self) + inherit [_] mapreduce_type_id + inherit! [_] T.mapreduce_const_generic + inherit! [_] PV.mapreduce_literal_type + + method visit_type_var_id : 'env -> type_var_id -> type_var_id * 'a = + fun _ x -> (x, self#zero) + end + type ty = | Adt of type_id * ty list * const_generic list (** {!Adt} encodes ADTs and tuples and assumed types. @@ -176,9 +235,25 @@ type ty = name = "map_ty"; variety = "map"; ancestors = [ "map_ty_base" ]; - nude = true (* Don't inherit {!VisitorsRuntime.iter} *); + nude = true (* Don't inherit {!VisitorsRuntime.map} *); concrete = true; polymorphic = false; + }, + visitors + { + name = "reduce_ty"; + variety = "reduce"; + ancestors = [ "reduce_ty_base" ]; + nude = true (* Don't inherit {!VisitorsRuntime.reduce} *); + polymorphic = false; + }, + visitors + { + name = "mapreduce_ty"; + variety = "mapreduce"; + ancestors = [ "mapreduce_ty_base" ]; + nude = true (* Don't inherit {!VisitorsRuntime.mapreduce} *); + polymorphic = false; }] type field = { field_name : string option; field_ty : ty } [@@deriving show] @@ -243,51 +318,39 @@ type variant_id = VariantId.id [@@deriving show] (** Ancestor for {!iter_typed_pattern} visitor *) class ['self] iter_typed_pattern_base = object (_self : 'self) - inherit [_] VisitorsRuntime.iter - method visit_literal : 'env -> literal -> unit = fun _ _ -> () + inherit [_] iter_ty method visit_var : 'env -> var -> unit = fun _ _ -> () method visit_mplace : 'env -> mplace -> unit = fun _ _ -> () - method visit_ty : 'env -> ty -> unit = fun _ _ -> () method visit_variant_id : 'env -> variant_id -> unit = fun _ _ -> () end (** Ancestor for {!map_typed_pattern} visitor *) class ['self] map_typed_pattern_base = object (_self : 'self) - inherit [_] VisitorsRuntime.map - method visit_literal : 'env -> literal -> literal = fun _ x -> x + inherit [_] map_ty method visit_var : 'env -> var -> var = fun _ x -> x method visit_mplace : 'env -> mplace -> mplace = fun _ x -> x - method visit_ty : 'env -> ty -> ty = fun _ x -> x method visit_variant_id : 'env -> variant_id -> variant_id = fun _ x -> x end (** Ancestor for {!reduce_typed_pattern} visitor *) class virtual ['self] reduce_typed_pattern_base = object (self : 'self) - inherit [_] VisitorsRuntime.reduce - method visit_literal : 'env -> literal -> 'a = fun _ _ -> self#zero + inherit [_] reduce_ty method visit_var : 'env -> var -> 'a = fun _ _ -> self#zero method visit_mplace : 'env -> mplace -> 'a = fun _ _ -> self#zero - method visit_ty : 'env -> ty -> 'a = fun _ _ -> self#zero method visit_variant_id : 'env -> variant_id -> 'a = fun _ _ -> self#zero end (** Ancestor for {!mapreduce_typed_pattern} visitor *) class virtual ['self] mapreduce_typed_pattern_base = object (self : 'self) - inherit [_] VisitorsRuntime.mapreduce - - method visit_literal : 'env -> literal -> literal * 'a = - fun _ x -> (x, self#zero) - + inherit [_] mapreduce_ty method visit_var : 'env -> var -> var * 'a = fun _ x -> (x, self#zero) method visit_mplace : 'env -> mplace -> mplace * 'a = fun _ x -> (x, self#zero) - method visit_ty : 'env -> ty -> ty * 'a = fun _ x -> (x, self#zero) - method visit_variant_id : 'env -> variant_id -> variant_id * 'a = fun _ x -> (x, self#zero) end @@ -419,11 +482,10 @@ type var_id = VarId.id [@@deriving show, ord] class ['self] iter_expression_base = object (_self : 'self) inherit [_] iter_typed_pattern - method visit_integer_type : 'env -> integer_type -> unit = fun _ _ -> () + inherit! [_] iter_type_id method visit_var_id : 'env -> var_id -> unit = fun _ _ -> () method visit_qualif : 'env -> qualif -> unit = fun _ _ -> () method visit_loop_id : 'env -> loop_id -> unit = fun _ _ -> () - method visit_type_decl_id : 'env -> type_decl_id -> unit = fun _ _ -> () method visit_field_id : 'env -> field_id -> unit = fun _ _ -> () end @@ -431,17 +493,10 @@ class ['self] iter_expression_base = class ['self] map_expression_base = object (_self : 'self) inherit [_] map_typed_pattern - - method visit_integer_type : 'env -> integer_type -> integer_type = - fun _ x -> x - + inherit! [_] map_type_id method visit_var_id : 'env -> var_id -> var_id = fun _ x -> x method visit_qualif : 'env -> qualif -> qualif = fun _ x -> x method visit_loop_id : 'env -> loop_id -> loop_id = fun _ x -> x - - method visit_type_decl_id : 'env -> type_decl_id -> type_decl_id = - fun _ x -> x - method visit_field_id : 'env -> field_id -> field_id = fun _ x -> x end @@ -449,17 +504,10 @@ class ['self] map_expression_base = class virtual ['self] reduce_expression_base = object (self : 'self) inherit [_] reduce_typed_pattern - - method visit_integer_type : 'env -> integer_type -> 'a = - fun _ _ -> self#zero - + inherit! [_] reduce_type_id method visit_var_id : 'env -> var_id -> 'a = fun _ _ -> self#zero method visit_qualif : 'env -> qualif -> 'a = fun _ _ -> self#zero method visit_loop_id : 'env -> loop_id -> 'a = fun _ _ -> self#zero - - method visit_type_decl_id : 'env -> type_decl_id -> 'a = - fun _ _ -> self#zero - method visit_field_id : 'env -> field_id -> 'a = fun _ _ -> self#zero end @@ -467,9 +515,7 @@ class virtual ['self] reduce_expression_base = class virtual ['self] mapreduce_expression_base = object (self : 'self) inherit [_] mapreduce_typed_pattern - - method visit_integer_type : 'env -> integer_type -> integer_type * 'a = - fun _ x -> (x, self#zero) + inherit! [_] mapreduce_type_id method visit_var_id : 'env -> var_id -> var_id * 'a = fun _ x -> (x, self#zero) @@ -480,9 +526,6 @@ class virtual ['self] mapreduce_expression_base = method visit_loop_id : 'env -> loop_id -> loop_id * 'a = fun _ x -> (x, self#zero) - method visit_type_decl_id : 'env -> type_decl_id -> type_decl_id * 'a = - fun _ x -> (x, self#zero) - method visit_field_id : 'env -> field_id -> field_id * 'a = fun _ x -> (x, self#zero) end @@ -592,9 +635,21 @@ and loop = { {[ { s with x := 3 } ]} + + We also use struct updates to encode array aggregates, so that whenever + the user writes code like: + {[ + let a : [u32; 2] = [0, 1]; + ... + ]} + this gets encoded to: + {[ + let a : Array u32 2 = Array.mk [0, 1] in + ... + ]} *) and struct_update = { - struct_id : type_decl_id; + struct_id : type_id; init : var_id option; updates : (field_id * texpression) list; } diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index 00620c58..58a5f9e2 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -611,7 +611,7 @@ let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl = (!Config.backend <> Lean && !Config.backend <> Coq) || not is_rec then - let struct_id = adt_id in + let struct_id = AdtId adt_id in let init = None in let updates = FieldId.mapi @@ -1168,8 +1168,8 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl = Var v ) -> (* We check that this is the proper ADT, and the proper field *) if - proj_adt_id = struct_id && field_id = fid - && x.ty = adt_ty + AdtId proj_adt_id = struct_id + && field_id = fid && x.ty = adt_ty then Some v else None | _ -> None) diff --git a/compiler/PureTypeCheck.ml b/compiler/PureTypeCheck.ml index 8f5b5df4..8d28bb8a 100644 --- a/compiler/PureTypeCheck.ml +++ b/compiler/PureTypeCheck.ml @@ -5,8 +5,8 @@ open PureUtils (** Utility function, used for type checking. - We need the number of fields for cases like `Slice`, when the number of fields - varies. + This function should only be used for "regular" ADTs, where the number + of fields is fixed: it shouldn't be used for arrays, slices, etc. *) let get_adt_field_types (type_decls : type_decl TypeDeclId.Map.t) (type_id : type_id) (variant_id : VariantId.id option) (tys : ty list) @@ -55,17 +55,9 @@ let get_adt_field_types (type_decls : type_decl TypeDeclId.Map.t) let ty = Collections.List.to_cons_nil tys in assert (variant_id = None); [ ty; ty ] - | Array -> - let ty = Collections.List.to_cons_nil tys in - let cg = Collections.List.to_cons_nil cgs in - let len = - (PrimitiveValuesUtils.literal_as_scalar - (TypesUtils.const_generic_as_literal cg)) - .value - in - let len = Z.to_int len in - Collections.List.repeat len ty - | Vec | Slice | Str -> + | Vec | Array | Slice | Str -> + (* Array: when not symbolic values (for instance, because of aggregates), + the array expressions are introduced as struct updates *) raise (Failure "Attempting to access the fields of an opaque type")) type tc_ctx = { @@ -207,7 +199,7 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit = assert (Option.is_some loop.back_output_tys || loop.loop_body.ty = e.ty); check_texpression ctx loop.fun_end; check_texpression ctx loop.loop_body - | StructUpdate supd -> + | StructUpdate supd -> ( (* Check the init value *) (if Option.is_some supd.init then match VarId.Map.find_opt (Option.get supd.init) ctx.env with @@ -216,18 +208,29 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit = (* Check the fields *) (* Retrieve and check the expected field type *) let adt_id, adt_type_args, adt_cg_args = ty_as_adt e.ty in - assert (adt_id = AdtId supd.struct_id); - let variant_id = None in - let expected_field_tys = - get_adt_field_types ctx.type_decls adt_id variant_id adt_type_args - adt_cg_args - in - List.iter - (fun (fid, fe) -> - let expected_field_ty = FieldId.nth expected_field_tys fid in - assert (expected_field_ty = fe.ty); - check_texpression ctx fe) - supd.updates + assert (adt_id = supd.struct_id); + (* The id can only be: a custom type decl or an array *) + match adt_id with + | AdtId _ -> + let variant_id = None in + let expected_field_tys = + get_adt_field_types ctx.type_decls adt_id variant_id adt_type_args + adt_cg_args + in + List.iter + (fun (fid, fe) -> + let expected_field_ty = FieldId.nth expected_field_tys fid in + assert (expected_field_ty = fe.ty); + check_texpression ctx fe) + supd.updates + | Assumed Array -> + let expected_field_ty = Collections.List.to_cons_nil adt_type_args in + List.iter + (fun (_, fe) -> + assert (expected_field_ty = fe.ty); + check_texpression ctx fe) + supd.updates + | _ -> raise (Failure "Unexpected")) | Meta (_, e_next) -> assert (e_next.ty = e.ty); check_texpression ctx e_next diff --git a/compiler/SymbolicAst.ml b/compiler/SymbolicAst.ml index 787fefc7..7dc94dcd 100644 --- a/compiler/SymbolicAst.ml +++ b/compiler/SymbolicAst.ml @@ -167,7 +167,7 @@ type expression = Contexts.eval_ctx * mplace option * V.symbolic_value - * V.typed_value + * value_aggregate * expression (** We introduce a new symbolic value, equal to some other value. @@ -246,6 +246,13 @@ and expansion = T.integer_type * (V.scalar_value * expression) list * expression (** An integer expansion (i.e, a switch over an integer). The last expression is for the "otherwise" branch. *) + +(* Remark: this type doesn't have to be mutually recursive with the other + types, but it makes it easy to generate the visitors *) +and value_aggregate = + | SingleValue of V.typed_value (** Regular case *) + | Array of V.typed_value list + (** This is used when introducing array aggregates *) [@@deriving show, visitors diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 5e47459d..3512270a 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -2272,19 +2272,34 @@ and translate_ExpandAdt_one_branch (sv : V.symbolic_value) raise (Failure "Unreachable") and translate_intro_symbolic (ectx : C.eval_ctx) (p : S.mplace option) - (sv : V.symbolic_value) (v : V.typed_value) (e : S.expression) + (sv : V.symbolic_value) (v : S.value_aggregate) (e : S.expression) (ctx : bs_ctx) : texpression = let mplace = translate_opt_mplace p in (* Introduce a fresh variable for the symbolic value *) let ctx, var = fresh_var_for_symbolic_value sv ctx in - (* Translate the value *) - let v = typed_value_to_texpression ctx ectx v in - (* Translate the next expression *) let next_e = translate_expression e ctx in + (* Translate the value: there are two cases, depending on whether this + is a "regular" let-binding or an array aggregate. + *) + let v = + match v with + | SingleValue v -> typed_value_to_texpression ctx ectx v + | Array values -> + (* We use a struct update to encode the array aggregate, in order + to preserve the structure and allow generating code of the shape + `[x0, ...., xn]` *) + let values = List.map (typed_value_to_texpression ctx ectx) values in + let values = FieldId.mapi (fun fid v -> (fid, v)) values in + let su : struct_update = + { struct_id = Assumed Array; init = None; updates = values } + in + { e = StructUpdate su; ty = var.ty } + in + (* Make the let-binding *) let monadic = false in let var = mk_typed_pattern_from_var var mplace in -- cgit v1.2.3