From 9c001af07c658b9583bde8d138d1d9408274d741 Mon Sep 17 00:00:00 2001 From: Tim Chevalier Date: Thu, 7 Apr 2011 18:15:56 -0700 Subject: [PATCH] Implemented computing prestates and poststates for a few expression forms. The typestate checker (if it's uncommented) now correctly rejects a trivial example program that has an uninitialized variable. --- src/comp/middle/typestate_check.rs | 287 ++++++++++++++++++++++------- src/comp/rustc.rc | 4 + src/comp/util/common.rs | 12 ++ src/comp/util/typestate_ann.rs | 14 +- src/lib/bitv.rs | 15 ++ 5 files changed, 268 insertions(+), 64 deletions(-) diff --git a/src/comp/middle/typestate_check.rs b/src/comp/middle/typestate_check.rs index 24bee1bd034..22ffcc35053 100644 --- a/src/comp/middle/typestate_check.rs +++ b/src/comp/middle/typestate_check.rs @@ -19,6 +19,8 @@ import front.ast.def_id; import front.ast.ann; import front.ast.expr; import front.ast.expr_call; +import front.ast.expr_vec; +import front.ast.expr_tup; import front.ast.expr_path; import front.ast.expr_log; import front.ast.expr_block; @@ -59,6 +61,8 @@ import util.common.span; import util.common.spanned; import util.common.new_str_hash; import util.common.new_def_hash; +import util.common.uistr; +import util.common.elt_exprs; import util.typestate_ann; import util.typestate_ann.ts_ann; import util.typestate_ann.empty_pre_post; @@ -75,6 +79,8 @@ import util.typestate_ann.ann_precond; import util.typestate_ann.ann_prestate; import util.typestate_ann.set_precondition; import util.typestate_ann.set_postcondition; +import util.typestate_ann.set_prestate; +import util.typestate_ann.set_poststate; import util.typestate_ann.set_in_postcond; import util.typestate_ann.implies; import util.typestate_ann.pre_and_post_state; @@ -92,6 +98,7 @@ import middle.ty.ty_to_str; import pretty.pprust.print_block; import pretty.pprust.print_expr; +import pretty.pprust.print_decl; import pretty.pp.mkstate; import std.io.stdout; import std.io.str_writer; @@ -102,11 +109,14 @@ import std._vec.len; import std._vec.pop; import std._vec.push; import std._vec.slice; +import std._vec.unzip; import std.option; import std.option.t; import std.option.some; import std.option.none; import std.option.from_maybe; +import std.option.is_none; +import std.option.get; import std.map.hashmap; import std.list; import std.list.list; @@ -137,6 +147,36 @@ fn log_expr(@expr e) -> () { log(s.get_str()); } +fn log_stmt(stmt st) -> () { + let str_writer s = string_writer(); + auto out_ = mkstate(s.get_writer(), 80u); + auto out = @rec(s=out_, + comments=option.none[vec[front.lexer.cmnt]], + mutable cur_cmnt=0u); + alt (st.node) { + case (ast.stmt_decl(?decl,_)) { + print_decl(out, decl); + } + case (ast.stmt_expr(?ex,_)) { + print_expr(out, ex); + } + case (_) { /* do nothing */ } + } + log(s.get_str()); +} + +fn log_bitv(fn_info enclosing, bitv.t v) { + auto s = ""; + + for each (@tup(def_id, tup(uint, ident)) p in enclosing.items()) { + if (bitv.get(v, p._1._0)) { + s += " " + p._1._1 + " "; + } + } + + log(s); +} + fn log_cond(vec[uint] v) -> () { auto res = ""; for (uint i in v) { @@ -173,14 +213,16 @@ fn print_idents(vec[ident] idents) -> () { } /**********************************************************************/ /* mapping from variable name (def_id is assumed to be for a local - variable in a given function) to bit number */ -type fn_info = std.map.hashmap[def_id, uint]; + variable in a given function) to bit number + (also remembers the ident for error-logging purposes) */ +type var_info = tup(uint, ident); +type fn_info = std.map.hashmap[def_id, var_info]; /* mapping from function name to fn_info map */ type _fn_info_map = std.map.hashmap[def_id, fn_info]; fn bit_num(def_id v, fn_info m) -> uint { check (m.contains_key(v)); - ret m.get(v); + ret m.get(v)._0; } fn var_is_local(def_id v, fn_info m) -> bool { @@ -191,14 +233,14 @@ fn num_locals(fn_info m) -> uint { ret m.size(); } -fn find_locals(_fn f) -> vec[def_id] { - auto res = _vec.alloc[def_id](0u); +fn find_locals(_fn f) -> vec[tup(ident,def_id)] { + auto res = _vec.alloc[tup(ident,def_id)](0u); for each (@tup(ident, block_index_entry) p in f.body.node.index.items()) { alt (p._1) { case (ast.bie_local(?loc)) { - res += vec(loc.id); + res += vec(tup(loc.ident,loc.id)); } case (_) { } } @@ -207,26 +249,25 @@ fn find_locals(_fn f) -> vec[def_id] { ret res; } -fn add_var(def_id v, uint next, fn_info tbl) -> uint { - tbl.insert(v, next); - // log(v + " |-> " + _uint.to_str(next, 10u)); +fn add_var(def_id v, ident nm, uint next, fn_info tbl) -> uint { + tbl.insert(v, tup(next,nm)); ret (next + 1u); } /* builds a table mapping each local var defined in f to a bit number in the precondition/postcondition vectors */ fn mk_fn_info(_fn f) -> fn_info { - auto res = new_def_hash[uint](); + auto res = new_def_hash[var_info](); let uint next = 0u; let vec[ast.arg] f_args = f.decl.inputs; for (ast.arg v in f_args) { - next = add_var(v.id, next, res); + next = add_var(v.id, v.ident, next, res); } - let vec[def_id] locals = find_locals(f); - for (def_id v in locals) { - next = add_var(v, next, res); + let vec[tup(ident,def_id)] locals = find_locals(f); + for (tup(ident,def_id) p in locals) { + next = add_var(p._1, p._0, next, res); } ret res; @@ -403,7 +444,7 @@ fn expr_states(&expr e) -> pre_and_post_state { fail; } case (some[@ts_ann](?p)) { - // ret p.states; + ret p.states; } } } @@ -691,7 +732,7 @@ fn find_pre_post_expr(&fn_info enclosing, &expr e) -> @expr { impure fn gen(&fn_info enclosing, ts_ann a, def_id id) { check(enclosing.contains_key(id)); - let uint i = enclosing.get(id); + let uint i = (enclosing.get(id))._0; set_in_postcond(i, a.conditions); } @@ -804,42 +845,152 @@ fn check_item_fn(&_fn_info_map fm, &span sp, ident i, &ast._fn f, } /* FIXME */ -fn find_pre_post_state_expr(&_fn_info_map fm, &fn_info enclosing, - &prestate pres, expr e) - -> tup(bool, @expr) { - log("Implement find_pre_post_state_expr!"); +fn find_pre_post_state_item(_fn_info_map fm, @item i) -> bool { + log("Implement find_pre_post_item!"); fail; } -/* FIXME: This isn't done yet. */ +impure fn set_prestate_ann(ann a, prestate pre) -> () { + alt (a) { + case (ann_type(_,_,?ts_a)) { + check (! is_none[@ts_ann](ts_a)); + set_prestate(*get[@ts_ann](ts_a), pre); + } + case (ann_none) { + log("set_prestate_ann: expected an ann_type here"); + fail; + } + } +} + +impure fn set_poststate_ann(ann a, poststate post) -> () { + alt (a) { + case (ann_type(_,_,?ts_a)) { + check (! is_none[@ts_ann](ts_a)); + set_poststate(*get[@ts_ann](ts_a), post); + } + case (ann_none) { + log("set_poststate_ann: expected an ann_type here"); + fail; + } + } +} + +fn seq_states(&_fn_info_map fm, &fn_info enclosing, + prestate pres, vec[@expr] exprs) -> tup(bool, poststate) { + auto changed = false; + auto post = pres; + + for (@expr e in exprs) { + changed = find_pre_post_state_expr(fm, enclosing, post, e) || changed; + post = expr_poststate(*e); + } + + ret tup(changed, post); +} + +fn find_pre_post_state_exprs(&_fn_info_map fm, + &fn_info enclosing, + &prestate pres, + &ann a, &vec[@expr] es) -> bool { + auto res = seq_states(fm, enclosing, pres, es); + set_prestate_ann(a, pres); + set_poststate_ann(a, res._1); + ret res._0; +} + +impure fn pure_exp(&ann a, &prestate p) -> () { + set_prestate_ann(a, p); + set_poststate_ann(a, p); +} + +fn find_pre_post_state_expr(&_fn_info_map fm, &fn_info enclosing, + &prestate pres, &@expr e) -> bool { + auto changed = false; + + alt (e.node) { + case (expr_vec(?elts, _, ?a)) { + be find_pre_post_state_exprs(fm, enclosing, pres, a, elts); + } + case (expr_tup(?elts, ?a)) { + be find_pre_post_state_exprs(fm, enclosing, pres, a, elt_exprs(elts)); + } + case (expr_call(?operator, ?operands, ?a)) { + /* do the prestate for the rator */ + changed = find_pre_post_state_expr(fm, enclosing, pres, operator) + || changed; + /* rands go left-to-right */ + ret(find_pre_post_state_exprs(fm, enclosing, + expr_poststate(*operator), a, operands) + || changed); + } + case (expr_path(_,_,?a)) { + pure_exp(a, pres); + ret false; + } + case (expr_log(?e,?a)) { + changed = find_pre_post_state_expr(fm, enclosing, pres, e); + set_prestate_ann(a, pres); + set_poststate_ann(a, expr_poststate(*e)); + ret changed; + } + case (_) { + log("find_pre_post_state_expr: implement this case!"); + fail; + } + } + +} + fn find_pre_post_state_stmt(&_fn_info_map fm, &fn_info enclosing, &prestate pres, @stmt s) -> bool { auto changed = false; alt (s.node) { case (stmt_decl(?adecl, ?a)) { + /* a must be some(a') at this point */ + check (! is_none[@ts_ann](a)); + auto stmt_ann = *(get[@ts_ann](a)); alt (adecl.node) { case (ast.decl_local(?alocal)) { alt (alocal.init) { case (some[ast.initializer](?an_init)) { - auto p = find_pre_post_state_expr(fm, enclosing, - pres, *an_init.expr); - fail; /* FIXME */ - /* Next: copy pres into a's prestate; - find the poststate by taking p's poststate - and setting the bit for alocal.id */ - } + changed = find_pre_post_state_expr + (fm, enclosing, pres, an_init.expr) || changed; + set_prestate(stmt_ann, expr_prestate(*an_init.expr)); + set_poststate(stmt_ann, expr_poststate(*an_init.expr)); + gen(enclosing, stmt_ann, alocal.id); + ret changed; + } + case (none[ast.initializer]) { + set_prestate(stmt_ann, pres); + set_poststate(stmt_ann, pres); + ret false; + } } } + case (ast.decl_item(?an_item)) { + be find_pre_post_state_item(fm, an_item); + } } } + case (stmt_expr(?e, ?a)) { + check (! is_none[@ts_ann](a)); + auto stmt_ann = *(get[@ts_ann](a)); + changed = find_pre_post_state_expr(fm, enclosing, pres, e) || changed; + set_prestate(stmt_ann, expr_prestate(*e)); + set_poststate(stmt_ann, expr_poststate(*e)); + ret changed; + } + case (_) { ret false; } } } -/* Returns a pair of a new block, with possibly a changed pre- or - post-state, and a boolean flag saying whether the function's pre- or - poststate changed */ +/* Updates the pre- and post-states of statements in the block, + returns a boolean flag saying whether any pre- or poststates changed */ fn find_pre_post_state_block(&_fn_info_map fm, &fn_info enclosing, block b) - -> tup(bool, block) { + -> bool { + log("pre_post_state_block: " + uistr(fm.size()) + " " + uistr(enclosing.size())); + auto changed = false; auto num_local_vars = num_locals(enclosing); @@ -857,43 +1008,35 @@ fn find_pre_post_state_block(&_fn_info_map fm, &fn_info enclosing, block b) extend_prestate(pres, stmt_poststate(*s, num_local_vars)); } - fn do_inner_(_fn_info_map fm, fn_info i, prestate p, &@expr e) - -> tup (bool, @expr) { - ret find_pre_post_state_expr(fm, i, p, *e); + alt (b.node.expr) { + case (none[@expr]) {} + case (some[@expr](?e)) { + changed = changed || find_pre_post_state_expr(fm, enclosing, pres, e); + } } - auto do_inner = bind do_inner_(fm, enclosing, pres, _); - let option.t[tup(bool, @expr)] e_ = - option.map[@expr, tup(bool, @expr)](do_inner, b.node.expr); - auto s = snd[bool, @expr]; - auto f = fst[bool, @expr]; - changed = changed || - from_maybe[bool](false, - option.map[tup(bool, @expr), bool](f, e_)); - let block_ b_res = rec(stmts=b.node.stmts, - expr=option.map[tup(bool, @expr), @expr](s, e_), - index=b.node.index); - ret tup(changed, respan(b.span, b_res)); + ret changed; } -fn find_pre_post_state_fn(_fn_info_map f_info, fn_info fi, &ast._fn f) - -> tup(bool, ast._fn) { - auto p = find_pre_post_state_block(f_info, fi, f.body); - ret tup(p._0, rec(decl=f.decl, proto=f.proto, body=p._1)); +fn find_pre_post_state_fn(&_fn_info_map f_info, &fn_info fi, &ast._fn f) + -> bool { + be find_pre_post_state_block(f_info, fi, f.body); } fn fixed_point_states(_fn_info_map fm, fn_info f_info, - fn (_fn_info_map, fn_info, &ast._fn) - -> tup(bool, ast._fn) f, - &ast._fn start) -> ast._fn { - auto next = f(fm, f_info, start); + // with no ampersands for the first two args, and likewise for find_pre_post_state_fn, + // I got a segfault + fn (&_fn_info_map, &fn_info, &ast._fn) -> bool f, + &ast._fn start) -> () { + log("fixed_point_states: " + uistr(fm.size()) + " " + uistr(f_info.size())); - if (next._0) { - // something changed - be fixed_point_states(fm, f_info, f, next._1); + auto changed = f(fm, f_info, start); + + if (changed) { + be fixed_point_states(fm, f_info, f, start); } else { // we're done! - ret next._1; + ret; } } @@ -917,7 +1060,12 @@ fn check_states_stmt(fn_info enclosing, &stmt s) -> () { let prestate pres = ann_prestate(*a); if (!implies(pres, prec)) { - log("check_states_stmt: unsatisfied precondition"); + log("check_states_stmt: unsatisfied precondition for "); + log_stmt(s); + log("Precondition: "); + log_bitv(enclosing, prec); + log("Prestate: "); + log_bitv(enclosing, pres); fail; } } @@ -947,16 +1095,18 @@ fn check_item_fn_state(&_fn_info_map f_info_map, &span sp, ident i, check(f_info_map.contains_key(id)); auto f_info = f_info_map.get(id); + log("check_item_fn_state: id = " + i + " " + uistr(f_info_map.size()) + " " + uistr(f_info.size())); + /* Compute the pre- and post-states for this function */ auto g = find_pre_post_state_fn; - auto res_f = fixed_point_states(f_info_map, f_info, g, f); + fixed_point_states(f_info_map, f_info, g, f); /* Now compare each expr's pre-state to its precondition and post-state to its postcondition */ - check_states_against_conditions(f_info, res_f); + check_states_against_conditions(f_info, f); /* Rebuild the same function */ - ret @respan(sp, ast.item_fn(i, res_f, ty_params, id, a)); + ret @respan(sp, ast.item_fn(i, f, ty_params, id, a)); } fn check_crate(@ast.crate crate) -> @ast.crate { @@ -978,3 +1128,14 @@ fn check_crate(@ast.crate crate) -> @ast.crate { ret fold.fold_crate[_fn_info_map](fn_info_map, fld1, with_pre_postconditions); } + +// +// Local Variables: +// mode: rust +// fill-column: 78; +// indent-tabs-mode: nil +// c-basic-offset: 4 +// buffer-file-coding-system: utf-8-unix +// compile-command: "make -k -C $RBUILD 2>&1 | sed -e 's/\\/x\\//x:\\//g'"; +// End: +// diff --git a/src/comp/rustc.rc b/src/comp/rustc.rc index e9d0b788205..14efb6671b0 100644 --- a/src/comp/rustc.rc +++ b/src/comp/rustc.rc @@ -65,10 +65,14 @@ auth lib.llvm = unsafe; auth pretty.pprust = impure; auth middle.typestate_check.find_pre_post_block = impure; auth middle.typestate_check.find_pre_post_state_block = impure; +auth middle.typestate_check.find_pre_post_state_stmt = impure; +auth middle.typestate_check.find_pre_post_state_expr = impure; +auth middle.typestate_check.find_pre_post_state_exprs = impure; auth middle.typestate_check.find_pre_post_expr = impure; auth middle.typestate_check.find_pre_post_stmt = impure; auth middle.typestate_check.check_states_against_conditions = impure; auth middle.typestate_check.check_states_stmt = impure; +auth middle.typestate_check.log_stmt = impure; auth util.typestate_ann.implies = impure; mod lib { diff --git a/src/comp/util/common.rs b/src/comp/util/common.rs index 5243e2f7df2..6dec6c00947 100644 --- a/src/comp/util/common.rs +++ b/src/comp/util/common.rs @@ -1,5 +1,6 @@ import std._uint; import std._int; +import std._vec; import front.ast; @@ -75,6 +76,17 @@ fn istr(int i) -> str { ret _int.to_str(i, 10u); } +fn uistr(uint i) -> str { + ret _uint.to_str(i, 10u); +} + +fn elt_expr(&ast.elt e) -> @ast.expr { ret e.expr; } + +fn elt_exprs(vec[ast.elt] elts) -> vec[@ast.expr] { + auto f = elt_expr; + be _vec.map[ast.elt, @ast.expr](f, elts); +} + // // Local Variables: // mode: rust diff --git a/src/comp/util/typestate_ann.rs b/src/comp/util/typestate_ann.rs index 53f9a71cf22..c8d23321620 100644 --- a/src/comp/util/typestate_ann.rs +++ b/src/comp/util/typestate_ann.rs @@ -104,6 +104,18 @@ impure fn set_postcondition(&ts_ann a, &postcond p) -> () { bitv.copy(p, a.conditions.postcondition); } +// Sets all the bits in a's prestate to equal the +// corresponding bit in p's prestate. +impure fn set_prestate(&ts_ann a, &prestate p) -> () { + bitv.copy(p, a.states.prestate); +} + +// Sets all the bits in a's postcondition to equal the +// corresponding bit in p's postcondition. +impure fn set_poststate(&ts_ann a, &poststate p) -> () { + bitv.copy(p, a.states.poststate); +} + // Set all the bits in p that are set in new impure fn extend_prestate(&prestate p, &poststate new) -> () { bitv.union(p, new); @@ -119,5 +131,5 @@ fn ann_prestate(&ts_ann a) -> prestate { impure fn implies(bitv.t a, bitv.t b) -> bool { bitv.difference(b, a); - be bitv.is_false(b); + ret bitv.is_false(b); } diff --git a/src/lib/bitv.rs b/src/lib/bitv.rs index 98e6c0401d6..75254ce75e8 100644 --- a/src/lib/bitv.rs +++ b/src/lib/bitv.rs @@ -170,6 +170,21 @@ fn to_vec(&t v) -> vec[uint] { ret _vec.init_fn[uint](sub, v.nbits); } +fn to_str(&t v) -> str { + auto res = ""; + + for(uint i in v.storage) { + if (i == 1u) { + res += "1"; + } + else { + res += "0"; + } + } + + ret res; +} + // FIXME: can we just use structural equality on to_vec? fn eq_vec(&t v0, &vec[uint] v1) -> bool { check (v0.nbits == _vec.len[uint](v1));