Refactor the autoderef cycle to be a reuseable base class

The autoderef cycle is used to coerce/select apropriate methods during
method resolution. This same cycle of autoderef is also used as coercion
sites. In order to avoid duplicating the logic this extracts out a useful
base class that can be implemented to support this elsewhere.
This commit is contained in:
Philip Herron 2022-07-29 14:29:51 +01:00
parent c788a80619
commit ffb419d6a4
4 changed files with 197 additions and 166 deletions

View File

@ -268,5 +268,129 @@ resolve_operator_overload_fn (
return true;
}
AutoderefCycle::AutoderefCycle (bool autoderef_flag)
: autoderef_flag (autoderef_flag)
{}
AutoderefCycle::~AutoderefCycle () {}
void
AutoderefCycle::try_hook (const TyTy::BaseType &)
{}
bool
AutoderefCycle::cycle (const TyTy::BaseType *receiver)
{
const TyTy::BaseType *r = receiver;
while (true)
{
if (try_autoderefed (r))
return true;
// 4. deref to to 1, if cannot deref then quit
if (autoderef_flag)
return false;
// try unsize
Adjustment unsize = Adjuster::try_unsize_type (r);
if (!unsize.is_error ())
{
adjustments.push_back (unsize);
auto unsize_r = unsize.get_expected ();
if (try_autoderefed (unsize_r))
return true;
adjustments.pop_back ();
}
Adjustment deref
= Adjuster::try_deref_type (r, Analysis::RustLangItem::ItemType::DEREF);
if (!deref.is_error ())
{
auto deref_r = deref.get_expected ();
adjustments.push_back (deref);
if (try_autoderefed (deref_r))
return true;
adjustments.pop_back ();
}
Adjustment deref_mut = Adjuster::try_deref_type (
r, Analysis::RustLangItem::ItemType::DEREF_MUT);
if (!deref_mut.is_error ())
{
auto deref_r = deref_mut.get_expected ();
adjustments.push_back (deref_mut);
if (try_autoderefed (deref_r))
return true;
adjustments.pop_back ();
}
if (!deref_mut.is_error ())
{
auto deref_r = deref_mut.get_expected ();
adjustments.push_back (deref_mut);
Adjustment raw_deref = Adjuster::try_raw_deref_type (deref_r);
adjustments.push_back (raw_deref);
deref_r = raw_deref.get_expected ();
if (try_autoderefed (deref_r))
return true;
adjustments.pop_back ();
adjustments.pop_back ();
}
if (!deref.is_error ())
{
r = deref.get_expected ();
adjustments.push_back (deref);
}
Adjustment raw_deref = Adjuster::try_raw_deref_type (r);
if (raw_deref.is_error ())
return false;
r = raw_deref.get_expected ();
adjustments.push_back (raw_deref);
}
return false;
}
bool
AutoderefCycle::try_autoderefed (const TyTy::BaseType *r)
{
try_hook (*r);
// 1. try raw
if (select (*r))
return true;
// 2. try ref
TyTy::ReferenceType *r1
= new TyTy::ReferenceType (r->get_ref (), TyTy::TyVar (r->get_ref ()),
Mutability::Imm);
adjustments.push_back (Adjustment (Adjustment::AdjustmentType::IMM_REF, r1));
if (select (*r1))
return true;
adjustments.pop_back ();
// 3. try mut ref
TyTy::ReferenceType *r2
= new TyTy::ReferenceType (r->get_ref (), TyTy::TyVar (r->get_ref ()),
Mutability::Mut);
adjustments.push_back (Adjustment (Adjustment::AdjustmentType::MUT_REF, r2));
if (select (*r2))
return true;
adjustments.pop_back ();
return false;
}
} // namespace Resolver
} // namespace Rust

View File

@ -144,6 +144,27 @@ private:
const TyTy::BaseType *base;
};
class AutoderefCycle
{
protected:
AutoderefCycle (bool autoderef_flag);
virtual ~AutoderefCycle ();
virtual bool select (const TyTy::BaseType &autoderefed) = 0;
// optional: this is a chance to hook in to grab predicate items on the raw
// type
virtual void try_hook (const TyTy::BaseType &);
bool cycle (const TyTy::BaseType *receiver);
bool try_autoderefed (const TyTy::BaseType *r);
bool autoderef_flag;
std::vector<Adjustment> adjustments;
};
} // namespace Resolver
} // namespace Rust

View File

@ -23,151 +23,32 @@
namespace Rust {
namespace Resolver {
MethodResolver::MethodResolver (bool autoderef_flag,
const HIR::PathIdentSegment &segment_name)
: AutoderefCycle (autoderef_flag), mappings (Analysis::Mappings::get ()),
context (TypeCheckContext::get ()), segment_name (segment_name),
try_result (MethodCandidate::get_error ())
{}
MethodCandidate
MethodResolver::Probe (const TyTy::BaseType *receiver,
const HIR::PathIdentSegment &segment_name,
bool autoderef_flag)
{
const TyTy::BaseType *r = receiver;
std::vector<Adjustment> adjustments;
while (true)
{
auto res = Try (r, segment_name, adjustments);
if (!res.is_error ())
return res;
// 4. deref to to 1, if cannot deref then quit
if (autoderef_flag)
return MethodCandidate::get_error ();
// try unsize
Adjustment unsize = Adjuster::try_unsize_type (r);
if (!unsize.is_error ())
{
adjustments.push_back (unsize);
auto unsize_r = unsize.get_expected ();
auto res = Try (unsize_r, segment_name, adjustments);
if (!res.is_error ())
{
return res;
}
adjustments.pop_back ();
}
Adjustment deref
= Adjuster::try_deref_type (r, Analysis::RustLangItem::ItemType::DEREF);
if (!deref.is_error ())
{
auto deref_r = deref.get_expected ();
adjustments.push_back (deref);
auto res = Try (deref_r, segment_name, adjustments);
if (!res.is_error ())
{
return res;
}
adjustments.pop_back ();
}
Adjustment deref_mut = Adjuster::try_deref_type (
r, Analysis::RustLangItem::ItemType::DEREF_MUT);
if (!deref_mut.is_error ())
{
auto deref_r = deref_mut.get_expected ();
adjustments.push_back (deref_mut);
auto res = Try (deref_r, segment_name, adjustments);
if (!res.is_error ())
{
return res;
}
adjustments.pop_back ();
}
if (!deref_mut.is_error ())
{
auto deref_r = deref_mut.get_expected ();
adjustments.push_back (deref_mut);
Adjustment raw_deref = Adjuster::try_raw_deref_type (deref_r);
adjustments.push_back (raw_deref);
deref_r = raw_deref.get_expected ();
auto res = Try (deref_r, segment_name, adjustments);
if (!res.is_error ())
{
return res;
}
adjustments.pop_back ();
adjustments.pop_back ();
}
if (!deref.is_error ())
{
r = deref.get_expected ();
adjustments.push_back (deref);
}
Adjustment raw_deref = Adjuster::try_raw_deref_type (r);
if (raw_deref.is_error ())
return MethodCandidate::get_error ();
r = raw_deref.get_expected ();
adjustments.push_back (raw_deref);
}
return MethodCandidate::get_error ();
MethodResolver resolver (autoderef_flag, segment_name);
bool ok = resolver.cycle (receiver);
return ok ? resolver.try_result : MethodCandidate::get_error ();
}
MethodCandidate
MethodResolver::Try (const TyTy::BaseType *r,
const HIR::PathIdentSegment &segment_name,
std::vector<Adjustment> &adjustments)
void
MethodResolver::try_hook (const TyTy::BaseType &r)
{
PathProbeCandidate c = PathProbeCandidate::get_error ();
const std::vector<TyTy::TypeBoundPredicate> &specified_bounds
= r->get_specified_bounds ();
const std::vector<MethodResolver::predicate_candidate> predicate_items
= get_predicate_items (segment_name, *r, specified_bounds);
// 1. try raw
MethodResolver raw (*r, segment_name, predicate_items);
c = raw.select ();
if (!c.is_error ())
{
return MethodCandidate{c, adjustments};
}
// 2. try ref
TyTy::ReferenceType *r1
= new TyTy::ReferenceType (r->get_ref (), TyTy::TyVar (r->get_ref ()),
Mutability::Imm);
MethodResolver imm_ref (*r1, segment_name, predicate_items);
c = imm_ref.select ();
if (!c.is_error ())
{
adjustments.push_back (
Adjustment (Adjustment::AdjustmentType::IMM_REF, r1));
return MethodCandidate{c, adjustments};
}
// 3. try mut ref
TyTy::ReferenceType *r2
= new TyTy::ReferenceType (r->get_ref (), TyTy::TyVar (r->get_ref ()),
Mutability::Mut);
MethodResolver mut_ref (*r2, segment_name, predicate_items);
c = mut_ref.select ();
if (!c.is_error ())
{
adjustments.push_back (
Adjustment (Adjustment::AdjustmentType::MUT_REF, r2));
return MethodCandidate{c, adjustments};
}
return MethodCandidate::get_error ();
const auto &specified_bounds = r.get_specified_bounds ();
predicate_items = get_predicate_items (segment_name, r, specified_bounds);
}
PathProbeCandidate
MethodResolver::select ()
bool
MethodResolver::select (const TyTy::BaseType &receiver)
{
struct impl_item_candidate
{
@ -300,9 +181,11 @@ MethodResolver::select ()
{
PathProbeCandidate::ImplItemCandidate c{impl_item.item,
impl_item.impl_block};
return PathProbeCandidate (
PathProbeCandidate::CandidateType::IMPL_FUNC, fn,
impl_item.item->get_locus (), c);
try_result = MethodCandidate{
PathProbeCandidate (PathProbeCandidate::CandidateType::IMPL_FUNC,
fn, impl_item.item->get_locus (), c),
adjustments};
return true;
}
}
@ -317,9 +200,11 @@ MethodResolver::select ()
PathProbeCandidate::TraitItemCandidate c{trait_item.reference,
trait_item.item_ref,
nullptr};
return PathProbeCandidate (
PathProbeCandidate::CandidateType::TRAIT_FUNC, fn,
trait_item.item->get_locus (), c);
try_result = MethodCandidate{
PathProbeCandidate (PathProbeCandidate::CandidateType::TRAIT_FUNC,
fn, trait_item.item->get_locus (), c),
adjustments};
return true;
}
}
@ -338,13 +223,15 @@ MethodResolver::select ()
PathProbeCandidate::TraitItemCandidate c{trait_ref, trait_item,
nullptr};
return PathProbeCandidate (
PathProbeCandidate::CandidateType::TRAIT_FUNC, fn->clone (),
trait_item->get_locus (), c);
try_result = MethodCandidate{
PathProbeCandidate (PathProbeCandidate::CandidateType::TRAIT_FUNC,
fn->clone (), trait_item->get_locus (), c),
adjustments};
return true;
}
}
return PathProbeCandidate::get_error ();
return false;
}
std::vector<MethodResolver::predicate_candidate>

View File

@ -37,43 +37,42 @@ struct MethodCandidate
bool is_error () const { return candidate.is_error (); }
};
class MethodResolver : public TypeCheckBase
class MethodResolver : protected AutoderefCycle
{
protected:
using Rust::Resolver::TypeCheckBase::visit;
public:
static MethodCandidate Probe (const TyTy::BaseType *receiver,
const HIR::PathIdentSegment &segment_name,
bool autoderef_flag = false);
protected:
struct predicate_candidate
{
TyTy::TypeBoundPredicateItem lookup;
TyTy::FnType *fntype;
};
static MethodCandidate Try (const TyTy::BaseType *r,
const HIR::PathIdentSegment &segment_name,
std::vector<Adjustment> &adjustments);
static MethodCandidate Probe (const TyTy::BaseType *receiver,
const HIR::PathIdentSegment &segment_name,
bool autoderef_flag = false);
static std::vector<predicate_candidate> get_predicate_items (
const HIR::PathIdentSegment &segment_name, const TyTy::BaseType &receiver,
const std::vector<TyTy::TypeBoundPredicate> &specified_bounds);
PathProbeCandidate select ();
protected:
MethodResolver (bool autoderef_flag,
const HIR::PathIdentSegment &segment_name);
MethodResolver (
const TyTy::BaseType &receiver, const HIR::PathIdentSegment &segment_name,
const std::vector<MethodResolver::predicate_candidate> &predicate_items)
: receiver (receiver), segment_name (segment_name),
predicate_items (predicate_items)
{}
void try_hook (const TyTy::BaseType &r) override;
const TyTy::BaseType &receiver;
bool select (const TyTy::BaseType &receiver) override;
private:
// context info
Analysis::Mappings *mappings;
TypeCheckContext *context;
// search
const HIR::PathIdentSegment &segment_name;
const std::vector<MethodResolver::predicate_candidate> &predicate_items;
std::vector<MethodResolver::predicate_candidate> predicate_items;
// mutable fields
MethodCandidate try_result;
};
} // namespace Resolver