diff options
author | Son Ho | 2023-10-24 17:34:17 +0200 |
---|---|---|
committer | Son Ho | 2023-10-24 17:34:17 +0200 |
commit | c3c7ca132b0dc0c4ea9205876932decda63baca1 (patch) | |
tree | 9b4842202b9f3cb06ae43e6619154e36a5ba01c3 /compiler | |
parent | c27c3052ec3f9a093b06a41f56b3a361cb65e950 (diff) | |
parent | f11d5186b467df318f7c09eedf8b5629c165b453 (diff) |
Merge branch 'son_traits_arrow' into protz_numeric
Diffstat (limited to '')
-rw-r--r-- | compiler/AssociatedTypes.ml | 12 | ||||
-rw-r--r-- | compiler/Assumed.ml | 2 | ||||
-rw-r--r-- | compiler/Extract.ml | 4 | ||||
-rw-r--r-- | compiler/ExtractBase.ml | 2 | ||||
-rw-r--r-- | compiler/FunsAnalysis.ml | 2 | ||||
-rw-r--r-- | compiler/InterpreterExpansion.ml | 2 | ||||
-rw-r--r-- | compiler/InterpreterExpressions.ml | 9 | ||||
-rw-r--r-- | compiler/InterpreterStatements.ml | 78 | ||||
-rw-r--r-- | compiler/InterpreterUtils.ml | 5 | ||||
-rw-r--r-- | compiler/Print.ml | 23 | ||||
-rw-r--r-- | compiler/PrintPure.ml | 22 | ||||
-rw-r--r-- | compiler/PureMicroPasses.ml | 32 | ||||
-rw-r--r-- | compiler/SymbolicToPure.ml | 50 | ||||
-rw-r--r-- | compiler/SynthesizeSymbolic.ml | 2 | ||||
-rw-r--r-- | compiler/Translate.ml | 10 | ||||
-rw-r--r-- | compiler/TypesAnalysis.ml | 8 |
16 files changed, 163 insertions, 100 deletions
diff --git a/compiler/AssociatedTypes.ml b/compiler/AssociatedTypes.ml index 992dade9..022aad2f 100644 --- a/compiler/AssociatedTypes.ml +++ b/compiler/AssociatedTypes.ml @@ -157,7 +157,8 @@ let ctx_add_norm_trait_types_from_preds (ctx : C.eval_ctx) let rec trait_instance_id_is_local_clause (id : 'r T.trait_instance_id) : bool = match id with | T.Self | Clause _ -> true - | TraitImpl _ | BuiltinOrAuto _ | TraitRef _ | UnknownTrait _ -> false + | TraitImpl _ | BuiltinOrAuto _ | TraitRef _ | UnknownTrait _ | FnPointer _ -> + false | ParentClause (id, _, _) | ItemClause (id, _, _, _) -> trait_instance_id_is_local_clause id @@ -187,6 +188,10 @@ let rec ctx_normalize_ty : 'r. 'r norm_ctx -> 'r T.ty -> 'r T.ty = | Ref (r, ty, rkind) -> let ty = ctx_normalize_ty ctx ty in T.Ref (r, ty, rkind) + | Arrow (inputs, output) -> + let inputs = List.map (ctx_normalize_ty ctx) inputs in + let output = ctx_normalize_ty ctx output in + Arrow (inputs, output) | TraitType (trait_ref, generics, type_name) -> ( log#ldebug (lazy @@ -401,6 +406,11 @@ and ctx_normalize_trait_instance_id : assert (trait_instance_id_is_local_clause trait_ref.trait_id); assert (trait_ref.generics = TypesUtils.mk_empty_generic_args); (trait_ref.trait_id, None) + | FnPointer ty -> + let ty = ctx_normalize_ty ctx ty in + (* TODO: we might want to return the ref to the function pointer, + in order to later normalize a call to this function pointer *) + (FnPointer ty, None) | UnknownTrait _ -> (* This is actually an error case *) (id, None) diff --git a/compiler/Assumed.ml b/compiler/Assumed.ml index 109175af..b1ec0660 100644 --- a/compiler/Assumed.ml +++ b/compiler/Assumed.ml @@ -347,7 +347,7 @@ let assumed_infos : assumed_info list = let vec_pre = [ "alloc"; "vec"; "Vec" ] in let index_pre = [ "core"; "ops"; "index" ] in [ - (A.Replace, Sig.mem_replace_sig, false, to_name [ "core"; "mem"; "replace" ]); + (Replace, Sig.mem_replace_sig, false, to_name [ "core"; "mem"; "replace" ]); (BoxNew, Sig.box_new_sig, false, to_name [ "alloc"; "boxed"; "Box"; "new" ]); ( BoxFree, Sig.box_free_sig, diff --git a/compiler/Extract.ml b/compiler/Extract.ml index b842aea1..e24cae16 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -2699,7 +2699,7 @@ and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter) (* Provided method: we see it as a regular function call, and use the function name *) let fun_id = - FromLlbc (FunId (A.Regular method_id.id), lp_id, rg_id) + FromLlbc (FunId (Regular method_id.id), lp_id, rg_id) in let fun_name = ctx_get_function with_opaque_pre fun_id ctx in F.pp_print_string fmt fun_name; @@ -3522,7 +3522,7 @@ let extract_fun_comment (ctx : extraction_ctx) (fmt : F.formatter) (def : fun_decl) : unit = let { keep_fwd; num_backs } = PureUtils.RegularFunIdMap.find - (Pure.FunId (A.Regular def.def_id), def.loop_id, def.back_id) + (Pure.FunId (Regular def.def_id), def.loop_id, def.back_id) ctx.fun_name_info in let comment_pre = "[" ^ Print.fun_name_to_string def.basename ^ "]: " in diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml index 1586e6ed..a921515b 100644 --- a/compiler/ExtractBase.ml +++ b/compiler/ExtractBase.ml @@ -1482,7 +1482,7 @@ let initialize_names_map (fmt : formatter) (init : names_map_init) : names_map = let assumed_functions = List.map (fun (fid, rg, name) -> - (FromLlbc (Pure.FunId (A.Assumed fid), None, rg), name)) + (FromLlbc (Pure.FunId (Assumed fid), None, rg), name)) init.assumed_llbc_functions @ List.map (fun (fid, name) -> (Pure fid, name)) init.assumed_pure_functions in diff --git a/compiler/FunsAnalysis.ml b/compiler/FunsAnalysis.ml index f8aa06dc..9f82b5c9 100644 --- a/compiler/FunsAnalysis.ml +++ b/compiler/FunsAnalysis.ml @@ -94,7 +94,7 @@ let analyze_module (m : crate) (funs_map : fun_decl FunDeclId.Map.t) can_fail := EU.binop_can_fail bop || !can_fail method! visit_Call env call = - (match call.func with + (match call.func.func with | FunId (Regular id) -> if FunDeclId.Set.mem id fun_ids then ( can_diverge := true; diff --git a/compiler/InterpreterExpansion.ml b/compiler/InterpreterExpansion.ml index ea692386..c1041fa3 100644 --- a/compiler/InterpreterExpansion.ml +++ b/compiler/InterpreterExpansion.ml @@ -709,7 +709,7 @@ let greedy_expand_symbolics_with_borrows (config : C.config) : cm_fun = raise (Failure "Attempted to greedily expand an ADT which can't be expanded ") - | T.TypeVar _ | T.Literal _ | Never | T.TraitType _ -> + | T.TypeVar _ | T.Literal _ | Never | T.TraitType _ | T.Arrow _ -> raise (Failure "Unreachable") in (* Compose and continue *) diff --git a/compiler/InterpreterExpressions.ml b/compiler/InterpreterExpressions.ml index 29826233..a42c552a 100644 --- a/compiler/InterpreterExpressions.ml +++ b/compiler/InterpreterExpressions.ml @@ -271,7 +271,7 @@ let eval_operand_no_reorganize (config : C.config) (op : E.operand) match cv.value with | E.CLiteral lit -> cf (literal_to_typed_value (TypesUtils.ty_as_literal cv.ty) lit) ctx - | E.TraitConst (trait_ref, generics, const_name) -> ( + | E.CTraitConst (trait_ref, generics, const_name) -> ( assert (generics = TypesUtils.mk_empty_generic_args); match trait_ref.trait_id with | T.TraitImpl _ -> @@ -329,7 +329,8 @@ let eval_operand_no_reorganize (config : C.config) (op : E.operand) None, value_as_symbolic v.value, SymbolicAst.ConstGenericValue vid, - e )))) + e ))) + | E.CFnPtr _ -> raise (Failure "TODO")) | E.Copy p -> (* Access the value *) let access = Read in @@ -426,7 +427,7 @@ let eval_unary_op_concrete (config : C.config) (unop : E.unop) (op : E.operand) match mk_scalar sv.int_ty i with | Error _ -> cf (Error EPanic) | Ok sv -> cf (Ok { v with V.value = V.Literal (PV.Scalar sv) })) - | E.Cast (src_ty, tgt_ty), V.Literal (PV.Scalar sv) -> ( + | E.Cast (E.CastInteger (src_ty, tgt_ty)), V.Literal (PV.Scalar sv) -> ( assert (src_ty = sv.int_ty); let i = sv.PV.value in match mk_scalar tgt_ty i with @@ -452,7 +453,7 @@ let eval_unary_op_symbolic (config : C.config) (unop : E.unop) (op : E.operand) match (unop, v.V.ty) with | E.Not, (T.Literal Bool as lty) -> lty | E.Neg, (T.Literal (Integer _) as lty) -> lty - | E.Cast (_, tgt_ty), _ -> T.Literal (Integer tgt_ty) + | E.Cast (E.CastInteger (_, tgt_ty)), _ -> T.Literal (Integer tgt_ty) | _ -> raise (Failure "Invalid input for unop") in let res_sv = diff --git a/compiler/InterpreterStatements.ml b/compiler/InterpreterStatements.ml index 36bc3492..9f35c6f2 100644 --- a/compiler/InterpreterStatements.ml +++ b/compiler/InterpreterStatements.ml @@ -306,7 +306,7 @@ let get_assumed_function_return_type (ctx : C.eval_ctx) (fid : A.assumed_fun_id) assert (generics.trait_refs = []); (* [Box::free] has a special treatment *) match fid with - | A.BoxFree -> + | BoxFree -> assert (generics.regions = []); assert (List.length generics.types = 1); assert (generics.const_generics = []); @@ -583,7 +583,7 @@ let eval_vec_function_concrete (_config : C.config) (_fid : A.assumed_fun_id) (** Evaluate a non-local function call in concrete mode *) let eval_assumed_function_call_concrete (config : C.config) (fid : A.assumed_fun_id) (call : A.call) : cm_fun = - let generics = call.generics in + let generics = call.func.generics in let args = call.args in let dest = call.dest in (* Sanity check: we don't fully handle the const generic vars environment @@ -595,7 +595,7 @@ let eval_assumed_function_call_concrete (config : C.config) See {!eval_box_free} *) match fid with - | A.BoxFree -> + | BoxFree -> (* Degenerate case: box_free *) eval_box_free config generics args dest | _ -> @@ -636,7 +636,7 @@ let eval_assumed_function_call_concrete (config : C.config) * access to a body. *) let cf_eval_body : cm_fun = match fid with - | A.Replace -> eval_replace_concrete config generics + | Replace -> eval_replace_concrete config generics | BoxNew -> eval_box_new_concrete config generics | BoxDeref -> eval_box_deref_concrete config generics | BoxDerefMut -> eval_box_deref_mut_concrete config generics @@ -854,15 +854,14 @@ and eval_global (config : C.config) (dest : E.place) (gid : LA.GlobalDeclId.id) match config.mode with | ConcreteMode -> (* Treat the evaluation of the global as a call to the global body (without arguments) *) - let call = + let func = { - A.func = A.FunId (A.Regular global.body_id); + E.func = FunId (Regular global.body_id); generics = TypesUtils.mk_empty_generic_args; trait_and_method_generic_args = None; - args = []; - dest; } in + let call = { A.func; args = []; dest } in (eval_transparent_function_call_concrete config global.body_id call) cf ctx | SymbolicMode -> @@ -1019,29 +1018,28 @@ and eval_function_call (config : C.config) (call : A.call) : st_cm_fun = and eval_function_call_concrete (config : C.config) (call : A.call) : st_cm_fun = fun cf ctx -> - match call.func with - | A.FunId (A.Regular fid) -> + match call.func.func with + | FunId (Regular fid) -> eval_transparent_function_call_concrete config fid call cf ctx - | A.FunId (A.Assumed fid) -> + | FunId (Assumed fid) -> (* Continue - note that we do as if the function call has been successful, * by giving {!Unit} to the continuation, because we place us in the case * where we haven't panicked. Of course, the translation needs to take the * panic case into account... *) eval_assumed_function_call_concrete config fid call (cf Unit) ctx - | A.TraitMethod _ -> raise (Failure "Unimplemented") + | TraitMethod _ -> raise (Failure "Unimplemented") and eval_function_call_symbolic (config : C.config) (call : A.call) : st_cm_fun = - match call.func with - | A.FunId (A.Regular _) | A.TraitMethod _ -> + match call.func.func with + | FunId (Regular _) | TraitMethod _ -> eval_transparent_function_call_symbolic config call - | A.FunId (A.Assumed fid) -> - eval_assumed_function_call_symbolic config fid call + | FunId (Assumed fid) -> eval_assumed_function_call_symbolic config fid call (** Evaluate a local (i.e., non-assumed) function call in concrete mode *) and eval_transparent_function_call_concrete (config : C.config) (fid : A.FunDeclId.id) (call : A.call) : st_cm_fun = - let generics = call.A.generics in + let generics = call.func.generics in let args = call.A.args in let dest = call.A.dest in (* Sanity check: we don't fully handle the const generic vars environment @@ -1195,29 +1193,29 @@ and eval_transparent_function_call_symbolic (config : C.config) (call : A.call) to a trait clause but directly to the method provided in the trait declaration. *) let func, generics, def, inst_sg = - match call.func with - | A.FunId (A.Regular fid) -> + match call.func.func with + | FunId (Regular fid) -> let def = C.ctx_lookup_fun_decl ctx fid in let tr_self = T.UnknownTrait __FUNCTION__ in let inst_sg = - instantiate_fun_sig ctx call.generics tr_self def.A.signature + instantiate_fun_sig ctx call.func.generics tr_self def.A.signature in - (call.func, call.generics, def, inst_sg) - | A.FunId (A.Assumed _) -> + (call.func.func, call.func.generics, def, inst_sg) + | FunId (Assumed _) -> (* Unreachable: must be a transparent function *) raise (Failure "Unreachable") - | A.TraitMethod (trait_ref, method_name, _) -> ( + | TraitMethod (trait_ref, method_name, _) -> ( log#ldebug (lazy ("trait method call:\n- call: " ^ call_to_string ctx call ^ "\n- method name: " ^ method_name ^ "\n- call.generics:\n" - ^ egeneric_args_to_string ctx call.generics + ^ egeneric_args_to_string ctx call.func.generics ^ "\n- trait and method generics:\n" ^ egeneric_args_to_string ctx - (Option.get call.trait_and_method_generic_args))); + (Option.get call.func.trait_and_method_generic_args))); (* When instantiating, we need to group the generics for the trait ref and the method *) - let generics = Option.get call.trait_and_method_generic_args in + let generics = Option.get call.func.trait_and_method_generic_args in (* Lookup the trait method signature - there are several possibilities depending on whethere we call a top-level trait method impl or the method from a local clause *) @@ -1251,7 +1249,7 @@ and eval_transparent_function_call_symbolic (config : C.config) (call : A.call) which implements the method. In order to do this properly, we also need to update the generics. *) - let func = A.FunId (A.Regular id) in + let func = E.FunId (Regular id) in (func, generics, method_def, inst_sg) | None -> (* If not found, lookup the methods provided by the trait *declaration* @@ -1287,7 +1285,7 @@ and eval_transparent_function_call_symbolic (config : C.config) (call : A.call) *) let all_generics = TypesUtils.merge_generic_args - trait_ref.trait_decl_ref.decl_generics call.generics + trait_ref.trait_decl_ref.decl_generics call.func.generics in log#ldebug (lazy @@ -1304,7 +1302,7 @@ and eval_transparent_function_call_symbolic (config : C.config) (call : A.call) instantiate_fun_sig ctx all_generics tr_self method_def.A.signature in - (call.func, call.generics, method_def, inst_sg)) + (call.func.func, call.func.generics, method_def, inst_sg)) | _ -> (* We are using a local clause - we lookup the trait decl *) let trait_decl = @@ -1333,7 +1331,7 @@ and eval_transparent_function_call_symbolic (config : C.config) (call : A.call) let inst_sg = instantiate_fun_sig ctx generics tr_self method_def.A.signature in - (call.func, call.generics, method_def, inst_sg)) + (call.func.func, call.func.generics, method_def, inst_sg)) in (* Sanity check *) assert (List.length call.args = List.length def.A.signature.inputs); @@ -1357,6 +1355,18 @@ and eval_function_call_symbolic_from_inst_sig (config : C.config) (generics : T.egeneric_args) (args : E.operand list) (dest : E.place) : st_cm_fun = fun cf ctx -> + log#ldebug + (lazy + ("eval_function_call_symbolic_from_inst_sig:\n- fid: " + ^ fun_id_or_trait_method_ref_to_string ctx fid + ^ "\n- inst_sg:\n" + ^ inst_fun_sig_to_string ctx inst_sg + ^ "\n- call.generics:\n" + ^ egeneric_args_to_string ctx generics + ^ "\n- args:\n" + ^ String.concat ", " (List.map (operand_to_string ctx) args) + ^ "\n- dest:\n" ^ place_to_string ctx dest)); + (* Generate a fresh symbolic value for the return value *) let ret_sv_ty = inst_sg.A.output in let ret_spc = mk_fresh_symbolic_value V.FunCallRet ret_sv_ty in @@ -1487,7 +1497,7 @@ and eval_function_call_symbolic_from_inst_sig (config : C.config) and eval_assumed_function_call_symbolic (config : C.config) (fid : A.assumed_fun_id) (call : A.call) : st_cm_fun = fun cf ctx -> - let generics = call.generics in + let generics = call.func.generics in let args = call.args in let dest = call.dest in (* Sanity check: make sure the type parameters don't contain regions - @@ -1503,7 +1513,7 @@ and eval_assumed_function_call_symbolic (config : C.config) See {!eval_box_free} *) match fid with - | A.BoxFree -> + | BoxFree -> (* Degenerate case: box_free - note that this is not really a function * call: no need to call a "synthesize_..." function *) eval_box_free config generics args dest (cf Unit) ctx @@ -1514,7 +1524,7 @@ and eval_assumed_function_call_symbolic (config : C.config) * instantiated signatures, and delegate the work to an auxiliary function *) let inst_sig = match fid with - | A.BoxFree -> + | BoxFree -> (* should have been treated above *) raise (Failure "Unreachable") | _ -> @@ -1525,7 +1535,7 @@ and eval_assumed_function_call_symbolic (config : C.config) in (* Evaluate the function call *) - eval_function_call_symbolic_from_inst_sig config (A.FunId (A.Assumed fid)) + eval_function_call_symbolic_from_inst_sig config (FunId (Assumed fid)) inst_sig generics args dest cf ctx (** Evaluate a statement seen as a function body *) diff --git a/compiler/InterpreterUtils.ml b/compiler/InterpreterUtils.ml index 7aaee6ff..6e08e553 100644 --- a/compiler/InterpreterUtils.ml +++ b/compiler/InterpreterUtils.ml @@ -46,6 +46,11 @@ let operand_to_string = PA.operand_to_string let egeneric_args_to_string = PA.egeneric_args_to_string let rtrait_instance_id_to_string = PA.rtrait_instance_id_to_string let fun_sig_to_string = PA.fun_sig_to_string +let inst_fun_sig_to_string = PA.inst_fun_sig_to_string + +let fun_id_or_trait_method_ref_to_string = + PA.fun_id_or_trait_method_ref_to_string + let fun_decl_to_string = PA.fun_decl_to_string let call_to_string = PA.call_to_string diff --git a/compiler/Print.ml b/compiler/Print.ml index 5d5c16ee..1d5ddc50 100644 --- a/compiler/Print.ml +++ b/compiler/Print.ml @@ -359,6 +359,18 @@ module Values = struct ^ "}" ^ "{regions=" ^ T.RegionId.Set.to_string None abs.regions ^ "}" ^ " {\n" ^ avs ^ "\n" ^ indent ^ "}" + + let inst_fun_sig_to_string (fmt : value_formatter) (sg : LlbcAst.inst_fun_sig) + : string = + (* TODO: print the trait type constraints? *) + let ty_fmt = value_to_rtype_formatter fmt in + let ty_to_string = PT.ty_to_string ty_fmt in + + let inputs = + "(" ^ String.concat ", " (List.map ty_to_string sg.inputs) ^ ")" + in + let output = ty_to_string sg.output in + inputs ^ " -> " ^ output end module PV = Values (* local module *) @@ -755,6 +767,17 @@ module EvalCtxLlbcAst = struct let fmt = PC.eval_ctx_to_ast_formatter ctx in PA.fun_sig_to_string fmt "" " " x + let inst_fun_sig_to_string (ctx : C.eval_ctx) (x : LlbcAst.inst_fun_sig) : + string = + let fmt = PC.eval_ctx_to_ast_formatter ctx in + let fmt = PC.ast_to_value_formatter fmt in + PV.inst_fun_sig_to_string fmt x + + let fun_id_or_trait_method_ref_to_string (ctx : C.eval_ctx) + (x : E.fun_id_or_trait_method_ref) : string = + let fmt = PC.eval_ctx_to_ast_formatter ctx in + PE.fun_id_or_trait_method_ref_to_string fmt x "..." + let statement_to_string (ctx : C.eval_ctx) (indent : string) (indent_incr : string) (e : A.statement) : string = let fmt = PC.eval_ctx_to_ast_formatter ctx in diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml index 5fb5978b..be7b3cb4 100644 --- a/compiler/PrintPure.ml +++ b/compiler/PrintPure.ml @@ -593,17 +593,17 @@ let fun_suffix (lp_id : LoopId.id option) (rg_id : T.RegionGroupId.id option) : let llbc_assumed_fun_id_to_string (fid : A.assumed_fun_id) : string = match fid with - | A.Replace -> "core::mem::replace" - | A.BoxNew -> "alloc::boxed::Box::new" - | A.BoxDeref -> "core::ops::deref::Deref::deref" - | A.BoxDerefMut -> "core::ops::deref::DerefMut::deref_mut" - | A.BoxFree -> "alloc::alloc::box_free" - | A.VecNew -> "alloc::vec::Vec::new" - | A.VecPush -> "alloc::vec::Vec::push" - | A.VecInsert -> "alloc::vec::Vec::insert" - | A.VecLen -> "alloc::vec::Vec::len" - | A.VecIndex -> "core::ops::index::Index<alloc::vec::Vec>::index" - | A.VecIndexMut -> "core::ops::index::IndexMut<alloc::vec::Vec>::index_mut" + | Replace -> "core::mem::replace" + | BoxNew -> "alloc::boxed::Box::new" + | BoxDeref -> "core::ops::deref::Deref::deref" + | BoxDerefMut -> "core::ops::deref::DerefMut::deref_mut" + | BoxFree -> "alloc::alloc::box_free" + | VecNew -> "alloc::vec::Vec::new" + | VecPush -> "alloc::vec::Vec::push" + | VecInsert -> "alloc::vec::Vec::insert" + | VecLen -> "alloc::vec::Vec::len" + | VecIndex -> "core::ops::index::Index<alloc::vec::Vec>::index" + | VecIndexMut -> "core::ops::index::IndexMut<alloc::vec::Vec>::index_mut" | ArrayIndexShared -> "@ArrayIndexShared" | ArrayIndexMut -> "@ArrayIndexMut" | ArrayToSliceShared -> "@ArrayToSliceShared" diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index cedc3559..b00509a6 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -791,7 +791,7 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx) let id0 = match id0 with | FunId fun_id -> fun_id - | TraitMethod (_, _, fun_decl_id) -> A.Regular fun_decl_id + | TraitMethod (_, _, fun_decl_id) -> Regular fun_decl_id in LlbcAstUtils.lookup_fun_sig id0 ctx.fun_ctx.fun_decls in @@ -1523,29 +1523,29 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl = match opt_destruct_function_call e with | Some (fun_id, _tys, args) -> ( match fun_id with - | Fun (FromLlbc (FunId (A.Assumed aid), _lp_id, rg_id)) -> ( + | Fun (FromLlbc (FunId (Assumed aid), _lp_id, rg_id)) -> ( (* Below, when dealing with the arguments: we consider the very * general case, where functions could be boxed (meaning we * could have: [box_new f x]) * *) match (aid, rg_id) with - | A.BoxNew, _ -> + | BoxNew, _ -> assert (rg_id = None); let arg, args = Collections.List.pop args in mk_apps arg args - | A.BoxDeref, None -> + | BoxDeref, None -> (* [Box::deref] forward is the identity *) let arg, args = Collections.List.pop args in mk_apps arg args - | A.BoxDeref, Some _ -> + | BoxDeref, Some _ -> (* [Box::deref] backward is [()] (doesn't give back anything) *) assert (args = []); mk_unit_rvalue - | A.BoxDerefMut, None -> + | BoxDerefMut, None -> (* [Box::deref_mut] forward is the identity *) let arg, args = Collections.List.pop args in mk_apps arg args - | A.BoxDerefMut, Some _ -> + | BoxDerefMut, Some _ -> (* [Box::deref_mut] back is almost the identity: * let box_deref_mut (x_init : t) (x_back : t) : t = x_back * *) @@ -1555,15 +1555,15 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl = | _ -> raise (Failure "Unreachable") in mk_apps arg args - | A.BoxFree, _ -> + | BoxFree, _ -> assert (args = []); mk_unit_rvalue - | ( ( A.Replace | VecNew | VecPush | VecInsert | VecLen - | VecIndex | VecIndexMut | ArraySubsliceShared - | ArraySubsliceMut | SliceIndexShared | SliceIndexMut - | SliceSubsliceShared | SliceSubsliceMut | ArrayIndexShared - | ArrayIndexMut | ArrayToSliceShared | ArrayToSliceMut - | ArrayRepeat | SliceLen ), + | ( ( Replace | VecNew | VecPush | VecInsert | VecLen | VecIndex + | VecIndexMut | ArraySubsliceShared | ArraySubsliceMut + | SliceIndexShared | SliceIndexMut | SliceSubsliceShared + | SliceSubsliceMut | ArrayIndexShared | ArrayIndexMut + | ArrayToSliceShared | ArrayToSliceMut | ArrayRepeat + | SliceLen ), _ ) -> super#visit_texpression env e) | _ -> super#visit_texpression env e) @@ -2046,7 +2046,7 @@ let filter_loop_inputs (transl : pure_fun_translation list) : let inputs_set = VarId.Set.of_list (List.map var_get_id inputs_prefix) in assert (Option.is_some decl.loop_id); - let fun_id = (A.Regular decl.def_id, decl.loop_id) in + let fun_id = (E.Regular decl.def_id, decl.loop_id) in let set_used vid = used := List.map (fun (vid', b) -> (vid', b || vid = vid')) !used @@ -2130,7 +2130,7 @@ let filter_loop_inputs (transl : pure_fun_translation list) : (* We then apply the filtering to all the function definitions at once *) let filter_in_one (decl : fun_decl) : fun_decl = (* Filter the function signature *) - let fun_id = (A.Regular decl.def_id, decl.loop_id) in + let fun_id = (E.Regular decl.def_id, decl.loop_id) in let decl = match FunLoopIdMap.find_opt fun_id !used_map with | None -> (* Nothing to filter *) decl diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 429198ad..54221cb1 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -343,7 +343,7 @@ let bs_ctx_lookup_llbc_fun_decl (id : A.FunDeclId.id) (ctx : bs_ctx) : (* TODO: move *) let bs_ctx_lookup_local_function_sig (def_id : A.FunDeclId.id) (back_id : T.RegionGroupId.id option) (ctx : bs_ctx) : fun_sig = - let id = (A.Regular def_id, back_id) in + let id = (E.Regular def_id, back_id) in (RegularFunIdNotLoopMap.find id ctx.fun_context.fun_sigs).sg (* Some generic translation functions (we need to translate different "flavours" @@ -390,6 +390,7 @@ and translate_trait_instance_id (translate_ty : 'r T.ty -> ty) let inst_id = translate_trait_instance_id inst_id in ItemClause (inst_id, decl_id, item_name, clause_id) | TraitRef tr -> TraitRef (translate_trait_ref translate_ty tr) + | FnPointer _ -> raise (Failure "TODO") | UnknownTrait s -> raise (Failure ("Unknown trait found: " ^ s)) let rec translate_sty (ty : T.sty) : ty = @@ -427,6 +428,7 @@ let rec translate_sty (ty : T.sty) : ty = let trait_ref = translate_strait_ref trait_ref in let generics = translate_sgeneric_args generics in TraitType (trait_ref, generics, type_name) + | Arrow _ -> raise (Failure "TODO") and translate_sgeneric_args (generics : T.sgeneric_args) : generic_args = translate_generic_args translate_sty generics @@ -569,6 +571,7 @@ let rec translate_fwd_ty (type_infos : TA.type_infos) (ty : 'r T.ty) : ty = let trait_ref = translate_fwd_trait_ref type_infos trait_ref in let generics = translate_fwd_generic_args type_infos generics in TraitType (trait_ref, generics, type_name) + | Arrow _ -> raise (Failure "TODO") and translate_fwd_generic_args (type_infos : TA.type_infos) (generics : 'r T.generic_args) : generic_args = @@ -658,6 +661,7 @@ let rec translate_back_ty (type_infos : TA.type_infos) let trait_ref = translate_fwd_trait_ref type_infos trait_ref in let generics = translate_fwd_generic_args type_infos generics in Some (TraitType (trait_ref, generics, type_name)) + | Arrow _ -> raise (Failure "TODO") (** Simply calls [translate_back_ty] *) let ctx_translate_back_ty (ctx : bs_ctx) (keep_region : 'r -> bool) @@ -694,7 +698,7 @@ let type_check_texpression (ctx : bs_ctx) (e : texpression) : unit = let translate_fun_id_or_trait_method_ref (ctx : bs_ctx) (id : A.fun_id_or_trait_method_ref) : fun_id_or_trait_method_ref = match id with - | A.FunId fun_id -> FunId fun_id + | FunId fun_id -> FunId fun_id | TraitMethod (trait_ref, method_name, fun_decl_id) -> let type_infos = ctx.type_context.type_infos in let trait_ref = translate_fwd_trait_ref type_infos trait_ref in @@ -795,7 +799,7 @@ let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t) (fun_id : A.fun_id_or_trait_method_ref) (lid : V.LoopId.id option) (gid : T.RegionGroupId.id option) : fun_effect_info = match fun_id with - | A.TraitMethod (_, _, fid) | A.FunId (A.Regular fid) -> + | TraitMethod (_, _, fid) | FunId (Regular fid) -> let info = A.FunDeclId.Map.find fid fun_infos in let stateful_group = info.stateful in let stateful = @@ -808,7 +812,7 @@ let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t) can_diverge = info.can_diverge; is_rec = info.is_rec || Option.is_some lid; } - | A.FunId (A.Assumed aid) -> + | FunId (Assumed aid) -> assert (lid = None); { can_fail = Assumed.assumed_can_fail aid; @@ -841,7 +845,7 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) in (* Is the function stateful, and can it fail? *) let lid = None in - let effect_info = get_fun_effect_info fun_infos (A.FunId fun_id) lid bid in + let effect_info = get_fun_effect_info fun_infos (FunId fun_id) lid bid in (* We need an evaluation context to normalize the types (to normalize the associated types, etc. - for instance it may happen that the types refer to the types associated to a trait ref, but where the trait ref @@ -1706,18 +1710,20 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : in (ctx, Unop (Neg int_ty), effect_info, args, None) | _ -> raise (Failure "Unreachable")) - | S.Unop (E.Cast (src_ty, tgt_ty)) -> - (* Note that cast can fail *) - let effect_info = - { - can_fail = true; - stateful_group = false; - stateful = false; - can_diverge = false; - is_rec = false; - } - in - (ctx, Unop (Cast (src_ty, tgt_ty)), effect_info, args, None) + | S.Unop (E.Cast cast_kind) -> ( + match cast_kind with + | CastInteger (src_ty, tgt_ty) -> + (* Note that cast can fail *) + let effect_info = + { + can_fail = true; + stateful_group = false; + stateful = false; + can_diverge = false; + is_rec = false; + } + in + (ctx, Unop (Cast (src_ty, tgt_ty)), effect_info, args, None)) | S.Binop binop -> ( match args with | [ arg0; arg1 ] -> @@ -1925,7 +1931,7 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) (* Sanity check: there is the proper number of inputs and outputs, and they have the proper type *) (if (* TODO: normalize the types *) !Config.type_check_pure_code then match fun_id with - | A.FunId fun_id -> + | FunId fun_id -> let inst_sg = get_instantiated_fun_sig fun_id (Some rg_id) generics ctx in @@ -2088,9 +2094,9 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) (* Actually the same case as [SynthInput] *) translate_end_abstraction_synth_input ectx abs e ctx rg_id | V.LoopCall -> - let fun_id = A.Regular ctx.fun_decl.A.def_id in + let fun_id = E.Regular ctx.fun_decl.A.def_id in let effect_info = - get_fun_effect_info ctx.fun_context.fun_infos (A.FunId fun_id) + get_fun_effect_info ctx.fun_context.fun_infos (FunId fun_id) (Some vloop_id) (Some rg_id) in let loop_info = LoopId.Map.find loop_id ctx.loops in @@ -2553,9 +2559,9 @@ and translate_forward_end (ectx : C.eval_ctx) let org_args = args in (* Lookup the effect info for the loop function *) - let fid = A.Regular ctx.fun_decl.A.def_id in + let fid = E.Regular ctx.fun_decl.A.def_id in let effect_info = - get_fun_effect_info ctx.fun_context.fun_infos (A.FunId fid) None ctx.bid + get_fun_effect_info ctx.fun_context.fun_infos (FunId fid) None ctx.bid in (* Introduce a fresh output value for the forward function *) diff --git a/compiler/SynthesizeSymbolic.ml b/compiler/SynthesizeSymbolic.ml index aeb6899f..9084f2b3 100644 --- a/compiler/SynthesizeSymbolic.ml +++ b/compiler/SynthesizeSymbolic.ml @@ -85,7 +85,7 @@ let synthesize_symbolic_expansion (sv : V.symbolic_value) match ls with | [ (Some see, exp) ] -> ExpandNoBranch (see, exp) | _ -> raise (Failure "Ill-formed borrow expansion")) - | T.TypeVar _ | T.Literal Char | Never | T.TraitType _ -> + | T.TypeVar _ | T.Literal Char | Never | T.TraitType _ | T.Arrow _ -> raise (Failure "Ill-formed symbolic expansion") in Some (Expansion (place, sv, expansion)) diff --git a/compiler/Translate.ml b/compiler/Translate.ml index e69abee1..8e01c869 100644 --- a/compiler/Translate.ml +++ b/compiler/Translate.ml @@ -61,7 +61,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx) (* Initialize the context *) let forward_sig = - RegularFunIdNotLoopMap.find (A.Regular def_id, None) fun_sigs + RegularFunIdNotLoopMap.find (E.Regular def_id, None) fun_sigs in let sv_to_var = V.SymbolicValueId.Map.empty in let var_counter = Pure.VarId.generator_zero in @@ -200,7 +200,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx) (* Initialize the context - note that the ret_ty is not really * useful as we don't translate a body *) let backward_sg = - RegularFunIdNotLoopMap.find (A.Regular def_id, Some back_id) fun_sigs + RegularFunIdNotLoopMap.find (Regular def_id, Some back_id) fun_sigs in let ctx = { ctx with bid = Some back_id; sg = backward_sg.sg } in @@ -211,7 +211,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx) variables required by the backward function. *) let backward_sg = - RegularFunIdNotLoopMap.find (A.Regular def_id, Some back_id) fun_sigs + RegularFunIdNotLoopMap.find (Regular def_id, Some back_id) fun_sigs in (* We need to ignore the forward inputs, and the state input (if there is) *) let backward_inputs = @@ -298,7 +298,7 @@ let translate_crate_to_pure (crate : A.crate) : let assumed_sigs = List.map (fun (id, sg, _, _) -> - (A.Assumed id, List.map (fun _ -> None) (sg : A.fun_sig).inputs, sg)) + (E.Assumed id, List.map (fun _ -> None) (sg : A.fun_sig).inputs, sg)) Assumed.assumed_infos in let local_sigs = @@ -312,7 +312,7 @@ let translate_crate_to_pure (crate : A.crate) : (fun (v : A.var) -> v.name) (LlbcAstUtils.fun_body_get_input_vars body) in - (A.Regular fdef.def_id, input_names, fdef.signature)) + (E.Regular fdef.def_id, input_names, fdef.signature)) (A.FunDeclId.Map.values crate.functions) in let sigs = List.append assumed_sigs local_sigs in diff --git a/compiler/TypesAnalysis.ml b/compiler/TypesAnalysis.ml index 95c7206a..4a187893 100644 --- a/compiler/TypesAnalysis.ml +++ b/compiler/TypesAnalysis.ml @@ -232,6 +232,14 @@ let analyze_full_ty (r_is_static : 'r -> bool) (updated : bool ref) in (* Return *) ty_info + | Arrow (inputs, output) -> + (* Just dive into the arrow *) + let ty_info = + List.fold_left + (fun ty_info ty -> analyze expl_info ty_info ty) + ty_info inputs + in + analyze expl_info ty_info output in (* Explore *) analyze expl_info_init ty_info ty |