Specialize equality for [T] and comparison for [u8]

Where T is a type that can be compared for equality bytewise, we can use
memcmp. We can also use memcmp for PartialOrd, Ord for [u8] and by
extension &str.

This is an improvement for example for the comparison [u8] == [u8] that
used to emit a loop that compared the slices byte by byte.

One worry here could be that this introduces function calls to memcmp
in contexts where it should really inline the comparison or even
optimize it out, but llvm takes care of recognizing memcmp specifically.
This commit is contained in:
Ulrik Sverdrup 2016-04-05 14:06:20 +02:00
parent a09f386e8d
commit 5d56e1daed
3 changed files with 147 additions and 46 deletions

View File

@ -75,6 +75,7 @@
#![feature(unwind_attributes)] #![feature(unwind_attributes)]
#![feature(repr_simd, platform_intrinsics)] #![feature(repr_simd, platform_intrinsics)]
#![feature(rustc_attrs)] #![feature(rustc_attrs)]
#![feature(specialization)]
#![feature(staged_api)] #![feature(staged_api)]
#![feature(unboxed_closures)] #![feature(unboxed_closures)]
#![feature(question_mark)] #![feature(question_mark)]

View File

@ -1630,12 +1630,59 @@ pub unsafe fn from_raw_parts_mut<'a, T>(p: *mut T, len: usize) -> &'a mut [T] {
} }
// //
// Boilerplate traits // Comparison traits
// //
extern {
/// Call implementation provided memcmp
///
/// Interprets the data as u8.
///
/// Return 0 for equal, < 0 for less than and > 0 for greater
/// than.
// FIXME(#32610): Return type should be c_int
fn memcmp(s1: *const u8, s2: *const u8, n: usize) -> i32;
}
#[stable(feature = "rust1", since = "1.0.0")] #[stable(feature = "rust1", since = "1.0.0")]
impl<A, B> PartialEq<[B]> for [A] where A: PartialEq<B> { impl<A, B> PartialEq<[B]> for [A] where A: PartialEq<B> {
fn eq(&self, other: &[B]) -> bool { fn eq(&self, other: &[B]) -> bool {
SlicePartialEq::equal(self, other)
}
fn ne(&self, other: &[B]) -> bool {
SlicePartialEq::not_equal(self, other)
}
}
#[stable(feature = "rust1", since = "1.0.0")]
impl<T: Eq> Eq for [T] {}
#[stable(feature = "rust1", since = "1.0.0")]
impl<T: Ord> Ord for [T] {
fn cmp(&self, other: &[T]) -> Ordering {
SliceOrd::compare(self, other)
}
}
#[stable(feature = "rust1", since = "1.0.0")]
impl<T: PartialOrd> PartialOrd for [T] {
fn partial_cmp(&self, other: &[T]) -> Option<Ordering> {
SlicePartialOrd::partial_compare(self, other)
}
}
// intermediate trait for specialization of slice's PartialEq
trait SlicePartialEq<B> {
fn equal(&self, other: &[B]) -> bool;
fn not_equal(&self, other: &[B]) -> bool;
}
// Generic slice equality
impl<A, B> SlicePartialEq<B> for [A]
where A: PartialEq<B>
{
default fn equal(&self, other: &[B]) -> bool {
if self.len() != other.len() { if self.len() != other.len() {
return false; return false;
} }
@ -1648,7 +1695,8 @@ impl<A, B> PartialEq<[B]> for [A] where A: PartialEq<B> {
true true
} }
fn ne(&self, other: &[B]) -> bool {
default fn not_equal(&self, other: &[B]) -> bool {
if self.len() != other.len() { if self.len() != other.len() {
return true; return true;
} }
@ -1663,12 +1711,69 @@ impl<A, B> PartialEq<[B]> for [A] where A: PartialEq<B> {
} }
} }
#[stable(feature = "rust1", since = "1.0.0")] // Use memcmp for bytewise equality when the types allow
impl<T: Eq> Eq for [T] {} impl<A> SlicePartialEq<A> for [A]
where A: PartialEq<A> + BytewiseEquality
{
fn equal(&self, other: &[A]) -> bool {
if self.len() != other.len() {
return false;
}
unsafe {
let size = mem::size_of_val(self);
memcmp(self.as_ptr() as *const u8,
other.as_ptr() as *const u8, size) == 0
}
}
#[stable(feature = "rust1", since = "1.0.0")] fn not_equal(&self, other: &[A]) -> bool {
impl<T: Ord> Ord for [T] { !self.equal(other)
fn cmp(&self, other: &[T]) -> Ordering { }
}
// intermediate trait for specialization of slice's PartialOrd
trait SlicePartialOrd<B> {
fn partial_compare(&self, other: &[B]) -> Option<Ordering>;
}
impl<A> SlicePartialOrd<A> for [A]
where A: PartialOrd
{
default fn partial_compare(&self, other: &[A]) -> Option<Ordering> {
let l = cmp::min(self.len(), other.len());
// Slice to the loop iteration range to enable bound check
// elimination in the compiler
let lhs = &self[..l];
let rhs = &other[..l];
for i in 0..l {
match lhs[i].partial_cmp(&rhs[i]) {
Some(Ordering::Equal) => (),
non_eq => return non_eq,
}
}
self.len().partial_cmp(&other.len())
}
}
impl SlicePartialOrd<u8> for [u8] {
#[inline]
fn partial_compare(&self, other: &[u8]) -> Option<Ordering> {
Some(SliceOrd::compare(self, other))
}
}
// intermediate trait for specialization of slice's Ord
trait SliceOrd<B> {
fn compare(&self, other: &[B]) -> Ordering;
}
impl<A> SliceOrd<A> for [A]
where A: Ord
{
default fn compare(&self, other: &[A]) -> Ordering {
let l = cmp::min(self.len(), other.len()); let l = cmp::min(self.len(), other.len());
// Slice to the loop iteration range to enable bound check // Slice to the loop iteration range to enable bound check
@ -1687,23 +1792,37 @@ impl<T: Ord> Ord for [T] {
} }
} }
#[stable(feature = "rust1", since = "1.0.0")] // memcmp compares a sequence of unsigned bytes lexicographically.
impl<T: PartialOrd> PartialOrd for [T] { // this matches the order we want for [u8], but no others (not even [i8]).
fn partial_cmp(&self, other: &[T]) -> Option<Ordering> { impl SliceOrd<u8> for [u8] {
let l = cmp::min(self.len(), other.len()); #[inline]
fn compare(&self, other: &[u8]) -> Ordering {
// Slice to the loop iteration range to enable bound check let order = unsafe {
// elimination in the compiler memcmp(self.as_ptr(), other.as_ptr(),
let lhs = &self[..l]; cmp::min(self.len(), other.len()))
let rhs = &other[..l]; };
if order == 0 {
for i in 0..l { self.len().cmp(&other.len())
match lhs[i].partial_cmp(&rhs[i]) { } else if order < 0 {
Some(Ordering::Equal) => (), Less
non_eq => return non_eq, } else {
} Greater
} }
self.len().partial_cmp(&other.len())
} }
} }
/// Trait implemented for types that can be compared for equality using
/// their bytewise representation
trait BytewiseEquality { }
macro_rules! impl_marker_for {
($traitname:ident, $($ty:ty)*) => {
$(
impl $traitname for $ty { }
)*
}
}
impl_marker_for!(BytewiseEquality,
u8 i8 u16 i16 u32 i32 u64 i64 usize isize char bool);

View File

@ -1150,16 +1150,7 @@ Section: Comparing strings
#[lang = "str_eq"] #[lang = "str_eq"]
#[inline] #[inline]
fn eq_slice(a: &str, b: &str) -> bool { fn eq_slice(a: &str, b: &str) -> bool {
a.len() == b.len() && unsafe { cmp_slice(a, b, a.len()) == 0 } a.as_bytes() == b.as_bytes()
}
/// Bytewise slice comparison.
/// NOTE: This uses the system's memcmp, which is currently dramatically
/// faster than comparing each byte in a loop.
#[inline]
unsafe fn cmp_slice(a: &str, b: &str, len: usize) -> i32 {
extern { fn memcmp(s1: *const i8, s2: *const i8, n: usize) -> i32; }
memcmp(a.as_ptr() as *const i8, b.as_ptr() as *const i8, len)
} }
/* /*
@ -1328,8 +1319,7 @@ Section: Trait implementations
*/ */
mod traits { mod traits {
use cmp::{self, Ordering, Ord, PartialEq, PartialOrd, Eq}; use cmp::{Ord, Ordering, PartialEq, PartialOrd, Eq};
use cmp::Ordering::{Less, Greater};
use iter::Iterator; use iter::Iterator;
use option::Option; use option::Option;
use option::Option::Some; use option::Option::Some;
@ -1340,16 +1330,7 @@ mod traits {
impl Ord for str { impl Ord for str {
#[inline] #[inline]
fn cmp(&self, other: &str) -> Ordering { fn cmp(&self, other: &str) -> Ordering {
let cmp = unsafe { self.as_bytes().cmp(other.as_bytes())
super::cmp_slice(self, other, cmp::min(self.len(), other.len()))
};
if cmp == 0 {
self.len().cmp(&other.len())
} else if cmp < 0 {
Less
} else {
Greater
}
} }
} }