From 04e15bf8f8b5a8138a2adbac6bff843f183f0301 Mon Sep 17 00:00:00 2001 From: Patrick Walton Date: Sun, 12 Dec 2010 20:02:49 -0800 Subject: [PATCH] rustc: Typecheck "alt" expressions and patterns --- src/comp/middle/typeck.rs | 175 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 175 insertions(+) diff --git a/src/comp/middle/typeck.rs b/src/comp/middle/typeck.rs index 2ffd2bedf6b..346e40ab20b 100644 --- a/src/comp/middle/typeck.rs +++ b/src/comp/middle/typeck.rs @@ -759,6 +759,15 @@ fn block_ty(&ast.block b) -> @ty { } } +fn pat_ty(@ast.pat pat) -> @ty { + alt (pat.node) { + case (ast.pat_wild(?ann)) { ret ann_to_type(ann); } + case (ast.pat_bind(_, _, ?ann)) { ret ann_to_type(ann); } + case (ast.pat_tag(_, _, _, ?ann)) { ret ann_to_type(ann); } + } + fail; // not reached +} + fn expr_ty(@ast.expr expr) -> @ty { alt (expr.node) { case (ast.expr_vec(_, ?ann)) { ret ann_to_type(ann); } @@ -772,6 +781,7 @@ fn expr_ty(@ast.expr expr) -> @ty { case (ast.expr_if(_, _, _, ?ann)) { ret ann_to_type(ann); } case (ast.expr_while(_, _, ?ann)) { ret ann_to_type(ann); } case (ast.expr_do_while(_, _, ?ann)) { ret ann_to_type(ann); } + case (ast.expr_alt(_, _, ?ann)) { ret ann_to_type(ann); } case (ast.expr_block(_, ?ann)) { ret ann_to_type(ann); } case (ast.expr_assign(_, _, ?ann)) { ret ann_to_type(ann); } case (ast.expr_assign_op(_, _, _, ?ann)) @@ -847,6 +857,20 @@ fn unify(&fn_ctxt fcx, @ty expected, @ty actual) -> unify_result { case (ty_char) { ret struct_cmp(expected, actual); } case (ty_str) { ret struct_cmp(expected, actual); } + case (ty_tag(?expected_id)) { + alt (actual.struct) { + case (ty_tag(?actual_id)) { + if (expected_id._0 == actual_id._0 && + expected_id._1 == actual_id._1) { + ret ures_ok(expected); + } + } + case (_) { /* fall through */ } + } + + ret ures_err(terr_mismatch, expected, actual); + } + case (ty_box(?expected_sub)) { alt (actual.struct) { case (ty_box(?actual_sub)) { @@ -1147,6 +1171,64 @@ fn are_compatible(&fn_ctxt fcx, @ty expected, @ty actual) -> bool { } } +// Type unification over typed patterns. Note that the pattern that you pass +// to this function must have been passed to check_pat() first. +// +// TODO: enforce this via a predicate. + +fn demand_pat(&fn_ctxt fcx, @ty expected, @ast.pat pat) -> @ast.pat { + auto p_1 = ast.pat_wild(ast.ann_none); // FIXME: typestate botch + + alt (pat.node) { + case (ast.pat_wild(?ann)) { + auto t = demand(fcx, pat.span, expected, ann_to_type(ann)); + p_1 = ast.pat_wild(ast.ann_type(t)); + } + case (ast.pat_bind(?id, ?did, ?ann)) { + auto t = demand(fcx, pat.span, expected, ann_to_type(ann)); + p_1 = ast.pat_bind(id, did, ast.ann_type(t)); + } + case (ast.pat_tag(?id, ?subpats, ?vdef_opt, ?ann)) { + auto t = demand(fcx, pat.span, expected, ann_to_type(ann)); + + // The type of the tag isn't enough; we also have to get the type + // of the variant, which is either a tag type in the case of + // nullary variants or a function type in the case of n-ary + // variants. + // + // TODO: When we have type-parametric tags, this will get a little + // trickier. Basically, we have to instantiate the variant type we + // acquire here with the type parameters provided to us by + // "expected". + + auto vdef = option.get[ast.variant_def](vdef_opt); + auto variant_ty = fcx.ccx.item_types.get(vdef._1); + + auto subpats_len = _vec.len[@ast.pat](subpats); + alt (variant_ty.struct) { + case (ty_tag(_)) { + // Nullary tag variant. + check (subpats_len == 0u); + p_1 = ast.pat_tag(id, subpats, vdef_opt, ast.ann_type(t)); + } + case (ty_fn(?args, ?tag_ty)) { + let vec[@ast.pat] new_subpats = vec(); + auto i = 0u; + for (arg a in args) { + auto new_subpat = demand_pat(fcx, a.ty, subpats.(i)); + new_subpats += vec(new_subpat); + i += 1u; + } + p_1 = ast.pat_tag(id, new_subpats, vdef_opt, + ast.ann_type(tag_ty)); + } + } + } + } + + ret @fold.respan[ast.pat_](pat.span, p_1); +} + // Type unification over typed expressions. Note that the expression that you // pass to this function must have been passed to check_expr() first. // @@ -1351,6 +1433,68 @@ fn check_lit(@ast.lit lit) -> @ty { ret plain_ty(sty); } +fn check_pat(&fn_ctxt fcx, @ast.pat pat) -> @ast.pat { + auto new_pat; + alt (pat.node) { + case (ast.pat_wild(_)) { + new_pat = ast.pat_wild(ast.ann_type(next_ty_var(fcx))); + } + case (ast.pat_bind(?id, ?def_id, _)) { + auto ann = ast.ann_type(next_ty_var(fcx)); + new_pat = ast.pat_bind(id, def_id, ann); + } + case (ast.pat_tag(?id, ?subpats, ?vdef_opt, _)) { + auto vdef = option.get[ast.variant_def](vdef_opt); + auto t = fcx.ccx.item_types.get(vdef._1); + alt (t.struct) { + // N-ary variants have function types. + case (ty_fn(?args, ?tag_ty)) { + auto arg_len = _vec.len[arg](args); + auto subpats_len = _vec.len[@ast.pat](subpats); + if (arg_len != subpats_len) { + // TODO: pluralize properly + auto err_msg = "tag type " + id + " has " + + _uint.to_str(subpats_len, 10u) + + " fields, but this pattern has " + + _uint.to_str(arg_len, 10u) + " fields"; + + fcx.ccx.sess.span_err(pat.span, err_msg); + fail; // TODO: recover + } + + let vec[@ast.pat] new_subpats = vec(); + for (@ast.pat subpat in subpats) { + new_subpats += vec(check_pat(fcx, subpat)); + } + + auto ann = ast.ann_type(tag_ty); + new_pat = ast.pat_tag(id, new_subpats, vdef_opt, ann); + } + + // Nullary variants have tag types. + case (ty_tag(?tid)) { + auto subpats_len = _vec.len[@ast.pat](subpats); + if (subpats_len > 0u) { + // TODO: pluralize properly + auto err_msg = "tag type " + id + " has no fields," + + " but this pattern has " + + _uint.to_str(subpats_len, 10u) + + " fields"; + + fcx.ccx.sess.span_err(pat.span, err_msg); + fail; // TODO: recover + } + + auto ann = ast.ann_type(plain_ty(ty_tag(tid))); + new_pat = ast.pat_tag(id, subpats, vdef_opt, ann); + } + } + } + } + + ret @fold.respan[ast.pat_](pat.span, new_pat); +} + fn check_expr(&fn_ctxt fcx, @ast.expr expr) -> @ast.expr { alt (expr.node) { case (ast.expr_lit(?lit, _)) { @@ -1528,6 +1672,37 @@ fn check_expr(&fn_ctxt fcx, @ast.expr expr) -> @ast.expr { ann)); } + case (ast.expr_alt(?expr, ?arms, _)) { + auto expr_0 = check_expr(fcx, expr); + auto pattern_ty = expr_ty(expr_0); + auto result_ty = next_ty_var(fcx); + + let vec[ast.arm] arms_0 = vec(); + for (ast.arm arm in arms) { + auto pat_0 = check_pat(fcx, arm.pat); + pattern_ty = demand(fcx, pat_0.span, pattern_ty, + pat_ty(pat_0)); + auto block_0 = check_block(fcx, arm.block); + result_ty = demand(fcx, block_0.span, result_ty, + block_ty(block_0)); + arms_0 += vec(rec(pat=pat_0, block=block_0, index=arm.index)); + } + + auto expr_1 = demand_expr(fcx, pattern_ty, expr); + + let vec[ast.arm] arms_1 = vec(); + for (ast.arm arm_0 in arms_0) { + auto pat_1 = demand_pat(fcx, pattern_ty, arm_0.pat); + auto block_1 = demand_block(fcx, result_ty, arm_0.block); + auto arm_1 = rec(pat=pat_1, block=block_1, index=arm_0.index); + arms_1 += vec(arm_1); + } + + auto ann = ast.ann_type(result_ty); + ret @fold.respan[ast.expr_](expr.span, + ast.expr_alt(expr_1, arms_1, ann)); + } + case (ast.expr_call(?f, ?args, _)) { // Check the function. auto f_0 = check_expr(fcx, f);