translate function shims using MIR

This commit is contained in:
Ariel Ben-Yehuda 2017-03-06 12:58:51 +02:00
parent ffee9566bb
commit bf80fec326
6 changed files with 199 additions and 140 deletions

View File

@ -12,30 +12,42 @@ use rustc::hir;
use rustc::infer;
use rustc::mir::*;
use rustc::mir::transform::MirSource;
use rustc::ty;
use rustc::ty::{self, Ty};
use rustc::ty::maps::Providers;
use rustc_data_structures::indexed_vec::{IndexVec, Idx};
use syntax::abi::Abi;
use syntax::ast;
use syntax::codemap::DUMMY_SP;
use syntax_pos::Span;
use std::cell::RefCell;
use std::iter;
use std::mem;
pub fn provide(providers: &mut Providers) {
providers.mir_shims = make_shim;
}
fn make_shim<'a, 'tcx>(_tcx: ty::TyCtxt<'a, 'tcx, 'tcx>,
fn make_shim<'a, 'tcx>(tcx: ty::TyCtxt<'a, 'tcx, 'tcx>,
instance: ty::InstanceDef<'tcx>)
-> &'tcx RefCell<Mir<'tcx>>
{
match instance {
debug!("make_shim({:?})", instance);
let result = match instance {
ty::InstanceDef::Item(..) =>
bug!("item {:?} passed to make_shim", instance),
ty::InstanceDef::FnPtrShim(..) => unimplemented!()
}
ty::InstanceDef::FnPtrShim(_, ty) => {
build_fn_ptr_shim(tcx, ty, instance.def_ty(tcx))
}
};
debug!("make_shim({:?}) = {:?}", instance, result);
let result = tcx.alloc_mir(result);
// Perma-borrow MIR from shims to prevent mutation.
mem::forget(result.borrow());
result
}
fn local_decls_for_sig<'tcx>(sig: &ty::FnSig<'tcx>)
@ -54,6 +66,111 @@ fn local_decls_for_sig<'tcx>(sig: &ty::FnSig<'tcx>)
})).collect()
}
fn build_fn_ptr_shim<'a, 'tcx>(tcx: ty::TyCtxt<'a, 'tcx, 'tcx>,
fn_ty: Ty<'tcx>,
sig_ty: Ty<'tcx>)
-> Mir<'tcx>
{
debug!("build_fn_ptr_shim(fn_ty={:?}, sig_ty={:?})", fn_ty, sig_ty);
let trait_sig = match sig_ty.sty {
ty::TyFnDef(_, _, fty) => tcx.erase_late_bound_regions(&fty),
_ => bug!("unexpected type for shim {:?}", sig_ty)
};
let self_ty = match trait_sig.inputs()[0].sty {
ty::TyParam(..) => fn_ty,
ty::TyRef(r, mt) => tcx.mk_ref(r, ty::TypeAndMut {
ty: fn_ty,
mutbl: mt.mutbl
}),
_ => bug!("unexpected self_ty {:?}", trait_sig),
};
let fn_ptr_sig = match fn_ty.sty {
ty::TyFnPtr(fty) |
ty::TyFnDef(_, _, fty) =>
tcx.erase_late_bound_regions_and_normalize(&fty),
_ => bug!("non-fn-ptr {:?} in build_fn_ptr_shim", fn_ty)
};
let sig = tcx.mk_fn_sig(
[
self_ty,
tcx.intern_tup(fn_ptr_sig.inputs(), false)
].iter().cloned(),
fn_ptr_sig.output(),
false,
hir::Unsafety::Normal,
Abi::RustCall,
);
let local_decls = local_decls_for_sig(&sig);
let source_info = SourceInfo {
span: DUMMY_SP,
scope: ARGUMENT_VISIBILITY_SCOPE
};
let fn_ptr = Lvalue::Local(Local::new(1+0));
let fn_ptr = match trait_sig.inputs()[0].sty {
ty::TyParam(..) => fn_ptr,
ty::TyRef(..) => Lvalue::Projection(box Projection {
base: fn_ptr, elem: ProjectionElem::Deref
}),
_ => bug!("unexpected self_ty {:?}", trait_sig),
};
let fn_args = Local::new(1+1);
let return_block_id = BasicBlock::new(1);
// return = ADT(arg0, arg1, ...); return
let start_block = BasicBlockData {
statements: vec![],
terminator: Some(Terminator {
source_info: source_info,
kind: TerminatorKind::Call {
func: Operand::Consume(fn_ptr),
args: fn_ptr_sig.inputs().iter().enumerate().map(|(i, ity)| {
Operand::Consume(Lvalue::Projection(box Projection {
base: Lvalue::Local(fn_args),
elem: ProjectionElem::Field(
Field::new(i), *ity
)
}))
}).collect(),
// FIXME: can we pass a Some destination for an uninhabited ty?
destination: Some((Lvalue::Local(RETURN_POINTER),
return_block_id)),
cleanup: None
}
}),
is_cleanup: false
};
let return_block = BasicBlockData {
statements: vec![],
terminator: Some(Terminator {
source_info: source_info,
kind: TerminatorKind::Return
}),
is_cleanup: false
};
let mut mir = Mir::new(
vec![start_block, return_block].into_iter().collect(),
IndexVec::from_elem_n(
VisibilityScopeData { span: DUMMY_SP, parent_scope: None }, 1
),
IndexVec::new(),
sig.output(),
local_decls,
sig.inputs().len(),
vec![],
DUMMY_SP
);
mir.spread_arg = Some(fn_args);
mir
}
pub fn build_adt_ctor<'a, 'gcx, 'tcx>(infcx: &infer::InferCtxt<'a, 'gcx, 'tcx>,
ctor_id: ast::NodeId,
fields: &[hir::StructField],

View File

@ -36,7 +36,6 @@ use back::symbol_names::symbol_name;
use trans_item::TransItem;
use type_of;
use rustc::ty::{self, Ty, TypeFoldable};
use rustc::hir;
use std::iter;
use syntax_pos::DUMMY_SP;
@ -130,15 +129,14 @@ impl<'tcx> Callee<'tcx> {
let method_ty = instance_ty(ccx.shared(), &instance);
Callee::ptr(llfn, method_ty)
}
traits::VtableFnPointer(vtable_fn_pointer) => {
let trait_closure_kind = tcx.lang_items.fn_trait_kind(trait_id).unwrap();
let instance = Instance::new(def_id, substs);
let llfn = trans_fn_pointer_shim(ccx, instance,
trait_closure_kind,
vtable_fn_pointer.fn_ty);
traits::VtableFnPointer(data) => {
let instance = ty::Instance {
def: ty::InstanceDef::FnPtrShim(def_id, data.fn_ty),
substs: substs,
};
let method_ty = instance_ty(ccx.shared(), &instance);
Callee::ptr(llfn, method_ty)
let (llfn, ty) = get_fn(ccx, instance);
Callee::ptr(llfn, ty)
}
traits::VtableObject(ref data) => {
Callee {
@ -363,124 +361,6 @@ fn trans_fn_once_adapter_shim<'a, 'tcx>(
lloncefn
}
/// Translates an adapter that implements the `Fn` trait for a fn
/// pointer. This is basically the equivalent of something like:
///
/// ```
/// impl<'a> Fn(&'a int) -> &'a int for fn(&int) -> &int {
/// extern "rust-abi" fn call(&self, args: (&'a int,)) -> &'a int {
/// (*self)(args.0)
/// }
/// }
/// ```
///
/// but for the bare function type given.
fn trans_fn_pointer_shim<'a, 'tcx>(
ccx: &'a CrateContext<'a, 'tcx>,
method_instance: Instance<'tcx>,
closure_kind: ty::ClosureKind,
bare_fn_ty: Ty<'tcx>)
-> ValueRef
{
let tcx = ccx.tcx();
// Normalize the type for better caching.
let bare_fn_ty = tcx.normalize_associated_type(&bare_fn_ty);
// If this is an impl of `Fn` or `FnMut` trait, the receiver is `&self`.
let is_by_ref = match closure_kind {
ty::ClosureKind::Fn | ty::ClosureKind::FnMut => true,
ty::ClosureKind::FnOnce => false,
};
let llfnpointer = match bare_fn_ty.sty {
ty::TyFnDef(def_id, substs, _) => {
// Function definitions have to be turned into a pointer.
let llfn = Callee::def(ccx, def_id, substs).reify(ccx);
if !is_by_ref {
// A by-value fn item is ignored, so the shim has
// the same signature as the original function.
return llfn;
}
Some(llfn)
}
_ => None
};
let bare_fn_ty_maybe_ref = if is_by_ref {
tcx.mk_imm_ref(tcx.mk_region(ty::ReErased), bare_fn_ty)
} else {
bare_fn_ty
};
// Check if we already trans'd this shim.
if let Some(&llval) = ccx.fn_pointer_shims().borrow().get(&bare_fn_ty_maybe_ref) {
return llval;
}
debug!("trans_fn_pointer_shim(bare_fn_ty={:?})",
bare_fn_ty);
// Construct the "tuply" version of `bare_fn_ty`. It takes two arguments: `self`,
// which is the fn pointer, and `args`, which is the arguments tuple.
let sig = bare_fn_ty.fn_sig();
let sig = tcx.erase_late_bound_regions_and_normalize(&sig);
assert_eq!(sig.unsafety, hir::Unsafety::Normal);
assert_eq!(sig.abi, Abi::Rust);
let tuple_input_ty = tcx.intern_tup(sig.inputs(), false);
let sig = tcx.mk_fn_sig(
[bare_fn_ty_maybe_ref, tuple_input_ty].iter().cloned(),
sig.output(),
false,
hir::Unsafety::Normal,
Abi::RustCall
);
let fn_ty = FnType::new(ccx, sig, &[]);
let tuple_fn_ty = tcx.mk_fn_ptr(ty::Binder(sig));
debug!("tuple_fn_ty: {:?}", tuple_fn_ty);
//
let function_name = symbol_name(method_instance, ccx.shared());
let llfn = declare::define_internal_fn(ccx, &function_name, tuple_fn_ty);
attributes::set_frame_pointer_elimination(ccx, llfn);
//
let bcx = Builder::new_block(ccx, llfn, "entry-block");
let mut llargs = get_params(llfn);
let self_arg = llargs.remove(fn_ty.ret.is_indirect() as usize);
let llfnpointer = llfnpointer.unwrap_or_else(|| {
// the first argument (`self`) will be ptr to the fn pointer
if is_by_ref {
bcx.load(self_arg, None)
} else {
self_arg
}
});
let callee = Callee {
data: Fn(llfnpointer),
ty: bare_fn_ty
};
let fn_ret = callee.ty.fn_ret();
let fn_ty = callee.direct_fn_type(ccx, &[]);
let llret = bcx.call(llfnpointer, &llargs, None);
fn_ty.apply_attrs_callsite(llret);
if fn_ret.0.is_never() {
bcx.unreachable();
} else {
if fn_ty.ret.is_indirect() || fn_ty.ret.is_ignore() {
bcx.ret_void();
} else {
bcx.ret(llret);
}
}
ccx.fn_pointer_shims().borrow_mut().insert(bare_fn_ty_maybe_ref, llfn);
llfn
}
/// Translates a reference to a fn/method item, monomorphizing and
/// inlining as it goes.

View File

@ -907,14 +907,12 @@ fn do_static_trait_method_dispatch<'a, 'tcx>(scx: &SharedCrateContext<'a, 'tcx>,
}
}
traits::VtableFnPointer(ref data) => {
// If we know the destination of this fn-pointer, we'll have to make
// sure that this destination actually gets instantiated.
if let ty::TyFnDef(def_id, substs, _) = data.fn_ty.sty {
// The destination of the pointer might be something that needs
// further dispatching, such as a trait method, so we do that.
do_static_dispatch(scx, def_id, substs)
} else {
StaticDispatchResult::Unknown
StaticDispatchResult::Dispatched {
instance: Instance {
def: ty::InstanceDef::FnPtrShim(trait_method.def_id, data.fn_ty),
substs: trait_ref.substs
},
fn_once_adjustment: None,
}
}
// Trait object shims are always instantiated in-place, and as they are

View File

@ -28,10 +28,12 @@ fn main() {
//~ TRANS_ITEM fn function_as_argument::take_fn_once[0]<u32, &str, fn(u32, &str)>
//~ TRANS_ITEM fn function_as_argument::function[0]<u32, &str>
//~ TRANS_ITEM fn core::ops[0]::FnOnce[0]::call_once[0]<fn(u32, &str), (u32, &str)>
take_fn_once(function, 0u32, "abc");
//~ TRANS_ITEM fn function_as_argument::take_fn_once[0]<char, f64, fn(char, f64)>
//~ TRANS_ITEM fn function_as_argument::function[0]<char, f64>
//~ TRANS_ITEM fn core::ops[0]::FnOnce[0]::call_once[0]<fn(char, f64), (char, f64)>
take_fn_once(function, 'c', 0f64);
//~ TRANS_ITEM fn function_as_argument::take_fn_pointer[0]<i32, ()>

View File

@ -40,22 +40,28 @@ fn take_foo_mut<T, F: FnMut(T) -> T>(mut f: F, arg: T) -> T {
fn main() {
//~ TRANS_ITEM fn trait_method_as_argument::take_foo_once[0]<u32, fn(u32) -> u32>
//~ TRANS_ITEM fn trait_method_as_argument::{{impl}}[0]::foo[0]
//~ TRANS_ITEM fn core::ops[0]::FnOnce[0]::call_once[0]<fn(u32) -> u32, (u32)>
take_foo_once(Trait::foo, 0u32);
//~ TRANS_ITEM fn trait_method_as_argument::take_foo_once[0]<char, fn(char) -> char>
//~ TRANS_ITEM fn trait_method_as_argument::Trait[0]::foo[0]<char>
//~ TRANS_ITEM fn core::ops[0]::FnOnce[0]::call_once[0]<fn(char) -> char, (char)>
take_foo_once(Trait::foo, 'c');
//~ TRANS_ITEM fn trait_method_as_argument::take_foo[0]<u32, fn(u32) -> u32>
//~ TRANS_ITEM fn core::ops[0]::Fn[0]::call[0]<fn(u32) -> u32, (u32)>
take_foo(Trait::foo, 0u32);
//~ TRANS_ITEM fn trait_method_as_argument::take_foo[0]<char, fn(char) -> char>
//~ TRANS_ITEM fn core::ops[0]::Fn[0]::call[0]<fn(char) -> char, (char)>
take_foo(Trait::foo, 'c');
//~ TRANS_ITEM fn trait_method_as_argument::take_foo_mut[0]<u32, fn(u32) -> u32>
//~ TRANS_ITEM fn core::ops[0]::FnMut[0]::call_mut[0]<fn(char) -> char, (char)>
take_foo_mut(Trait::foo, 0u32);
//~ TRANS_ITEM fn trait_method_as_argument::take_foo_mut[0]<char, fn(char) -> char>
//~ TRANS_ITEM fn core::ops[0]::FnMut[0]::call_mut[0]<fn(u32) -> u32, (u32)>
take_foo_mut(Trait::foo, 'c');
}

View File

@ -0,0 +1,56 @@
// Copyright 2016 The Rust Project Developers. See the COPYRIGHT
// file at the top-level directory of this distribution and at
// http://rust-lang.org/COPYRIGHT.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.
#![feature(fn_traits)]
#![feature(never_type)]
use std::panic;
fn foo(x: u32, y: u32) -> u32 { x/y }
fn foo_diverges() -> ! { panic!() }
fn test_fn_ptr<T>(mut t: T)
where T: Fn(u32, u32) -> u32,
{
let as_fn = <T as Fn<(u32, u32)>>::call;
assert_eq!(as_fn(&t, (9, 3)), 3);
let as_fn_mut = <T as FnMut<(u32, u32)>>::call_mut;
assert_eq!(as_fn_mut(&mut t, (18, 3)), 6);
let as_fn_once = <T as FnOnce<(u32, u32)>>::call_once;
assert_eq!(as_fn_once(t, (24, 3)), 8);
}
fn assert_panics<F>(f: F) where F: FnOnce() {
let f = panic::AssertUnwindSafe(f);
let result = panic::catch_unwind(move || {
f.0()
});
if let Ok(..) = result {
panic!("diverging function returned");
}
}
fn test_fn_ptr_panic<T>(mut t: T)
where T: Fn() -> !
{
let as_fn = <T as Fn<()>>::call;
assert_panics(|| as_fn(&t, ()));
let as_fn_mut = <T as FnMut<()>>::call_mut;
assert_panics(|| as_fn_mut(&mut t, ()));
let as_fn_once = <T as FnOnce<()>>::call_once;
assert_panics(|| as_fn_once(t, ()));
}
fn main() {
test_fn_ptr(foo);
test_fn_ptr(foo as fn(u32, u32) -> u32);
test_fn_ptr_panic(foo_diverges);
test_fn_ptr_panic(foo_diverges as fn() -> !);
}