diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/PrintPure.ml | 143 | ||||
-rw-r--r-- | src/Pure.ml | 151 | ||||
-rw-r--r-- | src/PureMicroPasses.ml | 2 | ||||
-rw-r--r-- | src/PureUtils.ml | 564 | ||||
-rw-r--r-- | src/SymbolicToPure.ml | 8 |
5 files changed, 497 insertions, 371 deletions
diff --git a/src/PrintPure.ml b/src/PrintPure.ml index 1a38e504..cf8c3f57 100644 --- a/src/PrintPure.ml +++ b/src/PrintPure.ml @@ -234,11 +234,62 @@ let mplace_to_string (fmt : ast_formatter) (p : mplace) : string = let name = name ^ "^" ^ V.VarId.to_string p.var_id ^ "llbc" in projection_to_string fmt name p.projection -let place_to_string (fmt : ast_formatter) (p : place) : string = - (* TODO: improve that *) - let var = fmt.var_id_to_string p.var in - projection_to_string fmt var p.projection +let adt_variant_to_string (fmt : value_formatter) (adt_id : type_id) + (variant_id : VariantId.id option) : string = + match adt_id with + | Tuple -> "Tuple" + | AdtId def_id -> ( + (* "Regular" ADT *) + match variant_id with + | Some vid -> fmt.adt_variant_to_string def_id vid + | None -> fmt.type_decl_id_to_string def_id) + | Assumed aty -> ( + (* Assumed type *) + match aty with + | State -> + (* The `State` type is opaque: we can't get there *) + raise (Failure "Unreachable") + | Result -> + let variant_id = Option.get variant_id in + if variant_id = result_return_id then "@Result::Return" + else if variant_id = result_fail_id then "@Result::Fail" + else + raise (Failure "Unreachable: improper variant id for result type") + | Option -> + let variant_id = Option.get variant_id in + if variant_id = option_some_id then "@Option::Some " + else if variant_id = option_none_id then "@Option::None" + else + raise (Failure "Unreachable: improper variant id for result type") + | Vec -> + assert (variant_id = None); + "Vec") + +let adt_field_to_string (fmt : value_formatter) (adt_id : type_id) + (field_id : FieldId.id) : string = + match adt_id with + | Tuple -> + raise (Failure "Unreachable") + (* Tuples don't use the opaque field id for the field indices, but `int` *) + | AdtId def_id -> ( + (* "Regular" ADT *) + let fields = fmt.adt_field_names def_id None in + match fields with + | None -> FieldId.to_string field_id + | Some fields -> FieldId.nth fields field_id) + | Assumed aty -> ( + (* Assumed type *) + match aty with + | State | Vec -> + (* Opaque types: we can't get there *) + raise (Failure "Unreachable") + | Result | Option -> + (* Enumerations: we can't get there *) + raise (Failure "Unreachable")) +(** TODO: we don't need a general function anymore (it is now only used for + lvalues (i.e., patterns) + *) let adt_g_value_to_string (fmt : value_formatter) (value_to_string : 'v -> string) (variant_id : VariantId.id option) (field_values : 'v list) (ty : ty) : string = @@ -311,17 +362,6 @@ let adt_g_value_to_string (fmt : value_formatter) ^ "\n- ty: " ^ ty_to_string fmt ty ^ "\n- variant_id: " ^ Print.option_to_string VariantId.to_string variant_id)) -let rec typed_rvalue_to_string (fmt : ast_formatter) (v : typed_rvalue) : string - = - match v.value with - | RvConcrete cv -> Print.Values.constant_value_to_string cv - | RvPlace p -> place_to_string fmt p - | RvAdt av -> - adt_g_value_to_string - (ast_to_value_formatter fmt) - (typed_rvalue_to_string fmt) - av.variant_id av.field_values v.ty - let var_or_dummy_to_string (fmt : ast_formatter) (v : var_or_dummy) : string = match v with | Var (v, None) -> var_to_string (ast_to_type_formatter fmt) v @@ -411,33 +451,19 @@ let fun_id_to_string (fmt : ast_formatter) (fun_id : fun_id) : string = | Binop (binop, int_ty) -> binop_to_string binop ^ "<" ^ integer_type_to_string int_ty ^ ">" -let meta_to_string (fmt : ast_formatter) (meta : meta) : string = - let meta = - match meta with - | Assignment (lp, rv, rp) -> - let rp = - match rp with - | None -> "" - | Some rp -> " [@src=" ^ mplace_to_string fmt rp ^ "]" - in - "@assign(" ^ mplace_to_string fmt lp ^ " := " - ^ typed_rvalue_to_string fmt rv - ^ rp ^ ")" - in - "@meta[" ^ meta ^ "]" - (** [inside]: controls the introduction of parentheses *) let rec texpression_to_string (fmt : ast_formatter) (inside : bool) (indent : string) (indent_incr : string) (e : texpression) : string = match e.e with - | Value (v, mp) -> + | Local (var_id, mp) -> let mp = match mp with | None -> "" | Some mp -> " [@mplace=" ^ mplace_to_string fmt mp ^ "]" in - let e = typed_rvalue_to_string fmt v ^ mp in - if inside then "(" ^ e ^ ")" else e + let s = fmt.var_id_to_string var_id ^ mp in + if inside then "(" ^ s ^ ")" else s + | Const cv -> Print.Values.constant_value_to_string cv | App _ -> (* Recursively destruct the app, to have a pair (app, arguments list) *) let app, args = destruct_apps e in @@ -447,8 +473,8 @@ let rec texpression_to_string (fmt : ast_formatter) (inside : bool) let xl, e = destruct_abs_list e in let e = abs_to_string fmt indent indent_incr xl e in if inside then "(" ^ e ^ ")" else e - | Func _ -> - (* Func without arguments *) + | Qualif _ -> + (* Qualifier without arguments *) app_to_string fmt inside indent indent_incr e [] | Let (monadic, lv, re, e) -> let e = let_to_string fmt indent indent_incr monadic lv re e in @@ -466,18 +492,36 @@ and app_to_string (fmt : ast_formatter) (inside : bool) (indent : string) (indent_incr : string) (app : texpression) (args : texpression list) : string = (* There are two possibilities: either the `app` is an instantiated, - * top-level function, or it is a "regular" expression *) + * top-level qualifier (function, ADT constructore...), or it is a "regular" + * expression *) let app, tys = match app.e with - | Func func -> - (* Function case *) - (* Convert the function identifier *) - let fun_id = fun_id_to_string fmt func.func in + | Qualif qualif -> + (* Qualifier case *) + (* Convert the qualifier identifier *) + let qualif_s = + match qualif.id with + | Func fun_id -> fun_id_to_string fmt fun_id + | AdtCons adt_cons_id -> + let variant_s = + adt_variant_to_string + (ast_to_value_formatter fmt) + adt_cons_id.adt_id adt_cons_id.variant_id + in + ConstStrings.constructor_prefix ^ variant_s + | Proj (ProjField (adt_id, field_id)) -> + let value_fmt = ast_to_value_formatter fmt in + let adt_s = adt_variant_to_string value_fmt adt_id None in + let field_s = adt_field_to_string value_fmt adt_id field_id in + (* Adopting an F*-like syntax *) + ConstStrings.constructor_prefix ^ adt_s ^ "?." ^ field_s + | Proj (ProjTuple field_id) -> "MkTuple?." ^ string_of_int field_id + in (* Convert the type instantiation *) let ty_fmt = ast_to_type_formatter fmt in - let tys = List.map (ty_to_string ty_fmt) func.type_params in + let tys = List.map (ty_to_string ty_fmt) qualif.type_params in (* *) - (fun_id, tys) + (qualif_s, tys) | _ -> (* "Regular" expression case *) let inside = args <> [] || (args = [] && inside) in @@ -540,6 +584,21 @@ and switch_to_string (fmt : ast_formatter) (indent : string) let branches = List.map branch_to_string branches in "match " ^ scrut ^ " with\n" ^ String.concat "\n" branches +and meta_to_string (fmt : ast_formatter) (meta : meta) : string = + let meta = + match meta with + | Assignment (lp, rv, rp) -> + let rp = + match rp with + | None -> "" + | Some rp -> " [@src=" ^ mplace_to_string fmt rp ^ "]" + in + "@assign(" ^ mplace_to_string fmt lp ^ " := " + ^ texpression_to_string fmt false "" "" rv + ^ rp ^ ")" + in + "@meta[" ^ meta ^ "]" + let fun_decl_to_string (fmt : ast_formatter) (def : fun_decl) : string = let type_fmt = ast_to_type_formatter fmt in let name = fun_name_to_string def.basename ^ fun_suffix def.back_id in diff --git a/src/Pure.ml b/src/Pure.ml index cd28b035..c72f9dd0 100644 --- a/src/Pure.ml +++ b/src/Pure.ml @@ -149,7 +149,9 @@ type var = { (* TODO: we might want to redefine field_proj_kind here, to prevent field accesses * on enumerations. - * Also: tuples... *) + * Also: tuples... + * Rmk: projections are actually only used as meta-data. + * *) type projection_elem = { pkind : E.field_proj_kind; field_id : FieldId.id } [@@deriving show] @@ -168,9 +170,6 @@ type mplace = { we introduce. *) -(* TODO: there shouldn't be places *) -type place = { var : VarId.id; projection : projection } [@@deriving show] - type variant_id = VariantId.id [@@deriving show] (** Ancestor for [iter_var_or_dummy] visitor *) @@ -182,8 +181,6 @@ class ['self] iter_value_base = method visit_var : 'env -> var -> unit = fun _ _ -> () - method visit_place : 'env -> place -> unit = fun _ _ -> () - method visit_mplace : 'env -> mplace -> unit = fun _ _ -> () method visit_ty : 'env -> ty -> unit = fun _ _ -> () @@ -201,8 +198,6 @@ class ['self] map_value_base = method visit_var : 'env -> var -> var = fun _ x -> x - method visit_place : 'env -> place -> place = fun _ x -> x - method visit_mplace : 'env -> mplace -> mplace = fun _ x -> x method visit_ty : 'env -> ty -> ty = fun _ x -> x @@ -220,8 +215,6 @@ class virtual ['self] reduce_value_base = method visit_var : 'env -> var -> 'a = fun _ _ -> self#zero - method visit_place : 'env -> place -> 'a = fun _ _ -> self#zero - method visit_mplace : 'env -> mplace -> 'a = fun _ _ -> self#zero method visit_ty : 'env -> ty -> 'a = fun _ _ -> self#zero @@ -240,8 +233,6 @@ class virtual ['self] mapreduce_value_base = method visit_var : 'env -> var -> var * 'a = fun _ x -> (x, self#zero) - method visit_place : 'env -> place -> place * 'a = fun _ x -> (x, self#zero) - method visit_mplace : 'env -> mplace -> mplace * 'a = fun _ x -> (x, self#zero) @@ -251,57 +242,6 @@ class virtual ['self] mapreduce_value_base = fun _ x -> (x, self#zero) end -(* TODO: merge with expressions *) -type rvalue = - | RvConcrete of constant_value - | RvPlace of place (* TODO: field projectors should be expressions *) - | RvAdt of adt_rvalue - -and adt_rvalue = { - variant_id : variant_id option; - (* TODO: variant constructors should be expressions, treated in a manner - * similar to functions *) - field_values : typed_rvalue list; -} - -and typed_rvalue = { value : rvalue; ty : ty } -[@@deriving - show, - visitors - { - name = "iter_typed_rvalue"; - variety = "iter"; - ancestors = [ "iter_value_base" ]; - nude = true (* Don't inherit [VisitorsRuntime.iter] *); - concrete = true; - polymorphic = false; - }, - visitors - { - name = "map_typed_rvalue"; - variety = "map"; - ancestors = [ "map_value_base" ]; - nude = true (* Don't inherit [VisitorsRuntime.iter] *); - concrete = true; - polymorphic = false; - }, - visitors - { - name = "reduce_typed_rvalue"; - variety = "reduce"; - ancestors = [ "reduce_value_base" ]; - nude = true (* Don't inherit [VisitorsRuntime.iter] *); - polymorphic = false; - }, - visitors - { - name = "mapreduce_typed_rvalue"; - variety = "mapreduce"; - ancestors = [ "mapreduce_value_base" ]; - nude = true (* Don't inherit [VisitorsRuntime.iter] *); - polymorphic = false; - }] - type var_or_dummy = | Var of var * mplace option (** Rk.: the mdplace is actually always a variable (i.e.: there are no projections). @@ -316,7 +256,7 @@ type var_or_dummy = { name = "iter_var_or_dummy"; variety = "iter"; - ancestors = [ "iter_typed_rvalue" ]; + ancestors = [ "iter_value_base" ]; nude = true (* Don't inherit [VisitorsRuntime.iter] *); concrete = true; polymorphic = false; @@ -325,7 +265,7 @@ type var_or_dummy = { name = "map_var_or_dummy"; variety = "map"; - ancestors = [ "map_typed_rvalue" ]; + ancestors = [ "map_value_base" ]; nude = true (* Don't inherit [VisitorsRuntime.map] *); concrete = true; polymorphic = false; @@ -334,7 +274,7 @@ type var_or_dummy = { name = "reduce_var_or_dummy"; variety = "reduce"; - ancestors = [ "reduce_typed_rvalue" ]; + ancestors = [ "reduce_value_base" ]; nude = true (* Don't inherit [VisitorsRuntime.reduce] *); polymorphic = false; }, @@ -342,7 +282,7 @@ type var_or_dummy = { name = "mapreduce_var_or_dummy"; variety = "mapreduce"; - ancestors = [ "mapreduce_typed_rvalue" ]; + ancestors = [ "mapreduce_value_base" ]; nude = true (* Don't inherit [VisitorsRuntime.reduce] *); polymorphic = false; }] @@ -419,30 +359,43 @@ type fun_id = | Binop of E.binop * integer_type [@@deriving show, ord] -(** Meta-information stored in the AST *) -type meta = Assignment of mplace * typed_rvalue * mplace option +type adt_cons_id = { adt_id : type_id; variant_id : variant_id option } +[@@deriving show] +(** An identifier for an ADT constructor *) + +type proj_kind = ProjField of type_id * FieldId.id | ProjTuple of int [@@deriving show] -type func = { func : fun_id; type_params : ty list } [@@deriving show] -(** A function. +type qualif_id = + | Func of fun_id + | AdtCons of adt_cons_id (** A function or ADT constructor identifier *) + | Proj of proj_kind (** Field projector *) +[@@deriving show] + +type qualif = { + id : qualif_id; + type_params : ty list; (* TODO: rename to type_args *) +} +[@@deriving show] +(** An instantiated qualified. Note that for now we have a clear separation between types and expressions, - which explains why we have the `type_params` field: a function is always - fully instantiated. + which explains why we have the `type_params` field: a function or ADT + constructor is always fully instantiated. *) +type var_id = VarId.id [@@deriving show] + (** Ancestor for [iter_expression] visitor *) class ['self] iter_expression_base = object (_self : 'self) inherit [_] iter_typed_lvalue - method visit_meta : 'env -> meta -> unit = fun _ _ -> () - method visit_integer_type : 'env -> integer_type -> unit = fun _ _ -> () - method visit_scalar_value : 'env -> scalar_value -> unit = fun _ _ -> () + method visit_var_id : 'env -> var_id -> unit = fun _ _ -> () - method visit_func : 'env -> func -> unit = fun _ _ -> () + method visit_qualif : 'env -> qualif -> unit = fun _ _ -> () end (** Ancestor for [map_expression] visitor *) @@ -450,15 +403,12 @@ class ['self] map_expression_base = object (_self : 'self) inherit [_] map_typed_lvalue - method visit_meta : 'env -> meta -> meta = fun _ x -> x - method visit_integer_type : 'env -> integer_type -> integer_type = fun _ x -> x - method visit_scalar_value : 'env -> scalar_value -> scalar_value = - fun _ x -> x + method visit_var_id : 'env -> var_id -> var_id = fun _ x -> x - method visit_func : 'env -> func -> func = fun _ x -> x + method visit_qualif : 'env -> qualif -> qualif = fun _ x -> x end (** Ancestor for [reduce_expression] visitor *) @@ -466,15 +416,12 @@ class virtual ['self] reduce_expression_base = object (self : 'self) inherit [_] reduce_typed_lvalue - method visit_meta : 'env -> meta -> 'a = fun _ _ -> self#zero - method visit_integer_type : 'env -> integer_type -> 'a = fun _ _ -> self#zero - method visit_scalar_value : 'env -> scalar_value -> 'a = - fun _ _ -> self#zero + method visit_var_id : 'env -> var_id -> 'a = fun _ _ -> self#zero - method visit_func : 'env -> func -> 'a = fun _ _ -> self#zero + method visit_qualif : 'env -> qualif -> 'a = fun _ _ -> self#zero end (** Ancestor for [mapreduce_expression] visitor *) @@ -482,15 +429,14 @@ class virtual ['self] mapreduce_expression_base = object (self : 'self) inherit [_] mapreduce_typed_lvalue - method visit_meta : 'env -> meta -> meta * 'a = fun _ x -> (x, self#zero) - method visit_integer_type : 'env -> integer_type -> integer_type * 'a = fun _ x -> (x, self#zero) - method visit_scalar_value : 'env -> scalar_value -> scalar_value * 'a = + method visit_var_id : 'env -> var_id -> var_id * 'a = fun _ x -> (x, self#zero) - method visit_func : 'env -> func -> func * 'a = fun _ x -> (x, self#zero) + method visit_qualif : 'env -> qualif -> qualif * 'a = + fun _ x -> (x, self#zero) end (** **Rk.:** here, [expression] is not at all equivalent to the expressions @@ -498,7 +444,12 @@ class virtual ['self] mapreduce_expression_base = more general than the LLBC statements, in a sense. *) type expression = - | Value of typed_rvalue * mplace option + | Local of var_id * mplace option + (** Local variable - TODO: we name it "Local" because "Var" is used + in [var_or_dummy]: change the name. Maybe rename `Var` and `Dummy` + in `var_or_dummy` to `PatVar` and `PatDummy`. + *) + | Const of constant_value | App of texpression * texpression (** Application of a function to an argument. @@ -509,7 +460,7 @@ type expression = are clashes of field names, some provers like F* get pretty bad...) *) | Abs of typed_lvalue * texpression (** Lambda abstraction: `fun x -> e` *) - | Func of func (** A function - TODO: change to Qualifier *) + | Qualif of qualif (** A top-level qualifier *) | Let of bool * typed_lvalue * texpression * texpression (** Let binding. @@ -550,13 +501,25 @@ type expression = ``` *) | Switch of texpression * switch_body - | Meta of meta * texpression (** Meta-information *) + | Meta of (meta[@opaque]) * texpression (** Meta-information *) and switch_body = If of texpression * texpression | Match of match_branch list and match_branch = { pat : typed_lvalue; branch : texpression } and texpression = { e : expression; ty : ty } + +and mvalue = (texpression[@opaque]) +(** Meta-value (converted to an expression) *) + +and meta = + | Assignment of mplace * mvalue * mplace option + (** Meta-information stored in the AST. + + The first mplace stores the destination. + The mvalue stores the value which is put in the destination + The second (optional) mplace stores the origin. + *) [@@deriving show, visitors diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml index 9ddc71ab..3c25e7b6 100644 --- a/src/PureMicroPasses.ml +++ b/src/PureMicroPasses.ml @@ -1160,7 +1160,7 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx) in let fresh_state_var () = let id = fresh_var_id () in - { id; basename = Some "st"; ty = mk_state_ty } + { id; basename = Some ConstStrings.state_basename; ty = mk_state_ty } in (* It is a very simple map *) let obj = diff --git a/src/PureUtils.ml b/src/PureUtils.ml index 873931be..c01dd5c9 100644 --- a/src/PureUtils.ml +++ b/src/PureUtils.ml @@ -43,96 +43,7 @@ let binop_can_fail (binop : E.binop) : bool = | Div | Rem | Add | Sub | Mul -> true | Shl | Shr -> raise Errors.Unimplemented -let mk_place_from_var (v : var) : place = { var = v.id; projection = [] } - -(** Make a "simplified" tuple type from a list of types: - - if there is exactly one type, just return it - - if there is > one type: wrap them in a tuple - *) -let mk_simpl_tuple_ty (tys : ty list) : ty = - match tys with [ ty ] -> ty | _ -> Adt (Tuple, tys) - -let unit_ty : ty = Adt (Tuple, []) - -let unit_rvalue : typed_rvalue = - let value = RvAdt { variant_id = None; field_values = [] } in - let ty = unit_ty in - { value; ty } - -let mk_typed_rvalue_from_var (v : var) : typed_rvalue = - let value = RvPlace (mk_place_from_var v) in - let ty = v.ty in - { value; ty } - -let mk_typed_lvalue_from_var (v : var) (mp : mplace option) : typed_lvalue = - let value = LvVar (Var (v, mp)) in - let ty = v.ty in - { value; ty } - -(** Make a "simplified" tuple value from a list of values: - - if there is exactly one value, just return it - - if there is > one value: wrap them in a tuple - *) -let mk_simpl_tuple_lvalue (vl : typed_lvalue list) : typed_lvalue = - match vl with - | [ v ] -> v - | _ -> - let tys = List.map (fun (v : typed_lvalue) -> v.ty) vl in - let ty = Adt (Tuple, tys) in - let value = LvAdt { variant_id = None; field_values = vl } in - { value; ty } - -(** Similar to [mk_simpl_tuple_lvalue] *) -let mk_simpl_tuple_rvalue (vl : typed_rvalue list) : typed_rvalue = - match vl with - | [ v ] -> v - | _ -> - let tys = List.map (fun (v : typed_rvalue) -> v.ty) vl in - let ty = Adt (Tuple, tys) in - let value = RvAdt { variant_id = None; field_values = vl } in - { value; ty } - -let mk_adt_lvalue (adt_ty : ty) (variant_id : VariantId.id) - (vl : typed_lvalue list) : typed_lvalue = - let value = LvAdt { variant_id = Some variant_id; field_values = vl } in - { value; ty = adt_ty } - -let ty_as_integer (t : ty) : T.integer_type = - match t with Integer int_ty -> int_ty | _ -> raise (Failure "Unreachable") - -(* TODO: move *) -let type_decl_is_enum (def : T.type_decl) : bool = - match def.kind with T.Struct _ -> false | Enum _ -> true | Opaque -> false - -let mk_state_ty : ty = Adt (Assumed State, []) - -let mk_result_ty (ty : ty) : ty = Adt (Assumed Result, [ ty ]) - -let mk_result_fail_rvalue (ty : ty) : typed_rvalue = - let ty = Adt (Assumed Result, [ ty ]) in - let value = RvAdt { variant_id = Some result_fail_id; field_values = [] } in - { value; ty } - -let mk_result_return_rvalue (v : typed_rvalue) : typed_rvalue = - let ty = Adt (Assumed Result, [ v.ty ]) in - let value = - RvAdt { variant_id = Some result_return_id; field_values = [ v ] } - in - { value; ty } - -let mk_result_fail_lvalue (ty : ty) : typed_lvalue = - let ty = Adt (Assumed Result, [ ty ]) in - let value = LvAdt { variant_id = Some result_fail_id; field_values = [] } in - { value; ty } - -let mk_result_return_lvalue (v : typed_lvalue) : typed_lvalue = - let ty = Adt (Assumed Result, [ v.ty ]) in - let value = - LvAdt { variant_id = Some result_return_id; field_values = [ v ] } - in - { value; ty } - -let mk_arrow_ty (arg_ty : ty) (ret_ty : ty) : ty = Arrow (arg_ty, ret_ty) +(*let mk_arrow_ty (arg_ty : ty) (ret_ty : ty) : ty = Arrow (arg_ty, ret_ty)*) let dest_arrow_ty (ty : ty) : ty * ty = match ty with @@ -150,10 +61,10 @@ let mk_typed_lvalue_from_constant_value (cv : constant_value) : typed_lvalue = let ty = compute_constant_value_ty cv in { value = LvConcrete cv; ty } -let mk_value_expression (v : typed_rvalue) (mp : mplace option) : texpression = - let e = Value (v, mp) in - let ty = v.ty in - { e; ty } +(*let mk_value_expression (v : typed_rvalue) (mp : mplace option) : texpression = + let e = Value (v, mp) in + let ty = v.ty in + { e; ty }*) let mk_let (monadic : bool) (lv : typed_lvalue) (re : texpression) (next_e : texpression) : texpression = @@ -244,9 +155,12 @@ let functions_not_mutually_recursive (funs : fun_decl list) : bool = object inherit [_] iter_expression as super - method! visit_func env func = - if FunIdSet.mem func.func !ids then raise Utils.Found - else super#visit_func env func + method! visit_qualif env qualif = + match qualif.id with + | Func fun_id -> + if FunIdSet.mem fun_id !ids then raise Utils.Found + else super#visit_qualif env qualif + | _ -> super#visit_qualif env qualif end in @@ -266,126 +180,24 @@ let functions_not_mutually_recursive (funs : fun_decl list) : bool = *) let rec let_group_requires_parentheses (e : texpression) : bool = match e.e with - | Value _ | App _ | Func _ | Abs _ -> false + | Local _ | Const _ | App _ | Abs _ | Qualif _ -> false | Let (monadic, _, _, next_e) -> if monadic then true else let_group_requires_parentheses next_e | Switch (_, _) -> false | Meta (_, next_e) -> let_group_requires_parentheses next_e -(** Module to perform type checking - we use this for sanity checks only *) -module TypeCheck = struct - type tc_ctx = { type_decls : type_decl TypeDeclId.Map.t } - - let check_constant_value (ty : ty) (v : constant_value) : unit = - match (ty, v) with - | Integer int_ty, V.Scalar sv -> assert (int_ty = sv.V.int_ty) - | Bool, Bool _ | Char, Char _ | Str, String _ -> () - | _ -> raise (Failure "Inconsistent type") - - let check_adt_g_value (ctx : tc_ctx) (check_value : ty -> 'v -> unit) - (variant_id : VariantId.id option) (field_values : 'v list) (ty : ty) : - unit = - (* Retrieve the field types *) - let field_tys = - match ty with - | Adt (Tuple, tys) -> - (* Tuple *) - tys - | Adt (AdtId def_id, tys) -> - (* "Regular" ADT *) - let def = TypeDeclId.Map.find def_id ctx.type_decls in - type_decl_get_instantiated_fields_types def variant_id tys - | Adt (Assumed aty, tys) -> ( - (* Assumed type *) - match aty with - | State -> - (* `State` is opaque *) - raise (Failure "Unreachable: `State` values are opaque") - | Result -> - let ty = Collections.List.to_cons_nil tys in - let variant_id = Option.get variant_id in - if variant_id = result_return_id then [ ty ] - else if variant_id = result_fail_id then [] - else - raise - (Failure "Unreachable: improper variant id for result type") - | Option -> - let ty = Collections.List.to_cons_nil tys in - let variant_id = Option.get variant_id in - if variant_id = option_some_id then [ ty ] - else if variant_id = option_none_id then [] - else - raise - (Failure "Unreachable: improper variant id for result type") - | Vec -> - assert (variant_id = None); - let ty = Collections.List.to_cons_nil tys in - List.map (fun _ -> ty) field_values) - | _ -> raise (Failure "Inconsistently typed value") - in - (* Check that the field values have the expected types *) - List.iter - (fun (ty, v) -> check_value ty v) - (List.combine field_tys field_values) - - let rec check_typed_lvalue (ctx : tc_ctx) (v : typed_lvalue) : unit = - log#ldebug (lazy ("check_typed_lvalue: " ^ show_typed_lvalue v)); - match v.value with - | LvConcrete cv -> check_constant_value v.ty cv - | LvVar _ -> () - | LvAdt av -> - check_adt_g_value ctx - (fun ty (v : typed_lvalue) -> - if ty <> v.ty then ( - log#serror - ("check_typed_lvalue: not the same types:" ^ "\n- ty: " - ^ show_ty ty ^ "\n- v.ty: " ^ show_ty v.ty); - raise (Failure "Inconsistent types")); - check_typed_lvalue ctx v) - av.variant_id av.field_values v.ty - - let rec check_typed_rvalue (ctx : tc_ctx) (v : typed_rvalue) : unit = - log#ldebug (lazy ("check_typed_rvalue: " ^ show_typed_rvalue v)); - match v.value with - | RvConcrete cv -> check_constant_value v.ty cv - | RvPlace _ -> - (* TODO: *) - () - | RvAdt av -> - check_adt_g_value ctx - (fun ty (v : typed_rvalue) -> - if ty <> v.ty then ( - log#serror - ("check_typed_rvalue: not the same types:" ^ "\n- ty: " - ^ show_ty ty ^ "\n- v.ty: " ^ show_ty v.ty); - raise (Failure "Inconsistent types")); - check_typed_rvalue ctx v) - av.variant_id av.field_values v.ty -end - -let is_value (e : texpression) : bool = - match e.e with Value _ -> true | _ -> false - let is_var (e : texpression) : bool = - match e.e with - | Value (v, _) -> ( - match v.value with - | RvPlace { var = _; projection = [] } -> true - | _ -> false) - | _ -> false + match e.e with Local _ -> true | _ -> false let as_var (e : texpression) : VarId.id = - match e.e with - | Value (v, _) -> ( - match v.value with - | RvPlace { var; projection = [] } -> var - | _ -> raise (Failure "Unreachable")) - | _ -> raise (Failure "Unreachable") + match e.e with Local (v, _) -> v | _ -> raise (Failure "Unreachable") (** Remove the external occurrences of [Meta] *) let rec unmeta (e : texpression) : texpression = match e.e with Meta (_, e) -> unmeta e | _ -> e +let mk_arrow (ty0 : ty) (ty1 : ty) : ty = Arrow (ty0, ty1) + (** Construct a type as a list of arrows: ty1 -> ... tyn *) let mk_arrows (inputs : ty list) (output : ty) = let rec aux (tys : ty list) : ty = @@ -419,16 +231,16 @@ let mk_app (app : texpression) (arg : texpression) : texpression = let mk_apps (app : texpression) (args : texpression list) : texpression = List.fold_left (fun app arg -> mk_app app arg) app args -(** Destruct an expression into a function identifier and a list of arguments, +(** Destruct an expression into a qualif identifier and a list of arguments, * if possible *) -let opt_destruct_function_call (e : texpression) : - (func * texpression list) option = +let opt_destruct_qualif_app (e : texpression) : + (qualif * texpression list) option = let app, args = destruct_apps e in - match app.e with Func func -> Some (func, args) | _ -> None + match app.e with Qualif qualif -> Some (qualif, args) | _ -> None (** Destruct an expression into a function identifier and a list of arguments *) -let destruct_function_call (e : texpression) : func * texpression list = - Option.get (opt_destruct_function_call e) +let destruct_qualif_app (e : texpression) : qualif * texpression list = + Option.get (opt_destruct_qualif_app e) let opt_destruct_result (ty : ty) : ty option = match ty with @@ -440,26 +252,6 @@ let destruct_result (ty : ty) : ty = Option.get (opt_destruct_result ty) let opt_destruct_tuple (ty : ty) : ty list option = match ty with Adt (Tuple, tys) -> Some tys | _ -> None -let opt_destruct_state_monad_result (ty : ty) : ty option = - (* Checking: - * ty == state -> result (state & _) ? *) - match ty with - | Arrow (ty0, ty1) -> - (* ty == ty0 -> ty1 - * Checking: ty0 == state ? - * *) - if ty0 = mk_state_ty then - (* Checking: ty1 == result (state & _) *) - match opt_destruct_result ty1 with - | None -> None - | Some ty2 -> ( - (* Checking: ty2 == state & _ *) - match opt_destruct_tuple ty2 with - | Some [ ty3; ty4 ] -> if ty3 = mk_state_ty then Some ty4 else None - | _ -> None) - else None - | _ -> None - let mk_abs (x : typed_lvalue) (e : texpression) : texpression = let ty = Arrow (x.ty, e.ty) in let e = Abs (x, e) in @@ -475,9 +267,14 @@ let rec destruct_abs_list (e : texpression) : typed_lvalue list * texpression = let destruct_arrow (ty : ty) : ty * ty = match ty with | Arrow (ty0, ty1) -> (ty0, ty1) - | _ -> raise (Failure "Unreachable") + | _ -> raise (Failure "Not an arrow type") -let mk_arrow (ty0 : ty) (ty1 : ty) : ty = Arrow (ty0, ty1) +let rec destruct_arrows (ty : ty) : ty list * ty = + match ty with + | Arrow (ty0, ty1) -> + let tys, out_ty = destruct_arrows ty1 in + (ty0 :: tys, out_ty) + | _ -> ([], ty) let get_switch_body_ty (sb : switch_body) : ty = match sb with @@ -512,3 +309,306 @@ let mk_switch (scrut : texpression) (sb : switch_body) : texpression = (* Put together *) let e = Switch (scrut, sb) in { e; ty } + +(** Make a "simplified" tuple type from a list of types: + - if there is exactly one type, just return it + - if there is > one type: wrap them in a tuple + *) +let mk_simpl_tuple_ty (tys : ty list) : ty = + match tys with [ ty ] -> ty | _ -> Adt (Tuple, tys) + +(** TODO: rename to "mk_..." *) +let unit_ty : ty = Adt (Tuple, []) + +(** TODO: rename to "mk_..." *) +let unit_rvalue : texpression = + let id = AdtCons { adt_id = Tuple; variant_id = None } in + let qualif = { id; type_params = [] } in + let e = Qualif qualif in + let ty = unit_ty in + { e; ty } + +let mk_texpression_from_var (v : var) (mp : mplace option) : texpression = + let e = Local (v.id, mp) in + let ty = v.ty in + { e; ty } + +let mk_typed_lvalue_from_var (v : var) (mp : mplace option) : typed_lvalue = + let value = LvVar (Var (v, mp)) in + let ty = v.ty in + { value; ty } + +(** Make a "simplified" tuple value from a list of values: + - if there is exactly one value, just return it + - if there is > one value: wrap them in a tuple + *) +let mk_simpl_tuple_lvalue (vl : typed_lvalue list) : typed_lvalue = + match vl with + | [ v ] -> v + | _ -> + let tys = List.map (fun (v : typed_lvalue) -> v.ty) vl in + let ty = Adt (Tuple, tys) in + let value = LvAdt { variant_id = None; field_values = vl } in + { value; ty } + +(** Similar to [mk_simpl_tuple_lvalue] *) +let mk_simpl_tuple_texpression (vl : texpression list) : texpression = + match vl with + | [ v ] -> v + | _ -> + (* Compute the types of the fields, and the type of the tuple constructor *) + let tys = List.map (fun (v : texpression) -> v.ty) vl in + let ty = Adt (Tuple, tys) in + let ty = mk_arrows tys ty in + (* Construct the tuple constructor qualifier *) + let id = AdtCons { adt_id = Tuple; variant_id = None } in + let qualif = { id; type_params = tys } in + (* Put everything together *) + let cons = { e = Qualif qualif; ty } in + mk_apps cons vl + +let mk_adt_lvalue (adt_ty : ty) (variant_id : VariantId.id) + (vl : typed_lvalue list) : typed_lvalue = + let value = LvAdt { variant_id = Some variant_id; field_values = vl } in + { value; ty = adt_ty } + +let ty_as_integer (t : ty) : T.integer_type = + match t with Integer int_ty -> int_ty | _ -> raise (Failure "Unreachable") + +(* TODO: move *) +let type_decl_is_enum (def : T.type_decl) : bool = + match def.kind with T.Struct _ -> false | Enum _ -> true | Opaque -> false + +let mk_state_ty : ty = Adt (Assumed State, []) + +let mk_result_ty (ty : ty) : ty = Adt (Assumed Result, [ ty ]) + +let mk_result_fail_rvalue (ty : ty) : texpression = + let type_args = [ ty ] in + let ty = Adt (Assumed Result, type_args) in + let id = + AdtCons { adt_id = Assumed Result; variant_id = Some result_fail_id } + in + let qualif = { id; type_params = type_args } in + let cons_e = Qualif qualif in + let cons_ty = ty in + let cons = { e = cons_e; ty = cons_ty } in + cons + +let mk_result_return_rvalue (v : texpression) : texpression = + let type_args = [ v.ty ] in + let ty = Adt (Assumed Result, type_args) in + let id = + AdtCons { adt_id = Assumed Result; variant_id = Some result_return_id } + in + let qualif = { id; type_params = type_args } in + let cons_e = Qualif qualif in + let cons_ty = mk_arrow v.ty ty in + let cons = { e = cons_e; ty = cons_ty } in + mk_app cons v + +let mk_result_fail_lvalue (ty : ty) : typed_lvalue = + let ty = Adt (Assumed Result, [ ty ]) in + let value = LvAdt { variant_id = Some result_fail_id; field_values = [] } in + { value; ty } + +let mk_result_return_lvalue (v : typed_lvalue) : typed_lvalue = + let ty = Adt (Assumed Result, [ v.ty ]) in + let value = + LvAdt { variant_id = Some result_return_id; field_values = [ v ] } + in + { value; ty } + +let opt_destruct_state_monad_result (ty : ty) : ty option = + (* Checking: + * ty == state -> result (state & _) ? *) + match ty with + | Arrow (ty0, ty1) -> + (* ty == ty0 -> ty1 + * Checking: ty0 == state ? + * *) + if ty0 = mk_state_ty then + (* Checking: ty1 == result (state & _) *) + match opt_destruct_result ty1 with + | None -> None + | Some ty2 -> ( + (* Checking: ty2 == state & _ *) + match opt_destruct_tuple ty2 with + | Some [ ty3; ty4 ] -> if ty3 = mk_state_ty then Some ty4 else None + | _ -> None) + else None + | _ -> None + +(** Utility function, used for type checking - TODO: move *) +let get_adt_field_types (type_decls : type_decl TypeDeclId.Map.t) + (type_id : type_id) (variant_id : VariantId.id option) (tys : ty list) : + ty list = + match type_id with + | Tuple -> + (* Tuple *) + assert (variant_id = None); + tys + | AdtId def_id -> + (* "Regular" ADT *) + let def = TypeDeclId.Map.find def_id type_decls in + type_decl_get_instantiated_fields_types def variant_id tys + | Assumed aty -> ( + (* Assumed type *) + match aty with + | State -> + (* `State` is opaque *) + raise (Failure "Unreachable: `State` values are opaque") + | Result -> + let ty = Collections.List.to_cons_nil tys in + let variant_id = Option.get variant_id in + if variant_id = result_return_id then [ ty ] + else if variant_id = result_fail_id then [] + else + raise (Failure "Unreachable: improper variant id for result type") + | Option -> + let ty = Collections.List.to_cons_nil tys in + let variant_id = Option.get variant_id in + if variant_id = option_some_id then [ ty ] + else if variant_id = option_none_id then [] + else + raise (Failure "Unreachable: improper variant id for result type") + | Vec -> raise (Failure "Unreachable: `Vector` values are opaque")) + +(** Module to perform type checking - we use this for sanity checks only + + TODO: move to a special file (so that we can also use PrintPure for + debugging) + *) +module TypeCheck = struct + type tc_ctx = { + type_decls : type_decl TypeDeclId.Map.t; (** The type declarations *) + env : ty VarId.Map.t; (** Environment from variables to types *) + } + + let check_constant_value (v : constant_value) (ty : ty) : unit = + match (ty, v) with + | Integer int_ty, V.Scalar sv -> assert (int_ty = sv.V.int_ty) + | Bool, Bool _ | Char, Char _ | Str, String _ -> () + | _ -> raise (Failure "Inconsistent type") + + let rec check_typed_lvalue (ctx : tc_ctx) (v : typed_lvalue) : tc_ctx = + log#ldebug (lazy ("check_typed_lvalue: " ^ show_typed_lvalue v)); + match v.value with + | LvConcrete cv -> + check_constant_value cv v.ty; + ctx + | LvVar Dummy -> ctx + | LvVar (Var (var, _)) -> + assert (var.ty = v.ty); + let env = VarId.Map.add var.id var.ty ctx.env in + { ctx with env } + | LvAdt av -> + (* Compute the field types *) + let type_id, tys = + match v.ty with + | Adt (type_id, tys) -> (type_id, tys) + | _ -> raise (Failure "Inconsistently typed value") + in + let field_tys = + get_adt_field_types ctx.type_decls type_id av.variant_id tys + in + let check_value (ctx : tc_ctx) (ty : ty) (v : typed_lvalue) : tc_ctx = + if ty <> v.ty then ( + log#serror + ("check_typed_lvalue: not the same types:" ^ "\n- ty: " + ^ show_ty ty ^ "\n- v.ty: " ^ show_ty v.ty); + raise (Failure "Inconsistent types")); + check_typed_lvalue ctx v + in + (* Check the field types - TODO: we might also want to check that the + * type of the applied constructor is correct *) + List.fold_left + (fun ctx (ty, v) -> check_value ctx ty v) + ctx + (List.combine field_tys av.field_values) + + let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit = + match e.e with + | Local (var_id, _) -> ( + (* Lookup the variable - note that the variable may not be there, + * if we type-check a subexpression (i.e.: if the variable is introduced + * "outside" of the expression) - TODO: this won't happen once + * we use a locally nameless representation *) + match VarId.Map.find_opt var_id ctx.env with + | None -> () + | Some ty -> assert (ty = e.ty)) + | Const cv -> check_constant_value cv e.ty + | App (app, arg) -> + let input_ty, output_ty = destruct_arrow app.ty in + assert (input_ty = arg.ty); + assert (output_ty = e.ty); + check_texpression ctx app; + check_texpression ctx arg + | Abs (pat, body) -> + let pat_ty, body_ty = destruct_arrow e.ty in + assert (pat.ty = pat_ty); + assert (body.ty = body_ty); + (* Check the pattern and register the introduced variables at the same time *) + let ctx = check_typed_lvalue ctx pat in + check_texpression ctx body + | Qualif qualif -> ( + match qualif.id with + | Func _ -> () (* TODO *) + | Proj (ProjField (type_id, field_id)) -> + (* Note we can only project fields of structurs (not enumerations) *) + let variant_id = None in + let expected_field_tys = + get_adt_field_types ctx.type_decls type_id variant_id + qualif.type_params + in + let expected_field_ty = FieldId.nth expected_field_tys field_id in + let _adt_ty, field_ty = destruct_arrow e.ty in + (* TODO: check the adt_ty *) + assert (expected_field_ty = field_ty) + | Proj (ProjTuple field_id) -> ( + let tuple_ty, field_ty = destruct_arrow e.ty in + match tuple_ty with + | Adt (Tuple, tys) -> + let expected_field_ty = List.nth tys field_id in + assert (field_ty = expected_field_ty) + | _ -> raise (Failure "Inconsistently typed projector")) + | AdtCons id -> + (* TODO: we might also want to check the out type *) + let expected_field_tys = + get_adt_field_types ctx.type_decls id.adt_id id.variant_id + qualif.type_params + in + let field_tys, _ = destruct_arrows e.ty in + assert (expected_field_tys = field_tys)) + | Let (monadic, pat, re, e_next) -> + let expected_pat_ty = + if monadic then destruct_result re.ty else re.ty + in + assert (pat.ty = expected_pat_ty); + assert (e.ty = e_next.ty); + (* Check the right-expression *) + check_texpression ctx re; + (* Check the pattern and register the introduced variables at the same time *) + let ctx = check_typed_lvalue ctx pat in + (* Check the next expression *) + check_texpression ctx e_next + | Switch (scrut, switch_body) -> ( + check_texpression ctx scrut; + match switch_body with + | If (e_then, e_else) -> + assert (scrut.ty = Bool); + assert (e_then.ty = e.ty); + assert (e_else.ty = e.ty); + check_texpression ctx e_then; + check_texpression ctx e_else + | Match branches -> + let check_branch (br : match_branch) : unit = + assert (br.pat.ty = scrut.ty); + let ctx = check_typed_lvalue ctx br.pat in + check_texpression ctx br.branch + in + List.iter check_branch branches) + | Meta (_, e_next) -> + assert (e_next.ty = e.ty); + check_texpression ctx e_next +end diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml index 64f8f481..6606ca25 100644 --- a/src/SymbolicToPure.ml +++ b/src/SymbolicToPure.ml @@ -999,7 +999,9 @@ and translate_panic (config : config) (ctx : bs_ctx) : texpression = let ty = v.ty in let e = { e; ty } in (* Add the lambda *) - let _, state_var = fresh_var (Some "st") mk_state_ty ctx in + let _, state_var = + fresh_var (Some ConstStrings.state_basename) mk_state_ty ctx + in let state_lvalue = mk_typed_lvalue_from_var state_var None in mk_abs state_lvalue e else @@ -1026,7 +1028,9 @@ and translate_return (config : config) (opt_v : V.typed_value option) * *) (* TODO: we should use a `return` function, it would be cleaner *) if config.use_state_monad then - let _, state_var = fresh_var (Some "st") mk_state_ty ctx in + let _, state_var = + fresh_var (Some ConstStrings.state_basename) mk_state_ty ctx + in let state_rvalue = mk_typed_rvalue_from_var state_var in let v = mk_result_return_rvalue (mk_simpl_tuple_rvalue [ state_rvalue; v ]) |