summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon Ho2022-04-29 16:10:19 +0200
committerSon Ho2022-04-29 16:10:19 +0200
commitf64397c472e82d6b001cf6507d7786d7ee90999d (patch)
tree6463792b63137edefbf076c928082c2d68edd619
parent7d24471866e5e486989d78676287bed267c4e5b4 (diff)
Merge the rvalues with the expressions
Diffstat (limited to '')
-rw-r--r--src/PrintPure.ml143
-rw-r--r--src/Pure.ml151
-rw-r--r--src/PureMicroPasses.ml2
-rw-r--r--src/PureUtils.ml564
-rw-r--r--src/SymbolicToPure.ml8
5 files changed, 497 insertions, 371 deletions
diff --git a/src/PrintPure.ml b/src/PrintPure.ml
index 1a38e504..cf8c3f57 100644
--- a/src/PrintPure.ml
+++ b/src/PrintPure.ml
@@ -234,11 +234,62 @@ let mplace_to_string (fmt : ast_formatter) (p : mplace) : string =
let name = name ^ "^" ^ V.VarId.to_string p.var_id ^ "llbc" in
projection_to_string fmt name p.projection
-let place_to_string (fmt : ast_formatter) (p : place) : string =
- (* TODO: improve that *)
- let var = fmt.var_id_to_string p.var in
- projection_to_string fmt var p.projection
+let adt_variant_to_string (fmt : value_formatter) (adt_id : type_id)
+ (variant_id : VariantId.id option) : string =
+ match adt_id with
+ | Tuple -> "Tuple"
+ | AdtId def_id -> (
+ (* "Regular" ADT *)
+ match variant_id with
+ | Some vid -> fmt.adt_variant_to_string def_id vid
+ | None -> fmt.type_decl_id_to_string def_id)
+ | Assumed aty -> (
+ (* Assumed type *)
+ match aty with
+ | State ->
+ (* The `State` type is opaque: we can't get there *)
+ raise (Failure "Unreachable")
+ | Result ->
+ let variant_id = Option.get variant_id in
+ if variant_id = result_return_id then "@Result::Return"
+ else if variant_id = result_fail_id then "@Result::Fail"
+ else
+ raise (Failure "Unreachable: improper variant id for result type")
+ | Option ->
+ let variant_id = Option.get variant_id in
+ if variant_id = option_some_id then "@Option::Some "
+ else if variant_id = option_none_id then "@Option::None"
+ else
+ raise (Failure "Unreachable: improper variant id for result type")
+ | Vec ->
+ assert (variant_id = None);
+ "Vec")
+
+let adt_field_to_string (fmt : value_formatter) (adt_id : type_id)
+ (field_id : FieldId.id) : string =
+ match adt_id with
+ | Tuple ->
+ raise (Failure "Unreachable")
+ (* Tuples don't use the opaque field id for the field indices, but `int` *)
+ | AdtId def_id -> (
+ (* "Regular" ADT *)
+ let fields = fmt.adt_field_names def_id None in
+ match fields with
+ | None -> FieldId.to_string field_id
+ | Some fields -> FieldId.nth fields field_id)
+ | Assumed aty -> (
+ (* Assumed type *)
+ match aty with
+ | State | Vec ->
+ (* Opaque types: we can't get there *)
+ raise (Failure "Unreachable")
+ | Result | Option ->
+ (* Enumerations: we can't get there *)
+ raise (Failure "Unreachable"))
+(** TODO: we don't need a general function anymore (it is now only used for
+ lvalues (i.e., patterns)
+ *)
let adt_g_value_to_string (fmt : value_formatter)
(value_to_string : 'v -> string) (variant_id : VariantId.id option)
(field_values : 'v list) (ty : ty) : string =
@@ -311,17 +362,6 @@ let adt_g_value_to_string (fmt : value_formatter)
^ "\n- ty: " ^ ty_to_string fmt ty ^ "\n- variant_id: "
^ Print.option_to_string VariantId.to_string variant_id))
-let rec typed_rvalue_to_string (fmt : ast_formatter) (v : typed_rvalue) : string
- =
- match v.value with
- | RvConcrete cv -> Print.Values.constant_value_to_string cv
- | RvPlace p -> place_to_string fmt p
- | RvAdt av ->
- adt_g_value_to_string
- (ast_to_value_formatter fmt)
- (typed_rvalue_to_string fmt)
- av.variant_id av.field_values v.ty
-
let var_or_dummy_to_string (fmt : ast_formatter) (v : var_or_dummy) : string =
match v with
| Var (v, None) -> var_to_string (ast_to_type_formatter fmt) v
@@ -411,33 +451,19 @@ let fun_id_to_string (fmt : ast_formatter) (fun_id : fun_id) : string =
| Binop (binop, int_ty) ->
binop_to_string binop ^ "<" ^ integer_type_to_string int_ty ^ ">"
-let meta_to_string (fmt : ast_formatter) (meta : meta) : string =
- let meta =
- match meta with
- | Assignment (lp, rv, rp) ->
- let rp =
- match rp with
- | None -> ""
- | Some rp -> " [@src=" ^ mplace_to_string fmt rp ^ "]"
- in
- "@assign(" ^ mplace_to_string fmt lp ^ " := "
- ^ typed_rvalue_to_string fmt rv
- ^ rp ^ ")"
- in
- "@meta[" ^ meta ^ "]"
-
(** [inside]: controls the introduction of parentheses *)
let rec texpression_to_string (fmt : ast_formatter) (inside : bool)
(indent : string) (indent_incr : string) (e : texpression) : string =
match e.e with
- | Value (v, mp) ->
+ | Local (var_id, mp) ->
let mp =
match mp with
| None -> ""
| Some mp -> " [@mplace=" ^ mplace_to_string fmt mp ^ "]"
in
- let e = typed_rvalue_to_string fmt v ^ mp in
- if inside then "(" ^ e ^ ")" else e
+ let s = fmt.var_id_to_string var_id ^ mp in
+ if inside then "(" ^ s ^ ")" else s
+ | Const cv -> Print.Values.constant_value_to_string cv
| App _ ->
(* Recursively destruct the app, to have a pair (app, arguments list) *)
let app, args = destruct_apps e in
@@ -447,8 +473,8 @@ let rec texpression_to_string (fmt : ast_formatter) (inside : bool)
let xl, e = destruct_abs_list e in
let e = abs_to_string fmt indent indent_incr xl e in
if inside then "(" ^ e ^ ")" else e
- | Func _ ->
- (* Func without arguments *)
+ | Qualif _ ->
+ (* Qualifier without arguments *)
app_to_string fmt inside indent indent_incr e []
| Let (monadic, lv, re, e) ->
let e = let_to_string fmt indent indent_incr monadic lv re e in
@@ -466,18 +492,36 @@ and app_to_string (fmt : ast_formatter) (inside : bool) (indent : string)
(indent_incr : string) (app : texpression) (args : texpression list) :
string =
(* There are two possibilities: either the `app` is an instantiated,
- * top-level function, or it is a "regular" expression *)
+ * top-level qualifier (function, ADT constructore...), or it is a "regular"
+ * expression *)
let app, tys =
match app.e with
- | Func func ->
- (* Function case *)
- (* Convert the function identifier *)
- let fun_id = fun_id_to_string fmt func.func in
+ | Qualif qualif ->
+ (* Qualifier case *)
+ (* Convert the qualifier identifier *)
+ let qualif_s =
+ match qualif.id with
+ | Func fun_id -> fun_id_to_string fmt fun_id
+ | AdtCons adt_cons_id ->
+ let variant_s =
+ adt_variant_to_string
+ (ast_to_value_formatter fmt)
+ adt_cons_id.adt_id adt_cons_id.variant_id
+ in
+ ConstStrings.constructor_prefix ^ variant_s
+ | Proj (ProjField (adt_id, field_id)) ->
+ let value_fmt = ast_to_value_formatter fmt in
+ let adt_s = adt_variant_to_string value_fmt adt_id None in
+ let field_s = adt_field_to_string value_fmt adt_id field_id in
+ (* Adopting an F*-like syntax *)
+ ConstStrings.constructor_prefix ^ adt_s ^ "?." ^ field_s
+ | Proj (ProjTuple field_id) -> "MkTuple?." ^ string_of_int field_id
+ in
(* Convert the type instantiation *)
let ty_fmt = ast_to_type_formatter fmt in
- let tys = List.map (ty_to_string ty_fmt) func.type_params in
+ let tys = List.map (ty_to_string ty_fmt) qualif.type_params in
(* *)
- (fun_id, tys)
+ (qualif_s, tys)
| _ ->
(* "Regular" expression case *)
let inside = args <> [] || (args = [] && inside) in
@@ -540,6 +584,21 @@ and switch_to_string (fmt : ast_formatter) (indent : string)
let branches = List.map branch_to_string branches in
"match " ^ scrut ^ " with\n" ^ String.concat "\n" branches
+and meta_to_string (fmt : ast_formatter) (meta : meta) : string =
+ let meta =
+ match meta with
+ | Assignment (lp, rv, rp) ->
+ let rp =
+ match rp with
+ | None -> ""
+ | Some rp -> " [@src=" ^ mplace_to_string fmt rp ^ "]"
+ in
+ "@assign(" ^ mplace_to_string fmt lp ^ " := "
+ ^ texpression_to_string fmt false "" "" rv
+ ^ rp ^ ")"
+ in
+ "@meta[" ^ meta ^ "]"
+
let fun_decl_to_string (fmt : ast_formatter) (def : fun_decl) : string =
let type_fmt = ast_to_type_formatter fmt in
let name = fun_name_to_string def.basename ^ fun_suffix def.back_id in
diff --git a/src/Pure.ml b/src/Pure.ml
index cd28b035..c72f9dd0 100644
--- a/src/Pure.ml
+++ b/src/Pure.ml
@@ -149,7 +149,9 @@ type var = {
(* TODO: we might want to redefine field_proj_kind here, to prevent field accesses
* on enumerations.
- * Also: tuples... *)
+ * Also: tuples...
+ * Rmk: projections are actually only used as meta-data.
+ * *)
type projection_elem = { pkind : E.field_proj_kind; field_id : FieldId.id }
[@@deriving show]
@@ -168,9 +170,6 @@ type mplace = {
we introduce.
*)
-(* TODO: there shouldn't be places *)
-type place = { var : VarId.id; projection : projection } [@@deriving show]
-
type variant_id = VariantId.id [@@deriving show]
(** Ancestor for [iter_var_or_dummy] visitor *)
@@ -182,8 +181,6 @@ class ['self] iter_value_base =
method visit_var : 'env -> var -> unit = fun _ _ -> ()
- method visit_place : 'env -> place -> unit = fun _ _ -> ()
-
method visit_mplace : 'env -> mplace -> unit = fun _ _ -> ()
method visit_ty : 'env -> ty -> unit = fun _ _ -> ()
@@ -201,8 +198,6 @@ class ['self] map_value_base =
method visit_var : 'env -> var -> var = fun _ x -> x
- method visit_place : 'env -> place -> place = fun _ x -> x
-
method visit_mplace : 'env -> mplace -> mplace = fun _ x -> x
method visit_ty : 'env -> ty -> ty = fun _ x -> x
@@ -220,8 +215,6 @@ class virtual ['self] reduce_value_base =
method visit_var : 'env -> var -> 'a = fun _ _ -> self#zero
- method visit_place : 'env -> place -> 'a = fun _ _ -> self#zero
-
method visit_mplace : 'env -> mplace -> 'a = fun _ _ -> self#zero
method visit_ty : 'env -> ty -> 'a = fun _ _ -> self#zero
@@ -240,8 +233,6 @@ class virtual ['self] mapreduce_value_base =
method visit_var : 'env -> var -> var * 'a = fun _ x -> (x, self#zero)
- method visit_place : 'env -> place -> place * 'a = fun _ x -> (x, self#zero)
-
method visit_mplace : 'env -> mplace -> mplace * 'a =
fun _ x -> (x, self#zero)
@@ -251,57 +242,6 @@ class virtual ['self] mapreduce_value_base =
fun _ x -> (x, self#zero)
end
-(* TODO: merge with expressions *)
-type rvalue =
- | RvConcrete of constant_value
- | RvPlace of place (* TODO: field projectors should be expressions *)
- | RvAdt of adt_rvalue
-
-and adt_rvalue = {
- variant_id : variant_id option;
- (* TODO: variant constructors should be expressions, treated in a manner
- * similar to functions *)
- field_values : typed_rvalue list;
-}
-
-and typed_rvalue = { value : rvalue; ty : ty }
-[@@deriving
- show,
- visitors
- {
- name = "iter_typed_rvalue";
- variety = "iter";
- ancestors = [ "iter_value_base" ];
- nude = true (* Don't inherit [VisitorsRuntime.iter] *);
- concrete = true;
- polymorphic = false;
- },
- visitors
- {
- name = "map_typed_rvalue";
- variety = "map";
- ancestors = [ "map_value_base" ];
- nude = true (* Don't inherit [VisitorsRuntime.iter] *);
- concrete = true;
- polymorphic = false;
- },
- visitors
- {
- name = "reduce_typed_rvalue";
- variety = "reduce";
- ancestors = [ "reduce_value_base" ];
- nude = true (* Don't inherit [VisitorsRuntime.iter] *);
- polymorphic = false;
- },
- visitors
- {
- name = "mapreduce_typed_rvalue";
- variety = "mapreduce";
- ancestors = [ "mapreduce_value_base" ];
- nude = true (* Don't inherit [VisitorsRuntime.iter] *);
- polymorphic = false;
- }]
-
type var_or_dummy =
| Var of var * mplace option
(** Rk.: the mdplace is actually always a variable (i.e.: there are no projections).
@@ -316,7 +256,7 @@ type var_or_dummy =
{
name = "iter_var_or_dummy";
variety = "iter";
- ancestors = [ "iter_typed_rvalue" ];
+ ancestors = [ "iter_value_base" ];
nude = true (* Don't inherit [VisitorsRuntime.iter] *);
concrete = true;
polymorphic = false;
@@ -325,7 +265,7 @@ type var_or_dummy =
{
name = "map_var_or_dummy";
variety = "map";
- ancestors = [ "map_typed_rvalue" ];
+ ancestors = [ "map_value_base" ];
nude = true (* Don't inherit [VisitorsRuntime.map] *);
concrete = true;
polymorphic = false;
@@ -334,7 +274,7 @@ type var_or_dummy =
{
name = "reduce_var_or_dummy";
variety = "reduce";
- ancestors = [ "reduce_typed_rvalue" ];
+ ancestors = [ "reduce_value_base" ];
nude = true (* Don't inherit [VisitorsRuntime.reduce] *);
polymorphic = false;
},
@@ -342,7 +282,7 @@ type var_or_dummy =
{
name = "mapreduce_var_or_dummy";
variety = "mapreduce";
- ancestors = [ "mapreduce_typed_rvalue" ];
+ ancestors = [ "mapreduce_value_base" ];
nude = true (* Don't inherit [VisitorsRuntime.reduce] *);
polymorphic = false;
}]
@@ -419,30 +359,43 @@ type fun_id =
| Binop of E.binop * integer_type
[@@deriving show, ord]
-(** Meta-information stored in the AST *)
-type meta = Assignment of mplace * typed_rvalue * mplace option
+type adt_cons_id = { adt_id : type_id; variant_id : variant_id option }
+[@@deriving show]
+(** An identifier for an ADT constructor *)
+
+type proj_kind = ProjField of type_id * FieldId.id | ProjTuple of int
[@@deriving show]
-type func = { func : fun_id; type_params : ty list } [@@deriving show]
-(** A function.
+type qualif_id =
+ | Func of fun_id
+ | AdtCons of adt_cons_id (** A function or ADT constructor identifier *)
+ | Proj of proj_kind (** Field projector *)
+[@@deriving show]
+
+type qualif = {
+ id : qualif_id;
+ type_params : ty list; (* TODO: rename to type_args *)
+}
+[@@deriving show]
+(** An instantiated qualified.
Note that for now we have a clear separation between types and expressions,
- which explains why we have the `type_params` field: a function is always
- fully instantiated.
+ which explains why we have the `type_params` field: a function or ADT
+ constructor is always fully instantiated.
*)
+type var_id = VarId.id [@@deriving show]
+
(** Ancestor for [iter_expression] visitor *)
class ['self] iter_expression_base =
object (_self : 'self)
inherit [_] iter_typed_lvalue
- method visit_meta : 'env -> meta -> unit = fun _ _ -> ()
-
method visit_integer_type : 'env -> integer_type -> unit = fun _ _ -> ()
- method visit_scalar_value : 'env -> scalar_value -> unit = fun _ _ -> ()
+ method visit_var_id : 'env -> var_id -> unit = fun _ _ -> ()
- method visit_func : 'env -> func -> unit = fun _ _ -> ()
+ method visit_qualif : 'env -> qualif -> unit = fun _ _ -> ()
end
(** Ancestor for [map_expression] visitor *)
@@ -450,15 +403,12 @@ class ['self] map_expression_base =
object (_self : 'self)
inherit [_] map_typed_lvalue
- method visit_meta : 'env -> meta -> meta = fun _ x -> x
-
method visit_integer_type : 'env -> integer_type -> integer_type =
fun _ x -> x
- method visit_scalar_value : 'env -> scalar_value -> scalar_value =
- fun _ x -> x
+ method visit_var_id : 'env -> var_id -> var_id = fun _ x -> x
- method visit_func : 'env -> func -> func = fun _ x -> x
+ method visit_qualif : 'env -> qualif -> qualif = fun _ x -> x
end
(** Ancestor for [reduce_expression] visitor *)
@@ -466,15 +416,12 @@ class virtual ['self] reduce_expression_base =
object (self : 'self)
inherit [_] reduce_typed_lvalue
- method visit_meta : 'env -> meta -> 'a = fun _ _ -> self#zero
-
method visit_integer_type : 'env -> integer_type -> 'a =
fun _ _ -> self#zero
- method visit_scalar_value : 'env -> scalar_value -> 'a =
- fun _ _ -> self#zero
+ method visit_var_id : 'env -> var_id -> 'a = fun _ _ -> self#zero
- method visit_func : 'env -> func -> 'a = fun _ _ -> self#zero
+ method visit_qualif : 'env -> qualif -> 'a = fun _ _ -> self#zero
end
(** Ancestor for [mapreduce_expression] visitor *)
@@ -482,15 +429,14 @@ class virtual ['self] mapreduce_expression_base =
object (self : 'self)
inherit [_] mapreduce_typed_lvalue
- method visit_meta : 'env -> meta -> meta * 'a = fun _ x -> (x, self#zero)
-
method visit_integer_type : 'env -> integer_type -> integer_type * 'a =
fun _ x -> (x, self#zero)
- method visit_scalar_value : 'env -> scalar_value -> scalar_value * 'a =
+ method visit_var_id : 'env -> var_id -> var_id * 'a =
fun _ x -> (x, self#zero)
- method visit_func : 'env -> func -> func * 'a = fun _ x -> (x, self#zero)
+ method visit_qualif : 'env -> qualif -> qualif * 'a =
+ fun _ x -> (x, self#zero)
end
(** **Rk.:** here, [expression] is not at all equivalent to the expressions
@@ -498,7 +444,12 @@ class virtual ['self] mapreduce_expression_base =
more general than the LLBC statements, in a sense.
*)
type expression =
- | Value of typed_rvalue * mplace option
+ | Local of var_id * mplace option
+ (** Local variable - TODO: we name it "Local" because "Var" is used
+ in [var_or_dummy]: change the name. Maybe rename `Var` and `Dummy`
+ in `var_or_dummy` to `PatVar` and `PatDummy`.
+ *)
+ | Const of constant_value
| App of texpression * texpression
(** Application of a function to an argument.
@@ -509,7 +460,7 @@ type expression =
are clashes of field names, some provers like F* get pretty bad...)
*)
| Abs of typed_lvalue * texpression (** Lambda abstraction: `fun x -> e` *)
- | Func of func (** A function - TODO: change to Qualifier *)
+ | Qualif of qualif (** A top-level qualifier *)
| Let of bool * typed_lvalue * texpression * texpression
(** Let binding.
@@ -550,13 +501,25 @@ type expression =
```
*)
| Switch of texpression * switch_body
- | Meta of meta * texpression (** Meta-information *)
+ | Meta of (meta[@opaque]) * texpression (** Meta-information *)
and switch_body = If of texpression * texpression | Match of match_branch list
and match_branch = { pat : typed_lvalue; branch : texpression }
and texpression = { e : expression; ty : ty }
+
+and mvalue = (texpression[@opaque])
+(** Meta-value (converted to an expression) *)
+
+and meta =
+ | Assignment of mplace * mvalue * mplace option
+ (** Meta-information stored in the AST.
+
+ The first mplace stores the destination.
+ The mvalue stores the value which is put in the destination
+ The second (optional) mplace stores the origin.
+ *)
[@@deriving
show,
visitors
diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml
index 9ddc71ab..3c25e7b6 100644
--- a/src/PureMicroPasses.ml
+++ b/src/PureMicroPasses.ml
@@ -1160,7 +1160,7 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx)
in
let fresh_state_var () =
let id = fresh_var_id () in
- { id; basename = Some "st"; ty = mk_state_ty }
+ { id; basename = Some ConstStrings.state_basename; ty = mk_state_ty }
in
(* It is a very simple map *)
let obj =
diff --git a/src/PureUtils.ml b/src/PureUtils.ml
index 873931be..c01dd5c9 100644
--- a/src/PureUtils.ml
+++ b/src/PureUtils.ml
@@ -43,96 +43,7 @@ let binop_can_fail (binop : E.binop) : bool =
| Div | Rem | Add | Sub | Mul -> true
| Shl | Shr -> raise Errors.Unimplemented
-let mk_place_from_var (v : var) : place = { var = v.id; projection = [] }
-
-(** Make a "simplified" tuple type from a list of types:
- - if there is exactly one type, just return it
- - if there is > one type: wrap them in a tuple
- *)
-let mk_simpl_tuple_ty (tys : ty list) : ty =
- match tys with [ ty ] -> ty | _ -> Adt (Tuple, tys)
-
-let unit_ty : ty = Adt (Tuple, [])
-
-let unit_rvalue : typed_rvalue =
- let value = RvAdt { variant_id = None; field_values = [] } in
- let ty = unit_ty in
- { value; ty }
-
-let mk_typed_rvalue_from_var (v : var) : typed_rvalue =
- let value = RvPlace (mk_place_from_var v) in
- let ty = v.ty in
- { value; ty }
-
-let mk_typed_lvalue_from_var (v : var) (mp : mplace option) : typed_lvalue =
- let value = LvVar (Var (v, mp)) in
- let ty = v.ty in
- { value; ty }
-
-(** Make a "simplified" tuple value from a list of values:
- - if there is exactly one value, just return it
- - if there is > one value: wrap them in a tuple
- *)
-let mk_simpl_tuple_lvalue (vl : typed_lvalue list) : typed_lvalue =
- match vl with
- | [ v ] -> v
- | _ ->
- let tys = List.map (fun (v : typed_lvalue) -> v.ty) vl in
- let ty = Adt (Tuple, tys) in
- let value = LvAdt { variant_id = None; field_values = vl } in
- { value; ty }
-
-(** Similar to [mk_simpl_tuple_lvalue] *)
-let mk_simpl_tuple_rvalue (vl : typed_rvalue list) : typed_rvalue =
- match vl with
- | [ v ] -> v
- | _ ->
- let tys = List.map (fun (v : typed_rvalue) -> v.ty) vl in
- let ty = Adt (Tuple, tys) in
- let value = RvAdt { variant_id = None; field_values = vl } in
- { value; ty }
-
-let mk_adt_lvalue (adt_ty : ty) (variant_id : VariantId.id)
- (vl : typed_lvalue list) : typed_lvalue =
- let value = LvAdt { variant_id = Some variant_id; field_values = vl } in
- { value; ty = adt_ty }
-
-let ty_as_integer (t : ty) : T.integer_type =
- match t with Integer int_ty -> int_ty | _ -> raise (Failure "Unreachable")
-
-(* TODO: move *)
-let type_decl_is_enum (def : T.type_decl) : bool =
- match def.kind with T.Struct _ -> false | Enum _ -> true | Opaque -> false
-
-let mk_state_ty : ty = Adt (Assumed State, [])
-
-let mk_result_ty (ty : ty) : ty = Adt (Assumed Result, [ ty ])
-
-let mk_result_fail_rvalue (ty : ty) : typed_rvalue =
- let ty = Adt (Assumed Result, [ ty ]) in
- let value = RvAdt { variant_id = Some result_fail_id; field_values = [] } in
- { value; ty }
-
-let mk_result_return_rvalue (v : typed_rvalue) : typed_rvalue =
- let ty = Adt (Assumed Result, [ v.ty ]) in
- let value =
- RvAdt { variant_id = Some result_return_id; field_values = [ v ] }
- in
- { value; ty }
-
-let mk_result_fail_lvalue (ty : ty) : typed_lvalue =
- let ty = Adt (Assumed Result, [ ty ]) in
- let value = LvAdt { variant_id = Some result_fail_id; field_values = [] } in
- { value; ty }
-
-let mk_result_return_lvalue (v : typed_lvalue) : typed_lvalue =
- let ty = Adt (Assumed Result, [ v.ty ]) in
- let value =
- LvAdt { variant_id = Some result_return_id; field_values = [ v ] }
- in
- { value; ty }
-
-let mk_arrow_ty (arg_ty : ty) (ret_ty : ty) : ty = Arrow (arg_ty, ret_ty)
+(*let mk_arrow_ty (arg_ty : ty) (ret_ty : ty) : ty = Arrow (arg_ty, ret_ty)*)
let dest_arrow_ty (ty : ty) : ty * ty =
match ty with
@@ -150,10 +61,10 @@ let mk_typed_lvalue_from_constant_value (cv : constant_value) : typed_lvalue =
let ty = compute_constant_value_ty cv in
{ value = LvConcrete cv; ty }
-let mk_value_expression (v : typed_rvalue) (mp : mplace option) : texpression =
- let e = Value (v, mp) in
- let ty = v.ty in
- { e; ty }
+(*let mk_value_expression (v : typed_rvalue) (mp : mplace option) : texpression =
+ let e = Value (v, mp) in
+ let ty = v.ty in
+ { e; ty }*)
let mk_let (monadic : bool) (lv : typed_lvalue) (re : texpression)
(next_e : texpression) : texpression =
@@ -244,9 +155,12 @@ let functions_not_mutually_recursive (funs : fun_decl list) : bool =
object
inherit [_] iter_expression as super
- method! visit_func env func =
- if FunIdSet.mem func.func !ids then raise Utils.Found
- else super#visit_func env func
+ method! visit_qualif env qualif =
+ match qualif.id with
+ | Func fun_id ->
+ if FunIdSet.mem fun_id !ids then raise Utils.Found
+ else super#visit_qualif env qualif
+ | _ -> super#visit_qualif env qualif
end
in
@@ -266,126 +180,24 @@ let functions_not_mutually_recursive (funs : fun_decl list) : bool =
*)
let rec let_group_requires_parentheses (e : texpression) : bool =
match e.e with
- | Value _ | App _ | Func _ | Abs _ -> false
+ | Local _ | Const _ | App _ | Abs _ | Qualif _ -> false
| Let (monadic, _, _, next_e) ->
if monadic then true else let_group_requires_parentheses next_e
| Switch (_, _) -> false
| Meta (_, next_e) -> let_group_requires_parentheses next_e
-(** Module to perform type checking - we use this for sanity checks only *)
-module TypeCheck = struct
- type tc_ctx = { type_decls : type_decl TypeDeclId.Map.t }
-
- let check_constant_value (ty : ty) (v : constant_value) : unit =
- match (ty, v) with
- | Integer int_ty, V.Scalar sv -> assert (int_ty = sv.V.int_ty)
- | Bool, Bool _ | Char, Char _ | Str, String _ -> ()
- | _ -> raise (Failure "Inconsistent type")
-
- let check_adt_g_value (ctx : tc_ctx) (check_value : ty -> 'v -> unit)
- (variant_id : VariantId.id option) (field_values : 'v list) (ty : ty) :
- unit =
- (* Retrieve the field types *)
- let field_tys =
- match ty with
- | Adt (Tuple, tys) ->
- (* Tuple *)
- tys
- | Adt (AdtId def_id, tys) ->
- (* "Regular" ADT *)
- let def = TypeDeclId.Map.find def_id ctx.type_decls in
- type_decl_get_instantiated_fields_types def variant_id tys
- | Adt (Assumed aty, tys) -> (
- (* Assumed type *)
- match aty with
- | State ->
- (* `State` is opaque *)
- raise (Failure "Unreachable: `State` values are opaque")
- | Result ->
- let ty = Collections.List.to_cons_nil tys in
- let variant_id = Option.get variant_id in
- if variant_id = result_return_id then [ ty ]
- else if variant_id = result_fail_id then []
- else
- raise
- (Failure "Unreachable: improper variant id for result type")
- | Option ->
- let ty = Collections.List.to_cons_nil tys in
- let variant_id = Option.get variant_id in
- if variant_id = option_some_id then [ ty ]
- else if variant_id = option_none_id then []
- else
- raise
- (Failure "Unreachable: improper variant id for result type")
- | Vec ->
- assert (variant_id = None);
- let ty = Collections.List.to_cons_nil tys in
- List.map (fun _ -> ty) field_values)
- | _ -> raise (Failure "Inconsistently typed value")
- in
- (* Check that the field values have the expected types *)
- List.iter
- (fun (ty, v) -> check_value ty v)
- (List.combine field_tys field_values)
-
- let rec check_typed_lvalue (ctx : tc_ctx) (v : typed_lvalue) : unit =
- log#ldebug (lazy ("check_typed_lvalue: " ^ show_typed_lvalue v));
- match v.value with
- | LvConcrete cv -> check_constant_value v.ty cv
- | LvVar _ -> ()
- | LvAdt av ->
- check_adt_g_value ctx
- (fun ty (v : typed_lvalue) ->
- if ty <> v.ty then (
- log#serror
- ("check_typed_lvalue: not the same types:" ^ "\n- ty: "
- ^ show_ty ty ^ "\n- v.ty: " ^ show_ty v.ty);
- raise (Failure "Inconsistent types"));
- check_typed_lvalue ctx v)
- av.variant_id av.field_values v.ty
-
- let rec check_typed_rvalue (ctx : tc_ctx) (v : typed_rvalue) : unit =
- log#ldebug (lazy ("check_typed_rvalue: " ^ show_typed_rvalue v));
- match v.value with
- | RvConcrete cv -> check_constant_value v.ty cv
- | RvPlace _ ->
- (* TODO: *)
- ()
- | RvAdt av ->
- check_adt_g_value ctx
- (fun ty (v : typed_rvalue) ->
- if ty <> v.ty then (
- log#serror
- ("check_typed_rvalue: not the same types:" ^ "\n- ty: "
- ^ show_ty ty ^ "\n- v.ty: " ^ show_ty v.ty);
- raise (Failure "Inconsistent types"));
- check_typed_rvalue ctx v)
- av.variant_id av.field_values v.ty
-end
-
-let is_value (e : texpression) : bool =
- match e.e with Value _ -> true | _ -> false
-
let is_var (e : texpression) : bool =
- match e.e with
- | Value (v, _) -> (
- match v.value with
- | RvPlace { var = _; projection = [] } -> true
- | _ -> false)
- | _ -> false
+ match e.e with Local _ -> true | _ -> false
let as_var (e : texpression) : VarId.id =
- match e.e with
- | Value (v, _) -> (
- match v.value with
- | RvPlace { var; projection = [] } -> var
- | _ -> raise (Failure "Unreachable"))
- | _ -> raise (Failure "Unreachable")
+ match e.e with Local (v, _) -> v | _ -> raise (Failure "Unreachable")
(** Remove the external occurrences of [Meta] *)
let rec unmeta (e : texpression) : texpression =
match e.e with Meta (_, e) -> unmeta e | _ -> e
+let mk_arrow (ty0 : ty) (ty1 : ty) : ty = Arrow (ty0, ty1)
+
(** Construct a type as a list of arrows: ty1 -> ... tyn *)
let mk_arrows (inputs : ty list) (output : ty) =
let rec aux (tys : ty list) : ty =
@@ -419,16 +231,16 @@ let mk_app (app : texpression) (arg : texpression) : texpression =
let mk_apps (app : texpression) (args : texpression list) : texpression =
List.fold_left (fun app arg -> mk_app app arg) app args
-(** Destruct an expression into a function identifier and a list of arguments,
+(** Destruct an expression into a qualif identifier and a list of arguments,
* if possible *)
-let opt_destruct_function_call (e : texpression) :
- (func * texpression list) option =
+let opt_destruct_qualif_app (e : texpression) :
+ (qualif * texpression list) option =
let app, args = destruct_apps e in
- match app.e with Func func -> Some (func, args) | _ -> None
+ match app.e with Qualif qualif -> Some (qualif, args) | _ -> None
(** Destruct an expression into a function identifier and a list of arguments *)
-let destruct_function_call (e : texpression) : func * texpression list =
- Option.get (opt_destruct_function_call e)
+let destruct_qualif_app (e : texpression) : qualif * texpression list =
+ Option.get (opt_destruct_qualif_app e)
let opt_destruct_result (ty : ty) : ty option =
match ty with
@@ -440,26 +252,6 @@ let destruct_result (ty : ty) : ty = Option.get (opt_destruct_result ty)
let opt_destruct_tuple (ty : ty) : ty list option =
match ty with Adt (Tuple, tys) -> Some tys | _ -> None
-let opt_destruct_state_monad_result (ty : ty) : ty option =
- (* Checking:
- * ty == state -> result (state & _) ? *)
- match ty with
- | Arrow (ty0, ty1) ->
- (* ty == ty0 -> ty1
- * Checking: ty0 == state ?
- * *)
- if ty0 = mk_state_ty then
- (* Checking: ty1 == result (state & _) *)
- match opt_destruct_result ty1 with
- | None -> None
- | Some ty2 -> (
- (* Checking: ty2 == state & _ *)
- match opt_destruct_tuple ty2 with
- | Some [ ty3; ty4 ] -> if ty3 = mk_state_ty then Some ty4 else None
- | _ -> None)
- else None
- | _ -> None
-
let mk_abs (x : typed_lvalue) (e : texpression) : texpression =
let ty = Arrow (x.ty, e.ty) in
let e = Abs (x, e) in
@@ -475,9 +267,14 @@ let rec destruct_abs_list (e : texpression) : typed_lvalue list * texpression =
let destruct_arrow (ty : ty) : ty * ty =
match ty with
| Arrow (ty0, ty1) -> (ty0, ty1)
- | _ -> raise (Failure "Unreachable")
+ | _ -> raise (Failure "Not an arrow type")
-let mk_arrow (ty0 : ty) (ty1 : ty) : ty = Arrow (ty0, ty1)
+let rec destruct_arrows (ty : ty) : ty list * ty =
+ match ty with
+ | Arrow (ty0, ty1) ->
+ let tys, out_ty = destruct_arrows ty1 in
+ (ty0 :: tys, out_ty)
+ | _ -> ([], ty)
let get_switch_body_ty (sb : switch_body) : ty =
match sb with
@@ -512,3 +309,306 @@ let mk_switch (scrut : texpression) (sb : switch_body) : texpression =
(* Put together *)
let e = Switch (scrut, sb) in
{ e; ty }
+
+(** Make a "simplified" tuple type from a list of types:
+ - if there is exactly one type, just return it
+ - if there is > one type: wrap them in a tuple
+ *)
+let mk_simpl_tuple_ty (tys : ty list) : ty =
+ match tys with [ ty ] -> ty | _ -> Adt (Tuple, tys)
+
+(** TODO: rename to "mk_..." *)
+let unit_ty : ty = Adt (Tuple, [])
+
+(** TODO: rename to "mk_..." *)
+let unit_rvalue : texpression =
+ let id = AdtCons { adt_id = Tuple; variant_id = None } in
+ let qualif = { id; type_params = [] } in
+ let e = Qualif qualif in
+ let ty = unit_ty in
+ { e; ty }
+
+let mk_texpression_from_var (v : var) (mp : mplace option) : texpression =
+ let e = Local (v.id, mp) in
+ let ty = v.ty in
+ { e; ty }
+
+let mk_typed_lvalue_from_var (v : var) (mp : mplace option) : typed_lvalue =
+ let value = LvVar (Var (v, mp)) in
+ let ty = v.ty in
+ { value; ty }
+
+(** Make a "simplified" tuple value from a list of values:
+ - if there is exactly one value, just return it
+ - if there is > one value: wrap them in a tuple
+ *)
+let mk_simpl_tuple_lvalue (vl : typed_lvalue list) : typed_lvalue =
+ match vl with
+ | [ v ] -> v
+ | _ ->
+ let tys = List.map (fun (v : typed_lvalue) -> v.ty) vl in
+ let ty = Adt (Tuple, tys) in
+ let value = LvAdt { variant_id = None; field_values = vl } in
+ { value; ty }
+
+(** Similar to [mk_simpl_tuple_lvalue] *)
+let mk_simpl_tuple_texpression (vl : texpression list) : texpression =
+ match vl with
+ | [ v ] -> v
+ | _ ->
+ (* Compute the types of the fields, and the type of the tuple constructor *)
+ let tys = List.map (fun (v : texpression) -> v.ty) vl in
+ let ty = Adt (Tuple, tys) in
+ let ty = mk_arrows tys ty in
+ (* Construct the tuple constructor qualifier *)
+ let id = AdtCons { adt_id = Tuple; variant_id = None } in
+ let qualif = { id; type_params = tys } in
+ (* Put everything together *)
+ let cons = { e = Qualif qualif; ty } in
+ mk_apps cons vl
+
+let mk_adt_lvalue (adt_ty : ty) (variant_id : VariantId.id)
+ (vl : typed_lvalue list) : typed_lvalue =
+ let value = LvAdt { variant_id = Some variant_id; field_values = vl } in
+ { value; ty = adt_ty }
+
+let ty_as_integer (t : ty) : T.integer_type =
+ match t with Integer int_ty -> int_ty | _ -> raise (Failure "Unreachable")
+
+(* TODO: move *)
+let type_decl_is_enum (def : T.type_decl) : bool =
+ match def.kind with T.Struct _ -> false | Enum _ -> true | Opaque -> false
+
+let mk_state_ty : ty = Adt (Assumed State, [])
+
+let mk_result_ty (ty : ty) : ty = Adt (Assumed Result, [ ty ])
+
+let mk_result_fail_rvalue (ty : ty) : texpression =
+ let type_args = [ ty ] in
+ let ty = Adt (Assumed Result, type_args) in
+ let id =
+ AdtCons { adt_id = Assumed Result; variant_id = Some result_fail_id }
+ in
+ let qualif = { id; type_params = type_args } in
+ let cons_e = Qualif qualif in
+ let cons_ty = ty in
+ let cons = { e = cons_e; ty = cons_ty } in
+ cons
+
+let mk_result_return_rvalue (v : texpression) : texpression =
+ let type_args = [ v.ty ] in
+ let ty = Adt (Assumed Result, type_args) in
+ let id =
+ AdtCons { adt_id = Assumed Result; variant_id = Some result_return_id }
+ in
+ let qualif = { id; type_params = type_args } in
+ let cons_e = Qualif qualif in
+ let cons_ty = mk_arrow v.ty ty in
+ let cons = { e = cons_e; ty = cons_ty } in
+ mk_app cons v
+
+let mk_result_fail_lvalue (ty : ty) : typed_lvalue =
+ let ty = Adt (Assumed Result, [ ty ]) in
+ let value = LvAdt { variant_id = Some result_fail_id; field_values = [] } in
+ { value; ty }
+
+let mk_result_return_lvalue (v : typed_lvalue) : typed_lvalue =
+ let ty = Adt (Assumed Result, [ v.ty ]) in
+ let value =
+ LvAdt { variant_id = Some result_return_id; field_values = [ v ] }
+ in
+ { value; ty }
+
+let opt_destruct_state_monad_result (ty : ty) : ty option =
+ (* Checking:
+ * ty == state -> result (state & _) ? *)
+ match ty with
+ | Arrow (ty0, ty1) ->
+ (* ty == ty0 -> ty1
+ * Checking: ty0 == state ?
+ * *)
+ if ty0 = mk_state_ty then
+ (* Checking: ty1 == result (state & _) *)
+ match opt_destruct_result ty1 with
+ | None -> None
+ | Some ty2 -> (
+ (* Checking: ty2 == state & _ *)
+ match opt_destruct_tuple ty2 with
+ | Some [ ty3; ty4 ] -> if ty3 = mk_state_ty then Some ty4 else None
+ | _ -> None)
+ else None
+ | _ -> None
+
+(** Utility function, used for type checking - TODO: move *)
+let get_adt_field_types (type_decls : type_decl TypeDeclId.Map.t)
+ (type_id : type_id) (variant_id : VariantId.id option) (tys : ty list) :
+ ty list =
+ match type_id with
+ | Tuple ->
+ (* Tuple *)
+ assert (variant_id = None);
+ tys
+ | AdtId def_id ->
+ (* "Regular" ADT *)
+ let def = TypeDeclId.Map.find def_id type_decls in
+ type_decl_get_instantiated_fields_types def variant_id tys
+ | Assumed aty -> (
+ (* Assumed type *)
+ match aty with
+ | State ->
+ (* `State` is opaque *)
+ raise (Failure "Unreachable: `State` values are opaque")
+ | Result ->
+ let ty = Collections.List.to_cons_nil tys in
+ let variant_id = Option.get variant_id in
+ if variant_id = result_return_id then [ ty ]
+ else if variant_id = result_fail_id then []
+ else
+ raise (Failure "Unreachable: improper variant id for result type")
+ | Option ->
+ let ty = Collections.List.to_cons_nil tys in
+ let variant_id = Option.get variant_id in
+ if variant_id = option_some_id then [ ty ]
+ else if variant_id = option_none_id then []
+ else
+ raise (Failure "Unreachable: improper variant id for result type")
+ | Vec -> raise (Failure "Unreachable: `Vector` values are opaque"))
+
+(** Module to perform type checking - we use this for sanity checks only
+
+ TODO: move to a special file (so that we can also use PrintPure for
+ debugging)
+ *)
+module TypeCheck = struct
+ type tc_ctx = {
+ type_decls : type_decl TypeDeclId.Map.t; (** The type declarations *)
+ env : ty VarId.Map.t; (** Environment from variables to types *)
+ }
+
+ let check_constant_value (v : constant_value) (ty : ty) : unit =
+ match (ty, v) with
+ | Integer int_ty, V.Scalar sv -> assert (int_ty = sv.V.int_ty)
+ | Bool, Bool _ | Char, Char _ | Str, String _ -> ()
+ | _ -> raise (Failure "Inconsistent type")
+
+ let rec check_typed_lvalue (ctx : tc_ctx) (v : typed_lvalue) : tc_ctx =
+ log#ldebug (lazy ("check_typed_lvalue: " ^ show_typed_lvalue v));
+ match v.value with
+ | LvConcrete cv ->
+ check_constant_value cv v.ty;
+ ctx
+ | LvVar Dummy -> ctx
+ | LvVar (Var (var, _)) ->
+ assert (var.ty = v.ty);
+ let env = VarId.Map.add var.id var.ty ctx.env in
+ { ctx with env }
+ | LvAdt av ->
+ (* Compute the field types *)
+ let type_id, tys =
+ match v.ty with
+ | Adt (type_id, tys) -> (type_id, tys)
+ | _ -> raise (Failure "Inconsistently typed value")
+ in
+ let field_tys =
+ get_adt_field_types ctx.type_decls type_id av.variant_id tys
+ in
+ let check_value (ctx : tc_ctx) (ty : ty) (v : typed_lvalue) : tc_ctx =
+ if ty <> v.ty then (
+ log#serror
+ ("check_typed_lvalue: not the same types:" ^ "\n- ty: "
+ ^ show_ty ty ^ "\n- v.ty: " ^ show_ty v.ty);
+ raise (Failure "Inconsistent types"));
+ check_typed_lvalue ctx v
+ in
+ (* Check the field types - TODO: we might also want to check that the
+ * type of the applied constructor is correct *)
+ List.fold_left
+ (fun ctx (ty, v) -> check_value ctx ty v)
+ ctx
+ (List.combine field_tys av.field_values)
+
+ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit =
+ match e.e with
+ | Local (var_id, _) -> (
+ (* Lookup the variable - note that the variable may not be there,
+ * if we type-check a subexpression (i.e.: if the variable is introduced
+ * "outside" of the expression) - TODO: this won't happen once
+ * we use a locally nameless representation *)
+ match VarId.Map.find_opt var_id ctx.env with
+ | None -> ()
+ | Some ty -> assert (ty = e.ty))
+ | Const cv -> check_constant_value cv e.ty
+ | App (app, arg) ->
+ let input_ty, output_ty = destruct_arrow app.ty in
+ assert (input_ty = arg.ty);
+ assert (output_ty = e.ty);
+ check_texpression ctx app;
+ check_texpression ctx arg
+ | Abs (pat, body) ->
+ let pat_ty, body_ty = destruct_arrow e.ty in
+ assert (pat.ty = pat_ty);
+ assert (body.ty = body_ty);
+ (* Check the pattern and register the introduced variables at the same time *)
+ let ctx = check_typed_lvalue ctx pat in
+ check_texpression ctx body
+ | Qualif qualif -> (
+ match qualif.id with
+ | Func _ -> () (* TODO *)
+ | Proj (ProjField (type_id, field_id)) ->
+ (* Note we can only project fields of structurs (not enumerations) *)
+ let variant_id = None in
+ let expected_field_tys =
+ get_adt_field_types ctx.type_decls type_id variant_id
+ qualif.type_params
+ in
+ let expected_field_ty = FieldId.nth expected_field_tys field_id in
+ let _adt_ty, field_ty = destruct_arrow e.ty in
+ (* TODO: check the adt_ty *)
+ assert (expected_field_ty = field_ty)
+ | Proj (ProjTuple field_id) -> (
+ let tuple_ty, field_ty = destruct_arrow e.ty in
+ match tuple_ty with
+ | Adt (Tuple, tys) ->
+ let expected_field_ty = List.nth tys field_id in
+ assert (field_ty = expected_field_ty)
+ | _ -> raise (Failure "Inconsistently typed projector"))
+ | AdtCons id ->
+ (* TODO: we might also want to check the out type *)
+ let expected_field_tys =
+ get_adt_field_types ctx.type_decls id.adt_id id.variant_id
+ qualif.type_params
+ in
+ let field_tys, _ = destruct_arrows e.ty in
+ assert (expected_field_tys = field_tys))
+ | Let (monadic, pat, re, e_next) ->
+ let expected_pat_ty =
+ if monadic then destruct_result re.ty else re.ty
+ in
+ assert (pat.ty = expected_pat_ty);
+ assert (e.ty = e_next.ty);
+ (* Check the right-expression *)
+ check_texpression ctx re;
+ (* Check the pattern and register the introduced variables at the same time *)
+ let ctx = check_typed_lvalue ctx pat in
+ (* Check the next expression *)
+ check_texpression ctx e_next
+ | Switch (scrut, switch_body) -> (
+ check_texpression ctx scrut;
+ match switch_body with
+ | If (e_then, e_else) ->
+ assert (scrut.ty = Bool);
+ assert (e_then.ty = e.ty);
+ assert (e_else.ty = e.ty);
+ check_texpression ctx e_then;
+ check_texpression ctx e_else
+ | Match branches ->
+ let check_branch (br : match_branch) : unit =
+ assert (br.pat.ty = scrut.ty);
+ let ctx = check_typed_lvalue ctx br.pat in
+ check_texpression ctx br.branch
+ in
+ List.iter check_branch branches)
+ | Meta (_, e_next) ->
+ assert (e_next.ty = e.ty);
+ check_texpression ctx e_next
+end
diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml
index 64f8f481..6606ca25 100644
--- a/src/SymbolicToPure.ml
+++ b/src/SymbolicToPure.ml
@@ -999,7 +999,9 @@ and translate_panic (config : config) (ctx : bs_ctx) : texpression =
let ty = v.ty in
let e = { e; ty } in
(* Add the lambda *)
- let _, state_var = fresh_var (Some "st") mk_state_ty ctx in
+ let _, state_var =
+ fresh_var (Some ConstStrings.state_basename) mk_state_ty ctx
+ in
let state_lvalue = mk_typed_lvalue_from_var state_var None in
mk_abs state_lvalue e
else
@@ -1026,7 +1028,9 @@ and translate_return (config : config) (opt_v : V.typed_value option)
* *)
(* TODO: we should use a `return` function, it would be cleaner *)
if config.use_state_monad then
- let _, state_var = fresh_var (Some "st") mk_state_ty ctx in
+ let _, state_var =
+ fresh_var (Some ConstStrings.state_basename) mk_state_ty ctx
+ in
let state_rvalue = mk_typed_rvalue_from_var state_var in
let v =
mk_result_return_rvalue (mk_simpl_tuple_rvalue [ state_rvalue; v ])