diff options
| author | Son Ho | 2022-01-28 22:17:28 +0100 | 
|---|---|---|
| committer | Son Ho | 2022-01-28 22:17:28 +0100 | 
| commit | 2d40d81b4b9fde44fd924bad5a44b7392a1c9f1e (patch) | |
| tree | ae92f07a33a48bb794a633647d7efa111edb2a94 /src/SymbolicToPure.ml | |
| parent | 0b145dd4b0ab0ac5ed56121663a25801f20bed67 (diff) | |
Make the pure expressions typed
Diffstat (limited to '')
| -rw-r--r-- | src/SymbolicToPure.ml | 155 | 
1 files changed, 97 insertions, 58 deletions
| diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml index f6f610dd..d48f732d 100644 --- a/src/SymbolicToPure.ml +++ b/src/SymbolicToPure.ml @@ -843,16 +843,23 @@ let get_abs_ancestors (ctx : bs_ctx) (abs : V.abs) : S.call * V.abs list =    let abs_ancestors = list_ancestor_abstractions ctx abs in    (call_info.forward, abs_ancestors) -let rec translate_expression (e : S.expression) (ctx : bs_ctx) : expression = +let rec translate_expression (e : S.expression) (ctx : bs_ctx) : texpression =    match e with    | S.Return opt_v -> translate_return opt_v ctx -  | Panic -> Value (mk_result_fail_rvalue ctx.ret_ty, None) +  | Panic -> +      (* Here we use the function return type - note that it is ok because +       * we don't match on panics which happen inside the function body - +       * but it won't be true anymore once we translate individual blocks *) +      let v = mk_result_fail_rvalue ctx.ret_ty in +      let e = Value (v, None) in +      let ty = v.ty in +      { e; ty }    | FunCall (call, e) -> translate_function_call call e ctx    | EndAbstraction (abs, e) -> translate_end_abstraction abs e ctx    | Expansion (p, sv, exp) -> translate_expansion p sv exp ctx    | Meta (meta, e) -> translate_meta meta e ctx -and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : expression +and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression      =    (* There are two cases:       - either we are translating a forward function, in which case the optional @@ -864,7 +871,10 @@ and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : expression        (* Forward function *)        let v = Option.get opt_v in        let v = typed_value_to_rvalue ctx v in -      Value (mk_result_return_rvalue v, None) +      let v = mk_result_return_rvalue v in +      let e = Value (v, None) in +      let ty = v.ty in +      { e; ty }    | Some bid ->        (* Backward function *)        (* Sanity check *) @@ -880,10 +890,13 @@ and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : expression        let ret_tys = List.map (fun (v : typed_rvalue) -> v.ty) field_values in        let ret_ty = Adt (Tuple, ret_tys) in        let ret_value : typed_rvalue = { value = ret_value; ty = ret_ty } in -      Value (mk_result_return_rvalue ret_value, None) +      let ret_value = mk_result_return_rvalue ret_value in +      let e = Value (ret_value, None) in +      let ty = ret_value.ty in +      { e; ty }  and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : -    expression = +    texpression =    (* Translate the function call *)    let type_params = List.map (ctx_translate_fwd_ty ctx) call.type_params in    let args = List.map (typed_value_to_rvalue ctx) call.args in @@ -918,17 +931,22 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :          | _ -> failwith "Unreachable")    in    let args = -    List.map (fun (arg, mp) -> Value (arg, mp)) (List.combine args args_mplaces) +    List.map +      (fun (arg, mp) -> mk_value_expression arg mp) +      (List.combine args args_mplaces)    in +  let dest_v = mk_typed_lvalue_from_var dest dest_mplace in    let call = { func; type_params; args } in    let call = Call call in +  let call_ty = if monadic then mk_result_ty dest_v.ty else dest_v.ty in +  let call = { e = call; ty = call_ty } in    (* Translate the next expression *) -  let e = translate_expression e ctx in +  let next_e = translate_expression e ctx in    (* Put together *) -  Let (monadic, mk_typed_lvalue_from_var dest dest_mplace, call, e) +  mk_let monadic dest_v call next_e  and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) : -    expression = +    texpression =    log#ldebug      (lazy        ("translate_end_abstraction: abstraction kind: " @@ -976,14 +994,16 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :          (fun (var, v) -> assert ((var : var).ty = (v : typed_rvalue).ty))          variables_values;        (* Translate the next expression *) -      let e = translate_expression e ctx in +      let next_e = translate_expression e ctx in        (* Generate the assignemnts *)        let monadic = false in        List.fold_right -        (fun (var, value) e -> -          Let -            (monadic, mk_typed_lvalue_from_var var None, Value (value, None), e)) -        variables_values e +        (fun (var, value) (e : texpression) -> +          mk_let monadic +            (mk_typed_lvalue_from_var var None) +            (mk_value_expression value None) +            e) +        variables_values next_e    | V.FunCall ->        let call_info = V.FunCallId.Map.find abs.call_id ctx.calls in        let call = call_info.forward in @@ -1039,17 +1059,19 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :         * if necessary *)        let ctx, func = bs_ctx_register_backward_call abs ctx in        (* Translate the next expression *) -      let e = translate_expression e ctx in +      let next_e = translate_expression e ctx in        (* Put everything together *)        let args_mplaces = List.map (fun _ -> None) inputs in        let args =          List.map -          (fun (arg, mp) -> Value (arg, mp)) +          (fun (arg, mp) -> mk_value_expression arg mp)            (List.combine inputs args_mplaces)        in -      let call = { func; type_params; args } in        let monadic = true in -      Let (monadic, output, Call call, e) +      let call = { func; type_params; args } in +      let call_ty = mk_result_ty output.ty in +      let call = { e = Call call; ty = call_ty } in +      mk_let monadic output call next_e    | V.SynthRet ->        (* If we end the abstraction which consumed the return value of the function         * we are synthesizing, we get back the borrows which were inside. Those borrows @@ -1100,20 +1122,18 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :            assert (given_back.ty = input.ty))          given_back_inputs;        (* Translate the next expression *) -      let e = translate_expression e ctx in +      let next_e = translate_expression e ctx in        (* Generate the assignments *)        let monadic = false in        List.fold_right          (fun (given_back, input_var) e -> -          Let -            ( monadic, -              given_back, -              Value (mk_typed_rvalue_from_var input_var, None), -              e )) -        given_back_inputs e +          mk_let monadic given_back +            (mk_value_expression (mk_typed_rvalue_from_var input_var) None) +            e) +        given_back_inputs next_e  and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) -    (exp : S.expansion) (ctx : bs_ctx) : expression = +    (exp : S.expansion) (ctx : bs_ctx) : texpression =    (* Translate the scrutinee *)    let scrutinee_var = lookup_var_for_symbolic_value sv ctx in    let scrutinee = mk_typed_rvalue_from_var scrutinee_var in @@ -1130,13 +1150,12 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)            (* The (mut/shared) borrow type is extracted to identity: we thus simply             * introduce an reassignment *)            let ctx, var = fresh_var_for_symbolic_value nsv ctx in -          let e = translate_expression e ctx in +          let next_e = translate_expression e ctx in            let monadic = false in -          Let -            ( monadic, -              mk_typed_lvalue_from_var var None, -              Value (scrutinee, scrutinee_mplace), -              e ) +          mk_let monadic +            (mk_typed_lvalue_from_var var None) +            (mk_value_expression scrutinee scrutinee_mplace) +            next_e        | SeAdt _ ->            (* Should be in the [ExpandAdt] case *)            failwith "Unreachable") @@ -1161,7 +1180,10 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)                  in                  let lv = mk_adt_lvalue scrutinee.ty variant_id lvars in                  let monadic = false in -                Let (monadic, lv, Value (scrutinee, scrutinee_mplace), branch) + +                mk_let monadic lv +                  (mk_value_expression scrutinee scrutinee_mplace) +                  branch                else                  (* This is not an enumeration: introduce let-bindings for every                   * field. @@ -1182,22 +1204,19 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)                  List.fold_right                    (fun (fid, var) e ->                      let field_proj = gen_field_proj fid var in -                    Let -                      ( monadic, -                        mk_typed_lvalue_from_var var None, -                        Value (field_proj, None), -                        e )) +                    mk_let monadic +                      (mk_typed_lvalue_from_var var None) +                      (mk_value_expression field_proj None) +                      e)                    id_var_pairs branch            | T.Tuple ->                let vars =                  List.map (fun x -> mk_typed_lvalue_from_var x None) vars                in                let monadic = false in -              Let -                ( monadic, -                  mk_tuple_lvalue vars, -                  Value (scrutinee, scrutinee_mplace), -                  branch ) +              mk_let monadic (mk_tuple_lvalue vars) +                (mk_value_expression scrutinee scrutinee_mplace) +                branch            | T.Assumed T.Box ->                (* There should be exactly one variable *)                let var = @@ -1206,11 +1225,10 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)                (* We simply introduce an assignment - the box type is the                 * identity when extracted (`box a == a`) *)                let monadic = false in -              Let -                ( monadic, -                  mk_typed_lvalue_from_var var None, -                  Value (scrutinee, scrutinee_mplace), -                  branch )) +              mk_let monadic +                (mk_typed_lvalue_from_var var None) +                (mk_value_expression scrutinee scrutinee_mplace) +                branch)        | branches ->            let translate_branch (variant_id : T.VariantId.id option)                (svl : V.symbolic_value list) (branch : S.expression) : @@ -1229,16 +1247,30 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)            let branches =              List.map (fun (vid, svl, e) -> translate_branch vid svl e) branches            in -          Switch (Value (scrutinee, scrutinee_mplace), Match branches)) +          let e = +            Switch +              (mk_value_expression scrutinee scrutinee_mplace, Match branches) +          in +          (* There should be at least one branch *) +          let branch = List.hd branches in +          let ty = branch.branch.ty in +          assert (List.for_all (fun br -> br.branch.ty = ty) branches); +          { e; ty })    | ExpandBool (true_e, false_e) ->        (* We don't need to update the context: we don't introduce any         * new values/variables *)        let true_e = translate_expression true_e ctx in        let false_e = translate_expression false_e ctx in -      Switch (Value (scrutinee, scrutinee_mplace), If (true_e, false_e)) +      let e = +        Switch +          (mk_value_expression scrutinee scrutinee_mplace, If (true_e, false_e)) +      in +      let ty = true_e.ty in +      assert (ty = false_e.ty); +      { e; ty }    | ExpandInt (int_ty, branches, otherwise) ->        let translate_branch ((v, branch_e) : V.scalar_value * S.expression) : -          scalar_value * expression = +          scalar_value * texpression =          (* We don't need to update the context: we don't introduce any           * new values/variables *)          let branch_e = translate_expression branch_e ctx in @@ -1246,13 +1278,18 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)        in        let branches = List.map translate_branch branches in        let otherwise = translate_expression otherwise ctx in -      Switch -        ( Value (scrutinee, scrutinee_mplace), -          SwitchInt (int_ty, branches, otherwise) ) +      let e = +        Switch +          ( mk_value_expression scrutinee scrutinee_mplace, +            SwitchInt (int_ty, branches, otherwise) ) +      in +      let ty = otherwise.ty in +      assert (List.for_all (fun (_, br) -> br.ty = ty) branches); +      { e; ty }  and translate_meta (meta : S.meta) (e : S.expression) (ctx : bs_ctx) : -    expression = -  let e = translate_expression e ctx in +    texpression = +  let next_e = translate_expression e ctx in    let meta =      match meta with      | S.Assignment (p, rv) -> @@ -1260,7 +1297,9 @@ and translate_meta (meta : S.meta) (e : S.expression) (ctx : bs_ctx) :          let rv = typed_value_to_rvalue ctx rv in          Assignment (p, rv)    in -  Meta (meta, e) +  let e = Meta (meta, next_e) in +  let ty = next_e.ty in +  { e; ty }  let translate_fun_def (ctx : bs_ctx) (body : S.expression) : fun_def =    let def = ctx.fun_def in | 
