summaryrefslogtreecommitdiff
path: root/compiler
diff options
context:
space:
mode:
authorSon Ho2023-09-03 15:18:36 +0200
committerSon Ho2023-09-03 15:18:36 +0200
commitb42c0a8fa4708d6bf8424d63b6a7fe4964ba0e3d (patch)
tree5d1c87cbc924de09fafae1823f9e0e7563ff48d6 /compiler
parent0cafb31dd42c95f22e0b6680531c27fa0508e376 (diff)
Make progress on the extraction
Diffstat (limited to '')
-rw-r--r--compiler/Extract.ml13
-rw-r--r--compiler/Pure.ml44
-rw-r--r--compiler/PureMicroPasses.ml8
-rw-r--r--compiler/SymbolicToPure.ml110
-rw-r--r--compiler/Translate.ml109
-rw-r--r--compiler/TranslateCore.ml1
6 files changed, 264 insertions, 21 deletions
diff --git a/compiler/Extract.ml b/compiler/Extract.ml
index ad89a59e..e07305f1 100644
--- a/compiler/Extract.ml
+++ b/compiler/Extract.ml
@@ -1301,7 +1301,8 @@ and extract_trait_instance_id (ctx : extraction_ctx) (fmt : F.formatter)
let name = ctx_get_trait_item_clause decl_id item_name clause_id ctx in
extract_trait_instance_id ctx fmt no_params_tys true inst_id;
F.pp_print_string fmt ("." ^ name)
- | TraitRef trait_ref -> extract_trait_ref ctx fmt no_params_tys true trait_ref
+ | TraitRef trait_ref ->
+ extract_trait_ref ctx fmt no_params_tys inside trait_ref
| UnknownTrait _ ->
(* This is an error case *)
raise (Failure "Unexpected")
@@ -3774,6 +3775,16 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter)
(* Add a break to insert lines between declarations *)
F.pp_print_break fmt 0 0
+(** Extract a trait declaration *)
+let extract_trait_decl (ctx : extraction_ctx) (fmt : F.formatter)
+ (trait_decl : trait_decl) : unit =
+ raise (Failure "TODO")
+
+(** Extract a trait implementation *)
+let extract_trait_impl (ctx : extraction_ctx) (fmt : F.formatter)
+ (trait_impl : trait_impl) : unit =
+ raise (Failure "TODO")
+
(** Extract a unit test, if the function is a unit function (takes no
parameters, returns unit).
diff --git a/compiler/Pure.ml b/compiler/Pure.ml
index 725f71ad..6c9f41f1 100644
--- a/compiler/Pure.ml
+++ b/compiler/Pure.ml
@@ -45,6 +45,8 @@ type trait_decl_id = T.trait_decl_id [@@deriving show, ord]
type trait_impl_id = T.trait_impl_id [@@deriving show, ord]
type trait_clause_id = T.trait_clause_id [@@deriving show, ord]
type trait_item_name = T.trait_item_name [@@deriving show, ord]
+type global_decl_id = T.global_decl_id [@@deriving show, ord]
+type fun_decl_id = A.fun_decl_id [@@deriving show, ord]
(** The assumed types for the pure AST.
@@ -361,11 +363,23 @@ type generic_params = {
}
[@@deriving show]
+type trait_type_constraint = {
+ trait_ref : trait_ref;
+ generics : generic_args;
+ type_name : trait_item_name;
+ ty : ty;
+}
+[@@deriving show]
+
+type predicates = { trait_type_constraints : trait_type_constraint list }
+[@@deriving show]
+
type type_decl = {
def_id : TypeDeclId.id;
name : name;
generics : generic_params;
kind : type_decl_kind;
+ preds : predicates;
}
[@@deriving show]
@@ -881,6 +895,7 @@ type fun_sig_info = {
type fun_sig = {
generics : generic_params;
(** TODO: we should analyse the signature to make the type parameters implicit whenever possible *)
+ preds : predicates;
inputs : ty list;
(** The types of the inputs.
@@ -952,8 +967,11 @@ type fun_body = {
}
[@@deriving show]
+type fun_kind = A.fun_kind [@@deriving show]
+
type fun_decl = {
def_id : FunDeclId.id;
+ kind : fun_kind;
num_loops : int;
(** The number of loops in the parent forward function (basically the number
of loops appearing in the original Rust functions, unless some loops are
@@ -973,3 +991,29 @@ type fun_decl = {
body : fun_body option;
}
[@@deriving show]
+
+type trait_decl = {
+ def_id : trait_decl_id;
+ name : name;
+ generics : generic_params;
+ preds : predicates;
+ all_trait_clauses : trait_clause list;
+ consts : (trait_item_name * (ty * global_decl_id option)) list;
+ types : (trait_item_name * (trait_clause list * ty option)) list;
+ required_methods : (trait_item_name * fun_decl_id) list;
+ provided_methods : trait_item_name list;
+}
+[@@deriving show]
+
+type trait_impl = {
+ def_id : trait_impl_id;
+ name : name;
+ impl_trait : trait_decl_ref;
+ generics : generic_params;
+ preds : predicates;
+ consts : (trait_item_name * (ty * global_decl_id)) list;
+ types : (trait_item_name * (trait_ref list * ty)) list;
+ required_methods : (trait_item_name * fun_decl_id) list;
+ provided_methods : (trait_item_name * fun_decl_id) list;
+}
+[@@deriving show]
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml
index 45e4ea98..93609695 100644
--- a/compiler/PureMicroPasses.ml
+++ b/compiler/PureMicroPasses.ml
@@ -1355,6 +1355,7 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list =
let loop_sig =
{
generics = fun_sig.generics;
+ preds = fun_sig.preds;
inputs = inputs_tys;
output;
doutputs;
@@ -1419,6 +1420,7 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list =
let loop_def =
{
def_id = def.def_id;
+ kind = def.kind;
num_loops;
loop_id = Some loop.loop_id;
back_id = def.back_id;
@@ -2135,7 +2137,9 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
let num_filtered =
List.length (List.filter (fun b -> not b) used_info)
in
- let { generics; inputs; output; doutputs; info } = decl.signature in
+ let { generics; preds; inputs; output; doutputs; info } =
+ decl.signature
+ in
let {
has_fuel;
num_fwd_inputs_with_fuel_no_state;
@@ -2161,7 +2165,7 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
effect_info;
}
in
- let signature = { generics; inputs; output; doutputs; info } in
+ let signature = { generics; preds; inputs; output; doutputs; info } in
{ decl with signature }
in
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index 166f08a0..1a981de1 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -4,6 +4,7 @@ open Pure
open PureUtils
module Id = Identifiers
module C = Contexts
+module A = LlbcAst
module S = SymbolicAst
module TA = TypesAnalysis
module L = Logging
@@ -473,6 +474,20 @@ let translate_trait_clause (clause : T.trait_clause) : trait_clause =
let generics = translate_sgeneric_args generics in
{ clause_id; trait_id; generics }
+let translate_strait_type_constraint (ttc : T.strait_type_constraint) :
+ trait_type_constraint =
+ let { T.trait_ref; generics; type_name; ty } = ttc in
+ let trait_ref = translate_strait_ref trait_ref in
+ let generics = translate_sgeneric_args generics in
+ let ty = translate_sty ty in
+ { trait_ref; generics; type_name; ty }
+
+let translate_predicates (preds : T.predicates) : predicates =
+ let trait_type_constraints =
+ List.map translate_strait_type_constraint preds.trait_type_constraints
+ in
+ { trait_type_constraints }
+
let translate_generic_params (generics : T.generic_params) : generic_params =
let { T.regions = _; types; const_generics; trait_clauses } = generics in
let trait_clauses = List.map translate_trait_clause trait_clauses in
@@ -515,7 +530,8 @@ let translate_type_decl (def : T.type_decl) : type_decl =
let trait_clauses = List.map translate_trait_clause trait_clauses in
let generics = { types; const_generics; trait_clauses } in
let kind = translate_type_decl_kind def.T.kind in
- { def_id; name; generics; kind }
+ let preds = translate_predicates def.preds in
+ { def_id; name; generics; kind; preds }
let translate_type_id (id : T.type_id) : type_id =
match id with
@@ -952,7 +968,8 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t)
effect_info;
}
in
- let sg = { generics; inputs; output; doutputs; info } in
+ let preds = translate_predicates sg.A.preds in
+ let sg = { generics; preds; inputs; output; doutputs; info } in
{ sg; output_names }
let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * typed_pattern =
@@ -2932,6 +2949,7 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
let def =
{
def_id;
+ kind = def.kind;
num_loops;
loop_id;
back_id = bid;
@@ -3002,3 +3020,91 @@ let translate_fun_signatures (fun_infos : FA.fun_info A.FunDeclId.Map.t)
List.fold_left
(fun m (id, sg) -> RegularFunIdNotLoopMap.add id sg m)
RegularFunIdNotLoopMap.empty translated
+
+let translate_trait_decl (type_infos : TA.type_infos)
+ (trait_decl : A.trait_decl) : trait_decl =
+ let {
+ A.def_id;
+ name;
+ generics;
+ preds;
+ all_trait_clauses;
+ consts;
+ types;
+ required_methods;
+ provided_methods;
+ } =
+ trait_decl
+ in
+ let generics = translate_generic_params generics in
+ let preds = translate_predicates preds in
+ let all_trait_clauses = List.map translate_trait_clause all_trait_clauses in
+ let consts =
+ List.map
+ (fun (name, (ty, id)) -> (name, (translate_fwd_ty type_infos ty, id)))
+ consts
+ in
+ let types =
+ List.map
+ (fun (name, (trait_clauses, ty)) ->
+ ( name,
+ ( List.map translate_trait_clause trait_clauses,
+ Option.map (translate_fwd_ty type_infos) ty ) ))
+ types
+ in
+ {
+ def_id;
+ name;
+ generics;
+ preds;
+ all_trait_clauses;
+ consts;
+ types;
+ required_methods;
+ provided_methods;
+ }
+
+let translate_trait_impl (type_infos : TA.type_infos)
+ (trait_impl : A.trait_impl) : trait_impl =
+ let {
+ A.def_id;
+ name;
+ impl_trait;
+ generics;
+ preds;
+ consts;
+ types;
+ required_methods;
+ provided_methods;
+ } =
+ trait_impl
+ in
+ let impl_trait =
+ translate_trait_decl_ref (translate_fwd_ty type_infos) impl_trait
+ in
+ let generics = translate_generic_params generics in
+ let preds = translate_predicates preds in
+ let consts =
+ List.map
+ (fun (name, (ty, id)) -> (name, (translate_fwd_ty type_infos ty, id)))
+ consts
+ in
+ let types =
+ List.map
+ (fun (name, (trait_refs, ty)) ->
+ ( name,
+ ( List.map (translate_fwd_trait_ref type_infos) trait_refs,
+ translate_fwd_ty type_infos ty ) ))
+ types
+ in
+ {
+ def_id;
+ name;
+ impl_trait;
+ generics;
+ preds;
+ consts;
+ types;
+ required_methods;
+ provided_methods;
+ }
diff --git a/compiler/Translate.ml b/compiler/Translate.ml
index ca661108..f4f59187 100644
--- a/compiler/Translate.ml
+++ b/compiler/Translate.ml
@@ -5,6 +5,7 @@ module T = Types
module A = LlbcAst
module SA = SymbolicAst
module Micro = PureMicroPasses
+module C = Contexts
open PureUtils
open TranslateCore
@@ -28,18 +29,34 @@ let translate_function_to_symbolics (trans_ctx : trans_ctx) (fdef : A.fun_decl)
("translate_function_to_symbolics: "
^ Print.fun_name_to_string fdef.A.name));
- let { type_context; fun_context; global_context } = trans_ctx in
+ let {
+ type_context;
+ fun_context;
+ global_context;
+ trait_decls_context;
+ trait_impls_context;
+ } =
+ trans_ctx
+ in
let fun_context = { C.fun_decls = fun_context.fun_decls } in
+ (* TODO: we should merge trans_ctx and decls_ctx *)
+ let decls_ctx : C.decls_ctx =
+ {
+ C.type_ctx = type_context;
+ fun_ctx = fun_context;
+ global_ctx = global_context;
+ trait_decls_ctx = trait_decls_context;
+ trait_impls_ctx = trait_impls_context;
+ }
+ in
+
match fdef.body with
| None -> None
| Some _ ->
(* Evaluate *)
let synthesize = true in
- let inputs, symb =
- evaluate_function_symbolic synthesize type_context fun_context
- global_context fdef
- in
+ let inputs, symb = evaluate_function_symbolic synthesize decls_ctx fdef in
Some (inputs, Option.get symb)
(** Translate a function, by generating its forward and backward translations.
@@ -57,7 +74,15 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
(lazy
("translate_function_to_pure: " ^ Print.fun_name_to_string fdef.A.name));
- let { type_context; fun_context; global_context } = trans_ctx in
+ let {
+ type_context;
+ fun_context;
+ global_context;
+ trait_decls_context;
+ trait_impls_context;
+ } =
+ trans_ctx
+ in
let def_id = fdef.def_id in
(* Compute the symbolic ASTs, if the function is transparent *)
@@ -148,6 +173,8 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
type_context;
fun_context;
global_context;
+ trait_decls_ctx = trait_decls_context.trait_decls;
+ trait_impls_ctx = trait_impls_context.trait_impls;
fun_decl = fdef;
forward_inputs = [];
(* Empty for now *)
@@ -280,13 +307,21 @@ let translate_crate_to_pure (crate : A.crate) :
log#ldebug (lazy "translate_crate_to_pure");
(* Compute the type and function contexts *)
- let type_context, fun_context, global_context = compute_contexts crate in
+ let decls_ctx = compute_contexts crate in
let fun_infos =
- FA.analyze_module crate fun_context.C.fun_decls
- global_context.C.global_decls !Config.use_state
+ FA.analyze_module crate decls_ctx.fun_ctx.C.fun_decls
+ decls_ctx.global_ctx.C.global_decls !Config.use_state
+ in
+ let fun_context = { fun_decls = decls_ctx.fun_ctx.fun_decls; fun_infos } in
+ let trans_ctx =
+ {
+ type_context = decls_ctx.type_ctx;
+ fun_context;
+ global_context = decls_ctx.global_ctx;
+ trait_decls_context = decls_ctx.trait_decls_ctx;
+ trait_impls_context = decls_ctx.trait_impls_ctx;
+ }
in
- let fun_context = { fun_decls = fun_context.fun_decls; fun_infos } in
- let trans_ctx = { type_context; fun_context; global_context } in
(* Translate all the type definitions *)
let type_decls =
@@ -323,7 +358,7 @@ let translate_crate_to_pure (crate : A.crate) :
let sigs = List.append assumed_sigs local_sigs in
let fun_sigs =
SymbolicToPure.translate_fun_signatures fun_context.fun_infos
- type_context.type_infos sigs
+ decls_ctx.type_ctx.type_infos sigs
in
(* Translate all the *transparent* functions *)
@@ -696,6 +731,36 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config)
Extract.extract_unit_test_if_unit_fun ctx.extract_ctx fmt fwd)
pure_ls
+(** Export a trait declaration. *)
+let export_trait_decl (fmt : Format.formatter) (_config : gen_config)
+ (ctx : gen_ctx) (trait_decl_id : Pure.trait_decl_id) : unit =
+ let trait_decl =
+ T.TraitDeclId.Map.find trait_decl_id
+ ctx.extract_ctx.trans_ctx.trait_decls_context.trait_decls
+ in
+ (* We translate the trait declaration on the fly (note that
+ trait declarations do not directly contain functions, constants,
+ etc.: they simply refer to them). *)
+ let type_infos = ctx.extract_ctx.trans_ctx.type_context.type_infos in
+ let trait_decl = SymbolicToPure.translate_trait_decl type_infos trait_decl in
+ let ctx = ctx.extract_ctx in
+ let ctx = { ctx with trait_decl_id = Some trait_decl.def_id } in
+ Extract.extract_trait_decl ctx fmt trait_decl
+
+(** Export a trait implementation. *)
+let export_trait_impl (fmt : Format.formatter) (_config : gen_config)
+ (ctx : gen_ctx) (trait_impl_id : Pure.trait_impl_id) : unit =
+ let trait_impl =
+ T.TraitImplId.Map.find trait_impl_id
+ ctx.extract_ctx.trans_ctx.trait_impls_context.trait_impls
+ in
+ (* We translate the trait implementation on the fly (note that
+ trait implementations do not directly contain functions, constants,
+ etc.: they simply refer to them). *)
+ let type_infos = ctx.extract_ctx.trans_ctx.type_context.type_infos in
+ let trait_impl = SymbolicToPure.translate_trait_impl type_infos trait_impl in
+ Extract.extract_trait_impl ctx.extract_ctx fmt trait_impl
+
(** A generic utility to generate the extracted definitions: as we may want to
split the definitions between different files (or not), we can control
what is precisely extracted.
@@ -710,6 +775,8 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config)
let export_functions_group = export_functions_group fmt config ctx in
let export_global = export_global fmt config ctx in
let export_types_group = export_types_group fmt config ctx in
+ let export_trait_decl = export_trait_decl fmt config ctx in
+ let export_trait_impl = export_trait_impl fmt config ctx in
let export_state_type () : unit =
let kind =
@@ -723,11 +790,18 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config)
| Type (NonRec id) ->
if config.extract_types then export_types_group false [ id ]
| Type (Rec ids) -> if config.extract_types then export_types_group true ids
- | Fun (NonRec id) ->
+ | Fun (NonRec id) -> (
(* Lookup *)
let pure_fun = A.FunDeclId.Map.find id ctx.trans_funs in
- (* Translate *)
- export_functions_group [ pure_fun ]
+ (* Special case: we skip trait method *declarations* (we will
+ extract their type directly in the records we generate for
+ the trait declarations themselves, there is no point in having
+ separate type definitions) *)
+ match (fst (fst (snd pure_fun))).Pure.kind with
+ | TraitMethodDecl _ -> ()
+ | _ ->
+ (* Translate *)
+ export_functions_group [ pure_fun ])
| Fun (Rec ids) ->
(* General case of mutually recursive functions *)
(* Lookup *)
@@ -737,11 +811,13 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config)
(* Translate *)
export_functions_group pure_funs
| Global id -> export_global id
+ | TraitDecl id -> export_trait_decl id
+ | TraitImpl id -> export_trait_impl id
in
(* If we need to export the state type: we try to export it after we defined
* the type definitions, because if the user wants to define a model for the
- * type, he might want to reuse those in the state type.
+ * type, they might want to reuse those in the state type.
* More specifically: if we extract functions in the same file as the type,
* we have no choice but to define the state type before the functions,
* because they may reuse this state type: in this case, we define/declare
@@ -930,6 +1006,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
use_opaque_pre = !Config.split_files;
use_dep_ite = !Config.backend = Lean && !Config.extract_decreases_clauses;
fun_name_info = PureUtils.RegularFunIdMap.empty;
+ trait_decl_id = None (* None by default *);
}
in
diff --git a/compiler/TranslateCore.ml b/compiler/TranslateCore.ml
index 1b1572d6..34a6434f 100644
--- a/compiler/TranslateCore.ml
+++ b/compiler/TranslateCore.ml
@@ -22,6 +22,7 @@ type trait_decls_context = C.trait_decls_context [@@deriving show]
type trait_impls_context = C.trait_impls_context [@@deriving show]
type global_context = C.global_context [@@deriving show]
+(* TODO: we should use Contexts.decls_ctx *)
type trans_ctx = {
type_context : type_context;
fun_context : fun_context;