Support return aggregates in platform intrinsics.
This also involved adding `[TYPE;N]` syntax and aggregate indexing support to the generator script: it's the only way to be able to have a parameterised intrinsic that returns an aggregate, since one can't refer to previous elements of the current aggregate (and that was harder to implement).
This commit is contained in:
parent
c19e7b629b
commit
7241ae9112
|
@ -19,7 +19,7 @@ import itertools
|
|||
SPEC = re.compile(
|
||||
r'^(?:(?P<void>V)|(?P<id>[iusfIUSF])(?:\((?P<start>\d+)-(?P<end>\d+)\)|'
|
||||
r'(?P<width>\d+)(:?/(?P<llvm_width>\d+))?)'
|
||||
r'|(?P<reference>\d+))(?P<modifiers>[vShdnwusDMC]*)(?P<force_width>x\d+)?'
|
||||
r'|(?P<reference>\d+))(?P<index>\.\d+)?(?P<modifiers>[vShdnwusfDMC]*)(?P<force_width>x\d+)?'
|
||||
r'(?:(?P<pointer>Pm|Pc)(?P<llvm_pointer>/.*)?|(?P<bitcast>->.*))?$'
|
||||
)
|
||||
|
||||
|
@ -70,9 +70,14 @@ class IntrinsicSet(object):
|
|||
{k: lookup(v) for k, v in data.items()})
|
||||
|
||||
class PlatformTypeInfo(object):
|
||||
def __init__(self, llvm_name, properties):
|
||||
self.properties = properties
|
||||
self.llvm_name = llvm_name
|
||||
def __init__(self, llvm_name, properties, elems = None):
|
||||
if elems is None:
|
||||
self.properties = properties
|
||||
self.llvm_name = llvm_name
|
||||
else:
|
||||
assert properties is None and llvm_name is None
|
||||
self.properties = {}
|
||||
self.elems = elems
|
||||
|
||||
def __repr__(self):
|
||||
return '<PlatformTypeInfo {}, {}>'.format(self.llvm_name, self.properties)
|
||||
|
@ -80,13 +85,17 @@ class PlatformTypeInfo(object):
|
|||
def __getattr__(self, name):
|
||||
return self.properties[name]
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.elems[idx]
|
||||
|
||||
def vectorize(self, length, width_info):
|
||||
props = self.properties.copy()
|
||||
props.update(width_info)
|
||||
return PlatformTypeInfo('v{}{}'.format(length, self.llvm_name), props)
|
||||
|
||||
def pointer(self):
|
||||
return PlatformTypeInfo('p0{}'.format(self.llvm_name), self.properties)
|
||||
def pointer(self, llvm_elem):
|
||||
name = self.llvm_name if llvm_elem is None else llvm_elem.llvm_name
|
||||
return PlatformTypeInfo('p0{}'.format(name), self.properties)
|
||||
|
||||
BITWIDTH_POINTER = '<pointer>'
|
||||
|
||||
|
@ -128,6 +137,8 @@ class Number(Type):
|
|||
return Unsigned(self.bitwidth())
|
||||
elif spec == 's':
|
||||
return Signed(self.bitwidth())
|
||||
elif spec == 'f':
|
||||
return Float(self.bitwidth())
|
||||
elif spec == 'w':
|
||||
return self.__class__(self.bitwidth() * 2)
|
||||
elif spec == 'n':
|
||||
|
@ -283,7 +294,11 @@ class Pointer(Type):
|
|||
self._elem.rust_name())
|
||||
|
||||
def type_info(self, platform_info):
|
||||
return self._elem.type_info(platform_info).pointer()
|
||||
if self._llvm_elem is None:
|
||||
llvm_elem = None
|
||||
else:
|
||||
llvm_elem = self._llvm_elem.type_info(platform_info)
|
||||
return self._elem.type_info(platform_info).pointer(llvm_elem)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, Pointer) and self._const == other._const \
|
||||
|
@ -298,6 +313,14 @@ class Aggregate(Type):
|
|||
def __repr__(self):
|
||||
return '<Aggregate {}>'.format(self._elems)
|
||||
|
||||
def modify(self, spec, width, previous):
|
||||
if spec.startswith('.'):
|
||||
num = int(spec[1:])
|
||||
return self._elems[num]
|
||||
else:
|
||||
print(spec)
|
||||
raise NotImplementedError()
|
||||
|
||||
def compiler_ctor(self):
|
||||
return 'agg({}, vec![{}])'.format('true' if self._flatten else 'false',
|
||||
', '.join(elem.compiler_ctor() for elem in self._elems))
|
||||
|
@ -306,8 +329,7 @@ class Aggregate(Type):
|
|||
return '({})'.format(', '.join(elem.rust_name() for elem in self._elems))
|
||||
|
||||
def type_info(self, platform_info):
|
||||
#return PlatformTypeInfo(None, None, self._llvm_name)
|
||||
return None
|
||||
return PlatformTypeInfo(None, None, [elem.type_info(platform_info) for elem in self._elems])
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, Aggregate) and self._flatten == other._flatten and \
|
||||
|
@ -349,7 +371,11 @@ class TypeSpec(object):
|
|||
id = match.group('id')
|
||||
reference = match.group('reference')
|
||||
|
||||
modifiers = list(match.group('modifiers') or '')
|
||||
modifiers = []
|
||||
index = match.group('index')
|
||||
if index is not None:
|
||||
modifiers.append(index)
|
||||
modifiers += list(match.group('modifiers') or '')
|
||||
force = match.group('force_width')
|
||||
if force is not None:
|
||||
modifiers.append(force)
|
||||
|
@ -407,16 +433,32 @@ class TypeSpec(object):
|
|||
else:
|
||||
assert False, 'matched `{}`, but didn\'t understand it?'.format(spec)
|
||||
elif spec.startswith('('):
|
||||
assert bitcast is None
|
||||
if spec.endswith(')'):
|
||||
raise NotImplementedError()
|
||||
true_spec = spec[1:-1]
|
||||
flatten = False
|
||||
elif spec.endswith(')f'):
|
||||
true_spec = spec[1:-2]
|
||||
flatten = True
|
||||
else:
|
||||
assert False, 'found unclosed aggregate `{}`'.format(spec)
|
||||
|
||||
for elems in itertools.product(*(TypeSpec(subspec).enumerate(width, previous)
|
||||
for subspec in true_spec.split(','))):
|
||||
yield Aggregate(flatten, elems)
|
||||
elif spec.startswith('['):
|
||||
if spec.endswith(']'):
|
||||
true_spec = spec[1:-1]
|
||||
flatten = False
|
||||
elif spec.endswith(']f'):
|
||||
true_spec = spec[1:-2]
|
||||
flatten = True
|
||||
else:
|
||||
assert False, 'found unclosed aggregate `{}`'.format(spec)
|
||||
elem_spec, count = true_spec.split(';')
|
||||
|
||||
count = int(count)
|
||||
for elem in TypeSpec(elem_spec).enumerate(width, previous):
|
||||
yield Aggregate(flatten, [elem] * count)
|
||||
else:
|
||||
assert False, 'Failed to parse `{}`'.format(spec)
|
||||
|
||||
|
@ -514,7 +556,7 @@ def parse_args():
|
|||
core_type := void | vector | scalar | aggregate | reference
|
||||
|
||||
modifier := 'v' | 'h' | 'd' | 'n' | 'w' | 'u' | 's' |
|
||||
'x' number
|
||||
'x' number | '.' number
|
||||
suffix := pointer | bitcast
|
||||
pointer := 'Pm' llvm_pointer? | 'Pc' llvm_pointer?
|
||||
llvm_pointer := '/' type
|
||||
|
@ -529,7 +571,7 @@ def parse_args():
|
|||
scalar_type := 'U' | 'S' | 'F'
|
||||
llvm_width := '/' number
|
||||
|
||||
aggregate := '(' (type),* ')' 'f'?
|
||||
aggregate := '(' (type),* ')' 'f'? | '[' type ';' number ']' 'f'?
|
||||
|
||||
reference := number
|
||||
|
||||
|
@ -586,6 +628,12 @@ def parse_args():
|
|||
- no `f` corresponds to `declare ... @llvm.foo({float, i32})`.
|
||||
- having an `f` corresponds to `declare ... @llvm.foo(float, i32)`.
|
||||
|
||||
The `[type;number]` form is a just shorter way to write
|
||||
`(...)`, except avoids doing a cartesian product of generic
|
||||
types, e.g. `[S32;2]` is the same as `(S32, S32)`, while
|
||||
`[I32;2]` is describing just the two types `(S32,S32)` and
|
||||
`(U32,U32)` (i.e. doesn't include `(S32,U32)`, `(U32,S32)` as
|
||||
`(I32,I32)` would).
|
||||
|
||||
(Currently aggregates can not contain other aggregates.)
|
||||
|
||||
|
@ -604,13 +652,16 @@ def parse_args():
|
|||
### Modifiers
|
||||
|
||||
- 'v': put a scalar into a vector of the current width (u32 -> u32x4, when width == 128)
|
||||
- 'S': get the scalar element of a vector (u32x4 -> u32)
|
||||
- 'h': half the length of the vector (u32x4 -> u32x2)
|
||||
- 'd': double the length of the vector (u32x2 -> u32x4)
|
||||
- 'n': narrow the element of the vector (u32x4 -> u16x4)
|
||||
- 'w': widen the element of the vector (u16x4 -> u32x4)
|
||||
- 'u': force an integer (vector or scalar) to be unsigned (i32x4 -> u32x4)
|
||||
- 's': force an integer (vector or scalar) to be signed (u32x4 -> i32x4)
|
||||
- 'u': force a number (vector or scalar) to be unsigned int (f32x4 -> u32x4)
|
||||
- 's': force a number (vector or scalar) to be signed int (u32x4 -> i32x4)
|
||||
- 'f': force a number (vector or scalar) to be float (u32x4 -> f32x4)
|
||||
- 'x' number: force the type to be a vector of bitwidth `number`.
|
||||
- '.' number: get the `number`th element of an aggregate
|
||||
- 'D': dereference a pointer (*mut u32 -> u32)
|
||||
- 'C': make a pointer const (*mut u32 -> *const u32)
|
||||
- 'M': make a pointer mut (*const u32 -> *mut u32)
|
||||
|
|
|
@ -965,7 +965,12 @@ pub fn trans_intrinsic_call<'a, 'blk, 'tcx>(mut bcx: Block<'blk, 'tcx>,
|
|||
vec![Type::vector(&elem,
|
||||
length as u64)]
|
||||
}
|
||||
Aggregate(false, _) => unimplemented!(),
|
||||
Aggregate(false, ref contents) => {
|
||||
let elems = contents.iter()
|
||||
.map(|t| one(ty_to_type(ccx, t, any_changes_needed)))
|
||||
.collect::<Vec<_>>();
|
||||
vec![Type::struct_(ccx, &elems, false)]
|
||||
}
|
||||
Aggregate(true, ref contents) => {
|
||||
*any_changes_needed = true;
|
||||
contents.iter()
|
||||
|
@ -1049,7 +1054,7 @@ pub fn trans_intrinsic_call<'a, 'blk, 'tcx>(mut bcx: Block<'blk, 'tcx>,
|
|||
};
|
||||
assert_eq!(inputs.len(), llargs.len());
|
||||
|
||||
match intr.definition {
|
||||
let val = match intr.definition {
|
||||
intrinsics::IntrinsicDef::Named(name) => {
|
||||
let f = declare::declare_cfn(ccx,
|
||||
name,
|
||||
|
@ -1057,6 +1062,20 @@ pub fn trans_intrinsic_call<'a, 'blk, 'tcx>(mut bcx: Block<'blk, 'tcx>,
|
|||
tcx.mk_nil());
|
||||
Call(bcx, f, &llargs, None, call_debug_location)
|
||||
}
|
||||
};
|
||||
|
||||
match intr.output {
|
||||
intrinsics::Type::Aggregate(flatten, ref elems) => {
|
||||
// the output is a tuple so we need to munge it properly
|
||||
assert!(!flatten);
|
||||
|
||||
for i in 0..elems.len() {
|
||||
let val = ExtractValue(bcx, val, i);
|
||||
Store(bcx, val, StructGEP(bcx, llresult, i));
|
||||
}
|
||||
C_nil(ccx)
|
||||
}
|
||||
_ => val,
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue