summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon Ho2023-08-03 16:21:43 +0200
committerSon Ho2023-08-03 16:21:43 +0200
commit931fabe3e8590815548d606b33fc8db31e9f6010 (patch)
treeba99ca0412c8e08cd8e89edbbd287c3b306ebfd8
parentfa682c18c8ffc5fa7224d9e9d0e0dd94250ada57 (diff)
Fix an issue with the extraction of aggregated arrays
Diffstat (limited to '')
-rw-r--r--compiler/Extract.ml218
-rw-r--r--compiler/InterpreterExpressions.ml4
-rw-r--r--compiler/InterpreterLoopsFixedPoint.ml2
-rw-r--r--compiler/InterpreterLoopsMatchCtxs.mli6
-rw-r--r--compiler/PrintPure.ml38
-rw-r--r--compiler/Pure.ml139
-rw-r--r--compiler/PureMicroPasses.ml6
-rw-r--r--compiler/PureTypeCheck.ml55
-rw-r--r--compiler/SymbolicAst.ml9
-rw-r--r--compiler/SymbolicToPure.ml23
10 files changed, 320 insertions, 180 deletions
diff --git a/compiler/Extract.ml b/compiler/Extract.ml
index 9ee94db2..f161cc13 100644
--- a/compiler/Extract.ml
+++ b/compiler/Extract.ml
@@ -2260,7 +2260,7 @@ let rec extract_texpression (ctx : extraction_ctx) (fmt : F.formatter)
| Let (_, _, _, _) -> extract_lets ctx fmt inside e
| Switch (scrut, body) -> extract_Switch ctx fmt inside scrut body
| Meta (_, e) -> extract_texpression ctx fmt inside e
- | StructUpdate supd -> extract_StructUpdate ctx fmt inside supd
+ | StructUpdate supd -> extract_StructUpdate ctx fmt inside e.ty supd
| Loop _ ->
(* The loop nodes should have been eliminated in {!PureMicroPasses} *)
raise (Failure "Unreachable")
@@ -2723,104 +2723,152 @@ and extract_Switch (ctx : extraction_ctx) (fmt : F.formatter) (_inside : bool)
F.pp_close_box fmt ()
and extract_StructUpdate (ctx : extraction_ctx) (fmt : F.formatter)
- (inside : bool) (supd : struct_update) : unit =
+ (inside : bool) (e_ty : ty) (supd : struct_update) : unit =
(* We can't update a subset of the fields in Coq (i.e., we can do
[{| x:= 3; y := 4 |}], but there is no syntax for [{| s with x := 3 |}]) *)
assert (!backend <> Coq || supd.init = None);
(* In the case of HOL4, records with no fields are not supported and are
thus extracted to unit. We need to check that by looking up the definition *)
let extract_as_unit =
- if !backend = HOL4 then
- let d =
- TypeDeclId.Map.find supd.struct_id ctx.trans_ctx.type_context.type_decls
- in
- d.kind = Struct []
- else false
+ match (!backend, supd.struct_id) with
+ | HOL4, AdtId adt_id ->
+ let d =
+ TypeDeclId.Map.find adt_id ctx.trans_ctx.type_context.type_decls
+ in
+ d.kind = Struct []
+ | _ -> false
in
if extract_as_unit then
(* Remark: this is only valid for HOL4 (for instance the Coq unit value is [tt]) *)
F.pp_print_string fmt "()"
- else (
- F.pp_open_hvbox fmt 0;
- F.pp_open_hvbox fmt ctx.indent_incr;
- (* Inner/outer brackets: there are several syntaxes for the field updates.
-
- For instance, in F*:
- {[
- { x with f = ..., ... }
- ]}
-
- In HOL4:
- {[
- x with <| f = ..., ... |>
- ]}
-
- In the above examples:
- - in F*, the { } brackets are "outer" brackets
- - in HOL4, the <| |> brackets are "inner" brackets
+ else
+ (* There are two cases:
+ - this is a regular struct
+ - this is an array
*)
- (* Outer brackets *)
- let olb, orb =
- match !backend with
- | Lean | FStar -> (Some "{", Some "}")
- | Coq -> (Some "{|", Some "|}")
- | HOL4 -> (None, None)
- in
- (* Inner brackets *)
- let ilb, irb =
- match !backend with
- | Lean | FStar | Coq -> (None, None)
- | HOL4 -> (Some "<|", Some "|>")
- in
- (* Helper *)
- let print_bracket (is_left : bool) b =
- match b with
- | Some b ->
- if not is_left then F.pp_print_space fmt ();
- F.pp_print_string fmt b;
- if is_left then F.pp_print_space fmt ()
- | None -> ()
- in
- print_bracket true olb;
- let need_paren = inside && !backend = HOL4 in
- if need_paren then F.pp_print_string fmt "(";
- F.pp_open_hvbox fmt ctx.indent_incr;
- if supd.init <> None then (
- let var_name = ctx_get_var (Option.get supd.init) ctx in
- F.pp_print_string fmt var_name;
- F.pp_print_space fmt ();
- F.pp_print_string fmt "with";
- F.pp_print_space fmt ());
- print_bracket true ilb;
- F.pp_open_hvbox fmt 0;
- let delimiter =
- match !backend with Lean -> "," | Coq | FStar | HOL4 -> ";"
- in
- let assign =
- match !backend with Coq | Lean | HOL4 -> ":=" | FStar -> "="
- in
- Collections.List.iter_link
- (fun () ->
- F.pp_print_string fmt delimiter;
- F.pp_print_space fmt ())
- (fun (fid, fe) ->
+ match supd.struct_id with
+ | AdtId _ ->
+ F.pp_open_hvbox fmt 0;
F.pp_open_hvbox fmt ctx.indent_incr;
- let f = ctx_get_field (AdtId supd.struct_id) fid ctx in
- F.pp_print_string fmt f;
- F.pp_print_string fmt (" " ^ assign);
- F.pp_print_space fmt ();
+ (* Inner/outer brackets: there are several syntaxes for the field updates.
+
+ For instance, in F*:
+ {[
+ { x with f = ..., ... }
+ ]}
+
+ In HOL4:
+ {[
+ x with <| f = ..., ... |>
+ ]}
+
+ In the above examples:
+ - in F*, the { } brackets are "outer" brackets
+ - in HOL4, the <| |> brackets are "inner" brackets
+ *)
+ (* Outer brackets *)
+ let olb, orb =
+ match !backend with
+ | Lean | FStar -> (Some "{", Some "}")
+ | Coq -> (Some "{|", Some "|}")
+ | HOL4 -> (None, None)
+ in
+ (* Inner brackets *)
+ let ilb, irb =
+ match !backend with
+ | Lean | FStar | Coq -> (None, None)
+ | HOL4 -> (Some "<|", Some "|>")
+ in
+ (* Helper *)
+ let print_bracket (is_left : bool) b =
+ match b with
+ | Some b ->
+ if not is_left then F.pp_print_space fmt ();
+ F.pp_print_string fmt b;
+ if is_left then F.pp_print_space fmt ()
+ | None -> ()
+ in
+ print_bracket true olb;
+ let need_paren = inside && !backend = HOL4 in
+ if need_paren then F.pp_print_string fmt "(";
F.pp_open_hvbox fmt ctx.indent_incr;
- extract_texpression ctx fmt true fe;
+ if supd.init <> None then (
+ let var_name = ctx_get_var (Option.get supd.init) ctx in
+ F.pp_print_string fmt var_name;
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "with";
+ F.pp_print_space fmt ());
+ print_bracket true ilb;
+ F.pp_open_hvbox fmt 0;
+ let delimiter =
+ match !backend with Lean -> "," | Coq | FStar | HOL4 -> ";"
+ in
+ let assign =
+ match !backend with Coq | Lean | HOL4 -> ":=" | FStar -> "="
+ in
+ Collections.List.iter_link
+ (fun () ->
+ F.pp_print_string fmt delimiter;
+ F.pp_print_space fmt ())
+ (fun (fid, fe) ->
+ F.pp_open_hvbox fmt ctx.indent_incr;
+ let f = ctx_get_field supd.struct_id fid ctx in
+ F.pp_print_string fmt f;
+ F.pp_print_string fmt (" " ^ assign);
+ F.pp_print_space fmt ();
+ F.pp_open_hvbox fmt ctx.indent_incr;
+ extract_texpression ctx fmt true fe;
+ F.pp_close_box fmt ();
+ F.pp_close_box fmt ())
+ supd.updates;
F.pp_close_box fmt ();
- F.pp_close_box fmt ())
- supd.updates;
- F.pp_close_box fmt ();
- print_bracket false irb;
- F.pp_close_box fmt ();
- F.pp_close_box fmt ();
- if need_paren then F.pp_print_string fmt ")";
- print_bracket false orb;
- F.pp_close_box fmt ())
+ print_bracket false irb;
+ F.pp_close_box fmt ();
+ F.pp_close_box fmt ();
+ if need_paren then F.pp_print_string fmt ")";
+ print_bracket false orb;
+ F.pp_close_box fmt ()
+ | Assumed Array ->
+ (* Open the boxes *)
+ F.pp_open_hvbox fmt ctx.indent_incr;
+ let need_paren = inside in
+ if need_paren then F.pp_print_string fmt "(";
+ (* Open the box for `Array.mk T N [` *)
+ F.pp_open_hovbox fmt ctx.indent_incr;
+ (* Print the array constructor *)
+ let cs = ctx_get_struct false (Assumed Array) ctx in
+ F.pp_print_string fmt cs;
+ (* Print the parameters *)
+ let _, tys, cgs = ty_as_adt e_ty in
+ let ty = Collections.List.to_cons_nil tys in
+ F.pp_print_space fmt ();
+ extract_ty ctx fmt TypeDeclId.Set.empty true ty;
+ let cg = Collections.List.to_cons_nil cgs in
+ F.pp_print_space fmt ();
+ extract_const_generic ctx fmt true cg;
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "[";
+ (* Close the box for `Array.mk T N [` *)
+ F.pp_close_box fmt ();
+ (* Print the values *)
+ let delimiter =
+ match !backend with Lean -> "," | Coq | FStar | HOL4 -> ";"
+ in
+ F.pp_print_space fmt ();
+ F.pp_open_hovbox fmt 0;
+ Collections.List.iter_link
+ (fun () ->
+ F.pp_print_string fmt delimiter;
+ F.pp_print_space fmt ())
+ (fun (_, fe) -> extract_texpression ctx fmt false fe)
+ supd.updates;
+ (* Close the boxes *)
+ F.pp_close_box fmt ();
+ if supd.updates <> [] then F.pp_print_space fmt ();
+ F.pp_print_string fmt "]";
+ if need_paren then F.pp_print_string fmt ")";
+ F.pp_close_box fmt ()
+ | _ -> raise (Failure "Unreachable")
(** Insert a space, if necessary *)
let insert_req_space (fmt : F.formatter) (space : bool ref) : unit =
diff --git a/compiler/InterpreterExpressions.ml b/compiler/InterpreterExpressions.ml
index 0834cfe2..8b2070c6 100644
--- a/compiler/InterpreterExpressions.ml
+++ b/compiler/InterpreterExpressions.ml
@@ -719,9 +719,7 @@ let eval_rvalue_aggregate (config : C.config)
(* Sanity check: the number of values is consistent with the length *)
let len = (literal_as_scalar (const_generic_as_literal cg)).value in
assert (len = Z.of_int (List.length values));
- let v = V.Adt { variant_id = None; field_values = values } in
let ty = T.Adt (T.Assumed T.Array, [], [ ety ], [ cg ]) in
- let aggregated : V.typed_value = { V.value = v; ty } in
(* In order to generate a better AST, we introduce a symbolic
value equal to the array. The reason is that otherwise, the
array we introduce here might be duplicated in the generated
@@ -736,7 +734,7 @@ let eval_rvalue_aggregate (config : C.config)
| Some e ->
(* Introduce the symbolic value in the AST *)
let sv = ValuesUtils.value_as_symbolic saggregated.value in
- Some (SymbolicAst.IntroSymbolic (ctx, None, sv, aggregated, e)))
+ Some (SymbolicAst.IntroSymbolic (ctx, None, sv, Array values, e)))
in
(* Compose and apply *)
comp eval_ops compute cf
diff --git a/compiler/InterpreterLoopsFixedPoint.ml b/compiler/InterpreterLoopsFixedPoint.ml
index a9ec9ecf..4310f017 100644
--- a/compiler/InterpreterLoopsFixedPoint.ml
+++ b/compiler/InterpreterLoopsFixedPoint.ml
@@ -322,7 +322,7 @@ let prepare_ashared_loans (loop_id : V.LoopId.id option) : cm_fun =
let sv =
V.SymbolicValueId.Map.find sid new_ctx_ids_map.sids_to_values
in
- SymbolicAst.IntroSymbolic (ctx, None, sv, v, e))
+ SymbolicAst.IntroSymbolic (ctx, None, sv, SingleValue v, e))
e !sid_subst)
let prepare_ashared_loans_no_synth (loop_id : V.LoopId.id) (ctx : C.eval_ctx) :
diff --git a/compiler/InterpreterLoopsMatchCtxs.mli b/compiler/InterpreterLoopsMatchCtxs.mli
index d0f57f32..20b997ce 100644
--- a/compiler/InterpreterLoopsMatchCtxs.mli
+++ b/compiler/InterpreterLoopsMatchCtxs.mli
@@ -34,13 +34,13 @@ val compute_abs_borrows_loans_maps :
We use it for joins, to check if two environments are convertible, etc.
See for instance {!MakeJoinMatcher} and {!MakeCheckEquivMatcher}.
- The functor is parameterized by a {!InterpreterLoopsCore.PrimMatcher}, which implements the
- non-generic part of the match. More precisely, the role of {!InterpreterLoopsCore.PrimMatcher} is two
+ The functor is parameterized by a {!PrimMatcher}, which implements the
+ non-generic part of the match. More precisely, the role of {!PrimMatcher} is two
provide generic functions which recursively match two values (by recursively
matching the fields of ADT values for instance). When it does need to match
values in a non-trivial manner (if two ADT values don't have the same
variant for instance) it calls the corresponding specialized function from
- {!InterpreterLoopsCore.PrimMatcher}.
+ {!PrimMatcher}.
*)
module MakeMatcher : functor (_ : PrimMatcher) -> Matcher
diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml
index 211fb2c2..43b11aa5 100644
--- a/compiler/PrintPure.ml
+++ b/compiler/PrintPure.ml
@@ -583,25 +583,39 @@ let rec texpression_to_string (fmt : ast_formatter) (inside : bool)
| Loop loop ->
let e = loop_to_string fmt indent indent_incr loop in
if inside then "(" ^ e ^ ")" else e
- | StructUpdate supd ->
+ | StructUpdate supd -> (
let s =
match supd.init with
| None -> ""
| Some vid -> " " ^ fmt.var_id_to_string vid ^ " with"
in
- let field_names = Option.get (fmt.adt_field_names supd.struct_id None) in
let indent1 = indent ^ indent_incr in
let indent2 = indent1 ^ indent_incr in
- let fields =
- List.map
- (fun (fid, fe) ->
- let field = FieldId.nth field_names fid in
- let fe = texpression_to_string fmt false indent2 indent_incr fe in
- "\n" ^ indent1 ^ field ^ " := " ^ fe ^ ";")
- supd.updates
- in
- let bl = if fields = [] then "" else "\n" ^ indent in
- "{" ^ s ^ String.concat "" fields ^ bl ^ "}"
+ (* The id should be a custom type decl id or an array *)
+ match supd.struct_id with
+ | AdtId aid ->
+ let field_names = Option.get (fmt.adt_field_names aid None) in
+ let fields =
+ List.map
+ (fun (fid, fe) ->
+ let field = FieldId.nth field_names fid in
+ let fe =
+ texpression_to_string fmt false indent2 indent_incr fe
+ in
+ "\n" ^ indent1 ^ field ^ " := " ^ fe ^ ";")
+ supd.updates
+ in
+ let bl = if fields = [] then "" else "\n" ^ indent in
+ "{" ^ s ^ String.concat "" fields ^ bl ^ "}"
+ | Assumed Array ->
+ let fields =
+ List.map
+ (fun (_, fe) ->
+ texpression_to_string fmt false indent2 indent_incr fe)
+ supd.updates
+ in
+ "[ " ^ String.concat ", " fields ^ " ]"
+ | _ -> raise (Failure "Unexpected"))
| Meta (meta, e) -> (
let meta_s = meta_to_string fmt meta in
let e = texpression_to_string fmt inside indent indent_incr e in
diff --git a/compiler/Pure.ml b/compiler/Pure.ml
index e202b170..ac4ca081 100644
--- a/compiler/Pure.ml
+++ b/compiler/Pure.ml
@@ -105,6 +105,29 @@ class ['self] map_type_id_base =
method visit_assumed_ty : 'env -> assumed_ty -> assumed_ty = fun _ x -> x
end
+(** Ancestor for reduce visitor for [ty] *)
+class virtual ['self] reduce_type_id_base =
+ object (self : 'self)
+ inherit [_] VisitorsRuntime.reduce
+
+ method visit_type_decl_id : 'env -> type_decl_id -> 'a =
+ fun _ _ -> self#zero
+
+ method visit_assumed_ty : 'env -> assumed_ty -> 'a = fun _ _ -> self#zero
+ end
+
+(** Ancestor for mapreduce visitor for [ty] *)
+class virtual ['self] mapreduce_type_id_base =
+ object (self : 'self)
+ inherit [_] VisitorsRuntime.mapreduce
+
+ method visit_type_decl_id : 'env -> type_decl_id -> type_decl_id * 'a =
+ fun _ x -> (x, self#zero)
+
+ method visit_assumed_ty : 'env -> assumed_ty -> assumed_ty * 'a =
+ fun _ x -> (x, self#zero)
+ end
+
type type_id = AdtId of type_decl_id | Tuple | Assumed of assumed_ty
[@@deriving
show,
@@ -126,6 +149,22 @@ type type_id = AdtId of type_decl_id | Tuple | Assumed of assumed_ty
nude = true (* Don't inherit {!VisitorsRuntime.iter} *);
concrete = true;
polymorphic = false;
+ },
+ visitors
+ {
+ name = "reduce_type_id";
+ variety = "reduce";
+ ancestors = [ "reduce_type_id_base" ];
+ nude = true (* Don't inherit {!VisitorsRuntime.iter} *);
+ polymorphic = false;
+ },
+ visitors
+ {
+ name = "mapreduce_type_id";
+ variety = "mapreduce";
+ ancestors = [ "mapreduce_type_id_base" ];
+ nude = true (* Don't inherit {!VisitorsRuntime.iter} *);
+ polymorphic = false;
}]
type literal_type = T.literal_type [@@deriving show, ord]
@@ -148,6 +187,26 @@ class ['self] map_ty_base =
method visit_type_var_id : 'env -> type_var_id -> type_var_id = fun _ x -> x
end
+(** Ancestor for reduce visitor for [ty] *)
+class virtual ['self] reduce_ty_base =
+ object (self : 'self)
+ inherit [_] reduce_type_id
+ inherit! [_] T.reduce_const_generic
+ inherit! [_] PV.reduce_literal_type
+ method visit_type_var_id : 'env -> type_var_id -> 'a = fun _ _ -> self#zero
+ end
+
+(** Ancestor for mapreduce visitor for [ty] *)
+class virtual ['self] mapreduce_ty_base =
+ object (self : 'self)
+ inherit [_] mapreduce_type_id
+ inherit! [_] T.mapreduce_const_generic
+ inherit! [_] PV.mapreduce_literal_type
+
+ method visit_type_var_id : 'env -> type_var_id -> type_var_id * 'a =
+ fun _ x -> (x, self#zero)
+ end
+
type ty =
| Adt of type_id * ty list * const_generic list
(** {!Adt} encodes ADTs and tuples and assumed types.
@@ -176,9 +235,25 @@ type ty =
name = "map_ty";
variety = "map";
ancestors = [ "map_ty_base" ];
- nude = true (* Don't inherit {!VisitorsRuntime.iter} *);
+ nude = true (* Don't inherit {!VisitorsRuntime.map} *);
concrete = true;
polymorphic = false;
+ },
+ visitors
+ {
+ name = "reduce_ty";
+ variety = "reduce";
+ ancestors = [ "reduce_ty_base" ];
+ nude = true (* Don't inherit {!VisitorsRuntime.reduce} *);
+ polymorphic = false;
+ },
+ visitors
+ {
+ name = "mapreduce_ty";
+ variety = "mapreduce";
+ ancestors = [ "mapreduce_ty_base" ];
+ nude = true (* Don't inherit {!VisitorsRuntime.mapreduce} *);
+ polymorphic = false;
}]
type field = { field_name : string option; field_ty : ty } [@@deriving show]
@@ -243,51 +318,39 @@ type variant_id = VariantId.id [@@deriving show]
(** Ancestor for {!iter_typed_pattern} visitor *)
class ['self] iter_typed_pattern_base =
object (_self : 'self)
- inherit [_] VisitorsRuntime.iter
- method visit_literal : 'env -> literal -> unit = fun _ _ -> ()
+ inherit [_] iter_ty
method visit_var : 'env -> var -> unit = fun _ _ -> ()
method visit_mplace : 'env -> mplace -> unit = fun _ _ -> ()
- method visit_ty : 'env -> ty -> unit = fun _ _ -> ()
method visit_variant_id : 'env -> variant_id -> unit = fun _ _ -> ()
end
(** Ancestor for {!map_typed_pattern} visitor *)
class ['self] map_typed_pattern_base =
object (_self : 'self)
- inherit [_] VisitorsRuntime.map
- method visit_literal : 'env -> literal -> literal = fun _ x -> x
+ inherit [_] map_ty
method visit_var : 'env -> var -> var = fun _ x -> x
method visit_mplace : 'env -> mplace -> mplace = fun _ x -> x
- method visit_ty : 'env -> ty -> ty = fun _ x -> x
method visit_variant_id : 'env -> variant_id -> variant_id = fun _ x -> x
end
(** Ancestor for {!reduce_typed_pattern} visitor *)
class virtual ['self] reduce_typed_pattern_base =
object (self : 'self)
- inherit [_] VisitorsRuntime.reduce
- method visit_literal : 'env -> literal -> 'a = fun _ _ -> self#zero
+ inherit [_] reduce_ty
method visit_var : 'env -> var -> 'a = fun _ _ -> self#zero
method visit_mplace : 'env -> mplace -> 'a = fun _ _ -> self#zero
- method visit_ty : 'env -> ty -> 'a = fun _ _ -> self#zero
method visit_variant_id : 'env -> variant_id -> 'a = fun _ _ -> self#zero
end
(** Ancestor for {!mapreduce_typed_pattern} visitor *)
class virtual ['self] mapreduce_typed_pattern_base =
object (self : 'self)
- inherit [_] VisitorsRuntime.mapreduce
-
- method visit_literal : 'env -> literal -> literal * 'a =
- fun _ x -> (x, self#zero)
-
+ inherit [_] mapreduce_ty
method visit_var : 'env -> var -> var * 'a = fun _ x -> (x, self#zero)
method visit_mplace : 'env -> mplace -> mplace * 'a =
fun _ x -> (x, self#zero)
- method visit_ty : 'env -> ty -> ty * 'a = fun _ x -> (x, self#zero)
-
method visit_variant_id : 'env -> variant_id -> variant_id * 'a =
fun _ x -> (x, self#zero)
end
@@ -419,11 +482,10 @@ type var_id = VarId.id [@@deriving show, ord]
class ['self] iter_expression_base =
object (_self : 'self)
inherit [_] iter_typed_pattern
- method visit_integer_type : 'env -> integer_type -> unit = fun _ _ -> ()
+ inherit! [_] iter_type_id
method visit_var_id : 'env -> var_id -> unit = fun _ _ -> ()
method visit_qualif : 'env -> qualif -> unit = fun _ _ -> ()
method visit_loop_id : 'env -> loop_id -> unit = fun _ _ -> ()
- method visit_type_decl_id : 'env -> type_decl_id -> unit = fun _ _ -> ()
method visit_field_id : 'env -> field_id -> unit = fun _ _ -> ()
end
@@ -431,17 +493,10 @@ class ['self] iter_expression_base =
class ['self] map_expression_base =
object (_self : 'self)
inherit [_] map_typed_pattern
-
- method visit_integer_type : 'env -> integer_type -> integer_type =
- fun _ x -> x
-
+ inherit! [_] map_type_id
method visit_var_id : 'env -> var_id -> var_id = fun _ x -> x
method visit_qualif : 'env -> qualif -> qualif = fun _ x -> x
method visit_loop_id : 'env -> loop_id -> loop_id = fun _ x -> x
-
- method visit_type_decl_id : 'env -> type_decl_id -> type_decl_id =
- fun _ x -> x
-
method visit_field_id : 'env -> field_id -> field_id = fun _ x -> x
end
@@ -449,17 +504,10 @@ class ['self] map_expression_base =
class virtual ['self] reduce_expression_base =
object (self : 'self)
inherit [_] reduce_typed_pattern
-
- method visit_integer_type : 'env -> integer_type -> 'a =
- fun _ _ -> self#zero
-
+ inherit! [_] reduce_type_id
method visit_var_id : 'env -> var_id -> 'a = fun _ _ -> self#zero
method visit_qualif : 'env -> qualif -> 'a = fun _ _ -> self#zero
method visit_loop_id : 'env -> loop_id -> 'a = fun _ _ -> self#zero
-
- method visit_type_decl_id : 'env -> type_decl_id -> 'a =
- fun _ _ -> self#zero
-
method visit_field_id : 'env -> field_id -> 'a = fun _ _ -> self#zero
end
@@ -467,9 +515,7 @@ class virtual ['self] reduce_expression_base =
class virtual ['self] mapreduce_expression_base =
object (self : 'self)
inherit [_] mapreduce_typed_pattern
-
- method visit_integer_type : 'env -> integer_type -> integer_type * 'a =
- fun _ x -> (x, self#zero)
+ inherit! [_] mapreduce_type_id
method visit_var_id : 'env -> var_id -> var_id * 'a =
fun _ x -> (x, self#zero)
@@ -480,9 +526,6 @@ class virtual ['self] mapreduce_expression_base =
method visit_loop_id : 'env -> loop_id -> loop_id * 'a =
fun _ x -> (x, self#zero)
- method visit_type_decl_id : 'env -> type_decl_id -> type_decl_id * 'a =
- fun _ x -> (x, self#zero)
-
method visit_field_id : 'env -> field_id -> field_id * 'a =
fun _ x -> (x, self#zero)
end
@@ -592,9 +635,21 @@ and loop = {
{[
{ s with x := 3 }
]}
+
+ We also use struct updates to encode array aggregates, so that whenever
+ the user writes code like:
+ {[
+ let a : [u32; 2] = [0, 1];
+ ...
+ ]}
+ this gets encoded to:
+ {[
+ let a : Array u32 2 = Array.mk [0, 1] in
+ ...
+ ]}
*)
and struct_update = {
- struct_id : type_decl_id;
+ struct_id : type_id;
init : var_id option;
updates : (field_id * texpression) list;
}
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml
index 00620c58..58a5f9e2 100644
--- a/compiler/PureMicroPasses.ml
+++ b/compiler/PureMicroPasses.ml
@@ -611,7 +611,7 @@ let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
(!Config.backend <> Lean && !Config.backend <> Coq)
|| not is_rec
then
- let struct_id = adt_id in
+ let struct_id = AdtId adt_id in
let init = None in
let updates =
FieldId.mapi
@@ -1168,8 +1168,8 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
Var v ) ->
(* We check that this is the proper ADT, and the proper field *)
if
- proj_adt_id = struct_id && field_id = fid
- && x.ty = adt_ty
+ AdtId proj_adt_id = struct_id
+ && field_id = fid && x.ty = adt_ty
then Some v
else None
| _ -> None)
diff --git a/compiler/PureTypeCheck.ml b/compiler/PureTypeCheck.ml
index 8f5b5df4..8d28bb8a 100644
--- a/compiler/PureTypeCheck.ml
+++ b/compiler/PureTypeCheck.ml
@@ -5,8 +5,8 @@ open PureUtils
(** Utility function, used for type checking.
- We need the number of fields for cases like `Slice`, when the number of fields
- varies.
+ This function should only be used for "regular" ADTs, where the number
+ of fields is fixed: it shouldn't be used for arrays, slices, etc.
*)
let get_adt_field_types (type_decls : type_decl TypeDeclId.Map.t)
(type_id : type_id) (variant_id : VariantId.id option) (tys : ty list)
@@ -55,17 +55,9 @@ let get_adt_field_types (type_decls : type_decl TypeDeclId.Map.t)
let ty = Collections.List.to_cons_nil tys in
assert (variant_id = None);
[ ty; ty ]
- | Array ->
- let ty = Collections.List.to_cons_nil tys in
- let cg = Collections.List.to_cons_nil cgs in
- let len =
- (PrimitiveValuesUtils.literal_as_scalar
- (TypesUtils.const_generic_as_literal cg))
- .value
- in
- let len = Z.to_int len in
- Collections.List.repeat len ty
- | Vec | Slice | Str ->
+ | Vec | Array | Slice | Str ->
+ (* Array: when not symbolic values (for instance, because of aggregates),
+ the array expressions are introduced as struct updates *)
raise (Failure "Attempting to access the fields of an opaque type"))
type tc_ctx = {
@@ -207,7 +199,7 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit =
assert (Option.is_some loop.back_output_tys || loop.loop_body.ty = e.ty);
check_texpression ctx loop.fun_end;
check_texpression ctx loop.loop_body
- | StructUpdate supd ->
+ | StructUpdate supd -> (
(* Check the init value *)
(if Option.is_some supd.init then
match VarId.Map.find_opt (Option.get supd.init) ctx.env with
@@ -216,18 +208,29 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit =
(* Check the fields *)
(* Retrieve and check the expected field type *)
let adt_id, adt_type_args, adt_cg_args = ty_as_adt e.ty in
- assert (adt_id = AdtId supd.struct_id);
- let variant_id = None in
- let expected_field_tys =
- get_adt_field_types ctx.type_decls adt_id variant_id adt_type_args
- adt_cg_args
- in
- List.iter
- (fun (fid, fe) ->
- let expected_field_ty = FieldId.nth expected_field_tys fid in
- assert (expected_field_ty = fe.ty);
- check_texpression ctx fe)
- supd.updates
+ assert (adt_id = supd.struct_id);
+ (* The id can only be: a custom type decl or an array *)
+ match adt_id with
+ | AdtId _ ->
+ let variant_id = None in
+ let expected_field_tys =
+ get_adt_field_types ctx.type_decls adt_id variant_id adt_type_args
+ adt_cg_args
+ in
+ List.iter
+ (fun (fid, fe) ->
+ let expected_field_ty = FieldId.nth expected_field_tys fid in
+ assert (expected_field_ty = fe.ty);
+ check_texpression ctx fe)
+ supd.updates
+ | Assumed Array ->
+ let expected_field_ty = Collections.List.to_cons_nil adt_type_args in
+ List.iter
+ (fun (_, fe) ->
+ assert (expected_field_ty = fe.ty);
+ check_texpression ctx fe)
+ supd.updates
+ | _ -> raise (Failure "Unexpected"))
| Meta (_, e_next) ->
assert (e_next.ty = e.ty);
check_texpression ctx e_next
diff --git a/compiler/SymbolicAst.ml b/compiler/SymbolicAst.ml
index 787fefc7..7dc94dcd 100644
--- a/compiler/SymbolicAst.ml
+++ b/compiler/SymbolicAst.ml
@@ -167,7 +167,7 @@ type expression =
Contexts.eval_ctx
* mplace option
* V.symbolic_value
- * V.typed_value
+ * value_aggregate
* expression
(** We introduce a new symbolic value, equal to some other value.
@@ -246,6 +246,13 @@ and expansion =
T.integer_type * (V.scalar_value * expression) list * expression
(** An integer expansion (i.e, a switch over an integer). The last
expression is for the "otherwise" branch. *)
+
+(* Remark: this type doesn't have to be mutually recursive with the other
+ types, but it makes it easy to generate the visitors *)
+and value_aggregate =
+ | SingleValue of V.typed_value (** Regular case *)
+ | Array of V.typed_value list
+ (** This is used when introducing array aggregates *)
[@@deriving
show,
visitors
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index 5e47459d..3512270a 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -2272,19 +2272,34 @@ and translate_ExpandAdt_one_branch (sv : V.symbolic_value)
raise (Failure "Unreachable")
and translate_intro_symbolic (ectx : C.eval_ctx) (p : S.mplace option)
- (sv : V.symbolic_value) (v : V.typed_value) (e : S.expression)
+ (sv : V.symbolic_value) (v : S.value_aggregate) (e : S.expression)
(ctx : bs_ctx) : texpression =
let mplace = translate_opt_mplace p in
(* Introduce a fresh variable for the symbolic value *)
let ctx, var = fresh_var_for_symbolic_value sv ctx in
- (* Translate the value *)
- let v = typed_value_to_texpression ctx ectx v in
-
(* Translate the next expression *)
let next_e = translate_expression e ctx in
+ (* Translate the value: there are two cases, depending on whether this
+ is a "regular" let-binding or an array aggregate.
+ *)
+ let v =
+ match v with
+ | SingleValue v -> typed_value_to_texpression ctx ectx v
+ | Array values ->
+ (* We use a struct update to encode the array aggregate, in order
+ to preserve the structure and allow generating code of the shape
+ `[x0, ...., xn]` *)
+ let values = List.map (typed_value_to_texpression ctx ectx) values in
+ let values = FieldId.mapi (fun fid v -> (fid, v)) values in
+ let su : struct_update =
+ { struct_id = Assumed Array; init = None; updates = values }
+ in
+ { e = StructUpdate su; ty = var.ty }
+ in
+
(* Make the let-binding *)
let monadic = false in
let var = mk_typed_pattern_from_var var mplace in